MLIR: lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp Source File (original) (raw)

1

2

3

4

5

6

7

8

9

10

11

12

13#include "llvm/ADT/Twine.h"

14

15#include "level_zero/ze_api.h"

16#include

17#include

18#include

19#include

20#include

21#include

22#include <unordered_set>

23#include

24

25namespace {

26template

27auto catchAll(F &&func) {

28 try {

29 return func();

30 } catch (const std::exception &e) {

31 std::cerr << "An exception was thrown: " << e.what() << std::endl;

32 std::abort();

33 } catch (...) {

34 std::cerr << "An unknown exception was thrown." << std::endl;

35 std::abort();

36 }

37}

38

39#define L0_SAFE_CALL(call) \

40 { \

41 ze_result_t status = (call); \

42 if (status != ZE_RESULT_SUCCESS) { \

43 const char *errorString; \

44 zeDriverGetLastErrorDescription(NULL, &errorString); \

45 std::cerr << "L0 error " << status << ": " << errorString << std::endl; \

46 std::abort(); \

47 } \

48 }

49}

50

51

52

53

54

55

56

57

58static ze_driver_handle_t getDriver(uint32_t idx = 0) {

59 ze_init_driver_type_desc_t driver_type = {};

60 driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;

61 driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;

62 driver_type.pNext = nullptr;

63 uint32_t driverCount{0};

64 thread_local static std::vector<ze_driver_handle_t> drivers;

65 thread_local static bool isDriverInitialised{false};

66 if (isDriverInitialised && idx < drivers.size())

67 return drivers[idx];

68 L0_SAFE_CALL(zeInitDrivers(&driverCount, nullptr, &driver_type));

69 if (!driverCount)

70 throw std::runtime_error("No L0 drivers found.");

71 drivers.resize(driverCount);

72 L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type));

73 if (idx >= driverCount)

74 throw std::runtime_error((llvm::Twine("Requested driver idx out-of-bound, "

75 "number of availabe drivers: ") +

76 std::to_string(driverCount))

77 .str());

78 isDriverInitialised = true;

79 return drivers[idx];

80}

81

82static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,

83 const int32_t devIdx = 0) {

84 thread_local static ze_device_handle_t l0Device;

85 thread_local int32_t currDevIdx{-1};

86 thread_local uint32_t currDriverIdx{0};

87 if (currDriverIdx == driverIdx && currDevIdx == devIdx)

88 return l0Device;

89 auto driver = getDriver(driverIdx);

90 uint32_t deviceCount{0};

91 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, nullptr));

92 if (!deviceCount)

93 throw std::runtime_error("getDevice failed: did not find L0 device.");

94 if (static_cast<int>(deviceCount) < devIdx + 1)

95 throw std::runtime_error("getDevice failed: devIdx out-of-bounds.");

96 std::vector<ze_device_handle_t> devices(deviceCount);

97 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));

98 l0Device = devices[devIdx];

99 currDriverIdx = driverIdx;

100 currDevIdx = devIdx;

101 return l0Device;

102}

103

104

105static ze_context_handle_t getContext(ze_driver_handle_t driver) {

106 thread_local static ze_context_handle_t context;

107 thread_local static bool isContextInitialised{false};

108 if (isContextInitialised)

109 return context;

110 ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr, 0};

111 L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));

112 isContextInitialised = true;

113 return context;

114}

115

116

117

118

119

126

128 void operator()(ze_command_list_handle_t cmdList) const {

129 if (cmdList)

131 }

132};

134 std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,

137 std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,

140 ze_driver_handle_t driver{nullptr};

141 ze_device_handle_t device{nullptr};

143

144

146

147

150

154

157

158

159 uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;

160 ze_device_properties_t deviceProperties{};

162 uint32_t queueGroupCount = 0;

163 L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(

164 device, &queueGroupCount, nullptr));

165 std::vector<ze_command_queue_group_properties_t> queueGroupProperties(

166 queueGroupCount);

167 L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(

168 device, &queueGroupCount, queueGroupProperties.data()));

169

170 for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount;

171 ++queueGroupIdx) {

172 const auto &group = queueGroupProperties[queueGroupIdx];

173 if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE)

174 computeEngineOrdinal = queueGroupIdx;

175 else if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) {

176 copyEngineOrdinal = queueGroupIdx;

178 }

179 if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u)

180 break;

181 }

182

183

184 if (copyEngineOrdinal == -1u)

185 copyEngineOrdinal = computeEngineOrdinal;

186

187 assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&

188 "Expected two engines to be available.");

189

190

191 ze_command_queue_desc_t cmdQueueDesc{

192 ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,

193 nullptr,

194 copyEngineOrdinal,

195 0,

196 0,

197 ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,

198 ZE_COMMAND_QUEUE_PRIORITY_NORMAL};

