【Hackathon 8th No.12】在 PaddleScience 中实现 SOAP 优化器 by BeingGod · Pull Request #1102 · PaddlePaddle/PaddleScience (original) (raw)

为验证 SOAP 实现正确性,基于 MLP demo + SOAP 对torch与paddle 的 loss 进行对比

前 20iter loss对比如下:

λ beinggod-workstation /workspace/hackathon/SOAP/SOAP python soap_paddle.py grep: warning: GREP_OPTIONS is deprecated; please use an alias or script W0313 05:29:07.938906 1188667 gpu_resources.cc:119] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 12.6, Runtime API Version: 11.8 W0313 05:29:07.939640 1188667 gpu_resources.cc:164] device: 0, cuDNN Version: 8.9. pddle param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562 torch param name: weight, mean: 0.006312752142548561, sum: 103.42813110351562 [ITERATION] 1/20 loss paddle: 12.870917, loss torch: 12.870915 [ITERATION] 2/20 loss paddle: 15.275839, loss torch: 15.275839 [ITERATION] 3/20 loss paddle: 9.223783, loss torch: 9.223782 [ITERATION] 4/20 loss paddle: 9.283824, loss torch: 9.283824 [ITERATION] 5/20 loss paddle: 7.028357, loss torch: 7.028356 [ITERATION] 6/20 loss paddle: 7.973010, loss torch: 7.973009 [ITERATION] 7/20 loss paddle: 11.567774, loss torch: 11.567776 [ITERATION] 8/20 loss paddle: 5.823763, loss torch: 5.823766 [ITERATION] 9/20 loss paddle: 12.174599, loss torch: 12.174601 [ITERATION] 10/20 loss paddle: 8.206469, loss torch: 8.206469 [ITERATION] 11/20 loss paddle: 7.991440, loss torch: 7.991440 [ITERATION] 12/20 loss paddle: 7.984601, loss torch: 7.984600 [ITERATION] 13/20 loss paddle: 9.944571, loss torch: 9.944571 [ITERATION] 14/20 loss paddle: 10.886424, loss torch: 10.886503 [ITERATION] 15/20 loss paddle: 8.474144, loss torch: 8.474236 [ITERATION] 16/20 loss paddle: 7.847643, loss torch: 7.847647 [ITERATION] 17/20 loss paddle: 10.080597, loss torch: 10.080593 [ITERATION] 18/20 loss paddle: 9.120567, loss torch: 9.120558 [ITERATION] 19/20 loss paddle: 7.173148, loss torch: 7.173133 [ITERATION] 20/20 loss paddle: 7.039429, loss torch: 7.039426 λ beinggod-workstation /workspace/hackathon/SOAP/SOAP

loss误差在1e-5,可以认为二者实现是对齐的

测试代码:

import paddle import paddle.device

from itertools import chain from collections import defaultdict

import paddle.optimizer import paddle.utils

import torch import numpy as np

from itertools import chain

import torch.utils import torch.utils.data

Parts of the code are modifications of Pypaddle's AdamW optimizer

Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_paddle/galore_projector.py

seed = 1234 paddle.device.set_device('gpu') paddle.seed(seed)

torch.manual_seed(seed) torch.cuda.manual_seed(seed)

class SOAP_paddle(paddle.optimizer.Optimizer): """ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).

Parameters:
    params (`list|tuple`):
        Iterable of parameters to optimize or dictionaries defining parameter groups.
    lr (`float`, *optional*, defaults to 0.003):
        The learning rate to use.
    betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
        Adam's betas parameters (b1, b2).
    shampoo_beta (`float`, *optional*, defaults to -1):
        If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
    eps (`float`, *optional*, defaults to 1e-08):
        Adam's epsilon for numerical stability.
    weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
    precondition_frequency (`int`, *optional*, defaults to 10):
        How often to update the preconditioner.
    max_precond_dim (`int`, *optional*, defaults to 10000):
        Maximum dimension of the preconditioner.
        Set to 10000, so that we exclude most common vocab sizes while including layers.
    merge_dims (`bool`, *optional*, defaults to `False`):
        Whether or not to merge dimensions of the preconditioner.
    precondition_1d (`bool`, *optional*, defaults to `False`):
        Whether or not to precondition 1D gradients.
    normalize_grads (`bool`, *optional*, defaults to `False`):
        Whether or not to normalize gradients per layer. 
        Helps at large precondition_frequency (~100 in our experiments), 
        but hurts performance at small precondition_frequency (~10 in our experiments).
    data_format (`str`, *optional*, defaults to `channels_first`):
        Data format of the input for convolutional layers.
        Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
    correct_bias (`bool`, *optional*, defaults to `True`):
        Whether or not to use bias correction in Adam.
    name (str, optional): Normally there is no need for user to set this property.
        For more information, please refer to :ref:`api_guide_Name`.
        The default value is None.
"""

