Merge branch 'dnarayanan/dist_optimizer_refactor' into 'main' · NVIDIA/Megatron-LM@6ac4db0 (original) (raw)

`@@ -2,7 +2,6 @@

`

2

2

``

3

3

`import logging

`

4

4

`from contextlib import contextmanager

`

5

``

`-

from typing import Dict

`

6

5

``

7

6

`import torch

`

8

7

``

`@@ -12,7 +11,7 @@

`

12

11

`from ..transformer.transformer_config import TransformerConfig

`

13

12

`from ..utils import is_float8tensor, log_single_rank

`

14

13

`from .distributed_data_parallel_config import DistributedDataParallelConfig

`

15

``

`-

from .param_and_grad_buffer import BucketGroup, ParamAndGradBuffer, partition_buckets

`

``

14

`+

from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets

`

16

15

``

17

16

`logger = logging.getLogger(name)

`

18

17

``

`@@ -77,7 +76,6 @@ def init(

`

77

76

`if disable_bucketing:

`

78

77

`self.bucket_size = None

`

79

78

``

80

``

`-

self.module = module

`

81

79

`self.param_to_bucket_group = {}

`

82

80

``

83

81

`# Group parameters by their gradient type.

`

`@@ -101,7 +99,7 @@ def init(

`

101

99

`else:

`

102

100

`expert_parallel_params.append(param)

`

103

101

``

104

``

`-

def allocate_buffers_for_parameters(

`

``

102

`+

def _allocate_buffers_for_parameters(

`

105

103

`input_params, data_parallel_group, gradient_scaling_factor

`

106

104

` ):

`

107

105

`param_and_grad_dtype_to_params = {}

`

`@@ -110,8 +108,7 @@ def allocate_buffers_for_parameters(

`

110

108

``

111

109

`# Group parameters by their gradient type.

`

112

110

`for param in input_params:

`

113

``

`-

if not param.requires_grad:

`

114

``

`-

continue

`

``

111

`+

assert param.requires_grad

`

115

112

``

116

113

`param_dtype = param.dtype

`

117

114

`if is_float8tensor(param):

`

`@@ -167,7 +164,7 @@ def allocate_buffers_for_parameters(

`

167

164

`buffers = []

`

168

165

`for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():

`

169

166

`buffers.append(

`

170

``

`-

ParamAndGradBuffer(

`

``

167

`+

_ParamAndGradBuffer(

`

171

168

`self.ddp_config,

`

172

169

`param_dtype,

`

173

170

`grad_dtype,

`

`@@ -187,9 +184,20 @@ def allocate_buffers_for_parameters(

`

187

184

`# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back

`

188

185

`# communications will prevent the overlap of the communication kernels with computation

`

189

186

`# kernels.

`

190

``

`-

bucket_groups = partition_buckets(buffers)

`

``

187

`+

If bucketing is explicitly disabled, then put all buckets in a buffer into a single

`

``

188

`+

bucket group.

`

``

189

`+

bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)

`

``

190

+

``

191

`` +

Set next_param_gather_bucket_group for different bucket groups by iterating through

``

``

192

`+

buckets in reverse order (since all-gathers happen in reverse order of buckets).

`

``

193

`+

if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:

`

``

194

`+

num_bucket_groups = len(bucket_groups)

`

``

195

`+

for i in range(1, num_bucket_groups):

`

``

196

`+

bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (

`

``

197

`+

bucket_groups[num_bucket_groups - i - 1]

`

``

198

`+

)

`

191

199

``

192

``

`-

Create map from param to BucketGroup, used in pre_hook.

`

``

200

`+

Create map from param to bucket group, used in pre_hook.

`

193

201

`for bucket_group in bucket_groups:

`

194

202

`for bucket in bucket_group.buckets:

`

195

203

`for param in bucket.params_list:

`

`@@ -214,15 +222,15 @@ def allocate_buffers_for_parameters(

`

214

222

`expert_gradient_scaling_factor = 1.0 / data_parallel_world_size

`

215

223

``

216

224

`# Allocate the param+grad buffers for dense params' grads.

`

217

``

`-

self.buffers, self.bucket_groups = allocate_buffers_for_parameters(

`

``

225

`+

self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(

`

218

226

`dense_params,

`

219

227

`parallel_state.get_data_parallel_group(with_context_parallel=True),

`

220

228

`gradient_scaling_factor=gradient_scaling_factor,

`

221

229

` )

`

222

230

``

223

231

`# Allocate separate param+grad buffers for expert parallel params' grads.

`

224

232

`self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (

`

225

``