199

200 ze_command_list_handle_t rawCmdListCopy = nullptr;

202 &cmdQueueDesc, &rawCmdListCopy));

204

205

206 cmdQueueDesc.ordinal = computeEngineOrdinal;

207 ze_command_list_handle_t rawCmdListCompute = nullptr;

209 context.get(), device, &cmdQueueDesc, &rawCmdListCompute));

211 }

214

218};

219

226

228 void operator()(ze_event_pool_handle_t pool) const {

229 if (pool)

231 }

232};

233

235 std::unique_ptr<std::remove_pointer<ze_event_handle_t>::type,

238 std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t>::type,

240

241

242

243

246

249 std::unordered_map<ze_event_handle_t, UniqueZeEvent> takenEvents;

250

251

252

253

258

262

265

266

269

271 assert(takenEvents.empty() && "Some events were not released");

272 }

273

275 ze_event_pool_desc_t eventPoolDesc = {};

276 eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;

277 eventPoolDesc.count = numEvents;

278

279 ze_event_pool_handle_t rawPool = nullptr;

280 L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context.get(), &eventPoolDesc, 1,

281 &rtCtx->device, &rawPool));

282

285 }

286

288 ze_event_handle_t rawEvent = nullptr;

289

291

294 rawEvent = uniqueEvent.get();

295 takenEvents[rawEvent] = std::move(uniqueEvent);

296 } else {

298 throw std::runtime_error("DynamicEventPool: reached max events limit");

299 }

302

303 ze_event_desc_t eventDesc = {

304 ZE_STRUCTURE_TYPE_EVENT_DESC, nullptr,

306 ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};

307

308 ze_event_handle_t newEvent = nullptr;

310 zeEventCreate(eventPools.back().get(), &eventDesc, &newEvent));

311

313 rawEvent = newEvent;

315 }

316

317 return rawEvent;

318 }

319

323 "Attempting to release unknown or already released event");

324

328 }

329};

330

333 return rtContext;

334}

335

338 return dynEventPool;

339}

340

342

345

348

353

354 void sync(ze_event_handle_t explicitEvent = nullptr) {

355 ze_event_handle_t syncEvent{nullptr};

356 if (!explicitEvent) {

358 syncEvent = lastImplicitEventPtr ? *lastImplicitEventPtr : nullptr;

359 } else {

360 syncEvent = explicitEvent;

361 }

362 if (syncEvent)

364 syncEvent, std::numeric_limits<uint64_t>::max()));

365

366

370 }

371

372 template

374 ze_event_handle_t newImplicitEvent = dynEventPool.takeEvent();

376 const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0;

377 std::forward(op)(newImplicitEvent, numWaitEvents,

378 lastImplicitEventPtr);

380 }

381};

382

383static ze_module_handle_t loadModule(const void *data, size_t dataSize) {

384 assert(data);

385 ze_module_handle_t zeModule;

386 ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,

387 nullptr,

388 ZE_MODULE_FORMAT_IL_SPIRV,

389 dataSize,

390 (const uint8_t *)data,

391 nullptr,

392 nullptr};

393 ze_module_build_log_handle_t buildLogHandle;

396 &zeModule, &buildLogHandle);

397 if (result != ZE_RESULT_SUCCESS) {

398 std::cerr << "Error creating module, error code: " << result << std::endl;

399 size_t logSize = 0;

400 L0_SAFE_CALL(zeModuleBuildLogGetString(buildLogHandle, &logSize, nullptr));

401 std::string buildLog(" ", logSize);

403 zeModuleBuildLogGetString(buildLogHandle, &logSize, buildLog.data()));

404 std::cerr << "Build log:\n" << buildLog << std::endl;

405 std::abort();

406 }

407 return zeModule;

408}

409

410

411

412

413

417

419 if (stream)

420 stream->sync();

421}

422

424

426 ze_event_handle_t event) {

427 assert(stream && "Invalid stream");

428 assert(event && "Invalid event");

429 stream->sync(event);

430}

431

435

439

442 zeEventHostSynchronize(event, std::numeric_limits<uint64_t>::max()));

444}

445

453

455 bool isShared) {

456 return catchAll([&]() {

457 void *memPtr = nullptr;

458 constexpr size_t alignment{64};

459 ze_device_mem_alloc_desc_t deviceDesc = {};

460 deviceDesc.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC;

461 if (isShared) {

462 ze_host_mem_alloc_desc_t hostDesc = {};

463 hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;

465 &hostDesc, size, alignment,

467 } else {

470 &memPtr));

471 }

472 if (!memPtr)

473 throw std::runtime_error("mem allocation failed!");

474 return memPtr;

475 });

476}

477

479 stream->sync();