def __init__(
    self,
    params,
    lr: float = 3e-3,
    betas=(0.95, 0.95),
    shampoo_beta: float= -1,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
    precondition_frequency: int=10,
    max_precond_dim: int=10000, # 
    merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
    precondition_1d: bool = False,
    normalize_grads: bool = False,
    data_format: str = "channels_first",
    correct_bias: bool = True,
    name: str = None,
):
    self._betas = betas
    self._shampoo_beta = shampoo_beta
    self._eps = eps
    self._precondition_frequency = precondition_frequency
    self._max_precond_dim = max_precond_dim
    self._merge_dims = merge_dims
    self._precondition_1d = precondition_1d
    self._normalize_grads = normalize_grads
    self._correct_bias = correct_bias
    self._weight_decay = weight_decay
    
    self.state = defaultdict(dict)

    super().__init__(learning_rate=lr, 
                     parameters=params,
                     weight_decay=weight_decay,
                     name=name)

    if isinstance(self._parameter_list[0],dict):
        raise TypeError(
            "The parameter groups is not supported on SOAP optimizer."
        )

    self._data_format = data_format

    
def merge_dims(self, grad, max_precond_dim):
    """
    Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
    """
    assert self._data_format in ["channels_first", "channels_last"]
    if self._data_format == "channels_last" and grad.dim() == 4:
        grad = grad.transpose(0, 3, 1, 2)
    shape = grad.shape
    new_shape = []
    
    curr_shape = 1
    for sh in shape:
        temp_shape = curr_shape * sh
        if temp_shape > max_precond_dim:
            if curr_shape > 1:
                new_shape.append(curr_shape)
                curr_shape = sh
            else:
                new_shape.append(sh)
                curr_shape = 1
        else:
            curr_shape = temp_shape
    
    if curr_shape > 1 or len(new_shape)==0:
        new_shape.append(curr_shape)
    
    new_grad = grad.reshape(new_shape)
    return new_grad               

@paddle.base.framework.non_static_only
def step(self, closure = None):
    """
    Performs a single optimization step.

    Arguments:
        closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
    """
    with paddle.no_grad():
        if closure is None:
            loss = None
        else:
            closure = paddle.enable_grad()(closure)
            loss = closure()

        for p in self._parameter_list:
            if p.grad is None:
                continue
            grad = p.grad

            state = self.state[p]
            
            if "step" not in state:
                state["step"] = 0 
                
            # State initialization
            if "exp_avg" not in state:
                # Exponential moving average of gradient values
                state["exp_avg"] = paddle.zeros_like(grad)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = paddle.zeros_like(grad)
            
            if 'Q' not in state:
                self.init_preconditioner(
                    grad,
                    state,
                    precondition_frequency=self._precondition_frequency,
                    precondition_1d=self._precondition_1d,
                    shampoo_beta=(self._shampoo_beta if self._shampoo_beta >= 0 else self._betas[1]),
                    max_precond_dim=self._max_precond_dim,
                    merge_dims=self._merge_dims,
                )
                self.update_preconditioner(grad, state,
                                            max_precond_dim=self._max_precond_dim,
                                            merge_dims=self._merge_dims,
                                            precondition_1d=self._precondition_1d)
                continue # first step is skipped so that we never use the current gradients in the projection.
            
            # Projecting gradients to the eigenbases of Shampoo's preconditioner 
            # i.e. projecting to the eigenbases of matrices in state['GG']
            grad_projected = self.project(grad, state, merge_dims=self._merge_dims, 
                                            max_precond_dim=self._max_precond_dim)

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
            beta1, beta2 = paddle.to_tensor(self._betas)

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            # In-place operations to update the averages at the same time
            exp_avg.multiply_(beta1).add_((1.0 - beta1)*grad_projected)
            exp_avg_sq.multiply_(beta2).add_((1.0-beta2)*grad_projected.square())

            denom = exp_avg_sq.sqrt().add_(paddle.to_tensor(self._eps))
            
            # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 
            # i.e. projecting to the eigenbases of matrices in state['GG']
            # exp_avg_projected = self.project(exp_avg, state, merge_dims=self._merge_dims"],
            #                                  max_precond_dim=self._max_precond_dim'])
            exp_avg_projected = exp_avg
            
            lr = self._learning_rate
            step_size = lr
            if self._correct_bias:
                bias_correction1 = 1.0 - beta1 ** (state["step"])
                bias_correction2 = 1.0 - beta2 ** (state["step"])
                step_size = step_size * (bias_correction2 ** .5) / bias_correction1

            # Projecting back the preconditioned (by Adam) exponential moving average of gradients
            # to the original space
            norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=self._merge_dims,
                                                max_precond_dim=self._max_precond_dim)

            if self._normalize_grads:
                norm_grad = norm_grad / (1e-30+paddle.mean(norm_grad**2)**0.5)
            
            p.add_(-step_size * norm_grad)
            

            # From AdamW code: Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want to decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            # Add weight decay at the end (fixed version)
            if self._weight_decay > 0.0:
                p.add_((-lr * self._weight_decay) * p)
                
            # Update is done after the gradient step to avoid using current gradients in the projection.
            self.update_preconditioner(grad, state, 
                                            max_precond_dim=self._max_precond_dim,
                                            merge_dims=self._merge_dims,
                                            precondition_1d=self._precondition_1d)
        
    return loss