`-

allocate_buffers_for_parameters(

`

``

233

`+

_allocate_buffers_for_parameters(

`

226

234

`expert_parallel_params,

`

227

235

`parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),

`

228

236

`gradient_scaling_factor=expert_gradient_scaling_factor,

`

`@@ -252,26 +260,93 @@ def unmap_weight_tensor(m):

`

252

260

`param_tmp = param.expand_as(param)

`

253

261

`# Get the gradient accumulator function.

`

254

262

`grad_acc = param_tmp.grad_fn.next_functions[0][0]

`

255

``

`-

grad_acc.register_hook(self._make_param_hook(param, self.param_to_bucket_group))

`

``

263

`+

grad_acc.register_hook(self._make_backward_post_hook(param))

`

256

264

`self.grad_accs.append(grad_acc)

`

257

265

``

``

266

`+

self.use_forward_hook = (

`

``

267

`+

self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather

`

``

268

`+

)

`

``

269

`+

self.remove_forward_pre_hook_handles = {}

`

``

270

`+

if self.use_forward_hook:

`

``

271

`+

self.enable_forward_pre_hook()

`

``

272

`+

self.overlap_param_gather_with_optimizer_step = False

`

``

273

+

``

274

`+

def enable_forward_pre_hook(self):

`

``

275

`+

"""

`

``

276

`+

Enable forward pre-hooks needed for param all-gather overlap with forward compute.

`

``

277

`+

"""

`

``

278

`+

assert self.use_forward_hook

`

``

279

`+

assert len(self.remove_forward_pre_hook_handles) == 0

`

``

280

`+

Register forward pre-hook for all sub-modules.

`

``

281

`+

for module in self.module.modules():

`

``

282

`+

self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(

`

``

283

`+

self._make_forward_pre_hook()

`

``

284

`+

)

`

``

285

+

``

286

`+

def disable_forward_pre_hook(self):

`

``

287

`+

"""

`

``

288

`+

Disable forward pre-hooks needed for param all-gather overlap with forward compute.

`

``

289

`+

"""

`

``

290

`+

assert self.use_forward_hook

`

``

291

`+

De-register forward pre-hook for all sub-modules.

`

``

292

`+

for module in self.module.modules():

`

``

293

`+

assert self.remove_forward_pre_hook_handles[module] is not None

`

``

294

`+

self.remove_forward_pre_hook_handles[module].remove()

`

``

295

`+

del self.remove_forward_pre_hook_handles[module]

`

``

296

`+

assert len(self.remove_forward_pre_hook_handles) == 0

`

``

297

+

``

298

`+

Force synchronize parameters.

`

``

299

`+

self.start_param_sync(force_sync=True)

`

``

300

+

258

301

`def forward(self, *inputs, **kwargs):

`

259

302

`"""

`

260

303

` Calls the wrapped module's forward() method.

`

261

304

` """

`

262

305

`return self.module(*inputs, **kwargs)

`

263

306

``

264

``

`-

def _make_param_hook(

`

265

``

`-

self,

`

266

``

`-

param: torch.nn.Parameter,

`

267

``

`-

param_to_bucket_group: Dict[torch.nn.Parameter, BucketGroup],

`

268

``

`-

):

`

``

307

`+

def _make_forward_pre_hook(self):

`

269

308

`"""

`

270

``

`-

Creates the all-reduce / reduce-scatter hook for backprop.

`

``

309

`+

Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,

`

``

310

`+

when a module uses a parameter in a bucket with a still incomplete all-gather).

`

271

311

` """

`

272

312

``

273

``

`-

def param_hook(*unused):

`

274

``

`-

if param.requires_grad:

`

``

313

`+

def hook(module, *unused):

`

``

314

`+

assert (

`

``

315

`+

self.use_forward_hook

`

``

316

`+

), "Should use pre-hook only when overlap_param_gather is True"

`

``

317

+

``

318

`+

Make sure all parameters in this module have been all-gathered as necessary.

`

``

319

`+

for param in module.parameters(recurse=False):

`

``

320

`+

Skip parameters without an associated buffer (such parameters have a

`

``

321

`+

.requires_grad field equal to False).

`

``

322

`+

if param not in self.param_to_bucket_group:

`

``

323

`+

continue

`

``

324

`+

assert param.requires_grad

`

``

325

+

``

326

`+

If aligning param all-gather across pipeline stages, all-gather is dispatched

`

``

327

`+

by start_param_sync calls in core/pipeline_parallelism/schedules.py.

`

``

328

`+

If overlapping param all-gather with optimizer step, then all-gather has

`

