MLIR: lib/ExecutionEngine/LevelZeroRuntimeWrappers.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
13#include "llvm/ADT/Twine.h"
14
15#include "level_zero/ze_api.h"
16#include
17#include
18#include
19#include
20#include
21#include
22#include <unordered_set>
23#include
24
25namespace {
26template
27auto catchAll(F &&func) {
28 try {
29 return func();
30 } catch (const std::exception &e) {
31 std::cerr << "An exception was thrown: " << e.what() << std::endl;
32 std::abort();
33 } catch (...) {
34 std::cerr << "An unknown exception was thrown." << std::endl;
35 std::abort();
36 }
37}
38
39#define L0_SAFE_CALL(call) \
40 { \
41 ze_result_t status = (call); \
42 if (status != ZE_RESULT_SUCCESS) { \
43 const char *errorString; \
44 zeDriverGetLastErrorDescription(NULL, &errorString); \
45 std::cerr << "L0 error " << status << ": " << errorString << std::endl; \
46 std::abort(); \
47 } \
48 }
49}
50
51
52
53
54
55
56
57
58static ze_driver_handle_t getDriver(uint32_t idx = 0) {
59 ze_init_driver_type_desc_t driver_type = {};
60 driver_type.stype = ZE_STRUCTURE_TYPE_INIT_DRIVER_TYPE_DESC;
61 driver_type.flags = ZE_INIT_DRIVER_TYPE_FLAG_GPU;
62 driver_type.pNext = nullptr;
63 uint32_t driverCount{0};
64 thread_local static std::vector<ze_driver_handle_t> drivers;
65 thread_local static bool isDriverInitialised{false};
66 if (isDriverInitialised && idx < drivers.size())
67 return drivers[idx];
68 L0_SAFE_CALL(zeInitDrivers(&driverCount, nullptr, &driver_type));
69 if (!driverCount)
70 throw std::runtime_error("No L0 drivers found.");
71 drivers.resize(driverCount);
72 L0_SAFE_CALL(zeInitDrivers(&driverCount, drivers.data(), &driver_type));
73 if (idx >= driverCount)
74 throw std::runtime_error((llvm::Twine("Requested driver idx out-of-bound, "
75 "number of availabe drivers: ") +
76 std::to_string(driverCount))
77 .str());
78 isDriverInitialised = true;
79 return drivers[idx];
80}
81
82static ze_device_handle_t getDevice(const uint32_t driverIdx = 0,
83 const int32_t devIdx = 0) {
84 thread_local static ze_device_handle_t l0Device;
85 thread_local int32_t currDevIdx{-1};
86 thread_local uint32_t currDriverIdx{0};
87 if (currDriverIdx == driverIdx && currDevIdx == devIdx)
88 return l0Device;
89 auto driver = getDriver(driverIdx);
90 uint32_t deviceCount{0};
91 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, nullptr));
92 if (!deviceCount)
93 throw std::runtime_error("getDevice failed: did not find L0 device.");
94 if (static_cast<int>(deviceCount) < devIdx + 1)
95 throw std::runtime_error("getDevice failed: devIdx out-of-bounds.");
96 std::vector<ze_device_handle_t> devices(deviceCount);
97 L0_SAFE_CALL(zeDeviceGet(driver, &deviceCount, devices.data()));
98 l0Device = devices[devIdx];
99 currDriverIdx = driverIdx;
100 currDevIdx = devIdx;
101 return l0Device;
102}
103
104
105static ze_context_handle_t getContext(ze_driver_handle_t driver) {
106 thread_local static ze_context_handle_t context;
107 thread_local static bool isContextInitialised{false};
108 if (isContextInitialised)
109 return context;
110 ze_context_desc_t ctxtDesc = {ZE_STRUCTURE_TYPE_CONTEXT_DESC, nullptr, 0};
111 L0_SAFE_CALL(zeContextCreate(driver, &ctxtDesc, &context));
112 isContextInitialised = true;
113 return context;
114}
115
116
117
118
119
126
128 void operator()(ze_command_list_handle_t cmdList) const {
129 if (cmdList)
131 }
132};
134 std::unique_ptr<std::remove_pointer<ze_context_handle_t>::type,
137 std::unique_ptr<std::remove_pointer<ze_command_list_handle_t>::type,
140 ze_driver_handle_t driver{nullptr};
141 ze_device_handle_t device{nullptr};
143
144
146
147
150
154
157
158
159 uint32_t computeEngineOrdinal = -1u, copyEngineOrdinal = -1u;
160 ze_device_properties_t deviceProperties{};
162 uint32_t queueGroupCount = 0;
163 L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(
164 device, &queueGroupCount, nullptr));
165 std::vector<ze_command_queue_group_properties_t> queueGroupProperties(
166 queueGroupCount);
167 L0_SAFE_CALL(zeDeviceGetCommandQueueGroupProperties(
168 device, &queueGroupCount, queueGroupProperties.data()));
169
170 for (uint32_t queueGroupIdx = 0; queueGroupIdx < queueGroupCount;
171 ++queueGroupIdx) {
172 const auto &group = queueGroupProperties[queueGroupIdx];
173 if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COMPUTE)
174 computeEngineOrdinal = queueGroupIdx;
175 else if (group.flags & ZE_COMMAND_QUEUE_GROUP_PROPERTY_FLAG_COPY) {
176 copyEngineOrdinal = queueGroupIdx;
178 }
179 if (copyEngineOrdinal != -1u && computeEngineOrdinal != -1u)
180 break;
181 }
182
183
184 if (copyEngineOrdinal == -1u)
185 copyEngineOrdinal = computeEngineOrdinal;
186
187 assert(copyEngineOrdinal != -1u && computeEngineOrdinal != -1u &&
188 "Expected two engines to be available.");
189
190
191 ze_command_queue_desc_t cmdQueueDesc{
192 ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
193 nullptr,
194 copyEngineOrdinal,
195 0,
196 0,
197 ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
198 ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
199
200 ze_command_list_handle_t rawCmdListCopy = nullptr;
202 &cmdQueueDesc, &rawCmdListCopy));
204
205
206 cmdQueueDesc.ordinal = computeEngineOrdinal;
207 ze_command_list_handle_t rawCmdListCompute = nullptr;
209 context.get(), device, &cmdQueueDesc, &rawCmdListCompute));
211 }
214
218};
219
226
228 void operator()(ze_event_pool_handle_t pool) const {
229 if (pool)
231 }
232};
233
235 std::unique_ptr<std::remove_pointer<ze_event_handle_t>::type,
238 std::unique_ptr<std::remove_pointer<ze_event_pool_handle_t>::type,
240
241
242
243
246
249 std::unordered_map<ze_event_handle_t, UniqueZeEvent> takenEvents;
250
251
252
253
258
262
265
266
269
271 assert(takenEvents.empty() && "Some events were not released");
272 }
273
275 ze_event_pool_desc_t eventPoolDesc = {};
276 eventPoolDesc.flags = ZE_EVENT_POOL_FLAG_HOST_VISIBLE;
277 eventPoolDesc.count = numEvents;
278
279 ze_event_pool_handle_t rawPool = nullptr;
280 L0_SAFE_CALL(zeEventPoolCreate(rtCtx->context.get(), &eventPoolDesc, 1,
281 &rtCtx->device, &rawPool));
282
285 }
286
288 ze_event_handle_t rawEvent = nullptr;
289
291
294 rawEvent = uniqueEvent.get();
295 takenEvents[rawEvent] = std::move(uniqueEvent);
296 } else {
298 throw std::runtime_error("DynamicEventPool: reached max events limit");
299 }
302
303 ze_event_desc_t eventDesc = {
304 ZE_STRUCTURE_TYPE_EVENT_DESC, nullptr,
306 ZE_EVENT_SCOPE_FLAG_DEVICE, ZE_EVENT_SCOPE_FLAG_HOST};
307
308 ze_event_handle_t newEvent = nullptr;
310 zeEventCreate(eventPools.back().get(), &eventDesc, &newEvent));
311
313 rawEvent = newEvent;
315 }
316
317 return rawEvent;
318 }
319
323 "Attempting to release unknown or already released event");
324
328 }
329};
330
333 return rtContext;
334}
335
338 return dynEventPool;
339}
340
342
345
348
353
354 void sync(ze_event_handle_t explicitEvent = nullptr) {
355 ze_event_handle_t syncEvent{nullptr};
356 if (!explicitEvent) {
358 syncEvent = lastImplicitEventPtr ? *lastImplicitEventPtr : nullptr;
359 } else {
360 syncEvent = explicitEvent;
361 }
362 if (syncEvent)
364 syncEvent, std::numeric_limits<uint64_t>::max()));
365
366
370 }
371
372 template
374 ze_event_handle_t newImplicitEvent = dynEventPool.takeEvent();
376 const uint32_t numWaitEvents = lastImplicitEventPtr ? 1 : 0;
377 std::forward(op)(newImplicitEvent, numWaitEvents,
378 lastImplicitEventPtr);
380 }
381};
382
383static ze_module_handle_t loadModule(const void *data, size_t dataSize) {
384 assert(data);
385 ze_module_handle_t zeModule;
386 ze_module_desc_t desc = {ZE_STRUCTURE_TYPE_MODULE_DESC,
387 nullptr,
388 ZE_MODULE_FORMAT_IL_SPIRV,
389 dataSize,
390 (const uint8_t *)data,
391 nullptr,
392 nullptr};
393 ze_module_build_log_handle_t buildLogHandle;
396 &zeModule, &buildLogHandle);
397 if (result != ZE_RESULT_SUCCESS) {
398 std::cerr << "Error creating module, error code: " << result << std::endl;
399 size_t logSize = 0;
400 L0_SAFE_CALL(zeModuleBuildLogGetString(buildLogHandle, &logSize, nullptr));
401 std::string buildLog(" ", logSize);
403 zeModuleBuildLogGetString(buildLogHandle, &logSize, buildLog.data()));
404 std::cerr << "Build log:\n" << buildLog << std::endl;
405 std::abort();
406 }
407 return zeModule;
408}
409
410
411
412
413
417
419 if (stream)
420 stream->sync();
421}
422
424
426 ze_event_handle_t event) {
427 assert(stream && "Invalid stream");
428 assert(event && "Invalid event");
429 stream->sync(event);
430}
431
435
439
442 zeEventHostSynchronize(event, std::numeric_limits<uint64_t>::max()));
444}
445
453
455 bool isShared) {
456 return catchAll([&]() {
457 void *memPtr = nullptr;
458 constexpr size_t alignment{64};
459 ze_device_mem_alloc_desc_t deviceDesc = {};
460 deviceDesc.stype = ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC;
461 if (isShared) {
462 ze_host_mem_alloc_desc_t hostDesc = {};
463 hostDesc.stype = ZE_STRUCTURE_TYPE_HOST_MEM_ALLOC_DESC;
465 &hostDesc, size, alignment,
467 } else {
470 &memPtr));
471 }
472 if (!memPtr)
473 throw std::runtime_error("mem allocation failed!");
474 return memPtr;
475 });
476}
477
479 stream->sync();
480 if (ptr)
482}
483
484extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
486 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
487 ze_event_handle_t *waitEvents) {
490 numWaitEvents, waitEvents));
491 });
492}
493
494template <typename PATTERN_TYPE>
495static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count,
498 auto listType =
502 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
503 ze_event_handle_t *waitEvents) {
505 listType, dst, &value, sizeof(PATTERN_TYPE),
506 count * sizeof(PATTERN_TYPE), newEvent, numWaitEvents, waitEvents));
507 });
508}
509extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
512}
513
514extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count,
517}
518
520 size_t gpuBlobSize) {
521 return catchAll([&]() { return loadModule(data, gpuBlobSize); });
522}
523
525 const char *name) {
526 assert(module && name);
527 ze_kernel_handle_t zeKernel;
528 ze_kernel_desc_t desc = {};
529 desc.pKernelName = name;
530 L0_SAFE_CALL(zeKernelCreate(module, &desc, &zeKernel));
531 return zeKernel;
532}
533
535 size_t gridY, size_t gridZ, size_t blockX,
536 size_t blockY, size_t blockZ,
538 void **params, void ** ,
539 size_t paramsCount) {
540
541 if (sharedMemBytes > 0) {
542 paramsCount = paramsCount - 1;
544 zeKernelSetArgumentValue(kernel, paramsCount, sharedMemBytes, nullptr));
545 }
546 for (size_t i = 0; i < paramsCount; ++i)
547 L0_SAFE_CALL(zeKernelSetArgumentValue(kernel, static_cast<uint32_t>(i),
548 sizeof(void *), params[i]));
549 L0_SAFE_CALL(zeKernelSetGroupSize(kernel, blockX, blockY, blockZ));
550 ze_group_count_t dispatch;
551 dispatch.groupCountX = static_cast<uint32_t>(gridX);
552 dispatch.groupCountY = static_cast<uint32_t>(gridY);
553 dispatch.groupCountZ = static_cast<uint32_t>(gridZ);
554 stream->enqueueOp([&](ze_event_handle_t newEvent, uint32_t numWaitEvents,
555 ze_event_handle_t *waitEvents) {
556 L0_SAFE_CALL(zeCommandListAppendLaunchKernel(
558 numWaitEvents, waitEvents));
559 });
560}
561
565
567 catchAll([&]() {
568
569
572 });
573}
std::unique_ptr< std::remove_pointer< ze_event_handle_t >::type, ZeEventDeleter > UniqueZeEvent
Definition LevelZeroRuntimeWrappers.cpp:234
void mgpuSetDefaultDevice(int32_t devIdx)
Definition LevelZeroRuntimeWrappers.cpp:566
static ze_module_handle_t loadModule(const void *data, size_t dataSize)
Definition LevelZeroRuntimeWrappers.cpp:383
static L0RTContextWrapper & getRtContext()
Definition LevelZeroRuntimeWrappers.cpp:331
void mgpuMemset16(void *dst, unsigned short value, size_t count, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:514
#define L0_SAFE_CALL(call)
Definition LevelZeroRuntimeWrappers.cpp:39
static void mgpuMemset(void *dst, PATTERN_TYPE value, size_t count, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:495
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
Definition LevelZeroRuntimeWrappers.cpp:82
static DynamicEventPool & getDynamicEventPool()
Definition LevelZeroRuntimeWrappers.cpp:336
std::unique_ptr< std::remove_pointer< ze_context_handle_t >::type, ZeContextDeleter > UniqueZeContext
Definition LevelZeroRuntimeWrappers.cpp:133
void * mgpuMemAlloc(uint64_t size, StreamWrapper *stream, bool isShared)
Definition LevelZeroRuntimeWrappers.cpp:454
void mgpuStreamDestroy(StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:423
ze_module_handle_t mgpuModuleLoad(const void *data, size_t gpuBlobSize)
Definition LevelZeroRuntimeWrappers.cpp:519
void mgpuEventSynchronize(ze_event_handle_t event)
Definition LevelZeroRuntimeWrappers.cpp:440
void mgpuModuleUnload(ze_module_handle_t module)
Definition LevelZeroRuntimeWrappers.cpp:562
static ze_driver_handle_t getDriver(uint32_t idx=0)
Definition LevelZeroRuntimeWrappers.cpp:58
void mgpuMemset32(void *dst, unsigned int value, size_t count, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:509
StreamWrapper * mgpuStreamCreate()
Definition LevelZeroRuntimeWrappers.cpp:414
std::unique_ptr< std::remove_pointer< ze_command_list_handle_t >::type, ZeCommandListDeleter > UniqueZeCommandList
Definition LevelZeroRuntimeWrappers.cpp:136
void mgpuEventDestroy(ze_event_handle_t event)
Definition LevelZeroRuntimeWrappers.cpp:436
void mgpuStreamSynchronize(StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:418
ze_kernel_handle_t mgpuModuleGetFunction(ze_module_handle_t module, const char *name)
Definition LevelZeroRuntimeWrappers.cpp:524
void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:484
void mgpuStreamWaitEvent(StreamWrapper *stream, ze_event_handle_t event)
Definition LevelZeroRuntimeWrappers.cpp:425
void mgpuMemFree(void *ptr, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:478
void mgpuEventRecord(ze_event_handle_t event, StreamWrapper *stream)
Definition LevelZeroRuntimeWrappers.cpp:446
ze_event_handle_t mgpuEventCreate()
Definition LevelZeroRuntimeWrappers.cpp:432
std::unique_ptr< std::remove_pointer< ze_event_pool_handle_t >::type, ZeEventPoolDeleter > UniqueZeEventPool
Definition LevelZeroRuntimeWrappers.cpp:237
void mgpuLaunchKernel(ze_kernel_handle_t kernel, size_t gridX, size_t gridY, size_t gridZ, size_t blockX, size_t blockY, size_t blockZ, size_t sharedMemBytes, StreamWrapper *stream, void **params, void **, size_t paramsCount)
Definition LevelZeroRuntimeWrappers.cpp:534
void createNewPool(size_t numEvents)
Definition LevelZeroRuntimeWrappers.cpp:274
L0RTContextWrapper * rtCtx
Definition LevelZeroRuntimeWrappers.cpp:257
size_t maxEventsCount
Definition LevelZeroRuntimeWrappers.cpp:254
DynamicEventPool & operator=(const DynamicEventPool &)=delete
static constexpr size_t numEventsPerPool
Definition LevelZeroRuntimeWrappers.cpp:245
void releaseEvent(ze_event_handle_t event)
Definition LevelZeroRuntimeWrappers.cpp:320
DynamicEventPool(DynamicEventPool &&) noexcept=default
DynamicEventPool(L0RTContextWrapper *rtCtx)
Definition LevelZeroRuntimeWrappers.cpp:259
std::vector< UniqueZeEventPool > eventPools
Definition LevelZeroRuntimeWrappers.cpp:247
std::unordered_map< ze_event_handle_t, UniqueZeEvent > takenEvents
Definition LevelZeroRuntimeWrappers.cpp:249
ze_event_handle_t takeEvent()
Definition LevelZeroRuntimeWrappers.cpp:287
size_t currentEventsCnt
Definition LevelZeroRuntimeWrappers.cpp:256
std::vector< UniqueZeEvent > availableEvents
Definition LevelZeroRuntimeWrappers.cpp:248
DynamicEventPool(const DynamicEventPool &)=delete
size_t currentEventsLimit
Definition LevelZeroRuntimeWrappers.cpp:255
Definition LevelZeroRuntimeWrappers.cpp:139
UniqueZeCommandList immCmdListCopy
Definition LevelZeroRuntimeWrappers.cpp:148
L0RTContextWrapper()=default
L0RTContextWrapper & operator=(const L0RTContextWrapper &)=delete
uint32_t copyEngineMaxMemoryFillPatternSize
Definition LevelZeroRuntimeWrappers.cpp:149
ze_device_handle_t device
Definition LevelZeroRuntimeWrappers.cpp:141
L0RTContextWrapper(L0RTContextWrapper &&) noexcept=default
UniqueZeContext context
Definition LevelZeroRuntimeWrappers.cpp:142
L0RTContextWrapper(const uint32_t driverIdx=0, const int32_t devIdx=0)
Definition LevelZeroRuntimeWrappers.cpp:152
ze_driver_handle_t driver
Definition LevelZeroRuntimeWrappers.cpp:140
UniqueZeCommandList immCmdListCompute
Definition LevelZeroRuntimeWrappers.cpp:145
L0RTContextWrapper(const L0RTContextWrapper &)=delete
~StreamWrapper()
Definition LevelZeroRuntimeWrappers.cpp:347
ze_event_handle_t * getLastImplicitEventPtr()
Definition LevelZeroRuntimeWrappers.cpp:349
void sync(ze_event_handle_t explicitEvent=nullptr)
Definition LevelZeroRuntimeWrappers.cpp:354
StreamWrapper(DynamicEventPool &dynEventPool)
Definition LevelZeroRuntimeWrappers.cpp:346
std::deque< ze_event_handle_t > implicitEventStack
Definition LevelZeroRuntimeWrappers.cpp:343
DynamicEventPool & dynEventPool
Definition LevelZeroRuntimeWrappers.cpp:344
void enqueueOp(Func &&op)
Definition LevelZeroRuntimeWrappers.cpp:373
void operator()(ze_command_list_handle_t cmdList) const
Definition LevelZeroRuntimeWrappers.cpp:128
Definition LevelZeroRuntimeWrappers.cpp:120
void operator()(ze_context_handle_t ctx) const
Definition LevelZeroRuntimeWrappers.cpp:121
void operator()(ze_event_handle_t event) const
Definition LevelZeroRuntimeWrappers.cpp:221
void operator()(ze_event_pool_handle_t pool) const
Definition LevelZeroRuntimeWrappers.cpp:228