blob: 7c8c596cb728f7786ad927a0de5d13df1f0e1421 [file] [log] [blame]
#ifndef LLO_OPS
#define LLO_OPS
include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/Pass/PassBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
defvar cmp_directions = ["Eq", "Ne", "Ge", "Gt", "Le", "Lt"];
def LLO_Dialect : Dialect {
let name = "llo";
let cppNamespace = "::mlir::llo";
let useDefaultAttributePrinterParser = 1;
}
class LLO_Attr<string name, string mnemonic_, list<Trait> traits = []>
: AttrDef<LLO_Dialect, name, traits> {
let mnemonic = mnemonic_;
}
class LLO_Type<string name, string mnemonic_, list<Trait> traits = []>
: TypeDef<LLO_Dialect, name, traits> {
let mnemonic = mnemonic_;
}
class LLO_EnumAttr<EnumAttrInfo enum, string name>
: EnumAttr<LLO_Dialect, enum, name> {
// Name of the LLO enum this attribute is modeling.
string lloEnumType;
}
class LLO_Op<string mnemonic, list<Trait> traits = []> :
Op<LLO_Dialect, mnemonic, traits> {
string customInstructionLowering;
bit hasInstructionLowering = 1;
string builderMethod = ?;
}
// TODO(apaszke): Model resources more precisely
def VmemAlloc : MemAlloc<DefaultResource>;
def VmemWrite : MemoryEffects<[MemWrite<DefaultResource>]>;
def VmemRead : MemoryEffects<[MemRead<DefaultResource>]>;
def SmemAlloc : MemAlloc<DefaultResource>;
def LLO_Vector : Type<
And<[IsVectorTypePred,
Or<[
And<[
CPred<"llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
CPred<"llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth() == 32">
]>,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{"
"8, 128, 32 / ::llvm::cast<::mlir::VectorType>($_self).getElementType().getIntOrFloatBitWidth()}">,
]>
]>,
"native-sized vreg", "::mlir::VectorType">;
// Only some of those types are actually allowed, depending on the target.
def LLO_VectorMask : Type<
And<[IsVectorTypePred,
Or<[
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128}">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128, 2}">,
CPred<"::llvm::cast<::mlir::VectorType>($_self).getShape() == ArrayRef<int64_t>{8, 128, 4}">,
]>
]>,
"", "::mlir::VectorType">;
def LLO_AllocaVmemOp : LLO_Op<"alloca_vmem"> {
let arguments = (ins I32Attr:$num_words);
let results = (outs Res<I32, "", [VmemAlloc]>:$result);
let assemblyFormat = [{ $num_words attr-dict `:` type($result) }];
let hasInstructionLowering = 0;
}
def LLO_AllocaSmemOp : LLO_Op<"alloca_smem"> {
let arguments = (ins I32Attr:$num_words);
let results = (outs Res<I32, "", [SmemAlloc]>:$result);
let assemblyFormat = [{ $num_words attr-dict `:` type($result) }];
let hasInstructionLowering = 0;
}
def LLO_AllocaSyncFlagOp : LLO_Op<"alloca_sflag"> {
let arguments = (ins I32Attr:$num_flags);
let results = (outs I32:$result);
let assemblyFormat = [{ $num_flags attr-dict `:` type($result) }];
let builderMethod = "$_builder.AllocateScopedSflags($_self.getNumFlags())";
}
// TODO(apaszke): Expose more cases
def LLO_GainLatchMode : I32EnumAttr<"GainLatchMode", "Gain latch mode", [
I32EnumAttrCase<"kNoXposeF32", 0, "f32">,
I32EnumAttrCase<"kXposeF32", 1, "xpose.f32">,
I32EnumAttrCase<"kNoXposeHiF32", 2, "hi.f32">,
I32EnumAttrCase<"kXposeHiF32", 3, "xpose.hi.f32">,
I32EnumAttrCase<"kNoXposeLowF32", 4, "low.f32">,
I32EnumAttrCase<"kXposeLowF32", 5, "xpose.low.f32">,
I32EnumAttrCase<"kNoXposeSoftMiddleEightF32", 6, "soft_middle_eight.f32">,
I32EnumAttrCase<"kXposeSoftMiddleEightF32", 7, "xpose.soft_middle_eight.f32">,
I32EnumAttrCase<"kNoXposeSoftLowEightF32", 8, "soft_low_eight.f32">,
I32EnumAttrCase<"kXposeSoftLowEightF32", 9, "xpose.soft_low_eight.f32">,
I32EnumAttrCase<"kXposePackedBf16", 10, "xpose.packed.bf16">,
I32EnumAttrCase<"kNoXposePackedBf16", 11, "packed.bf16">,
I32EnumAttrCase<"kNoXposeS8", 20, "s8">,
I32EnumAttrCase<"kXposeS8", 21, "xpose.s8">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_GainLatchModeEnum
: LLO_EnumAttr<LLO_GainLatchMode, "gain_latch_mode"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::GainLatchMode";
}
def LLO_VectorLatchOp : LLO_Op<"vlatch"> {
let arguments = (ins
LLO_Vector:$input,
LLO_GainLatchModeEnum:$gain_latch_mode,
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs);
// Even though this is very close to the VectorLatch LLO instruction,
// we use the builder to support non-native modes.
let builderMethod = [{
$_builder.Vlatch(
GetLloValue($_self.getInput(), $_value_map),
static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()),
op.getMxuId());
}];
}
def LLO_MatmulMode : I32EnumAttr<"MatmulMode", "Matrix multiplication mode", [
I32EnumAttrCase<"kRound", 0, "round">,
I32EnumAttrCase<"kHigh", 1, "high">,
I32EnumAttrCase<"kLow", 2, "low">,
I32EnumAttrCase<"kSoftMiddleEight", 3, "soft_middle_eight">,
I32EnumAttrCase<"kSoftLowEight", 4, "soft_low_eight">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_MatmulModeEnum
: LLO_EnumAttr<LLO_MatmulMode, "matmul_mode"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::MatmulMode";
}
def LLO_MatmulDataFormat : I32EnumAttr<"MatmulDataFormat", "Matrix multiplication data format", [
I32EnumAttrCase<"kInvalid", 0, "invalid">,
I32EnumAttrCase<"kF32", 1, "f32">,
I32EnumAttrCase<"kBf16", 2, "bf16">,
I32EnumAttrCase<"kS8", 6, "s8">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_MatmulDataFormatEnum
: LLO_EnumAttr<LLO_MatmulDataFormat, "matmul_data_format"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::MatmulDataFormat";
}
def LLO_VectorMatmulOp : LLO_Op<"vmatmul"> {
let arguments = (ins
LLO_Vector:$input,
LLO_MatmulModeEnum:$mode,
DefaultValuedAttr<I32Attr,"0">:$mxu_id,
DefaultValuedAttr<BoolAttr,"false">:$dwg,
DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format
);
let results = (outs);
// Even though this is very close to the VectorMatmul LLO instruction,
// we use the builder to support non-native matmul modes.
let builderMethod = [{
$_self.getDataFormat() == MatmulDataFormat::kF32
? $_builder.Vmatmul(
GetLloValue(op.getInput(), $_value_map),
static_cast<::xla::jellyfish::MatmulMode>(op.getMode()),
op.getDwg()
? xla::jellyfish::DoneWithGainsMode::kNormal
: xla::jellyfish::DoneWithGainsMode::kNone,
op.getMxuId())
: $_builder.Vmatmul(
GetLloValue(op.getInput(), $_value_map),
static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()),
op.getDwg()
? xla::jellyfish::DoneWithGainsMode::kNormal
: xla::jellyfish::DoneWithGainsMode::kNone,
op.getMxuId());
}];
}
// GLC+ only.
def LLO_VectorMatprepSubrOp : LLO_Op<"vmatprep.subr"> {
let arguments = (ins
LLO_Vector:$input,
LLO_GainLatchModeEnum:$gain_latch_mode,
DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format,
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs);
let builderMethod = [{
$_builder.VmatprepSubr(
$_builder.VprepareForLatch(
GetLloValue(op.getInput(), $_value_map),
static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode())
).first,
static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()),
op.getMxuId());
}];
}
// GLC+ only.
def LLO_VectorMatprepMubrOp : LLO_Op<"vmatprep.mubr"> {
let arguments = (ins
LLO_Vector:$input,
LLO_MatmulModeEnum:$mode,
DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format,
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs);
let builderMethod = [{
$_builder.VmatprepMubr(
$_builder.VprepareForMatmul(
GetLloValue(op.getInput(), $_value_map),
static_cast<::xla::jellyfish::MatmulMode>(op.getMode())).first,
static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()),
op.getMxuId());
}];
}
// GLC+ only.
def LLO_VectorLatchIOp : LLO_Op<"vlatchi"> {
let arguments = (ins
LLO_Vector:$input,
I32Attr:$latch_variant,
LLO_GainLatchModeEnum:$gain_latch_mode,
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs);
let builderMethod = [{
$_self.getLatchVariant() == 3
? $_builder.Vlatch3(
GetLloValue($_self.getInput(), $_value_map),
static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()),
op.getMxuId())
: ($_self.getLatchVariant() == 2
? $_builder.Vlatch2(
GetLloValue($_self.getInput(), $_value_map),
static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()),
op.getMxuId())
: $_builder.Vlatch1(
GetLloValue($_self.getInput(), $_value_map),
static_cast<::xla::jellyfish::GainLatchMode>(op.getGainLatchMode()),
op.getMxuId())
);
}];
let hasVerifier = 1;
}
// GLC+ only.
def LLO_VectorMatmulMubrOp : LLO_Op<"vmatmul.mubr"> {
let arguments = (ins
LLO_Vector:$input,
LLO_MatmulModeEnum:$mode,
DefaultValuedAttr<I32Attr,"0">:$mxu_id,
DefaultValuedAttr<BoolAttr,"false">:$dwg,
DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format
);
let results = (outs);
let builderMethod = [{
$_builder.VmatmulMubr(
$_builder.VprepareForMatmul(
GetLloValue(op.getInput(), $_value_map),
static_cast<::xla::jellyfish::MatmulMode>(op.getMode())).first,
static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()),
op.getDwg()
? xla::jellyfish::DoneWithGainsMode::kNormal
: xla::jellyfish::DoneWithGainsMode::kNone,
op.getMxuId(),
/*mrb_addr=*/std::nullopt);
}];
}
// TODO(apaszke): Add support for the data format attribute
def LLO_VectorMatresOp : LLO_Op<"vmatres"> {
let arguments = (ins
DefaultValuedAttr<LLO_MatmulDataFormatEnum,"MatmulDataFormat::kF32">:$data_format,
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs LLO_Vector:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorMatres(
static_cast<::xla::jellyfish::MatmulDataFormat>(op.getDataFormat()),
/*mrb_addr=*/0,
op.getMxuId(), b.region());
}];
}
// TODO(apaszke): Support the transposed variant
def LLO_VectorDoneWithGainsOp : LLO_Op<"vdwg"> {
let arguments = (ins
DefaultValuedAttr<I32Attr,"0">:$mxu_id
);
let results = (outs);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorDoneWithGains(
::xla::jellyfish::DoneWithGainsMode::kNormal,
op.getMxuId(),
b.region());
}];
}
def LLO_VxposeMode : I32EnumAttr<"VxposeMode", "Transpose mode", [
I32EnumAttrCase<"kB32", 0, "b32">,
I32EnumAttrCase<"kCompressedB16", 1, "b16.c">,
I32EnumAttrCase<"kCompressedB8", 2, "b8.c">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_VxposeModeEnum
: LLO_EnumAttr<LLO_VxposeMode, "vxpose_mode"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::VxposeMode";
}
def LLO_VectorTransposeOp : LLO_Op<"vxpose"> {
let arguments = (ins
LLO_Vector:$input,
LLO_VxposeModeEnum:$mode,
I32:$width,
I32Attr:$chunk_id,
I32Attr:$number_of_chunks,
DefaultValuedAttr<I32Attr,"0">:$xlu_id,
OptionalAttr<I32Attr>:$source_bus
);
let results = (outs);
}
def LLO_VectorTransposeBinaryCompressedB16Op : LLO_Op<"vxpose.c.b16"> {
let arguments = (ins
LLO_Vector:$first,
LLO_Vector:$second,
I32:$width,
I32Attr:$chunk_id,
I32Attr:$number_of_chunks,
DefaultValuedAttr<I32Attr,"0">:$xlu_id,
OptionalAttr<I32Attr>:$source_bus
);
let results = (outs);
}
def LLO_VectorTransposeResultOp : LLO_Op<"vxpose.result"> {
let arguments = (ins DefaultValuedAttr<I32Attr,"0">:$xlu_id);
let results = (outs LLO_Vector:$result);
}
def LLO_AddrScaledOp : LLO_Op<"saddr_scaled", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words,
I32Attr:$multiplier_in_bytes
);
let results = (outs I32:$result);
let builderMethod = [{
$_builder.AddrScaled(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
op.getMultiplierInBytes());
}];
}
// TODO(apaszke): Deduplicate ScalarAddress ops.
def LLO_ScalarAddressVmemOp : LLO_Op<"saddr.vmem", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words
);
let results = (outs I32:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarAddress(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
::xla::jellyfish::MemorySpace::kVmem,
b.region());
}];
}
def LLO_ScalarAddressSmemOp : LLO_Op<"saddr.smem", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words
);
let results = (outs I32:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarAddress(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
::xla::jellyfish::MemorySpace::kSmem,
b.region());
}];
}
def LLO_ScalarAddressSflagOp : LLO_Op<"saddr.sflag", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words
);
let results = (outs I32:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarAddress(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
::xla::jellyfish::MemorySpace::kSflag,
b.region());
}];
}
def LLO_ScalarAddressCmemOp : LLO_Op<"saddr.cmem", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words
);
let results = (outs I32:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarAddress(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
::xla::jellyfish::MemorySpace::kCmem,
b.region());
}];
}
def LLO_ScalarAddressHbmOp : LLO_Op<"saddr.hbm", [SameOperandsAndResultType]> {
let arguments = (ins
I32:$address,
I32:$offset_in_words
);
let results = (outs I32:$result);
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarAddress(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getOffsetInWords(), $_value_map),
::xla::jellyfish::MemorySpace::kHbm,
b.region());
}];
}
def LLO_VldWithArbitrarySlaneStrideOp
: LLO_Op<"vector_load_slane_stride", [VmemRead]> {
let arguments = (ins
I32:$address,
I64Attr:$sublane_stride, // In sublane-sized units
I64Attr:$sublane_count,
Optional<UI32>:$sublane_mask
);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$address (`masked` $sublane_mask^)? attr-dict `:` type($address) `->` type($result)
}];
let builderMethod = [{
$_builder.LdWithArbitrarySlaneStride(
GetLloValue(op.getAddress(), $_value_map),
op.getSublaneStride(),
op.getSublaneCount(),
op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr);
}];
}
def LLO_VstWithArbitrarySlaneStrideOp
: LLO_Op<"vector_store_slane_stride", [VmemWrite]> {
let arguments = (ins
I32:$address,
LLO_Vector:$to_store,
I64Attr:$sublane_stride, // In sublane-sized units
I64Attr:$sublane_count,
Optional<UI32>:$sublane_mask
);
let results = (outs);
let assemblyFormat = [{
$to_store `into` $address (`masked` $sublane_mask^)? attr-dict `:`
type($to_store) `into` type($address)
}];
let builderMethod = [{
$_builder.VstWithArbitrarySlaneStride(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getToStore(), $_value_map),
op.getSublaneStride(),
op.getSublaneCount(),
op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr);
}];
}
def LLO_VstMaskedWithArbitrarySlaneStrideOp
: LLO_Op<"vector_store_masked_slane_stride", [VmemWrite]> {
let arguments = (ins
I32:$address,
LLO_VectorMask:$mask,
LLO_Vector:$to_store,
I64Attr:$sublane_stride, // In sublane-sized units
I64Attr:$sublane_count,
Optional<UI32>:$sublane_mask
);
let results = (outs);
let assemblyFormat = [{
$to_store `into` $address `masked` $mask (`slmask` $sublane_mask^)? attr-dict `:`
type($to_store) `into` type($address) `,` type($mask)
}];
let builderMethod = [{
$_builder.VstMaskedWithArbitrarySlaneStride(
GetLloValue(op.getAddress(), $_value_map),
GetLloValue(op.getMask(), $_value_map),
GetLloValue(op.getToStore(), $_value_map),
op.getSublaneStride(),
op.getSublaneCount(),
op.getSublaneMask() ? GetLloValue(op.getSublaneMask(), $_value_map) : nullptr);
}];
}
def LLO_VectorLoadOp
: LLO_Op<"vector_load", [AttrSizedOperandSegments, VmemRead]> {
let arguments = (ins
I32:$address,
Optional<I32>:$displacement, // In sublane-sized units
Optional<UI32>:$sublane_mask,
DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units
DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{
$address (`+` $displacement^)? (`masked` $sublane_mask^)? attr-dict `:`
type($address) `->` type($result)
}];
}
def LLO_VectorStoreOp
: LLO_Op<"vector_store", [AttrSizedOperandSegments, VmemWrite]> {
let arguments = (ins
I32:$address,
Optional<I32>:$displacement, // In sublane-sized units
LLO_Vector:$to_store,
Optional<UI32>:$sublane_mask,
DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units
DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride
);
let results = (outs);
let assemblyFormat = [{
$to_store `into` $address (`+` $displacement^)? (`masked` $sublane_mask^)? attr-dict `:`
type($to_store) `into` type($address)
}];
}
def LLO_VectorStoreMaskedOp
: LLO_Op<"vector_store_masked", [AttrSizedOperandSegments, VmemWrite]> {
let arguments = (ins
I32:$address,
Optional<I32>:$displacement, // In sublane-sized units
LLO_VectorMask:$mask,
LLO_Vector:$to_store,
DefaultValuedAttr<I64Attr,"1">:$sublane_stride, // In sublane-sized units
DefaultValuedAttr<I64Attr,"1">:$sublanes_per_stride,
Optional<UI32>:$sublane_mask
);
let results = (outs);
let assemblyFormat = [{
$to_store `into` $address (`+` $displacement^)? `masked` $mask (`slmask` $sublane_mask^)? attr-dict `:`
type($to_store) `into` type($address) `,` type($mask)
}];
}
def LLO_ConstantOp : LLO_Op<"constant", [Pure]> {
let arguments = (ins AnyAttr:$value);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict `:` type($result)";
let hasInstructionLowering = 0;
}
class LLO_CrossLaneReductionOp<string name, string instr>
: LLO_Op<name, [Pure, SameOperandsAndResultType]> {
let arguments = (ins
LLO_Vector:$input,
DefaultValuedAttr<I64Attr,"0">:$xlu_id,
OptionalAttr<I64Attr>:$source_bus
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorXlaneResult(
b.Instruction(::xla::jellyfish::LloInstruction::Create}] # instr # [{(
GetLloValue(op.getInput(), $_value_map),
op.getXluId(),
op.getSourceBus()
? std::optional<int32_t>(op.getSourceBus().value())
: std::nullopt,
b.region())),
op.getXluId(), b.region());
}];
}
def LLO_VectorAddReduceF32Op : LLO_CrossLaneReductionOp<"vadd.xlane.f32", "VectorAddReduceF32">;
def LLO_VectorMaxReduceF32Op : LLO_CrossLaneReductionOp<"vmax.xlane.f32", "VectorMaxReduceF32">;
class LLO_SublaneReductionOp<string name, string method_suffix>
: LLO_Op<name, [Pure, SameOperandsAndResultType]> {
let arguments = (ins
LLO_Vector:$input
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{
$input attr-dict `:` type($input)
}];
let builderMethod = "$_builder.Vslane" # method_suffix #
"(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorAddSublaneReduceF32Op : LLO_SublaneReductionOp<"vadd.slane.f32", "SumF32">;
def LLO_VectorMaxSublaneReduceF32Op : LLO_SublaneReductionOp<"vmax.slane.f32", "MaxF32">;
def LLO_VectorRotateOp : LLO_Op<"vrot.lane", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
LLO_Vector:$input,
SI32Attr:$amount,
OptionalAttr<SI32Attr>:$stride
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $input `,` $amount attr-dict (`stride` $stride^)? `:` type($input) }];
let builderMethod = [{$_builder.Vpermuteres($_builder.Vrotate(
GetLloValue(op.getInput(), $_value_map),
b.SimmS32(op.getAmount()),
op.getStride()
? std::optional<::xla::jellyfish::LloValue*>(
b.SimmS32(op.getStride().value()))
: std::nullopt,
::xla::jellyfish::BitDataFormat::kB32));
}];
}
def LLO_VectorSublaneRotateOp : LLO_Op<"vrot.slane", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
LLO_Vector:$input,
I32Attr:$amount
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $input `,` $amount attr-dict `:` type($input) }];
let builderMethod = [{ $_builder.VslaneRotateTZ(GetLloValue($_self.getInput(), $_value_map), $_self.getAmount()) }];
}
def LLO_VectorCreateMaskOp : LLO_Op<"vcmask", [Pure]> {
let arguments = (ins
I64Attr:$sublane_start,
I64Attr:$sublane_end, // inclusive
I64Attr:$lane_start,
I64Attr:$lane_end // inclusive
);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{
`[` $sublane_start `:` $sublane_end `]` `[` $lane_start `:` $lane_end `]`
attr-dict `:` type($result)
}];
let builderMethod = [{
$_builder.CreateVmask(op.getSublaneStart(), op.getSublaneEnd(),
op.getLaneStart(), op.getLaneEnd())
}];
}
def LLO_VectorCreateSublaneMaskOp : LLO_Op<"vsmask", [Pure]> {
let arguments = (ins I32:$limit);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{
$limit attr-dict `:` type($limit) `->` type($result)
}];
}
def LLO_VectorCreateLaneMaskOp : LLO_Op<"vlmask", [Pure]> {
let arguments = (ins I32:$limit);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{
$limit attr-dict `:` type($limit) `->` type($result)
}];
}
class LLO_VectorMaskBinOp<string name, string opcode,
list<Trait> traits = [Pure, SameOperandsAndResultType]>
: LLO_Op<name, traits> {
let arguments = (ins
LLO_VectorMask:$lhs,
LLO_VectorMask:$rhs
);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorMaskBinop(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
GetLloValue($_self->getOperand(1), $_value_map),
$_region);
}];
}
def LLO_VectorMaskAndOp : LLO_VectorMaskBinOp<"vmand", "kVectorMaskAnd">;
def LLO_VectorMaskOrOp : LLO_VectorMaskBinOp<"vmor", "kVectorMaskOr">;
def LLO_VectorMaskXorOp : LLO_VectorMaskBinOp<"vmxor", "kVectorMaskXor">;
class LLO_VectorMaskUnOp<string name, string opcode,
list<Trait> traits = [Pure, SameOperandsAndResultType]>
: LLO_Op<name, traits> {
let arguments = (ins LLO_VectorMask:$operand);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{ $operand attr-dict `:` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorMaskUnop(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_VectorMaskNegateOp : LLO_VectorMaskUnOp<"vmneg", "kVectorMaskNegate">;
def LLO_VectorSelectOp : LLO_Op<"vselect", [Pure]> {
let arguments = (ins
LLO_VectorMask:$mask,
LLO_Vector:$if_true,
LLO_Vector:$if_false
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{
$mask `,` $if_true `,` $if_false attr-dict `:`
`(` type($mask) `,` type($if_true) `,` type($if_false) `)` `->` type($result)
}];
}
class LLO_VectorBinOp<string name, string opcode,
list<Trait> traits = [Pure, SameOperandsAndResultType]>
: LLO_Op<name, traits> {
let arguments = (ins
LLO_Vector:$lhs,
LLO_Vector:$rhs
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorBinop(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
GetLloValue($_self->getOperand(1), $_value_map),
$_region);
}];
}
class LLO_VectorBinOpAnyType<string name, string opcode,
list<Trait> traits = [Pure]>
: LLO_VectorBinOp<name, opcode, traits> {
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
}
def LLO_VectorAddF32Op : LLO_VectorBinOp<"vadd.f32", "kVectorAddF32">;
def LLO_VectorSubF32Op : LLO_VectorBinOp<"vsub.f32", "kVectorSubtractF32">;
def LLO_VectorMulF32Op : LLO_VectorBinOp<"vmul.f32", "kVectorMultiplyF32">;
def LLO_VectorMaxF32Op : LLO_VectorBinOp<"vmax.f32", "kVectorMaximumF32">;
def LLO_VectorMinF32Op : LLO_VectorBinOp<"vmin.f32", "kVectorMinimumF32">;
def LLO_VectorPowF32Op : LLO_VectorBinOp<"vpow.f32", "kVectorPowF32">;
def LLO_VectorAddS32Op : LLO_VectorBinOp<"vadd.s32", "kVectorAddS32">;
def LLO_VectorSubS32Op : LLO_VectorBinOp<"vsub.s32", "kVectorSubtractS32">;
def LLO_VectorMulS32Op : LLO_VectorBinOp<"vmul.s32", "UNUSED"> {
let builderMethod = [{
$_builder.VmulU32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_VectorDivS32Op : LLO_VectorBinOp<"vdiv.s32", "UNUSED"> {
let builderMethod = [{
$_builder.VdivS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_VectorRemS32Op : LLO_VectorBinOp<"vrem.s32", "UNUSED"> {
let builderMethod = [{
$_builder.VremS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_VectorShiftLeftLogicalOp : LLO_VectorBinOpAnyType<"vshll", "kVectorShiftLeftLogical">;
def LLO_VectorShiftRightLogicalOp : LLO_VectorBinOpAnyType<"vshrl", "kVectorShiftRightLogical">;
def LLO_VectorShiftRightArithmeticOp : LLO_VectorBinOpAnyType<"vshra", "kVectorShiftRightArithmetic">;
def LLO_VectorAndU32Op : LLO_VectorBinOp<"vand.u32", "kVectorAndU32">;
def LLO_VectorOrU32Op : LLO_VectorBinOp<"vor.u32", "kVectorOrU32">;
def LLO_VectorXOrU32Op : LLO_VectorBinOp<"vxor.u32", "kVectorXorU32">;
def LLO_VectorMaxS32Op : LLO_VectorBinOp<"vmax.s32", "UNUSED"> {
let builderMethod = [{
$_builder.VmaxS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_VectorMinS32Op : LLO_VectorBinOp<"vmin.s32", "UNUSED"> {
let builderMethod = [{
$_builder.VminS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
class LLO_VectorUnOp<string name, string opcode>
: LLO_Op<name, [Pure, SameOperandsAndResultType]> {
let arguments = (ins LLO_Vector:$operand);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $operand attr-dict `:` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorUnop(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_VectorRsqrtF32Op : LLO_VectorUnOp<"vrsqrt.f32", "unused"> {
let builderMethod = "$_builder.VrsqrtNr(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorRecipF32Op : LLO_VectorUnOp<"vrecip.f32", "unused"> {
let builderMethod = "$_builder.VrecpNr(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorSqrtF32Op : LLO_VectorUnOp<"vsqrt.f32", "unused"> {
let builderMethod = "$_builder.VsqrtF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorExpF32Op : LLO_VectorUnOp<"vexp.f32", "unused"> {
let builderMethod = "$_builder.Vexp(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorPow2F32Op : LLO_VectorUnOp<"vpow2.f32", "kVectorPow2F32AndPop">;
def LLO_VectorCosF32Op : LLO_VectorUnOp<"vcos.f32", "unused"> {
let builderMethod = "$_builder.Vcos(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorSinF32Op : LLO_VectorUnOp<"vsin.f32", "unused"> {
let builderMethod = "$_builder.Vsin(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorTanhF32Op : LLO_VectorUnOp<"vtanh.f32", "unused"> {
let builderMethod = "$_builder.Vtanh(xla::F32, GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorLogF32Op : LLO_VectorUnOp<"vln.f32", "unused"> {
let builderMethod = "$_builder.Vln(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorLog1pF32Op : LLO_VectorUnOp<"vln1p.f32", "unused"> {
let builderMethod = "$_builder.Vln1p(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorNegF32Op : LLO_VectorUnOp<"vneg.f32", "unused"> {
let builderMethod = "$_builder.VnegF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorNegS32Op : LLO_VectorUnOp<"vneg.s32", "UNUSED"> {
let builderMethod = "$_builder.VnegS32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorAbsF32Op : LLO_VectorUnOp<"vabs.f32", "unused"> {
let builderMethod = "$_builder.VabsF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorRoundF32Op : LLO_VectorUnOp<"vround.f32", "unused"> {
let builderMethod = "$_builder.VroundF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_VectorRoundEvenF32Op : LLO_VectorUnOp<"vround.nearest_even.f32", "unused"> {
let builderMethod = "$_builder.VroundNearestEvenF32(GetLloValue(op.getOperand(), $_value_map))";
}
class LLO_VectorCompareOp<string name, string direction, string xlaType, string order>
: LLO_Op<name, [Pure]> {
let arguments = (ins
LLO_Vector:$lhs,
LLO_Vector:$rhs
);
let results = (outs LLO_VectorMask:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorCompare(
GetLloValue($_self->getOperand(0), $_value_map),
GetLloValue($_self->getOperand(1), $_value_map),
::xla::Comparison(
::xla::ComparisonDirection::}] # direction # [{,
::xla::}] # xlaType # [{,
::xla::ComparisonOrder::}] # order # [{),
$_region);
}];
}
foreach dir = cmp_directions in {
def LLO_VectorCmp#dir#S32Op : LLO_VectorCompareOp<
"vcmp."#!tolower(dir)#".s32", "k"#dir, "S32", "kTotal">;
}
foreach dir = cmp_directions in {
def LLO_VectorCmp#dir#F32Op : LLO_VectorCompareOp<
"vcmp."#!tolower(dir)#".f32", "k"#dir, "F32", "kPartial">;
}
class LLO_VectorConvertOp<string name, string opcode>
: LLO_Op<name, [Pure]> {
let arguments = (ins LLO_Vector:$operand);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorUnop(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_VectorConvertS32ToF32Op
: LLO_VectorConvertOp<"vcvt.s32.f32", "kVectorConvertS32ToF32">;
def LLO_VectorConvertF32ToS32Op
: LLO_VectorConvertOp<"vcvt.f32.s32", "kVectorConvertF32ToS32">;
def LLO_VectorConvertF32ToS32TowardsZeroPseudoOp
: LLO_VectorConvertOp<"vcvt.f32.s32.to.zero.pseudo", "kVectorConvertF32ToS32TowardsZeroPseudo">;
def LLO_VpackFormat : IntEnumAttr<I16, "VpackFormat", "Vector pack format", [
I32EnumAttrCase<"kInvalid", 0, "invalid">,
I32EnumAttrCase<"kCompressedBf16", 1, "compressed_bf16">,
I32EnumAttrCase<"kCompressedB16", 2, "compressed_b16">,
I32EnumAttrCase<"kCompressedB8", 3, "compressed_b8">,
I32EnumAttrCase<"kCompressedB4", 4, "compressed_b4">,
I32EnumAttrCase<"kCompressedB2", 5, "compressed_b2">,
I32EnumAttrCase<"kCompressedB1", 6, "compressed_b1">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
let underlyingType = "uint16_t";
}
def LLO_VpackFormatEnum
: LLO_EnumAttr<LLO_VpackFormat, "vpack_format"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::VpackFormat";
}
def LLO_VectorUnpackOp : LLO_Op<"vunpack", [Pure]> {
let arguments = (ins
I32Attr:$sublane_idx,
LLO_VpackFormatEnum:$format,
LLO_Vector:$operand
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorUnpack(
::xla::jellyfish::LloOpcode::kVectorUnpack,
$_self.getSublaneIdx(),
static_cast<::xla::jellyfish::VpackFormat>($_self.getFormat()),
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_VectorPackOp : LLO_Op<"vpack", [Pure]> {
let arguments = (ins
LLO_VpackFormatEnum:$format,
LLO_Vector:$high,
LLO_Vector:$low
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $high `,` $low attr-dict `:` type($high) `,` type($low) `->` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateVectorPack(
::xla::jellyfish::LloOpcode::kVectorPack,
static_cast<::xla::jellyfish::VpackFormat>($_self.getFormat()),
GetLloValue($_self.getHigh(), $_value_map),
GetLloValue($_self.getLow(), $_value_map),
$_region);
}];
}
def LLO_BitDataFormat : I32EnumAttr<"BitDataFormat", "Bit data format", [
I32EnumAttrCase<"kB32", 0, "b32">,
I32EnumAttrCase<"kCompressedB16", 1, "b16.compressed">,
I32EnumAttrCase<"kCompressedB8", 2, "b8.compressed">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_BitDataFormatEnum
: LLO_EnumAttr<LLO_BitDataFormat, "bit_data_format"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::BitDataFormat";
}
def LLO_SetPermuteMode : I32EnumAttr<"SetPermuteMode", "Set permute mode", [
I32EnumAttrCase<"kOneSublane", 0, "one_sublane">,
I32EnumAttrCase<"kAllSublanes", 1, "all_sublanes">,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::llo";
}
def LLO_SetPermuteModeEnum
: LLO_EnumAttr<LLO_SetPermuteMode, "set_permute_mode"> {
let assemblyFormat = "`<` $value `>`";
let lloEnumType = "::xla::jellyfish::SetPermuteMode";
}
def LLO_VectorSetPermutePatternOp : LLO_Op<"vsetperm"> {
let arguments = (ins
LLO_Vector:$pattern,
LLO_SetPermuteModeEnum:$mode,
DefaultValuedAttr<I32Attr,"0">:$xlu_id,
OptionalAttr<I32Attr>:$source_bus
);
let results = (outs
LLO_Vector:$output // Not really, but we need to pass something to VectorPermuteOp
);
let assemblyFormat = [{ $pattern `,` $mode `,` attr-dict `:` type($pattern) `,` type($output) }];
}
def LLO_VectorPermuteOp : LLO_Op<"vperm"> {
let arguments = (ins
LLO_Vector:$source,
LLO_Vector:$pattern,
DefaultValuedAttr<I32Attr,"0">:$xlu_id,
OptionalAttr<I32Attr>:$source_bus
);
let results = (outs LLO_Vector:$request);
let assemblyFormat = [{ $source `,` $pattern `,` attr-dict `:` type($source) `,` type($pattern) `,` type($request) }];
}
def LLO_VectorLaneBroadcastOp : LLO_Op<"vlbroadcast"> {
let arguments = (ins
LLO_Vector:$source,
I32:$lane_offset,
Optional<I32>:$sublane_stride,
LLO_BitDataFormatEnum:$format
);
let results = (outs LLO_Vector:$request);
let builderMethod = [{
$_builder.Vbroadcastlane(
GetLloValue($_self.getSource(), $_value_map),
GetLloValue($_self.getLaneOffset(), $_value_map),
$_self.getSublaneStride()
? std::optional<::xla::jellyfish::LloValue*>(
GetLloValue($_self.getSublaneStride(), $_value_map))
: std::nullopt,
static_cast<::xla::jellyfish::BitDataFormat>($_self.getFormat())
);
}];
}
def LLO_VectorPermuteResultOp : LLO_Op<"vpermres"> {
let arguments = (ins
LLO_Vector:$request,
DefaultValuedAttr<I32Attr,"0">:$xlu_id
);
let results = (outs
LLO_Vector:$result
);
let assemblyFormat = [{ $request `,` attr-dict `:` type($request) `,` type($result) }];
}
def LLO_VectorBroadcastSublaneChunkOp : LLO_Op<"vbcast_sublane_chunk"> {
let arguments = (ins
LLO_Vector:$source,
I32:$lane_offset
);
let results = (outs LLO_Vector:$request);
let builderMethod = [{
::xla::jellyfish::fusion_util::BroadcastSublaneChunk(
GetLloValue($_self.getSource(), $_value_map),
GetLloValue($_self.getLaneOffset(), $_value_map),
ToPrimitiveType($_self.getType().getElementType(),
/*throw_error_throw_error_on_signless=*/false),
$_builder
)
}];
}
def LLO_VectorSublaneId : LLO_Op<"vslaneid"> {
let arguments = (ins);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let builderMethod = "$_builder.Vslaneid()";
}
def LLO_VectorLaneId : LLO_Op<"vxlaneid"> {
let arguments = (ins);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let builderMethod = "$_builder.Vxlaneid()";
}
def LLO_VectorLaneSeqOp : LLO_Op<"vlaneseq"> {
let arguments = (ins);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let builderMethod = "$_builder.Vlaneseq()";
}
def LLO_ScalarToVectorOp : LLO_Op<"stov", [Pure]> {
let arguments = (ins AnyType:$scalar);
let results = (outs LLO_Vector:$result);
}
def LLO_VectorToScalarOp : LLO_Op<"vtos", [Pure]> {
let arguments = (ins LLO_Vector:$vector);
let results = (outs AnyType:$scalar);
let builderMethod = "$_builder.Vtos(GetLloValue($_self.getVector(), $_value_map))";
}
// TODO(apaszke): Add support for other conversions
class LLO_ScalarConvertOp<string name, string opcode>
: LLO_Op<name, [Pure]> {
let arguments = (ins AnyType:$operand);
let results = (outs AnyType:$result);
let assemblyFormat = [{ $operand attr-dict `:` type($operand) `->` type($result) }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarUnOp(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_ScalarConvertS32ToF32Op
: LLO_ScalarConvertOp<"scvt.s32.f32", "kScalarConvertS32ToF32">;
def LLO_ScalarConvertF32ToS32Op
: LLO_ScalarConvertOp<"scvt.f32.s32", "kScalarConvertF32ToS32">;
def LLO_ScalarConvertF32ToS32TowardsZeroPseudoOp
: LLO_ScalarConvertOp<"scvt.f32.s32.to.zero.pseudo", "kScalarConvertF32ToS32TowardsZeroPseudo">;
class LLO_ScalarUnOp<string name, string opcode, Type type>
: LLO_Op<name, [Pure, SameOperandsAndResultType]> {
let arguments = (ins type:$operand);
let results = (outs type:$result);
let assemblyFormat = [{ $operand attr-dict }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarUnOp(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
$_region);
}];
}
def LLO_ScalarCountLeadingZerosOp : LLO_ScalarUnOp<"sclz", "kScalarCountLeadingZeros", I32>;
class LLO_ScalarBinOp<string name, string opcode, Type type>
: LLO_Op<name, [Pure, SameOperandsAndResultType]> {
let arguments = (ins
type:$lhs,
type:$rhs
);
let results = (outs type:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarBinOp(
::xla::jellyfish::LloOpcode::}] # opcode # [{,
GetLloValue($_self->getOperand(0), $_value_map),
GetLloValue($_self->getOperand(1), $_value_map),
$_region);
}];
}
def LLO_ScalarAddS32Op : LLO_ScalarBinOp<"sadd.s32", "kScalarAddS32", I32> {
let hasCanonicalizeMethod = 1;
}
def LLO_ScalarAddF32Op : LLO_ScalarBinOp<"sadd.f32", "kScalarAddF32", F32>;
def LLO_ScalarSubF32Op : LLO_ScalarBinOp<"ssub.f32", "kScalarSubtractF32", F32>;
def LLO_ScalarSubS32Op : LLO_ScalarBinOp<"ssub.s32", "kScalarSubtractS32", I32>;
def LLO_ScalarMulF32Op : LLO_ScalarBinOp<"smul.f32", "kScalarMultiplyF32", F32>;
def LLO_ScalarMaxF32Op : LLO_ScalarBinOp<"smax.f32", "kScalarMaximumF32", F32>;
def LLO_ScalarMinF32Op : LLO_ScalarBinOp<"smin.f32", "kScalarMinimumF32", F32>;
// Signedness does not matter for two's complement multiplication.
def LLO_ScalarMulS32Op : LLO_ScalarBinOp<"smul.s32", "kScalarMultiplyU32", I32>;
def LLO_ScalarDivS32Op : LLO_ScalarBinOp<"sdiv.s32", "", I32> {
let hasInstructionLowering = 0;
}
def LLO_ScalarRemS32Op : LLO_ScalarBinOp<"srem.s32", "", I32> {
let hasInstructionLowering = 0;
}
def LLO_ScalarMinS32Op : LLO_ScalarBinOp<"smin.s32", "UNUSED", I32> {
let builderMethod = [{
$_builder.SminS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_ScalarMaxS32Op : LLO_ScalarBinOp<"smax.s32", "UNUSED", I32> {
let builderMethod = [{
$_builder.SmaxS32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_ScalarDivF32Op : LLO_ScalarBinOp<"sdiv.f32", "UNUSED", F32> {
let builderMethod = [{
$_builder.SdivF32(GetLloValue(op.getLhs(), $_value_map),
GetLloValue(op.getRhs(), $_value_map))
}];
}
def LLO_ScalarShllOp : LLO_ScalarBinOp<"shll", "kScalarShll", I32>;
def LLO_ScalarShrlOp : LLO_ScalarBinOp<"shrl", "kScalarShrl", I32>;
def LLO_ScalarShraOp : LLO_ScalarBinOp<"shla", "kScalarShra", I32>;
def LLO_ScalarBitwiseOrOp : LLO_ScalarBinOp<"sor", "kScalarBitwiseOr", I32>;
def LLO_ScalarBitwiseAndOp : LLO_ScalarBinOp<"sand", "kScalarBitwiseAnd", I32>;
def LLO_ScalarBitwiseXorOp : LLO_ScalarBinOp<"sxor", "kScalarBitwiseXor", I32>;
def LLO_ScalarNegF32Op : LLO_ScalarUnOp<"sneg.f32", "UNUSED", F32> {
let builderMethod = "$_builder.SnegF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_ScalarSqrtF32Op : LLO_ScalarUnOp<"ssqrt.f32", "UNUSED", F32> {
let builderMethod = "$_builder.Ssqrt(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_ScalarExpF32Op : LLO_ScalarUnOp<"sexp.f32", "UNUSED", F32> {
let builderMethod = "$_builder.Sexp(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_ScalarLogF32Op : LLO_ScalarUnOp<"sln.f32", "UNUSED", F32> {
let builderMethod = "$_builder.Sln(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_ScalarRoundF32Op : LLO_ScalarUnOp<"sround.f32", "UNUSED", F32> {
let builderMethod = "$_builder.SroundF32(GetLloValue(op.getOperand(), $_value_map))";
}
def LLO_ScalarRoundEvenF32Op : LLO_ScalarUnOp<"sround.ties_to_even.f32", "UNUSED", F32> {
let builderMethod = "$_builder.SroundTiesToEvenF32(GetLloValue(op.getOperand(), $_value_map))";
}
class LLO_ScalarCompareOp<string name, string direction, string xlaType, Type type>
: LLO_Op<name, [Pure]> {
let arguments = (ins
type:$lhs,
type:$rhs
);
let results = (outs I1:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict }];
let customInstructionLowering = [{
return ::xla::jellyfish::LloInstruction::CreateScalarCompare(
GetLloValue($_self->getOperand(0), $_value_map),
GetLloValue($_self->getOperand(1), $_value_map),
::xla::Comparison(
::xla::Comparison::Direction::}] # direction # [{,
::xla::}] # xlaType # [{),
$_region);
}];
}
foreach dir = cmp_directions in {
def LLO_ScalarCmp#dir#S32Op : LLO_ScalarCompareOp<
"s"#!tolower(dir)#".s32", "k"#dir, "S32", I32>;
}
foreach dir = cmp_directions in {
def LLO_ScalarCmp#dir#F32Op : LLO_ScalarCompareOp<
"s"#!tolower(dir)#".f32", "k"#dir, "F32", F32>;
}
def LLO_ScalarSelectOp : LLO_Op<"sselect", [Pure]> {
let arguments = (ins
I1:$condition,
AnyType:$if_true,
AnyType:$if_false
);
let results = (outs AnyType:$result);
let assemblyFormat = [{
$condition `,` $if_true `,` $if_false attr-dict `:`
`(` type($if_true) `,` type($if_false) `)` `->` type($result)
}];
}
def LLO_ScalarLoadOp : LLO_Op<"sld"> {
let arguments = (ins
I32:$address,
Optional<I32>:$offset
);
let results = (outs AnyType:$result);
let assemblyFormat = [{ $address (`+` $offset^)? attr-dict `:` type($result) }];
}
def LLO_ScalarStoreOp : LLO_Op<"sst"> {
let arguments = (ins
I32:$address,
AnyType:$to_store
);
let results = (outs);
let assemblyFormat = [{ $address `,` $to_store attr-dict `:` type($to_store) }];
}
def LLO_PredicateOrOp : LLO_Op<"por", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
I1:$lhs,
I1:$rhs
);
let results = (outs I1:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict }];
}
def LLO_PredicateAndOp : LLO_Op<"pand", [Pure, SameOperandsAndResultType]> {
let arguments = (ins
I1:$lhs,
I1:$rhs
);
let results = (outs I1:$result);
let assemblyFormat = [{ $lhs `,` $rhs attr-dict }];
let builderMethod = [{
$_builder.Pand(GetLloValue($_self.getLhs(), $_value_map),
GetLloValue($_self.getRhs(), $_value_map))
}];
}
def LLO_PredicateNegateOp : LLO_Op<"pneg", [Pure, SameOperandsAndResultType]> {
let arguments = (ins I1:$predicate);
let results = (outs I1:$result);
let assemblyFormat = [{ $predicate attr-dict }];
}
def LLO_ErrorIfOp : LLO_Op<"error.if"> {
let arguments = (ins
I1:$predicate,
StrAttr:$message
);
let results = (outs);
let builderMethod = [{
$_builder.ErrorIf(GetLloValue($_self.getPredicate(), $_value_map), $_self.getMessage())
}];
}
def LLO_VectorBitcastOp : LLO_Op<"vbitcast", [Pure]> {
let arguments = (ins LLO_Vector:$input);
let results = (outs LLO_Vector:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
let hasInstructionLowering = 0;
}
def LLO_ScalarBitcastOp : LLO_Op<"sbitcast", [Pure]> {
let arguments = (ins AnyType:$input);
let results = (outs AnyType:$output);
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
let hasInstructionLowering = 0;
}
def LLO_EnqueueDMAOp : LLO_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
let arguments = (ins
I32:$src_addr,
Optional<I32>:$src_flag,
I32:$size_512b, // size in multiples of 512 bytes
I32:$dst_addr,
I32:$dst_flag,
// Major-to-minor, strides (and the last steps_per_stride) are in bytes.
DenseI32ArrayAttr:$src_strides,
DenseI32ArrayAttr:$dst_strides,
Variadic<I32>:$steps_per_stride,
Optional<I32>:$chip_id,
Optional<I32>:$core_index
);
let builderMethod = [{
$_builder.EnqueueDmaMultiStrided(
GetLloValue($_self.getSrcAddr(), $_value_map),
GetLloValue($_self.getDstAddr(), $_value_map),
GetOptionalCoreLocation($_context, $_builder, $_self, "enqueue-dma", $_value_map),
::xla::jellyfish::LloMemUnit(
GetLloValue($_self.getSize_512b(), $_value_map),
::xla::jellyfish::Granule::k512Byte),
GetDmaStrides($_builder, $_self.getSrcStrides(), $_self.getDstStrides(), $_self.getStepsPerStride(), $_value_map),
GetLloValue($_self.getDstFlag(), $_value_map),
op.getSrcFlag() ? GetLloValue(op.getSrcFlag(), $_value_map) : $_builder.DummySyncFlag()
);
}];
let hasVerifier = 1;
}
def LLO_DMADoneOp : LLO_Op<"dma_done"> {
let arguments = (ins
I32:$size_512b, // size in multiples of 512 bytes
I32:$flag_addr,
BoolAttr:$emit_sfence
);
let builderMethod = [{
$_builder.DmaDoneInGranules(
::xla::jellyfish::LloMemUnit(
GetLloValue($_self.getSize_512b(), $_value_map),
::xla::jellyfish::Granule::k512Byte),
GetLloValue($_self.getFlagAddr(), $_value_map),
$_self.getEmitSfence()
);
}];
}
def LLO_ChipIdOp : LLO_Op<"chip_id"> {
let arguments = (ins);
let results = (outs I32:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let builderMethod = "$_builder.ChipId()";
}
def LLO_CoreIndexOp : LLO_Op<"core_index"> {
let arguments = (ins);
let results = (outs I32:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let builderMethod = "$_builder.CoreIndex()";
}
def LLO_LogicalDeviceIdOp : LLO_Op<"logical_device_id", [Pure]> {
let arguments = (ins);
let results = (outs I32:$result);
let assemblyFormat = [{ attr-dict `:` type($result) }];
let hasInstructionLowering = 0;
}
def LLO_VSyncRead : LLO_Op<"vsync.read"> {
let arguments = (ins I32:$flag_addr);
let results = (outs I32:$result);
let assemblyFormat = [{ $flag_addr attr-dict `:` type($result) }];
let builderMethod = "$_builder.VsyncRead(GetLloValue($_self.getFlagAddr(), $_value_map))";
}
class LLO_VWaitOp<string cmp> : LLO_Op<"vwait." # !tolower(cmp)> {
let arguments = (ins
I32:$flag_addr,
I32:$value
);
let builderMethod = [{
$_builder.Vwait}] # cmp # [{SV(
GetLloValue($_self.getFlagAddr(), $_value_map),
GetLloValue($_self.getValue(), $_value_map)
);
}];
}
def LLO_VWaitEqOp : LLO_VWaitOp<"Eq">;
def LLO_VWaitGeOp : LLO_VWaitOp<"Ge">;
def LLO_VSyncSetOp : LLO_Op<"vsync.set"> {
let arguments = (ins
I32:$flag_addr,
I32:$value
);
let builderMethod = [{
$_builder.VsyncSet(
GetLloValue($_self.getFlagAddr(), $_value_map),
GetLloValue($_self.getValue(), $_value_map)
);
}];
}
def LLO_VSyncAddRemoteOp : LLO_Op<"vsync.add.remote"> {
let arguments = (ins
I32:$flag_addr,
I32:$chip_id,
I32:$core_index,
I32:$amount
);
let results = (outs);
let assemblyFormat = [{ $flag_addr `,` $chip_id `,` $core_index `,` $amount attr-dict }];
let builderMethod = [{
$_builder.VsyncAddRemote(
GetLloValue($_self.getFlagAddr(), $_value_map),
GetCoreLocation($_context, $_builder, $_self, "sync-add", $_value_map),
GetLloValue($_self.getAmount(), $_value_map)
);
}];
}
def LLO_VSyncAddOp : LLO_Op<"vsync.add"> {
let arguments = (ins
I32:$flag_addr,
I32:$amount
);
let results = (outs);
let assemblyFormat = [{ $flag_addr `,` $amount attr-dict }];
let builderMethod = [{
$_builder.VsyncAdd(
GetLloValue($_self.getFlagAddr(), $_value_map),
GetLloValue($_self.getAmount(), $_value_map)
);
}];
}
def LLO_LogS32 : LLO_Op<"log.s32"> {
let arguments = (ins
I32:$operand,
StrAttr:$tag
);
let results = (outs);
let hasInstructionLowering = 0;
}
def LLO_LogPred : LLO_Op<"log.pred"> {
let arguments = (ins
I1:$operand,
StrAttr:$tag
);
let results = (outs);
let hasInstructionLowering = 0;
}
def LLO_LogVmaskOp : LLO_Op<"log.vmask"> {
let arguments = (ins
LLO_VectorMask:$operand,
StrAttr:$tag
);
let results = (outs);
let builderMethod = "$_builder.LogVmask(GetLloValue($_self.getOperand(), $_value_map), $_self.getTag().str())";
}
def LLO_RegionOp : LLO_Op<"region", [
RecursiveMemoryEffects, SingleBlockImplicitTerminator<"llo::YieldOp">]> {
let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$region);
let hasInstructionLowering = 0;
}
def LLO_TraceOp : LLO_Op<"trace", [RecursiveMemoryEffects, SingleBlockImplicitTerminator<"llo::YieldOp">]> {
let arguments = (ins StrAttr:$message, I32Attr:$level);
let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$region);
let hasInstructionLowering = 0;
}
def LLO_YieldOp : LLO_Op<"yield", [Pure, ReturnLike, Terminator]> {
let arguments = (ins Variadic<AnyType>:$results);
let assemblyFormat = [{ attr-dict ($results^ `:` type($results))? }];
let hasInstructionLowering = 0;
}
def LLO_TraceStartOp : LLO_Op<"trace_start", []> {
let arguments = (ins StrAttr:$message, I32Attr:$level);
let results = (outs);
let hasInstructionLowering = 0;
}
def LLO_TraceStopOp : LLO_Op<"trace_stop", []> {
let arguments = (ins);
let results = (outs);
let hasInstructionLowering = 0;
}
// Extensions that do not have any corresponding ops in LLO, but help with
// lowering and optimization.
// TODO(apaszke): Implement a common optimization that folds the ops that follow
// the alloc-store-load-free pattern into loads
def LLO_VectorSublaneReverseOp : LLO_Op<"sublane_reverse", [Pure]> {
let arguments = (ins LLO_Vector:$operand);
let results = (outs LLO_Vector:$result);
let hasInstructionLowering = 0;
}
def LLO_VectorLoadAndUnpackOp : LLO_Op<"vector_load_unpack", [VmemRead]> {
let arguments = (ins
I32:$address,
I32Attr:$chunk_id,
TypeAttr:$packed_type
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{
$address `,` $chunk_id attr-dict `:` type($result)
}];
let builderMethod = [{
$_builder.LdAndUnpack(
ToPrimitiveType($_self.getPackedType()),
GetLloValue($_self.getAddress(), $_value_map),
$_builder.SimmS32($_self.getChunkId()),
$_builder.SublaneMaskForSublaneCount($_builder.target().SublaneCount()),
$_builder.target().SublaneCount());
}];
}
def LLO_VectorStoreAndPackX16Op : LLO_Op<"vector_store_unpack_x16", [VmemWrite]> {
let arguments = (ins
I32:$address,
I32Attr:$chunk_id,
LLO_Vector:$to_store,
TypeAttr:$packed_type
);
let results = (outs);
let assemblyFormat = [{
$address `,` $chunk_id `,` $to_store attr-dict `:` type($to_store)
}];
let builderMethod = [{
$_builder.StAndPackToX16(
ToPrimitiveType($_self.getPackedType()),
GetLloValue($_self.getAddress(), $_value_map),
$_builder.SimmS32($_self.getChunkId()),
GetLloValue($_self.getToStore(), $_value_map),
$_builder.SimmS32($_builder.target().SublaneCount() / 2));
}];
}
def LLO_VectorSublaneShuffleOp : LLO_Op<"slane.shuffle", [Pure]> {
let arguments = (ins
LLO_Vector:$source,
DenseI32ArrayAttr:$indices
);
let results = (outs
LLO_Vector:$result
);
let assemblyFormat = [{ $source `[` $indices `]` attr-dict `:` type($source) `->` type($result) }];
let builderMethod = [{
$_builder.VslaneShfl(GetLloValue($_self.getSource(), $_value_map),
b.CreateShufflePattern($_self.getIndices()));
}];
}
def LLO_VectorSublaneReplicateOp : LLO_Op<"vslreplicate", [Pure]> {
let arguments = (ins
LLO_Vector:$source,
I32:$sublane
);
let results = (outs LLO_Vector:$result);
let assemblyFormat = [{ $source `,` $sublane attr-dict `:` type($source) `->` type($result)}];
let builderMethod = [{$_builder.VslaneToAll(GetLloValue($_self.getSource(), $_value_map),
GetLloValue($_self.getSublane(), $_value_map)) }];
}
def EliminateLLOExtensionsPass : Pass<"elim-llo-ext", "::mlir::func::FuncOp"> {
let dependentDialects = ["::mlir::llo::LLODialect"];
let constructor = "::mlir::llo::createEliminateLLOExtensionsPass()";
}
// LLO extensions that are only valid within Mosaic programs and are lowered specially.
def LLO_GetIterationBoundOp : LLO_Op<"iteration_bound"> {
let arguments = (ins I32Attr:$dim);
let results = (outs I32:$result);
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
let hasInstructionLowering = 0;
}
def LLO_GetBarrierSyncFlagOp : LLO_Op<"barrier_sflag"> {
let arguments = (ins);
let results = (outs I32:$sync_flag);
let assemblyFormat = [{ attr-dict `:` type($sync_flag) }];
let hasInstructionLowering = 0;
}
#endif // LLO_OPS