blob: 8421d3f74002fd815a697d76928d1b98d09b0cc5 [file] [log] [blame]
#include <string>
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/InitLLVM.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Main.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
#include "mlir/TableGen/Operator.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/ADT/StringRef.h"
#include "third_party/llvm/llvm-project/llvm/include/llvm/TableGen/Error.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/Support/LogicalResult.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/TableGen/Argument.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/TableGen/Format.h"
using llvm::StringRef;
using mlir::tblgen::Operator;
using mlir::tblgen::NamedAttribute;
using mlir::tblgen::NamedTypeConstraint;
using mlir::tblgen::Argument;
using mlir::failure;
using mlir::success;
static StringRef XlaInstructionName(const Operator& op) {
return op.getCppClassName().drop_back(2); // Drop the Op suffix.
}
static std::string AttrType(const Operator& op, const NamedAttribute& nattr) {
if (auto ty = nattr.attr.getDef().getValueAsOptionalString("lloEnumType")) {
return ty->str();
}
StringRef def_name = nattr.attr.getAttrDefName();
std::string type;
if (def_name == "I64Attr") {
type = "int64_t";
}
if (def_name == "I32Attr") {
type = "int32_t";
}
if (nattr.attr.isOptional()) {
type = "std::optional<" + type + ">";
}
return type;
PrintFatalError(
op.getLoc(), "Unknown attribute (can't infer C++ type): " + def_name);
}
static std::string AttrValue(const Operator& op,
const NamedAttribute& nattr,
const std::string& generic_attr) {
auto& attr_def = nattr.attr.getDef();
if (auto ty = attr_def.getValueAsOptionalString("lloEnumType")) {
auto cppNamespace = attr_def.getValueAsString("cppNamespace");
auto cppClass = attr_def.getValueAsString("cppClassName");
return llvm::formatv("static_cast<{0}>(cast<{2}::{3}>({1}).getValue())",
*ty, generic_attr, cppNamespace, cppClass)
.str();
}
StringRef def_name = nattr.attr.getAttrDefName();
if (def_name == "I64Attr" || def_name == "I32Attr") {
std::string value =
"cast<mlir::IntegerAttr>(" + generic_attr + ").getInt()";
if (nattr.attr.hasDefaultValue()) {
value = generic_attr + " ? " + value + " : " +
nattr.attr.getDefaultValue().str();
} else if (nattr.attr.isOptional()) {
auto type = AttrType(op, nattr);
value = generic_attr + " ? " + type + "(" + value + ") : std::nullopt";
}
return value;
}
PrintFatalError(
op.getLoc(), "Unknown attribute (can't get value): " + def_name);
}
static mlir::LogicalResult BuildOperatorGeneric(llvm::raw_ostream& os,
const Operator& op) {
bool has_operand_segments = false;
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
has_operand_segments = true;
os << " auto operand_sizes = "
"cast<DenseI32ArrayAttr>(op->getAttr(\"operandSegmentSizes\"))."
"asArrayRef();\n";
} else {
if (op.isVariadic() || op.getNumVariableLengthOperands() > 1) {
return failure();
}
}
os << " auto operand_it = op->getOperands().begin();\n";
os << " (void)operand_it;\n";
int operand_ix = 0;
for (const Argument& arg : op.getArgs()) {
os.indent(4);
if (auto nattr = arg.dyn_cast<NamedAttribute*>()) {
auto attr_type = AttrType(op, *nattr);
auto attr_value = AttrValue(
op, *nattr, llvm::formatv("op->getAttr(\"{0}\")", nattr->name).str());
os << attr_type << " " << nattr->name << " = " << attr_value;
} else if (auto operand = arg.dyn_cast<NamedTypeConstraint*>()) {
os << "::xla::jellyfish::LloValue* " << operand->name << " = ";
if (operand->isOptional() && has_operand_segments) {
os << "(operand_sizes[" << operand_ix
<< "] == 1) ? GetLloValue(*operand_it++, value_map) : nullptr";
} else if (operand->isOptional()) {
os << "(op->getNumOperands() == " << op.getNumOperands()
<< ") ? GetLloValue(*operand_it++, value_map) : nullptr";
} else {
if (operand->isVariadic()) {
return failure();
}
os << "GetLloValue(*operand_it++, value_map)";
}
++operand_ix;
}
os << ";\n";
}
os << " return ::xla::jellyfish::LloInstruction::Create"
<< XlaInstructionName(op) << "(";
for (const Argument& arg : op.getArgs()) {
if (auto nattr = arg.dyn_cast<NamedAttribute*>()) {
os << nattr->name;
} else if (auto operand = arg.dyn_cast<NamedTypeConstraint*>()) {
os << operand->name;
} else {
return failure();
}
os << ", ";
}
os << "region);\n";
return success();
}
static mlir::LogicalResult BuildOperator(llvm::raw_ostream& os,
const Operator& op) {
if (!op.getDef().getValueAsBit("hasInstructionLowering")) {
return success();
}
os << " if (auto op = llvm::dyn_cast<" << op.getCppNamespace()
<< "::" << op.getCppClassName() << ">(raw_op)) {\n";
mlir::tblgen::FmtContext ctx;
ctx.addSubst("_value_map", "value_map")
.addSubst("_region", "region")
.addSubst("_context", "ctx")
.withBuilder("b")
.withSelf("op");
auto builder_method = op.getDef().getValueAsOptionalString("builderMethod");
if (!builder_method) {
os << " auto instruction = [&b, &op, &region, &value_map]() -> "
"std::unique_ptr<::xla::jellyfish::LloInstruction> {\n";
if (auto impl =
op.getDef().getValueAsOptionalString("customInstructionLowering")) {
os << mlir::tblgen::tgfmt(*impl, &ctx);
} else {
if (failed(BuildOperatorGeneric(os, op))) {
return failure();
}
}
os << " }();\n";
os << " if (!instruction) return ::mlir::failure();\n";
}
os.indent(4);
if (op.getNumResults() == 1) {
os << "value_map[op.getResult()] = ";
} else if (op.getNumResults() != 0) {
PrintFatalError(op.getLoc(), "Multi-result operations are unsupported");
return failure();
}
if (builder_method) {
os << mlir::tblgen::tgfmt(*builder_method, &ctx) << ";\n";
} else {
os << "b.Instruction(std::move(instruction));\n";
}
os << " return ::mlir::success();\n";
os << " }\n";
return success();
}
static bool OperatorWritersMain(llvm::raw_ostream& os,
llvm::RecordKeeper& records) {
llvm::emitSourceFileHeader("LLO Builder bindings", os);
os << '\n';
os << R"-(mlir::LogicalResult AppendInstructionGenerated(
BuilderContext *ctx,
::xla::jellyfish::LloRegionBuilder& b,
Operation* raw_op,
llvm::DenseMap<mlir::Value, ::xla::jellyfish::LloValue*>& value_map) {)-";
os << '\n';
os << " auto region = b.region();\n";
for (const auto* def : records.getAllDerivedDefinitions("LLO_Op")) {
Operator op(def);
if (failed(BuildOperator(os, op))) {
return true;
}
}
os << " return ::mlir::failure();\n";
os << "}\n";
for (const auto* def : records.getAllDerivedDefinitions("LLO_EnumAttr")) {
auto llo_type = def->getValueAsString("lloEnumType");
auto *enum_def = def->getValueAsDef("enum");
auto cases = enum_def->getValueAsListOfDefs("enumerants");
os << '\n';
os << "// Verifying values of " << llo_type << '\n';
for (llvm::Record* c : cases) {
auto val = c->getValueAsInt("value");
auto sym = c->getValueAsString("symbol");
os << llvm::formatv(
"static_assert(static_cast<int>({0}::{1}) == {2}, \"Mismatched value "
"of {0}::{1}\");\n",
llo_type, sym, val);
}
}
return false;
}
int main(int argc, char** argv) {
llvm::InitLLVM y(argc, argv);
llvm::cl::ParseCommandLineOptions(argc, argv);
return TableGenMain(argv[0], &OperatorWritersMain);
}