480 if (ptr)

482}

483

484extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,

486 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,

487 ze_event_handle_t *waitEvents) {

490 numWaitEvents, waitEvents));

491 });

492}

493

494template <typename PATTERN_TYPE>

495static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,

498 auto listType =

502 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,

503 ze_event_handle_t *waitEvents) {

505 listType, dst, &value, sizeof(PATTERN_TYPE),

506 count * sizeof(PATTERN_TYPE), newEvent, numWaitEvents, waitEvents));

507 });

508}

509extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,

512}

513

514extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count,

517}

518

520 size_t gpuBlobSize) {

521 return catchAll([&]() { return loadModule(data, gpuBlobSize); });

522}

523

525 const char *name) {

526 assert(module && name);

527 ze_kernel_handle_t zeKernel;

528 ze_kernel_desc_t desc = {};

529 desc.pKernelName = name;

530 L0_SAFE_CALL(zeKernelCreate(module, &desc, &zeKernel));

531 return zeKernel;

532}

533

535 size_t gridY, size_t gridZ, size_t blockX,

536 size_t blockY, size_t blockZ,

538 void **params, void ** ,

539 size_t paramsCount) {

540

541 if (sharedMemBytes > 0) {

542 paramsCount = paramsCount - 1;

544 zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes, nullptr));

545 }

546 for (size_t i = 0; i < paramsCount; ++i)

547 L0_SAFE_CALL(zeKernelSetArgumentValue(kernel, static_cast<uint32_t>(i),

548 sizeof(void *), params[i]));

549 L0_SAFE_CALL(zeKernelSetGroupSize(kernel, blockX, blockY, blockZ));

550 ze_group_count_t dispatch;

551 dispatch.groupCountX = static_cast<uint32_t>(gridX);

552 dispatch.groupCountY = static_cast<uint32_t>(gridY);

553 dispatch.groupCountZ = static_cast<uint32_t>(gridZ);

554 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,

555 ze_event_handle_t *waitEvents) {

556 L0_SAFE_CALL(zeCommandListAppendLaunchKernel(

558 numWaitEvents, waitEvents));

559 });

560}

561

565

567 catchAll([&]() {

568

569

572 });

573}

std::unique_ptr< std::remove_pointer< ze_event_handle_t >::type, ZeEventDeleter > UniqueZeEvent

Definition LevelZeroRuntimeWrappers.cpp:234

void mgpuSetDefaultDevice(int32_t devIdx)

Definition LevelZeroRuntimeWrappers.cpp:566

static ze_module_handle_t loadModule(const void *data, size_t dataSize)

Definition LevelZeroRuntimeWrappers.cpp:383

static L0RTContextWrapper & getRtContext()

Definition LevelZeroRuntimeWrappers.cpp:331

void mgpuMemset16(void *dst, unsigned short value, size_t count, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:514

#define L0_SAFE_CALL(call)

Definition LevelZeroRuntimeWrappers.cpp:39

static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:495

static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)

Definition LevelZeroRuntimeWrappers.cpp:82

static DynamicEventPool & getDynamicEventPool()

Definition LevelZeroRuntimeWrappers.cpp:336

std::unique_ptr< std::remove_pointer< ze_context_handle_t >::type, ZeContextDeleter > UniqueZeContext

Definition LevelZeroRuntimeWrappers.cpp:133

void * mgpuMemAlloc(uint64_t size, StreamWrapper *stream, bool isShared)

Definition LevelZeroRuntimeWrappers.cpp:454

void mgpuStreamDestroy(StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:423

ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize)

Definition LevelZeroRuntimeWrappers.cpp:519

void mgpuEventSynchronize(ze_event_handle_t event)

Definition LevelZeroRuntimeWrappers.cpp:440

void mgpuModuleUnload(ze_module_handle_t module)

Definition LevelZeroRuntimeWrappers.cpp:562

static ze_driver_handle_t getDriver(uint32_t idx=0)

Definition LevelZeroRuntimeWrappers.cpp:58