``

329

`+

already been dispatched in optimizer step.

`

``

330

`+

skip_next_bucket_dispatch = (

`

``

331

`+

self.ddp_config.align_param_gather

`

``

332

`+

or self.overlap_param_gather_with_optimizer_step

`

``

333

`+

)

`

``

334

`+

self.param_to_bucket_group[param].finish_param_sync(

`

``

335

`+

skip_next_bucket_dispatch=skip_next_bucket_dispatch

`

``

336

`+

)

`

``

337

+

``

338

`+

return hook

`

``

339

+

``

340

`+

def _make_backward_post_hook(self, param: torch.nn.Parameter):

`

``

341

`+

"""

`

``

342

`+

Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when

`

``

343

`+

ready (i.e., when all grads in a bucket have been computed in all microbatches

`

``

344

`+

in a batch).

`

``

345

`+

"""

`

``

346

+

``

347

`+

def hook(*unused):

`

``

348

`+

if param in self.param_to_bucket_group:

`

``

349

`+

assert param.requires_grad

`

275

350

`if self.ddp_config.overlap_grad_reduce:

`

276

351

`assert (

`

277

352

`param.grad is not None

`

`@@ -283,9 +358,9 @@ def param_hook(*unused):

`

283

358

`param.grad = None

`

284

359

``

285

360

`if self.ddp_config.overlap_grad_reduce:

`

286

``

`-

param_to_bucket_group[param].register_grad_ready(param)

`

``

361

`+

self.param_to_bucket_group[param].register_grad_ready(param)

`

287

362

``

288

``

`-

return param_hook

`

``

363

`+

return hook

`

289

364

``

290

365

`@contextmanager

`

291

366

`def no_sync(self):

`

`@@ -300,6 +375,28 @@ def no_sync(self):

`

300

375

`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:

`

301

376

`bucket_group.is_last_microbatch = True

`

302

377

``

``

378

`+

def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):

`

``

379

`+

"""

`

``

380

`+

Initiates param sync (all-gather) communication operations for all model parameters.

`

``

381

+

``

382

`+

By default, when overlap_param_gather is set to True, dispatches asynchronous communication

`

``

383

`+

calls; when overlap_param_gather is set to False, calls synchronous communication

`

``

384

`+

ops. Can override this default behavior using flags below.

`

``

385

+

``

386

`+

Args:

`

``

387

`+

force_sync (bool, optional): force synchronous collective regardless of

`

``

388

`+

other settings.

`

``

389

`+

force_dispatch (bool, optional): force dispatch regardless of other settings.

`

``

390

`+

"""

`

``

391

`+

if not force_sync:

`

``

392

`+

If overlapping param AG with optimizer step, AG should not be dispatched again

`

``

393

`+

in forward_backward_step.

`

``

394

`+

if self.overlap_param_gather_with_optimizer_step and not force_dispatch:

`

``

395

`+

return

`

``

396

+

``

397

`+

for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:

`

``

398

`+

bucket_group.start_param_sync(force_sync=force_sync)

`

``

399

+

303

400

`def start_grad_sync(self, *unused):

`

304

401

`"""

`

305

402

` Initiates grad sync (all-reduce or reduce-scatter) communication operations

`

`@@ -312,11 +409,6 @@ def start_grad_sync(self, *unused):

`

312

409

`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:

`

313

410

`bucket_group.start_grad_sync()

`

314

411

``

315

``

`-

def scale_gradients(self, scaling_factor: float) -> None:

`

316

``

`` -

"""Scale all gradients inside the buffers by scaling_factor."""

``

317

``

`-

for buffer in self.buffers + self.expert_parallel_buffers:

`

318

``

`-

buffer.scale_gradients(scaling_factor)

`

319

``

-

320

412

`def finish_grad_sync(self):

`

321

413

`"""

`

322

414

` Finishes grad sync (all-reduce or reduce-scatter) communication operations

`

`@@ -329,6 +421,11 @@ def finish_grad_sync(self):

`

329

421

`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:

`

330

422

`bucket_group.finish_grad_sync()

`

331

423

``

``

424

`+

def scale_gradients(self, scaling_factor: float):

`

``

425

`` +

"""Scale all gradients inside the buffers by scaling_factor."""

``

``

426

`+

for buffer in self.buffers + self.expert_parallel_buffers:

`

``

427

`+

buffer.scale_gradients(scaling_factor)

`

``

428

+

332

429

`def zero_grad_buffer(self):

`

333

430

`"""

`

334

431

` Zeros out all grad buffers. Needs to be called at the beginning of each

`