| #ifndef LLO_OPS |
| #define LLO_OPS |
| |
| include "mlir/IR/OpBase.td" |
| include "mlir/IR/EnumAttr.td" |
| include "mlir/IR/AttrTypeBase.td" |
| include "mlir/Pass/PassBase.td" |
| include "mlir/Interfaces/ControlFlowInterfaces.td" |
| include "mlir/Interfaces/SideEffectInterfaces.td" |
| include "mlir/Interfaces/InferTypeOpInterface.td" |
| |
| defvar cmp_directions = ["Eq", "Ne", "Ge", "Gt", "Le", "Lt"]; |
| |
| def LLO_Dialect : Dialect { |
| let name = "llo"; |
| let cppNamespace = "::mlir::llo"; |
| let useDefaultAttributePrinterParser = 1; |
| } |
| |
| class LLO_Attr<string name, string mnemonic_, list<Trait> traits = []> |
| : AttrDef<LLO_Dialect, name, traits> { |
| let mnemonic = mnemonic_; |
| } |
| |
| class LLO_Type<string name, string mnemonic_, list<Trait> traits = []> |
| : TypeDef<LLO_Dialect, name, traits> { |
| let mnemonic = mnemonic_; |
| } |
| |
| class LLO_EnumAttr<EnumAttrInfo enum, string name> |
| : EnumAttr<LLO_Dialect, enum, name> { |
| // Name of the LLO enum this attribute is modeling. |
| string lloEnumType; |
| } |
| |
| class LLO_Op<string mnemonic, list<Trait> traits = []> : |
| Op<LLO_Dialect, mnemonic, traits> { |
| string customInstructionLowering; |
| bit hasInstructionLowering = 1; |
| string builderMethod = ?; |
| } |
| |
| // TODO(apaszke): Model resources more precisely |
| def VmemAlloc : MemAlloc<DefaultResource>; |
| def VmemWrite : MemoryEffects<[MemWrite<DefaultResource>]>; |
| def VmemRead : MemoryEffects<[MemRead<DefaultResource>]>; |
| def SmemAlloc : MemAlloc<DefaultResource>; |
| |
| def LLO_Vector : Type< |
| And<[IsVectorTypePred, |
| Or<[ |
| And<[ |
| CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">, |
| CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32"> |
| ]>, |
| CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{" |
| "8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">, |
| ]> |
| ]>, |
| "native-sized vreg", "::mlir::VectorType">; |
| |
| // Only some of those types are actually allowed, depending on the target. |
| def LLO_VectorMask : Type< |
| And<[IsVectorTypePred, |
| Or<[ |
| CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">, |
| CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128, 2}">, |
| CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128, 4}">, |
| ]> |
| ]>, |
| "", "::mlir::VectorType">; |
| |
| def LLO_AllocaVmemOp : LLO_Op<"alloca_vmem"> { |
| let arguments = (ins I32Attr:$num_words); |
| let results = (outs Res<I32, "", [VmemAlloc]>:$result); |
| let assemblyFormat = [{ $num_words attr-dict `:` type($result) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_AllocaSmemOp : LLO_Op<"alloca_smem"> { |
| let arguments = (ins I32Attr:$num_words); |
| let results = (outs Res<I32, "", [SmemAlloc]>:$result); |
| let assemblyFormat = [{ $num_words attr-dict `:` type($result) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_AllocaSyncFlagOp : LLO_Op<"alloca_sflag"> { |
| let arguments = (ins I32Attr:$num_flags); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ $num_flags attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.AllocateScopedSflags($_self.getNumFlags())"; |
| } |
| |
| // TODO(apaszke): Expose more cases |
| def LLO_GainLatchMode : I32EnumAttr<"GainLatchMode", "Gain latch mode", [ |
| I32EnumAttrCase<"kNoXposeF32", 0, "f32">, |
| I32EnumAttrCase<"kXposeF32", 1, "xpose.f32">, |
| I32EnumAttrCase<"kNoXposeHiF32", 2, "hi.f32">, |
| I32EnumAttrCase<"kXposeHiF32", 3, "xpose.hi.f32">, |
| I32EnumAttrCase<"kNoXposeLowF32", 4, "low.f32">, |
| I32EnumAttrCase<"kXposeLowF32", 5, "xpose.low.f32">, |
| I32EnumAttrCase<"kNoXposeSoftMiddleEightF32", 6, "soft_middle_eight.f32">, |
| I32EnumAttrCase<"kXposeSoftMiddleEightF32", 7, "xpose.soft_middle_eight.f32">, |
| I32EnumAttrCase<"kNoXposeSoftLowEightF32", 8, "soft_low_eight.f32">, |
| I32EnumAttrCase<"kXposeSoftLowEightF32", 9, "xpose.soft_low_eight.f32">, |
| I32EnumAttrCase<"kXposePackedBf16", 10, "xpose.packed.bf16">, |
| I32EnumAttrCase<"kNoXposePackedBf16", 11, "packed.bf16">, |
| I32EnumAttrCase<"kNoXposeS8", 20, "s8">, |
| I32EnumAttrCase<"kXposeS8", 21, "xpose.s8">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_GainLatchModeEnum |
| : LLO_EnumAttr<LLO_GainLatchMode, "gain_latch_mode"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::GainLatchMode"; |
| } |
| |
| def LLO_VectorLatchOp : LLO_Op<"vlatch"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_GainLatchModeEnum:$gain_latch_mode, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs); |
| // Even though this is very close to the VectorLatch LLO instruction, |
| // we use the builder to support non-native modes. |
| let builderMethod = [{ |
| $_builder.Vlatch( |
| GetLloValue($_self.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()), |
| op.getMxuId()); |
| }]; |
| } |
| |
| def LLO_MatmulMode : I32EnumAttr<"MatmulMode", "Matrix multiplication mode", [ |
| I32EnumAttrCase<"kRound", 0, "round">, |
| I32EnumAttrCase<"kHigh", 1, "high">, |
| I32EnumAttrCase<"kLow", 2, "low">, |
| I32EnumAttrCase<"kSoftMiddleEight", 3, "soft_middle_eight">, |
| I32EnumAttrCase<"kSoftLowEight", 4, "soft_low_eight">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_MatmulModeEnum |
| : LLO_EnumAttr<LLO_MatmulMode, "matmul_mode"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::MatmulMode"; |
| } |
| |
| def LLO_MatmulDataFormat : I32EnumAttr<"MatmulDataFormat", "Matrix multiplication data format", [ |
| I32EnumAttrCase<"kInvalid", 0, "invalid">, |
| I32EnumAttrCase<"kF32", 1, "f32">, |
| I32EnumAttrCase<"kBf16", 2, "bf16">, |
| I32EnumAttrCase<"kS8", 6, "s8">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_MatmulDataFormatEnum |
| : LLO_EnumAttr<LLO_MatmulDataFormat, "matmul_data_format"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::MatmulDataFormat"; |
| } |
| |
| def LLO_VectorMatmulOp : LLO_Op<"vmatmul"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_MatmulModeEnum:$mode, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id, |
| DefaultValuedAttr<BoolAttr,"false">:$dwg, |
| DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format |
| ); |
| let results = (outs); |
| // Even though this is very close to the VectorMatmul LLO instruction, |
| // we use the builder to support non-native matmul modes. |
| let builderMethod = [{ |
| $_self.getDataFormat() == MatmulDataFormat::kF32 |
| ? $_builder.Vmatmul( |
| GetLloValue(op.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::MatmulMode>(op.getMode()), |
| op.getDwg() |
| ? xla::jellyfish::DoneWithGainsMode::kNormal |
| : xla::jellyfish::DoneWithGainsMode::kNone, |
| op.getMxuId()) |
| : $_builder.Vmatmul( |
| GetLloValue(op.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()), |
| op.getDwg() |
| ? xla::jellyfish::DoneWithGainsMode::kNormal |
| : xla::jellyfish::DoneWithGainsMode::kNone, |
| op.getMxuId()); |
| }]; |
| } |
| |
| // GLC+ only. |
| def LLO_VectorMatprepSubrOp : LLO_Op<"vmatprep.subr"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_GainLatchModeEnum:$gain_latch_mode, |
| DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs); |
| let builderMethod = [{ |
| $_builder.VmatprepSubr( |
| $_builder.VprepareForLatch( |
| GetLloValue(op.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()) |
| ).first, |
| static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()), |
| op.getMxuId()); |
| }]; |
| } |
| |
| // GLC+ only. |
| def LLO_VectorMatprepMubrOp : LLO_Op<"vmatprep.mubr"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_MatmulModeEnum:$mode, |
| DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs); |
| let builderMethod = [{ |
| $_builder.VmatprepMubr( |
| $_builder.VprepareForMatmul( |
| GetLloValue(op.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::MatmulMode>(op.getMode())).first, |
| static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()), |
| op.getMxuId()); |
| }]; |
| } |
| |
| // GLC+ only. |
| def LLO_VectorLatchIOp : LLO_Op<"vlatchi"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| I32Attr:$latch_variant, |
| LLO_GainLatchModeEnum:$gain_latch_mode, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs); |
| let builderMethod = [{ |
| $_self.getLatchVariant() == 3 |
| ? $_builder.Vlatch3( |
| GetLloValue($_self.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()), |
| op.getMxuId()) |
| : ($_self.getLatchVariant() == 2 |
| ? $_builder.Vlatch2( |
| GetLloValue($_self.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()), |
| op.getMxuId()) |
| : $_builder.Vlatch1( |
| GetLloValue($_self.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()), |
| op.getMxuId()) |
| ); |
| }]; |
| let hasVerifier = 1; |
| } |
| |
| // GLC+ only. |
| def LLO_VectorMatmulMubrOp : LLO_Op<"vmatmul.mubr"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_MatmulModeEnum:$mode, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id, |
| DefaultValuedAttr<BoolAttr,"false">:$dwg, |
| DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format |
| ); |
| let results = (outs); |
| let builderMethod = [{ |
| $_builder.VmatmulMubr( |
| $_builder.VprepareForMatmul( |
| GetLloValue(op.getInput(), $_value_map), |
| static_cast<::xla::jellyfish::MatmulMode>(op.getMode())).first, |
| static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()), |
| op.getDwg() |
| ? xla::jellyfish::DoneWithGainsMode::kNormal |
| : xla::jellyfish::DoneWithGainsMode::kNone, |
| op.getMxuId(), |
| /*mrb_addr=*/std::nullopt); |
| }]; |
| } |
| |
| // TODO(apaszke): Add support for the data format attribute |
| def LLO_VectorMatresOp : LLO_Op<"vmatres"> { |
| let arguments = (ins |
| DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format, |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs LLO_Vector:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorMatres( |
| static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()), |
| /*mrb_addr=*/0, |
| op.getMxuId(), b.region()); |
| }]; |
| } |
| |
| // TODO(apaszke): Support the transposed variant |
| def LLO_VectorDoneWithGainsOp : LLO_Op<"vdwg"> { |
| let arguments = (ins |
| DefaultValuedAttr<I32Attr,"0">:$mxu_id |
| ); |
| let results = (outs); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorDoneWithGains( |
| ::xla::jellyfish::DoneWithGainsMode::kNormal, |
| op.getMxuId(), |
| b.region()); |
| }]; |
| } |
| |
| def LLO_VxposeMode : I32EnumAttr<"VxposeMode", "Transpose mode", [ |
| I32EnumAttrCase<"kB32", 0, "b32">, |
| I32EnumAttrCase<"kCompressedB16", 1, "b16.c">, |
| I32EnumAttrCase<"kCompressedB8", 2, "b8.c">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_VxposeModeEnum |
| : LLO_EnumAttr<LLO_VxposeMode, "vxpose_mode"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::VxposeMode"; |
| } |
| |
| def LLO_VectorTransposeOp : LLO_Op<"vxpose"> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| LLO_VxposeModeEnum:$mode, |
| I32:$width, |
| I32Attr:$chunk_id, |
| I32Attr:$number_of_chunks, |
| DefaultValuedAttr<I32Attr,"0">:$xlu_id, |
| OptionalAttr<I32Attr>:$source_bus |
| ); |
| let results = (outs); |
| } |
| |
| def LLO_VectorTransposeBinaryCompressedB16Op : LLO_Op<"vxpose.c.b16"> { |
| let arguments = (ins |
| LLO_Vector:$first, |
| LLO_Vector:$second, |
| I32:$width, |
| I32Attr:$chunk_id, |
| I32Attr:$number_of_chunks, |
| DefaultValuedAttr<I32Attr,"0">:$xlu_id, |
| OptionalAttr<I32Attr>:$source_bus |
| ); |
| let results = (outs); |
| } |
| |
| def LLO_VectorTransposeResultOp : LLO_Op<"vxpose.result"> { |
| let arguments = (ins DefaultValuedAttr<I32Attr,"0">:$xlu_id); |
| let results = (outs LLO_Vector:$result); |
| } |
| |
| def LLO_AddrScaledOp : LLO_Op<"saddr_scaled", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words, |
| I32Attr:$multiplier_in_bytes |
| ); |
| let results = (outs I32:$result); |
| let builderMethod = [{ |
| $_builder.AddrScaled( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| op.getMultiplierInBytes()); |
| }]; |
| } |
| |
| // TODO(apaszke): Deduplicate ScalarAddress ops. |
| def LLO_ScalarAddressVmemOp : LLO_Op<"saddr.vmem", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words |
| ); |
| let results = (outs I32:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarAddress( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| ::xla::jellyfish::MemorySpace::kVmem, |
| b.region()); |
| }]; |
| } |
| |
| def LLO_ScalarAddressSmemOp : LLO_Op<"saddr.smem", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words |
| ); |
| let results = (outs I32:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarAddress( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| ::xla::jellyfish::MemorySpace::kSmem, |
| b.region()); |
| }]; |
| } |
| |
| def LLO_ScalarAddressSflagOp : LLO_Op<"saddr.sflag", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words |
| ); |
| let results = (outs I32:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarAddress( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| ::xla::jellyfish::MemorySpace::kSflag, |
| b.region()); |
| }]; |
| } |
| |
| def LLO_ScalarAddressCmemOp : LLO_Op<"saddr.cmem", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words |
| ); |
| let results = (outs I32:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarAddress( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| ::xla::jellyfish::MemorySpace::kCmem, |
| b.region()); |
| }]; |
| } |
| |
| def LLO_ScalarAddressHbmOp : LLO_Op<"saddr.hbm", [SameOperandsAndResultType]> { |
| let arguments = (ins |
| I32:$address, |
| I32:$offset_in_words |
| ); |
| let results = (outs I32:$result); |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarAddress( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getOffsetInWords(), $_value_map), |
| ::xla::jellyfish::MemorySpace::kHbm, |
| b.region()); |
| }]; |
| } |
| |
| def LLO_VldWithArbitrarySlaneStrideOp |
| : LLO_Op<"vector_load_slane_stride", [VmemRead]> { |
| let arguments = (ins |
| I32:$address, |
| I64Attr:$sublane_stride, // In sublane-sized units |
| I64Attr:$sublane_count, |
| Optional<UI32>:$sublane_mask |
| ); |
| let results = (outs AnyType:$result); |
| let assemblyFormat = [{ |
| $address (`masked` $sublane_mask^)? attr-dict `:` type($address) `->` type($result) |
| }]; |
| let builderMethod = [{ |
| $_builder.LdWithArbitrarySlaneStride( |
| GetLloValue(op.getAddress(), $_value_map), |
| op.getSublaneStride(), |
| op.getSublaneCount(), |
| op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr); |
| }]; |
| } |
| |
| def LLO_VstWithArbitrarySlaneStrideOp |
| : LLO_Op<"vector_store_slane_stride", [VmemWrite]> { |
| let arguments = (ins |
| I32:$address, |
| LLO_Vector:$to_store, |
| I64Attr:$sublane_stride, // In sublane-sized units |
| I64Attr:$sublane_count, |
| Optional<UI32>:$sublane_mask |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ |
| $to_store `into` $address (`masked` $sublane_mask^)? attr-dict `:` |
| type($to_store) `into` type($address) |
| }]; |
| let builderMethod = [{ |
| $_builder.VstWithArbitrarySlaneStride( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getToStore(), $_value_map), |
| op.getSublaneStride(), |
| op.getSublaneCount(), |
| op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr); |
| }]; |
| } |
| |
| def LLO_VstMaskedWithArbitrarySlaneStrideOp |
| : LLO_Op<"vector_store_masked_slane_stride", [VmemWrite]> { |
| let arguments = (ins |
| I32:$address, |
| LLO_VectorMask:$mask, |
| LLO_Vector:$to_store, |
| I64Attr:$sublane_stride, // In sublane-sized units |
| I64Attr:$sublane_count, |
| Optional<UI32>:$sublane_mask |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ |
| $to_store `into` $address `masked` $mask (`slmask` $sublane_mask^)? attr-dict `:` |
| type($to_store) `into` type($address) `,` type($mask) |
| }]; |
| let builderMethod = [{ |
| $_builder.VstMaskedWithArbitrarySlaneStride( |
| GetLloValue(op.getAddress(), $_value_map), |
| GetLloValue(op.getMask(), $_value_map), |
| GetLloValue(op.getToStore(), $_value_map), |
| op.getSublaneStride(), |
| op.getSublaneCount(), |
| op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr); |
| }]; |
| } |
| |
| def LLO_VectorLoadOp |
| : LLO_Op<"vector_load", [AttrSizedOperandSegments, VmemRead]> { |
| let arguments = (ins |
| I32:$address, |
| Optional<I32>:$displacement, // In sublane-sized units |
| Optional<UI32>:$sublane_mask, |
| DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units |
| DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride |
| ); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ |
| $address (`+` $displacement^)? (`masked` $sublane_mask^)? attr-dict `:` |
| type($address) `->` type($result) |
| }]; |
| } |
| |
| def LLO_VectorStoreOp |
| : LLO_Op<"vector_store", [AttrSizedOperandSegments, VmemWrite]> { |
| let arguments = (ins |
| I32:$address, |
| Optional<I32>:$displacement, // In sublane-sized units |
| LLO_Vector:$to_store, |
| Optional<UI32>:$sublane_mask, |
| DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units |
| DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ |
| $to_store `into` $address (`+` $displacement^)? (`masked` $sublane_mask^)? attr-dict `:` |
| type($to_store) `into` type($address) |
| }]; |
| } |
| |
| def LLO_VectorStoreMaskedOp |
| : LLO_Op<"vector_store_masked", [AttrSizedOperandSegments, VmemWrite]> { |
| let arguments = (ins |
| I32:$address, |
| Optional<I32>:$displacement, // In sublane-sized units |
| LLO_VectorMask:$mask, |
| LLO_Vector:$to_store, |
| DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units |
| DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride, |
| Optional<UI32>:$sublane_mask |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ |
| $to_store `into` $address (`+` $displacement^)? `masked` $mask (`slmask` $sublane_mask^)? attr-dict `:` |
| type($to_store) `into` type($address) `,` type($mask) |
| }]; |
| } |
| |
| def LLO_ConstantOp : LLO_Op<"constant", [Pure]> { |
| let arguments = (ins AnyAttr:$value); |
| let results = (outs AnyType:$result); |
| let assemblyFormat = "attr-dict `:` type($result)"; |
| let hasInstructionLowering = 0; |
| } |
| |
| |
| class LLO_CrossLaneReductionOp<string name, string instr> |
| : LLO_Op<name, [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| DefaultValuedAttr<I64Attr,"0">:$xlu_id, |
| OptionalAttr<I64Attr>:$source_bus |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ |
| $input attr-dict `:` type($input) |
| }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorXlaneResult( |
| b.Instruction(::xla::jellyfish::LloInstruction::Create}] # instr # [{( |
| GetLloValue(op.getInput(), $_value_map), |
| op.getXluId(), |
| op.getSourceBus() |
| ? std::optional<int32_t>(op.getSourceBus().value()) |
| : std::nullopt, |
| b.region())), |
| op.getXluId(), b.region()); |
| }]; |
| } |
| |
| def LLO_VectorAddReduceF32Op : LLO_CrossLaneReductionOp<"vadd.xlane.f32", "VectorAddReduceF32">; |
| def LLO_VectorMaxReduceF32Op : LLO_CrossLaneReductionOp<"vmax.xlane.f32", "VectorMaxReduceF32">; |
| |
| class LLO_SublaneReductionOp<string name, string method_suffix> |
| : LLO_Op<name, [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| LLO_Vector:$input |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ |
| $input attr-dict `:` type($input) |
| }]; |
| let builderMethod = "$_builder.Vslane" # method_suffix # |
| "(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| |
| def LLO_VectorAddSublaneReduceF32Op : LLO_SublaneReductionOp<"vadd.slane.f32", "SumF32">; |
| def LLO_VectorMaxSublaneReduceF32Op : LLO_SublaneReductionOp<"vmax.slane.f32", "MaxF32">; |
| |
| def LLO_VectorRotateOp : LLO_Op<"vrot.lane", [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| SI32Attr:$amount, |
| OptionalAttr<SI32Attr>:$stride |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ $input `,` $amount attr-dict (`stride` $stride^)? `:` type($input) }]; |
| let builderMethod = [{$_builder.Vpermuteres($_builder.Vrotate( |
| GetLloValue(op.getInput(), $_value_map), |
| b.SimmS32(op.getAmount()), |
| op.getStride() |
| ? std::optional<::xla::jellyfish::LloValue*>( |
| b.SimmS32(op.getStride().value())) |
| : std::nullopt, |
| ::xla::jellyfish::BitDataFormat::kB32)); |
| }]; |
| } |
| def LLO_VectorSublaneRotateOp : LLO_Op<"vrot.slane", [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| LLO_Vector:$input, |
| I32Attr:$amount |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ $input `,` $amount attr-dict `:` type($input) }]; |
| let builderMethod = [{ $_builder.VslaneRotateTZ(GetLloValue($_self.getInput(), $_value_map), $_self.getAmount()) }]; |
| } |
| |
| |
| def LLO_VectorCreateMaskOp : LLO_Op<"vcmask", [Pure]> { |
| let arguments = (ins |
| I64Attr:$sublane_start, |
| I64Attr:$sublane_end, // inclusive |
| I64Attr:$lane_start, |
| I64Attr:$lane_end // inclusive |
| ); |
| let results = (outs LLO_VectorMask:$result); |
| let assemblyFormat = [{ |
| `[` $sublane_start `:` $sublane_end `]` `[` $lane_start `:` $lane_end `]` |
| attr-dict `:` type($result) |
| }]; |
| let builderMethod = [{ |
| $_builder.CreateVmask(op.getSublaneStart(), op.getSublaneEnd(), |
| op.getLaneStart(), op.getLaneEnd()) |
| }]; |
| } |
| |
| def LLO_VectorCreateSublaneMaskOp : LLO_Op<"vsmask", [Pure]> { |
| let arguments = (ins I32:$limit); |
| let results = (outs LLO_VectorMask:$result); |
| let assemblyFormat = [{ |
| $limit attr-dict `:` type($limit) `->` type($result) |
| }]; |
| } |
| |
| def LLO_VectorCreateLaneMaskOp : LLO_Op<"vlmask", [Pure]> { |
| let arguments = (ins I32:$limit); |
| let results = (outs LLO_VectorMask:$result); |
| let assemblyFormat = [{ |
| $limit attr-dict `:` type($limit) `->` type($result) |
| }]; |
| } |
| |
| class LLO_VectorMaskBinOp<string name, string opcode, |
| list<Trait> traits = [Pure, SameOperandsAndResultType]> |
| : LLO_Op<name, traits> { |
| let arguments = (ins |
| LLO_VectorMask:$lhs, |
| LLO_VectorMask:$rhs |
| ); |
| let results = (outs LLO_VectorMask:$result); |
| |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorMaskBinop( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| GetLloValue($_self->getOperand(1), $_value_map), |
| $_region); |
| }]; |
| } |
| def LLO_VectorMaskAndOp : LLO_VectorMaskBinOp<"vmand", "kVectorMaskAnd">; |
| def LLO_VectorMaskOrOp : LLO_VectorMaskBinOp<"vmor", "kVectorMaskOr">; |
| def LLO_VectorMaskXorOp : LLO_VectorMaskBinOp<"vmxor", "kVectorMaskXor">; |
| |
| class LLO_VectorMaskUnOp<string name, string opcode, |
| list<Trait> traits = [Pure, SameOperandsAndResultType]> |
| : LLO_Op<name, traits> { |
| let arguments = (ins LLO_VectorMask:$operand); |
| let results = (outs LLO_VectorMask:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict `:` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorMaskUnop( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| def LLO_VectorMaskNegateOp : LLO_VectorMaskUnOp<"vmneg", "kVectorMaskNegate">; |
| |
| def LLO_VectorSelectOp : LLO_Op<"vselect", [Pure]> { |
| let arguments = (ins |
| LLO_VectorMask:$mask, |
| LLO_Vector:$if_true, |
| LLO_Vector:$if_false |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ |
| $mask `,` $if_true `,` $if_false attr-dict `:` |
| `(` type($mask) `,` type($if_true) `,` type($if_false) `)` `->` type($result) |
| }]; |
| } |
| |
| class LLO_VectorBinOp<string name, string opcode, |
| list<Trait> traits = [Pure, SameOperandsAndResultType]> |
| : LLO_Op<name, traits> { |
| let arguments = (ins |
| LLO_Vector:$lhs, |
| LLO_Vector:$rhs |
| ); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorBinop( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| GetLloValue($_self->getOperand(1), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| class LLO_VectorBinOpAnyType<string name, string opcode, |
| list<Trait> traits = [Pure]> |
| : LLO_VectorBinOp<name, opcode, traits> { |
| let assemblyFormat = [{ |
| $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) |
| }]; |
| } |
| |
| def LLO_VectorAddF32Op : LLO_VectorBinOp<"vadd.f32", "kVectorAddF32">; |
| def LLO_VectorSubF32Op : LLO_VectorBinOp<"vsub.f32", "kVectorSubtractF32">; |
| def LLO_VectorMulF32Op : LLO_VectorBinOp<"vmul.f32", "kVectorMultiplyF32">; |
| def LLO_VectorMaxF32Op : LLO_VectorBinOp<"vmax.f32", "kVectorMaximumF32">; |
| def LLO_VectorMinF32Op : LLO_VectorBinOp<"vmin.f32", "kVectorMinimumF32">; |
| def LLO_VectorPowF32Op : LLO_VectorBinOp<"vpow.f32", "kVectorPowF32">; |
| def LLO_VectorAddS32Op : LLO_VectorBinOp<"vadd.s32", "kVectorAddS32">; |
| def LLO_VectorSubS32Op : LLO_VectorBinOp<"vsub.s32", "kVectorSubtractS32">; |
| def LLO_VectorMulS32Op : LLO_VectorBinOp<"vmul.s32", "UNUSED"> { |
| let builderMethod = [{ |
| $_builder.VmulU32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_VectorDivS32Op : LLO_VectorBinOp<"vdiv.s32", "UNUSED"> { |
| let builderMethod = [{ |
| $_builder.VdivS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_VectorRemS32Op : LLO_VectorBinOp<"vrem.s32", "UNUSED"> { |
| let builderMethod = [{ |
| $_builder.VremS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_VectorShiftLeftLogicalOp : LLO_VectorBinOpAnyType<"vshll", "kVectorShiftLeftLogical">; |
| def LLO_VectorShiftRightLogicalOp : LLO_VectorBinOpAnyType<"vshrl", "kVectorShiftRightLogical">; |
| def LLO_VectorShiftRightArithmeticOp : LLO_VectorBinOpAnyType<"vshra", "kVectorShiftRightArithmetic">; |
| def LLO_VectorAndU32Op : LLO_VectorBinOp<"vand.u32", "kVectorAndU32">; |
| def LLO_VectorOrU32Op : LLO_VectorBinOp<"vor.u32", "kVectorOrU32">; |
| def LLO_VectorXOrU32Op : LLO_VectorBinOp<"vxor.u32", "kVectorXorU32">; |
| def LLO_VectorMaxS32Op : LLO_VectorBinOp<"vmax.s32", "UNUSED"> { |
| let builderMethod = [{ |
| $_builder.VmaxS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_VectorMinS32Op : LLO_VectorBinOp<"vmin.s32", "UNUSED"> { |
| let builderMethod = [{ |
| $_builder.VminS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| |
| class LLO_VectorUnOp<string name, string opcode> |
| : LLO_Op<name, [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins LLO_Vector:$operand); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict `:` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorUnop( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_VectorRsqrtF32Op : LLO_VectorUnOp<"vrsqrt.f32", "unused"> { |
| let builderMethod = "$_builder.VrsqrtNr(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorRecipF32Op : LLO_VectorUnOp<"vrecip.f32", "unused"> { |
| let builderMethod = "$_builder.VrecpNr(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorSqrtF32Op : LLO_VectorUnOp<"vsqrt.f32", "unused"> { |
| let builderMethod = "$_builder.VsqrtF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorExpF32Op : LLO_VectorUnOp<"vexp.f32", "unused"> { |
| let builderMethod = "$_builder.Vexp(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorPow2F32Op : LLO_VectorUnOp<"vpow2.f32", "kVectorPow2F32AndPop">; |
| def LLO_VectorCosF32Op : LLO_VectorUnOp<"vcos.f32", "unused"> { |
| let builderMethod = "$_builder.Vcos(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorSinF32Op : LLO_VectorUnOp<"vsin.f32", "unused"> { |
| let builderMethod = "$_builder.Vsin(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorTanhF32Op : LLO_VectorUnOp<"vtanh.f32", "unused"> { |
| let builderMethod = "$_builder.Vtanh(xla::F32, GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorLogF32Op : LLO_VectorUnOp<"vln.f32", "unused"> { |
| let builderMethod = "$_builder.Vln(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorLog1pF32Op : LLO_VectorUnOp<"vln1p.f32", "unused"> { |
| let builderMethod = "$_builder.Vln1p(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorNegF32Op : LLO_VectorUnOp<"vneg.f32", "unused"> { |
| let builderMethod = "$_builder.VnegF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorNegS32Op : LLO_VectorUnOp<"vneg.s32", "UNUSED"> { |
| let builderMethod = "$_builder.VnegS32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorAbsF32Op : LLO_VectorUnOp<"vabs.f32", "unused"> { |
| let builderMethod = "$_builder.VabsF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorRoundF32Op : LLO_VectorUnOp<"vround.f32", "unused"> { |
| let builderMethod = "$_builder.VroundF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_VectorRoundEvenF32Op : LLO_VectorUnOp<"vround.nearest_even.f32", "unused"> { |
| let builderMethod = "$_builder.VroundNearestEvenF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| |
| class LLO_VectorCompareOp<string name, string direction, string xlaType, string order> |
| : LLO_Op<name, [Pure]> { |
| let arguments = (ins |
| LLO_Vector:$lhs, |
| LLO_Vector:$rhs |
| ); |
| let results = (outs LLO_VectorMask:$result); |
| |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorCompare( |
| GetLloValue($_self->getOperand(0), $_value_map), |
| GetLloValue($_self->getOperand(1), $_value_map), |
| ::xla::Comparison( |
| ::xla::ComparisonDirection::}] # direction # [{, |
| ::xla::}] # xlaType # [{, |
| ::xla::ComparisonOrder::}] # order # [{), |
| $_region); |
| }]; |
| } |
| |
| foreach dir = cmp_directions in { |
| def LLO_VectorCmp#dir#S32Op : LLO_VectorCompareOp< |
| "vcmp."#!tolower(dir)#".s32", "k"#dir, "S32", "kTotal">; |
| } |
| foreach dir = cmp_directions in { |
| def LLO_VectorCmp#dir#F32Op : LLO_VectorCompareOp< |
| "vcmp."#!tolower(dir)#".f32", "k"#dir, "F32", "kPartial">; |
| } |
| |
| class LLO_VectorConvertOp<string name, string opcode> |
| : LLO_Op<name, [Pure]> { |
| let arguments = (ins LLO_Vector:$operand); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorUnop( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_VectorConvertS32ToF32Op |
| : LLO_VectorConvertOp<"vcvt.s32.f32", "kVectorConvertS32ToF32">; |
| def LLO_VectorConvertF32ToS32Op |
| : LLO_VectorConvertOp<"vcvt.f32.s32", "kVectorConvertF32ToS32">; |
| def LLO_VectorConvertF32ToS32TowardsZeroPseudoOp |
| : LLO_VectorConvertOp<"vcvt.f32.s32.to.zero.pseudo", "kVectorConvertF32ToS32TowardsZeroPseudo">; |
| |
| |
| def LLO_VpackFormat : IntEnumAttr<I16, "VpackFormat", "Vector pack format", [ |
| I32EnumAttrCase<"kInvalid", 0, "invalid">, |
| I32EnumAttrCase<"kCompressedBf16", 1, "compressed_bf16">, |
| I32EnumAttrCase<"kCompressedB16", 2, "compressed_b16">, |
| I32EnumAttrCase<"kCompressedB8", 3, "compressed_b8">, |
| I32EnumAttrCase<"kCompressedB4", 4, "compressed_b4">, |
| I32EnumAttrCase<"kCompressedB2", 5, "compressed_b2">, |
| I32EnumAttrCase<"kCompressedB1", 6, "compressed_b1">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| let underlyingType = "uint16_t"; |
| } |
| |
| def LLO_VpackFormatEnum |
| : LLO_EnumAttr<LLO_VpackFormat, "vpack_format"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::VpackFormat"; |
| } |
| |
| def LLO_VectorUnpackOp : LLO_Op<"vunpack", [Pure]> { |
| let arguments = (ins |
| I32Attr:$sublane_idx, |
| LLO_VpackFormatEnum:$format, |
| LLO_Vector:$operand |
| ); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorUnpack( |
| ::xla::jellyfish::LloOpcode::kVectorUnpack, |
| $_self.getSublaneIdx(), |
| static_cast<::xla::jellyfish::VpackFormat>($_self.getFormat()), |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_VectorPackOp : LLO_Op<"vpack", [Pure]> { |
| let arguments = (ins |
| LLO_VpackFormatEnum:$format, |
| LLO_Vector:$high, |
| LLO_Vector:$low |
| ); |
| let results = (outs LLO_Vector:$result); |
| |
| let assemblyFormat = [{ $high `,` $low attr-dict `:` type($high) `,` type($low) `->` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateVectorPack( |
| ::xla::jellyfish::LloOpcode::kVectorPack, |
| static_cast<::xla::jellyfish::VpackFormat>($_self.getFormat()), |
| GetLloValue($_self.getHigh(), $_value_map), |
| GetLloValue($_self.getLow(), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_BitDataFormat : I32EnumAttr<"BitDataFormat", "Bit data format", [ |
| I32EnumAttrCase<"kB32", 0, "b32">, |
| I32EnumAttrCase<"kCompressedB16", 1, "b16.compressed">, |
| I32EnumAttrCase<"kCompressedB8", 2, "b8.compressed">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_BitDataFormatEnum |
| : LLO_EnumAttr<LLO_BitDataFormat, "bit_data_format"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::BitDataFormat"; |
| } |
| |
| def LLO_SetPermuteMode : I32EnumAttr<"SetPermuteMode", "Set permute mode", [ |
| I32EnumAttrCase<"kOneSublane", 0, "one_sublane">, |
| I32EnumAttrCase<"kAllSublanes", 1, "all_sublanes">, |
| ]> { |
| let genSpecializedAttr = 0; |
| let cppNamespace = "::mlir::llo"; |
| } |
| |
| def LLO_SetPermuteModeEnum |
| : LLO_EnumAttr<LLO_SetPermuteMode, "set_permute_mode"> { |
| let assemblyFormat = "`<` $value `>`"; |
| let lloEnumType = "::xla::jellyfish::SetPermuteMode"; |
| } |
| |
| def LLO_VectorSetPermutePatternOp : LLO_Op<"vsetperm"> { |
| let arguments = (ins |
| LLO_Vector:$pattern, |
| LLO_SetPermuteModeEnum:$mode, |
| DefaultValuedAttr<I32Attr,"0">:$xlu_id, |
| OptionalAttr<I32Attr>:$source_bus |
| ); |
| let results = (outs |
| LLO_Vector:$output // Not really, but we need to pass something to VectorPermuteOp |
| ); |
| let assemblyFormat = [{ $pattern `,` $mode `,` attr-dict `:` type($pattern) `,` type($output) }]; |
| } |
| |
| def LLO_VectorPermuteOp : LLO_Op<"vperm"> { |
| let arguments = (ins |
| LLO_Vector:$source, |
| LLO_Vector:$pattern, |
| DefaultValuedAttr<I32Attr,"0">:$xlu_id, |
| OptionalAttr<I32Attr>:$source_bus |
| ); |
| let results = (outs LLO_Vector:$request); |
| let assemblyFormat = [{ $source `,` $pattern `,` attr-dict `:` type($source) `,` type($pattern) `,` type($request) }]; |
| } |
| def LLO_VectorLaneBroadcastOp : LLO_Op<"vlbroadcast"> { |
| let arguments = (ins |
| LLO_Vector:$source, |
| I32:$lane_offset, |
| Optional<I32>:$sublane_stride, |
| LLO_BitDataFormatEnum:$format |
| ); |
| let results = (outs LLO_Vector:$request); |
| let builderMethod = [{ |
| $_builder.Vbroadcastlane( |
| GetLloValue($_self.getSource(), $_value_map), |
| GetLloValue($_self.getLaneOffset(), $_value_map), |
| $_self.getSublaneStride() |
| ? std::optional<::xla::jellyfish::LloValue*>( |
| GetLloValue($_self.getSublaneStride(), $_value_map)) |
| : std::nullopt, |
| static_cast<::xla::jellyfish::BitDataFormat>($_self.getFormat()) |
| ); |
| }]; |
| } |
| |
| def LLO_VectorPermuteResultOp : LLO_Op<"vpermres"> { |
| let arguments = (ins |
| LLO_Vector:$request, |
| DefaultValuedAttr<I32Attr,"0">:$xlu_id |
| ); |
| let results = (outs |
| LLO_Vector:$result |
| ); |
| let assemblyFormat = [{ $request `,` attr-dict `:` type($request) `,` type($result) }]; |
| } |
| |
| def LLO_VectorBroadcastSublaneChunkOp : LLO_Op<"vbcast_sublane_chunk"> { |
| let arguments = (ins |
| LLO_Vector:$source, |
| I32:$lane_offset |
| ); |
| let results = (outs LLO_Vector:$request); |
| let builderMethod = [{ |
| ::xla::jellyfish::fusion_util::BroadcastSublaneChunk( |
| GetLloValue($_self.getSource(), $_value_map), |
| GetLloValue($_self.getLaneOffset(), $_value_map), |
| ToPrimitiveType($_self.getType().getElementType(), |
| /*throw_error_throw_error_on_signless=*/false), |
| $_builder |
| ) |
| }]; |
| } |
| |
| def LLO_VectorSublaneId : LLO_Op<"vslaneid"> { |
| let arguments = (ins); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.Vslaneid()"; |
| } |
| |
| def LLO_VectorLaneId : LLO_Op<"vxlaneid"> { |
| let arguments = (ins); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.Vxlaneid()"; |
| } |
| |
| def LLO_VectorLaneSeqOp : LLO_Op<"vlaneseq"> { |
| let arguments = (ins); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.Vlaneseq()"; |
| } |
| |
| def LLO_ScalarToVectorOp : LLO_Op<"stov", [Pure]> { |
| let arguments = (ins AnyType:$scalar); |
| let results = (outs LLO_Vector:$result); |
| } |
| |
| def LLO_VectorToScalarOp : LLO_Op<"vtos", [Pure]> { |
| let arguments = (ins LLO_Vector:$vector); |
| let results = (outs AnyType:$scalar); |
| let builderMethod = "$_builder.Vtos(GetLloValue($_self.getVector(), $_value_map))"; |
| } |
| |
| // TODO(apaszke): Add support for other conversions |
| class LLO_ScalarConvertOp<string name, string opcode> |
| : LLO_Op<name, [Pure]> { |
| let arguments = (ins AnyType:$operand); |
| let results = (outs AnyType:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarUnOp( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_ScalarConvertS32ToF32Op |
| : LLO_ScalarConvertOp<"scvt.s32.f32", "kScalarConvertS32ToF32">; |
| def LLO_ScalarConvertF32ToS32Op |
| : LLO_ScalarConvertOp<"scvt.f32.s32", "kScalarConvertF32ToS32">; |
| def LLO_ScalarConvertF32ToS32TowardsZeroPseudoOp |
| : LLO_ScalarConvertOp<"scvt.f32.s32.to.zero.pseudo", "kScalarConvertF32ToS32TowardsZeroPseudo">; |
| |
| class LLO_ScalarUnOp<string name, string opcode, Type type> |
| : LLO_Op<name, [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins type:$operand); |
| let results = (outs type:$result); |
| |
| let assemblyFormat = [{ $operand attr-dict }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarUnOp( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_ScalarCountLeadingZerosOp : LLO_ScalarUnOp<"sclz", "kScalarCountLeadingZeros", I32>; |
| |
| class LLO_ScalarBinOp<string name, string opcode, Type type> |
| : LLO_Op<name, [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| type:$lhs, |
| type:$rhs |
| ); |
| let results = (outs type:$result); |
| |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarBinOp( |
| ::xla::jellyfish::LloOpcode::}] # opcode # [{, |
| GetLloValue($_self->getOperand(0), $_value_map), |
| GetLloValue($_self->getOperand(1), $_value_map), |
| $_region); |
| }]; |
| } |
| |
| def LLO_ScalarAddS32Op : LLO_ScalarBinOp<"sadd.s32", "kScalarAddS32", I32> { |
| let hasCanonicalizeMethod = 1; |
| } |
| def LLO_ScalarAddF32Op : LLO_ScalarBinOp<"sadd.f32", "kScalarAddF32", F32>; |
| def LLO_ScalarSubF32Op : LLO_ScalarBinOp<"ssub.f32", "kScalarSubtractF32", F32>; |
| def LLO_ScalarSubS32Op : LLO_ScalarBinOp<"ssub.s32", "kScalarSubtractS32", I32>; |
| def LLO_ScalarMulF32Op : LLO_ScalarBinOp<"smul.f32", "kScalarMultiplyF32", F32>; |
| def LLO_ScalarMaxF32Op : LLO_ScalarBinOp<"smax.f32", "kScalarMaximumF32", F32>; |
| def LLO_ScalarMinF32Op : LLO_ScalarBinOp<"smin.f32", "kScalarMinimumF32", F32>; |
| // Signedness does not matter for two's complement multiplication. |
| def LLO_ScalarMulS32Op : LLO_ScalarBinOp<"smul.s32", "kScalarMultiplyU32", I32>; |
| def LLO_ScalarDivS32Op : LLO_ScalarBinOp<"sdiv.s32", "", I32> { |
| let hasInstructionLowering = 0; |
| } |
| def LLO_ScalarRemS32Op : LLO_ScalarBinOp<"srem.s32", "", I32> { |
| let hasInstructionLowering = 0; |
| } |
| def LLO_ScalarMinS32Op : LLO_ScalarBinOp<"smin.s32", "UNUSED", I32> { |
| let builderMethod = [{ |
| $_builder.SminS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_ScalarMaxS32Op : LLO_ScalarBinOp<"smax.s32", "UNUSED", I32> { |
| let builderMethod = [{ |
| $_builder.SmaxS32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_ScalarDivF32Op : LLO_ScalarBinOp<"sdiv.f32", "UNUSED", F32> { |
| let builderMethod = [{ |
| $_builder.SdivF32(GetLloValue(op.getLhs(), $_value_map), |
| GetLloValue(op.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_ScalarShllOp : LLO_ScalarBinOp<"shll", "kScalarShll", I32>; |
| def LLO_ScalarShrlOp : LLO_ScalarBinOp<"shrl", "kScalarShrl", I32>; |
| def LLO_ScalarShraOp : LLO_ScalarBinOp<"shla", "kScalarShra", I32>; |
| def LLO_ScalarBitwiseOrOp : LLO_ScalarBinOp<"sor", "kScalarBitwiseOr", I32>; |
| def LLO_ScalarBitwiseAndOp : LLO_ScalarBinOp<"sand", "kScalarBitwiseAnd", I32>; |
| def LLO_ScalarBitwiseXorOp : LLO_ScalarBinOp<"sxor", "kScalarBitwiseXor", I32>; |
| def LLO_ScalarNegF32Op : LLO_ScalarUnOp<"sneg.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.SnegF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_ScalarSqrtF32Op : LLO_ScalarUnOp<"ssqrt.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.Ssqrt(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_ScalarExpF32Op : LLO_ScalarUnOp<"sexp.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.Sexp(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_ScalarLogF32Op : LLO_ScalarUnOp<"sln.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.Sln(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_ScalarRoundF32Op : LLO_ScalarUnOp<"sround.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.SroundF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| def LLO_ScalarRoundEvenF32Op : LLO_ScalarUnOp<"sround.ties_to_even.f32", "UNUSED", F32> { |
| let builderMethod = "$_builder.SroundTiesToEvenF32(GetLloValue(op.getOperand(), $_value_map))"; |
| } |
| |
| class LLO_ScalarCompareOp<string name, string direction, string xlaType, Type type> |
| : LLO_Op<name, [Pure]> { |
| let arguments = (ins |
| type:$lhs, |
| type:$rhs |
| ); |
| let results = (outs I1:$result); |
| |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict }]; |
| let customInstructionLowering = [{ |
| return ::xla::jellyfish::LloInstruction::CreateScalarCompare( |
| GetLloValue($_self->getOperand(0), $_value_map), |
| GetLloValue($_self->getOperand(1), $_value_map), |
| ::xla::Comparison( |
| ::xla::Comparison::Direction::}] # direction # [{, |
| ::xla::}] # xlaType # [{), |
| $_region); |
| }]; |
| } |
| |
| foreach dir = cmp_directions in { |
| def LLO_ScalarCmp#dir#S32Op : LLO_ScalarCompareOp< |
| "s"#!tolower(dir)#".s32", "k"#dir, "S32", I32>; |
| } |
| foreach dir = cmp_directions in { |
| def LLO_ScalarCmp#dir#F32Op : LLO_ScalarCompareOp< |
| "s"#!tolower(dir)#".f32", "k"#dir, "F32", F32>; |
| } |
| |
| def LLO_ScalarSelectOp : LLO_Op<"sselect", [Pure]> { |
| let arguments = (ins |
| I1:$condition, |
| AnyType:$if_true, |
| AnyType:$if_false |
| ); |
| let results = (outs AnyType:$result); |
| let assemblyFormat = [{ |
| $condition `,` $if_true `,` $if_false attr-dict `:` |
| `(` type($if_true) `,` type($if_false) `)` `->` type($result) |
| }]; |
| } |
| |
| def LLO_ScalarLoadOp : LLO_Op<"sld"> { |
| let arguments = (ins |
| I32:$address, |
| Optional<I32>:$offset |
| ); |
| let results = (outs AnyType:$result); |
| let assemblyFormat = [{ $address (`+` $offset^)? attr-dict `:` type($result) }]; |
| } |
| |
| def LLO_ScalarStoreOp : LLO_Op<"sst"> { |
| let arguments = (ins |
| I32:$address, |
| AnyType:$to_store |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ $address `,` $to_store attr-dict `:` type($to_store) }]; |
| } |
| |
| def LLO_PredicateOrOp : LLO_Op<"por", [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| I1:$lhs, |
| I1:$rhs |
| ); |
| let results = (outs I1:$result); |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict }]; |
| } |
| def LLO_PredicateAndOp : LLO_Op<"pand", [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins |
| I1:$lhs, |
| I1:$rhs |
| ); |
| let results = (outs I1:$result); |
| let assemblyFormat = [{ $lhs `,` $rhs attr-dict }]; |
| let builderMethod = [{ |
| $_builder.Pand(GetLloValue($_self.getLhs(), $_value_map), |
| GetLloValue($_self.getRhs(), $_value_map)) |
| }]; |
| } |
| def LLO_PredicateNegateOp : LLO_Op<"pneg", [Pure, SameOperandsAndResultType]> { |
| let arguments = (ins I1:$predicate); |
| let results = (outs I1:$result); |
| let assemblyFormat = [{ $predicate attr-dict }]; |
| } |
| |
| def LLO_ErrorIfOp : LLO_Op<"error.if"> { |
| let arguments = (ins |
| I1:$predicate, |
| StrAttr:$message |
| ); |
| let results = (outs); |
| let builderMethod = [{ |
| $_builder.ErrorIf(GetLloValue($_self.getPredicate(), $_value_map), $_self.getMessage()) |
| }]; |
| } |
| |
| def LLO_VectorBitcastOp : LLO_Op<"vbitcast", [Pure]> { |
| let arguments = (ins LLO_Vector:$input); |
| let results = (outs LLO_Vector:$output); |
| let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_ScalarBitcastOp : LLO_Op<"sbitcast", [Pure]> { |
| let arguments = (ins AnyType:$input); |
| let results = (outs AnyType:$output); |
| let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_EnqueueDMAOp : LLO_Op<"enqueue_dma", [AttrSizedOperandSegments]> { |
| let arguments = (ins |
| I32:$src_addr, |
| Optional<I32>:$src_flag, |
| I32:$size_512b, // size in multiples of 512 bytes |
| I32:$dst_addr, |
| I32:$dst_flag, |
| // Major-to-minor, strides (and the last steps_per_stride) are in bytes. |
| DenseI32ArrayAttr:$src_strides, |
| DenseI32ArrayAttr:$dst_strides, |
| Variadic<I32>:$steps_per_stride, |
| Optional<I32>:$chip_id, |
| Optional<I32>:$core_index |
| ); |
| let builderMethod = [{ |
| $_builder.EnqueueDmaMultiStrided( |
| GetLloValue($_self.getSrcAddr(), $_value_map), |
| GetLloValue($_self.getDstAddr(), $_value_map), |
| GetOptionalCoreLocation($_context, $_builder, $_self, "enqueue-dma", $_value_map), |
| ::xla::jellyfish::LloMemUnit( |
| GetLloValue($_self.getSize_512b(), $_value_map), |
| ::xla::jellyfish::Granule::k512Byte), |
| GetDmaStrides($_builder, $_self.getSrcStrides(), $_self.getDstStrides(), $_self.getStepsPerStride(), $_value_map), |
| GetLloValue($_self.getDstFlag(), $_value_map), |
| op.getSrcFlag() ? GetLloValue(op.getSrcFlag(), $_value_map) : $_builder.DummySyncFlag() |
| ); |
| }]; |
| let hasVerifier = 1; |
| } |
| |
| def LLO_DMADoneOp : LLO_Op<"dma_done"> { |
| let arguments = (ins |
| I32:$size_512b, // size in multiples of 512 bytes |
| I32:$flag_addr, |
| BoolAttr:$emit_sfence |
| ); |
| let builderMethod = [{ |
| $_builder.DmaDoneInGranules( |
| ::xla::jellyfish::LloMemUnit( |
| GetLloValue($_self.getSize_512b(), $_value_map), |
| ::xla::jellyfish::Granule::k512Byte), |
| GetLloValue($_self.getFlagAddr(), $_value_map), |
| $_self.getEmitSfence() |
| ); |
| }]; |
| } |
| |
| def LLO_ChipIdOp : LLO_Op<"chip_id"> { |
| let arguments = (ins); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.ChipId()"; |
| } |
| |
| def LLO_CoreIndexOp : LLO_Op<"core_index"> { |
| let arguments = (ins); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.CoreIndex()"; |
| } |
| |
| def LLO_LogicalDeviceIdOp : LLO_Op<"logical_device_id", [Pure]> { |
| let arguments = (ins); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ attr-dict `:` type($result) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| |
| def LLO_VSyncRead : LLO_Op<"vsync.read"> { |
| let arguments = (ins I32:$flag_addr); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ $flag_addr attr-dict `:` type($result) }]; |
| let builderMethod = "$_builder.VsyncRead(GetLloValue($_self.getFlagAddr(), $_value_map))"; |
| } |
| |
| class LLO_VWaitOp<string cmp> : LLO_Op<"vwait." # !tolower(cmp)> { |
| let arguments = (ins |
| I32:$flag_addr, |
| I32:$value |
| ); |
| let builderMethod = [{ |
| $_builder.Vwait}] # cmp # [{SV( |
| GetLloValue($_self.getFlagAddr(), $_value_map), |
| GetLloValue($_self.getValue(), $_value_map) |
| ); |
| }]; |
| |
| } |
| |
| def LLO_VWaitEqOp : LLO_VWaitOp<"Eq">; |
| def LLO_VWaitGeOp : LLO_VWaitOp<"Ge">; |
| |
| def LLO_VSyncSetOp : LLO_Op<"vsync.set"> { |
| let arguments = (ins |
| I32:$flag_addr, |
| I32:$value |
| ); |
| let builderMethod = [{ |
| $_builder.VsyncSet( |
| GetLloValue($_self.getFlagAddr(), $_value_map), |
| GetLloValue($_self.getValue(), $_value_map) |
| ); |
| }]; |
| } |
| |
| def LLO_VSyncAddRemoteOp : LLO_Op<"vsync.add.remote"> { |
| let arguments = (ins |
| I32:$flag_addr, |
| I32:$chip_id, |
| I32:$core_index, |
| I32:$amount |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ $flag_addr `,` $chip_id `,` $core_index `,` $amount attr-dict }]; |
| let builderMethod = [{ |
| $_builder.VsyncAddRemote( |
| GetLloValue($_self.getFlagAddr(), $_value_map), |
| GetCoreLocation($_context, $_builder, $_self, "sync-add", $_value_map), |
| GetLloValue($_self.getAmount(), $_value_map) |
| ); |
| }]; |
| } |
| |
| def LLO_VSyncAddOp : LLO_Op<"vsync.add"> { |
| let arguments = (ins |
| I32:$flag_addr, |
| I32:$amount |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ $flag_addr `,` $amount attr-dict }]; |
| let builderMethod = [{ |
| $_builder.VsyncAdd( |
| GetLloValue($_self.getFlagAddr(), $_value_map), |
| GetLloValue($_self.getAmount(), $_value_map) |
| ); |
| }]; |
| } |
| |
| |
| def LLO_LogS32 : LLO_Op<"log.s32"> { |
| let arguments = (ins |
| I32:$operand, |
| StrAttr:$tag |
| ); |
| let results = (outs); |
| let hasInstructionLowering = 0; |
| } |
| def LLO_LogPred : LLO_Op<"log.pred"> { |
| let arguments = (ins |
| I1:$operand, |
| StrAttr:$tag |
| ); |
| let results = (outs); |
| let hasInstructionLowering = 0; |
| } |
| def LLO_LogVmaskOp : LLO_Op<"log.vmask"> { |
| let arguments = (ins |
| LLO_VectorMask:$operand, |
| StrAttr:$tag |
| ); |
| let results = (outs); |
| let builderMethod = "$_builder.LogVmask(GetLloValue($_self.getOperand(), $_value_map), $_self.getTag().str())"; |
| } |
| |
| def LLO_RegionOp : LLO_Op<"region", [ |
| RecursiveMemoryEffects, SingleBlockImplicitTerminator<"llo::YieldOp">]> { |
| let results = (outs Variadic<AnyType>:$results); |
| let regions = (region AnyRegion:$region); |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_TraceOp : LLO_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"llo::YieldOp">]> { |
| let arguments = (ins StrAttr:$message, I32Attr:$level); |
| let results = (outs Variadic<AnyType>:$results); |
| let regions = (region AnyRegion:$region); |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_YieldOp : LLO_Op<"yield", [Pure, ReturnLike, Terminator]> { |
| let arguments = (ins Variadic<AnyType>:$results); |
| let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_TraceStartOp : LLO_Op<"trace_start", []> { |
| let arguments = (ins StrAttr:$message, I32Attr:$level); |
| let results = (outs); |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_TraceStopOp : LLO_Op<"trace_stop", []> { |
| let arguments = (ins); |
| let results = (outs); |
| let hasInstructionLowering = 0; |
| } |
| |
| // Extensions that do not have any corresponding ops in LLO, but help with |
| // lowering and optimization. |
| |
| // TODO(apaszke): Implement a common optimization that folds the ops that follow |
| // the alloc-store-load-free pattern into loads |
| def LLO_VectorSublaneReverseOp : LLO_Op<"sublane_reverse", [Pure]> { |
| let arguments = (ins LLO_Vector:$operand); |
| let results = (outs LLO_Vector:$result); |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_VectorLoadAndUnpackOp : LLO_Op<"vector_load_unpack", [VmemRead]> { |
| let arguments = (ins |
| I32:$address, |
| I32Attr:$chunk_id, |
| TypeAttr:$packed_type |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ |
| $address `,` $chunk_id attr-dict `:` type($result) |
| }]; |
| let builderMethod = [{ |
| $_builder.LdAndUnpack( |
| ToPrimitiveType($_self.getPackedType()), |
| GetLloValue($_self.getAddress(), $_value_map), |
| $_builder.SimmS32($_self.getChunkId()), |
| $_builder.SublaneMaskForSublaneCount($_builder.target().SublaneCount()), |
| $_builder.target().SublaneCount()); |
| }]; |
| } |
| |
| def LLO_VectorStoreAndPackX16Op : LLO_Op<"vector_store_unpack_x16", [VmemWrite]> { |
| let arguments = (ins |
| I32:$address, |
| I32Attr:$chunk_id, |
| LLO_Vector:$to_store, |
| TypeAttr:$packed_type |
| ); |
| let results = (outs); |
| let assemblyFormat = [{ |
| $address `,` $chunk_id `,` $to_store attr-dict `:` type($to_store) |
| }]; |
| let builderMethod = [{ |
| $_builder.StAndPackToX16( |
| ToPrimitiveType($_self.getPackedType()), |
| GetLloValue($_self.getAddress(), $_value_map), |
| $_builder.SimmS32($_self.getChunkId()), |
| GetLloValue($_self.getToStore(), $_value_map), |
| $_builder.SimmS32($_builder.target().SublaneCount() / 2)); |
| }]; |
| } |
| |
| def LLO_VectorSublaneShuffleOp : LLO_Op<"slane.shuffle", [Pure]> { |
| let arguments = (ins |
| LLO_Vector:$source, |
| DenseI32ArrayAttr:$indices |
| ); |
| let results = (outs |
| LLO_Vector:$result |
| ); |
| let assemblyFormat = [{ $source `[` $indices `]` attr-dict `:` type($source) `->` type($result) }]; |
| let builderMethod = [{ |
| $_builder.VslaneShfl(GetLloValue($_self.getSource(), $_value_map), |
| b.CreateShufflePattern($_self.getIndices())); |
| }]; |
| } |
| |
| def LLO_VectorSublaneReplicateOp : LLO_Op<"vslreplicate", [Pure]> { |
| let arguments = (ins |
| LLO_Vector:$source, |
| I32:$sublane |
| ); |
| let results = (outs LLO_Vector:$result); |
| let assemblyFormat = [{ $source `,` $sublane attr-dict `:` type($source) `->` type($result)}]; |
| let builderMethod = [{$_builder.VslaneToAll(GetLloValue($_self.getSource(), $_value_map), |
| GetLloValue($_self.getSublane(), $_value_map)) }]; |
| } |
| |
| def EliminateLLOExtensionsPass : Pass<"elim-llo-ext", "::mlir::func::FuncOp"> { |
| let dependentDialects = ["::mlir::llo::LLODialect"]; |
| let constructor = "::mlir::llo::createEliminateLLOExtensionsPass()"; |
| } |
| |
| // LLO extensions that are only valid within Mosaic programs and are lowered specially. |
| def LLO_GetIterationBoundOp : LLO_Op<"iteration_bound"> { |
| let arguments = (ins I32Attr:$dim); |
| let results = (outs I32:$result); |
| let assemblyFormat = [{ $dim attr-dict `:` type($result) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| def LLO_GetBarrierSyncFlagOp : LLO_Op<"barrier_sflag"> { |
| let arguments = (ins); |
| let results = (outs I32:$sync_flag); |
| let assemblyFormat = [{ attr-dict `:` type($sync_flag) }]; |
| let hasInstructionLowering = 0; |
| } |
| |
| |
| #endif // LLO_OPS |