【Hackathon 8th No.32】 Adam-mini 精调算法复现 by megemini · Pull Request #10413 · PaddlePaddle/PaddleNLP (original) (raw)

Update 20250424

1. 算法实现

首先,原来的 AdamWMini 算法是不是有问题?原算法这里:

image

是不是应该写成:

p += (mom1 / denom) * (-(lr / (1.0 - beta1_pow)))

然后,正如之前所说,这里分 block 之后,

2. 显存分析

这里的测试环境是:aistudio 的 32g v100 环境。

测试命令:

sft_argument.json 显存装不下,所以用的 lora。

这里实现的 AdamWMini 算法主要优化了 moment2 的显存占用,在原 AdamW 算法优化过程中,moment2shapeparam 一致,占用显存情况如下:

[2025-04-24 13:24:56,538] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4724, interval_samples_per_second: 2.1169, interval_steps_per_second: 2.1169, ppl: 9.008836689307477, progress_or_epoch: 0.012

如果不做优化,即,去掉 moment1moment2 的显存占用,则:

[2025-04-24 13:30:41,708] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.238787412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.429, interval_samples_per_second: 2.3307, interval_steps_per_second: 2.3307, ppl: 9.008836689307477, progress_or_epoch: 0.012

可得,AdamWmoment1/moment2 占用显存为: 15.395037412643433 - 15.238787412643433 = 0.15625

当使用 AdamWMini 算法后,由于大部分需要优化的参数,其 moment2shape 都是 [param.shape[0], 1],占用显存情况如下:

[2025-04-24 14:40:40,786] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.321604490280151, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6338, interval_samples_per_second: 1.5779, interval_steps_per_second: 1.5779, ppl: 9.008836689307477, progress_or_epoch: 0.012

可得,优化后的显存占用: (15.395037412643433 - 15.321604490280151) / 0.15625 = 0.469970703

由于这里只是使用 lora 进行 finetune,所以显存占用的优化情况并不明显。如果有条件的话可以看看 pretrain 的优化情况 ~ 不过 aistudio 好像搞不定 ... ...

另外,对比原 AdamWMini 部分 block 的显存使用情况:

[2025-04-24 14🔞27,326] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.31701922416687, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6674, interval_samples_per_second: 1.4984, interval_steps_per_second: 1.4984, ppl: 9.008836689307477, progress_or_epoch: 0.012

优化后的显存占用: (15.395037412643433 - 15.31701922416687) / 0.15625 = 0.499316406

之所以原来的 AdamWMini 占用显存更小,是因为,原 AdamWMini 不分 block,所有的 moment2shape 都是 [1],但,这与 AdamWMini 作者实现的算法不一致,所以这里不做参考了。

3. 算法的问题

在测试比对 AdamW 的优化精度的时候发现一个问题,这里不使用 c++,而是直接用 python 更新 param 好像不太对 ~~~

在算法中打印参数,插入以下代码:

    if param.name == 'lo_ra_linear_223.w_2':    
        print(">>>>>", master_weight is None)
        print('>'*20, 
            'm1', moment1.sum(), 'm2', moment2.sum(), 
            'b1', beta1_pow_acc.sum(), 'b2', beta2_pow_acc.sum(), 
            'master_weight', master_weight.sum(), 'param', param_and_grad[0].sum(), 'grad', param_and_grad[1].sum())

AdamW 算法,