def init_preconditioner(self, grad, state, precondition_frequency=10, 
                        shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
                        merge_dims=False):
    """
    Initializes the preconditioner matrices (L and R in the paper).
    """
    state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
    if grad.dim() == 1:
        if not precondition_1d or grad.shape[0] > max_precond_dim:
            state['GG'].append([])
        else:
            state['GG'].append(paddle.zeros([grad.shape[0], grad.shape[0]]))
    else:
        if merge_dims:
            grad = self.merge_dims(grad, max_precond_dim)

        for sh in grad.shape:
            if sh > max_precond_dim:
                state['GG'].append([])
            else:
                state['GG'].append(paddle.zeros([sh, sh]))
                
    state['Q'] = None # Will hold all the eigenbases of the preconditioner.
    state['precondition_frequency'] = precondition_frequency
    state['shampoo_beta'] = shampoo_beta          
    
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
    """
    Projects the gradient to the eigenbases of the preconditioner.
    """
    original_shape = grad.shape
    if merge_dims:
        if grad.dim() == 4 and self._data_format == 'channels_last':
            transposed_shape = grad.transpose(0, 3, 1, 2).shape
        grad = self.merge_dims(grad, max_precond_dim)
    
    for mat in state['Q']:
        if len(mat) > 0:
            grad = paddle.tensordot(
                    grad,
                    mat,
                    axes=[[0], [0]],
                )
        else:
            transpose_order = list(range(1, len(grad.shape))) + [0]
            grad = grad.transpose(transpose_order)
    
    if merge_dims:
        if self._data_format == 'channels_last' and len(original_shape) == 4:
            grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
        else:
            grad = grad.reshape(original_shape)
    return grad
    
def update_preconditioner(self, grad, state, 
                          max_precond_dim=10000, merge_dims=False, precondition_1d=False):
    """
    Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
    """
    if state["Q"] is not None:
        state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
    if grad.dim() == 1:
        if precondition_1d and grad.shape[0] <= max_precond_dim:
            state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
    else:
        if merge_dims:
            new_grad = self.merge_dims(grad, max_precond_dim)
            for idx, sh in enumerate(new_grad.shape):
                if sh <= max_precond_dim:
                    outer_product = paddle.tensordot(
                            new_grad,
                            new_grad,
                            axes=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
                        )
                    state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
        else:
            for idx, sh in enumerate(grad.shape):
                if sh <= max_precond_dim:
                    outer_product = paddle.tensordot(
                            grad,
                            grad,
                            # Contracts across all dimensions except for k.
                            axes=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                        )
                    state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
    
    if state['Q'] is None:
        state['Q'] = self.get_orthogonal_matrix(state['GG'])
    if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
        state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
        # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             

    if state["step"] > 0:
        state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 

