MLIR: include/mlir/ExecutionEngine/RunnerUtils.h Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16 #ifndef MLIR_EXECUTIONENGINE_RUNNERUTILS_H

17 #define MLIR_EXECUTIONENGINE_RUNNERUTILS_H

18

19 #ifdef _WIN32

20 #ifndef MLIR_RUNNERUTILS_EXPORT

21 #ifdef mlir_runner_utils_EXPORTS

22

23 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllexport)

24 #else

25

26 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllimport)

27 #endif

28 #endif

29 #else

30

31 #define MLIR_RUNNERUTILS_EXPORT __attribute__((visibility("default")))

32 #endif

33

34 #include <assert.h>

35 #include

36 #include

37 #include

38 #include

39

42

43 template <typename T, typename StreamType>

45

46

47 os << "base@ = " << std::hex << std::showbase

48 << reinterpret_caststd::intptr\_t\(v.data) << std::dec << std::noshowbase

49 << " rank = " << v.rank << " offset = " << v.offset;

50 auto print = [&](const int64_t *ptr) {

51 if (v.rank == 0)

52 return;

53 os << ptr[0];

54 for (int64_t i = 1; i < v.rank; ++i)

55 os << ", " << ptr[i];

56 };

57 os << " sizes = [";

59 os << "] strides = [";

61 os << "]";

62 }

63

64 template <typename StreamType, typename T, int N>

66 static_assert(N >= 0, "Expected N > 0");

67 os << "MemRef ";

69 }

70

71 template <typename StreamType, typename T>

73 os << "Unranked MemRef ";

75 }

76

77

78

79

84

85 template <typename T, int M, int... Dims>

87

88 template <int... Dims>

90 static constexpr int value = 1;

91 };

92

93 template <int N, int... Dims>

96 };

97

98 static inline void printSpace(std::ostream &os, int count) {

99 for (int i = 0; i < count; ++i) {

100 os << ' ';

101 }

102 }

103

104 template <typename T, int M, int... Dims>

107 };

108

109 template <typename T, int M, int... Dims>

112 static_assert(M > 0, "0 dimensioned tensor");

114 "Incorrect vector size!");

115

116 os << "(" << val[0];

117 if (M > 1)

118 os << ", ";

119 if (sizeof...(Dims) > 1)

120 os << "\n";

121

122 for (unsigned i = 1; i + 1 < M; ++i) {

124 os << val[i] << ", ";

125 if (sizeof...(Dims) > 1)

126 os << "\n";

127 }

128

129 if (M > 1) {

131 os << val[M - 1];

132 }

133 os << ")";

134 }

135

136 template <typename T, int M, int... Dims>

139 return os;

140 }

141

142 template

144 static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,

145 int64_t offset, const int64_t *sizes,

146 const int64_t *strides);

147 static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,

148 int64_t offset, const int64_t *sizes,

149 const int64_t *strides);

150 static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,

151 int64_t offset, const int64_t *sizes,

152 const int64_t *strides);

153 };

154

155 template

157 int64_t rank, int64_t offset,

158 const int64_t *sizes,

159 const int64_t *strides) {

160 os << "[";

161 print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);

162

163 if (sizes[0] <= 1) {

164 os << "]";

165 return;

166 }

167 os << ", ";

168 if (dim > 1)

169 os << "\n";

170 }

171

172 template

174 int64_t rank, int64_t offset,

175 const int64_t *sizes, const int64_t *strides) {

176 if (dim == 0) {

177 os << base[offset];

178 return;

179 }

180 printFirst(os, base, dim, rank, offset, sizes, strides);

181 for (unsigned i = 1; i + 1 < sizes[0]; ++i) {

183 print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,

184 strides + 1);

185 os << ", ";

186 if (dim > 1)

187 os << "\n";

188 }

189 if (sizes[0] <= 1)

190 return;

191 printLast(os, base, dim, rank, offset, sizes, strides);

192 }

193

194 template

196 int64_t rank, int64_t offset,

197 const int64_t *sizes,

198 const int64_t *strides) {

200 print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),

201 sizes + 1, strides + 1);

202 os << "]";

203 }

204

205 template <typename T, int N>

207 std::cout << "Memref ";

209 }

210

211 template

213 std::cout << "Unranked Memref ";

215 }

216

217 template

220 std::cout << " data = \n";

221 if (m.rank == 0)