[2025-04-24 15:19:23,985] [ INFO] - Total num train samples = 500 [2025-04-24 15:19:23,991] [ DEBUG] - Number of trainable parameters = 20,971,520 (per device) W0424 15:19:24.498792 301809 multiply_fwd_func.cc:76] got different data type, run type promotion automatically, this may cause data type been changed. W0424 15:19:24.506394 301809 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.9, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version. Found inf or nan, current scale is: 32768.0, decrease to: 32768.00.5 [2025-04-24 15:19:25,264] [ WARNING] - optimizer not run, scale_before: 32768.0, scale_after: 16384.0 [2025-04-24 15:19:25,268] [ INFO] - loss: 4.01541185, learning_rate: 0.0, global_step: 1, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 1.2765, interval_samples_per_second: 0.7834, interval_steps_per_second: 0.7834, ppl: 55.44612618781211, progress_or_epoch: 0.002 Found inf or nan, current scale is: 16384.0, decrease to: 16384.00.5 [2025-04-24 15:19:25,657] [ WARNING] - optimizer not run, scale_before: 16384.0, scale_after: 8192.0 [2025-04-24 15:19:25,661] [ INFO] - loss: 2.49645019, learning_rate: 0.0, global_step: 2, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3921, interval_samples_per_second: 2.5504, interval_steps_per_second: 2.5504, ppl: 12.139325087796644, progress_or_epoch: 0.004 Found inf or nan, current scale is: 8192.0, decrease to: 8192.00.5 [2025-04-24 15:19:26,041] [ WARNING] - optimizer not run, scale_before: 8192.0, scale_after: 4096.0 [2025-04-24 15:19:26,044] [ INFO] - loss: 1.44202387, learning_rate: 0.0, global_step: 3, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3833, interval_samples_per_second: 2.609, interval_steps_per_second: 2.609, ppl: 4.229246606564234, progress_or_epoch: 0.006 Found inf or nan, current scale is: 4096.0, decrease to: 4096.00.5 [2025-04-24 15:19:26,426] [ WARNING] - optimizer not run, scale_before: 4096.0, scale_after: 2048.0 [2025-04-24 15:19:26,429] [ INFO] - loss: 5.85809898, learning_rate: 0.0, global_step: 4, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.3853, interval_samples_per_second: 2.5952, interval_steps_per_second: 2.5952, ppl: 350.05804374322315, progress_or_epoch: 0.008 Found inf or nan, current scale is: 2048.0, decrease to: 2048.0*0.5 [2025-04-24 15:19:26,831] [ WARNING] - optimizer not run, scale_before: 2048.0, scale_after: 1024.0 [2025-04-24 15:19:26,836] [ INFO] - loss: 1.94955564, learning_rate: 0.0, global_step: 5, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4064, interval_samples_per_second: 2.4606, interval_steps_per_second: 2.4606, ppl: 7.025565006800808, progress_or_epoch: 0.01

False

m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.89999998) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99900001) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.03613281) [2025-04-24 15:19:27,365] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5298, interval_samples_per_second: 1.8877, interval_steps_per_second: 1.8877, ppl: 9.008836689307477, progress_or_epoch: 0.012 Found inf or nan, current scale is: 1024.0, decrease to: 1024.00.5 [2025-04-24 15:19:27,869] [ WARNING] - optimizer not run, scale_before: 1024.0, scale_after: 512.0 [2025-04-24 15:19:27,875] [ INFO] - loss: 3.99188352, learning_rate: 0.0003, global_step: 7, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5089, interval_samples_per_second: 1.9651, interval_steps_per_second: 1.9651, ppl: 54.15679877261727, progress_or_epoch: 0.014 Found inf or nan, current scale is: 512.0, decrease to: 512.00.5 [2025-04-24 15:19:28,377] [ WARNING] - optimizer not run, scale_before: 512.0, scale_after: 256.0 [2025-04-24 15:19:28,381] [ INFO] - loss: 2.8999517, learning_rate: 0.0003, global_step: 8, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5069, interval_samples_per_second: 1.9729, interval_steps_per_second: 1.9729, ppl: 18.173267579420518, progress_or_epoch: 0.016 Found inf or nan, current scale is: 256.0, decrease to: 256.00.5 [2025-04-24 15:19:28,875] [ WARNING] - optimizer not run, scale_before: 256.0, scale_after: 128.0 [2025-04-24 15:19:28,879] [ INFO] - loss: 4.1729598, learning_rate: 0.0003, global_step: 9, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.498, interval_samples_per_second: 2.0082, interval_steps_per_second: 2.0082, ppl: 64.90728064956862, progress_or_epoch: 0.018 False m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00361453) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00000435) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.80999994) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99800104) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.01861572) [2025-04-24 15:19:29,404] [ INFO] - loss: 1.64721501, learning_rate: 0.0002994, global_step: 10, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5251, interval_samples_per_second: 1.9044, interval_steps_per_second: 1.9044, ppl: 5.192498614806668, progress_or_epoch: 0.02 Found inf or nan, current scale is: 128.0, decrease to: 128.00.5 [2025-04-24 15:19:29,912] [ WARNING] - optimizer not run, scale_before: 128.0, scale_after: 64.0 [2025-04-24 15:19:29,917] [ INFO] - loss: 3.22934556, learning_rate: 0.0002994, global_step: 11, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5125, interval_samples_per_second: 1.9513, interval_steps_per_second: 1.9513, ppl: 25.263118364607866, progress_or_epoch: 0.022 False m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00511430) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00000622) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.72899991) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99700308) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, -0.13641964) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, -0.13635254) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.03591919) [2025-04-24 15:19:30,438] [ INFO] - loss: 3.001441, learning_rate: 0.0002988, global_step: 12, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.521, interval_samples_per_second: 1.9193, interval_steps_per_second: 1.9193, ppl: 20.114501045532172, progress_or_epoch: 0.024 False m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00819490) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00000712) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.65609992) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99600607) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, -0.22575189) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, -0.22570801) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, -0.05255127) [2025-04-24 15:19:30,962] [ INFO] - loss: 2.2637279, learning_rate: 0.0002982, global_step: 13, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.5247, interval_samples_per_second: 1.9058, interval_steps_per_second: 1.9058, ppl: 9.618880636929768, progress_or_epoch: 0.026

