MLIR: lib/Dialect/SPIRV/IR/MemoryOps.cpp Source File (original) (raw)
1
2
3
4
5
6
7
8
9
10
11
12
15
20
21 #include "llvm/ADT/StringExtras.h"
22 #include "llvm/Support/Casting.h"
23
25
27
28
29
30
31
32
33
34 template
37
39
40 return success();
41 }
42
43 spirv::MemoryAccess memoryAccessAttr;
44 StringAttr memoryAccessAttrName =
45 MemoryOpTy::getMemoryAccessAttrName(state.name);
46 if (spirv::parseEnumStrAttrspirv::MemoryAccessAttr(
47 memoryAccessAttr, parser, state, memoryAccessAttrName))
48 return failure();
49
50 if (spirv::bitEnumContainsAll(memoryAccessAttr,
51 spirv::MemoryAccess::Aligned)) {
52
54 StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
57 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
58 state.attributes)) {
59 return failure();
60 }
61 }
63 }
64
65
66
67
68
69 template
72
74
75 return success();
76 }
77
78 spirv::MemoryAccess memoryAccessAttr;
79 StringRef memoryAccessAttrName =
80 MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
81 if (spirv::parseEnumStrAttrspirv::MemoryAccessAttr(
82 memoryAccessAttr, parser, state, memoryAccessAttrName))
83 return failure();
84
85 if (spirv::bitEnumContainsAll(memoryAccessAttr,
86 spirv::MemoryAccess::Aligned)) {
87
89 StringAttr alignmentAttrName =
90 MemoryOpTy::getSourceAlignmentAttrName(state.name);
93 parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
94 state.attributes)) {
95 return failure();
96 }
97 }
99 }
100
101
102
103
104
105 template
109 std::optionalspirv::MemoryAccess memoryAccessAtrrValue = std::nullopt,
110 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
111
112 printer << ", ";
113
114
115 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
116 : memoryOp.getMemoryAccess())) {
117 elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());
118
119 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
120
121 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
122
123 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
124 : memoryOp.getAlignment())) {
125 elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
126 printer << ", " << *alignment;
127 }
128 }
129 printer << "]";
130 }
131 elidedAttrs.push_back(spirv::attributeNamespirv::StorageClass());
132 }
133
134 template
138 std::optionalspirv::MemoryAccess memoryAccessAtrrValue = std::nullopt,
139 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
140
141 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
142 : memoryOp.getMemoryAccess())) {
143 elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());
144
145 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
146
147 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
148
149 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
150 : memoryOp.getAlignment())) {
151 elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
152 printer << ", " << *alignment;
153 }
154 }
155 printer << "]";
156 }
157 elidedAttrs.push_back(spirv::attributeNamespirv::StorageClass());
158 }
159
160 template
163
164
165
166
167
169 llvm::castspirv::PointerType(ptr.getType()).getPointeeType()) {
170 return op.emitOpError("mismatch in result type and pointer type");
171 }
172 return success();
173 }
174
175 template
177
178
179
180 auto *op = memoryOp.getOperation();
181 auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
182 if (!memAccessAttr) {
183
184
185 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
186 return memoryOp.emitOpError(
187 "invalid alignment specification without aligned memory access "
188 "specification");
189 }
190 return success();
191 }
192
193 auto memAccess = llvm::castspirv::MemoryAccessAttr(memAccessAttr);
194
195 if (!memAccess) {
196 return memoryOp.emitOpError("invalid memory access specifier: ")
197 << memAccessAttr;
198 }
199
200 if (spirv::bitEnumContainsAll(memAccess.getValue(),
201 spirv::MemoryAccess::Aligned)) {
202 if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
203 return memoryOp.emitOpError("missing alignment value");
204 }
205 } else {
206 if (op->getAttr(memoryOp.getAlignmentAttrName())) {
207 return memoryOp.emitOpError(
208 "invalid alignment specification with non-aligned memory access "
209 "specification");
210 }
211 }
212 return success();
213 }
214
215
216
217
218
219 template
221
222
223
224 auto *op = memoryOp.getOperation();
225 auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
226 if (!memAccessAttr) {
227
228
229 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
230 return memoryOp.emitOpError(
231 "invalid alignment specification without aligned memory access "
232 "specification");
233 }
234 return success();
235 }
236
237 auto memAccess = llvm::castspirv::MemoryAccessAttr(memAccessAttr);
238
239 if (!memAccess) {
240 return memoryOp.emitOpError("invalid memory access specifier: ")
241 << memAccess;
242 }
243
244 if (spirv::bitEnumContainsAll(memAccess.getValue(),
245 spirv::MemoryAccess::Aligned)) {
246 if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
247 return memoryOp.emitOpError("missing alignment value");
248 }
249 } else {
250 if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
251 return memoryOp.emitOpError(
252 "invalid alignment specification with non-aligned memory access "
253 "specification");
254 }
255 }
256 return success();
257 }
258
259
260
261
262
264 auto ptrType = llvm::dyn_castspirv::PointerType(type);
265 if (!ptrType) {
266 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
267 "to composite type, but provided ")
268 << type;
269 return nullptr;
270 }
271
272 auto resultType = ptrType.getPointeeType();
273 auto resultStorageClass = ptrType.getStorageClass();
274 int32_t index = 0;
275
276 for (auto indexSSA : indices) {
277 auto cType = llvm::dyn_castspirv::CompositeType(resultType);
278 if (!cType) {
280 baseLoc,
281 "'spirv.AccessChain' op cannot extract from non-composite type ")
282 << resultType << " with index " << index;
283 return nullptr;
284 }
285 index = 0;
286 if (llvm::isaspirv::StructType(resultType)) {
287 Operation *op = indexSSA.getDefiningOp();
288 if (!op) {
289 emitError(baseLoc, "'spirv.AccessChain' op index must be an "
290 "integer spirv.Constant to access "
291 "element of spirv.struct");
292 return nullptr;
293 }
294
295
296
299 baseLoc,
300 "'spirv.AccessChain' index must be an integer spirv.Constant to "
301 "access element of spirv.struct, but provided ")
303 return nullptr;
304 }
305 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
306 emitError(baseLoc, "'spirv.AccessChain' op index ")
307 << index << " out of bounds for " << resultType;
308 return nullptr;
309 }
310 }
311 resultType = cType.getElementType(index);
312 }
314 }
315
319 assert(type && "Unable to deduce return type based on basePtr and indices");
320 build(builder, state, type, basePtr, indices);
321 }
322
323 template
325 printer << ' ' << op.getBasePtr() << '[' << indices
326 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
327 }
328
329 template
331 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
332 indices, accessChainOp.getLoc());
333 if (!resultType)
334 return failure();
335
336 auto providedResultType =
337 llvm::dyn_castspirv::PointerType(accessChainOp.getType());
338 if (!providedResultType)
340 "result type must be a pointer, but provided")
341 << providedResultType;
342
343 if (resultType != providedResultType)
344 return accessChainOp.emitOpError("invalid result type: expected ")
345 << resultType << ", but provided " << providedResultType;
346
347 return success();
348 }
349
352 }
353
354
355
356
357
358 void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
359 MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
360 auto ptrType = llvm::castspirv::PointerType(basePtr.getType());
361 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
362 alignment);
363 }
364
365 ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
366
367 spirv::StorageClass storageClass;
368 OpAsmParser::UnresolvedOperand ptrInfo;
369 Type elementType;
370 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
371 parseMemoryAccessAttributes(parser, result) ||
372 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
373 parser.parseType(elementType)) {
374 return failure();
375 }
376
378 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
379 return failure();
380 }
381
382 result.addTypes(elementType);
383 return success();
384 }
385
387 SmallVector<StringRef, 4> elidedAttrs;
388 StringRef sc = stringifyStorageClass(
389 llvm::castspirv::PointerType(getPtr().getType()).getStorageClass());
390 printer << " \"" << sc << "\" " << getPtr();
391
393
394 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
395 printer << " : " << getType();
396 }
397
399
400
401
403 return failure();
404 }
406 }
407
408
409
410
411
412 ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
413
414 spirv::StorageClass storageClass;
415 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
416 auto loc = parser.getCurrentLocation();
417 Type elementType;
419 parser.parseOperandList(operandInfo, 2) ||
420 parseMemoryAccessAttributes(parser, result) ||
421 parser.parseColon() || parser.parseType(elementType)) {
422 return failure();
423 }
424
426 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
427 result.operands)) {
428 return failure();
429 }
430 return success();
431 }
432
434 SmallVector<StringRef, 4> elidedAttrs;
435 StringRef sc = stringifyStorageClass(
436 llvm::castspirv::PointerType(getPtr().getType()).getStorageClass());
437 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
438
440
441 printer << " : " << getValue().getType();
442 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
443 }
444
446
447
449 return failure();
451 }
452
453
454
455
456
458 printer << ' ';
459
460 StringRef targetStorageClass = stringifyStorageClass(
461 llvm::castspirv::PointerType(getTarget().getType()).getStorageClass());
462 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
463
464 StringRef sourceStorageClass = stringifyStorageClass(
465 llvm::castspirv::PointerType(getSource().getType()).getStorageClass());
466 printer << " \"" << sourceStorageClass << "\" " << getSource();
467
468 SmallVector<StringRef, 4> elidedAttrs;
471 getSourceMemoryAccess(),
472 getSourceAlignment());
473
474 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
475
476 Type pointeeType =
477 llvm::castspirv::PointerType(getTarget().getType()).getPointeeType();
478 printer << " : " << pointeeType;
479 }
480
481 ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
482 spirv::StorageClass targetStorageClass;
483 OpAsmParser::UnresolvedOperand targetPtrInfo;
484
485 spirv::StorageClass sourceStorageClass;
486 OpAsmParser::UnresolvedOperand sourcePtrInfo;
487
488 Type elementType;
489
491 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
493 parser.parseOperand(sourcePtrInfo) ||
494 parseMemoryAccessAttributes(parser, result)) {
495 return failure();
496 }
497
498 if (!parser.parseOptionalComma()) {
499
500 if (parseSourceMemoryAccessAttributes(parser, result)) {
501 return failure();
502 }
503 }
504
505 if (parser.parseColon() || parser.parseType(elementType))
506 return failure();
507
508 if (parser.parseOptionalAttrDict(result.attributes))
509 return failure();
510
513
514 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
515 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
516 return failure();
517 }
518
519 return success();
520 }
521
523 Type targetType =
524 llvm::castspirv::PointerType(getTarget().getType()).getPointeeType();
525
526 Type sourceType =
527 llvm::castspirv::PointerType(getSource().getType()).getPointeeType();
528
529 if (targetType != sourceType)
530 return emitOpError("both operands must be pointers to the same type");
531
533 return failure();
534
535
536
537
538
539
540
541
542
544 }
545
546
547
548
549
550 void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
551 Value basePtr, Value element,
552 ValueRange indices) {
553 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
554 assert(type && "Unable to deduce return type based on basePtr and indices");
555 build(builder, state, type, basePtr, element, indices);
556 }
557
560 }
561
562
563
564
565
566 void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
567 Value basePtr, Value element, ValueRange indices) {
568 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
569 assert(type && "Unable to deduce return type based on basePtr and indices");
570 build(builder, state, type, basePtr, element, indices);
571 }
572
575 }
576
577
578
579
580
581 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
582
583 std::optionalOpAsmParser::UnresolvedOperand initInfo;
584 if (succeeded(parser.parseOptionalKeyword("init"))) {
585 initInfo = OpAsmParser::UnresolvedOperand();
586 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
587 parser.parseRParen())
588 return failure();
589 }
590
592 return failure();
593 }
594
595
597 if (parser.parseColon())
598 return failure();
599 auto loc = parser.getCurrentLocation();
600 if (parser.parseType(type))
601 return failure();
602
603 auto ptrType = llvm::dyn_castspirv::PointerType(type);
604 if (!ptrType)
605 return parser.emitError(loc, "expected spirv.ptr type");
606 result.addTypes(ptrType);
607
608
609 if (initInfo) {
610 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
611 result.operands))
612 return failure();
613 }
614
615 auto attr = parser.getBuilder().getAttrspirv::StorageClassAttr(
616 ptrType.getStorageClass());
617 result.addAttribute(spirv::attributeNamespirv::StorageClass(), attr);
618
619 return success();
620 }
621
623 SmallVector<StringRef, 4> elidedAttrs{
624 spirv::attributeNamespirv::StorageClass()};
625
626 if (getNumOperands() != 0)
627 printer << " init(" << getInitializer() << ")";
628
630 printer << " : " << getType();
631 }
632
634
635
636
637 if (getStorageClass() != spirv::StorageClass::Function) {
638 return emitOpError(
639 "can only be used to model function-level variables. Use "
640 "spirv.GlobalVariable for module-level variables.");
641 }
642
643 auto pointerType = llvm::castspirv::PointerType(getPointer().getType());
644 if (getStorageClass() != pointerType.getStorageClass())
645 return emitOpError(
646 "storage class must match result pointer's storage class");
647
648 if (getNumOperands() != 0) {
649
650
651 auto *initOp = getOperand(0).getDefiningOp();
652 if (!initOp || !isa<spirv::ConstantOp,
653 spirv::ReferenceOfOp,
654 spirv::AddressOfOp>(initOp))
655 return emitOpError("initializer must be the result of a "
656 "constant or spirv.GlobalVariable op");
657 }
658
659 auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
660 return op->getAttr(
661 llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
662 };
663
664
665 for (auto decoration :
666 {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
667 spirv::Decoration::BuiltIn}) {
668 if (auto attr = getDecorationAttr(decoration))
669 return emitOpError("cannot have '")
670 << llvm::convertToSnakeFromCamelCase(
671 stringifyDecoration(decoration))
672 << "' attribute (only allowed in spirv.GlobalVariable)";
673 }
674
675
676
677
678
679 auto pointeePtrType = dyn_castspirv::PointerType(getPointeeType());
680 if (!pointeePtrType) {
681 if (auto pointeeArrayType = dyn_castspirv::ArrayType(getPointeeType())) {
682 pointeePtrType =
683 dyn_castspirv::PointerType(pointeeArrayType.getElementType());
684 }
685 }
686
687 if (pointeePtrType && pointeePtrType.getStorageClass() ==
688 spirv::StorageClass::PhysicalStorageBuffer) {
689 bool hasAliasedPtr =
690 getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
691 bool hasRestrictPtr =
692 getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
693
694 if (!hasAliasedPtr && !hasRestrictPtr)
695 return emitOpError() << " with physical buffer pointer must be decorated "
696 "either 'AliasedPointer' or 'RestrictPointer'";
697
698 if (hasAliasedPtr && hasRestrictPtr)
699 return emitOpError()
700 << " with physical buffer pointer must have exactly one "
701 "aliasing decoration";
702 }
703
704 return success();
705 }
706
707 }
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This class helps build Operations.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
Operation is the basic unit of execution within MLIR.
OperationName getName()
The name of an operation is the key identifier for it.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static PointerType get(Type pointeeType, StorageClass storageClass)
@ Type
An inlay hint that for a type annotation.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName=spirv::attributeName< EnumClass >())
Parses the next string attribute in parser as an enumerant of the given EnumClass.
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printSourceMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state)
Parses optional memory access (a.k.a.
static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc)
void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs)
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val)
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp)
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer)
static void printMemoryAccessAttribute(MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl< StringRef > &elidedAttrs, std::optional< spirv::MemoryAccess > memoryAccessAtrrValue=std::nullopt, std::optional< uint32_t > alignmentAttrValue=std::nullopt)
LogicalResult extractValueFromConstOp(Operation *op, int32_t &value)
ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state)
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices)
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This represents an operation in an abstracted form, suitable for use with the builder APIs.