MLIR: lib/Dialect/X86Vector/Transforms/AVXTranspose.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

22 #include "llvm/Support/Format.h"

23 #include "llvm/Support/FormatVariadic.h"

24

25 using namespace mlir;

31

34 auto asmDialectAttr =

36 const auto *asmTp = "vblendps 0,0, 0,1, $2, {0}";

37 const auto *asmCstr =

38 "=x,x,x";

40 auto asmStr = llvm::formatv(asmTp, llvm::format_hex(mask, 2)).str();

41 auto asmOp = b.createLLVM::InlineAsmOp(

42 v1.getType(), asmVals, asmStr,

43 asmCstr, false,

45 asmDialectAttr,

46 ArrayAttr());

47 return asmOp.getResult(0);

48 }

49

52 return b.createvector::ShuffleOp(

54 }

55

58 return b.createvector::ShuffleOp(

60 }

61

62

63

64

67 uint8_t mask) {

68 uint8_t b01, b23, b45, b67;

69 MaskHelper::extractShuffle(mask, b01, b23, b45, b67);

71 b01, b23, b45 + 8, b67 + 8, b01 + 4, b23 + 4, b45 + 8 + 4, b67 + 8 + 4};

72 return b.createvector::ShuffleOp(v1, v2, shuffleMask);

73 }

74

75

76

77

78

79

80

84 auto appendToMask = [&](uint8_t control) {

85 if (control == 0)

87 else if (control == 1)

89 else if (control == 2)

91 else if (control == 3)

92 llvm::append_range(shuffleMask, ArrayRef<int64_t>{12, 13, 14, 15});

93 else

94 llvm_unreachable("control > 3 : overflow");

95 };

96 uint8_t b03, b47;

97 MaskHelper::extractPermute(mask, b03, b47);

98 appendToMask(b03);

99 appendToMask(b47);

100 return b.createvector::ShuffleOp(v1, v2, shuffleMask);

101 }

102

103

106 uint8_t mask) {

108 for (int i = 0; i < 8; ++i) {

109 bool isSet = mask & (1 << i);

110 shuffleMask.push_back(!isSet ? i : i + 8);

111 }

112 return b.createvector::ShuffleOp(v1, v2, shuffleMask);

113 }

114

115

118 #ifndef NDEBUG

120 assert(vs.size() == 4 && "expects 4 vectors");

121 assert(llvm::all_of(ValueRange{vs}.getTypes(),

122 [&](Type t) { return t == vt; }) &&

123 "expects all types to be vector<8xf32>");

124 #endif

125

130 Value s0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 1, 0>());

131 Value s1 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<3, 2, 3, 2>());

132 Value s2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 1, 0>());

133 Value s3 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<3, 2, 3, 2>());

138 }

139

140

144 (void)vt;

145 assert(vs.size() == 8 && "expects 8 vectors");

146 assert(llvm::all_of(ValueRange{vs}.getTypes(),

147 [&](Type t) { return t == vt; }) &&

148 "expects all types to be vector<8xf32>");

149

158

160 Value sh0 = mm256ShufflePs(ib, t0, t2, MaskHelper::shuffle<1, 0, 3, 2>());

161 Value sh2 = mm256ShufflePs(ib, t1, t3, MaskHelper::shuffle<1, 0, 3, 2>());

162 Value sh4 = mm256ShufflePs(ib, t4, t6, MaskHelper::shuffle<1, 0, 3, 2>());

163 Value sh6 = mm256ShufflePs(ib, t5, t7, MaskHelper::shuffle<1, 0, 3, 2>());

164

166 mm256BlendPsAsm(ib, t0, sh0, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());

168 mm256BlendPsAsm(ib, t2, sh0, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());

170 mm256BlendPsAsm(ib, t1, sh2, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());

172 mm256BlendPsAsm(ib, t3, sh2, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());

174 mm256BlendPsAsm(ib, t4, sh4, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());

176 mm256BlendPsAsm(ib, t6, sh4, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());

178 mm256BlendPsAsm(ib, t5, sh6, MaskHelper::blend<0, 0, 1, 1, 0, 0, 1, 1>());

180 mm256BlendPsAsm(ib, t7, sh6, MaskHelper::blend<1, 1, 0, 0, 1, 1, 0, 0>());

181

190 }

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

210 public:

212

214 int benefit)

216 loweringOptions(loweringOptions) {}

217

220 auto loc = op.getLoc();

221

222

223

224 VectorType srcType = op.getSourceVectorType();

225 if (!srcType.getElementType().isF32())