def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
    """
    Projects the gradient back to the original space.
    """
    original_shape = grad.shape
    if merge_dims:
        if self._data_format == 'channels_last' and grad.dim() == 4:
            transposed_shape = grad.transpose(0, 3, 1, 2).shape
        grad = self.merge_dims(grad, max_precond_dim)
    for mat in state['Q']:
        if len(mat) > 0:
            grad = paddle.tensordot(
                    grad,
                    mat,
                    axes=[[0], [1]],
                )
        else:
            transpose_order = list(range(1, len(grad.shape))) + [0]
            grad = grad.transpose(transpose_order)
            
    if merge_dims:
        if self._data_format == 'channels_last' and len(original_shape) == 4:
            grad = grad.reshape(transposed_shape).transpose(0, 2, 3, 1)
        else:
            grad = grad.reshape(original_shape)
    return grad
    

def get_orthogonal_matrix(self, mat):
    """
    Computes the eigenbases of the preconditioner using paddle.linalg.eigh decomposition.
    """
    matrix = []
    for m in mat:
        if len(m) == 0:
            matrix.append([])
            continue
        if m.data.dtype != paddle.float32:
            float_data = False
            original_type = m.data.dtype
            original_device = m.data.place
            matrix.append(m.data.to(paddle.float32))
        else:
            float_data = True
            matrix.append(m.data)
    
    final = []
    for m in matrix:
        if len(m) == 0:
            final.append([])
            continue
        # try:
        #     _, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
        # except:
        #     _, Q = paddle.linalg.eigh(m.to(paddle.float64)+1e-30*paddle.eye(m.shape[0]))
        #     Q = Q.to(m.dtype)
        _, Q = paddle.linalg.eigh(m+1e-30*paddle.eye(m.shape[0]))
        Q = paddle.flip(Q, [1])

        if not float_data:
            Q = Q.to(original_device, dtype=original_type)
        final.append(Q)
    return final
    

def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
    """
    Computes the eigenbases of the preconditioner using one round of power iteration 
    followed by paddle.linalg.qr decomposition.
    """
    precond_list = state['GG']
    orth_list = state['Q']

    matrix = []
    orth_matrix = []
    for m,o in zip(precond_list, orth_list):
        if len(m) == 0:
            matrix.append([])
            orth_matrix.append([])
            continue
        if m.data.dtype != paddle.float32:
            float_data = False
            original_type = m.data.dtype
            original_device = m.data.place
            matrix.append(m.data.to(paddle.float32))
            orth_matrix.append(o.data.to(paddle.float32))
        else:
            float_data = True
            matrix.append(m.data.to(paddle.float32))
            orth_matrix.append(o.data.to(paddle.float32))
    
    orig_shape = state['exp_avg_sq'].shape
    if self._data_format == 'channels_last' and len(orig_shape) == 4:
        transposed_shape = state['exp_avg_sq'].transpose(0, 3, 1, 2).shape
    if merge_dims:
        exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
    else:
        exp_avg_sq = state['exp_avg_sq']
        
    final = []
    for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
        if len(m)==0:
            final.append([])
            continue
        est_eig = paddle.diag(o.T @ m @ o)
        sort_idx = paddle.argsort(est_eig, descending=True)
        exp_avg_sq = exp_avg_sq.index_select(sort_idx, ind)
        o = o[:,sort_idx]
        power_iter = m @ o
        Q, _ = paddle.linalg.qr(power_iter)

        if not float_data:
            Q = Q.to(original_device, dtype=original_type)
        final.append(Q)
    
    if merge_dims:
        if self._data_format == 'channels_last' and len(orig_shape) == 4:
            exp_avg_sq = exp_avg_sq.reshape(transposed_shape).transpose(0, 2, 3, 1)
        else:
            exp_avg_sq = exp_avg_sq.reshape(orig_shape)
            
    state['exp_avg_sq'] = exp_avg_sq
    return final

class MLP_paddle(paddle.nn.Layer):

def __init__(self, in_features, out_features):
    super().__init__()
    self._internal_weight = paddle.randn([out_features, in_features])
    self.weight = paddle.create_parameter(shape=self._internal_weight.shape,
                    dtype=self._internal_weight.dtype,
                    default_initializer=paddle.nn.initializer.Assign(self._internal_weight))
    self.weight.stop_gradient = False

def forward(self, inp):
    return paddle.matmul(inp, self.weight.T)

Parts of the code are modifications of Pytorch's AdamW optimizer

Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py

