fix: Bugfix in Linear-to-AddMM Fusion Lowering Pass by gs-olive · Pull Request #1619 · pytorch/TensorRT (original) (raw)

  %4 : Tensor = prim::If(%true)
    block0():
      %res = aten::linear(%input, %weight, %biasNone)
      -> (%res)
    block1():
      %res = aten::linear(%input, %weight, %biasNone)
      -> (%res)

  %4 : Tensor = prim::If(%true)
    block0():
      %res = aten::linear(%input, %weight, %bias)
      -> (%res)
    block1():
      %res = aten::linear(%input, %invalid_weight, %bias)
      -> (%res)

=============== TRANSLATES TO ===============

%13 : int = prim::Constantvalue=1 %14 : Tensor = aten::t(%weight) %15 : Tensor = aten::matmul(%input, %14) %16 : Tensor = trt::const(%bias) %17 : Tensor = aten::add(%16, %15, %13) %3 : bool = prim::Constantvalue=1 %4 : Tensor = aten::t(%weight) %8 : int = prim::Constantvalue=1 %9 : Tensor = aten::t(%4) %10 : Tensor = aten::matmul(%input, %9) %11 : Tensor = trt::const(%bias) %12 : Tensor = aten::add(%11, %10, %8) %5 : Tensor = prim::If(%3) block0(): -> (%17) block1(): -> (%12)

=============== LEADING TO ===============

%10 : Tensor = aten::matmul(%input, %9): last dimension of input0 = 7 and second to last dimension of input1 = 8 but must match.