void mgpuMemset32(void *dst, unsigned int value, size_t count, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:509

StreamWrapper * mgpuStreamCreate()

Definition LevelZeroRuntimeWrappers.cpp:414

std::unique_ptr< std::remove_pointer< ze_command_list_handle_t >::type, ZeCommandListDeleter > UniqueZeCommandList

Definition LevelZeroRuntimeWrappers.cpp:136

void mgpuEventDestroy(ze_event_handle_t event)

Definition LevelZeroRuntimeWrappers.cpp:436

void mgpuStreamSynchronize(StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:418

ze_kernel_handle_t mgpuModuleGetFunction(ze_module_handle_t module, const char *name)

Definition LevelZeroRuntimeWrappers.cpp:524

void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:484

void mgpuStreamWaitEvent(StreamWrapper *stream, ze_event_handle_t event)

Definition LevelZeroRuntimeWrappers.cpp:425

void mgpuMemFree(void *ptr, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:478

void mgpuEventRecord(ze_event_handle_t event, StreamWrapper *stream)

Definition LevelZeroRuntimeWrappers.cpp:446

ze_event_handle_t mgpuEventCreate()

Definition LevelZeroRuntimeWrappers.cpp:432

std::unique_ptr< std::remove_pointer< ze_event_pool_handle_t >::type, ZeEventPoolDeleter > UniqueZeEventPool

Definition LevelZeroRuntimeWrappers.cpp:237

void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, StreamWrapper *stream, void **params, void **, size_t paramsCount)

Definition LevelZeroRuntimeWrappers.cpp:534

void createNewPool(size_t numEvents)

Definition LevelZeroRuntimeWrappers.cpp:274

L0RTContextWrapper * rtCtx

Definition LevelZeroRuntimeWrappers.cpp:257

size_t maxEventsCount

Definition LevelZeroRuntimeWrappers.cpp:254

DynamicEventPool & operator=(const DynamicEventPool &)=delete

static constexpr size_t numEventsPerPool

Definition LevelZeroRuntimeWrappers.cpp:245

void releaseEvent(ze_event_handle_t event)

Definition LevelZeroRuntimeWrappers.cpp:320

DynamicEventPool(DynamicEventPool &&) noexcept=default

DynamicEventPool(L0RTContextWrapper *rtCtx)

Definition LevelZeroRuntimeWrappers.cpp:259

std::vector< UniqueZeEventPool > eventPools

Definition LevelZeroRuntimeWrappers.cpp:247

std::unordered_map< ze_event_handle_t, UniqueZeEvent > takenEvents

Definition LevelZeroRuntimeWrappers.cpp:249

ze_event_handle_t takeEvent()

Definition LevelZeroRuntimeWrappers.cpp:287

size_t currentEventsCnt

Definition LevelZeroRuntimeWrappers.cpp:256

std::vector< UniqueZeEvent > availableEvents

Definition LevelZeroRuntimeWrappers.cpp:248

DynamicEventPool(const DynamicEventPool &)=delete

size_t currentEventsLimit

Definition LevelZeroRuntimeWrappers.cpp:255

Definition LevelZeroRuntimeWrappers.cpp:139

UniqueZeCommandList immCmdListCopy

Definition LevelZeroRuntimeWrappers.cpp:148

L0RTContextWrapper()=default

L0RTContextWrapper & operator=(const L0RTContextWrapper &)=delete

uint32_t copyEngineMaxMemoryFillPatternSize

Definition LevelZeroRuntimeWrappers.cpp:149

ze_device_handle_t device

Definition LevelZeroRuntimeWrappers.cpp:141

L0RTContextWrapper(L0RTContextWrapper &&) noexcept=default

UniqueZeContext context

Definition LevelZeroRuntimeWrappers.cpp:142

L0RTContextWrapper(const uint32_t driverIdx=0, const int32_t devIdx=0)

Definition LevelZeroRuntimeWrappers.cpp:152

ze_driver_handle_t driver

Definition LevelZeroRuntimeWrappers.cpp:140

UniqueZeCommandList immCmdListCompute

Definition LevelZeroRuntimeWrappers.cpp:145

L0RTContextWrapper(const L0RTContextWrapper &)=delete

~StreamWrapper()

Definition LevelZeroRuntimeWrappers.cpp:347

ze_event_handle_t * getLastImplicitEventPtr()

Definition LevelZeroRuntimeWrappers.cpp:349

void sync(ze_event_handle_t explicitEvent=nullptr)

Definition LevelZeroRuntimeWrappers.cpp:354

StreamWrapper(DynamicEventPool &dynEventPool)

Definition LevelZeroRuntimeWrappers.cpp:346

std::deque< ze_event_handle_t > implicitEventStack

Definition LevelZeroRuntimeWrappers.cpp:343

DynamicEventPool & dynEventPool

Definition LevelZeroRuntimeWrappers.cpp:344

void enqueueOp(Func &&op)

Definition LevelZeroRuntimeWrappers.cpp:373

void operator()(ze_command_list_handle_t cmdList) const

Definition LevelZeroRuntimeWrappers.cpp:128

Definition LevelZeroRuntimeWrappers.cpp:120

void operator()(ze_context_handle_t ctx) const

Definition LevelZeroRuntimeWrappers.cpp:121

void operator()(ze_event_handle_t event) const

Definition LevelZeroRuntimeWrappers.cpp:221

void operator()(ze_event_pool_handle_t pool) const

Definition LevelZeroRuntimeWrappers.cpp:228