222 std::cout << "[";

224 m.sizes, m.strides);

225 if (m.rank == 0)

226 std::cout << "]";

227 std::cout << '\n' << std::flush;

228 }

229

230 template <typename T, int N>

232 std::cout << "Memref ";

234 }

235

236 template

238 std::cout << "Unranked Memref ";

240 }

241

242

243

244 template

246

248

249

251

252

253 static bool verifyElem(T actual, T expected);

254

255

256 static int64_t verify(std::ostream &os, T *actualBasePtr, T *expectedBasePtr,

257 int64_t dim, int64_t offset, const int64_t *sizes,

258 const int64_t *strides, int64_t &printCounter);

259 };

260

261 template

263 T epsilon) {

264

265 if (!std::isfinite(actual) || !std::isfinite(expected))

266 return false;

267

268 T delta = std::abs(actual - expected);

269 return (delta <= epsilon * std::abs(expected));

270 }

271

272 template

274 return actual == expected;

275 }

276

277 template <>

279 double expected) {

280 return verifyRelErrorSmallerThan(actual, expected, 1e-12);

281 }

282

283 template <>

285 float expected) {

286 return verifyRelErrorSmallerThan(actual, expected, 1e-6f);

287 }

288

289 template

291 T *expectedBasePtr, int64_t dim,

292 int64_t offset, const int64_t *sizes,

293 const int64_t *strides,

294 int64_t &printCounter) {

295 int64_t errors = 0;

296

297 if (dim == 0) {

298 if (!verifyElem(actualBasePtr[offset], expectedBasePtr[offset])) {

299 if (printCounter < printLimit) {

300 os << actualBasePtr[offset] << " != " << expectedBasePtr[offset]

301 << " offset = " << offset << "\n";

302 printCounter++;

303 }

304 errors++;

305 }

306 } else {

307

308 for (int64_t i = 0; i < sizes[0]; ++i) {

309 errors +=

310 verify(os, actualBasePtr, expectedBasePtr, dim - 1,

311 offset + i * strides[0], sizes + 1, strides + 1, printCounter);

312 }

313 }

314 return errors;

315 }

316

317

318

319 template

322

323 for (int64_t i = 0; i < actual.rank; ++i) {

325 actual.sizes[i] != expected.sizes[i] ||

329 return -1;

330 }

331 }

332

333 int64_t printCounter = 0;

336 actual.strides, printCounter);

337 }

338

339

340

341 template

346 }

347

348 }

349

350

351

352

369

392

394

402

413

428

431

457

459 void *actualPtr,

460 void *expectedPtr);

462 void *actualPtr,

463 void *expectedPtr);

465 void *actualPtr,

466 void *expectedPtr);

468 void *actualPtr,

469 void *expectedPtr);

471 void *actualPtr,

472 void *expectedPtr);

474 void *actualPtr,

475 void *expectedPtr);

476

