| |
| #include <string> |
| |
| #include "llvm/Support/CommandLine.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/InitLLVM.h" |
| #include "llvm/TableGen/Error.h" |
| #include "llvm/TableGen/Main.h" |
| #include "llvm/TableGen/Record.h" |
| #include "llvm/TableGen/TableGenBackend.h" |
| #include "mlir/TableGen/Operator.h" |
| #include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h" |
| #include "third_party/llvm/llvm-project/llvm/include/llvm/TableGen/Error.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/TableGen/Argument.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/TableGen/Format.h" |
| |
| using llvm::StringRef; |
| using mlir::tblgen::Operator; |
| using mlir::tblgen::NamedAttribute; |
| using mlir::tblgen::NamedTypeConstraint; |
| using mlir::tblgen::Argument; |
| using mlir::failure; |
| using mlir::success; |
| |
| static StringRef XlaInstructionName(const Operator& op) { |
| return op.getCppClassName().drop_back(2); // Drop the Op suffix. |
| } |
| |
| static std::string AttrType(const Operator& op, const NamedAttribute& nattr) { |
| if (auto ty = nattr.attr.getDef().getValueAsOptionalString("lloEnumType")) { |
| return ty->str(); |
| } |
| StringRef def_name = nattr.attr.getAttrDefName(); |
| std::string type; |
| if (def_name == "I64Attr") { |
| type = "int64_t"; |
| } |
| if (def_name == "I32Attr") { |
| type = "int32_t"; |
| } |
| if (nattr.attr.isOptional()) { |
| type = "std::optional<" + type + ">"; |
| } |
| return type; |
| PrintFatalError( |
| op.getLoc(), "Unknown attribute (can't infer C++ type): " + def_name); |
| } |
| |
| static std::string AttrValue(const Operator& op, |
| const NamedAttribute& nattr, |
| const std::string& generic_attr) { |
| auto& attr_def = nattr.attr.getDef(); |
| if (auto ty = attr_def.getValueAsOptionalString("lloEnumType")) { |
| auto cppNamespace = attr_def.getValueAsString("cppNamespace"); |
| auto cppClass = attr_def.getValueAsString("cppClassName"); |
| return llvm::formatv("static_cast<{0}>(cast<{2}::{3}>({1}).getValue())", |
| *ty, generic_attr, cppNamespace, cppClass) |
| .str(); |
| } |
| StringRef def_name = nattr.attr.getAttrDefName(); |
| if (def_name == "I64Attr" || def_name == "I32Attr") { |
| std::string value = |
| "cast<mlir::IntegerAttr>(" + generic_attr + ").getInt()"; |
| if (nattr.attr.hasDefaultValue()) { |
| value = generic_attr + " ? " + value + " : " + |
| nattr.attr.getDefaultValue().str(); |
| } else if (nattr.attr.isOptional()) { |
| auto type = AttrType(op, nattr); |
| value = generic_attr + " ? " + type + "(" + value + ") : std::nullopt"; |
| } |
| return value; |
| } |
| PrintFatalError( |
| op.getLoc(), "Unknown attribute (can't get value): " + def_name); |
| } |
| |
| static mlir::LogicalResult BuildOperatorGeneric(llvm::raw_ostream& os, |
| const Operator& op) { |
| bool has_operand_segments = false; |
| if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) { |
| has_operand_segments = true; |
| os << " auto operand_sizes = " |
| "cast<DenseI32ArrayAttr>(op->getAttr(\"operandSegmentSizes\"))." |
| "asArrayRef();\n"; |
| } else { |
| if (op.isVariadic() || op.getNumVariableLengthOperands() > 1) { |
| return failure(); |
| } |
| } |
| os << " auto operand_it = op->getOperands().begin();\n"; |
| os << " (void)operand_it;\n"; |
| int operand_ix = 0; |
| for (const Argument& arg : op.getArgs()) { |
| os.indent(4); |
| if (auto nattr = arg.dyn_cast<NamedAttribute*>()) { |
| auto attr_type = AttrType(op, *nattr); |
| auto attr_value = AttrValue( |
| op, *nattr, llvm::formatv("op->getAttr(\"{0}\")", nattr->name).str()); |
| os << attr_type << " " << nattr->name << " = " << attr_value; |
| } else if (auto operand = arg.dyn_cast<NamedTypeConstraint*>()) { |
| os << "::xla::jellyfish::LloValue* " << operand->name << " = "; |
| if (operand->isOptional() && has_operand_segments) { |
| os << "(operand_sizes[" << operand_ix |
| << "] == 1) ? GetLloValue(*operand_it++, value_map) : nullptr"; |
| } else if (operand->isOptional()) { |
| os << "(op->getNumOperands() == " << op.getNumOperands() |
| << ") ? GetLloValue(*operand_it++, value_map) : nullptr"; |
| } else { |
| if (operand->isVariadic()) { |
| return failure(); |
| } |
| os << "GetLloValue(*operand_it++, value_map)"; |
| } |
| ++operand_ix; |
| } |
| os << ";\n"; |
| } |
| os << " return ::xla::jellyfish::LloInstruction::Create" |
| << XlaInstructionName(op) << "("; |
| for (const Argument& arg : op.getArgs()) { |
| if (auto nattr = arg.dyn_cast<NamedAttribute*>()) { |
| os << nattr->name; |
| } else if (auto operand = arg.dyn_cast<NamedTypeConstraint*>()) { |
| os << operand->name; |
| } else { |
| return failure(); |
| } |
| os << ", "; |
| } |
| os << "region);\n"; |
| return success(); |
| } |
| |
| static mlir::LogicalResult BuildOperator(llvm::raw_ostream& os, |
| const Operator& op) { |
| if (!op.getDef().getValueAsBit("hasInstructionLowering")) { |
| return success(); |
| } |
| os << " if (auto op = llvm::dyn_cast<" << op.getCppNamespace() |
| << "::" << op.getCppClassName() << ">(raw_op)) {\n"; |
| mlir::tblgen::FmtContext ctx; |
| ctx.addSubst("_value_map", "value_map") |
| .addSubst("_region", "region") |
| .addSubst("_context", "ctx") |
| .withBuilder("b") |
| .withSelf("op"); |
| auto builder_method = op.getDef().getValueAsOptionalString("builderMethod"); |
| if (!builder_method) { |
| os << " auto instruction = [&b, &op, ®ion, &value_map]() -> " |
| "std::unique_ptr<::xla::jellyfish::LloInstruction> {\n"; |
| if (auto impl = |
| op.getDef().getValueAsOptionalString("customInstructionLowering")) { |
| os << mlir::tblgen::tgfmt(*impl, &ctx); |
| } else { |
| if (failed(BuildOperatorGeneric(os, op))) { |
| return failure(); |
| } |
| } |
| os << " }();\n"; |
| os << " if (!instruction) return ::mlir::failure();\n"; |
| } |
| os.indent(4); |
| if (op.getNumResults() == 1) { |
| os << "value_map[op.getResult()] = "; |
| } else if (op.getNumResults() != 0) { |
| PrintFatalError(op.getLoc(), "Multi-result operations are unsupported"); |
| return failure(); |
| } |
| if (builder_method) { |
| os << mlir::tblgen::tgfmt(*builder_method, &ctx) << ";\n"; |
| } else { |
| os << "b.Instruction(std::move(instruction));\n"; |
| } |
| os << " return ::mlir::success();\n"; |
| os << " }\n"; |
| return success(); |
| } |
| |
| static bool OperatorWritersMain(llvm::raw_ostream& os, |
| llvm::RecordKeeper& records) { |
| llvm::emitSourceFileHeader("LLO Builder bindings", os); |
| os << '\n'; |
| os << R"-(mlir::LogicalResult AppendInstructionGenerated( |
| BuilderContext *ctx, |
| ::xla::jellyfish::LloRegionBuilder& b, |
| Operation* raw_op, |
| llvm::DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {)-"; |
| os << '\n'; |
| os << " auto region = b.region();\n"; |
| for (const auto* def : records.getAllDerivedDefinitions("LLO_Op")) { |
| Operator op(def); |
| if (failed(BuildOperator(os, op))) { |
| return true; |
| } |
| } |
| os << " return ::mlir::failure();\n"; |
| os << "}\n"; |
| for (const auto* def : records.getAllDerivedDefinitions("LLO_EnumAttr")) { |
| auto llo_type = def->getValueAsString("lloEnumType"); |
| auto *enum_def = def->getValueAsDef("enum"); |
| auto cases = enum_def->getValueAsListOfDefs("enumerants"); |
| os << '\n'; |
| os << "// Verifying values of " << llo_type << '\n'; |
| for (llvm::Record* c : cases) { |
| auto val = c->getValueAsInt("value"); |
| auto sym = c->getValueAsString("symbol"); |
| os << llvm::formatv( |
| "static_assert(static_cast<int>({0}::{1}) == {2}, \"Mismatched value " |
| "of {0}::{1}\");\n", |
| llo_type, sym, val); |
| } |
| } |
| return false; |
| } |
| |
| int main(int argc, char** argv) { |
| llvm::InitLLVM y(argc, argv); |
| llvm::cl::ParseCommandLineOptions(argc, argv); |
| return TableGenMain(argv[0], &OperatorWritersMain); |
| } |