而在 AdamWMini 中,

[2025-04-24 15:16:38,172] [ INFO] - Total num train samples = 500 [2025-04-24 15:16:38,175] [ DEBUG] - Number of trainable parameters = 20,971,520 (per device) W0424 15:16:38.704833 296812 multiply_fwd_func.cc:76] got different data type, run type promotion automatically, this may cause data type been changed. W0424 15:16:38.712059 296812 gpu_resources.cc:306] WARNING: device: 0. The installed Paddle is compiled with CUDNN 8.9, but CUDNN version in your machine is 8.9, which may cause serious incompatible bug. Please recompile or reinstall Paddle with compatible CUDNN version. Found inf or nan, current scale is: 32768.0, decrease to: 32768.00.5 [2025-04-24 15:16:39,594] [ WARNING] - optimizer not run, scale_before: 32768.0, scale_after: 16384.0 [2025-04-24 15:16:39,599] [ INFO] - loss: 4.01541185, learning_rate: 0.0, global_step: 1, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 1.4227, interval_samples_per_second: 0.7029, interval_steps_per_second: 0.7029, ppl: 55.44612618781211, progress_or_epoch: 0.002 Found inf or nan, current scale is: 16384.0, decrease to: 16384.00.5 [2025-04-24 15:16:40,021] [ WARNING] - optimizer not run, scale_before: 16384.0, scale_after: 8192.0 [2025-04-24 15:16:40,024] [ INFO] - loss: 2.49645019, learning_rate: 0.0, global_step: 2, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.425, interval_samples_per_second: 2.3529, interval_steps_per_second: 2.3529, ppl: 12.139325087796644, progress_or_epoch: 0.004 Found inf or nan, current scale is: 8192.0, decrease to: 8192.00.5 [2025-04-24 15:16:40,433] [ WARNING] - optimizer not run, scale_before: 8192.0, scale_after: 4096.0 [2025-04-24 15:16:40,436] [ INFO] - loss: 1.44202387, learning_rate: 0.0, global_step: 3, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4121, interval_samples_per_second: 2.4269, interval_steps_per_second: 2.4269, ppl: 4.229246606564234, progress_or_epoch: 0.006 Found inf or nan, current scale is: 4096.0, decrease to: 4096.00.5 [2025-04-24 15:16:40,840] [ WARNING] - optimizer not run, scale_before: 4096.0, scale_after: 2048.0 [2025-04-24 15:16:40,843] [ INFO] - loss: 5.85809898, learning_rate: 0.0, global_step: 4, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4068, interval_samples_per_second: 2.4583, interval_steps_per_second: 2.4583, ppl: 350.05804374322315, progress_or_epoch: 0.008 Found inf or nan, current scale is: 2048.0, decrease to: 2048.0*0.5 [2025-04-24 15:16:41,249] [ WARNING] - optimizer not run, scale_before: 2048.0, scale_after: 1024.0 [2025-04-24 15:16:41,252] [ INFO] - loss: 1.94955564, learning_rate: 0.0, global_step: 5, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.409, interval_samples_per_second: 2.445, interval_steps_per_second: 2.445, ppl: 7.025565006800808, progress_or_epoch: 0.01

False