class SOAP(torch.optim.Optimizer): """ Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).

Parameters:
    params (`Iterable[nn.parameter.Parameter]`):
        Iterable of parameters to optimize or dictionaries defining parameter groups.
    lr (`float`, *optional*, defaults to 0.003):
        The learning rate to use.
    betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
        Adam's betas parameters (b1, b2).
    shampoo_beta (`float`, *optional*, defaults to -1):
        If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
    eps (`float`, *optional*, defaults to 1e-08):
        Adam's epsilon for numerical stability.
    weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
    precondition_frequency (`int`, *optional*, defaults to 10):
        How often to update the preconditioner.
    max_precond_dim (`int`, *optional*, defaults to 10000):
        Maximum dimension of the preconditioner.
        Set to 10000, so that we exclude most common vocab sizes while including layers.
    merge_dims (`bool`, *optional*, defaults to `False`):
        Whether or not to merge dimensions of the preconditioner.
    precondition_1d (`bool`, *optional*, defaults to `False`):
        Whether or not to precondition 1D gradients.
    normalize_grads (`bool`, *optional*, defaults to `False`):
        Whether or not to normalize gradients per layer. 
        Helps at large precondition_frequency (~100 in our experiments), 
        but hurts performance at small precondition_frequency (~10 in our experiments).
    data_format (`str`, *optional*, defaults to `channels_first`):
        Data format of the input for convolutional layers.
        Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
    correct_bias (`bool`, *optional*, defaults to `True`):
        Whether or not to use bias correction in Adam.
"""

def __init__(
    self,
    params,
    lr: float = 3e-3,
    betas=(0.95, 0.95),
    shampoo_beta: float= -1,
    eps: float = 1e-8,
    weight_decay: float = 0.01,
    precondition_frequency: int=10,
    max_precond_dim: int=10000, # 
    merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
    precondition_1d: bool = False,
    normalize_grads: bool = False,
    data_format: str = "channels_first",
    correct_bias: bool = True,
):
    defaults = {
        "lr": lr,
        "betas": betas,
        "shampoo_beta": shampoo_beta,
        "eps": eps,
        "weight_decay": weight_decay,
        "precondition_frequency": precondition_frequency,
        "max_precond_dim": max_precond_dim,
        "merge_dims": merge_dims,
        "precondition_1d": precondition_1d,
        "normalize_grads": normalize_grads,
        "correct_bias": correct_bias,
    }
    super().__init__(params, defaults)
    self._data_format = data_format
    
def merge_dims(self, grad, max_precond_dim):
    """
    Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
    """
    assert self._data_format in ["channels_first", "channels_last"]
    if self._data_format == "channels_last" and grad.dim() == 4:
        grad = grad.permute(0, 3, 1, 2)
    shape = grad.shape
    new_shape = []
    
    curr_shape = 1
    for sh in shape:
        temp_shape = curr_shape * sh
        if temp_shape > max_precond_dim:
            if curr_shape > 1:
                new_shape.append(curr_shape)
                curr_shape = sh
            else:
                new_shape.append(sh)
                curr_shape = 1
        else:
            curr_shape = temp_shape
    
    if curr_shape > 1 or len(new_shape)==0:
        new_shape.append(curr_shape)
    
    new_grad = grad.reshape(new_shape)
    return new_grad               

