Merge branch 'dnarayanan/dist_optimizer_refactor' into 'main' · NVIDIA/Megatron-LM@6ac4db0 (original) (raw)
`@@ -2,7 +2,6 @@
`
2
2
``
3
3
`import logging
`
4
4
`from contextlib import contextmanager
`
5
``
`-
from typing import Dict
`
6
5
``
7
6
`import torch
`
8
7
``
`@@ -12,7 +11,7 @@
`
12
11
`from ..transformer.transformer_config import TransformerConfig
`
13
12
`from ..utils import is_float8tensor, log_single_rank
`
14
13
`from .distributed_data_parallel_config import DistributedDataParallelConfig
`
15
``
`-
from .param_and_grad_buffer import BucketGroup, ParamAndGradBuffer, partition_buckets
`
``
14
`+
from .param_and_grad_buffer import _ParamAndGradBuffer, partition_buckets
`
16
15
``
17
16
`logger = logging.getLogger(name)
`
18
17
``
`@@ -77,7 +76,6 @@ def init(
`
77
76
`if disable_bucketing:
`
78
77
`self.bucket_size = None
`
79
78
``
80
``
`-
self.module = module
`
81
79
`self.param_to_bucket_group = {}
`
82
80
``
83
81
`# Group parameters by their gradient type.
`
`@@ -101,7 +99,7 @@ def init(
`
101
99
`else:
`
102
100
`expert_parallel_params.append(param)
`
103
101
``
104
``
`-
def allocate_buffers_for_parameters(
`
``
102
`+
def _allocate_buffers_for_parameters(
`
105
103
`input_params, data_parallel_group, gradient_scaling_factor
`
106
104
` ):
`
107
105
`param_and_grad_dtype_to_params = {}
`
`@@ -110,8 +108,7 @@ def allocate_buffers_for_parameters(
`
110
108
``
111
109
`# Group parameters by their gradient type.
`
112
110
`for param in input_params:
`
113
``
`-
if not param.requires_grad:
`
114
``
`-
continue
`
``
111
`+
assert param.requires_grad
`
115
112
``
116
113
`param_dtype = param.dtype
`
117
114
`if is_float8tensor(param):
`
`@@ -167,7 +164,7 @@ def allocate_buffers_for_parameters(
`
167
164
`buffers = []
`
168
165
`for (param_dtype, grad_dtype), params in param_and_grad_dtype_to_params.items():
`
169
166
`buffers.append(
`
170
``
`-
ParamAndGradBuffer(
`
``
167
`+
_ParamAndGradBuffer(
`
171
168
`self.ddp_config,
`
172
169
`param_dtype,
`
173
170
`grad_dtype,
`
`@@ -187,9 +184,20 @@ def allocate_buffers_for_parameters(
`
187
184
`# because of the use of CUDA_DEVICE_MAX_CONNECTIONS=1, having multiple back-to-back
`
188
185
`# communications will prevent the overlap of the communication kernels with computation
`
189
186
`# kernels.
`
190
``
`-
bucket_groups = partition_buckets(buffers)
`
``
187
`+
If bucketing is explicitly disabled, then put all buckets in a buffer into a single
`
``
188
`+
bucket group.
`
``
189
`+
bucket_groups = partition_buckets(buffers, force_single_bucket_group=disable_bucketing)
`
``
190
+
``
191
`` +
Set next_param_gather_bucket_group
for different bucket groups by iterating through
``
``
192
`+
buckets in reverse order (since all-gathers happen in reverse order of buckets).
`
``
193
`+
if self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather:
`
``
194
`+
num_bucket_groups = len(bucket_groups)
`
``
195
`+
for i in range(1, num_bucket_groups):
`
``
196
`+
bucket_groups[num_bucket_groups - i].next_param_gather_bucket_group = (
`
``
197
`+
bucket_groups[num_bucket_groups - i - 1]
`
``
198
`+
)
`
191
199
``
192
``
`-
Create map from param to BucketGroup, used in pre_hook.
`
``
200
`+
Create map from param to bucket group, used in pre_hook.
`
193
201
`for bucket_group in bucket_groups:
`
194
202
`for bucket in bucket_group.buckets:
`
195
203
`for param in bucket.params_list:
`
`@@ -214,15 +222,15 @@ def allocate_buffers_for_parameters(
`
214
222
`expert_gradient_scaling_factor = 1.0 / data_parallel_world_size
`
215
223
``
216
224
`# Allocate the param+grad buffers for dense params' grads.
`
217
``
`-
self.buffers, self.bucket_groups = allocate_buffers_for_parameters(
`
``
225
`+
self.buffers, self.bucket_groups = _allocate_buffers_for_parameters(
`
218
226
`dense_params,
`
219
227
`parallel_state.get_data_parallel_group(with_context_parallel=True),
`
220
228
`gradient_scaling_factor=gradient_scaling_factor,
`
221
229
` )
`
222
230
``
223
231
`# Allocate separate param+grad buffers for expert parallel params' grads.
`
224
232
`self.expert_parallel_buffers, self.expert_parallel_bucket_groups = (
`
225
``
`-
allocate_buffers_for_parameters(
`
``
233
`+
_allocate_buffers_for_parameters(
`
226
234
`expert_parallel_params,
`
227
235
`parallel_state.get_data_modulo_expert_parallel_group(with_context_parallel=True),
`
228
236
`gradient_scaling_factor=expert_gradient_scaling_factor,
`
`@@ -252,26 +260,93 @@ def unmap_weight_tensor(m):
`
252
260
`param_tmp = param.expand_as(param)
`
253
261
`# Get the gradient accumulator function.
`
254
262
`grad_acc = param_tmp.grad_fn.next_functions[0][0]
`
255
``
`-
grad_acc.register_hook(self._make_param_hook(param, self.param_to_bucket_group))
`
``
263
`+
grad_acc.register_hook(self._make_backward_post_hook(param))
`
256
264
`self.grad_accs.append(grad_acc)
`
257
265
``
``
266
`+
self.use_forward_hook = (
`
``
267
`+
self.ddp_config.use_distributed_optimizer and self.ddp_config.overlap_param_gather
`
``
268
`+
)
`
``
269
`+
self.remove_forward_pre_hook_handles = {}
`
``
270
`+
if self.use_forward_hook:
`
``
271
`+
self.enable_forward_pre_hook()
`
``
272
`+
self.overlap_param_gather_with_optimizer_step = False
`
``
273
+
``
274
`+
def enable_forward_pre_hook(self):
`
``
275
`+
"""
`
``
276
`+
Enable forward pre-hooks needed for param all-gather overlap with forward compute.
`
``
277
`+
"""
`
``
278
`+
assert self.use_forward_hook
`
``
279
`+
assert len(self.remove_forward_pre_hook_handles) == 0
`
``
280
`+
Register forward pre-hook for all sub-modules.
`
``
281
`+
for module in self.module.modules():
`
``
282
`+
self.remove_forward_pre_hook_handles[module] = module.register_forward_pre_hook(
`
``
283
`+
self._make_forward_pre_hook()
`
``
284
`+
)
`
``
285
+
``
286
`+
def disable_forward_pre_hook(self):
`
``
287
`+
"""
`
``
288
`+
Disable forward pre-hooks needed for param all-gather overlap with forward compute.
`
``
289
`+
"""
`
``
290
`+
assert self.use_forward_hook
`
``
291
`+
De-register forward pre-hook for all sub-modules.
`
``
292
`+
for module in self.module.modules():
`
``
293
`+
assert self.remove_forward_pre_hook_handles[module] is not None
`
``
294
`+
self.remove_forward_pre_hook_handles[module].remove()
`
``
295
`+
del self.remove_forward_pre_hook_handles[module]
`
``
296
`+
assert len(self.remove_forward_pre_hook_handles) == 0
`
``
297
+
``
298
`+
Force synchronize parameters.
`
``
299
`+
self.start_param_sync(force_sync=True)
`
``
300
+
258
301
`def forward(self, *inputs, **kwargs):
`
259
302
`"""
`
260
303
` Calls the wrapped module's forward() method.
`
261
304
` """
`
262
305
`return self.module(*inputs, **kwargs)
`
263
306
``
264
``
`-
def _make_param_hook(
`
265
``
`-
self,
`
266
``
`-
param: torch.nn.Parameter,
`
267
``
`-
param_to_bucket_group: Dict[torch.nn.Parameter, BucketGroup],
`
268
``
`-
):
`
``
307
`+
def _make_forward_pre_hook(self):
`
269
308
`"""
`
270
``
`-
Creates the all-reduce / reduce-scatter hook for backprop.
`
``
309
`+
Create a forward pre-hook to wait on all-gather handles when necessary (i.e.,
`
``
310
`+
when a module uses a parameter in a bucket with a still incomplete all-gather).
`
271
311
` """
`
272
312
``
273
``
`-
def param_hook(*unused):
`
274
``
`-
if param.requires_grad:
`
``
313
`+
def hook(module, *unused):
`
``
314
`+
assert (
`
``
315
`+
self.use_forward_hook
`
``
316
`+
), "Should use pre-hook only when overlap_param_gather is True"
`
``
317
+
``
318
`+
Make sure all parameters in this module have been all-gathered as necessary.
`
``
319
`+
for param in module.parameters(recurse=False):
`
``
320
`+
Skip parameters without an associated buffer (such parameters have a
`
``
321
`+
.requires_grad field equal to False).
`
``
322
`+
if param not in self.param_to_bucket_group:
`
``
323
`+
continue
`
``
324
`+
assert param.requires_grad
`
``
325
+
``
326
`+
If aligning param all-gather across pipeline stages, all-gather is dispatched
`
``
327
`+
by start_param_sync calls in core/pipeline_parallelism/schedules.py.
`
``
328
`+
If overlapping param all-gather with optimizer step, then all-gather has
`
``
329
`+
already been dispatched in optimizer step.
`
``
330
`+
skip_next_bucket_dispatch = (
`
``
331
`+
self.ddp_config.align_param_gather
`
``
332
`+
or self.overlap_param_gather_with_optimizer_step
`
``
333
`+
)
`
``
334
`+
self.param_to_bucket_group[param].finish_param_sync(
`
``
335
`+
skip_next_bucket_dispatch=skip_next_bucket_dispatch
`
``
336
`+
)
`
``
337
+
``
338
`+
return hook
`
``
339
+
``
340
`+
def _make_backward_post_hook(self, param: torch.nn.Parameter):
`
``
341
`+
"""
`
``
342
`+
Creates a backward post-hook to dispatch an all-reduce / reduce-scatter when
`
``
343
`+
ready (i.e., when all grads in a bucket have been computed in all microbatches
`
``
344
`+
in a batch).
`
``
345
`+
"""
`
``
346
+
``
347
`+
def hook(*unused):
`
``
348
`+
if param in self.param_to_bucket_group:
`
``
349
`+
assert param.requires_grad
`
275
350
`if self.ddp_config.overlap_grad_reduce:
`
276
351
`assert (
`
277
352
`param.grad is not None
`
`@@ -283,9 +358,9 @@ def param_hook(*unused):
`
283
358
`param.grad = None
`
284
359
``
285
360
`if self.ddp_config.overlap_grad_reduce:
`
286
``
`-
param_to_bucket_group[param].register_grad_ready(param)
`
``
361
`+
self.param_to_bucket_group[param].register_grad_ready(param)
`
287
362
``
288
``
`-
return param_hook
`
``
363
`+
return hook
`
289
364
``
290
365
`@contextmanager
`
291
366
`def no_sync(self):
`
`@@ -300,6 +375,28 @@ def no_sync(self):
`
300
375
`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
`
301
376
`bucket_group.is_last_microbatch = True
`
302
377
``
``
378
`+
def start_param_sync(self, *unused, force_sync: bool = False, force_dispatch: bool = False):
`
``
379
`+
"""
`
``
380
`+
Initiates param sync (all-gather) communication operations for all model parameters.
`
``
381
+
``
382
`+
By default, when overlap_param_gather is set to True, dispatches asynchronous communication
`
``
383
`+
calls; when overlap_param_gather is set to False, calls synchronous communication
`
``
384
`+
ops. Can override this default behavior using flags below.
`
``
385
+
``
386
`+
Args:
`
``
387
`+
force_sync (bool, optional): force synchronous collective regardless of
`
``
388
`+
other settings.
`
``
389
`+
force_dispatch (bool, optional): force dispatch regardless of other settings.
`
``
390
`+
"""
`
``
391
`+
if not force_sync:
`
``
392
`+
If overlapping param AG with optimizer step, AG should not be dispatched again
`
``
393
`+
in forward_backward_step.
`
``
394
`+
if self.overlap_param_gather_with_optimizer_step and not force_dispatch:
`
``
395
`+
return
`
``
396
+
``
397
`+
for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
`
``
398
`+
bucket_group.start_param_sync(force_sync=force_sync)
`
``
399
+
303
400
`def start_grad_sync(self, *unused):
`
304
401
`"""
`
305
402
` Initiates grad sync (all-reduce or reduce-scatter) communication operations
`
`@@ -312,11 +409,6 @@ def start_grad_sync(self, *unused):
`
312
409
`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
`
313
410
`bucket_group.start_grad_sync()
`
314
411
``
315
``
`-
def scale_gradients(self, scaling_factor: float) -> None:
`
316
``
`` -
"""Scale all gradients inside the buffers by scaling_factor
."""
``
317
``
`-
for buffer in self.buffers + self.expert_parallel_buffers:
`
318
``
`-
buffer.scale_gradients(scaling_factor)
`
319
``
-
320
412
`def finish_grad_sync(self):
`
321
413
`"""
`
322
414
` Finishes grad sync (all-reduce or reduce-scatter) communication operations
`
`@@ -329,6 +421,11 @@ def finish_grad_sync(self):
`
329
421
`for bucket_group in self.bucket_groups + self.expert_parallel_bucket_groups:
`
330
422
`bucket_group.finish_grad_sync()
`
331
423
``
``
424
`+
def scale_gradients(self, scaling_factor: float):
`
``
425
`` +
"""Scale all gradients inside the buffers by scaling_factor
."""
``
``
426
`+
for buffer in self.buffers + self.expert_parallel_buffers:
`
``
427
`+
buffer.scale_gradients(scaling_factor)
`
``
428
+
332
429
`def zero_grad_buffer(self):
`
333
430
`"""
`
334
431
` Zeros out all grad buffers. Needs to be called at the beginning of each
`