m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.89999998) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99900001) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.03613281) [2025-04-24 15:16:41,970] [ INFO] - loss: 2.19820595, learning_rate: 0.0003, global_step: 6, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.7184, interval_samples_per_second: 1.3921, interval_steps_per_second: 1.3921, ppl: 9.008836689307477, progress_or_epoch: 0.012 Found inf or nan, current scale is: 1024.0, decrease to: 1024.00.5 [2025-04-24 15:16:42,371] [ WARNING] - optimizer not run, scale_before: 1024.0, scale_after: 512.0 [2025-04-24 15:16:42,374] [ INFO] - loss: 3.99188352, learning_rate: 0.0003, global_step: 7, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4037, interval_samples_per_second: 2.4771, interval_steps_per_second: 2.4771, ppl: 54.15679877261727, progress_or_epoch: 0.014 Found inf or nan, current scale is: 512.0, decrease to: 512.00.5 [2025-04-24 15:16:42,773] [ WARNING] - optimizer not run, scale_before: 512.0, scale_after: 256.0 [2025-04-24 15:16:42,776] [ INFO] - loss: 2.8999517, learning_rate: 0.0003, global_step: 8, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4018, interval_samples_per_second: 2.4887, interval_steps_per_second: 2.4887, ppl: 18.173267579420518, progress_or_epoch: 0.016 Found inf or nan, current scale is: 256.0, decrease to: 256.00.5 [2025-04-24 15:16:43,173] [ WARNING] - optimizer not run, scale_before: 256.0, scale_after: 128.0 [2025-04-24 15:16:43,176] [ INFO] - loss: 4.1729598, learning_rate: 0.0003, global_step: 9, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4001, interval_samples_per_second: 2.4994, interval_steps_per_second: 2.4994, ppl: 64.90728064956862, progress_or_epoch: 0.018 False m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00362164) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00000072) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.80999994) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99800104) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.01861572) [2025-04-24 15:16:43,868] [ INFO] - loss: 1.64721501, learning_rate: 0.0002994, global_step: 10, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.6926, interval_samples_per_second: 1.4438, interval_steps_per_second: 1.4438, ppl: 5.192498614806668, progress_or_epoch: 0.02 Found inf or nan, current scale is: 128.0, decrease to: 128.00.5 [2025-04-24 15:16:44,266] [ WARNING] - optimizer not run, scale_before: 128.0, scale_after: 64.0 [2025-04-24 15:16:44,269] [ INFO] - loss: 26.33701515, learning_rate: 0.0002994, global_step: 11, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4004, interval_samples_per_second: 2.4975, interval_steps_per_second: 2.4975, ppl: 274170263505.28873, progress_or_epoch: 0.022 Found inf or nan, current scale is: 64.0, decrease to: 64.0*0.5 [2025-04-24 15:16:44,670] [ WARNING] - optimizer not run, scale_before: 64.0, scale_after: 32.0 [2025-04-24 15:16:44,673] [ INFO] - loss: 25.63891411, learning_rate: 0.0002994, global_step: 12, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.4043, interval_samples_per_second: 2.4735, interval_steps_per_second: 2.4735, ppl: 136407710588.60149, progress_or_epoch: 0.024 False m1 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00512069) m2 Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, 0.00000083) b1 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.72899991) b2 Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.99700308) master_weight Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True, -1044.12109375) param Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, -1044.) grad Tensor(shape=[], dtype=float16, place=Place(gpu:0), stop_gradient=True, 0.) [2025-04-24 15:16:45,375] [ INFO] - loss: 21.70650673, learning_rate: 0.0002988, global_step: 13, current_memory_allocated: 15.395037412643433, current_memory_reserved: 16.195343732833862, max_memory_allocated: 16.186050415039062, max_memory_reserved: 16.195343732833862, interval_runtime: 0.7015, interval_samples_per_second: 1.4254, interval_steps_per_second: 1.4254, ppl: 2673105467.692815, progress_or_epoch: 0.026

可以看到,参数的更新明显出现问题,与 c++ 的结果不一样 ~

这里尝试使用 in-place 的算子也没法与 AdamW 对齐,感觉像是 python 的 tensor 更新有问题?还请帮忙看一下~~~

@DrownFish19 还请帮忙确认一下,感谢!!!

附:

{ "model_name_or_path": "meta-llama/Meta-Llama-3-8B", "dataset_name_or_path": "./data", "output_dir": "./checkpoints/lora_ckpts", "per_device_train_batch_size": 1, "gradient_accumulation_steps": 1, "per_device_eval_batch_size": 1, "eval_accumulation_steps":16, "num_train_epochs": 1, "learning_rate": 3e-04, "warmup_steps": 1, "logging_steps": 1, "evaluation_strategy": "epoch", "save_strategy": "epoch", "src_length": 1024, "max_length": 2048, "bf16": false, "fp16": true, "fp16_opt_level": "O2", "do_train": true, "do_eval": true, "disable_tqdm": true, "load_best_model_at_end": true, "eval_with_do_generation": false, "metric_for_best_model": "accuracy", "recompute": true, "save_total_limit": 1, "tensor_parallel_degree": 1, "pipeline_parallel_degree": 1, "sharding": "stage1", "lora": true, "zero_padding": false, "use_flash_attention": false, "unified_checkpoint": true, "pissa": false, "use_mora": false, "optim": "adamw_mini" }