477 #endif

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF16(UnrankedMemRefType< f16 > *actual, UnrankedMemRefType< f16 > *expected)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefI32(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI32(UnrankedMemRefType< int32_t > *actual, UnrankedMemRefType< int32_t > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeF32(UnrankedMemRefType< float > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefF16(UnrankedMemRefType< f16 > *m)

#define MLIR_RUNNERUTILS_EXPORT

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF32(UnrankedMemRefType< float > *actual, UnrankedMemRefType< float > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefC32(UnrankedMemRefType< impl::complex32 > *m)

MLIR_RUNNERUTILS_EXPORT void printMemrefF32(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void printMemrefI64(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeF64(UnrankedMemRefType< double > *m)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC64(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeI8(UnrankedMemRefType< int8_t > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dC32(StridedMemRefType< impl::complex32, 1 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dI32(StridedMemRefType< int32_t, 1 > *m)

void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType< T > &v)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dI8(StridedMemRefType< int8_t, 1 > *m)

MLIR_RUNNERUTILS_EXPORT void printMemrefInd(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void printMemrefC32(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefC64(UnrankedMemRefType< impl::complex64 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefVector4x4xf32(StridedMemRefType< Vector2D< 4, 4, float >, 2 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dInd(StridedMemRefType< impl::index_type, 1 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefI64(UnrankedMemRefType< int64_t > *m)

MLIR_RUNNERUTILS_EXPORT void printMemrefI32(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefC64(UnrankedMemRefType< impl::complex64 > *actual, UnrankedMemRefType< impl::complex64 > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref0dF32(StridedMemRefType< float, 0 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefI16(UnrankedMemRefType< int16_t > *m)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefC32(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeInd(UnrankedMemRefType< impl::index_type > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeI32(UnrankedMemRefType< int32_t > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefInd(UnrankedMemRefType< impl::index_type > *actual, UnrankedMemRefType< impl::index_type > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeI64(UnrankedMemRefType< int64_t > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dF64(StridedMemRefType< double, 1 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeC32(UnrankedMemRefType< impl::complex32 > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefBF16(UnrankedMemRefType< bf16 > *actual, UnrankedMemRefType< bf16 > *expected)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF64(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dF32(StridedMemRefType< float, 1 > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefC32(UnrankedMemRefType< impl::complex32 > *actual, UnrankedMemRefType< impl::complex32 > *expected)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefF32(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefF64(UnrankedMemRefType< double > *actual, UnrankedMemRefType< double > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefBF16(UnrankedMemRefType< bf16 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefF32(UnrankedMemRefType< float > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref2dF32(StridedMemRefType< float, 2 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefI32(UnrankedMemRefType< int32_t > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dC64(StridedMemRefType< impl::complex64, 1 > *m)

void printMemRefMetaData(StreamType &os, const DynamicMemRefType< T > &v)

MLIR_RUNNERUTILS_EXPORT void printMemrefC64(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void printMemrefF64(int64_t rank, void *ptr)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefShapeC64(UnrankedMemRefType< impl::complex64 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefF64(UnrankedMemRefType< double > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI64(UnrankedMemRefType< int64_t > *actual, UnrankedMemRefType< int64_t > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefInd(UnrankedMemRefType< impl::index_type > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI16(UnrankedMemRefType< int16_t > *actual, UnrankedMemRefType< int16_t > *expected)

MLIR_RUNNERUTILS_EXPORT int64_t verifyMemRefInd(int64_t rank, void *actualPtr, void *expectedPtr)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_nanoTime()

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref3dF32(StridedMemRefType< float, 3 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref4dF32(StridedMemRefType< float, 4 > *m)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemref1dI64(StridedMemRefType< int64_t, 1 > *m)

MLIR_RUNNERUTILS_EXPORT int64_t _mlir_ciface_verifyMemRefI8(UnrankedMemRefType< int8_t > *actual, UnrankedMemRefType< int8_t > *expected)

MLIR_RUNNERUTILS_EXPORT void _mlir_ciface_printMemrefI8(UnrankedMemRefType< int8_t > *m)

static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)

static void printSpace(std::ostream &os, int count)

std::complex< double > complex64

std::ostream & operator<<(std::ostream &os, const Vector< T, M, Dims... > &v)

void printMemRefShape(StridedMemRefType< T, N > &m)

int64_t verifyMemRef(const DynamicMemRefType< T > &actual, const DynamicMemRefType< T > &expected)

Verify the equivalence of two dynamic memrefs and return the number of errors or -1 if the shape of t...

std::complex< float > complex32

void printMemRef(const DynamicMemRefType< T > &m)

Fraction abs(const Fraction &f)

LogicalResult verify(Operation *op, bool verifyRecursively=true)

Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...

StridedMemRef descriptor type with static rank.

static void print(std::ostream &os, T *base, int64_t dim, int64_t rank, int64_t offset, const int64_t *sizes, const int64_t *strides)

static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank, int64_t offset, const int64_t *sizes, const int64_t *strides)

static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank, int64_t offset, const int64_t *sizes, const int64_t *strides)

Verify the result of two computations are equivalent up to a small numerical error and return the num...

static constexpr int printLimit

Maximum number of errors printed by the verifier.

static bool verifyRelErrorSmallerThan(T actual, T expected, T epsilon)

Verify the relative difference of the values is smaller than epsilon.

static int64_t verify(std::ostream &os, T *actualBasePtr, T *expectedBasePtr, int64_t dim, int64_t offset, const int64_t *sizes, const int64_t *strides, int64_t &printCounter)

Verify the data element-by-element and return the number of errors.

static bool verifyElem(T actual, T expected)

Verify the values are equivalent (integers) or are close (floating-point).

static constexpr int value

static void print(std::ostream &os, const Vector< T, M, Dims... > &val)