@torch.no_grad()
def step(self, closure = None):
    """
    Performs a single optimization step.

    Arguments:
        closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
    """
    if closure is None:
        loss = None
    else:
        loss = closure()
        
    for group in self.param_groups:
        for p in group["params"]:
            if p.grad is None:
                continue
            grad = p.grad

            state = self.state[p]
            
            if "step" not in state:
                state["step"] = 0 
                
            # State initialization
            if "exp_avg" not in state:
                # Exponential moving average of gradient values
                state["exp_avg"] = torch.zeros_like(grad)
                # Exponential moving average of squared gradient values
                state["exp_avg_sq"] = torch.zeros_like(grad)
            
            if 'Q' not in state:
                self.init_preconditioner(
                    grad,
                    state,
                    precondition_frequency=group['precondition_frequency'],
                    precondition_1d=group['precondition_1d'],
                    shampoo_beta=(group['shampoo_beta'] if group['shampoo_beta'] >= 0 else group["betas"][1]),
                    max_precond_dim=group['max_precond_dim'],
                    merge_dims=group["merge_dims"],
                )
                self.update_preconditioner(grad, state,
                                           max_precond_dim=group['max_precond_dim'],
                                           merge_dims=group["merge_dims"],
                                           precondition_1d=group["precondition_1d"])
                continue # first step is skipped so that we never use the current gradients in the projection.
            
            # Projecting gradients to the eigenbases of Shampoo's preconditioner 
            # i.e. projecting to the eigenbases of matrices in state['GG']
            grad_projected = self.project(grad, state, merge_dims=group["merge_dims"], 
                                          max_precond_dim=group['max_precond_dim'])

            exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
            beta1, beta2 = group["betas"]

            state["step"] += 1

            # Decay the first and second moment running average coefficient
            # In-place operations to update the averages at the same time
            exp_avg.mul_(beta1).add_(grad_projected, alpha=(1.0 - beta1))
            exp_avg_sq.mul_(beta2).add_(grad_projected.square(), alpha=(1.0 - beta2))

            denom = exp_avg_sq.sqrt().add_(group["eps"])
            
            # Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner 
            # i.e. projecting to the eigenbases of matrices in state['GG']
            # exp_avg_projected = self.project(exp_avg, state, merge_dims=group["merge_dims"],
            #                                  max_precond_dim=group['max_precond_dim'])
            exp_avg_projected = exp_avg
            
            step_size = group["lr"]
            if group["correct_bias"]:
                bias_correction1 = 1.0 - beta1 ** (state["step"])
                bias_correction2 = 1.0 - beta2 ** (state["step"])
                step_size = step_size * (bias_correction2 ** .5) / bias_correction1

            # Projecting back the preconditioned (by Adam) exponential moving average of gradients
            # to the original space
            norm_grad = self.project_back(exp_avg_projected / denom, state, merge_dims=group["merge_dims"],
                                             max_precond_dim=group['max_precond_dim'])

            if group["normalize_grads"]:
                norm_grad = norm_grad / (1e-30+torch.mean(norm_grad**2)**0.5)
            
            p.add_(norm_grad, alpha=-step_size)
            

            # From AdamW code: Just adding the square of the weights to the loss function is *not*
            # the correct way of using L2 regularization/weight decay with Adam,
            # since that will interact with the m and v parameters in strange ways.
            #
            # Instead we want to decay the weights in a manner that doesn't interact
            # with the m/v parameters. This is equivalent to adding the square
            # of the weights to the loss with plain (non-momentum) SGD.
            # Add weight decay at the end (fixed version)
            if group["weight_decay"] > 0.0:
                p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
                
            # Update is done after the gradient step to avoid using current gradients in the projection.
            self.update_preconditioner(grad, state, 
                                           max_precond_dim=group['max_precond_dim'],
                                           merge_dims=group["merge_dims"],
                                           precondition_1d=group["precondition_1d"])
    
    return loss

def init_preconditioner(self, grad, state, precondition_frequency=10, 
                        shampoo_beta=0.95, max_precond_dim=10000, precondition_1d=False,
                        merge_dims=False):
    """
    Initializes the preconditioner matrices (L and R in the paper).
    """
    state['GG'] = [] # Will hold all the preconditioner matrices (L and R in the paper).
    if grad.dim() == 1:
        if not precondition_1d or grad.shape[0] > max_precond_dim:
            state['GG'].append([])
        else:
            state['GG'].append(torch.zeros(grad.shape[0], grad.shape[0], device=grad.device))
    else:
        if merge_dims:
            grad = self.merge_dims(grad, max_precond_dim)

        for sh in grad.shape:
            if sh > max_precond_dim:
                state['GG'].append([])
            else:
                state['GG'].append(torch.zeros(sh, sh, device=grad.device))

    state['Q'] = None # Will hold all the eigenbases of the preconditioner.
    state['precondition_frequency'] = precondition_frequency
    state['shampoo_beta'] = shampoo_beta          
    
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
    """
    Projects the gradient to the eigenbases of the preconditioner.
    """
    original_shape = grad.shape
    if merge_dims:
        if grad.dim() == 4 and self._data_format == 'channels_last':
            permuted_shape = grad.permute(0, 3, 1, 2).shape
        grad = self.merge_dims(grad, max_precond_dim)

    for mat in state['Q']:
        if len(mat) > 0:
            grad = torch.tensordot(
                    grad,
                    mat,
                    dims=[[0], [0]],
                )
        else:
            permute_order = list(range(1, len(grad.shape))) + [0]
            grad = grad.permute(permute_order)
    
    if merge_dims:
        if self._data_format == 'channels_last' and len(original_shape) == 4:
            grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
        else:
            grad = grad.reshape(original_shape)
    return grad
    