226 return rewriter.notifyMatchFailure(op, "Unsupported vector element type");

227

229 if (failed(srcGtOneDims))

231 op, "expected transposition on a 2D slice");

232

233

234

235 int64_t m = srcType.getDimSize(std::get<0>(srcGtOneDims.value()));

236 int64_t n = srcType.getDimSize(std::get<1>(srcGtOneDims.value()));

237

238 auto applyRewrite = [&]() {

241

242

243

244 auto flattenedType =

246 auto reshInputType = VectorType::get({m, n}, srcType.getElementType());

247 auto reshInput =

248 ib.createvector::ShapeCastOp(flattenedType, op.getVector());

249 reshInput = ib.createvector::ShapeCastOp(reshInputType, reshInput);

250

251

252

253 for (int64_t i = 0; i < m; ++i)

254 vs.push_back(ib.createvector::ExtractOp(reshInput, i));

255

256

257 if (m == 4)

259 if (m == 8)

261

262

263

264 Value res = ib.createarith::ConstantOp(reshInputType,

266 for (int64_t i = 0; i < m; ++i)

267 res = ib.createvector::InsertOp(vs[i], res, i);

268

269

270

271

272 res = ib.createvector::ShapeCastOp(flattenedType, res);

273 res = ib.createvector::ShapeCastOp(op.getResultVectorType(), res);

275 return success();

276 };

277

278 if (loweringOptions.transposeOptions.lower4x8xf32_ && m == 4 && n == 8)

279 return applyRewrite();

280 if (loweringOptions.transposeOptions.lower8x8xf32_ && m == 8 && n == 8)

281 return applyRewrite();

282 return failure();

283 }

284

285 private:

287 };

288

292 }

static llvm::ManagedStatic< PassManagerOptions > options

static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)

Walks the given type hierarchy with the given indices, potentially down to component granularity,...

Rewrite AVX2-specific vector.transpose, for the supported cases and depending on the TransposeLowerin...

TransposeOpLowering(LoweringOptions loweringOptions, MLIRContext *context, int benefit)

LogicalResult matchAndRewrite(vector::TransposeOp op, PatternRewriter &rewriter) const override

TypedAttr getZeroAttr(Type type)

MLIRContext * getContext() const

ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...

OpTy create(Args &&...args)

Create an operation of specific op type at the current insertion point and location.

MLIRContext is the top-level object for a collection of MLIR operations.

A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...

std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)

Used to notify the listener that the IR failed to be rewritten because of a match failure,...

virtual void replaceOp(Operation *op, ValueRange newValues)

Replace the results of the given (original) operation with the specified list of values (replacements...

Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...

This class provides an abstraction over the different types of ranges over Values.

This class represents an instance of an SSA value in the MLIR system, representing a computable value...

Type getType() const

Return the type of this value.

FailureOr< std::pair< int, int > > isTranspose2DSlice(vector::TransposeOp op)

Returns two dims that are greater than one if the transposition is applied on a 2D slice.

Value mm256BlendPsAsm(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)

Methods in the inline_asm namespace emit calls to LLVM::InlineAsmOp.

Value mm256UnpackHiPs(ImplicitLocOpBuilder &b, Value v1, Value v2)

Lower to vector.shuffle v1, v2, [0, 8, 1, 9, 4, 12, 5, 13].

Value mm256BlendPs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)

If bit i of mask is zero, take f32@i from v1 else take it from v2.

Value mm256Permute2f128Ps(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)

Value mm256ShufflePs(ImplicitLocOpBuilder &b, Value v1, Value v2, uint8_t mask)

a a b b a a b b Take an 8 bit mask, 2 bit for each position of a[0, 3) and b[0, 4): 0:127 | 128:255 b...

Value mm256UnpackLoPs(ImplicitLocOpBuilder &b, Value v1, Value v2)

Methods in the intrin namespace emulate clang's impl. of X86 intrinsics.

void transpose8x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)

8x8xf32-specific AVX2 transpose lowering.

void populateSpecializedTransposeLoweringPatterns(RewritePatternSet &patterns, LoweringOptions options=LoweringOptions(), int benefit=10)

Insert specialized transpose lowering patterns.

void transpose4x8xf32(ImplicitLocOpBuilder &ib, MutableArrayRef< Value > vs)

Generic lowerings may either use intrin or inline_asm depending on needs.

Include the generated interface declarations.

const FrozenRewritePatternSet & patterns

auto get(MLIRContext *context, Ts &&...params)

Helper method that injects context only if needed, this helps unify some of the attribute constructio...

OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...

Options for controlling specialized AVX2 lowerings.