def update_preconditioner(self, grad, state, 
                          max_precond_dim=10000, merge_dims=False, precondition_1d=False):
    """
    Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
    """
    if state["Q"] is not None:
        state["exp_avg"] = self.project_back(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim)
    if grad.dim() == 1:
        if precondition_1d and grad.shape[0] <= max_precond_dim:
            state['GG'][0].lerp_(grad.unsqueeze(1) @ grad.unsqueeze(0), 1-state['shampoo_beta'])
    else:
        if merge_dims:
            new_grad = self.merge_dims(grad, max_precond_dim)
            for idx, sh in enumerate(new_grad.shape):
                if sh <= max_precond_dim:
                    outer_product = torch.tensordot(
                            new_grad,
                            new_grad,
                            dims=[[*chain(range(idx), range(idx + 1, len(new_grad.shape)))]] * 2,
                        )
                    state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
        else:
            for idx, sh in enumerate(grad.shape):
                if sh <= max_precond_dim:
                    outer_product = torch.tensordot(
                            grad,
                            grad,
                            # Contracts across all dimensions except for k.
                            dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]] * 2,
                        )
                    state['GG'][idx].lerp_(outer_product, 1-state['shampoo_beta'])
                 
    if state['Q'] is None:
        state['Q'] = self.get_orthogonal_matrix(state['GG'])
    if state['step'] > 0 and state['step'] % state['precondition_frequency'] == 0:
        state['Q'] = self.get_orthogonal_matrix_QR(state, max_precond_dim, merge_dims)
        # state['Q'] = self.get_fast_QR(state, max_precond_dim, merge_dims)             

    if state["step"] > 0:
        state["exp_avg"] = self.project(state["exp_avg"], state, merge_dims=merge_dims, max_precond_dim=max_precond_dim) 

def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
    """
    Projects the gradient back to the original space.
    """
    original_shape = grad.shape
    if merge_dims:
        if self._data_format == 'channels_last' and grad.dim() == 4:
            permuted_shape = grad.permute(0, 3, 1, 2).shape
        grad = self.merge_dims(grad, max_precond_dim)
    for mat in state['Q']:
        if len(mat) > 0:
            grad = torch.tensordot(
                    grad,
                    mat,
                    dims=[[0], [1]],
                )
        else:
            permute_order = list(range(1, len(grad.shape))) + [0]
            grad = grad.permute(permute_order)
            
    if merge_dims:
        if self._data_format == 'channels_last' and len(original_shape) == 4:
            grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
        else:
            grad = grad.reshape(original_shape)
    return grad
    

def get_orthogonal_matrix(self, mat):
    """
    Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
    """
    matrix = []
    for m in mat:
        if len(m) == 0:
            matrix.append([])
            continue
        if m.data.dtype != torch.float:
            float_data = False
            original_type = m.data.dtype
            original_device = m.data.device
            matrix.append(m.data.float())
        else:
            float_data = True
            matrix.append(m.data)
    
    final = []
    for m in matrix:
        if len(m) == 0:
            final.append([])
            continue
        try:
            _, Q = torch.linalg.eigh(m+1e-30*torch.eye(m.shape[0], device=m.device))
        except:
            _, Q = torch.linalg.eigh(m.to(torch.float64)+1e-30*torch.eye(m.shape[0], device=m.device))
            Q = Q.to(m.dtype)
        Q = torch.flip(Q, [1])

        if not float_data:
            Q = Q.to(original_device).type(original_type)
        final.append(Q)
    return final
    

def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
    """
    Computes the eigenbases of the preconditioner using one round of power iteration 
    followed by torch.linalg.qr decomposition.
    """
    precond_list = state['GG']
    orth_list = state['Q']

    matrix = []
    orth_matrix = []
    for m,o in zip(precond_list, orth_list):
        if len(m) == 0:
            matrix.append([])
            orth_matrix.append([])
            continue
        if m.data.dtype != torch.float:
            float_data = False
            original_type = m.data.dtype
            original_device = m.data.device
            matrix.append(m.data.float())
            orth_matrix.append(o.data.float())
        else:
            float_data = True
            matrix.append(m.data.float())
            orth_matrix.append(o.data.float())
    
    orig_shape = state['exp_avg_sq'].shape
    if self._data_format == 'channels_last' and len(orig_shape) == 4:
        permuted_shape = state['exp_avg_sq'].permute(0, 3, 1, 2).shape
    if merge_dims:
        exp_avg_sq = self.merge_dims(state['exp_avg_sq'], max_precond_dim)
    else:
        exp_avg_sq = state['exp_avg_sq']
        
    final = []
    for ind, (m,o) in enumerate(zip(matrix, orth_matrix)):
        if len(m)==0:
            final.append([])
            continue
        est_eig = torch.diag(o.T @ m @ o)
        sort_idx = torch.argsort(est_eig, descending=True)
        exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
        o = o[:,sort_idx]
        power_iter = m @ o
        Q, _ = torch.linalg.qr(power_iter)

        if not float_data:
            Q = Q.to(original_device).type(original_type)
        final.append(Q)
    
    if merge_dims:
        if self._data_format == 'channels_last' and len(orig_shape) == 4:
            exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
        else:
            exp_avg_sq = exp_avg_sq.reshape(orig_shape)
            
    state['exp_avg_sq'] = exp_avg_sq
    return final

class MLP_torch(torch.nn.Module):

def __init__(self, in_features, out_features, device=torch.cuda.current_device()):
    super().__init__()
    self.weight = torch.nn.Parameter(torch.randn([out_features, in_features],device=device))

def forward(self, inp):
    # return torch.matmul(inp, self.weight.t())
    return torch.matmul(inp, self.weight.t())

class SampleDataSet(torch.utils.data.Dataset):

def __init__(self, samples, hidden_state):
    super().__init__()

    assert hidden_state % 64 == 0 and hidden_state >= 64

    self._data = torch.rand([samples, hidden_state])
    self._label = torch.rand([samples, hidden_state//64])

def __len__(self):
    return self._data.size(0)

def __getitem__(self, index):
    return self._data[index], self._label[index]
    

if name == "main": samples = 20 hidden_state = 1024

batch_size = 1

sample_dataset = SampleDataSet(samples, hidden_state)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size)

model_torch = MLP_torch(hidden_state, hidden_state//64)
model_paddle = MLP_paddle(hidden_state, hidden_state//64)

for name, param in model_torch.named_parameters():
    print(f"pddle param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")

for name, param in model_torch.named_parameters():
    print(f"torch param name: {name}, mean: {param.mean().item()}, sum: {param.sum().item()}")

lr = 0.03
weight_decay=0
optimizer_paddle = SOAP_paddle(model_paddle.parameters(), lr, weight_decay=weight_decay)
criterion_paddle = paddle.nn.L1Loss(reduction='mean')

optimizer_torch = SOAP(model_torch.parameters(), lr,weight_decay=weight_decay)
criterion_torch = torch.nn.L1Loss(reduction='mean')

params_paddle = model_paddle.parameters()
params_torch = [param for param in model_torch.parameters()]

for param_paddle, param_torch in zip(params_paddle, params_torch):
    param_paddle.set_value(param_torch.detach().cpu().numpy())

for param_paddle, param_torch in zip(params_paddle, params_torch):
    # check param
    np.testing.assert_allclose(param_torch.detach().cpu().numpy(), param_paddle.numpy(), atol=0)

stop = 20
for iter,(inp,label) in enumerate(sample_dataloader):
    inp_numpy,label_numpy = inp.numpy(),label.numpy()
    
    out_paddle = model_paddle(paddle.to_tensor(inp_numpy))
    loss_paddle = criterion_paddle(out_paddle, paddle.to_tensor(label_numpy))
    loss_paddle.backward()

    out_torch = model_torch(torch.from_numpy(inp_numpy).to(torch.cuda.current_device()))
    loss_torch = criterion_torch(out_torch, torch.from_numpy(label_numpy).to(torch.cuda.current_device()))
    loss_torch.backward()


    optimizer_paddle.step()
    optimizer_torch.step()
    
    state_paddle = optimizer_paddle.state 
    state_torch = optimizer_torch.state 


    optimizer_paddle.clear_grad()
    optimizer_torch.zero_grad()

    print(f"[ITERATION] {(iter+1)}/{len(sample_dataloader)} loss paddle: {loss_paddle.item():.6f}, loss torch: {loss_torch.item():.6f}")

    if iter >= stop:
        break