Alex's LAVM code
diff --git a/LAVM/BUILD b/LAVM/BUILD
new file mode 100644
index 0000000..f85b21b
--- /dev/null
+++ b/LAVM/BUILD
@@ -0,0 +1,182 @@
+load("//third_party/flex:flex.bzl", "genlex")
+load("//third_party/bison:bison.bzl", "genyacc")
+load("//third_party/llvm/llvm-project/mlir:tblgen.bzl", "gentbl")
+
+package(
+ default_visibility = [":friends"],
+)
+
+package_group(
+ name = "friends",
+ packages = [
+ "//learning/brain/experimental/mlir/...",
+ ],
+)
+
+filegroup(
+ name = "LAVMDialectTdFiles",
+ srcs = [
+ "LAVMDialect.td",
+ "//third_party/llvm/llvm-project/mlir:OpBaseTdFiles",
+ ],
+)
+
+gentbl(
+ name = "LAVMDialectIncGen",
+ tbl_outs = [
+ (
+ "-gen-op-decls",
+ "LAVMDialect.h.inc",
+ ),
+ (
+ "-gen-op-defs",
+ "LAVMDialect.cpp.inc",
+ ),
+ ],
+ tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen",
+ td_file = "LAVMDialect.td",
+ td_srcs = [
+ ":LAVMDialectTdFiles",
+ ],
+)
+
+cc_library(
+ name = "LAVMDialect",
+ srcs = [
+ "LAVMDialect.cpp",
+ ],
+ hdrs = [
+ "LAVMDialect.h",
+ ],
+ deps = [
+ ":LAVMDialectIncGen",
+ "//third_party/llvm/llvm-project/llvm:support",
+ "//third_party/llvm/llvm-project/mlir:IR",
+ "//third_party/llvm/llvm-project/mlir:StandardOps",
+ "//third_party/llvm/llvm-project/mlir:Support",
+ ],
+)
+
+cc_library(
+ name = "LAVMDialectRegistration",
+ srcs = [
+ "LAVMDialectRegistration.cpp",
+ ],
+ deps = [
+ ":LAVMDialect",
+ "//third_party/llvm/llvm-project/mlir:IR",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "LAVMTarget",
+ srcs = [
+ "LAVMTarget.cpp",
+ "LAVMTargetParser.cpp",
+ ],
+ hdrs = [
+ "LAVMTarget.h",
+ "Lexer.h",
+ ],
+ deps = [
+ ":LAVMDialect",
+ ":md_grammar",
+ "//third_party/llvm/llvm-project/llvm:support",
+ "//third_party/llvm/llvm-project/mlir:IR",
+ ],
+)
+
+gentbl(
+ name = "LAVMPassIncGen",
+ tbl_outs = [
+ (
+ "-gen-rewriters",
+ "LAVMPass.cpp.inc",
+ ),
+ ],
+ tblgen = "//third_party/llvm/llvm-project/mlir:mlir-tblgen",
+ td_file = "LAVMPass.td",
+ td_srcs = [
+ ":LAVMDialectTdFiles",
+ "//third_party/llvm/llvm-project/mlir:StdOpsTdFiles",
+ ],
+)
+
+cc_library(
+ name = "LAVMPass",
+ srcs = [
+ "LAVMPass.cpp",
+ ],
+ deps = [
+ ":LAVMDialect",
+ ":LAVMPassIncGen",
+ ":LAVMTarget",
+ "//third_party/llvm/llvm-project/llvm:support",
+ "//third_party/llvm/llvm-project/mlir:IR",
+ "//third_party/llvm/llvm-project/mlir:Pass",
+ "//third_party/llvm/llvm-project/mlir:StandardOps",
+ "//third_party/llvm/llvm-project/mlir:Transforms",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "LAVMMain",
+ srcs = [
+ "LAVMMain.cpp",
+ ],
+ deps = [
+ "//third_party/llvm/llvm-project/llvm:support",
+ "//third_party/llvm/llvm-project/mlir:Parser",
+ "//third_party/llvm/llvm-project/mlir:Pass",
+ "//third_party/llvm/llvm-project/mlir:Support",
+ ],
+)
+
+cc_binary(
+ name = "lavm",
+ deps = [
+ ":LAVMDialectRegistration",
+ ":LAVMMain",
+ ":LAVMPass",
+ "//learning/brain/experimental/mlir/tpu:op_registration",
+ "//learning/brain/experimental/mlir/tpu:transforms",
+ "//third_party/llvm/llvm-project/llvm:support",
+ "//third_party/llvm/llvm-project/mlir:AllPassesAndDialectsNoRegistration",
+ "//third_party/llvm/llvm-project/mlir:MlirOptLib",
+ ],
+)
+
+# Machine-description grammar (scanning and parsing).
+
+genlex(
+ name = "md_scanner",
+ src = "md_scanner.lex",
+ out = "md_scanner.lex.cpp",
+ includes = ["md_parser.y.h"],
+ visibility = ["//visibility:private"],
+)
+
+genyacc(
+ name = "md_parser",
+ src = "md_parser.y",
+ header_out = "md_parser.y.h",
+ source_out = "md_parser.y.cpp",
+ visibility = ["//visibility:private"],
+)
+
+cc_library(
+ name = "md_grammar",
+ srcs = [
+ "md_grammar.cpp",
+ "md_parser.y.cpp",
+ "md_scanner.lex.cpp",
+ ],
+ hdrs = [
+ "md_grammar.h",
+ "md_parser.y.h",
+ ],
+ copts = ["-Wno-implicit-fallthrough"],
+ visibility = ["//visibility:private"],
+)
diff --git a/LAVM/LAVMDialect.cpp b/LAVM/LAVMDialect.cpp
new file mode 100644
index 0000000..9a9f6c6
--- /dev/null
+++ b/LAVM/LAVMDialect.cpp
@@ -0,0 +1,26 @@
+#include "LAVMDialect.h"
+
+#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/StandardTypes.h"
+
+using namespace mlir;
+using namespace mlir::lavm;
+
+//===----------------------------------------------------------------------===//
+// LAVMDialect
+//===----------------------------------------------------------------------===//
+
+LAVMDialect::LAVMDialect(MLIRContext *context)
+ : Dialect(getDialectNamespace(), context) {
+ addOperations<
+#define GET_OP_LIST
+#include "experimental/LAVM/LAVMDialect.cpp.inc"
+ >();
+}
+
+//===----------------------------------------------------------------------===//
+// TableGen'd op method definitions
+//===----------------------------------------------------------------------===//
+
+#define GET_OP_CLASSES
+#include "experimental/LAVM/LAVMDialect.cpp.inc"
diff --git a/LAVM/LAVMDialect.h b/LAVM/LAVMDialect.h
new file mode 100644
index 0000000..fa3e50c
--- /dev/null
+++ b/LAVM/LAVMDialect.h
@@ -0,0 +1,22 @@
+#ifndef EXPERIMENTAL_LAVM_LAVMDIALECT_H_
+#define EXPERIMENTAL_LAVM_LAVMDIALECT_H_
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
+
+namespace mlir {
+namespace lavm {
+
+class LAVMDialect : public Dialect {
+ public:
+ explicit LAVMDialect(MLIRContext *context);
+ static StringRef getDialectNamespace() { return "lavm"; }
+};
+
+#define GET_OP_CLASSES
+#include "experimental/LAVM/LAVMDialect.h.inc"
+
+} // end namespace lavm
+} // end namespace mlir
+
+#endif // EXPERIMENTAL_LAVM_LAVMDIALECT_H_
diff --git a/LAVM/LAVMDialect.td b/LAVM/LAVMDialect.td
new file mode 100644
index 0000000..f482e3c
--- /dev/null
+++ b/LAVM/LAVMDialect.td
@@ -0,0 +1,193 @@
+#ifndef LAVM_DIALECT
+#define LAVM_DIALECT
+
+include "mlir/IR/OpBase.td"
+
+def LAVM_Dialect : Dialect {
+ let name = "lavm";
+ let cppNamespace = "";
+}
+
+// Base class for LAVM dialect.
+class LAVM_Op<string mnemonic, list<OpTrait> traits = []> :
+ Op<LAVM_Dialect, mnemonic, traits> {
+ let printer = ?; // [{ return ::print(p, *this); }];
+ let verifier = ?; // [{ return ::verify(*this); }];
+ let parser = ?; // [{ return ::parse$cppClass(parser, result); }];
+ bit is_elementwise = 0;
+ // bit is_HL = 0;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Elementwise operations.
+////////////////////////////////////////////////////////////////////////////////
+
+// Unary:
+
+def LAVMAbsOp : LAVM_Op<"abs", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise abs";
+ let description = [{ Elementwise abs. }];
+ let arguments = (ins AnyType:$lhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMExpOp : LAVM_Op<"exp", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise exp";
+ let description = [{ Elementwise exp. }];
+ let arguments = (ins AnyType:$lhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMLogOp : LAVM_Op<"log", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise log";
+ let description = [{ Elementwise log. }];
+ let arguments = (ins AnyType:$lhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMNegOp : LAVM_Op<"neg", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise neg";
+ let description = [{ Elementwise neg. }];
+ let arguments = (ins AnyType:$lhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMNotOp : LAVM_Op<"not", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise not";
+ let description = [{ Elementwise not. }];
+ let arguments = (ins AnyType:$lhs);
+ let results = (outs AnyType:$res);
+}
+
+// Binary:
+
+def LAVMAddOp : LAVM_Op<"add", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise add";
+ let description = [{ Elementwise add. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMAndOp : LAVM_Op<"and", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise and";
+ let description = [{ Elementwise and. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMDivOp : LAVM_Op<"div", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise div";
+ let description = [{ Elementwise div. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMMaxOp : LAVM_Op<"max", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise max";
+ let description = [{ Elementwise max. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMMinOp : LAVM_Op<"min", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise min";
+ let description = [{ Elementwise min. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMMulOp : LAVM_Op<"mul", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise mul";
+ let description = [{ Elementwise mul. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMOrOp : LAVM_Op<"or", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise or";
+ let description = [{ Elementwise or. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMSubOp : LAVM_Op<"sub", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise sub";
+ let description = [{ Elementwise sub. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMXorOp : LAVM_Op<"xor", [NoSideEffect, SameOperandsAndResultType]> {
+ let summary = "LAVM elementwise xor";
+ let description = [{ Elementwise xor. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Matrix operations.
+////////////////////////////////////////////////////////////////////////////////
+
+def LAVMDotOp : LAVM_Op<"dot", [NoSideEffect]> {
+ let summary = "LAVM dot product";
+ let description = [{ Dot product. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+def LAVMMatmulOp : LAVM_Op<"matmul", [NoSideEffect]> {
+ let summary = "LAVM matmul";
+ let description = [{ Matmul. }];
+ let arguments = (ins AnyType:$lhs, AnyType:$rhs);
+ let results = (outs AnyType:$res);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Memory operations.
+////////////////////////////////////////////////////////////////////////////////
+
+def LAVMLoadOp : LAVM_Op<"load"> {
+ let summary = "LAVM load";
+ let description = [{ Load }];
+ let arguments = (ins AnyMemRef:$address, Variadic<Index>:$indices);
+ let results = (outs AnyType:$res);
+ let extraClassDeclaration = [{
+ static unsigned GetAddressOpIndex() { return 0; }
+ }];
+}
+
+def LAVMStoreOp : LAVM_Op<"store"> {
+ let summary = "LAVM store";
+ let description = [{ store }];
+ let arguments = (ins AnyType:$value, AnyMemRef:$address, Variadic<Index>:$indices);
+ let results = (outs);
+ let extraClassDeclaration = [{
+ static unsigned GetValueOpIndex() { return 0; }
+ static unsigned GetAddressOpIndex() { return 1; }
+ }];
+}
+
+def LAVMDmaOp : LAVM_Op<"dma"> {
+ let summary = "LAVM DMA";
+ let description = [{ DMA }];
+ let arguments = (ins AnyMemRef:$to, AnyMemRef:$from, AnyType:$size);
+ let results = (outs);
+ let extraClassDeclaration = [{
+ static unsigned GetToOpIndex() { return 0; }
+ static unsigned GetFromOpIndex() { return 1; }
+ static unsigned GetSizeOpIndex() { return 2; }
+ }];
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Other operations.
+////////////////////////////////////////////////////////////////////////////////
+
+def LAVMConstantOp : LAVM_Op<"constant", [NoSideEffect]> {
+ let summary = "LAVM constant";
+ let description = [{ Elementwise constant. }];
+ let arguments = (ins AnyAttr:$value);
+ let results = (outs AnyType:$res);
+}
+
+#endif // LAVM_DIALECT
diff --git a/LAVM/LAVMDialectRegistration.cpp b/LAVM/LAVMDialectRegistration.cpp
new file mode 100644
index 0000000..e7d6634
--- /dev/null
+++ b/LAVM/LAVMDialectRegistration.cpp
@@ -0,0 +1,4 @@
+#include "LAVMDialect.h"
+
+// Static initialization for LAVM dialect registration.
+static mlir::DialectRegistration<mlir::lavm::LAVMDialect> LAVM;
diff --git a/LAVM/LAVMMain.cpp b/LAVM/LAVMMain.cpp
new file mode 100644
index 0000000..a211981
--- /dev/null
+++ b/LAVM/LAVMMain.cpp
@@ -0,0 +1,110 @@
+//===- LAVMMain.cpp -------------------------------------------------------===//
+//
+// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Main entry function for LAVM for when built as standalone binary.
+//
+//===----------------------------------------------------------------------===//
+
+#include "llvm/Support/InitLLVM.h"
+#include "llvm/Support/ToolOutputFile.h"
+#include "mlir/Parser.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/FileUtilities.h"
+#include "third_party/llvm/llvm-project/llvm/include/llvm/Support/CommandLine.h"
+
+static llvm::cl::opt<std::string> inputFilename(llvm::cl::Positional,
+ llvm::cl::desc("<input file>"),
+ llvm::cl::init("-"));
+
+static llvm::cl::opt<std::string> outputFilename(
+ "o", llvm::cl::desc("Output filename"), llvm::cl::value_desc("filename"),
+ llvm::cl::init("-"));
+
+// FIXME!!! discrepancy. This option specifies the location of the "target
+// desctiption" file. However, because the "td" abbreviation associates strongly
+// with tablegen, we use "md" instead, an abbreviation of "Machine Description",
+// and use the ".md" extension for the example target description files.
+static llvm::cl::opt<std::string> mdFilename(
+ "md", llvm::cl::desc("Target machine description filename"),
+ llvm::cl::value_desc("filename"), llvm::cl::Required,
+ llvm::cl::ValueRequired);
+
+static llvm::cl::opt<bool> useLexYacc(
+ "use_lex_yacc",
+ llvm::cl::desc("Use lex/yacc to parse target machine description file"),
+ llvm::cl::init(false));
+
+namespace mlir {
+namespace lavm {
+
+// Invoked during registration of the LAVMPass.
+const std::string &GetMdFilename() { return mdFilename; }
+bool UseLexYacc() { return useLexYacc; }
+
+static int LAVMmain(int argc, char **argv) {
+ llvm::InitLLVM y(argc, argv);
+
+ // Register any pass manager command line options.
+ registerPassManagerCLOptions();
+ PassPipelineCLParser passPipeline("", "Compiler passes to run");
+
+ // Parse pass names in main to ensure static initialization completed.
+ llvm::cl::ParseCommandLineOptions(argc, argv, "LAVM driver\n");
+
+ // Set up the input file.
+ std::string errorMessage;
+ auto file = openInputFile(inputFilename, &errorMessage);
+ if (!file) {
+ llvm::errs() << errorMessage << "\n";
+ exit(1);
+ }
+
+ auto output = openOutputFile(outputFilename, &errorMessage);
+ if (!output) {
+ llvm::errs() << errorMessage << "\n";
+ exit(1);
+ }
+
+ // Tell sourceMgr about this buffer, which is what the parser will pick up.
+ llvm::SourceMgr sourceMgr;
+ sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
+
+ MLIRContext context;
+ OwningModuleRef module(parseSourceFile(sourceMgr, &context));
+ if (!module) {
+ llvm::errs() << "Cannot parse input.\n";
+ exit(1);
+ }
+
+ // Apply any pass manager command line options.
+ PassManager pm(&context, /*verifyPasses=*/false);
+ applyPassManagerCLOptions(pm);
+
+ // Build the pipeline.
+ if (failed(passPipeline.addToPipeline(pm))) {
+ llvm::errs() << "Cannot build pipeline.\n";
+ exit(1);
+ }
+
+ // Run the pipeline.
+ if (failed(pm.run(*module))) {
+ llvm::errs() << "Pipeline run failed.\n";
+ exit(1);
+ }
+
+ // Print the output.
+ module->print(output->os());
+
+ return 0;
+}
+
+} // namespace lavm
+} // namespace mlir
+
+int main(int argc, char **argv) { return mlir::lavm::LAVMmain(argc, argv); }
diff --git a/LAVM/LAVMPass.cpp b/LAVM/LAVMPass.cpp
new file mode 100644
index 0000000..5a33c0b
--- /dev/null
+++ b/LAVM/LAVMPass.cpp
@@ -0,0 +1,104 @@
+#include "LAVMDialect.h"
+#include "LAVMTarget.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/Pass/Pass.h"
+
+namespace mlir {
+namespace lavm {
+
+// Someone should provide these functions for successful linking.
+// For the standalone 'lavm' binary, they are defined in the file, which
+// contains the main() entry function.
+extern const std::string& GetMdFilename();
+extern bool UseLexYacc();
+
+#include "experimental/LAVM/LAVMPass.cpp.inc"
+
+class LAVMPass : public ModulePass<LAVMPass> {
+ public:
+ explicit LAVMPass(const std::string& mdFilename, bool use_lex_yacc)
+ : ModulePass<LAVMPass>(), target() {
+ if (!target.ParseFromFile(mdFilename.data(), use_lex_yacc)) {
+ llvm_unreachable("Errors. Exiting.");
+ }
+ }
+
+ private:
+ void XRegion(Region* region) {
+ for (Block& block : *region) {
+ XBlock(&block);
+ }
+ }
+
+ void XBlock(Block* block) {
+ for (Operation& operation : *block) {
+ XOperation(&operation);
+ }
+ }
+
+ void XOperation(Operation* operation) {
+ std::cerr << "\nOPERATION: ";
+ operation->print(llvm::errs());
+
+ Dialect* dialect = operation->getDialect();
+ if (dialect) {
+ if (dialect->getNamespace() == LAVMDialect::getDialectNamespace()) {
+ const LAVMOp* lavm_op =
+ target.GetLAVMOp(operation->getName().getStringRef().str());
+ if (lavm_op == nullptr) {
+ std::cerr << "\tNO INSTANTIATION RULE(S) PROVIDED, EXPAND TO OTHER "
+ "LAVM OP(S)\n";
+ } else {
+ std::cerr << "\tINSTANTIATE " << lavm_op->ToString() << "\n";
+ // FIXME!!! make all these expansions and instantiations opaque for
+ // the client (for example, declare them with auto or with
+ // target::expansions_type)
+ LAVMInstantiationList instantiations =
+ target.InferValidTypes(lavm_op);
+ std::cerr << "\n"
+ << LAVMInstantiation::ToDetailedString(instantiations,
+ *lavm_op, target);
+ }
+ } else {
+ std::cerr << "\tNON-LAVM DIALECT: '" << dialect->getNamespace().data()
+ << "'\n";
+ }
+ } else {
+ std::cerr << "\tHAS NO DIALECT\n";
+ }
+ std::cerr << "\n";
+
+ for (Region& region : operation->getRegions()) {
+ XRegion(®ion);
+ }
+ }
+
+ void runOnModule() override {
+ std::cerr << "\n" << target.ToDetailedString() << "\n";
+
+ // Attempt to raise target information from value to buffer level with
+ // target-specific types. Derive buffer operation sizes from the value-based
+ // target description. "expand" an op into loads/op/stores even for cases
+ // when target supports all different types for each.
+ // ...
+
+ for (Operation& O : getModule()) {
+ XOperation(&O);
+ }
+ }
+
+ private:
+ LAVMTarget target;
+};
+
+std::unique_ptr<OpPassBase<ModuleOp>> createLAVMPass(
+ const std::string& mdFilename, bool use_lex_yacc) {
+ return std::make_unique<LAVMPass>(mdFilename, use_lex_yacc);
+}
+
+static PassRegistration<LAVMPass> pass("lavm", "Run LAVM pass", [] {
+ return std::make_unique<LAVMPass>(GetMdFilename(), UseLexYacc());
+});
+
+} // namespace lavm
+} // namespace mlir
diff --git a/LAVM/LAVMPass.td b/LAVM/LAVMPass.td
new file mode 100644
index 0000000..2d15845
--- /dev/null
+++ b/LAVM/LAVMPass.td
@@ -0,0 +1,18 @@
+#ifndef LAVM_PASS
+#define LAVM_PASS
+
+include "mlir/Dialect/StandardOps/IR/Ops.td"
+include "LAVMDialect.td"
+
+// Initial mapping from HL ops to LAVM ops
+// def : Pat<(LAIRHLAddOp $dst, $lhs, $rhs), (LAIRAddMemrefOp $dst, $lhs, $rhs)>;
+// def : Pat<(LAIRHLSubOp $dst, $lhs, $rhs), (LAIRSubMemrefOp $dst, $lhs, $rhs)>;
+// def : Pat<(LAIRHLMulOp $dst, $lhs, $rhs), (LAIRMulMemrefOp $dst, $lhs, $rhs)>;
+// def : Pat<(LAIRHLDotOp $dst, $lhs, $rhs), (LAIRDotMemrefOp $dst, $lhs, $rhs)>;
+// def : Pat<(LAIRHLMatmulOp $dst, $lhs, $rhs), (LAIRMatmulMemrefOp $dst, $lhs, $rhs)>;
+//
+// def UIsMemRefTypePred : TypeConstraint<CPred<"$_self.isa<MemRefType>()">, "memref">;
+//
+// def : Pat<(LAIRAddOp F32:$lhs, F32:$rhs), (AddFOp $lhs, $rhs)>;
+
+#endif // LAVM_PASS
diff --git a/LAVM/LAVMTarget.cpp b/LAVM/LAVMTarget.cpp
new file mode 100644
index 0000000..593056b
--- /dev/null
+++ b/LAVM/LAVMTarget.cpp
@@ -0,0 +1,626 @@
+#include "LAVMTarget.h"
+
+#include "LAVMDialect.h"
+#include "Lexer.h"
+#include "mlir/IR/Operation.h"
+
+namespace mlir {
+namespace lavm {
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMInstantiation
+////////////////////////////////////////////////////////////////////////////////
+
+// FIXME!!! these belong to LAVMInstantiation.cpp
+
+LAVMType LAVMInstantiation::GetTopLevelType() const { // FIXME!!! rename
+ // std::cerr << "Entered LAVMInstantiation::GetTopLevelType() with " <<
+ // ToString() << "\n";
+
+ // Just a reminder, it's already been enforced:
+ assert(lavm_op != nullptr && lavm_op->IsFlatOp());
+ LAVMTypeList list;
+ for (int32_t i = 0; i < lavm_op->GetNumOperands(); i++) {
+ const LAVMType* type = GetType(&lavm_op->Operand(i));
+ if (type == nullptr) {
+ // This variable is not used in the expansion, can be any type.
+ list.push_back(LAVMType::CreateAnyType());
+ } else {
+ list.push_back(*type);
+ }
+ }
+ const LAVMType domain_type = LAVMType::Create(list);
+ const LAVMType& range_type = GetType(expansion_op)->GetRangeType();
+ return LAVMType::CreateFunctionType(domain_type, range_type);
+}
+
+int32_t LAVMInstantiation::GetCost(const LAVMTarget* target) const {
+ int32_t cost = 0;
+ auto f = [this, &target, &cost](const LAVMOp* op,
+ const LAVMInstantiation* instantiation) {
+ if (op->IsOp()) {
+ const LAVMInstantiation* sub_instantiation =
+ instantiation->GetSubInstantiation(op);
+ if (sub_instantiation == nullptr) {
+ const LAVMType* type = instantiation->GetType(op);
+ assert(type);
+ // FIXME!!! create LAVMInstantiation::dfs()
+ const LAVMAnnotation* type_annotation =
+ target->GetTypeAnnotation(*op, *type);
+ if (type_annotation != nullptr) {
+ // std::cerr << "\nAnnotation for [" << op->ToString() << ", " <<
+ // type->ToString() << "]: "
+ // << type_annotation->ToString() << "\n";
+ cost += type_annotation->ToInt32();
+ } else {
+ cost++;
+ }
+ }
+ } else {
+ assert(op->IsName());
+ }
+ };
+ dfs(f);
+ return cost;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Interface
+////////////////////////////////////////////////////////////////////////////////
+
+const LAVMOp* LAVMTarget::GetTargetOp(const std::string& name) const {
+ return GetTargetOpImpl(name);
+}
+
+const LAVMOp* LAVMTarget::GetLAVMOp(const std::string& name) const {
+ return GetLAVMOpImpl(name);
+}
+
+bool LAVMTarget::ParseFromFile(const char* target_description_filename,
+ bool use_lex_yacc) {
+ return LAVMTarget::ParseFromFileImpl(target_description_filename,
+ use_lex_yacc);
+}
+
+LAVMInstantiationList LAVMTarget::InferValidTypes(const LAVMOp* lavm_op) const {
+ return LAVMTarget::InferValidTypesImpl(lavm_op);
+}
+
+bool LAVMTarget::SupportsDma(const LAVMTargetMemory& to,
+ const LAVMTargetMemory& from) const {
+ return SupportsDmaImpl(to, from);
+}
+
+bool LAVMTarget::SupportsLoad(const LAVMTargetMemory& mem) const {
+ return !GetLoadTypes(mem).empty();
+}
+
+bool LAVMTarget::SupportsStore(const LAVMTargetMemory& mem) const {
+ return !GetStoreTypes(mem).empty();
+}
+
+LAVMTypeList LAVMTarget::GetLoadTypes(const LAVMTargetMemory& mem) const {
+ return GetLoadTypesImpl(mem);
+}
+
+LAVMTypeList LAVMTarget::GetStoreTypes(const LAVMTargetMemory& mem) const {
+ return GetStoreTypesImpl(mem);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Serialize
+////////////////////////////////////////////////////////////////////////////////
+
+std::string LAVMTarget::ToString() const {
+ const std::string indent = " ";
+ std::string s;
+ s += "target {\n";
+ for (const auto& pair : target_op_types) {
+ const LAVMOp& op = pair.first;
+ for (const LAVMType& type : pair.second) {
+ s += indent + op.ToString() + " : " + type.ToString();
+ const LAVMAnnotation* type_annotation = GetTypeAnnotation(op, type);
+ if (type_annotation != nullptr) {
+ s += " @ " + *type_annotation;
+ }
+ s += ";\n";
+ }
+ }
+ s += "}\n";
+ s += "map {\n";
+ for (const auto& pair : expansions) {
+ const LAVMOp& op = pair.first;
+ for (const LAVMOp& target_op : pair.second) {
+ s += indent + op.ToString() + " : " + target_op.ToString() + ";\n";
+ }
+ }
+ s += "}\n";
+ s += "memory {\n";
+ for (const auto& pair : memory) {
+ const LAVMTargetMemory& mem = pair.second;
+ std::string attr = mem.ToStringAttributes();
+ if (!attr.empty()) {
+ s += indent + attr + ";\n";
+ }
+ }
+ for (const auto& pair : memory) {
+ const LAVMTargetMemory& mem = pair.second;
+ std::string transfers = mem.ToStringTransfers();
+ if (!transfers.empty()) {
+ s += indent + transfers + "\n"; // FIXME!!! inconsistent: no ";"
+ }
+ }
+ s += "}\n";
+ return s;
+}
+
+std::string LAVMTarget::ToDetailedString() const {
+ std::string s = "Target info:\n" + ToString();
+
+ // Dump all possible instantiations for LAVM ops.
+ for (auto it = ExpansionBegin(); it != ExpansionEnd(); ++it) {
+ const LAVMOp& lavm_op = it->first; // FIXME!!! make opaque
+ LAVMInstantiationList instantiations = InferValidTypes(&lavm_op);
+ s += "\n" +
+ LAVMInstantiation::ToDetailedString(instantiations, lavm_op, *this);
+ }
+
+ // Dump memory info.
+ s += "\nMemory info:\n";
+ bool emitCache = false;
+ do {
+ emitCache = !emitCache;
+ for (auto it = MemoryBegin(); it != MemoryEnd(); ++it) {
+ const LAVMTargetMemory& mem = it->second;
+ if (emitCache == mem.IsCache()) {
+ s += " ";
+ if (emitCache) {
+ s += "cache $";
+ }
+ s += mem.GetName() + " has size " + mem.GetSize();
+ if (SupportsLoad(mem)) {
+ s += ", supports load(s) of ";
+ bool First = true; // FIXME!!! rewrite this and other prints to hide
+ // this boolean
+ for (const LAVMType& type : GetLoadTypes(mem)) {
+ if (!First) {
+ s += ", ";
+ }
+ s += type.ToString();
+ First = false;
+ }
+ }
+ if (SupportsStore(mem)) {
+ s += ", supports store(s) of ";
+ bool First = true; // FIXME!!! rewrite this and other prints to hide
+ // this boolean
+ for (const LAVMType& type : GetStoreTypes(mem)) {
+ if (!First) {
+ s += ", ";
+ }
+ s += type.ToString();
+ First = false;
+ }
+ }
+ s += "\n";
+ }
+ }
+ } while (emitCache);
+
+ // Dump DMA info.
+ s += "\nSupported DMA:\n";
+ bool dma_empty = true;
+ for (auto it_to = MemoryBegin(); it_to != MemoryEnd(); ++it_to) {
+ const LAVMTargetMemory& to = it_to->second;
+ for (auto it_from = MemoryBegin(); it_from != MemoryEnd(); ++it_from) {
+ const LAVMTargetMemory& from = it_from->second;
+ if (SupportsDma(to, from)) {
+ s += " " + to.GetName() + " -> " + from.GetName() + "\n";
+ dma_empty = false;
+ }
+ }
+ }
+ if (dma_empty) {
+ s += " none.\n";
+ }
+ return s;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Helpers
+////////////////////////////////////////////////////////////////////////////////
+
+// FIXME!!! mess
+LAVMOpPtrList LAVMTarget::GetTargetExpansions(const LAVMOp* lavm_op) const {
+ LAVMOpPtrList mapping;
+ assert(lavm_op != nullptr); // remove? dereference will crash anyways
+ // FIXME!!! rename adjusted_op
+ const LAVMOp* adjusted_op = GetLAVMOp(lavm_op->GetOpName());
+ auto it = expansions.find(*adjusted_op);
+ if (it != expansions.end()) {
+ for (const LAVMOp& target_op : it->second) {
+ mapping.push_back(&target_op);
+ }
+ }
+ return mapping;
+}
+
+// FIXME!!! mess
+LAVMTypePtrList LAVMTarget::GetTypesSupportedByTargetOp(
+ const LAVMOp* op) const {
+ LAVMTypePtrList mapping;
+ const LAVMOp* target_op = GetTargetOp(op->GetOpName());
+ if (target_op != nullptr) {
+ auto it = target_op_types.find(*target_op);
+ for (const LAVMType& type : it->second) {
+ mapping.push_back(&type);
+ }
+ }
+ return mapping;
+}
+
+// This method maps an arbitrary (supposedly, target) op to its equivalent in
+// the target's "target op" table. Messy.
+//
+const LAVMOp* LAVMTarget::GetTargetOpImpl(const std::string& name) const {
+ const LAVMOpPtrList list = FilterTargetOpTypes(name);
+ if (list.size() == 1) {
+ return list.front();
+ }
+ if (!list.empty()) {
+ std::cerr << "Multiple target entries for op '" << name
+ << "' not supported yet:\n";
+ }
+ return nullptr;
+}
+
+const LAVMOp* LAVMTarget::GetLAVMOpImpl(const std::string& name) const {
+ const LAVMOpPtrList list = FilterExpansions(name);
+ if (list.size() == 1) {
+ return list.front();
+ }
+ if (!list.empty()) {
+ std::cerr << "Multiple map entries for op '" << name
+ << "' not supported yet:\n";
+ }
+ return nullptr;
+}
+
+// Returns a list of types that can be loaded directly from the passed-in memory
+// using a load instruction. NOTE: the types are inferred, therefore the return
+// type is a list of LAVMType instances rather than a list of pointers to
+// LAVMType.
+//
+LAVMTypeList LAVMTarget::GetLoadTypesImpl(const LAVMTargetMemory& mem) const {
+ LAVMTypeList list;
+ const LAVMOp* load = GetLAVMLoad();
+ if (load != nullptr) {
+ const LAVMInstantiationList instantiations = InferValidTypes(load);
+ // NOTE: LAVMLoadOp is not a subclass of LAVMOp
+ const int32_t address_operand = LAVMLoadOp::GetAddressOpIndex();
+ for (const LAVMInstantiation& instantiation : instantiations) {
+ const LAVMType type = instantiation.GetTopLevelType();
+ const LAVMType& domain_type = type.GetDomainType();
+ const LAVMType& range_type = type.GetRangeType();
+ const LAVMType& address_type = domain_type.GetType(address_operand);
+ const LAVMType& value_type = range_type;
+ assert(address_type.IsName());
+ if (address_type.IsPointerTo(mem)) {
+ // std::cerr << "Can load " << value_type.GetName() << " from " <<
+ // mem.GetName() << "\n";
+ list.push_back(value_type);
+ }
+ }
+ }
+ return list;
+}
+
+// Returns a list of types that can be stored directly to the passed-in memory
+// using a store instruction. NOTE: the types are inferred, therefore the return
+// type is a list of LAVMType instances rather than a list of pointers to
+// LAVMType.
+//
+LAVMTypeList LAVMTarget::GetStoreTypesImpl(const LAVMTargetMemory& mem) const {
+ LAVMTypeList list;
+ const LAVMOp* store = GetLAVMStore();
+ if (store != nullptr) {
+ const LAVMInstantiationList instantiations = InferValidTypes(store);
+ // NOTE: LAVMStoreOp is not a subclass of LAVMOp
+ const int32_t address_operand = LAVMStoreOp::GetAddressOpIndex();
+ const int32_t value_operand = LAVMStoreOp::GetValueOpIndex();
+ for (const LAVMInstantiation& instantiation : instantiations) {
+ const LAVMType type = instantiation.GetTopLevelType();
+ const LAVMType& domain_type = type.GetDomainType();
+ const LAVMType& address_type = domain_type.GetType(address_operand);
+ const LAVMType& value_type = domain_type.GetType(value_operand);
+ assert(address_type.IsName());
+ assert(value_type.IsName());
+ if (address_type.IsPointerTo(mem)) {
+ // std::cerr << "Can store " << value_type.GetName() << " to " <<
+ // mem.GetName() << "\n";
+ list.push_back(value_type);
+ }
+ }
+ }
+ return list;
+}
+
+bool LAVMTarget::SupportsDmaImpl(const LAVMTargetMemory& to,
+ const LAVMTargetMemory& from) const {
+ bool Supported = false;
+ const LAVMOp* dma = GetLAVMDma();
+ if (dma != nullptr) {
+ const LAVMInstantiationList instantiations = InferValidTypes(dma);
+ // NOTE: LAVMDmaOp is not a subclass of LAVMOp
+ const int32_t to_operand = LAVMDmaOp::GetToOpIndex();
+ const int32_t from_operand = LAVMDmaOp::GetFromOpIndex();
+ for (const LAVMInstantiation& instantiation : instantiations) {
+ const LAVMType type = instantiation.GetTopLevelType();
+ const LAVMType& domain_type = type.GetDomainType();
+ const LAVMType& to_type = domain_type.GetType(to_operand);
+ const LAVMType& from_type = domain_type.GetType(from_operand);
+ assert(to_type.IsName());
+ assert(from_type.IsName());
+ // std::cerr << to_type.GetName() << "(" << to.GetName() << ") && " <<
+ // from_type.GetName() << "(" << from.GetName() << ")\n";
+ if (to_type.IsPointerTo(to) && from_type.IsPointerTo(from)) {
+ assert(!Supported); // why require unique?
+ Supported = true;
+ }
+ }
+ }
+ return Supported;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Type inference
+////////////////////////////////////////////////////////////////////////////////
+
+// Check if all uses consume same type according to the concrete instantiation.
+// For early mismatch detection, want_type can be passed in, otherwise it
+// defaults to nullptr, which means that any type common to all uses is
+// acceptable. Reurns nullptr if there is a mismatch between types already
+// assigned to the uses or a type already assigned to a use and the wanted type
+// want_type, if passed-in. Special rules apply if want_type or any use type are
+// AnyType.
+//
+/*static*/ const LAVMType* LAVMTarget::FindCommonUseType(
+ const LAVMOpArgList& uses, const LAVMInstantiation& instantiation,
+ const LAVMType* want_type) {
+ const LAVMType* type = want_type;
+ for (const LAVMOpArg& use : uses) {
+ const LAVMOp* use_op = use.first;
+ const int32_t use_op_arg_number = use.second;
+ if (use_op == nullptr) {
+ // Can use any type supported by the target.
+ // FIXME!!! check that the op is the root of the expansion tree.
+ // std::cerr << " no uses, any type supported by the target is
+ // ok, including " << type->ToString() << "\n";
+ } else {
+ // if non-nullptr, this is a sub-instantiation.
+ // const LAVMInstantiation* sub_instantiation =
+ // instantiation.GetSubInstantiation(use_op); if (sub_instantiation !=
+ // nullptr) {
+ // LAVMType sub_instantiation_type =
+ // sub_instantiation->GetTopLevelType(); // avoid mixing up the warning
+ // in this call and the output below.
+ // // std::cerr << " * FindCommonUseType(): use_op=" <<
+ // use_op->ToString()
+ // // << ", sub_instantiation type=" <<
+ // sub_instantiation_type.ToString()
+ // // << ", sub_instantiation=" <<
+ // sub_instantiation->ToString() << "\n";
+ // }
+ const LAVMType* assigned_type = instantiation.GetType(use_op);
+ assert(assigned_type->IsFunctionType());
+ // FIXME!!! awkward:
+ const LAVMType* assigned_use_type = &assigned_type->GetDomainType();
+ if (assigned_use_type->IsListType()) {
+ assigned_use_type = &assigned_use_type->GetType(use_op_arg_number);
+ } else {
+ assert(assigned_use_type->IsName());
+ assert(use_op_arg_number == 0);
+ }
+ // std::cerr << " compare types at uses\n"
+ // << " arg " << use_op_arg_number << " of '" <<
+ // use_op->ToString() << " : "
+ // << assigned_type->ToString() << "' has been
+ // assigned type " <<
+ // assigned_use_type->ToString() << "\n"
+ // << " name %" << op->GetName() << " so far has type
+ // "
+ // << (type == nullptr ? "NONE" :
+ // type->ToString()) << "\n";
+ if (type == nullptr || type->IsAny()) {
+ type = assigned_use_type;
+ } else if (!assigned_use_type->IsAny() && *assigned_use_type != *type) {
+ // any type!!!
+ return nullptr;
+ }
+ }
+ }
+ return type;
+}
+
+// First parameter 'list' is the *linearized* expansion in the dfs order.
+// The last parameter 'instantiation' is passed by value on purpose.
+template <typename F>
+void LAVMTarget::InstantiateList(const LAVMOpPtrList& list,
+ const LAVMOpUses& op_uses, F&& at_end,
+ LAVMInstantiation instantiation) const {
+ // std::cerr << "Entered InstantiateList()\n";
+ if (mt(list)) {
+ // Done!
+ // std::cerr << " Visited all nodes in the expansion, calling
+ // at_end()\n";
+ at_end(instantiation);
+ } else {
+ const LAVMOp* op = car(list);
+ auto it = op_uses.find(op);
+ assert(it != op_uses.end());
+ const LAVMOpArgList& uses = it->second;
+ if (op->IsName()) {
+ // std::cerr << " NAME node %" << op->GetName() << "\n";
+ const LAVMType* assigned_type = instantiation.GetType(op);
+ const LAVMType* type =
+ FindCommonUseType(uses, instantiation, assigned_type);
+ if (type == nullptr) {
+ // if (assigned_type == nullptr) {
+ // std::cerr << "no match (no common use type and no previously
+ // assigned type)\n";
+ // } else {
+ // std::cerr << "no match (no common use type or common use type
+ // mismatches already assigned type " << assigned_type->ToString() <<
+ // ")\n";
+ // }
+ } else if (assigned_type == nullptr) {
+ // std::cerr << "match (no previously assigned type, using common use
+ // type " << type->ToString() << ")\n";
+ instantiation.SetType(op, type);
+ InstantiateList(cdr(list), op_uses, at_end, instantiation);
+ } else {
+ // std::cerr << "match (already assigned type " <<
+ // assigned_type->ToString() << " matches common use type " <<
+ // type->ToString() << ")\n";
+ assert(*type == *assigned_type); // just a reminder, already guaranteed
+ // by FindCommonUseType()
+ assert(!type->IsAny()); // what is this case?
+ InstantiateList(cdr(list), op_uses, at_end, instantiation);
+ }
+ } else if (op->IsOp()) {
+ const LAVMTypePtrList target_op_types = GetTypesSupportedByTargetOp(op);
+ if (target_op_types.empty()) {
+ // This must be an lavm_op in the expansion sequence.
+ // This works, but is nuts. Allow InferValidTypes() to take
+ // instantiations seeded with the types known so far. Also FIXME!!!
+ // allow variable names to be different in the sub-expansions, just
+ // match them to what's used in the current expansion when computing the
+ // union of the instantitations. std::cerr << " < Instantiating
+ // nested expansion(s) of " << op->ToString() << "\n";
+ // expansion trees only, no dags:
+ assert(instantiation.GetType(op) == nullptr);
+ assert(instantiation.GetSubInstantiation(op) == nullptr);
+ // Infer all legal sub-instantiations...
+ const LAVMOp* lavm_op = GetLAVMOp(op->GetOpName()); // FIXME!!! awkward
+ // FIXME!!! mess:
+ LAVMInstantiationList sub_instantiations = InferValidTypes(lavm_op);
+ if (sub_instantiations.empty()) {
+ // FIXME!!! detect this earlier and replace with unreachable()
+ std::cerr << "WARNING: expansion uses undefined operation/expansion "
+ << op->ToString() << "\n";
+ } else {
+ // std::cerr << " Found " << sub_instantiations.size() << "
+ // sub-instantiations of " << op->ToString() << ":\n"; for (const
+ // LAVMInstantiation& sub_instantiation : sub_instantiations) {
+ // LAVMType sub_instantiation_type =
+ // sub_instantiation.GetTopLevelType(); // avoid mixing up the
+ // warning in this call and the output below. std::cerr << " As " <<
+ // sub_instantiation.GetExpansionOp()->ToString() << ": "
+ // << sub_instantiation_type.ToString() << " " <<
+ // sub_instantiation.ToString() << "\n";
+ // }
+ // ... and for each one...
+ for (const LAVMInstantiation& sub_instantiation :
+ sub_instantiations) {
+ // ... match its types with the types picked so far in the current
+ // instantiation.
+ const LAVMType sub_instantiation_type =
+ sub_instantiation.GetTopLevelType();
+ const LAVMType* type = FindCommonUseType(
+ uses, instantiation, &sub_instantiation_type.GetRangeType());
+ if (type != nullptr) {
+ // std::cerr << " Feasible sub-instantiation as " <<
+ // sub_instantiation.GetExpansionOp()->ToString() << ": "
+ // << type->ToString() << " " <<
+ // sub_instantiation.ToString() << "\n";
+ // Override the type assigned in the previous iteration:
+ instantiation.SetType(op, &sub_instantiation_type);
+ // Override the sub-instantiation assigned in the previous
+ // iteration:
+ instantiation.SetSubInstantiation(op, sub_instantiation);
+ InstantiateList(cdr(list), op_uses, at_end, instantiation);
+ }
+ }
+ }
+ // std::cerr << " > Done with nested expansions of " <<
+ // op->ToString() << "\n";
+ } else {
+ // std::cerr << " OP node " << op->GetOpName() << "\n";
+ // Expansion trees only, no dags:
+ assert(instantiation.GetType(op) == nullptr);
+ // For all types supported by the target for this target op...
+ for (const LAVMType* target_type : target_op_types) {
+ // std::cerr << " Checking if target type " <<
+ // target_type->ToString()
+ // << " matches types already assigned to uses of " <<
+ // op->ToString() << "\n";
+ assert(target_type->IsFunctionType());
+ // Check if all uses consume this target type. Otherwise this specific
+ // instantiation has failed.
+ if (FindCommonUseType(uses, instantiation,
+ &target_type->GetRangeType())) {
+ // std::cerr << "match\n";
+ // Override the type assigned in the previous iteration:
+ instantiation.SetType(op, target_type);
+ // std::cerr << " " << target_type->ToString() << "\n";
+ InstantiateList(cdr(list), op_uses, at_end, instantiation);
+ }
+ }
+ // If we implement instantiation.RemoveType(op) here we should be able
+ // to pass the instantiation pointer as a parameter to the recursive
+ // calls instead of intentionally passing it by value to create its
+ // copy.
+ }
+ } else {
+ assert(false);
+ }
+ }
+ // std::cerr << "Exiting InstantiateList()\n";
+}
+
+// Infer all possible valid instantiations for the passed-in op.
+// Mutually recursive with LAVMTarget::InstantiateList()
+//
+LAVMInstantiationList LAVMTarget::InferValidTypesImpl(
+ const LAVMOp* lavm_op) const {
+ LAVMInstantiationList instantiations;
+
+ const LAVMOpPtrList expansions = GetTargetExpansions(lavm_op);
+ assert(!expansions.empty()); // empty rhs in the target section entry??? just
+ // return empty result in this case?
+ for (const LAVMOp* expansion : expansions) {
+ // Linearize the nodes in the expansion. FIXME!!! cache?
+ const LAVMOpPtrList dfs_order = LAVMOp::DfsOrder(expansion);
+
+ // Build use map for the expansion. FIXME!!! cache?
+ const LAVMOpUses op_uses = LAVMOp::BuildOpUses(expansion);
+ // FIXME!!! move to OpUses::ToString()
+ // std::cerr << "ZZZ op_uses:\n";
+ // for (const auto& I : op_uses) {
+ // const LAVMOp* op = I.first;
+ // for (const OpArg& use : I.second) {
+ // const LAVMOp* use_op = use.first;
+ // // const int32_t use_op_argument_number = use.second;
+ // if (use_op == nullptr) {
+ // // std::cerr << op->ToString() << " has no uses. Must be root of
+ // expansion tree.\n"; assert(op == target_expansion);
+ // } else {
+ // // std::cerr << op->ToString() << " used by arg " <<
+ // use_op_argument_number << " of " << use_op->ToString() << "\n";
+ // }
+ // }
+ // }
+
+ auto process_instantiation = [&](const LAVMInstantiation& instantiation) {
+ instantiations.push_back(instantiation);
+ instantiations.back().SetLAVMOp(lavm_op); // FIXME!!! hacky
+ instantiations.back().SetExpansionOp(expansion); // FIXME!!! hacky
+ };
+ InstantiateList(dfs_order, op_uses, process_instantiation);
+ }
+
+ return instantiations;
+}
+
+} // namespace lavm
+} // namespace mlir
diff --git a/LAVM/LAVMTarget.h b/LAVM/LAVMTarget.h
new file mode 100644
index 0000000..c2d78c2
--- /dev/null
+++ b/LAVM/LAVMTarget.h
@@ -0,0 +1,845 @@
+#ifndef EXPERIMENTAL_LAVM_LAVMTARGET_H_
+#define EXPERIMENTAL_LAVM_LAVMTARGET_H_
+
+#include <iostream>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "Lexer.h"
+
+namespace mlir {
+class MLIRContext;
+class Operation;
+} // namespace mlir
+
+namespace mlir {
+namespace lavm {
+
+template <typename T>
+inline bool mt(const T& list) {
+ return list.empty();
+}
+template <typename T>
+inline typename T::value_type car(const T& list) {
+ assert(!mt(list));
+ return list.front();
+}
+template <typename T>
+inline T cdr(const T& list) {
+ assert(!mt(list));
+ return T(std::next(list.begin()), list.end());
+}
+// template <typename T>
+// inline T cons(typename T::value_type elt, const T& list = T{}) {
+// T new_list{elt};
+// new_list.insert(new_list.end(), list.begin(), list.end());
+// return new_list;
+// }
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMAnnotation class
+////////////////////////////////////////////////////////////////////////////////
+
+class LAVMAnnotation : public std::string {
+ public:
+ // Create:
+ static LAVMAnnotation Create(const std::string& annotation) {
+ return LAVMAnnotation(annotation);
+ }
+
+ // FIXME!! revisit this pile:
+ LAVMAnnotation() : std::string() {}
+ explicit LAVMAnnotation(const std::string& s) : std::string(s) {}
+ explicit LAVMAnnotation(const std::string&& s) : std::string(s) {}
+ LAVMAnnotation& operator=(const std::string& s) {
+ this->std::string::operator=(s);
+ return *this;
+ }
+ LAVMAnnotation& operator=(const std::string&& s) {
+ this->std::string::operator=(s);
+ return *this;
+ }
+
+ const std::string& ToString() const { return *this; }
+ int32_t ToInt32() const { return std::stoi(*this); }
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMTargetMemory class
+////////////////////////////////////////////////////////////////////////////////
+
+struct LAVMTargetMemory;
+using LAVMTargetMemoryList = std::vector<LAVMTargetMemory>;
+using LAVMTargetMemoryPtrList = std::vector<const LAVMTargetMemory*>;
+struct LAVMTargetMemory {
+ private:
+ using AttributeName = std::string;
+ using AttributeValue = std::string;
+ using AttributeMap = std::map<AttributeName, AttributeValue>;
+
+ public:
+ // Create:
+ static LAVMTargetMemory Create(const std::string& name, bool is_cache) {
+ LAVMTargetMemory mem;
+ mem.name = name;
+ mem.is_cache = is_cache;
+ return mem;
+ }
+
+ // Access:
+ bool IsCache() const { return is_cache; }
+ const std::string& GetName() const { return name; }
+ const std::string& GetSize() const {
+ return GetAttribute(kSizeName);
+ } // FIXME!!! change return value to int64_t
+
+ // Mutate:
+ void AddAttribute(const AttributeName& attribute_name,
+ const AttributeValue& attribute_value) {
+ attributes[attribute_name] = attribute_value;
+ }
+ static void AddTransfer(LAVMTargetMemory* from, LAVMTargetMemory* to) {
+ from->transfer_to.push_back(to);
+ to->transfer_from.push_back(from);
+ }
+
+ // Serialize:
+ std::string ToStringAttributes() const {
+ std::string s;
+ if (IsCache()) {
+ s += "$";
+ }
+ s += GetName();
+ s += " : ";
+ bool First = true;
+ for (int32_t i = 0; i < GetNumAttributes(); i++) {
+ auto it = GetAttribute(i);
+ if (!First) {
+ s += ", ";
+ }
+ s += it->first + " = " + it->second;
+ First = false;
+ }
+ return s;
+ }
+ std::string ToStringTransfers() const {
+ // FIXME!!! we cannot serialize this information in a single ToString()
+ // call, which probably means poor class design. It will become more
+ // complicated when instructions for transfer are supported?
+ std::string s;
+ bool First = true;
+ for (const LAVMTargetMemory* mem : transfer_to) {
+ if (!First) {
+ s += " ";
+ }
+ if (IsCache()) {
+ s += "$";
+ }
+ s += GetName();
+ s += " -> ";
+ if (mem->IsCache()) {
+ s += "$";
+ }
+ s += mem->GetName();
+ s += ";";
+ First = false;
+ }
+ if (!transfer_from.empty()) {
+ s += " # ";
+ bool First = true;
+ for (const LAVMTargetMemory* mem : transfer_from) {
+ if (!First) {
+ s += " ";
+ }
+ if (mem->IsCache()) {
+ s += "$";
+ }
+ s += mem->GetName();
+ s += " -> ";
+ if (IsCache()) {
+ s += "$";
+ }
+ s += GetName();
+ s += ";";
+ First = false;
+ }
+ }
+ return s;
+ }
+
+ private:
+ // Helpers.
+ int32_t GetNumAttributes() const { return attributes.size(); }
+ AttributeMap::const_iterator GetAttribute(int32_t n) const {
+ assert(n < GetNumAttributes());
+ auto it = attributes.begin();
+ std::advance(it, n);
+ return it;
+ }
+ const AttributeValue& GetAttribute(
+ const AttributeName& attribute_name) const {
+ auto it = attributes.find(attribute_name);
+ assert(it != attributes.end());
+ return it->second;
+ }
+
+ private:
+ static constexpr char kSizeName[] = "size";
+ bool is_cache = false;
+ std::string name = "";
+ std::map<AttributeName, AttributeValue> attributes = {};
+ LAVMTargetMemoryPtrList transfer_to = {};
+ LAVMTargetMemoryPtrList transfer_from = {};
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMOp class
+////////////////////////////////////////////////////////////////////////////////
+
+struct LAVMOp;
+using LAVMOpList = std::vector<LAVMOp>;
+using LAVMOpPtrList = std::vector<const LAVMOp*>;
+using LAVMOpArg = std::pair<const LAVMOp*, int32_t>; // A use is an op and
+ // argument number
+using LAVMOpArgList = std::vector<LAVMOpArg>; // List of args
+using LAVMOpUses = std::map<const LAVMOp*, LAVMOpArgList>; // mapping from an
+ // operation to its
+ // uses
+struct LAVMOp {
+ // Create:
+ static LAVMOp Create(const std::string& name) {
+ LAVMOp op;
+ op.name = name;
+ return op;
+ }
+ static LAVMOp CreateAsName(const std::string& name) {
+ LAVMOp op;
+ op.name = name;
+ op.is_name = true;
+ return op;
+ }
+ static LAVMOpUses BuildOpUses(const LAVMOp* root) {
+ LAVMOpUses uses;
+ auto add_use = [&uses](const LAVMOp* op, const LAVMOp* parent,
+ int32_t parent_arg_no) {
+ uses[op].push_back(LAVMOpArg(parent, parent_arg_no));
+ };
+ root->dfs_with_parent(add_use);
+ return uses;
+ }
+ static LAVMOpPtrList DfsOrder(const LAVMOp* op) {
+ LAVMOpPtrList list;
+ op->dfs([&list](const LAVMOp* op) { list.push_back(op); });
+ return list;
+ }
+
+ // Access:
+ bool IsName() const { return is_name; }
+ bool IsOp() const { return !IsName(); }
+ // All operands are name ops:
+ bool IsFlatOp() const {
+ if (!IsOp()) {
+ return false;
+ }
+ for (int32_t i = 0; i < GetNumOperands(); i++) {
+ if (!Operand(i).IsName()) {
+ return false;
+ }
+ }
+ return true;
+ }
+ const std::string& GetName() const {
+ assert(IsName());
+ return name;
+ }
+ const std::string& GetOpName() const {
+ assert(IsOp());
+ return name;
+ }
+ int32_t GetNumOperands() const {
+ assert(IsOp());
+ return operands.size();
+ }
+ const LAVMOp& Operand(int32_t n) const {
+ assert(IsOp());
+ assert(n < GetNumOperands());
+ return operands.at(n);
+ }
+ // DFS visitors
+ template <typename F>
+ void dfs(F&& f) const {
+ if (IsName()) {
+ f(this);
+ } else if (IsOp()) {
+ f(this);
+ for (int32_t i = 0; i < GetNumOperands(); i++) {
+ Operand(i).dfs(f);
+ }
+ } else {
+ f(nullptr);
+ }
+ }
+ template <typename F>
+ void dfs_with_parent(F&& f, const LAVMOp* parent = nullptr,
+ int32_t parent_arg_no = 0) const {
+ if (IsName()) {
+ f(this, parent, parent_arg_no);
+ } else if (IsOp()) {
+ f(this, parent, parent_arg_no);
+ for (int32_t i = 0; i < GetNumOperands(); i++) {
+ Operand(i).dfs_with_parent(f, this, i);
+ }
+ } else {
+ f(nullptr, parent, parent_arg_no);
+ }
+ }
+
+ // Mutate:
+ void AddOperand(const LAVMOp& operand) {
+ assert(IsOp());
+ operands.push_back(operand);
+ }
+
+ // Compare:
+ bool operator==(const LAVMOp& other) const {
+ return is_name == other.is_name && name == other.name &&
+ operands == other.operands;
+ }
+ // Needed for using as a key in std::map.
+ bool operator<(const LAVMOp& other) const {
+ // FIXME!!! make more efficient
+ return ToString() < other.ToString();
+ }
+
+ // Serialize:
+ template <typename Prefix, typename Suffix>
+ std::string ToString(Prefix&& prefix, Suffix&& suffix) const {
+ std::string s = prefix(this);
+ if (IsName()) {
+ s += "%";
+ s += GetName();
+ } else if (IsOp()) {
+ s += GetOpName();
+ s += "(";
+ bool First = true;
+ for (int32_t i = 0; i < GetNumOperands(); i++) {
+ if (!First) {
+ s += ", ";
+ }
+ s += Operand(i).ToString(prefix, suffix);
+ First = false;
+ }
+ s += ")";
+ } else {
+ s += "<<unknown LAVMOp kind>>";
+ }
+ return s + suffix(this);
+ }
+
+ // Helpers:
+ // FIXME!!! this is similar to some helpers in the LAVMTarget class, perhaps,
+ // we should consider creating a separate LAVMUtils class/namespace.
+
+ // FIXME!!! this method does not belong here, move it to LAVMOpPtrList.
+ // NOTE: the comparison is performed by value, not by comparing pointers.
+ static bool Contains(const LAVMOpPtrList& list, const LAVMOp& op) {
+ auto equal = [&op](const LAVMOp* op_in_list) { return op == *op_in_list; };
+ return std::find_if(list.begin(), list.end(), equal) != list.end();
+ }
+
+ // Return the list of parameter operands (those, for which IsName() is true).
+ // Recursively visits all operands when necessary. NOTE: the returned list
+ // does not contain duplicates.
+ LAVMOpPtrList FilterArguments() const {
+ LAVMOpPtrList list;
+ auto filter = [&list] (const LAVMOp* op) {
+ if (op->IsName()) {
+ // Check if we have a duplicate.
+ if (!Contains(list, *op)) {
+ list.push_back(op);
+ }
+ }
+ };
+ dfs(filter);
+ return list;
+ }
+
+ static constexpr auto bupkis = [](const LAVMOp*) {
+ return std::string();
+ }; // FIXME!!! find better place and rename
+ std::string ToString() const { return ToString(bupkis, bupkis); }
+
+ private:
+ bool is_name = false;
+ std::string name = "";
+ LAVMOpList operands = {}; // can be empty
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMType class
+////////////////////////////////////////////////////////////////////////////////
+
+struct LAVMType;
+using LAVMTypeList = std::vector<LAVMType>;
+using LAVMAnnotatedTypeList = std::vector<std::pair<LAVMType, LAVMAnnotation>>;
+using LAVMTypePtrList = std::vector<const LAVMType*>;
+struct LAVMType {
+ // Create:
+ // FIXME!!! poorly defined meaning
+ static LAVMType Create(const LAVMTypeList& list) {
+ if (list.empty()) {
+ return CreateVoidType();
+ } else if (list.size() == 1) {
+ return list.front();
+ } else {
+ return CreateListType(list);
+ }
+ }
+ static LAVMType CreateNameType(const std::string& name) {
+ assert(name != kVoid && name != kAny);
+ LAVMType type;
+ type.name = name;
+ type.is_name = true;
+ return type;
+ }
+ static LAVMType CreateVoidType() {
+ LAVMType type;
+ type.name = kVoid;
+ type.is_name = true;
+ return type;
+ }
+ static LAVMType CreateAnyType() {
+ LAVMType type;
+ type.name = kAny;
+ type.is_name = true;
+ return type;
+ }
+ static LAVMType CreateListType(const LAVMTypeList& list) {
+ LAVMType type;
+ type.list = list;
+ return type;
+ }
+ static LAVMType CreateFunctionType(const LAVMType& domain,
+ const LAVMType& range) {
+ LAVMType type;
+ type.domain.push_back(domain);
+ type.range.push_back(range);
+ return type;
+ }
+
+ // Access:
+ bool IsName() const { return is_name; } // IsLeaf(), really
+ bool IsVoid() const { return IsName() && GetName() == kVoid; }
+ bool IsAny() const { return IsName() && GetName() == kAny; }
+ bool IsListType() const {
+ assert(IsSound());
+ return !IsName() && !list.empty();
+ }
+ bool IsFunctionType() const {
+ assert(IsSound());
+ return !IsName() && !domain.empty();
+ }
+ // FIXME!!! this is a temporary hack until .md syntax is determined.
+ bool IsPointerTo(const LAVMTargetMemory& mem) const {
+ return GetName() == "memref<" + mem.GetName() + ">";
+ }
+ bool IsSound() const {
+ if (IsName()) {
+ return list.empty() && domain.empty() && range.empty();
+ } else {
+ return list.empty() != domain.empty() &&
+ domain.empty() == range.empty() &&
+ (domain.empty() || (domain.size() + range.size() == 2));
+ }
+ }
+ const std::string& GetName() const {
+ assert(IsName());
+ return name;
+ }
+ // if list
+ int32_t GetNumTypes() const {
+ assert(IsListType());
+ return list.size();
+ }
+ const LAVMType& GetType(int32_t n) const {
+ assert(IsSound());
+ if (IsListType()) {
+ assert(n < GetNumTypes());
+ return list.at(n);
+ }
+ assert(IsName() && n == 0);
+ return *this;
+ }
+ // if function
+ const LAVMType& GetDomainType() const {
+ assert(IsFunctionType());
+ return domain.front();
+ }
+ const LAVMType& GetRangeType() const {
+ assert(IsFunctionType());
+ return range.front();
+ }
+
+ // Compare:
+ bool operator==(const LAVMType& other) const {
+ assert(IsSound() && other.IsSound());
+ if (IsName() && other.IsName()) {
+ return GetName() == other.GetName();
+ } else if (IsListType() && other.IsListType()) {
+ return list == other.list;
+ } else if (IsFunctionType() && other.IsFunctionType()) {
+ return domain == other.domain && range == other.range;
+ }
+ return false;
+ }
+ bool operator!=(const LAVMType& other) const {
+ return !this->operator==(other);
+ }
+ bool operator<(const LAVMType& other) const {
+ // FIXME!!! make more efficient
+ return ToString() < other.ToString();
+ }
+
+ // Serialize:
+ std::string ToString() const {
+ std::string s;
+ if (IsVoid()) { // FIXME!!! should be built into GetName()
+ s += "()";
+ } else if (IsName()) { // note: includes IsAny()
+ s += GetName();
+ } else if (IsListType()) { // FIXME!!! create LAVMTypeList::ToString
+ s += "(";
+ bool First = true;
+ for (const LAVMType& t : list) {
+ if (!First) {
+ s += ", ";
+ }
+ s += t.ToString();
+ First = false;
+ }
+ s += ")";
+ } else if (IsFunctionType()) {
+ auto to_string = [](const LAVMType& type) {
+ return type.IsFunctionType() ? "[" + type.ToString() + "]"
+ : type.ToString();
+ };
+ s += to_string(GetDomainType()) + " -> " + to_string(GetRangeType());
+ } else {
+ s += "<<unknown LAVMType kind>>";
+ }
+ return s;
+ }
+
+ private:
+ static constexpr char kVoid[] = "";
+ static constexpr char kAny[] = "AnyType";
+ bool is_name = false; // similar to LAVMOp struct, here it means is_leaf
+ // FIXME!!! both should be renamed
+ std::string name = "";
+ LAVMTypeList list = {};
+ LAVMTypeList domain = {};
+ LAVMTypeList range = {}; // non-empty for function types; want LAVMType here,
+ // not LAVMTypeList
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMInstantiation class
+////////////////////////////////////////////////////////////////////////////////
+
+class LAVMTarget; // needed for LAVMInstantiation::GetCost(). Problem with
+ // design?
+
+// Map operands to the types assigned to these operands. Can represent a
+// specific instantiation of an expansion tree.
+struct LAVMInstantiation;
+using LAVMInstantiationList = std::vector<LAVMInstantiation>;
+struct LAVMInstantiation {
+ // Access:
+ // FIXME!!! move definitions to a cpp file
+ bool empty() const { return ops_to_types.empty() && names_to_types.empty(); }
+ const LAVMType* GetType(const LAVMOp* op) const {
+ if (op->IsName()) {
+ auto it = names_to_types.find(op->GetName());
+ return it == names_to_types.end() ? nullptr : &it->second;
+ } else if (op->IsOp()) {
+ auto it = ops_to_types.find(op);
+ return it == ops_to_types.end() ? nullptr : &it->second;
+ } else {
+ assert(false);
+ return nullptr;
+ }
+ }
+ const LAVMOp* GetLAVMOp() const { return lavm_op; }
+ const LAVMOp* GetExpansionOp() const { return expansion_op; }
+ // FIXME!!! explain the meaning.
+ const LAVMInstantiation* GetSubInstantiation(const LAVMOp* op) const {
+ auto it = ops_to_sub_instantiations.find(op);
+ if (it == ops_to_sub_instantiations.end()) {
+ return nullptr;
+ }
+ const LAVMInstantiationList& list = it->second;
+ assert(list.size() == 1);
+ return &list.front();
+ }
+ // Get the top-level type of this instantiation: (arg types) -> result_type
+ // NOTE: this function returns a new instance of LAVMType.
+ // FIXME!!! rename
+ LAVMType GetTopLevelType() const;
+ int32_t GetCost(const LAVMTarget* target) const;
+ // DFS visitor
+ // FIXME!!! consider allowing f to have an arbitrary return type
+ template <typename F>
+ void dfs(F&& f) const {
+ auto g = [this, &f](const LAVMOp* op) {
+ f(op, this);
+ const LAVMInstantiation* sub_instantiation = GetSubInstantiation(op);
+ if (sub_instantiation != nullptr) {
+ sub_instantiation->dfs(f);
+ }
+ };
+ expansion_op->dfs(g);
+ }
+
+ // Mutate:
+ void SetLAVMOp(const LAVMOp* op) {
+ // assert(IsLAVMOp(op));
+ assert(op != nullptr && op->IsFlatOp());
+ lavm_op = op;
+ }
+ void SetExpansionOp(const LAVMOp* op) { expansion_op = op; }
+ void SetType(const LAVMOp* op, const LAVMType* type) {
+ if (op->IsName()) {
+ names_to_types[op->GetName()] = *type;
+ } else if (op->IsOp()) {
+ ops_to_types[op] = *type;
+ } else {
+ assert(false);
+ }
+ }
+ void SetSubInstantiation(const LAVMOp* op,
+ const LAVMInstantiation& sub_instantiation) {
+ // assert(op->IsLAVMOp());
+ ops_to_sub_instantiations[op] = LAVMInstantiationList{sub_instantiation};
+ }
+
+ // Serialize:
+ std::string ToString() const {
+ auto emit_trailing_type = [this](const LAVMOp* op) -> std::string {
+ const LAVMType* type = GetType(op);
+ if (type->IsFunctionType()) {
+ type = &type->GetRangeType();
+ }
+ const LAVMInstantiation* sub_instantiation = GetSubInstantiation(op);
+ if (sub_instantiation == nullptr) {
+ return " : " + type->ToString();
+ }
+ return " : " + type->ToString() + " " + sub_instantiation->ToString();
+ };
+ return "{ " + expansion_op->ToString(LAVMOp::bupkis, emit_trailing_type) +
+ " }";
+ }
+ // FIXME!!! this method does not belong here
+ static std::string ToDetailedString(
+ const LAVMInstantiationList& instantiations, const LAVMOp& lavm_op,
+ const LAVMTarget& target) {
+ std::string s = "Inferred " + std::to_string(instantiations.size()) +
+ " instantiations of " + lavm_op.ToString() + ":\n";
+ for (const LAVMInstantiation& instantiation : instantiations) {
+ assert(instantiation.GetLAVMOp() == &lavm_op);
+ const LAVMType instantiation_type = instantiation.GetTopLevelType();
+ s += " As " + instantiation.GetExpansionOp()->ToString() +
+ " with cost " + std::to_string(instantiation.GetCost(&target)) +
+ ": " + instantiation_type.ToString() + " " +
+ instantiation.ToString() + "\n";
+ }
+ return s;
+ }
+
+ private:
+ // LAVM op for which this instantiation is computed.
+ const LAVMOp* lavm_op = nullptr;
+ // Starting op for the expansion.
+ const LAVMOp* expansion_op = nullptr;
+ // Mapping from operations to types assigned to those operations.
+ // IMPORTANT FIXME!!! Different pointers are expected to be used for different
+ // op appearances in the expansion tree. Must be always kept in mind until
+ // eventually this reliance on the different pointer values is removed.
+ std::map<const LAVMOp*, LAVMType> ops_to_types;
+ // Mapping from input variables to types assigned to those input variables.
+ std::map<std::string, LAVMType> names_to_types;
+ // Mapping from lavm operations in the expansion tree to their
+ // sub-instantiations. One sub-instantitaion per operation (list size must be
+ // 1). Same IMPORTANT FIXME as for the ops_to_types.
+ std::map<const LAVMOp*, LAVMInstantiationList> ops_to_sub_instantiations;
+};
+
+////////////////////////////////////////////////////////////////////////////////
+// LAVMTarget class
+////////////////////////////////////////////////////////////////////////////////
+
+class LAVMTarget {
+ private:
+ using ExpansionMap = std::map<LAVMOp, LAVMOpList>;
+ using TargetOpTypeMap = std::map<LAVMOp, LAVMTypeList>;
+ using TargetOpTypeAnnotationMap =
+ std::map<std::pair<LAVMOp, LAVMType>, LAVMAnnotation>; // FIXME!!! hacky
+ using TargetMemoryMap =
+ std::map<std::string, LAVMTargetMemory>; // name -> mem
+
+ public:
+ // Interface:
+ //
+ // FIXME!!! Should all classes above be hidden? And accessed only through the
+ // target class? Do we want 'pass-through' methods in the interface?
+ // On the one hand, they do nothing, on the other hand, all information
+ // would be accessible via a single target object and instantiation itself can
+ // be opaque to the user (as well as lavm op, type, etc.) ???
+
+ // Finding and checking specific LAVM ops.
+ const LAVMOp* GetTargetOp(const std::string& name) const;
+ const LAVMOp* GetLAVMOp(const std::string& name) const;
+ const LAVMOp* GetLAVMLoad() const { return GetLAVMOp(kLAVMLoadName); }
+ const LAVMOp* GetLAVMStore() const { return GetLAVMOp(kLAVMStoreName); }
+ const LAVMOp* GetLAVMDma() const { return GetLAVMOp(kLAVMDmaName); }
+
+ bool ParseFromFile(const char* target_description_filename,
+ bool use_lex_yacc);
+ LAVMInstantiationList InferValidTypes(const LAVMOp* lavm_op) const;
+ bool SupportsLoad(const LAVMTargetMemory& mem) const;
+ bool SupportsStore(const LAVMTargetMemory& mem) const;
+ // FIXME!!! common out the code for these two methods.
+ LAVMTypeList GetLoadTypes(const LAVMTargetMemory& mem) const;
+ LAVMTypeList GetStoreTypes(const LAVMTargetMemory& mem) const;
+ bool SupportsDma(const LAVMTargetMemory& to,
+ const LAVMTargetMemory& from) const;
+ const LAVMAnnotation* GetTypeAnnotation(const LAVMOp& op,
+ const LAVMType& type) const {
+ // FIXME!!! this is fragile and confusing!!!
+ const LAVMOp* target_op = GetTargetOp(op.GetOpName());
+ assert(target_op != nullptr);
+ auto it =
+ target_op_type_annotations.find(std::make_pair(*target_op, type));
+ // std::cerr << "\nLooking for type annotation for ["
+ // << op.ToString() << "/" << target_op->ToString() << ", " <<
+ // type.ToString() << "]: "
+ // << (it == target_op_type_annotations.end() ? "NOT FOUND" : "found")
+ // << "\n";
+ return it == target_op_type_annotations.end() ? nullptr : &it->second;
+ }
+
+ // FIXME!!! Instead, need IsTargetLoad() instead. How to implement it?
+ // bool IsLAVMLoad(const LAVMOp* op) const {
+ // if (op == GetLAVMLoad()) {
+ // return true;
+ // }
+ // assert(op->IsName() || op->GetOpName() != GetLAVMLoad()->GetOpName());
+ // return false;
+ // }
+
+ // Iterators over target properties.
+ // FIXME!!! these expose internal representation, consider hiding the details.
+ ExpansionMap::iterator ExpansionBegin() { return expansions.begin(); }
+ ExpansionMap::iterator ExpansionEnd() { return expansions.end(); }
+ ExpansionMap::const_iterator ExpansionBegin() const {
+ return expansions.begin();
+ }
+ ExpansionMap::const_iterator ExpansionEnd() const {
+ return expansions.end();
+ }
+ TargetMemoryMap::iterator MemoryBegin() { return memory.begin(); }
+ TargetMemoryMap::iterator MemoryEnd() { return memory.end(); }
+ TargetMemoryMap::const_iterator MemoryBegin() const {
+ return memory.begin();
+ }
+ TargetMemoryMap::const_iterator MemoryEnd() const { return memory.end(); }
+
+ // Serialize:
+ std::string ToString() const;
+ std::string ToDetailedString() const;
+
+ private:
+ // Helpers:
+ template <typename T>
+ LAVMOpPtrList Filter(const std::string& name, const T& map) const {
+ LAVMOpPtrList list;
+ for (const auto& pair : map) {
+ const LAVMOp& op = pair.first;
+ if (op.GetOpName() == name) {
+ list.push_back(&op);
+ }
+ }
+ return list;
+ }
+ LAVMOpPtrList FilterExpansions(const std::string& name) const {
+ return Filter(name, expansions);
+ }
+ LAVMOpPtrList FilterTargetOpTypes(const std::string& name) const {
+ return Filter(name, target_op_types);
+ }
+ LAVMOpPtrList GetTargetExpansions(const LAVMOp* lavm_op) const;
+ LAVMTypePtrList GetTypesSupportedByTargetOp(const LAVMOp* op) const;
+ const LAVMOp* GetTargetOpImpl(const std::string& name) const;
+ const LAVMOp* GetLAVMOpImpl(const std::string& name) const;
+ LAVMTypeList GetLoadTypesImpl(const LAVMTargetMemory& mem) const;
+ LAVMTypeList GetStoreTypesImpl(const LAVMTargetMemory& mem) const;
+ bool SupportsDmaImpl(const LAVMTargetMemory& to,
+ const LAVMTargetMemory& from) const;
+
+ // Type inference:
+ static const LAVMType* FindCommonUseType(
+ const LAVMOpArgList& uses, const LAVMInstantiation& instantiation,
+ const LAVMType* want_type = nullptr);
+ template <typename F>
+ void InstantiateList(const LAVMOpPtrList& list, const LAVMOpUses& op_uses,
+ F&& at_end, LAVMInstantiation instantiation = {}) const;
+ LAVMInstantiationList InferValidTypesImpl(const LAVMOp* lavm_op) const;
+
+ // Parsing (FIXME!!! Move these methods to LAVMTargetParser class):
+ Token ParseOp(Token token, LAVMOp* op) const;
+ Token ParseOpList(Token token, LAVMOpList* list) const;
+ Token ParseType(Token token, LAVMType* type) const;
+ Token ParseAnnotation(Token token, LAVMAnnotation* annotation) const;
+ Token ParseTypeList(Token token, LAVMTypeList* list) const;
+ Token ParseAnnotatedTypeList(Token token, LAVMAnnotatedTypeList* list) const;
+
+ Token ParseTargetSectionEntry(Token token);
+ Token ParseTargetSection(Token token);
+ Token ParseMapSectionEntry(Token token);
+ Token ParseMapSection(Token token);
+ Token ParseMemorySectionEntry(Token token);
+ Token ParseMemorySection(Token token);
+ Token SkipToEndOfSection(Token token) const;
+ Token SkipSection(Token token) const;
+ bool Validate();
+ bool Parse(const char* target_description);
+ bool ParseFromFileImpl(const char* target_description_filename,
+ bool use_lex_yacc);
+
+ private:
+ static constexpr char kLAVMLoadName[] = "lavm.load";
+ static constexpr char kLAVMStoreName[] = "lavm.store";
+ static constexpr char kLAVMDmaName[] = "lavm.dma";
+
+ // Available expansions of LAVM ops, result of parsing the 'map' section:
+ // lavm_op(%a, ...) -> {target_op(target_op(%a, ...), ...), ...}
+ ExpansionMap expansions;
+ // Available instantiations of target ops, result of parsing the 'target'
+ // section: target_op : type, type, ...;
+ // FIXME!!! instead of the value being the type list can use instantiation
+ // list, which may erase the difference between the instantiations of LAVM ops
+ // and of target ops.
+ TargetOpTypeMap target_op_types;
+ // Optional annotations attached to target op types:
+ // target_op : type : annotation, ...;
+ // FIXME!!! consider making annotations (and their parsing/printing) same as
+ // memory attributes.
+ TargetOpTypeAnnotationMap target_op_type_annotations;
+ // Available memory levels and their properties
+ TargetMemoryMap memory;
+};
+
+} // namespace lavm
+} // namespace mlir
+
+#endif // EXPERIMENTAL_LAVM_LAVMTARGET_H_
diff --git a/LAVM/LAVMTargetParser.cpp b/LAVM/LAVMTargetParser.cpp
new file mode 100644
index 0000000..38c6783
--- /dev/null
+++ b/LAVM/LAVMTargetParser.cpp
@@ -0,0 +1,777 @@
+#include <cstddef>
+#include <fstream>
+
+#include "LAVMTarget.h"
+#include "Lexer.h"
+#include "md_grammar.h"
+
+////////////////////////////////////////////////////////////////////////////////
+// New parsing format
+////////////////////////////////////////////////////////////////////////////////
+
+namespace mlir {
+namespace lavm {
+namespace {
+
+struct {
+ void init() {
+ name = "";
+ mnemonics = "";
+ ins.clear();
+ outs.clear();
+ PatternNode::Destroy(pattern);
+ pattern = nullptr;
+ llvm = "";
+ }
+
+ std::string ToString() const {
+ std::string s;
+ s += "name = '" + name + "'";
+ s += ", mnemonics = '" + mnemonics + "'";
+ s += ", ins = (";
+ bool First = true;
+ for (const auto& in : ins) {
+ if (!First) {
+ s += ", ";
+ }
+ s += in.first + ":" + in.second;
+ First = false;
+ }
+ s += ")";
+ s += ", outs = (";
+ First = true;
+ for (const auto& out : outs) {
+ if (!First) {
+ s += ", ";
+ }
+ s += out.first + ":" + out.second;
+ First = false;
+ }
+ s += ")";
+ if (pattern != nullptr) {
+ s += ", pattern = '" + pattern->ToString() + "'";
+ }
+ s += ", llvm = '" + llvm + "'";
+ return s;
+ }
+
+ std::string name = "";
+ std::string mnemonics = "";
+ std::vector<std::pair<std::string, std::string>> ins = {};
+ std::vector<std::pair<std::string, std::string>> outs = {};
+ PatternNode* pattern = nullptr;
+ std::string llvm = "";
+} parsing_data;
+
+void DefClauseInit() {
+ // std::cerr << "def clause init\n";
+ parsing_data.init();
+}
+
+void DefClauseDone() {
+ std::cerr << "def clause " << parsing_data.ToString() << "\n";
+}
+
+void DefClauseSetName(const char* name) {
+ // std::cerr << "def clause set name = '" << name << "'\n";
+ parsing_data.name = std::string(name);
+}
+
+void DefClauseSetMnemonics(const char* mnemonics) {
+ // std::cerr << "def clause set mnemonics = '" << mnemonics << "'\n";
+ parsing_data.mnemonics = std::string(mnemonics);
+}
+
+void DefClauseAddInout(const char* storage, const char* name,
+ bool parsing_outs) {
+ // std::cerr << "def clause add inout = '" << storage << "':'" << name
+ // << "', parsing_outs = " << parsing_outs << "\n";
+ if (parsing_outs) {
+ parsing_data.outs.push_back(std::make_pair(std::string(storage),
+ std::string(name)));
+ } else {
+ parsing_data.ins.push_back(std::make_pair(std::string(storage),
+ std::string(name)));
+ }
+}
+
+void DefClauseSetPattern(PatternNode* pattern) {
+ // std::cerr << "def clause set pattern = '" << pattern->ToString() << "'\n";
+ // parsing_data now owns the pattern and is responsible for deleting it.
+ parsing_data.pattern = pattern;
+}
+
+void DefClauseSetLlvm(const char* llvm) {
+ // std::cerr << "def clause set llvm = '" << llvm << "'\n";
+ parsing_data.llvm = std::string(llvm);
+}
+
+} // namespace
+} // namespace lavm
+} // namespace mlir
+
+// Functions called from the bison rules:
+
+void def_clause_init() { mlir::lavm::DefClauseInit(); }
+void def_clause_done() { mlir::lavm::DefClauseDone(); }
+
+void def_clause_set_name(const char* name) {
+ mlir::lavm::DefClauseSetName(name);
+}
+
+void def_clause_set_mnemonics(const char* mnemonics) {
+ mlir::lavm::DefClauseSetMnemonics(mnemonics);
+}
+
+void def_clause_add_inout(const char* storage, const char* name,
+ bool parsing_outs) {
+ mlir::lavm::DefClauseAddInout(storage, name, parsing_outs);
+}
+
+void def_clause_set_pattern(PatternNode* pattern) {
+ mlir::lavm::DefClauseSetPattern(pattern);
+}
+
+void def_clause_set_llvm(const char* llvm) {
+ mlir::lavm::DefClauseSetLlvm(llvm);
+}
+
+namespace mlir {
+namespace lavm {
+
+// FIXME: consistently check for errors, such as end of file
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: LAVMAnnotation
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseAnnotation(Token token,
+ LAVMAnnotation* annotation) const {
+ *annotation = LAVMAnnotation::Create(token.str());
+ return Lexer::NextToken(token, false);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: LAVMOp
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseOp(Token token, LAVMOp* op) const {
+ if (Lexer::is_name_marker(*token.data())) {
+ token = token.drop_front();
+ // FIXME!!! check for end/empty...
+ *op = LAVMOp::CreateAsName(token.str());
+ token = Lexer::NextToken(token, false);
+ } else {
+ *op = LAVMOp::Create(token.str());
+ token = Lexer::NextToken(token, false);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ if (Lexer::Eat("(", &token, false)) {
+ // Parse nested list and push its elements as op's operands.
+ LAVMOpList nested_list;
+ token = ParseOpList(token, &nested_list);
+ for (const LAVMOp& I : nested_list) {
+ op->AddOperand(I);
+ }
+ Lexer::EatOrError(")", &token, false);
+ }
+ }
+ return token;
+}
+
+Token LAVMTarget::ParseOpList(Token token, LAVMOpList* list) const {
+ assert(!token.empty());
+ // FIXME!!! this is a hack to allow empty '()' for void ops
+ if (Lexer::is_close_paren(*token.data())) {
+ return token;
+ }
+ LAVMOp op;
+ token = ParseOp(token, &op);
+ list->push_back(op);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ if (Lexer::is_list_separator(*token.data())) {
+ return ParseOpList(Lexer::NextToken(token, false), list);
+ }
+ return token;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: LAVMType
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseType(Token token, LAVMType* type) const {
+ // std::cerr << "Enter ParseType\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ auto parse_type = [this](Token token, LAVMType* type) {
+ if (Lexer::is_open_paren(*token.data())) {
+ Lexer::Eat("(", &token, false);
+ // Parse nested list
+ LAVMTypeList list;
+ token = ParseTypeList(token, &list);
+ if (!Lexer::EatOrError(")", &token, false)) {
+ return token;
+ }
+ *type = LAVMType::Create(list);
+ return token;
+ } else {
+ *type = LAVMType::CreateNameType(token.str());
+ return Lexer::NextToken(token, false);
+ }
+ };
+
+ LAVMType domain;
+ token = parse_type(token, &domain);
+ // std::cerr << "Parsed domain '" << domain.ToString() << "'\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ if (Lexer::is_arrow(token)) {
+ Lexer::Eat("->", &token, false);
+ // std::cerr << "ate ->\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ LAVMType range;
+ token = ParseType(token, &range);
+ // std::cerr << "Parsed range '" << range.ToString() << "'\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ *type = LAVMType::CreateFunctionType(domain, range);
+ // } else if (domain.size() == 1) {
+ // // std::cerr << "Simple type '" << domain.front().ToString() << "'\n";
+ // // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ // *type = domain.front();
+ } else {
+ *type = domain;
+ }
+ // std::cerr << "Exit ParseType\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ return token;
+}
+
+// FIXME!!! with the addition of ParseAnnotatedTypeList() this method has
+// (almost) become redundant.
+//
+Token LAVMTarget::ParseTypeList(Token token, LAVMTypeList* list) const {
+ // std::cerr << "Enter ParseTypeList\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ assert(!token.empty());
+ // FIXME!!! this is a hack to allow empty '()' for void types
+ if (Lexer::is_close_paren(*token.data())) {
+ return token;
+ }
+ LAVMType type;
+ token = ParseType(token, &type);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+
+ list->push_back(type);
+
+ // std::cerr << "Parsed type '" << type.ToString() << "'\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ if (Lexer::is_list_separator(*token.data())) {
+ // std::cerr << "Tail call to ParseTypeList\n";
+ return ParseTypeList(Lexer::NextToken(token, false), list);
+ }
+ // std::cerr << "Exit ParseTypeList\n";
+ return token;
+}
+
+Token LAVMTarget::ParseAnnotatedTypeList(Token token,
+ LAVMAnnotatedTypeList* list) const {
+ // std::cerr << "Enter ParseAnnotatedTypeList\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ assert(!token.empty());
+ do {
+ LAVMType type;
+ token = ParseType(token, &type);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ LAVMAnnotation annotation;
+ if (Lexer::Eat("@", &token, false)) {
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ token = ParseAnnotation(token, &annotation);
+ }
+ // std::cerr << "Parsed type '" << type.ToString() << "'" <<
+ // (annotation.empty() ? "" : " with annotation " + annotation) << "\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "'\n";
+ list->push_back(std::make_pair(type, annotation));
+ // FIXME!!! replace "," with Lexer::get_list_separator()
+ } while (Lexer::Eat(",", &token, false));
+
+ // std::cerr << "Exit ParseAnnotatedTypeList\n";
+ return token;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: "target" section. It has the following format:
+// target {
+// op(%arg0, ...) : T [ @ cost], ...; # zero or more arguments, one or more
+// # types, cost annotation optional
+// ...
+// }
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseTargetSectionEntry(Token token) {
+ // std::cerr << "Enter ParseTargetEntry\n";
+ // std::cerr << "token size " << token.size() << ": '" <<
+ // Lexer::TokenToString(token) << "\n"; std::cerr << "token size " <<
+ // token.size() << ": '" << Lexer::TokenToString(token) << "' from '" <<
+ // token.data() << "'\n"; left side: one op, possibly with parameters
+ LAVMOpList lhs;
+ token = ParseOpList(token, &lhs);
+ if (Lexer::is_end(token)) {
+ // std::cerr << "Exit ParseTargetEntry, error\n";
+ return token;
+ }
+ if (lhs.size() != 1) {
+ Lexer::ParseError(
+ "Exactly one op expected on the left side of ':' in the target "
+ "section");
+ }
+ const LAVMOp& target_op = lhs.front();
+ if (!target_op.IsFlatOp()) {
+ Lexer::ParseError(
+ "Only flat ops are supported on the left side of ':' in the target "
+ "section for now");
+ return token;
+ }
+
+ // separator
+ if (!Lexer::EatOrError(":", &token, false)) {
+ // std::cerr << "Exit ParseTargetEntry, error\n";
+ return token;
+ }
+
+ // Right side: list of types
+ LAVMAnnotatedTypeList rhs;
+ token = ParseAnnotatedTypeList(token, &rhs);
+ if (Lexer::is_end(token)) {
+ // std::cerr << "Exit ParseTargetEntry, error\n";
+ return token;
+ }
+ // std::cerr << "Parsed types:\n";
+ // for (const LAVMType& t : rhs) {
+ // std::cerr << " " << t.ToString() << "\n";
+ // }
+
+ for (const auto& it : rhs) {
+ const LAVMType& type = it.first;
+ const LAVMAnnotation& annotation = it.second;
+ // FIXME: detect duplicates
+ if (type.IsFunctionType()) {
+ target_op_types[target_op].push_back(type);
+ } else {
+ // All operands and the result have same type.
+ LAVMTypeList list(target_op.GetNumOperands(), type);
+ const LAVMType domain_type = LAVMType::Create(list);
+ const LAVMType& range_type = type;
+ target_op_types[target_op].push_back(
+ LAVMType::CreateFunctionType(domain_type, range_type));
+ }
+ if (!annotation.empty()) {
+ // Note the use of .back(): becomes invalid if the duplicate detection is
+ // implemented above.
+ target_op_type_annotations[std::make_pair(
+ target_op, target_op_types[target_op].back())] = annotation;
+ }
+ }
+ return token;
+}
+
+Token LAVMTarget::ParseTargetSection(Token token) {
+ if (!Lexer::EatOrError("{", &token)) {
+ return token;
+ }
+ while (token != "}") {
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the target section");
+ return token;
+ }
+ token = ParseTargetSectionEntry(token);
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the target section");
+ return token;
+ }
+ if (!Lexer::EatOrError(";", &token)) {
+ return token;
+ }
+ }
+ Lexer::Eat("}", &token);
+ return token;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: "map" section. It has the following format:
+// map {
+// op(%arg0, ...) : expansion(%arg0, ...); # one flat op on the left side, one
+// # or more expansions on the right
+// ...
+// }
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseMapSectionEntry(Token token) {
+ // left side: one op, possibly with parameters
+ LAVMOpList lhs;
+ token = ParseOpList(token, &lhs);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ if (lhs.size() != 1) {
+ Lexer::ParseError(
+ "Exactly one op expected on the left side of ':' in the map section");
+ }
+ const LAVMOp& lavm_op = lhs.front();
+ if (!lavm_op.IsFlatOp()) {
+ Lexer::ParseError(
+ "Only flat ops are supported on the left side of ':' in the map "
+ "section for now");
+ return token;
+ }
+
+ // separator
+ if (!Lexer::EatOrError(":", &token, false)) {
+ return token;
+ }
+
+ // Right side: list of ops
+ LAVMOpList rhs;
+ token = ParseOpList(token, &rhs);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ for (const LAVMOp& target_op : rhs) {
+ // FIXME!!! check that rhs names are a subset of lhs names
+ expansions[lavm_op].push_back(target_op);
+ }
+ return token;
+}
+
+Token LAVMTarget::ParseMapSection(Token token) {
+ if (!Lexer::EatOrError("{", &token)) {
+ return token;
+ }
+ while (token != "}") {
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the map section");
+ return token;
+ }
+ token = ParseMapSectionEntry(token);
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the map section");
+ return token;
+ }
+ if (!Lexer::EatOrError(";", &token)) {
+ return token;
+ }
+ }
+ Lexer::Eat("}", &token);
+ return token;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: "memory" section
+////////////////////////////////////////////////////////////////////////////////
+
+Token LAVMTarget::ParseMemorySectionEntry(Token token) {
+ // Left side: memory name
+ const bool is_cache = Lexer::is_cache_marker(*token.data());
+ if (is_cache) {
+ token = token.drop_front();
+ // FIXME!!! check for end/empty...
+ }
+ std::string mem_name = token.str();
+ token = Lexer::NextToken(token, false);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+
+ // Right side: the list of the attributes and their values.
+ if (Lexer::Eat(":", &token, false)) {
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+
+ // Create memory entry if has not been defined yet.
+ if (memory.find(mem_name) == memory.end()) {
+ memory[mem_name] = LAVMTargetMemory::Create(mem_name, is_cache);
+ }
+
+ LAVMTargetMemory& mem = memory[mem_name];
+ if (mem.IsCache() != is_cache) {
+ Lexer::ParseError(mem_name +
+ " is defined as both cache and non-cache memory.");
+ return token;
+ }
+
+ // Right side: 'attribute = value' list
+ bool First = true;
+ while (token != ";") {
+ if (!First) {
+ if (!Lexer::EatOrError(",", &token, false) || Lexer::is_end(token)) {
+ return token;
+ }
+ }
+ std::string attribute_name = token.str();
+ token = Lexer::NextToken(token, false);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+ if (!Lexer::EatOrError("=", &token, false) || Lexer::is_end(token)) {
+ return token;
+ }
+ std::string attribute_value = token.str();
+ token = Lexer::NextToken(token, false);
+
+ mem.AddAttribute(attribute_name, attribute_value);
+
+ First = false;
+ }
+ return token;
+ }
+
+ // Right side: data transfer between different kinds of memories.
+ // FIXME!!! this is incomplete and messy
+ if (Lexer::is_arrow(token)) {
+ bool from_is_cache = is_cache;
+ std::string from_mem_name = mem_name;
+ while (Lexer::Eat("->", &token, false)) {
+ bool to_is_cache = Lexer::is_cache_marker(*token.data());
+ if (to_is_cache) {
+ token = token.drop_front();
+ // FIXME!!! check for end/empty...
+ }
+ std::string to_mem_name = token.str();
+ token = Lexer::NextToken(token, false);
+ if (Lexer::is_end(token)) {
+ return token;
+ }
+
+ auto from_it = memory.find(from_mem_name);
+ auto to_it = memory.find(to_mem_name);
+ if (from_it == memory.end()) {
+ Lexer::ParseError(from_mem_name + " is not defined in " +
+ from_mem_name + " -> " + to_mem_name);
+ return token;
+ }
+ if (to_it == memory.end()) {
+ Lexer::ParseError(to_mem_name + " is not defined in " + from_mem_name +
+ " -> " + to_mem_name);
+ return token;
+ }
+ if (from_it->second.IsCache() != from_is_cache) {
+ Lexer::ParseError(from_mem_name + " has inconsistent cache markers.");
+ return token;
+ }
+ if (to_it->second.IsCache() != to_is_cache) {
+ Lexer::ParseError(to_mem_name + " has inconsistent cache markers.");
+ return token;
+ }
+
+ LAVMTargetMemory::AddTransfer(&from_it->second, &to_it->second);
+
+ from_is_cache = to_is_cache;
+ from_mem_name = to_mem_name;
+ }
+ return token;
+ }
+
+ Lexer::ParseErrorExpected("':' or '->'", token);
+ return token;
+}
+
+Token LAVMTarget::ParseMemorySection(Token token) {
+ if (!Lexer::EatOrError("{", &token)) {
+ return token;
+ }
+ while (token != "}") {
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the memory section");
+ return token;
+ }
+ token = ParseMemorySectionEntry(token);
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file inside the memory section");
+ return token;
+ }
+ if (!Lexer::EatOrError(";", &token)) {
+ return token;
+ }
+ }
+ Lexer::Eat("}", &token);
+ return token;
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing: skip a section without parsing its content
+////////////////////////////////////////////////////////////////////////////////
+
+// Curly braces inside the section must match properly so that the end of the
+// section is correctly detected.
+//
+Token LAVMTarget::SkipToEndOfSection(Token token) const {
+ int32_t nesting_level = 1;
+ while (nesting_level > 0) {
+ if (Lexer::is_end(token)) {
+ Lexer::ParseError("Unexpected end of file within the skipped section");
+ return token;
+ }
+ if (token == "{") {
+ nesting_level++;
+ } else if (token == "}") {
+ nesting_level--;
+ }
+ token = Lexer::NextToken(token, false);
+ }
+ return token;
+}
+
+// Curly braces inside the section must match properly so that the end of the
+// section is correctly detected.
+//
+Token LAVMTarget::SkipSection(Token token) const {
+ if (!Lexer::EatOrError("{", &token)) {
+ return token;
+ }
+ return SkipToEndOfSection(token);
+}
+
+////////////////////////////////////////////////////////////////////////////////
+// Parsing
+////////////////////////////////////////////////////////////////////////////////
+
+// Check if the parsed target description is consistent.
+//
+// FIXME!!! this method may better fit elsewhere, but first we need a mechanism
+// to return/emit error messages/codes.
+//
+bool LAVMTarget::Validate() {
+ bool Success = true;
+
+ // Check the expansions in the map section.
+ for (auto it = ExpansionBegin(); it != ExpansionEnd(); ++it) {
+ const LAVMOp& lavm_op = it->first;
+ const LAVMOpList& target_op_list = it->second;
+
+ // Already enforced, but may need to update this code if that changes.
+ assert(lavm_op.IsFlatOp());
+
+ // Collect the arguments of the lavm_op.
+ const LAVMOpPtrList lavm_op_arguments = lavm_op.FilterArguments();
+
+ // Check for duplicate argument names.
+ if (lavm_op_arguments.size() != lavm_op.GetNumOperands()) {
+ std::cerr << "Duplicate argument(s) in the expansion of "
+ << lavm_op.ToString() << "\n";
+ Success = false;
+ }
+
+ // Check that each argument name used in the expansion had been declared by
+ // the lavm_op. "op(%a) : expansion(%b)" is an error. Emit warning for the
+ // arguments declared by the lavm_op but not used in the expansion.
+ for (const LAVMOp& target_op : target_op_list) {
+ const LAVMOpPtrList target_op_arguments = target_op.FilterArguments();
+ for (const LAVMOp* op : target_op_arguments) {
+ if (!LAVMOp::Contains(lavm_op_arguments, *op)) {
+ std::cerr << "Use of undeclared argument " << op->ToString()
+ << " in the expansion " << lavm_op.ToString()
+ << " : " << target_op.ToString() << "\n";
+ Success = false;
+ }
+ }
+ for (const LAVMOp* op : lavm_op_arguments) {
+ if (!LAVMOp::Contains(target_op_arguments, *op)) {
+ std::cerr << "WARNING: argument " << op->ToString()
+ << " is not used in the expansion " << lavm_op.ToString()
+ << " : " << target_op.ToString() << "\n";
+ }
+ }
+ }
+ }
+ return Success;
+}
+
+// Main parser entry.
+// FIXME: update syntax to uniform and replace with generalized parsing
+//
+bool LAVMTarget::Parse(const char* target_description) {
+ int32_t garbage_tokens_in_a_row = 0; // FIXME!!! find a better way
+ Token token = Lexer::FindToken(target_description);
+ while (!Lexer::is_end(token)) {
+ if (Lexer::Eat("target", &token, false)) {
+ token = ParseTargetSection(token);
+ garbage_tokens_in_a_row = 0;
+ } else if (Lexer::Eat("map", &token, false)) {
+ token = ParseMapSection(token);
+ garbage_tokens_in_a_row = 0;
+ } else if (Lexer::Eat("memory", &token, false)) {
+ token = ParseMemorySection(token);
+ garbage_tokens_in_a_row = 0;
+ } else if (Lexer::Eat("md", &token, false)) {
+ token = SkipSection(token);
+ garbage_tokens_in_a_row = 0;
+ } else if (Lexer::Eat("skip", &token, false)) {
+ token = SkipSection(token);
+ garbage_tokens_in_a_row = 0;
+ } else {
+ if (garbage_tokens_in_a_row == 0) {
+ Lexer::ParseWarning("Garbage in target description: '" +
+ Lexer::TokenToString(token) + "', ignored");
+ } else if (garbage_tokens_in_a_row == 1) {
+ Lexer::ParseWarning("More garbage, ignored");
+ }
+ token = Lexer::NextToken(token);
+ garbage_tokens_in_a_row++;
+ }
+ }
+ // FIXME!!! return false earlier if there were parsing errors.
+ return Validate();
+}
+
+bool LAVMTarget::ParseFromFileImpl(const char* target_description_filename,
+ bool use_lex_yacc) {
+ if (use_lex_yacc) {
+ return md_read_file(target_description_filename) == 0;
+ }
+ bool Success = false;
+ std::fstream file;
+ file.open(target_description_filename, std::fstream::in);
+ if (file.is_open()) {
+ file.seekg(0, file.end);
+ int32_t length = file.tellg();
+ file.seekg(0, file.beg);
+ if (length >= 0) {
+ std::string input(length + 1, '\0');
+ file.read(&input[0], length);
+ if (Parse(input.data())) {
+ Success = true;
+ } else {
+ std::cerr << "Errors in the target description file "
+ << target_description_filename << "\n";
+ }
+ } else {
+ std::cerr << "Empty target description file "
+ << target_description_filename << "\n";
+ }
+ file.close();
+ } else {
+ std::cerr << "Error opening target description file "
+ << target_description_filename << "\n";
+ }
+ return Success;
+}
+
+} // namespace lavm
+} // namespace mlir
diff --git a/LAVM/Lexer.h b/LAVM/Lexer.h
new file mode 100644
index 0000000..c5148c4
--- /dev/null
+++ b/LAVM/Lexer.h
@@ -0,0 +1,147 @@
+#ifndef EXPERIMENTAL_LAVM_LEXER_H_
+#define EXPERIMENTAL_LAVM_LEXER_H_
+
+#include <iostream>
+
+#include "llvm/ADT/StringRef.h"
+
+namespace mlir {
+namespace lavm {
+
+using Token = llvm::StringRef;
+
+class Lexer {
+ public:
+ static bool is_newline(char c) { return c == '\n'; }
+
+ static bool is_whitespace(char c) {
+ return is_newline(c) || c == ' ' || c == '\t';
+ }
+
+ static bool is_list_separator(char c) { return c == ','; }
+
+ static bool is_open_paren(char c) { return c == '('; }
+
+ static bool is_close_paren(char c) { return c == ')'; }
+
+ static bool is_open_angle(char c) {
+ return false; // c == '<';
+ }
+
+ static bool is_close_angle(char c) {
+ return false; // c == '>'; conflicts with "->"
+ }
+
+ static bool is_separator(char c) {
+ return is_list_separator(c) || is_open_paren(c) || is_close_paren(c) ||
+ is_open_angle(c) || is_close_angle(c) || c == ':';
+ }
+
+ static bool is_terminator(char c) { return c == ';'; }
+
+ static bool is_name_marker(char c) { return c == '%'; }
+
+ static bool is_cache_marker(char c) { return c == '$'; }
+
+ static bool is_comment(char c) { return c == '#'; }
+
+ static bool is_end(char c) { return c == '\0'; }
+
+ static bool is_end(Token token) { return token.empty(); }
+
+ static bool is_separator(Token token) {
+ return token.size() == 1 && is_separator(token.front());
+ }
+
+ static bool is_terminator(Token token) {
+ return token.size() == 1 && is_terminator(token.front());
+ }
+
+ static bool is_arrow(Token token) {
+ return token.size() == 2 && token.front() == '-' && token.back() == '>';
+ }
+
+ static std::string TokenToString(Token token) {
+ return is_end(token) ? "END-OF-FILE" : token.str();
+ }
+
+ static Token FindToken(const char* start_at,
+ bool skip_terminators_and_separators = true) {
+ const char* first = start_at;
+ if (is_end(*first)) {
+ return "";
+ }
+ while (is_whitespace(*first) ||
+ (skip_terminators_and_separators &&
+ (is_terminator(*first) || is_separator(*first)))) {
+ first++;
+ }
+ if (is_separator(*first) || is_terminator(*first)) {
+ assert(!skip_terminators_and_separators);
+ return Token(first, 1);
+ }
+ if (is_comment(*first)) {
+ // Skip the rest of the line
+ while (!is_end(*first) && !is_newline(*first)) {
+ first++;
+ }
+ // Pass either 'end' or 'newline' in.
+ return FindToken(first, skip_terminators_and_separators);
+ }
+ const char* last = first;
+ while (!is_end(*last) && !is_whitespace(*last) && !is_separator(*last) &&
+ !is_terminator(*last))
+ last++;
+ return Token(first, last - first);
+ }
+
+ static Token NextToken(Token token,
+ bool skip_terminators_and_separators = true) {
+ const char* start_at = token.data() + token.size();
+ Token next = FindToken(start_at, skip_terminators_and_separators);
+ return next;
+ }
+
+ static bool Eat(const std::string& food, Token* token,
+ bool skip_terminators_and_separators = true) {
+ if (food == token->str()) {
+ *token = NextToken(*token, skip_terminators_and_separators);
+ return true;
+ }
+ return false;
+ }
+
+ static bool EatOrError(const std::string& food, Token* token,
+ bool skip_terminators_and_separators = true) {
+ return Eat(food, token, skip_terminators_and_separators) ||
+ ParseErrorExpected(food, *token);
+ }
+
+ static bool ParseError(const std::string& message) {
+ std::cerr << "ERROR: " << message << std::endl;
+ return false;
+ }
+
+ static bool ParseErrorExpected(const std::string& expected, Token token) {
+ ParseError("Expect to see '" + expected + "', instead '" +
+ TokenToString(token) + "' found");
+ return false;
+ }
+
+ static bool ParseWarning(const std::string& message) {
+ // FIXME? suppress warnings if any error messages had been emitted.
+ std::cerr << "WARNING: " << message << std::endl;
+ return false;
+ }
+
+ // static Type ParseType(Token token, MLIRContext* context) {
+ // // suppress "unexpected character" error message after the type token
+ // std::string t(token.begin(), token.end());
+ // return mlir::parseType(t, context);
+ // }
+};
+
+} // namespace lavm
+} // namespace mlir
+
+#endif // EXPERIMENTAL_LAVM_LEXER_H_
diff --git a/LAVM/OWNERS b/LAVM/OWNERS
new file mode 100644
index 0000000..feb9fe6
--- /dev/null
+++ b/LAVM/OWNERS
@@ -0,0 +1,8 @@
+set noparent
+
+ajcbik
+andydavis
+grosul
+ntv
+tatge
+qyi
diff --git a/LAVM/md_grammar.cpp b/LAVM/md_grammar.cpp
new file mode 100644
index 0000000..00bb606
--- /dev/null
+++ b/LAVM/md_grammar.cpp
@@ -0,0 +1,86 @@
+#include "md_grammar.h"
+
+#include <cassert>
+#include <cstdio>
+
+#include "experimental/LAVM/md_parser.y.h"
+
+int md_line = 1;
+int md_column = 0;
+bool parsing_outs = false;
+
+std::vector<char*> md_symbols;
+
+int md_read_file(const char *filename) {
+ int ret = -1;
+ extern FILE *md_scannerin;
+ md_scannerin = fopen(filename, "r");
+ md_line = 1;
+ md_column = 0;
+ parsing_outs = false;
+ assert(md_symbols.empty());
+ if (md_scannerin) {
+ ret = yyparse();
+ fclose(md_scannerin);
+ }
+ if (!ret)
+ fprintf(stderr, "Successfully parsed '%s' (%d lines)\n",
+ filename, md_line - 1);
+ return ret;
+}
+
+void md_free_symbols() {
+ extern std::vector<char*> md_symbols;
+ for(auto it = md_symbols.begin(); it != md_symbols.end(); ++it) {
+ free(*it);
+ }
+ md_symbols.clear();
+}
+
+/*static*/ void PatternNodePtrList::Destroy(PatternNodePtrList* list) {
+ if (list != nullptr) {
+ for (PatternNode* node : list->nodes) {
+ PatternNode::Destroy(node);
+ }
+ delete list;
+ }
+}
+
+std::string PatternNodePtrList::ToString() const {
+ std::string s;
+ bool First = true;
+ for (size_t i = 0; i < nodes.size(); i++) {
+ if (!First) {
+ s += ", ";
+ }
+ s += nodes[i]->ToString();
+ First = false;
+ }
+ return s;
+}
+
+/*static*/ void PatternNode::Destroy(PatternNode* node) {
+ if (node != nullptr) {
+ PatternNodePtrList::Destroy(node->pattern_arg_list);
+ delete node;
+ }
+}
+
+std::string PatternNode::ToString() const {
+ std::string s;
+ if (storage != nullptr && name != nullptr) {
+ s += "[" + std::string(storage) + ":" + std::string(name) + "]";
+ } else if (pattern_op) {
+ s += "(" + std::string(pattern_op);
+ if (!pattern_arg_list->empty()) {
+ s += " " + pattern_arg_list->ToString();
+ }
+ s += ")";
+ } else if (storage == nullptr && name == nullptr) {
+ s += "(empty pattern)";
+ } else {
+ assert(false);
+ s += "ERROR: inconsistent pattern";
+ }
+ return s;
+}
diff --git a/LAVM/md_grammar.h b/LAVM/md_grammar.h
new file mode 100644
index 0000000..7a42131
--- /dev/null
+++ b/LAVM/md_grammar.h
@@ -0,0 +1,80 @@
+#ifndef EXPERIMENTAL_LAVM_MD_GRAMMAR_H_
+#define EXPERIMENTAL_LAVM_MD_GRAMMAR_H_
+
+#include <string>
+#include <vector>
+
+// The PatternNodePtrList and PatternNode structs are used to represent the
+// patterns defined in the target description.
+struct PatternNode;
+
+struct PatternNodePtrList {
+ static PatternNodePtrList* CreateEmpty() {
+ return new PatternNodePtrList;
+ }
+
+ static PatternNodePtrList* Append(PatternNodePtrList* list,
+ PatternNode* node) {
+ list->nodes.push_back(node);
+ return list;
+ }
+
+ static void Destroy(PatternNodePtrList* list);
+
+ bool empty() const {
+ return nodes.empty();
+ }
+
+ std::string ToString() const;
+
+ std::vector<PatternNode*> nodes = {};
+};
+
+struct PatternNode {
+ static PatternNode* CreateLeaf(const char* storage, const char* name) {
+ PatternNode* node = new PatternNode;
+ node->storage = storage;
+ node->name = name;
+ return node;
+ }
+
+ static PatternNode* CreateInternal(const char* pattern_op,
+ PatternNodePtrList* pattern_arg_list) {
+ PatternNode* node = new PatternNode;
+ node->pattern_op = pattern_op;
+ node->pattern_arg_list = pattern_arg_list;
+ return node;
+ }
+
+ static void Destroy(PatternNode* node);
+
+ std::string ToString() const;
+
+ // leaf
+ const char* storage = nullptr;
+ const char* name = nullptr;
+ // internal node
+ const char* pattern_op = nullptr;
+ PatternNodePtrList* pattern_arg_list = nullptr;
+};
+
+// Tracks line/column in machine description file.
+extern int md_line;
+extern int md_column;
+
+// Tracks whether the list of outs or the list of ins is being processed within
+// a def clause (the value is true and false, respectively).
+extern bool parsing_outs;
+
+// Collects all heap allocated symbols.
+extern std::vector<char *> md_symbols;
+
+// Helper method to start lex/yacc parsing on the given file.
+// Returns zero for successful parsing, nonzero on failure.
+// Note that the method is not thread-safe.
+int md_read_file(const char *filename);
+
+// Frees all heap allocated symbols.
+void md_free_symbols();
+
+#endif // EXPERIMENTAL_LAVM_MD_GRAMMAR_H_
diff --git a/LAVM/md_parser.y b/LAVM/md_parser.y
new file mode 100644
index 0000000..4ca7029
--- /dev/null
+++ b/LAVM/md_parser.y
@@ -0,0 +1,217 @@
+%{
+
+#include "experimental/LAVM/md_grammar.h"
+
+#include <cstdio>
+
+extern int md_scannerleng;
+extern char* md_scannertext;
+void yyerror(const char* s);
+
+extern void def_clause_init();
+extern void def_clause_done();
+extern void def_clause_set_name(const char*);
+extern void def_clause_set_mnemonics(const char*);
+extern void def_clause_add_inout(const char*, const char*, bool);
+extern void def_clause_set_pattern(PatternNode*);
+extern void def_clause_set_llvm(const char*);
+
+// Macro magic.
+int md_scannerlex();
+#define yylex md_scannerlex
+
+%}
+
+%union {
+ char* str;
+ struct PatternNode* pattern_node_ptr;
+ struct PatternNodePtrList* pattern_node_ptr_list_ptr;
+}
+
+// All grammar tokens.
+%token MD TARGET MAP MEMORY ARROW INTEGER
+%token <str> DEF TYPEEXPR TENSOR VAR ID QID STRING
+%token DEF_MNEMONICS DEF_OUTS DEF_INS DEF_PATTERN DEF_LLVM
+
+%type <str> id def_pattern_op mnemonics llvm str_lit
+%type <pattern_node_ptr> def_pattern def_pattern_arg
+%type <pattern_node_ptr_list_ptr> optional_def_pattern_arg_list
+%type <pattern_node_ptr_list_ptr> def_pattern_arg_list
+
+
+%start clauses
+
+%%
+
+clauses : clauses clause
+ |
+ ;
+
+clause : target_clause
+ | map_clause
+ | memory_clause
+ | md_clause
+ ;
+
+md_clause : MD '{' def_clause_list '}' { printf("md clause\n"); }
+ ;
+
+def_clause_list : def_clause_list def_clause
+ |
+ ;
+
+def_clause : DEF { def_clause_init(); } QID { def_clause_set_name($3); }
+ '{' def_list '}' { def_clause_done(); }
+ ;
+
+def_list : def_list def ';'
+ |
+ ;
+
+def : DEF_MNEMONICS '=' mnemonics
+ { def_clause_set_mnemonics($3); }
+ | DEF_OUTS '='
+ '(' { parsing_outs = true; } optional_inouts_list ')'
+ | DEF_INS '='
+ '(' { parsing_outs = false; } optional_inouts_list ')'
+ | DEF_PATTERN '=' def_pattern { def_clause_set_pattern($3); }
+ | DEF_LLVM '=' llvm { def_clause_set_llvm($3); }
+ ;
+
+mnemonics : str_lit { $$ = $1; }
+ ;
+
+optional_inouts_list : inouts_list
+ |
+ ;
+
+inouts_list : inouts_list ',' inout
+ | inout
+ ;
+
+inout : id ':' id { def_clause_add_inout($1, $3, parsing_outs); }
+ ;
+
+def_pattern : '(' def_pattern_op optional_def_pattern_arg_list ')'
+ { $$ = PatternNode::CreateInternal($2, $3); }
+ ;
+
+def_pattern_op : id
+ ;
+
+optional_def_pattern_arg_list : def_pattern_arg_list { $$ = $1; }
+ | { $$ = PatternNodePtrList::CreateEmpty(); }
+ ;
+
+def_pattern_arg_list : def_pattern_arg_list ',' def_pattern_arg
+ { $$ = PatternNodePtrList::Append($1, $3); }
+ | def_pattern_arg
+ { $$ = PatternNodePtrList::Append(
+ PatternNodePtrList::CreateEmpty(), $1); }
+ ;
+
+def_pattern_arg : def_pattern { $$ = $1; }
+ | id ':' id { $$ = PatternNode::CreateLeaf($1, $3); }
+ ;
+
+llvm : id { $$ = $1; }
+ ;
+
+target_clause : TARGET '{' target_list '}' { printf("target clause\n"); }
+ ;
+
+map_clause : MAP '{' map_list '}' { printf("map clause\n"); }
+ ;
+
+memory_clause : MEMORY '{' memory_list '}' { printf("memory clause\n"); }
+ ;
+
+target_list : target_list target
+ |
+ ;
+
+map_list : map_list map
+ |
+ ;
+
+memory_list : memory_list memory
+ |
+ ;
+
+target : id '(' optional_var_list ')' ':' toplevel_type_list ';'
+ ;
+
+map : id '(' optional_var_list ')' ':' pattern_list ';'
+ ;
+
+memory : id ':' xid '=' xid ';'
+ | id ARROW mem_red ';'
+ ;
+
+optional_var_list : var_list
+ |
+ ;
+
+var_list : var_list ',' VAR
+ | VAR
+ ;
+
+toplevel_type_list : toplevel_type_list ',' toplevel_type
+ | toplevel_type
+ ;
+
+optional_type_list : type_list
+ |
+ ;
+
+type_list : type_list ',' type
+ | type
+ ;
+
+pattern_list : pattern_list ',' pattern
+ | pattern
+ ;
+
+pattern : id '(' pattern_list ')'
+ | VAR
+ ;
+
+toplevel_type : type optional_cost
+ ;
+
+type : id
+ | TYPEEXPR
+ | TENSOR
+ | '(' optional_type_list ')' ARROW type
+ | '(' optional_type_list ')' ARROW '(' optional_type_list ')'
+ ;
+
+optional_cost : '@' INTEGER
+ |
+ ;
+
+mem_red : mem_red ARROW xid
+ | xid
+ ;
+
+id : ID { $$ = $1; }
+ | QID { $$ = $1; }
+ ;
+
+xid : id
+ | TYPEEXPR
+ | str_lit
+ ;
+
+str_lit : STRING { $$ = $1; }
+ ;
+
+%%
+
+// Error method.
+void yyerror(const char* s) {
+ const int col_start = md_column + 1 - md_scannerleng;
+ const int col_end = md_column;
+ fprintf(stderr, "Machine description %s %d:%d-%d\n",
+ s, md_line, col_start, col_end);
+}
diff --git a/LAVM/md_scanner.lex b/LAVM/md_scanner.lex
new file mode 100644
index 0000000..d588884
--- /dev/null
+++ b/LAVM/md_scanner.lex
@@ -0,0 +1,73 @@
+%{
+
+#include "experimental/LAVM/md_grammar.h"
+
+#include "md_parser.y.h"
+
+// Advances line/column bookkeeping.
+static void track(bool next) {
+ if (next) {
+ ++md_line;
+ md_column = 0;
+ } else {
+ md_column += md_scannerleng;
+ }
+}
+
+// Returns current lexeme as attribute.
+static void store() {
+ char* str = strndup(md_scannertext, md_scannerleng);
+ md_symbols.push_back(str);
+ yylval.str = str;
+}
+
+%}
+
+whitespace [\t\f ]+
+letter [A-Za-z_$]
+digit [0-9]
+integer {digit}+
+identifier {letter}(({digit}|{letter})*)
+typeexpr {digit}(({digit}|{letter})*)
+qualified {identifier}"."{identifier}
+variable "%"{identifier}
+tensor {identifier}"<"({digit}|{letter})+">"
+str_lit \"([^"\n])*\"
+
+%option noyywrap
+
+%%
+
+ /* White space handling. */
+
+#.*\n { track(true); /* comment */ }
+{whitespace} { track(false); /* whitespace */ }
+\n { track(true); /* newline */ }
+
+ /* Reserved keywords and compound operators. */
+
+"md" { track(false); return MD; }
+"map" { track(false); return MAP; }
+"target" { track(false); return TARGET; }
+"memory" { track(false); return MEMORY; }
+"def" { track(false); return DEF; }
+"mnemonics" { track(false); return DEF_MNEMONICS; }
+"outs" { track(false); return DEF_OUTS; }
+"ins" { track(false); return DEF_INS; }
+"pattern" { track(false); return DEF_PATTERN; }
+"llvm" { track(false); return DEF_LLVM; }
+"->" { track(false); return ARROW; }
+
+ /* Language components and default action. */
+
+{integer} { track(false); store(); return INTEGER; }
+{identifier} { track(false); store(); return ID; }
+{typeexpr} { track(false); store(); return TYPEEXPR; }
+{qualified} { track(false); store(); return QID; }
+{variable} { track(false); store(); return VAR; }
+{tensor} { track(false); store(); return TENSOR; }
+{str_lit} { track(false); store(); return STRING; }
+
+. { track(false); return yytext[0]; }
+
+%%
diff --git a/LAVM/test/LAVM.mlir b/LAVM/test/LAVM.mlir
new file mode 100644
index 0000000..ba0b33f
--- /dev/null
+++ b/LAVM/test/LAVM.mlir
@@ -0,0 +1,33 @@
+func @elementwise(%arg0 : memref<?x99x?x?xf32>,
+ %arg1 : memref<?x99x?x?xf32>,
+ %arg2 : memref<?x99x?x?xf32>,
+ %m0 : memref<99x200xf32>,
+ %m1 : memref<200x11xf32>,
+ %m2 : memref<99x11xf32>,
+ %v0 : vector<99x99x101x103xf32>,
+ %aaa : memref<3x3x3xvector<8x128xf32>>) -> memref<?x99x?x?xf32>
+{
+ // "lair.HLadd"(%arg2, %arg0, %arg1) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> ()
+ // "lair.HLsub"(%arg2, %arg0, %arg1) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> ()
+ // "lair.HLmul"(%arg2, %arg0, %arg1) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> ()
+ // "lair.HLdot"(%arg2, %arg0, %arg1) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> ()
+ // "lair.HLmatmul"(%m2, %m0, %m1) : (memref<99x11xf32>, memref<99x200xf32>, memref<200x11xf32>) -> ()
+
+ %n = "lavm.neg"(%v0) : (vector<99x99x101x103xf32>) -> vector<99x99x101x103xf32>
+ %x = "lavm.add"(%v0, %v0) : (vector<99x99x101x103xf32>, vector<99x99x101x103xf32>) -> vector<99x99x101x103xf32>
+ %y = "lavm.sub"(%v0, %v0) : (vector<99x99x101x103xf32>, vector<99x99x101x103xf32>) -> vector<99x99x101x103xf32>
+ // %x = "lavm.add"(%arg1, %arg2) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> memref<?x99x?x?xf32>
+ %z = "lavm.mul"(%arg1, %arg2) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> memref<?x99x?x?xf32>
+ %w = "lavm.dot"(%arg1, %arg2) : (memref<?x99x?x?xf32>, memref<?x99x?x?xf32>) -> memref<?x99x?x?xf32>
+ %m = "lavm.matmul"(%m0, %m1) : (memref<99x200xf32>, memref<200x11xf32>) -> memref<99x11xf32>
+
+ // %c0 = constant 0 : index
+ // %vzero = tpu.vimm.f32 0.0
+ // affine.for %i0 = 0 to 3 {
+ // affine.for %i1 = 0 to 3 {
+ // affine.store %vzero, %aaa[%i0, %i1, %c0] : memref<3x3x3xvector<8x128xf32>>
+ // }
+ // }
+
+ return %arg2 : memref<?x99x?x?xf32>
+}
diff --git a/LAVM/test/coverage.md b/LAVM/test/coverage.md
new file mode 100644
index 0000000..2ce3868
--- /dev/null
+++ b/LAVM/test/coverage.md
@@ -0,0 +1,53 @@
+### Operations and types supported by the target. ###
+target {
+ tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>;
+ tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>;
+ tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32>
+
+ tgt.add(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32;
+ tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32;
+ tgt.mul(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32;
+
+ tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32>;
+}
+
+### Map LAVM ops to previously defined target ops or their combinations. ###
+# current restriction is that lavm ops on the lhs should appear with same arg names and count
+map {
+ lavm.neg(%a) : tgt.neg(%a);
+ lavm.neg(%a) : tgt.neg(tgt.i2f(%a));
+
+ lavm.add(%a, %b) : tgt.add(%a, %b),
+ tgt.add(tgt.i2f(%a), tgt.i2f(%b));
+
+ lavm.sub(%a, %b) : lavm.add(%a, lavm.neg(%b));
+
+ lavm.matmul(%a, %b) : tgt.matmul(%a, %b);
+}
+
+### Target memory description. ###
+memory {
+
+ # TPU-like memory:
+ HBM : size = 16G;
+ HBM : garbage = ignored;
+ VMEM : size = 16M;
+ SMEM : size = 16K;
+ CMEM : size = 16M;
+ HBM -> VMEM;
+ VMEM -> HBM;
+ HBM -> SMEM;
+ SMEM -> HBM;
+ HBM -> CMEM -> VMEM;
+ VMEM -> CMEM -> HBM;
+ # GPU-like memory:
+ GLOBAL : size = 8G;
+ SHARED : size = 16M;
+ LOCAL : size = 1M;
+ # CPU-like memory:
+ MEMORY : size = 64GB;
+ $L1 : size = 512K;
+ $L2 : size = 4M;
+ MEMORY -> $L2 -> $L1;
+ $L1 -> $L2 -> MEMORY;
+}
diff --git a/LAVM/test/coverage.mlir b/LAVM/test/coverage.mlir
new file mode 100644
index 0000000..41094c8
--- /dev/null
+++ b/LAVM/test/coverage.mlir
@@ -0,0 +1,11 @@
+func @coverage(%arg0 : vector<500x777xf32>,
+ %arg1 : vector<500x777xf32>,
+ %arg2 : vector<777x500xf32>) -> vector<500x500xf32>
+{
+ %0 = "lavm.mul"(%arg0, %arg0) : (vector<500x777xf32>, vector<500x777xf32>) -> vector<500x777xf32>
+ %1 = "lavm.add"(%arg1, %arg1) : (vector<500x777xf32>, vector<500x777xf32>) -> vector<500x777xf32>
+ %2 = "lavm.sub"(%0, %1) : (vector<500x777xf32>, vector<500x777xf32>) -> vector<500x777xf32>
+ %3 = "lavm.matmul"(%2, %arg2) : (vector<500x777xf32>, vector<777x500xf32>) -> vector<500x500xf32>
+ %4 = "lavm.neg"(%3) : (vector<500x500xf32>) -> vector<500x500xf32>
+ return %4 : vector<500x500xf32>
+}
diff --git a/LAVM/test/llvm.md b/LAVM/test/llvm.md
new file mode 100644
index 0000000..746bd62
--- /dev/null
+++ b/LAVM/test/llvm.md
@@ -0,0 +1,1970 @@
+# ###############################################################
+# Control flow
+
+def tpu.halt {
+ mnemonics = "_ = shalt${pred}";
+ outs = ();
+ ins = ();
+ pattern = (lavm.halt);
+ llvm = HALT;
+}
+
+# pseudo
+def tpu.trap {
+ mnemonics = "_ = #TRAP $p";
+ outs = ();
+ ins = (PPR:$p);
+ pattern = (lavm.trap PPR:$p);
+ llvm = TRAP;
+}
+
+# ###############################################################
+# Scalar ALU ops
+
+def tpu.addrr {
+ mnemonics = "$d = sadd.s32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set GPR:$d, (lavm.add (i32 GPR:$x), (i32 GPR:$y)));
+ llvm = ADDrr;
+}
+def tpu.addri {
+ mnemonics = "$d = sadd.s32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, i32imm:$y);
+ pattern = (set GPR:$d, (lavm.add (i32 GPR:$x), (i32 imm:$y)));
+ llvm = ADDri;
+}
++ SUB"ssub.s32", AND"sand.s32", OR"sor.u32", XOR"sxor.u32", MUL"smul.u32", SHL"sshll.u32", SRL"sshrl.u32", SRA"sshra.s32"
+
+def tpu.faddrr {
+ mnemonics = "$d = sadd.f32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set GPR:$d, (lavm.add (f32 GPR:$x), (f32 GPR:$y)));
+ llvm = FADDrr;
+}
+def tpu.faddri {
+ mnemonics = "$d = sadd.f32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, tpuf32imm:$y);
+ pattern = (set GPR:$d, (lavm.add (f32 GPR:$x), (f32 fpimm:$y)));
+ llvm = FADDri;
+}
++ FSUB"ssub.f32", FMUL"smul.f32", FMAX"smax.f32", FMIN"smin.f32"
+
+def tpu.clz {
+ mnemonics = "$Sd = sclz.u32${pred} $y";
+ outs = (GPR:$Sd);
+ ins = (GPR:$y);
+ pattern = (set GPR:$Sd, (lavm.ctlz (i32 GPR:$y)));
+ llvm = CLZ;
+}
+
+def tpu.mov {
+ mnemonics = "$Sd = smov.u32${pred} $y";
+ outs = (GPR:$Sd);
+ ins = (GPR:$y);
+ pattern = ;
+ llvm = MOV;
+}
+
+def tpu.imm {
+ mnemonics = "$Sd = simm.s32${pred} $y";
+ outs = (GPR:$Sd);
+ ins = (i32imm:$y);
+ pattern = (set GPR:$Sd, (i32 imm:$y));
+ llvm = IMM;
+}
+
+def tpu.fimm {
+ mnemonics = "$Sd = simm.f32${pred} $y";
+ outs = (GPR:$Sd);
+ ins = (tpuf32imm:$y);
+ pattern = (set GPR:$Sd, (f32 fpimm:$y));
+ llvm = FIMM;
+}
+
+# ###############################################################
+# Scalar comparison ops
+
+def tpu.cmpeqrr {
+ mnemonics = "$d = seq.s32${pred} $x, $y";
+ outs = (PPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set PPR:$d, (lavm.eq (i32 GPR:$x), (i32 GPR:$y)));
+ llvm = CMPEQrr;
+}
+def tpu.cmpeqri {
+ mnemonics = "$d = seq.s32${pred} $x, $y";
+ outs = (PPR:$d);
+ ins = (GPR:$x, i32imm:$y);
+ pattern = (set PPR:$d, (lavm.eq (i32 GPR:$x), (i32 imm:$y)));
+ llvm = CMPEQri;
+}
+def tpu.fcmpeqrr {
+ mnemonics = "$d = seq.f32${pred} $x, $y";
+ outs = (PPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set PPR:$d, (lavm.eq (f32 GPR:$x), (f32 GPR:$y)));
+ llvm = FCMPEQrr;
+}
+def tpu.fcmpeqri {
+ mnemonics = "$d = seq.f32${pred} $x, $y";
+ outs = (PPR:$d);
+ ins = (GPR:$x, tpuf32imm:$y);
+ pattern = (set PPR:$d, (lavm.eq (f32 GPR:$x), (f32 fpimm:$y)));
+ llvm = FCMPEQri;
+}
++ ne, gt, ge, lt, le
+
+multiclass FPComparePat<string OpName, PatFrag OpNode> {
+ def : Pat<(OpNode (f32 GPR:$x), (f32 fpimm:$y)),
+ (!cast<Instruction>(OpName#"ri") GPR:$x, tpuf32imm:$y)>;
+ def : Pat<(OpNode (f32 GPR:$x), (f32 GPR:$y)),
+ (!cast<Instruction>(OpName#"rr") GPR:$x, GPR:$y)>;
+}
+// Patterns for the cases where we don't care about unordered.
+defm : FPComparePat<"FCMPEQ", seteq>;
+defm : FPComparePat<"FCMPNE", setne>;
+defm : FPComparePat<"FCMPGT", setgt>;
+defm : FPComparePat<"FCMPGE", setge>;
+defm : FPComparePat<"FCMPLT", setlt>;
+defm : FPComparePat<"FCMPLE", setle>;
+
+defm CARRYOUT : IntCompareOp<"sc.u32", Carryout, 54>;
+} // Predicates = [NotBC]
+
+def : Pat<(i1 (int_tpu_addcarry (i32 GPR:$lhs), (i32 GPR:$rhs))),
+ (CARRYOUTrr GPR:$lhs, GPR:$rhs)>;
+def : Pat<(i1 (int_tpu_addcarry (i32 GPR:$lhs), (i32 imm:$rhs))),
+ (CARRYOUTri GPR:$lhs, i32imm:$rhs)>;
+
+defm WEIRD : TPUInstSany<62, (outs PPR:$Pd), (ins GPR:$Ss),
+ "$Pd = sweird.f32${pred} $Ss",
+ [(set PPR:$Pd, (int_tpu_weird_f32 (f32 GPR:$Ss)))]>, Requires<[NotBC]>;
+
+# ###############################################################
+# Scalar conversion ops
+
+ let Predicates = [NotBC] in {
+def FPTOSIrr : TPUInstP<(outs GPR:$Sd), (ins GPR:$x, GPR:$y),
+ "$Sd = scvt.f32.s32${pred} $x, $y",
+ [(set GPR:$Sd,
+ (int_tpu_cvt_fptosi (f32 GPR:$x), (i32 GPR:$y)))]>,
+ Bundle<B_Sany>, Sched<[WriteFPConvert]>;
+def FPTOSIri : TPUInstP<(outs GPR:$Sd), (ins GPR:$x, i32imm:$y),
+ "$Sd = scvt.f32.s32${pred} $x, $y",
+ [(set GPR:$Sd,
+ (int_tpu_cvt_fptosi (f32 GPR:$x), (i32 imm:$y)))]>,
+ Bundle<B_Sany>, BundleImmSy, Sched<[WriteFPConvert]>;
+def SITOFPr : TPUInstP<(outs GPR:$Sd), (ins GPR:$x),
+ "$Sd = scvt.s32.f32${pred} $x",
+ [(set GPR:$Sd, (sint_to_fp (i32 GPR:$x)))]>,
+ Bundle<B_Sany>, Sched<[WriteFPConvert]>;
+def SITOFPi : TPUInstP<(outs GPR:$Sd), (ins i32imm:$x),
+ "$Sd = scvt.s32.f32${pred} $x",
+ [(set GPR:$Sd, (sint_to_fp (i32 imm:$x)))]>,
+ Bundle<B_Sany>, BundleImmSy, Sched<[WriteFPConvert]>;
+def : Pat<(i32 (fp_to_sint (f32 GPR:$x))), (FPTOSIri GPR:$x, (i32 -1))>;
+} // Predicates = [NotBC]
+//===----------------------------------------------------------------------===//
+// Predicate manipulation ops
+//===----------------------------------------------------------------------===//
+def pnot : PatFrag<(ops node:$a),
+ (xor node:$a, (i1 -1))>;
+let Predicates = [NotBC] in {
+// The only predicate op is POR. It can negate any of its operands, so we can
+// create PMOV, PNOT, PSET and PCLEAR
+def POR : TPUInstP<(outs PPR:$Pd), (ins PPR:$Ps, PPR:$Pt),
+ "$Pd = por${pred} $Ps, $Pt",
+ [(set PPR:$Pd, (or PPR:$Ps, PPR:$Pt))]>, Bundle<B_Sany>;
+
+def PNAND : TPUInstP<(outs PPR:$Pd), (ins PPR:$Ps, PPR:$Pt),
+ "$Pd = por${pred} !$Ps, !$Pt",
+ [(set PPR:$Pd, (pnot (and PPR:$Ps, PPR:$Pt)))]>, Bundle<B_Sany>;
+
+def PMOV : TPUInstP<(outs PPR:$Pd), (ins PPR:$Ps),
+ "$Pd = por${pred} $Ps, $Ps",
+ []>, Bundle<B_Sany>;
+
+def : Pat<(pnot PPR:$x), (PNAND PPR:$x, PPR:$x)>;
+def : Pat<(setne PPR:$x, (i1 -1)), (PNAND PPR:$x, PPR:$x)>;
+
+// PORii takes two immediates it is used for PSET or PCLEAR. We need two
+// operands to allow MCParser to work correctly.
+def PORii : TPUInstP<(outs PPR:$Pd), (ins i1imm:$val0, i1imm:$val1),
+ "$Pd = por${pred} $val0, $val1",
+ []>, Bundle<B_Sany>;
+
+def : Pat<(i1 imm:$val), (PORii imm:$val, imm:$val)>;
+def : Pat<(i1 (trunc (i32 GPR:$x))), (CMPEQri (ANDri $x, (i32 1)), (i32 1))>;
+def : Pat<(i1 (and PPR:$Ps, PPR:$Pt)),(PNAND (PNAND $Ps, $Pt), (PNAND $Ps, $Pt))>;
+
+// DAG combine may convert XOR i1 %x, -1 to setcc.
+def : Pat<(i1 (setcc (i32 (zext PPR:$p)), (i32 1), SETNE)), (PNAND $p, $p)>;
+
+let Constraints = "$d = $a", isPseudo = 1 in {
+ def PSELrr : TPUInst<(outs PPR:$d), (ins PPR:$p, PPR:$a, PPR:$b),
+ "$d = #PSEL $p, $a, $b",
+ [(set PPR:$d, (select PPR:$p, PPR:$a, PPR:$b))]>,
+ Bundle<B_Sany>;
+}
+} // Predicates = [NotBC]
+
+def FPZero : PatFrag<(ops), (fpimm), [{
+ return cast<ConstantFPSDNode>(N)->isZero();
+}]>;
+
+// Match clamp(a, 0, b) - clamp a between [0, b].
+def Relu : PatFrag<(ops node:$a, node:$b),
+ (fmaximum (fminimum node:$a, node:$b),
+ (Splat FPZero))>;
+
+# ###############################################################
+# Load/store ops
+
+def tpu.sldri {
+ mnemonics = "$Sd = sld${pred} [smem:${Ss}+$imm]";
+ outs = (GPR:$Sd);
+ ins = (GPR:$Ss, i32imm:$imm);
+ pattern = (set GPR:$Sd, (i32 (lavm.sld (add GPR:$Ss, imm:$imm))));
+ llvm = SLDri;
+}
+def tpu.sldi {
+ mnemonics = "$Sd = sld${pred} [smem:$imm]";
+ outs = (GPR:$Sd);
+ ins = (i32imm:$imm);
+ pattern = (set GPR:$Sd, (i32 (lavm.sld (Wrapper tglobaladdr:$imm))));
+ llvm = SLDi;
+}
+def tpu.sldrr {
+ mnemonics = "$Sd = sld${pred} [smem:${Sx}+${Sy}]";
+ outs = (GPR:$Sd);
+ ins = (GPR:$Sx, GPR:$Sy);
+ pattern = (set GPR:$Sd, (i32 (lavm.sld (add GPR:$Sx, GPR:$Sy))));
+ llvm = SLDrr;
+}
++ (f32 (lavm.sld (Wrapper tglobaladdr:$imm))), (SLDi imm:$imm)
+ (f32 (lavm.sld (imm:$imm))), (SLDi imm:$imm)
+ (i32 (lavm.sld (imm:$imm))), (SLDi imm:$imm)
+ (i32 (lavm.sld GPR:$Ss)), (SLDri GPR:$Ss, (i32 0))
+ (f32 (lavm.sld (add GPR:$Ss, imm:$imm))), (SLDri GPR:$Ss, imm:$imm)
+ (f32 (lavm.sld (add GPR:$Sx, GPR:$Sy))), (SLDrr GPR:$Sx, GPR:$Sy)
+ (f32 (lavm.sld GPR:$Ss)), (SLDri GPR:$Ss, (i32 0))
+
+
+// Scalar store to smem.
+let mayStore = 1 in {
+def SSTr : TPUInstP<(outs), (ins GPR:$Sval, GPR:$Ss),
+ "[smem:${Ss}] = sst${pred} $Sval",
+ [(store_smem (i32 GPR:$Sval), GPR:$Ss)]>,
+ Bundle<B_SST>;
+def SSTi : TPUInstP<(outs), (ins GPR:$Sval, i32imm:$imm),
+ "[smem:${imm}] = sst${pred} $Sval",
+ [(store_smem (i32 GPR:$Sval), (Wrapper tglobaladdr:$imm))]>,
+ Bundle<B_SST>, BundleImmSy;
+}
+
+def : Pat<(store_smem (f32 GPR:$Sval), (Wrapper tglobaladdr:$imm)),
+ (SSTi GPR:$Sval, imm:$imm)>;
+def : Pat<(store_smem (f32 GPR:$Sval), (imm:$imm)),
+ (SSTi GPR:$Sval, imm:$imm)>;
+def : Pat<(store_smem (i32 GPR:$Sval), (imm:$imm)),
+ (SSTi GPR:$Sval, imm:$imm)>;
+def : Pat<(store_smem (f32 GPR:$Sval), (i32 GPR:$Ss)),
+ (SSTr GPR:$Sval, GPR:$Ss)>;
+
+
+//===----------------------------------------------------------------------===//
+// Branch ops
+//===----------------------------------------------------------------------===//
+// The BR pseudo is used from early to late code generation to represent the
+// branching point; the point at which control flow changes.
+//
+// The BRrel instruction is the actual branch instruction; it has a delay slot
+// so is inserted one cycle before the BR pseudo was scheduled at.
+//
+// (brcond) SDNodes are custom selected to a BR.
+//===----------------------------------------------------------------------===//
+
+// Relative branch. This is used for both unconditional and conditional jumps.
+class RelTargetOperand<ValueType VT> : Operand<VT> {
+ let OperandType = "OPERAND_PCREL";
+}
+
+let Predicates = [NotBC] in {
+let isBranch = 1, isTerminator = 1 in {
+def BRrel : TPUInstP<(outs), (ins RelTargetOperand<OtherVT>:$target),
+ "(pc) = sbr.rel${pred} $target",
+ []>, Bundle<B_S0>, BundleImm<IMM_0>;
+}
+
+let isPseudo = 1, isBranch = 1, isTerminator = 1, isBarrier = 1 in {
+// Pseudo to model the actual change in control flow, after the delay slot ends.
+def BR : TPUInstP<(outs), (ins Operand<OtherVT>:$target),
+ "(pc) = #BR${pred} $target", [(br bb:$target)]>,
+ Bundle<B_S0>, BundleImm<IMM_0>;
+}
+} // Predicates = [NotBC]
+
+//===----------------------------------------------------------------------===//
+// Pseudo ops
+//===----------------------------------------------------------------------===//
+let Predicates = [NotBC] in {
+let isPseudo = 1 in {
+ // Pseudo-select instruction. Note that this is lowered to either a predicated
+ // IMM or MOV, so we don't have an ii version and it doesn't take a predicate.
+ let Constraints = "$d = $a" in {
+ def SELrr : TPUInst<(outs GPR:$d), (ins PPR:$p, GPR:$a, GPR:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set GPR:$d, (select PPR:$p, (i32 GPR:$a), (i32 GPR:$b)))]>,
+ Bundle<B_Sany>;
+ def SELri : TPUInst<(outs GPR:$d), (ins PPR:$p, GPR:$a, i32imm:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set GPR:$d, (select PPR:$p, (i32 GPR:$a), (i32 imm:$b)))]>,
+ Bundle<B_Sany>;
+ }
+ let Constraints = "$d = $b" in {
+ def SELir : TPUInst<(outs GPR:$d), (ins PPR:$p, i32imm:$a, GPR:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set GPR:$d, (select PPR:$p, (i32 imm:$a), (i32 GPR:$b)))]>,
+ Bundle<B_Sany>;
+ }
+}
+def : Pat<(select PPR:$p, (f32 GPR:$a), (f32 GPR:$b)),
+ (SELrr PPR:$p, GPR:$a, GPR:$b)>;
+def : Pat<(select PPR:$p, (f32 GPR:$a), (f32 fpimm:$b)),
+ (SELri PPR:$p, GPR:$a, (ftoi $b))>;
+def : Pat<(select PPR:$p, (f32 fpimm:$a), (f32 GPR:$b)),
+ (SELir PPR:$p, (ftoi $a), GPR:$b)>;
+
+def : Pat<(i32 (zext PPR:$x)), (SELri PPR:$x, (IMM 1), 0)>;
+def : Pat<(i32 (sext PPR:$x)), (SELri PPR:$x, (IMM -1), 0)>;
+} // Predicates = [NotBC]
+//===----------------------------------------------------------------------===//
+// Misc ops
+//===----------------------------------------------------------------------===//
+let Predicates = [NotBC] in {
+
+let hasSideEffects = 1 in {
+def SNOP : TPUInstP<(outs), (ins), "_ = snop${pred}", [(int_tpu_nop)]>;
+// Hardware doesn't have vnop, we encode it as vm0 vmov vm0 so it needs a misc
+// slot.
+def VNOP : TPUInstP<(outs), (ins), "_ = vnop${pred}", []>, Bundle<B_Misc>,
+ IsVectorInstruction;
+}
+
+let hasSideEffects = 1, mayStore = 1 in {
+def SFENCE : TPUInstP<(outs), (ins),
+ "_ = sfence${pred}",
+ []>, Bundle<B_Sany>;
+}
+
+
+# ###############################################################
+# DMA Operations
+
+def tpu.dmarrrr {
+ mnemonics = "[smem:${dst}], [sflag:${sflag}] = dma.local${pred} [hbm:${src}], ${len}";
+ outs = ();
+ ins = (GPR:$dst, GPR:$sflag, GPR:$src, GPR:$len);
+ pattern = (lavm.int_tpu_dma_hbm_to_smem GPR:$sflag, GPR:$src, GPR:$dst, (i32 GPR:$len));
+ llvm = DMA_HBM_TO_SMEMrrrr;
+}
+def tpu.dmarirr {
+ mnemonics = "[smem:${dst}], [sflag:${sflag}] = dma.local${pred} [hbm:${src}], ${len}";
+ outs = ();
+ ins = (GPR:$dst, i32imm:$sflag, GPR:$src, GPR:$len);
+ pattern = (lavm.int_tpu_dma_hbm_to_smem (Wrapper tglobaladdr:$sflag), GPR:$src, GPR:$dst, (i32 GPR:$len));
+ llvm = DMA_HBM_TO_SMEMrirr;
+}
+def tpu.dmariri {
+ mnemonics = "[smem:${dst}], [sflag:${sflag}] = dma.local${pred} [hbm:${src}], ${len}";
+ outs = ();
+ ins = (GPR:$dst, i32imm:$sflag, GPR:$src, i32imm:$len);
+ pattern = (lavm.int_tpu_dma_hbm_to_smem (Wrapper tglobaladdr:$sflag), GPR:$src, GPR:$dst, (i32 imm:$len));
+ llvm = DMA_HBM_TO_SMEMriri;
+}
+def tpu.dmarrri {
+ mnemonics = "[smem:${dst}], [sflag:${sflag}] = dma.local${pred} [hbm:${src}], ${len}";
+ outs = ();
+ ins = (GPR:$dst, GPR:$sflag, GPR:$src, i32imm:$len);
+ pattern = (lavm.int_tpu_dma_hbm_to_smem GPR:$sflag, GPR:$src, GPR:$dst, (i32 imm:$len));
+ llvm = DMA_HBM_TO_SMEMrrri;
+}
++ DMA_HBM_TO_TMEM, DMA_HBM_TO_VMEM, DMA_HBM_TO_SPMEM, DMA_HBM_TO_HBM, DMA_SMEM_TO_HBM, DMA_VMEM_TO_HBM,
+ DMA_TMEM_TO_HBM, DMA_SPMEM_TO_HBM, DMA_SPMEM_TO_SPMEM, DMA_TMEM_TO_SPMEM, DMA_SPMEM_TO_TMEM
+
+// TODO(thomasraoux): Mark those instructions as using all Vs slots.
+multiclass DMAStrided<string srcmem, string dstmem, DAGOperand FlagT,
+ PatFrag PatFlagType, DAGOperand LenT, DAGOperand PatLenType,
+ Intrinsic strided_intr =
+ !cast<Intrinsic>("int_tpu_dma_"#srcmem#"_to_"#dstmem#"_single_strided")> {
+ def "" : TPUInstP<(outs), (ins GPR:$dst, FlagT:$sflag, GPR:$src, LenT:$len,
+ GPR:$srcs, GPR:$dsts, GPR:$els),
+ "["#dstmem#":${dst}@${dsts}], [sflag:${sflag}] = dma.strided${pred} ["#srcmem#":${src}@${srcs}], length:${len}, elements_per_stride:${els}",
+ [(strided_intr PatFlagType, GPR:$src, GPR:$dst, (i32 PatLenType:$len),
+ GPR:$srcs, GPR:$dsts, GPR:$els)]>,
+ Bundle<B_Sboth>, Sched<[WriteDmaLocal]>;
+}
+
+multiclass DMAGeneral<string srcmem, string dstmem, DAGOperand LenT,
+ DAGOperand PatLenType, DAGOperand DescT, DAGOperand PatDescType,
+ Intrinsic general_intr =
+ !cast<Intrinsic>("int_tpu_dma_"#srcmem#"_to_"#dstmem#"_general")> {
+ def "" : TPUInstP<(outs), (ins GPR:$dst, GPR:$dstflags, GPR:$src,
+ GPR:$sflag, LenT:$len, DescT:$desc,
+ i32imm:$scount, GPR:$override),
+ "["#dstmem#":${dst}], [sflag:${dstflags}] = dma.general${pred} ["#srcmem#":${src}], [sflag:${sflag}], length:${len}, [smem:${desc}], stride_count:${scount}, ici_dest:${override}",
+ [(general_intr GPR:$dstflags, GPR:$src, GPR:$dst, PatLenType:$len, GPR:$sflag,
+ (i32 PatDescType:$desc), (i32 imm:$scount), GPR:$override)]>,
+ Bundle<B_Sboth>, Sched<[WriteDmaLocal]>;
+}
+
+multiclass DMA_Extended<string srcmem, string dstmem> {
+defm STRIDEDrr : DMAStrided<srcmem, dstmem, GPR, (i32 GPR:$sflag), GPR, GPR>;
+defm STRIDEDri :
+ DMAStrided<srcmem, dstmem, i32imm, (Wrapper tglobaladdr:$sflag), GPR, GPR>,
+ BundleImmSy;
+defm STRIDEDir : DMAStrided<srcmem, dstmem, GPR, (i32 GPR:$sflag), i32imm, imm>,
+ BundleImmSy;
+defm STRIDEDii :
+ DMAStrided<srcmem, dstmem, i32imm, (Wrapper tglobaladdr:$sflag), i32imm, imm>,
+ BundleImmSy<[IMM_OP_0, IMM_OP_1]>;
+
+defm GENERALrr : DMAGeneral<srcmem, dstmem, GPR, GPR, GPR, GPR>;
+defm GENERALri : DMAGeneral<srcmem, dstmem, i32imm, imm, GPR, GPR>, BundleImmSy;
+defm GENERALir : DMAGeneral<srcmem, dstmem, GPR, GPR, i32imm, imm>, BundleImmSy;
+defm GENERALii : DMAGeneral<srcmem, dstmem, i32imm, imm, i32imm, imm>, BundleImmSy<[IMM_OP_0, IMM_OP_1]>;
+}
+defm DMA_HBM_TO_SMEM : DMA_Extended<"hbm", "smem">;
+defm DMA_HBM_TO_VMEM : DMA_Extended<"hbm", "vmem">;
+defm DMA_SMEM_TO_HBM : DMA_Extended<"smem", "hbm">;
+defm DMA_VMEM_TO_HBM : DMA_Extended<"vmem", "hbm">;
+
+let hasSideEffects = 1 in {
+// DMA Descriptor
+def DMADescr : TPUInstP<(outs), (ins GPR:$desc),
+ "_ = dma.desc${pred} [smem:${desc}]",
+ [(int_tpu_dma_descriptor GPR:$desc)]>,
+ Bundle<B_S1>, Sched<[WriteDmaLocal]>;
+def DMADesci : TPUInstP<(outs), (ins i32imm:$desc),
+ "_ = dma.desc${pred} [smem:${desc}]",
+ [(int_tpu_dma_descriptor (i32 imm:$desc))]>,
+ Bundle<B_S1>, BundleImmSy, Sched<[WriteDmaLocal]>;
+}
+} // mayLoad = 1, mayStore = 1
+
+let hasSideEffects = 1, mayLoad = 1, mayStore = 1 in {
+// Note that this instruction is not predicated. The predicate goes at the end of
+// the operand uses list, and the length of this list is not known at compile time.
+def EVENT : TPUInst<(outs), (ins i32imm:$tag, reglist:$regs, variable_ops),
+ "_ = event $tag$regs", []>, Bundle<B_Misc>;
+def EVENT_NULLARY : TPUInst<(outs), (ins i32imm:$tag),
+ "_ = event $tag", []>, Bundle<B_Misc>;
+}
+
+def SDELAY : TPUInstP<(outs), (ins i32imm:$cycles),
+ "_ = sdelay${pred} $cycles", []>, Bundle<B_Sany>;
+def VDELAY : TPUInstP<(outs), (ins i32imm:$cycles),
+ "_ = vdelay${pred} $cycles", []>, Bundle<B_Misc>, IsVectorInstruction;
+// For delay longer than 8 cycles we need to use immediate slots.
+def VDELAY_LONG : TPUInstP<(outs), (ins i32imm:$cycles),
+ "_ = vdelay${pred} $cycles", []>, Bundle<B_Misc>,
+ BundleImmVy<[IMM_OP_0], IMM_2_to_5>, IsVectorInstruction;
+} // Predicates = [NotBC]
+
+//===----------------------------------------------------------------------===//
+// Vector ALU ops
+//===----------------------------------------------------------------------===//
+multiclass VIntALUOp<VIntALUOpEncoding enc, string Name, PatFrag OpNode, BundleSlot Slot, string XY> {
+ // Register-immediate - full 32-bit immediate.
+ defm ri : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, i32imm:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNi32 VPR_AGG:$Vd), (OpNode (vNi32 VPR_AGG:$x),
+ (vNi32 (Splat imm:$y))))]>,
+ BundleImmVy, IsVectorInstruction;
+ // Register-scalar - splat a scalar into all lanes of a vector.
+ defm rs : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, GPR:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNi32 VPR_AGG:$Vd), (OpNode (vNi32 VPR_AGG:$x),
+ (vNi32 (Splat (i32 GPR:$y)))))]>,
+ IsVectorInstruction;
+ // Register-register.
+ defm rr : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, VPR_AGG:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNi32 VPR_AGG:$Vd), (OpNode (vNi32 VPR_AGG:$x),
+ (vNi32 VPR_AGG:$y)))]>,
+ IsVectorInstruction;
+}
+multiclass VFPALUOp<VIntALUOpEncoding enc, string Name, PatFrag OpNode, BundleSlot Slot, string XY> {
+ defm ri : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, tpuf32imm:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 VPR_AGG:$x),
+ (vNf32 (Splat fpimm:$y))))]>,
+ BundleImmVy, IsVectorInstruction;
+ defm rs : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, GPR:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 VPR_AGG:$x),
+ (vNf32 (Splat (f32 GPR:$y)))))]>,
+ IsVectorInstruction;
+ defm rr : TPUInst<Slot, enc, (outs VPR_AGG:$Vd), (ins VPR_AGG:$x, VPR_AGG:$y),
+ "$Vd = " # Name # "${pred} " # XY,
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 VPR_AGG:$x),
+ (vNf32 VPR_AGG:$y)))]>,
+ IsVectorInstruction;
+}
+
+multiclass VIntALUOpXY<bits<6> opc, string Name, PatFrag OpNode, BundleSlot Slot = B_Vany> :
+ VIntALUOp<VIntALUOpEncoding<opc>, Name, OpNode, Slot, "$x, $y">;
+multiclass VIntALUOpYX<bits<6> opc, string Name, PatFrag OpNode, BundleSlot Slot = B_Vany> :
+ VIntALUOp<VIntALUOpEncoding<opc>, Name, OpNode, Slot, "$y, $x">;
+multiclass VFPALUOpXY<bits<6> opc, string Name, PatFrag OpNode, BundleSlot Slot = B_Vany> :
+ VFPALUOp<VIntALUOpEncoding<opc>, Name, OpNode, Slot, "$x, $y">;
+multiclass VFPALUOpYX<bits<6> opc, string Name, PatFrag OpNode, BundleSlot Slot = B_Vany> :
+ VFPALUOp<VIntALUOpEncoding<opc>, Name, OpNode, Slot, "$y, $x">;
+
+// Because the order of the immediate matters for VSUB, define it separately.
+// See description for scalar ALU ops; the same applies here.
+defm VSUBir : TPUInst<B_Vany, VIntALUOpEncoding<1>, (outs VPR_AGG:$Vd), (ins i32imm:$y, VPR_AGG:$x),
+ "$Vd = vsub.s32${pred} $y, $x",
+ [(set VPR_AGG:$Vd, (sub (vNi32 (Splat imm:$y)),
+ (vNi32 VPR_AGG:$x)))],
+ YOpIdx1>,
+ BundleImmVy, IsVectorInstruction;
+defm VSUBsr : TPUInst<B_Vany, VIntALUOpEncoding<1>, (outs VPR_AGG:$Vd), (ins GPR:$y, VPR_AGG:$x),
+ "$Vd = vsub.s32${pred} $y, $x",
+ [(set VPR_AGG:$Vd, (sub (vNi32 (Splat GPR:$y)),
+ (vNi32 VPR_AGG:$x)))],
+ YOpIdx1>, IsVectorInstruction;
+defm VSUBrr : TPUInst<B_Vany, VIntALUOpEncoding<1>, (outs VPR_AGG:$Vd), (ins VPR_AGG:$y, VPR_AGG:$x),
+ "$Vd = vsub.s32${pred} $y, $x",
+ [(set VPR_AGG:$Vd, (sub (vNi32 VPR_AGG:$y),
+ (vNi32 VPR_AGG:$x)))],
+ YOpIdx1>, IsVectorInstruction;
+
+// Non commutative version of VFPALUOpYX.
+multiclass VFPALUOpYX_NC<bits<6> opc, string Name, PatFrag OpNode, BundleSlot Slot = B_Vany> {
+defm ir : TPUInst<Slot, VIntALUOpEncoding<opc>, (outs VPR_AGG:$Vd), (ins tpuf32imm:$y, VPR_AGG:$x),
+ "$Vd = " # Name # "${pred} $y, $x",
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 (Splat fpimm:$y)),
+ (vNf32 VPR_AGG:$x)))],
+ YOpIdx1>,
+ BundleImmVy, IsVectorInstruction;
+defm sr : TPUInst<Slot, VIntALUOpEncoding<opc>, (outs VPR_AGG:$Vd), (ins GPR:$y, VPR_AGG:$x),
+ "$Vd = " # Name # "${pred} $y, $x",
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 (Splat GPR:$y)),
+ (vNf32 VPR_AGG:$x)))],
+ YOpIdx1>, IsVectorInstruction;
+defm rr : TPUInst<Slot, VIntALUOpEncoding<opc>, (outs VPR_AGG:$Vd), (ins VPR_AGG:$y, VPR_AGG:$x),
+ "$Vd = " # Name # "${pred} $y, $x",
+ [(set (vNf32 VPR_AGG:$Vd), (OpNode (vNf32 VPR_AGG:$y),
+ (vNf32 VPR_AGG:$x)))],
+ YOpIdx1>, IsVectorInstruction;
+}
+
+let isMoveReg = 1 in {
+defm VMOVr : TPUInstVany<VIntALUUnOpEncoding<31>, (outs VPR:$Vd), (ins VPR:$y),
+ "$Vd = vmov${pred} $y", [], YOpIdx1>, IsVectorInstruction;
+}
+defm VMOVs : TPUInstVany<VIntALUUnOpEncoding<31>, (outs VPR:$Vd), (ins GPR:$y),
+ "$Vd = vmov${pred} $y",
+ [(set VPR:$Vd, (vNf32 (Splat GPR:$y)))], YOpIdx1>, IsVectorInstruction;
+let isMoveImm = 1 in {
+defm VIMMF : TPUInstVany<VIntALUUnOpEncoding<31>, (outs VPR:$Vd), (ins tpuf32imm:$y),
+ "$Vd = vimm.f32${pred} $y",
+ [(set VPR:$Vd, (vNf32 (Splat fpimm:$y)))], YOpIdx1>,
+ BundleImmVy, IsVectorInstruction;
+// Need a special version for int version as immediates are handled differently
+defm VIMMI : TPUInstVany<VIntALUUnOpEncoding<31>, (outs VPR:$Vd), (ins i32imm:$y),
+ "$Vd = vimm.s32${pred} $y",
+ [(set VPR:$Vd, (vNi32 (Splat (i32 imm:$y))))], YOpIdx1>,
+ BundleImmVy, IsVectorInstruction;
+}
+def : Pat<(vNi32 (Splat GPR:$y)), (VMOVs GPR:$y)>;
+
+let Predicates = [HasVPU] in {
+defm VADD : VIntALUOpYX<0, "vadd.s32", add>;
+defm VAND : VIntALUOpYX<2, "vand.u32", and>;
+defm VOR : VIntALUOpYX<3, "vor.u32", or>;
+defm VXOR : VIntALUOpYX<4, "vxor.u32", xor>;
+defm VFADD : VFPALUOpYX<5, "vadd.f32", fadd, B_V1>, Sched<[WriteFadd]>;
+defm VFSUB : VFPALUOpYX_NC<6, "vsub.f32", fsub, B_V1>, Sched<[WriteFadd]>;
+defm VFMUL : VFPALUOpYX<7, "vmul.f32", fmul, B_V0>, Sched<[WriteFmul]>;
+defm VFMAX : VFPALUOpXY<8, "vmax.f32", fmaximum>;
+defm VFMIN : VFPALUOpXY<9, "vmin.f32", fminimum>;
+defm VSHL : VIntALUOpXY<10, "vshll.u32", shl, B_V1>;
+defm VSRL : VIntALUOpXY<11, "vshrl.u32", srl, B_V1>;
+defm VSRA : VIntALUOpXY<12, "vshra.s32", sra, B_V1>;
+defm VRSRA : VIntALUOpXY<13, "vrshra.s32", int_tpu_vrshra, B_V1>, Sched<[WriteVrshra]>;
+
+def : Pat<(vNi32 (int_tpu_shll (vNi32 VPR:$lhs), (vNi32 VPR:$rhs))),
+ (VSHLrr VPR:$lhs, VPR:$rhs)>;
+def : Pat<(vNi32 (int_tpu_shll (vNi32 VPR:$lhs), (vNi32 (Splat (i32 imm:$rhs))))),
+ (VSHLri VPR:$lhs, (i32 imm:$rhs))>;
+def : Pat<(vNi32 (int_tpu_shll (vNi32 VPR:$lhs), (vNi32 (Splat (i32 GPR:$rhs))))),
+ (VSHLrs VPR:$lhs, GPR:$rhs)>;
+def : Pat<(vNi32 (int_tpu_shrl (vNi32 VPR:$lhs), (vNi32 VPR:$rhs))),
+ (VSRLrr VPR:$lhs, VPR:$rhs)>;
+def : Pat<(vNi32 (int_tpu_shrl (vNi32 VPR:$lhs), (vNi32 (Splat (i32 imm:$rhs))))),
+ (VSRLri VPR:$lhs, (i32 imm:$rhs))>;
+def : Pat<(vNi32 (int_tpu_shrl (vNi32 VPR:$lhs), (vNi32 (Splat (i32 GPR:$rhs))))),
+ (VSRLrs VPR:$lhs, GPR:$rhs)>;
+def : Pat<(vNi32 (int_tpu_shra (vNi32 VPR:$lhs), (vNi32 VPR:$rhs))),
+ (VSRArr VPR:$lhs, VPR:$rhs)>;
+def : Pat<(vNi32 (int_tpu_shra (vNi32 VPR:$lhs), (vNi32 (Splat (i32 imm:$rhs))))),
+ (VSRAri VPR:$lhs, (i32 imm:$rhs))>;
+def : Pat<(vNi32 (int_tpu_shra (vNi32 VPR:$lhs), (vNi32 (Splat (i32 GPR:$rhs))))),
+ (VSRArs VPR:$lhs, (i32 GPR:$rhs))>;
+
+defm VCLAMPZ : VFPALUOpXY<30, "vclamp.gez.f32", Relu>;
+}
+let Predicates = [HasPxcVPU] in {
+defm VCLAMPS : VFPALUOpXY<54, "vclamps.f32", int_tpu_clamp_symmetric>;
+}
+
+// XOR can go in slot0 or 1 while fsub can only go in slot1.
+def : Pat<(vNf32 (fneg VPR:$x)), (VXORri VPR:$x, (i32 0x80000000))>;
+
+//===----------------------------------------------------------------------===//
+// Vector conversion ops
+//===----------------------------------------------------------------------===//
+let Predicates = [NotBC, HasVPU] in {
+def VFPTOSIrr : TPUInstP<(outs VPR:$Sd), (ins VPR:$x, VPR:$y),
+ "$Sd = vcvt.f32.s32${pred} $x, $y",
+ [(set VPR:$Sd,
+ (int_tpu_cvt_fptosi (vNf32 VPR:$x), (vNi32 VPR:$y)))]>,
+ Bundle<B_Vany>, Sched<[WriteFPConvert]>, IsVectorInstruction;
+def VFPTOSIrs : TPUInstP<(outs VPR:$Sd), (ins VPR:$x, GPR:$y),
+ "$Sd = vcvt.f32.s32${pred} $x, $y",
+ [(set VPR:$Sd,
+ (int_tpu_cvt_fptosi (vNf32 VPR:$x), (vNi32 (Splat (i32 GPR:$y)))))]>,
+ Bundle<B_Vany>, Sched<[WriteFPConvert]>, IsVectorInstruction;
+def VFPTOSIri : TPUInstP<(outs VPR:$Sd), (ins VPR:$x, i32imm:$y),
+ "$Sd = vcvt.f32.s32${pred} $x, $y",
+ [(set VPR:$Sd,
+ (int_tpu_cvt_fptosi (vNf32 VPR:$x), (vNi32 (Splat (i32 imm:$y)))))]>,
+ Bundle<B_Vany>, BundleImmVy, Sched<[WriteFPConvert]>, IsVectorInstruction;
+def VSITOFPr : TPUInstP<(outs VPR:$Vd), (ins VPR:$Vs),
+ "$Vd = vcvt.s32.f32${pred} $Vs",
+ [(set VPR:$Vd, (sint_to_fp (vNi32 VPR:$Vs)))]>,
+ Bundle<B_Vany>, Sched<[WriteFPConvert]>, IsVectorInstruction;
+def : Pat<(vNi32 (fp_to_sint (vNf32 VPR:$x))), (VFPTOSIri VPR:$x, (i32 -1))>;
+}
+
+//===----------------------------------------------------------------------===//
+// Vector comparison ops
+//===----------------------------------------------------------------------===//
+multiclass VIntCompareOp<bits<6> opc, string Name, PatFrag OpNode> {
+ defm ri : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, i32imm:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNi32 VPR_AGG:$x), (Splat (i32 imm:$y))))]>,
+ BundleImmVy, IsVectorInstruction;
+ defm rs : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, GPR:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNi32 VPR_AGG:$x),
+ (vNi32 (Splat (i32 GPR:$y)))))]>,
+ IsVectorInstruction;
+ defm rr : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, VPR_AGG:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNi32 VPR_AGG:$x), (vNi32 VPR_AGG:$y)))]>,
+ IsVectorInstruction;
+}
+
+multiclass VFPCompareOp<bits<6> opc, string Name, PatFrag OpNode> {
+ defm ri : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, tpuf32imm:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNf32 VPR_AGG:$x), (Splat (f32 fpimm:$y))))]>,
+ BundleImmVy, IsVectorInstruction;
+ defm rs : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, GPR:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNf32 VPR_AGG:$x),
+ (vNf32 (Splat (f32 GPR:$y)))))]>,
+ IsVectorInstruction;
+ defm rr : TPUInst<B_Vany, VIntALUOpEncoding<opc>, (outs MPR:$Md), (ins VPR_AGG:$x, VPR_AGG:$y),
+ "$Md = " # Name # "${pred} $x, $y",
+ [(set MPR:$Md, (OpNode (vNf32 VPR_AGG:$x), (vNf32 VPR_AGG:$y)))]>,
+ IsVectorInstruction;
+}
+
+let Predicates = [HasVPU] in {
+defm VCMPEQ : VIntCompareOp<32, "veq.s32", seteq>;
+defm VCMPNE : VIntCompareOp<33, "vne.s32", setne>;
+defm VCMPGT : VIntCompareOp<34, "vgt.s32", setgt>;
+defm VCMPGE : VIntCompareOp<35, "vge.s32", setge>;
+defm VCMPLT : VIntCompareOp<36, "vlt.s32", setlt>;
+defm VCMPLE : VIntCompareOp<37, "vle.s32", setle>;
+defm VFCMPEQ : VFPCompareOp<40, "veq.f32", setoeq>;
+defm VFCMPNE : VFPCompareOp<41, "vne.f32", setune>;
+defm VFCMPGT : VFPCompareOp<42, "vgt.f32", setogt>;
+defm VFCMPGE : VFPCompareOp<43, "vge.f32", setoge>;
+defm VFCMPLT : VFPCompareOp<44, "vlt.f32", setolt>;
+defm VFCMPLE : VFPCompareOp<45, "vle.f32", setole>;
+}
+let Predicates = [NotBC, HasVPU] in {
+def VWEIRD : TPUInstP<(outs MPR:$Md), (ins VPR:$Vs),
+ "$Md = vweird.f32${pred} $Vs",
+ [(set MPR:$Md,
+ (int_tpu_weird (vNf32 VPR:$Vs)))]>, Bundle<B_Vany>,
+ IsVectorInstruction;
+}
+
+multiclass VFPComparePat<string OpName, PatFrag OpNode> {
+ def : Pat<(OpNode (vNf32 VPR:$x), (Splat (f32 fpimm:$y))),
+ (!cast<Instruction>(OpName#"ri") VPR:$x, tpuf32imm:$y)>;
+ def : Pat<(OpNode (vNf32 VPR:$x), (Splat (f32 GPR:$y))),
+ (!cast<Instruction>(OpName#"rs") VPR:$x, GPR:$y)>;
+ def : Pat<(OpNode (vNf32 VPR:$x), (vNf32 VPR:$y)),
+ (!cast<Instruction>(OpName#"rr") VPR:$x, VPR:$y)>;
+}
+// Patterns for the cases where we don't care about unordered.
+defm : VFPComparePat<"VFCMPEQ", seteq>;
+defm : VFPComparePat<"VFCMPNE", setne>;
+defm : VFPComparePat<"VFCMPGT", setgt>;
+defm : VFPComparePat<"VFCMPGE", setge>;
+defm : VFPComparePat<"VFCMPLT", setlt>;
+defm : VFPComparePat<"VFCMPLE", setle>;
+
+//===----------------------------------------------------------------------===//
+// Vector select and mask manipulation ops
+//===----------------------------------------------------------------------===//
+let Predicates = [HasVPU] in {
+// Note the order of these operands - the first operand is a VPR/GPR/imm32,
+// the second operand is always a VPR.
+defm VSELir : TPUInst<B_Vany, VIntALUVmselEncoding, (outs VPR:$Vd), (ins MPR:$m, i32imm:$y, VPR:$x),
+ "$Vd = vsel${pred} $m, $y, $x",
+ [(set VPR:$Vd, (vselect MPR:$m, (vNi32 (Splat imm:$y)),
+ (vNi32 VPR:$x)))]>,
+ BundleImmVy, IsVectorInstruction;
+defm VSELsr : TPUInst<B_Vany, VIntALUVmselEncoding, (outs VPR:$Vd), (ins MPR:$m, GPR:$y, VPR:$x),
+ "$Vd = vsel${pred} $m, $y, $x",
+ [(set VPR:$Vd, (vselect MPR:$m, (vNi32 (Splat GPR:$y)),
+ (vNi32 VPR:$x)))]>,
+ IsVectorInstruction;
+defm VSELrr : TPUInst<B_Vany, VIntALUVmselEncoding, (outs VPR:$Vd), (ins MPR:$m, VPR:$y, VPR:$x),
+ "$Vd = vsel${pred} $m, $y, $x",
+ [(set VPR:$Vd, (vselect MPR:$m, (vNi32 VPR:$y),
+ (vNi32 VPR:$x)))]>, IsVectorInstruction;
+}
+// Always make vsel use integer immediate since it is type agnostic and we don't
+// want to print an integer as float.
+def : Pat<(vNf32 (vselect MPR:$m, (vNf32 (Splat fpimm:$y)), VPR:$x)),
+ (VSELir MPR:$m, (ftoi $y), VPR:$x)>;
+// Because rr and sr are type agnostic, define the vNf32 variants as patterns.
+def : Pat<(vNf32 (vselect MPR:$m, VPR:$y, VPR:$x)),
+ (VSELrr MPR:$m, VPR:$y, VPR:$x)>;
+def : Pat<(vNf32 (vselect MPR:$m, (Splat GPR:$y), VPR:$x)),
+ (VSELsr MPR:$m, GPR:$y, VPR:$x)>;
+
+let Predicates = [NotBC] in {
+// Vector select with scalar predicate.
+let isPseudo = 1 in {
+ // Pseudo-select instruction. Note that this is lowered to either a predicated
+ // VMOVr or VMOVi.
+ let Constraints = "$d = $a" in {
+ def tcSELrr : TPUInst<(outs VPR:$d), (ins PPR:$p, VPR:$a, VPR:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set VPR:$d, (select PPR:$p,
+ (vNi32 VPR:$a), (vNi32 VPR:$b)))]>,
+ Bundle<B_Sany>;
+ def tcSELri : TPUInst<(outs VPR:$d), (ins PPR:$p, VPR:$a, i32imm:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set VPR:$d, (select PPR:$p, (vNi32 VPR:$a),
+ (vNi32 (Splat imm:$b))))]>,
+ Bundle<B_Sany>;
+ def tcSELrif : TPUInst<(outs VPR:$d), (ins PPR:$p, VPR:$a, tpuf32imm:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set VPR:$d, (select PPR:$p, (vNf32 VPR:$a),
+ (Splat (f32 fpimm:$b))))]>,
+ Bundle<B_Sany>;
+ }
+ let Constraints = "$d = $b" in {
+ def tcSELir : TPUInst<(outs VPR:$d), (ins PPR:$p, i32imm:$a, VPR:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set VPR:$d, (select PPR:$p,
+ (vNi32 (Splat imm:$a)), (vNi32 VPR:$b)))]>,
+ Bundle<B_Sany>;
+ def tcSELirf : TPUInst<(outs VPR:$d), (ins PPR:$p, tpuf32imm:$a, VPR:$b),
+ "$d = #SEL $p, $a, $b",
+ [(set VPR:$d, (select PPR:$p,
+ (vNf32 (Splat (f32 fpimm:$a))), (vNf32 VPR:$b)))]>,
+ Bundle<B_Sany>;
+ }
+}
+
+def : Pat<(vNf32 (select PPR:$p, VPR:$y, VPR:$x)),
+ (tcSELrr PPR:$p, VPR:$y, VPR:$x)>;
+
+def VMMOV : TPUInstP<(outs MPR:$Md), (ins MPR:$Ms), "$Md = vmmov${pred} $Ms",
+ []>,
+ Bundle<B_Misc>;
+def VMNEG : TPUInstP<(outs MPR:$Md), (ins MPR:$Ms), "$Md = vmneg${pred} $Ms",
+ [(set MPR:$Md, (vnot MPR:$Ms))]>,
+ Bundle<B_Misc>;
+// TODO(jmolloy): Fix ISS so these are't tied on VFC.
+let Constraints = "$Md = $Ms" in {
+def VMAND : TPUInstP<(outs MPR:$Md), (ins MPR:$Ms, MPR:$Mt),
+ "$Md = vmand${pred} $Mt",
+ [(set MPR:$Md, (and MPR:$Ms, MPR:$Mt))]>,
+ Bundle<B_Misc>;
+def VMOR : TPUInstP<(outs MPR:$Md), (ins MPR:$Ms, MPR:$Mt),
+ "$Md = vmor${pred} $Mt",
+ [(set MPR:$Md, (or MPR:$Ms, MPR:$Mt))]>,
+ Bundle<B_Misc>;
+def VMXOR : TPUInstP<(outs MPR:$Md), (ins MPR:$Ms, MPR:$Mt),
+ "$Md = vmxor${pred} $Mt",
+ [(set MPR:$Md, (xor MPR:$Ms, MPR:$Mt))]>,
+ Bundle<B_Misc>;
+}
+let isPseudo = 1 in {
+def VMZERO : TPUInstP<(outs MPR:$Md), (ins),
+ "$Md = #VMZERO${pred}",
+ [(set MPR:$Md, (vNi1 (Splat 0)))]>, Bundle<B_Misc>;
+}
+def : Pat<(xor MPR:$m, (Splat -1)), (VMNEG MPR:$m)>;
+def : Pat<(vNi1 (Splat -1)), (VMNEG (VMZERO))>;
+
+let isPseudo = 1, usesCustomInserter = 1 in {
+def VMREAD : TPUInstP<(outs VPR:$d), (ins MPR:$m),
+ "$d = #VMREAD${pred} $m",
+ [(set VPR:$d, (vNi32 (zext MPR:$m)))]>;
+}
+
+let isPseudo = 1, isReMaterializable = 1 in {
+def VMLANE : TPUInstP<(outs MPR:$m), (ins i32imm:$lane),
+ "$m = #VMLANE${pred} $lane",
+ []>;
+}
+
+let mayLoad = 1, mayStore = 1 in {
+let isPush = 1 in
+ def VPUSH : TPUInstP<(outs V2SFPR:$v2s), (ins VPR:$v),
+ "$v2s = vpush${pred} $v", []>,
+ Bundle<B_VST>, Sched<[WriteV2SF]>;
+
+let isPop = 1 in
+ def SPOP : TPUInstP<(outs GPR:$s), (ins V2SFPR:$v2s),
+ "$s = spop${pred} $v2s", []>,
+ Bundle<B_Sany>, Sched<[WriteV2SFPop]>;
+} // mayLoad = 1, mayStore = 1
+} // Predicates = [NotBC]
+
+//===----------------------------------------------------------------------===//
+// Vector manipulation ops
+//===----------------------------------------------------------------------===//
+class UnaryOp<string Name, SDPatternOperator Intr, ValueTypeByHwMode VDstType,
+ ValueTypeByHwMode SrcType> :
+ TPUInstP<(outs VPR:$Vd), (ins VPR:$x), "$Vd = "#Name#"${pred} $x",
+ [(set (VDstType VPR:$Vd), (Intr (SrcType VPR:$x)))]>, Bundle<B_Vany>, IsVectorInstruction;
+class UnaryVFOp<string Name, SDPatternOperator Intr> : UnaryOp<Name, Intr, vNf32, vNf32>;
+class UnaryVIOp<string Name, SDPatternOperator Intr> : UnaryOp<Name, Intr, vNi32, vNi32>;
+
+let Predicates = [NotBC, HasVPU] in {
+// Iota for vectors - produces {0, 1, 2, 3, 4, 5, 6, 7, ...}
+def VLANESEQ : TPUInstP<(outs VPR:$Vd), (ins),
+ "$Vd = vlaneseq.u32${pred}",
+ [(set (vNi32 VPR:$Vd), (int_tpu_vlaneseq))]>,
+ Bundle<B_Vany>, IsVectorInstruction;
+def VROTDOWNr : UnaryVIOp<"vrot.slane.down", int_tpu_vrot_sublane_down>,
+ Sched<[WriteRotateSLane]>;
+def : Pat<(vNf32 (int_tpu_vrot_sublane_down (vNf32 VPR:$x))),
+ (VROTDOWNr (vNf32 VPR:$x))>;
+
+// There is no VROTDOWNri instruction in the ISA yet. This is a pseudo that will
+// be expanded to N VROTDOWNr's early.
+let isCodeGenOnly = 1, usesCustomInserter = 1 in {
+def VROTDOWNri : TPUInstP<(outs VPR:$Vd), (ins VPR:$x, i32imm:$n),
+ "$Vd = #VROTDOWNri${pred} $x, $n",
+ []>,
+ Bundle<B_Vany>, IsVectorInstruction;
+}
+def : Pat<(Vrotdown (vNf32 VPR:$x), imm:$n), (VROTDOWNri VPR:$x, imm:$n)>;
+def : Pat<(Vrotdown (vNi32 VPR:$x), imm:$n), (VROTDOWNri VPR:$x, imm:$n)>;
+
+def VBROADCASTr : TPUInstP<(outs VPR:$Vd), (ins VPR:$v, GPR:$i),
+ "$Vd = vbroadcast${pred} $v, $i",
+ [(set VPR:$Vd, (Vbroadcast (vNi32 VPR:$v), (i32 GPR:$i)))]>,
+ Bundle<B_Vany>, IsVectorInstruction;
+
+def VBROADCASTi : TPUInstP<(outs VPR:$Vd), (ins VPR:$v, i32imm:$i),
+ "$Vd = vbroadcast${pred} $v, $i",
+ [(set VPR:$Vd, (Vbroadcast (vNi32 VPR:$v), (i32 imm:$i)))]>,
+ Bundle<B_Vany>, IsVectorInstruction;
+def VPOPCNTr : UnaryVIOp<"vpcnt", ctpop>;
+def VCLZr : UnaryVIOp<"vclz", ctlz>;
+def VEXPONENTr : UnaryOp<"vf32.e.s32", int_tpu_exponent, vNi32, vNf32>;
+defm VCOMPOSE : VFPALUOpYX_NC<27, "vf32.f32", int_tpu_compose>, Sched<[WriteFloatCompose]>;
+def VSIGNIFICANDr : UnaryOp<"vf32.s.s32", int_tpu_significand, vNi32, vNf32>;
+defm VPACK : VFPALUOpYX_NC<28, "vpack.f32.f16", int_tpu_pack>, Sched<[WritePackingInst]>;
+defm VPACKC : VFPALUOpYX_NC<55, "vpackc.f32.f16", int_tpu_packc>, Sched<[WritePackingInst]>;
+def VUNPACKU : UnaryVFOp<"vunpacku", int_tpu_unpacku>, Sched<[WritePackingInst]>;
+def VUNPACKL : UnaryVFOp<"vunpackl", int_tpu_unpackl>, Sched<[WritePackingInst]>;
+
+// Pseudo instruction to read the zeroth vector element. Will be expanded to
+// vpush; spop.
+let isCodeGenOnly = 1, usesCustomInserter = 1 in {
+def VREAD : TPUInstP<(outs GPR:$d), (ins VPR:$v),
+ "$d = vread${pred} $v",
+ [(set GPR:$d, (f32 (extractelt (vNf32 VPR:$v), (i32 0))))]>;
+}
+def : Pat<(extractelt (vNi32 VPR:$v), (i32 0)), (VREAD VPR:$v)>;
+def : Pat<(extractelt (vNi1 MPR:$m), (i32 0)),
+ (CMPNEri (VREAD (VSELir MPR:$m, (i32 1), (VIMMI 0))), (i32 0))>;
+
+// Those pattern only work for sparsecore. Rotating the lanes on tensorcore
+// requires a different sequence.
+let Predicates = [HasV8] in {
+def : Pat<(extractelt (vNi32 VPR:$v), imm:$n), (VREAD (VROTDOWNri VPR:$v, imm:$n))>;
+def : Pat<(extractelt (vNf32 VPR:$v), imm:$n), (VREAD (VROTDOWNri VPR:$v, imm:$n))>;
+}
+
+// Remove unnecessary bitcasts.
+def : Pat<(vNi32 (bitconvert (vNf32 VPR:$value))), (vNi32 VPR:$value)>;
+def : Pat<(vNf32 (bitconvert (vNi32 VPR:$value))), (vNf32 VPR:$value)>;
+def : Pat<(i32 (bitconvert (f32 GPR:$value))), (i32 GPR:$value)>;
+def : Pat<(f32 (bitconvert (i32 GPR:$value))), (f32 GPR:$value)>;
+
+def : Pat<(vNi32 (zext (vNi1 MPR:$x))), (VSELir MPR:$x, (i32 1), (VIMMI 0))>;
+def : Pat<(vNi32 (sext (vNi1 MPR:$x))), (VSELir MPR:$x, (i32 -1), (VIMMI 0))>;
+def : Pat<(vNi1 (trunc (vNi32 VPR:$x))),
+ (VCMPEQri (VANDri $x, (i32 1)), (i32 1))>;
+
+def : Pat<(int_tpu_make_restrict_ptr (i32 GPR:$value)),
+ (i32 GPR:$value)>;
+def : Pat<(int_tpu_make_restrict_ptr_f (i32 GPR:$value)),
+ (i32 GPR:$value)>;
+} // Predicates = [NotBC]
+
+
+//===----------------------------------------------------------------------===//
+// Vector extended unary (EUP)
+//===----------------------------------------------------------------------===//
+let mayLoad = 1, mayStore = 1, usesCustomInserter = 1 in {
+let isPush = 1 in {
+defm VRSQRT : TPUInstVany<VIntALUEupOpEncoding<48>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vrsqrt.f32${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_rsqrt (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+defm VPOW2 : TPUInstVany<VIntALUEupOpEncoding<49>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vpow2.f32${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_pow2 (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+defm VLOG2 : TPUInstVany<VIntALUEupOpEncoding<50>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vlog2.f32${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_log2 (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+defm VTANH : TPUInstVany<VIntALUEupOpEncoding<51>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vtanh.f32${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_tanh (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+defm VRCP : TPUInstVany<VIntALUEupOpEncoding<52>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vrcp.f32${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_rcp (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+defm VPUSH_EUP : TPUInstVany<VIntALUEupOpEncoding<53>, (outs ERFPR:$eup), (ins VPR:$x),
+ "$eup = vpush${pred} $x",
+ [(set ERFPR:$eup, (int_tpu_eup_push (vNf32 VPR:$x)))],
+ YOpIdxNone>,
+ Sched<[WriteEup]>, IsVectorInstruction;
+} // isPush = 1
+
+let isPop = 1 in {
+
+defm VRES_EUP :
+ TPUInstVResAny<
+ VRES_EUPEncoding,
+ (outs VPR:$v), (ins ERFPR:$eup), "$v = vpop${pred} $eup",
+ [(set (vNf32 VPR:$v), (int_tpu_eup_pop (i32 ERFPR:$eup)))]>,
+ Sched<[WriteEupPop]>, IsVectorInstruction;
+
+} // isPop = 1
+} // mayLoad = 1, mayStore = 1, usesCustomInserter = 1
+
+//===----------------------------------------------------------------------===//
+// SFlag instructions
+//===----------------------------------------------------------------------===//
+// Workaround to be able to define patterns along with instructions within
+// multiclass. We set some instructions properties in the final definition.
+// So to be able to compile, the patterns inherit form this structure.
+// We set the following instruction attributes at the high level definition. Add
+// those to pattern to be able to compile.
+class DummyInfo {
+ bit hasSideEffects;
+ bit mayStore;
+ bit mayLoad;
+}
+
+multiclass BaseSflagIntrinsicInst<string instr, Intrinsic intrinsic> {
+ def ii : TPUInstP<(outs), (ins i32imm:$targ, i32imm:$val),
+ instr,
+ [(intrinsic (Wrapper tglobaladdr:$targ), (i32 imm16:$val))]>;
+ def : Pat<(intrinsic (i32 imm:$targ), (i32 imm16:$val)),
+ (!cast<Instruction>(NAME#"ii") (i32 imm:$targ), (i32 imm:$val))>,
+ DummyInfo;
+ def ri : TPUInstP<(outs), (ins GPR:$targ, i32imm:$val),
+ instr,
+ [(intrinsic (i32 GPR:$targ), (i32 imm16:$val))]>;
+
+ def ir : TPUInstP<(outs), (ins i32imm:$targ, GPR:$val),
+ instr,
+ [(intrinsic (Wrapper tglobaladdr:$targ), (i32 GPR:$val))]>;
+ def : Pat<(intrinsic (i32 imm:$targ), (i32 GPR:$val)),
+ (!cast<Instruction>(NAME#"ir") (i32 imm:$targ), (i32 GPR:$val))>,
+ DummyInfo;
+ def rr : TPUInstP<(outs), (ins GPR:$targ, GPR:$val),
+ instr,
+ [(intrinsic (i32 GPR:$targ), (i32 GPR:$val))]>;
+}
+
+// SflagIntrinsicInst represents any other VSYNCSET/VSYNCADD that
+// is mapped to an intrinsic.
+multiclass FlagIntrinsicInst<string mnemonic, Intrinsic intrinsic> :
+ BaseSflagIntrinsicInst<"[sflag:${targ}] = " # mnemonic # "${pred} $val",
+ intrinsic>;
+
+multiclass WaitInst<string mnemonic, Intrinsic intrinsic> :
+ BaseSflagIntrinsicInst<"_ = " # mnemonic # "${pred} [sflag:${targ}], $val",
+ intrinsic>;
+
+multiclass WaitDoneInst<string mnemonic, Intrinsic intrinsic> {
+ def i : TPUInstP<(outs), (ins i32imm:$imm),
+ "_ = " # mnemonic # "${pred} [sflag:$imm]",
+ [(intrinsic (Wrapper tglobaladdr:$imm))]>;
+ def : Pat<(intrinsic (i32 imm:$imm)),
+ (!cast<Instruction>(NAME#"i") (i32 imm:$imm))>, DummyInfo;
+ def r : TPUInstP<(outs), (ins GPR:$targ),
+ "_ = " # mnemonic # "${pred} [sflag:$targ]",
+ [(intrinsic GPR:$targ)]>;
+}
+
+// SFlagStoreInst represents a pure VSYNCSET that maps to a store_sflag.
+multiclass SFlagStoreInst<string mnemonic> {
+ def ii : TPUInstP<(outs), (ins i32imm:$targ, i32imm:$imm),
+ "[sflag:${targ}] = " # mnemonic # "${pred} $imm",
+ [(store_sflag (i32 imm16:$imm), (Wrapper tglobaladdr:$targ))]>;
+ def : Pat<(store_sflag (i32 imm16:$imm), (i32 imm:$targ)),
+ (!cast<Instruction>(NAME#"ii") (i32 imm:$targ), (i32 imm:$imm))>,
+ DummyInfo;
+ def ri : TPUInstP<(outs), (ins GPR:$targ, i32imm:$imm),
+ "[sflag:${targ}] = " # mnemonic # "${pred} $imm",
+ [(store_sflag (i32 imm16:$imm), (i32 GPR:$targ))]>;
+ def ir : TPUInstP<(outs), (ins i32imm:$targ, GPR:$val),
+ "[sflag:${targ}] = " # mnemonic # "${pred} $val",
+ [(store_sflag (i32 GPR:$val), (Wrapper tglobaladdr:$targ))]>;
+ def : Pat<(store_sflag (i32 GPR:$val), (i32 imm:$targ)),
+ (!cast<Instruction>(NAME#"ir") (i32 imm:$targ), (i32 GPR:$val))>,
+ DummyInfo;
+ def rr : TPUInstP<(outs), (ins GPR:$targ, GPR:$val),
+ "[sflag:${targ}] = " # mnemonic # "${pred} $val",
+ [(store_sflag (i32 GPR:$val), (i32 GPR:$targ))]>;
+}
+
+multiclass SyncInst<string prefix>{
+let hasSideEffects = 1 in {
+defm EQ : WaitInst<prefix#"wait.eq", int_tpu_waiteq>;
+defm NE : WaitInst<prefix#"wait.ne", int_tpu_waitne>;
+defm GT : WaitInst<prefix#"wait.gt", int_tpu_waitgt>;
+defm GE : WaitInst<prefix#"wait.ge", int_tpu_waitge>;
+defm LT : WaitInst<prefix#"wait.lt", int_tpu_waitlt>;
+defm LE : WaitInst<prefix#"wait.le", int_tpu_waitle>;
+
+defm DONE : WaitDoneInst<prefix#"wait.done", int_tpu_waitdone>;
+defm NOTDONE : WaitDoneInst<prefix#"wait.notdone", int_tpu_waitnotdone>;
+}
+}
+
+multiclass SetAddSyncFlag<string setMnemonic, string addMnemonic> {
+let mayStore = 1 in {
+defm SET : SFlagStoreInst<setMnemonic#".s32">;
+defm SET_DONE : FlagIntrinsicInst<setMnemonic#".done.s32",
+ int_tpu_syncset_done>;
+defm SET_NOTDONE : FlagIntrinsicInst<setMnemonic#".notdone.s32",
+ int_tpu_syncset_notdone>;
+defm SET_REMOTE : FlagIntrinsicInst<setMnemonic#".remote.s32",
+ int_tpu_syncset_remote>;
+defm SET_REMOTE_DONE : FlagIntrinsicInst<setMnemonic#".remote.done.s32",
+ int_tpu_syncset_remote_done>;
+defm SET_REMOTE_DONEINV : FlagIntrinsicInst<setMnemonic#".remote.doneinv.s32",
+ int_tpu_syncset_remote_doneinv>;
+}
+
+let mayLoad = 1, mayStore = 1 in {
+defm ADD : FlagIntrinsicInst<addMnemonic#".s32",
+ int_tpu_syncadd>;
+defm ADD_DONE : FlagIntrinsicInst<addMnemonic#".done.s32",
+ int_tpu_syncadd_done>;
+defm ADD_NOTDONE : FlagIntrinsicInst<addMnemonic#".notdone.s32",
+ int_tpu_syncadd_notdone>;
+defm ADD_REMOTE : FlagIntrinsicInst<addMnemonic#".remote.s32",
+ int_tpu_syncadd_remote>;
+defm ADD_REMOTE_DONE : FlagIntrinsicInst<addMnemonic#".remote.done.s32",
+ int_tpu_syncadd_remote_done>;
+defm ADD_REMOTE_DONEINV : FlagIntrinsicInst<addMnemonic#".remote.doneinv.s32",
+ int_tpu_syncadd_remote_doneinv>;
+}
+}
+
+//===----------------------------------------------------------------------===//
+// Call related instructions.
+//===----------------------------------------------------------------------===//
+def SDT_TPUCall : SDTypeProfile<0, -1, [SDTCisVT<0, i32>]>;
+def SDT_TPUCallSeqStart : SDCallSeqStart<[SDTCisVT<0, i32>,
+ SDTCisVT<1, i32>]>;
+def SDT_TPUCallSeqEnd : SDCallSeqEnd<[SDTCisVT<0, i32>, SDTCisVT<1, i32>]>;
+
+def callseq_start : SDNode<"ISD::CALLSEQ_START", SDT_TPUCallSeqStart,
+ [SDNPHasChain, SDNPOutGlue, SDNPSideEffect]>;
+def callseq_end : SDNode<"ISD::CALLSEQ_END", SDT_TPUCallSeqEnd,
+ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+ SDNPSideEffect]>;
+
+
+def call : SDNode<"TPUISD::CALL", SDT_TPUCall,
+ [SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
+ SDNPVariadic]>;
+
+def calltarget : Operand<i32>;
+let isCall=1 in {
+ def CALL : TPUInstP<(outs), (ins calltarget:$dst), "_ = call${pred} $dst", []>;
+}
+def : Pat<(call tglobaladdr:$dst), (CALL tglobaladdr:$dst)>;
+def : Pat<(call texternalsym:$dst), (CALL texternalsym:$dst)>;
+
+let hasSideEffects=1 in {
+def CALLSEQ_START :
+ TPUInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "Callseq_start $amt1",
+ [(callseq_start timm:$amt1, timm:$amt2)]>;
+def CALLSEQ_END :
+ TPUInst<(outs), (ins i32imm:$amt1, i32imm:$amt2),
+ "Callseq_end $amt1",
+ [(callseq_end timm:$amt1, timm:$amt2)]>;
+}
+
+ ##################################### Tensor Core td file
+
+//===----------------------------------------------------------------------===//
+// Flag instructions
+//===----------------------------------------------------------------------===//
+let Predicates = [HasVectorSflags, NotBC] in {
+// Mark all WAIT/SYNC instruction as using BundleVs. Not all the derived
+// instructions use Vs slot but the bundle tracker only reserves Vs slot for
+// scalar registers so it is okay to be conservative and this simplify sharing
+// the declaration with SparseCore.
+defm tcWAIT : SyncInst<"v">, Bundle<B_Misc>, IsVectorInstruction;
+defm tcSYNC : SetAddSyncFlag<"vsyncset", "vsyncadd">, Bundle<B_Misc>, IsVectorInstruction;
+
+let isPseudo = 1, usesCustomInserter = 1 in {
+def VFREADi : TPUInstP<(outs GPR:$d), (ins i32imm:$imm),
+ "$d =\t#SFLAGREAD${pred} [sflag:$imm]",
+ [(set GPR:$d, (i32 (load_sflag (Wrapper tglobaladdr:$imm))))]>,
+ Bundle<B_Misc>;
+def VFREADr : TPUInstP<(outs GPR:$d), (ins GPR:$r),
+ "$d =\t#SFLAGREAD${pred} [sflag:$r]",
+ [(set GPR:$d, (i32 (load_sflag GPR:$r)))]>,
+ Bundle<B_Misc>;
+def VFREADDONEi : TPUInstP<(outs GPR:$d), (ins i32imm:$imm),
+ "$d =\t#SFLAGREAD.done${pred} [sflag:$imm]",
+ [(set GPR:$d, (i32 (int_tpu_syncdonemov (Wrapper tglobaladdr:$imm))))]>,
+ Bundle<B_Misc>;
+def VFREADDONEr : TPUInstP<(outs GPR:$d), (ins GPR:$r),
+ "$d =\t#SFLAGREAD.done${pred} [sflag:$r]",
+ [(set GPR:$d, (i32 (int_tpu_syncdonemov GPR:$r)))]>,
+ Bundle<B_Misc>;
+}
+
+let mayLoad = 1, mayStore = 1 in {
+let isPush = 1 in {
+ def VSYNCMOVEi : TPUInstP<(outs V2SFPR:$v2s), (ins i32imm:$imm),
+ "$v2s =\tvsyncmov${pred} [sflag:$imm]", []>,
+ Bundle<B_Misc>, Sched<[WriteSFlagV2SF]>, IsVectorInstruction;
+ def VSYNCMOVEr : TPUInstP<(outs V2SFPR:$v2s), (ins GPR:$r),
+ "$v2s =\tvsyncmov${pred} [sflag:$r]", []>,
+ Bundle<B_Misc>, Sched<[WriteSFlagV2SF]>, IsVectorInstruction;
+ def VSYNCMOVEDONEi : TPUInstP<(outs V2SFPR:$v2s), (ins i32imm:$imm),
+ "$v2s =\tvsyncmov.done${pred} [sflag:$imm]", []>,
+ Bundle<B_Misc>, Sched<[WriteSFlagV2SF]>, IsVectorInstruction;
+ def VSYNCMOVEDONEr : TPUInstP<(outs V2SFPR:$v2s), (ins GPR:$r),
+ "$v2s =\tvsyncmov.done${pred} [sflag:$r]", []>,
+ Bundle<B_Misc>, Sched<[WriteSFlagV2SF]>, IsVectorInstruction;
+} // IsPush = 1
+} // mayLoad = 1, mayStore = 1
+} // Predicates = [HasVectorSflags, NotBC]
+
+class LdStInfo<string Inst> {
+ Instruction Opcode = !cast<Instruction>(Inst);
+ bit HasAddress = 0;
+ bit HasMask = 0;
+ bit HasStride = 0;
+ bit HasShuffle = 0;
+ bit HasVMask = 0;
+ bit HasLdReplicateEvenOdd = 0;
+ bit HasVsEvenOdd = 0;
+ bit HasIndex = 0;
+}
+
+def LdSTMemAccessTable : GenericTable {
+ let FilterClass = "LdStInfo";
+ let CppTypeName = "LdStInfoTy";
+ let Fields = ["Opcode", "HasAddress", "HasMask", "HasStride", "HasShuffle",
+ "HasVMask", "HasLdReplicateEvenOdd", "HasVsEvenOdd", "HasIndex"];
+ let PrimaryKey = ["Opcode"];
+ let PrimaryKeyName = "LdStInfo";
+}
+
+//===----------------------------------------------------------------------===//
+// Load
+//===----------------------------------------------------------------------===//
+let mayLoad = 1, isVMemLoadInstr = 1 in {
+class Vld<int immcount, dag iops, string ldtype, string address, string sublanespec,
+ string selector> : TPUInstP<(outs VPR:$Vd), iops,
+ "$Vd =\tvld"#ldtype#"${pred} ["#address#""#sublanespec#"]"#selector,[]>,
+ Bundle<B_VLD>,
+ BundleImmVy<!cond(!eq(immcount, 1): [IMM_OP_0],
+ !eq(immcount, 2): [IMM_OP_0, IMM_OP_1],
+ !eq(immcount, 3): [IMM_OP_0, IMM_OP_1, IMM_OP_2]),
+ IMM_2_to_5>, Sched<[WriteVLD]>, LdStInfo<NAME>, IsVectorInstruction;
+}
+
+// Two addressing mode.
+multiclass VldAddr<int immcount, dag iops, string ldtype, string sublanespec,
+ string selector> {
+ def i : Vld<!add(immcount, 1), !con((ins i32imm:$imm), iops),
+ ldtype, "vmem:$imm", sublanespec,
+ selector>;
+ def ri : Vld<!add(immcount, 1), !con((ins GPR:$Ss, i32imm:$imm), iops),
+ ldtype, "vmem:${Ss}+$imm", sublanespec,
+ selector>
+ { let HasAddress = 1; }
+}
+
+// May or may not have mask.
+multiclass VLdMask<int immcount, dag iops, string ldtype, string sublanespec,
+ string selector> {
+let HasMask = 1 in {
+ defm _MaskR : VldAddr<immcount, !con((ins GPR:$mask), iops),
+ ldtype, sublanespec#" sm:$mask", selector>;
+ defm _MaskI : VldAddr<!add(immcount, 1), !con((ins i32imm:$mask), iops),
+ ldtype, sublanespec#" sm:$mask", selector>;
+}
+ defm "" : VldAddr<immcount, iops, ldtype, sublanespec, selector>;
+}
+
+// May or may not have mask.
+multiclass VLdStride<dag iops, string ldtype> {
+let HasStride = 1 in {
+ // stride.
+ defm _StrideR : VLdMask<0, !con((ins GPR:$stride), iops), ldtype, " ss:$stride", "">;
+ defm _StrideI : VLdMask<1, !con((ins i32imm:$stride), iops), ldtype, " ss:$stride", "">;
+}
+ // No stride
+ defm "" : VLdMask<0, iops, ldtype, "", "">;
+}
+
+multiclass tcVld_ {
+ // Strided ld.
+ defm "" : VLdStride<(ins), "">;
+let HasShuffle = 1 in {
+ // Shuffle.
+ defm _ShuffleR : VLdMask<0, (ins GPR:$selector), ".sshfl", "", ", $selector">;
+ defm _ShuffleI : VLdMask<1, (ins i32imm:$selector), ".sshfl", "", ", $selector">;
+}
+ // Use custom inserter to attach the right memory operands.
+let usesCustomInserter = 1, isIndexedLoadStore = 1, HasIndex = 1 in {
+ // Indexed ld.
+ defm _IAR0 : VLdStride<(ins IARPR0:$iar), ".iar0">;
+ defm _IAR1 : VLdStride<(ins IARPR1:$iar), ".iar1">;
+}
+}
+
+defm tcVLV : tcVld_;
+
+// Class to map pseudo IAR instruction to the opcode we lower to after bundle
+// packing.
+class PseudoIAR {
+ Instruction PseudoOp = !cast<Instruction>(NAME);
+ Instruction NativeOp =
+ !cast<Instruction>(!subst("tcVSV_ODDEVEN", "tcVSV_IAR0",
+ !subst("tcVLV_REPLICATE_EVENODD", "tcVLV_IAR1", NAME)));
+}
+
+def PseudoIARInst : GenericTable {
+ let FilterClass = "PseudoIAR";
+ let CppTypeName = "PseudoIARTy";
+ let Fields = ["PseudoOp", "NativeOp"];
+ let PrimaryKey = ["PseudoOp"];
+ let PrimaryKeyName = "getPseudoIAROpcode";
+}
+
+let isPseudo = 1, HasLdReplicateEvenOdd = 1 in {
+defm tcVLV_REPLICATE_EVENODD : VLdStride<(ins IARPR1:$iar), "PSEUDO_IAR1">,
+ PseudoIAR;
+}
+
+multiclass Vld_pat<string Name, dag in_ops, dag out_ops>{
+ // We build the output dag children incrementally. To be able to merge all
+ // the nodes they need to have the same root. As we don't know the final
+ // instruction when we create the arguments we use outs and substitute it
+ // with the instruction when creating the final pattern.
+ defm : MultiVTypePat<in_ops,
+ !foreach(v, out_ops, !subst(outs, !cast<Instruction>(Name), v))>;
+}
+
+multiclass LdAddressing<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ defm : Vld_pat<Name#"ri", !con((Intr (add GPR:$Ss, imm:$imm)), in_ops),
+ !con((outs GPR:$Ss, imm:$imm), out_ops)>;
+ defm : Vld_pat<Name#"ri", !con((Intr GPR:$Ss), in_ops),
+ !con((outs GPR:$Ss, PatLeaf<(i32 0)>), out_ops)>;
+ defm : Vld_pat<Name#i, !con((Intr imm:$imm), in_ops),
+ !con((outs imm:$imm), out_ops)>;
+}
+
+multiclass LdPatMask<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ // No Mask
+ defm : LdAddressing<Intr, Name, !con((Intr (i32 255)), in_ops), out_ops>;
+ // Mask imm or register.
+ defm : LdAddressing<Intr, Name#"_MaskR", !con((Intr GPR:$mask), in_ops),
+ !con((outs GPR:$mask), out_ops)>;
+ defm : LdAddressing<Intr, Name#"_MaskI", !con((Intr imm:$mask), in_ops),
+ !con((outs imm:$mask), out_ops)>;
+}
+
+multiclass LdPatStride<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+// No stride.
+defm : LdPatMask<Intr, Name,
+ !con((Intr (i32 1)), in_ops), out_ops>;
+// Stride imm or register.
+defm : LdPatMask<Intr, Name#"_StrideR",
+ !con((Intr GPR:$stride), in_ops), !con((outs GPR:$stride), out_ops)>;
+defm : LdPatMask<Intr, Name#"_StrideI",
+ !con((Intr imm:$stride), in_ops), !con((outs imm:$stride), out_ops)>;
+}
+
+// Strided ld
+defm : LdPatStride<int_tpu_vld_strided, "tcVLV", (int_tpu_vld_strided), (outs)>;
+
+// Indexed load
+defm : LdPatStride<int_tpu_vld_indexed, "tcVLV_IAR0",
+ (int_tpu_vld_indexed IARPR0:$iar, (i32 0)), (outs IARPR0:$iar)>;
+defm : LdPatStride<int_tpu_vld_indexed, "tcVLV_IAR1",
+ (int_tpu_vld_indexed IARPR1:$iar, (i32 1)), (outs IARPR1:$iar)>;
+
+// Shuffle imm or register.
+defm : LdPatMask<int_tpu_vld_shuffle, "tcVLV_ShuffleR",
+ (int_tpu_vld_shuffle GPR:$shuffle), (outs GPR:$shuffle)>;
+defm : LdPatMask<int_tpu_vld_shuffle, "tcVLV_ShuffleI",
+ (int_tpu_vld_shuffle imm:$shuffle), (outs imm:$shuffle)>;
+
+// Pattern for native load instruction.
+defm : MultiVTypePat<(load_vmem (Wrapper tglobaladdr:$imm)), (tcVLVi imm:$imm)>;
+defm : MultiVTypePat<(load_vmem (imm:$imm)), (tcVLVi imm:$imm)>;
+defm : MultiVTypePat<(load_vmem (add GPR:$Ss, imm:$imm)), (tcVLVri GPR:$Ss, imm:$imm)>;
+defm : MultiVTypePat<(load_vmem GPR:$Ss), (tcVLVri GPR:$Ss, (i32 0))>;
+
+defm : LdPatStride<int_tpu_vld_replicate_evenodd_sublanes,
+ "tcVLV_REPLICATE_EVENODD",
+ (int_tpu_vld_replicate_evenodd_sublanes IARPR1:$iar),
+ (outs IARPR1:$iar)>;
+//===----------------------------------------------------------------------===//
+// Store
+//===----------------------------------------------------------------------===//
+let mayStore = 1, isVMemStoreInstr = 1 in {
+class Vst<int immcount, dag iops, string sttype, string address, string sublanespec,
+ string mask> : TPUInstP<(outs), !con((ins VPR:$Vs), iops),
+ "["#address#""#sublanespec#"] =\tvst"#sttype#"${pred}"#mask#" $Vs", []>,
+ Bundle<B_VST>, BundleImmVy<!cond(!eq(immcount, 1): [IMM_OP_0],
+ !eq(immcount, 2): [IMM_OP_0, IMM_OP_1],
+ !eq(immcount, 3): [IMM_OP_0, IMM_OP_1, IMM_OP_2]),
+ IMM_2_to_5>, LdStInfo<NAME>, IsVectorInstruction;
+}
+
+// Two addressing mode.
+multiclass VstAddr<int immcount, dag iops, string sttype, string sublanespec,
+ string mask> {
+ def i : Vst<!add(1, immcount), !con((ins i32imm:$imm), iops), sttype, "vmem:$imm", sublanespec, mask>;
+ def ri : Vst<!add(1, immcount), !con((ins GPR:$Ss, i32imm:$imm), iops), sttype, "vmem:${Ss}+$imm", sublanespec, mask>
+ { let HasAddress = 1; }
+}
+
+// May or may not have mask.
+multiclass VstMask<int immcount, dag iops, string sttype, string sublanespec, string vmask> {
+let HasMask = 1 in {
+ defm _MaskR : VstAddr<immcount, !con((ins GPR:$mask), iops), sttype, sublanespec#" sm:$mask",
+ vmask>;
+ defm _MaskI : VstAddr<!add(1, immcount), !con((ins i32imm:$mask), iops), sttype, sublanespec#" sm:$mask",
+ vmask>;
+}
+ // No mask
+ defm "" : VstAddr<immcount, iops, sttype, sublanespec, vmask>;
+}
+
+// May or may not have a stride.
+multiclass VstStride<dag iops, string sttype, string vmask> {
+let HasStride = 1 in {
+ defm _StrideR : VstMask<0, !con((ins GPR:$stride), iops), sttype, " ss:$stride", vmask>;
+ defm _StrideI : VstMask<1, !con((ins i32imm:$stride), iops), sttype, " ss:$stride", vmask>;
+}
+ // No stride
+ defm "" : VstMask<0, iops, sttype, "", vmask>;
+}
+
+// May or may not have vmask.
+multiclass VstVMask<dag iops, string sttype, string source> {
+let HasVMask = 1 in {
+ defm _VMask : VstStride<!con((ins MPR:$vmask), iops), sttype#".msk", source#" $vmask,">;
+}
+ // No VMask
+ defm "" : VstStride<iops, sttype, source>;
+}
+
+// May or may not have index.
+multiclass VstIndexed {
+// Use custom inserter to attach the right memory operands.
+let usesCustomInserter = 1, isIndexedLoadStore = 1, HasIndex = 1 in {
+ defm _IAR0 : VstVMask<(ins IARPR0:$iar), ".iar", " $iar,">;
+ defm _IAR1 : VstVMask<(ins IARPR1:$iar), ".iar", " $iar,">;
+}
+ // No iar
+ defm "" : VstVMask<(ins), "", "">;
+}
+
+defm tcVSV : VstIndexed;
+
+let isPseudo = 1, HasVsEvenOdd = 1 in {
+defm tcVSV_ODDEVEN : VstVMask<(ins IARPR0:$iar), ".PSEUDO_IAR0", " $iar,">, PseudoIAR;
+}
+
+// Match to the right store.
+multiclass MultiTypeStore<SDPatternOperator Intr, string Name, dag in_ops,
+ dag out_ops> {
+ def : Pat<!con((Intr (vNf32 VPR:$Vs)), in_ops),
+ !con((!cast<Instruction>(Name) VPR:$Vs), out_ops)>;
+ def : Pat<!con((Intr (vNi32 VPR:$Vs)), in_ops),
+ !con((!cast<Instruction>(Name) VPR:$Vs), out_ops)>;
+}
+
+multiclass Vst_pat<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ defm : MultiTypeStore<Intr, Name, in_ops,
+ !foreach(v, out_ops, !subst(outs, !cast<Instruction>(Name), v))>;
+}
+
+multiclass StAddressing<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ defm : Vst_pat<Intr, Name#"ri", !con((Intr (add GPR:$Ss, imm:$imm)), in_ops),
+ !con((outs GPR:$Ss, imm:$imm), out_ops)>;
+ defm : Vst_pat<Intr, Name#"ri", !con((Intr GPR:$Ss), in_ops),
+ !con((outs GPR:$Ss, PatLeaf<(i32 0)>), out_ops)>;
+ defm : Vst_pat<Intr, Name#i, !con((Intr imm:$imm), in_ops),
+ !con((outs imm:$imm), out_ops)>;
+}
+
+multiclass StPatMask<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ // No Mask.
+ defm : StAddressing<Intr, Name, !con((Intr (i32 255)), in_ops), out_ops>;
+ // Mask Register or Imm.
+ defm : StAddressing<Intr, Name#"_MaskR", !con((Intr GPR:$mask), in_ops),
+ !con((outs GPR:$mask), out_ops)>;
+ defm : StAddressing<Intr, Name#"_MaskI", !con((Intr imm:$mask), in_ops),
+ !con((outs imm:$mask), out_ops)>;
+}
+
+multiclass StPatStride<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+ // No stride.
+ defm : StPatMask<Intr, Name, !con((Intr (i32 1)), in_ops), out_ops>;
+ // Stride Register or Imm.
+ defm : StPatMask<Intr, Name#"_StrideR", !con((Intr GPR:$stride), in_ops),
+ !con((outs GPR:$stride), out_ops)>;
+ defm : StPatMask<Intr, Name#"_StrideI", !con((Intr imm:$stride), in_ops),
+ !con((outs imm:$stride), out_ops)>;
+}
+
+multiclass StPatVMask<SDPatternOperator Intr, string Name, dag in_ops, dag out_ops> {
+// No vector Mask.
+defm : StPatStride<Intr, Name, !con((Intr (vNi1 (Splat -1))), in_ops), out_ops>;
+// Vector register mask.
+defm : StPatStride<Intr, Name#"_VMask", !con((Intr (vNi1 MPR:$vmask)), in_ops),
+ !con((outs MPR:$vmask), out_ops)>;
+}
+
+// No IAR
+defm : StPatVMask<int_tpu_vst_strided, "tcVSV", (int_tpu_vst_strided), (outs)>;
+// IAR 0 and 1
+defm : StPatVMask<int_tpu_vst_indexed, "tcVSV_IAR0",
+ (int_tpu_vst_indexed IARPR0:$iar, (i32 0)), (outs IARPR0:$iar)>;
+defm : StPatVMask<int_tpu_vst_indexed, "tcVSV_IAR1",
+ (int_tpu_vst_indexed IARPR1:$iar, (i32 1)), (outs IARPR1:$iar)>;
+
+
+// Native store matching.
+multiclass MatchTcStoreType<RegisterClass VType> {
+ def : Pat<(store_vmem (VType VPR:$Vs), (Wrapper tglobaladdr:$imm)), (tcVSVi VPR:$Vs, imm:$imm)>;
+ def : Pat<(store_vmem (VType VPR:$Vs), (imm:$imm)), (tcVSVi VPR:$Vs, imm:$imm)>;
+ def : Pat<(store_vmem (VType VPR:$Vs), (i32 GPR:$Ss)), (tcVSVri VPR:$Vs, GPR:$Ss, (i32 0))>;
+ def : Pat<(store_vmem (VType VPR:$Vs), (add (i32 GPR:$Ss), imm:$imm)),
+ (tcVSVri VPR:$Vs, GPR:$Ss, imm:$imm)>;
+}
+defm : MatchTcStoreType<vNf32>;
+defm : MatchTcStoreType<vNi32>;
+
+defm : StPatVMask<int_tpu_vst_evenodd_sublanes, "tcVSV_ODDEVEN",
+ (int_tpu_vst_evenodd_sublanes IARPR0:$iar), (outs IARPR0:$iar)>;
+
+//===----------------------------------------------------------------------===//
+// Set IAR Intrinsics
+//===----------------------------------------------------------------------===//
+multiclass SetIAR<string postFix, PatFrag OpNode, int iarIndex,
+ DAGOperand iar = !cast<DAGOperand>("IARPR"#iarIndex),
+ SchedWrite Sch = !cast<SchedWrite>("WriteIar"#iarIndex)> {
+// Use custom inserter to attach the right memory operands.
+let usesCustomInserter = 1 in {
+ def "" : TPUInstP<(outs iar:$iar), (ins VPR:$vsrc),
+ "$iar =\tvsetiar."#postFix#"${pred} $vsrc",
+ [(set iar:$iar, (OpNode (vNi32 VPR:$vsrc), (i32 iarIndex)))]>,
+ Bundle<B_VST>, Sched<[Sch]>, IsVectorInstruction;
+}
+}
+
+multiclass SetIARMode<int iarIndex> {
+ defm _SET_LANE : SetIAR<"lane", int_tpu_set_lane_indexed, iarIndex>;
+ defm _SET_SUBLANE : SetIAR<"sublane", int_tpu_set_sublane_indexed, iarIndex>;
+ defm _SET_RAW : SetIAR<"raw", int_tpu_set_iar_raw, iarIndex>;
+}
+
+defm IAR0 : SetIARMode<0>;
+defm IAR1 : SetIARMode<1>;
+
+multiclass VectorTraceInst<Intrinsic intrinsic> {
+ // Note that vtrace takes two args, but LLO actually implements a simplified version which takes
+ // either sreg or immidiate. The following implementation take same simplified approach, which may
+ // be revised later.
+ def r : TPUInstP<(outs), (ins GPR:$op),
+ "_ =\tvtrace${pred} $op",
+ [(intrinsic GPR:$op)]>,
+ Bundle<B_Misc>;
+ def i : TPUInstP<(outs), (ins i32imm:$imm),
+ "_ =\tvtrace${pred} $imm",
+ [(intrinsic imm:$imm)]>,
+ Bundle<B_Misc>;
+}
+multiclass VectorSetTracemark<Intrinsic intrinsic> {
+ def r : TPUInstP<(outs), (ins GPR:$op),
+ "(tm) =\tvsettm${pred} $op",
+ [(intrinsic GPR:$op)]>,
+ Bundle<B_Misc>;
+ def i : TPUInstP<(outs), (ins i32imm:$imm),
+ "(tm) =\tvsettm${pred} $imm",
+ [(intrinsic imm16:$imm)]>,
+ Bundle<B_Misc>;
+}
+let Predicates = [HasVectorSflags] in {
+let hasSideEffects = 1 in {
+defm tcVTRACE : VectorTraceInst<int_tpu_vtrace>, IsVectorInstruction;
+defm tcVSETTM : VectorSetTracemark<int_tpu_vsettm>, IsVectorInstruction;
+}
+
+} // Predicates = [HasVectorSflags]
+
+//===----------------------------------------------------------------------===//
+// MXU operations
+//===----------------------------------------------------------------------===//
+// Mat push/mul can be masked or not.
+multiclass MatOpMasked<int i, string Name, string IntrName,
+ DAGOperand fiforegsrc, DAGOperand fiforegdst, SchedWrite Schedule,
+ Intrinsic Intr = !cast<Intrinsic>(IntrName#"_f32")> {
+ def "" : TPUInstP<(outs fiforegdst:$fifodst), (ins VPR:$Vs, fiforegsrc:$fifosrc),
+ Name # ".f32${pred} $Vs",
+ [(set fiforegdst:$fifodst, (Intr (vNf32 VPR:$Vs),
+ (vNi1 (Splat -1)), (i32 i), (i32 fiforegsrc:$fifosrc)))]>,
+ Sched<[Schedule]>, ExtraPredicates<[HasMXU]>;
+ def m : TPUInstP<(outs fiforegdst:$fifodst), (ins VPR:$Vs, MPR:$m, fiforegsrc:$fifosrc),
+ Name # ".msk.f32${pred} $m, $Vs",
+ [(set fiforegdst:$fifodst, (Intr (vNf32 VPR:$Vs),
+ (vNi1 MPR:$m), (i32 i), (i32 fiforegsrc:$fifosrc)))]>,
+ Sched<[Schedule]>, ExtraPredicates<[HasMXU]>;
+}
+
+multiclass MatPush<int i, string IntrName, string OpName, string FifoName,
+ DAGOperand fiforeg = !cast<DAGOperand>(FifoName#i)> :
+ MatOpMasked<i, OpName, IntrName, fiforeg, fiforeg,
+ !cast<SchedWrite>("WriteMatPush"#i)>;
+
+// MatPush may transpose or not
+multiclass MatPushXPos<int i, string IntrName, string OpName> {
+ defm "" : MatPush<i, IntrName, OpName, "GSFNPR">;
+let OtherPredicates = [UseGsftForXpose] in {
+ defm _XPOS : MatPush<i, IntrName#"_xpose", OpName#".xpose", "GSFTPR">;
+}
+// JF/DF don't have a special latch for transpose matpush.
+let OtherPredicates = [UseGsfnForXpose] in {
+ defm _XPOS_JF : MatPush<i, IntrName#"_xpose", OpName#".xpose", "GSFNPR">;
+}
+}
+
+multiclass MatPushMode<int i, string IntrName, string OpName> {
+ defm "" : MatPushXPos<i, IntrName, OpName>;
+ defm _LOW : MatPushXPos<i, IntrName#"_low", OpName#".low">;
+ defm _HI : MatPushXPos<i, IntrName#"_hi", OpName#".hi">;
+ defm _PACKED : MatPushXPos<i, IntrName#"_packed", OpName#".packed">;
+}
+
+multiclass MatMul<int i, string IntrName, string OpName, string ScheduleName> :
+ MatOpMasked<i, OpName, IntrName, !cast<DAGOperand>("GMRPR"#i),
+ !cast<DAGOperand>("MRFPR"#i), !cast<SchedWrite>(ScheduleName#i)>;
+
+multiclass MatMulMode<int i, string IntrName, string OpName> {
+ defm "" : MatMul<i, IntrName, OpName, "WriteMatMulMxu">;
+ defm _LOW : MatMul<i, IntrName#"_low", OpName#".low", "WriteMatMulMxu">;
+ defm _HI : MatMul<i, IntrName#"_hi", OpName#".hi", "WriteMatMulMxu">;
+let isPackedMatMul = 1 in {
+ defm _PACKED : MatMul<i, IntrName#"_packed", OpName#".packed", "WriteMatMulMxuPacked">;
+}
+}
+
+multiclass Dwg<int i,
+ DAGOperand gmr = !cast<DAGOperand>("GMRPR"#i),
+ DAGOperand gsfn = !cast<DAGOperand>("GSFNPR"#i),
+ DAGOperand gsft = !cast<DAGOperand>("GSFTPR"#i)> {
+ def N : TPUInstP<(outs gmr:$gmr), (ins gsfn:$gsfn),
+ "$gmr =\tvdwg.f16${pred} $gsfn",
+ [(set gmr:$gmr, (int_tpu_vdwg (i32 i), (i32 gsfn:$gsfn)))]>;
+ def T : TPUInstP<(outs gmr:$gmr), (ins gsft:$gsft),
+ "$gmr =\tvdwg.f16${pred} $gsft",
+ [(set gmr:$gmr, (int_tpu_vdwg_xpose (i32 i), (i32 gsft:$gsft)))]>;
+}
+
+multiclass MXU<int i, DAGOperand mrf = !cast<DAGOperand>("MRFPR"#i)> {
+let mayLoad = 1, mayStore = 1, usesCustomInserter = 1 in {
+let Itinerary = IIC_MXU_PUSH in {
+ defm MATPUSH : MatPushMode<i, "int_tpu_vmatpush", "$fifodst =\tvmatpush">;
+}
+let Itinerary = IIC_MXU_MUL, isPush = 1 in {
+ defm MATMUL : MatMulMode<i, "int_tpu_vmatmul", "$fifodst =\tvmatmul">;
+}
+let isDwg = 1 in {
+ defm DWG : Dwg<i>, Bundle<B_VEX>;
+}
+def MATPOP : TPUInstP<(outs VPR:$Vd), (ins mrf:$mrf),
+ "$Vd =\tvmatres.8x128.f32${pred} $mrf",
+ [(set VPR:$Vd, (int_tpu_vmatres_f32 i, (i32 mrf:$mrf)))]>,
+ Sched<[!cast<SchedWrite>("WriteMatRes"#i)]>, Bundle<B_VResAny> { let isPop = 1; }
+}
+}
+// Define 4 MXU for all platforms, we assume user won't try to use more MXU
+// than available on the platform. We can add more fine grain predicates later
+// to be able to report user errors.
+foreach Index = 0-3 in {
+defm MXU#Index : MXU<Index>, IsVectorInstruction;
+}
+
+//===----------------------------------------------------------------------===//
+// XLU operations
+//===----------------------------------------------------------------------===//
+class PseudoInstMapping<string PseudoInst, string Inst> {
+ Instruction Pseudo = !cast<Instruction>(PseudoInst);
+ Instruction Lowered = !cast<Instruction>(Inst);
+}
+
+def PseudoInstTable : GenericTable {
+ let FilterClass = "PseudoInstMapping";
+ let CppTypeName = "PseudoInstMappingTy";
+ let Fields = ["Pseudo", "Lowered"];
+ let PrimaryKey = ["Pseudo"];
+ let PrimaryKeyName = "PseudoInstMapping";
+}
+
+multiclass Transpose<string Name, string PostFix, string Sch, string IntrName,
+ int busIndex, DAGOperand trf> {
+// immediate width support only for now. Having none-constant width makes it
+// really hard to match Pop instructions associated to a transpose.
+// Height is an argument. Even though hardware doesn't need it we force user to
+// pass it to be able to compute an accurate latency.
+def "" : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, i32imm:$width, i32imm:$height, trf:$trfsrc),
+ "$trf =\t"#Name#"."#busIndex#PostFix#"${pred} $vsrc, $width",
+ [(set trf:$trf, (!cast<Intrinsic>(IntrName) (vNi32 VPR:$vsrc),
+ (timm:$width), (timm:$height), (i32 busIndex), (i32 trf:$trfsrc)))]>,
+ Bundle<B_VEX>, Sched<[!cast<SchedWrite>(Sch#busIndex)]>;
+
+// Packed transpose needs to be broken down into two instructions within going
+// in the same bundle. We emit a pseudo instruction with both source and expand
+// it post bundle packing into a packed instruction and a vsupp instruction.
+let isPacked = 1 in {
+def _PACKED : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, i32imm:$width, i32imm:$height, trf:$trfsrc),
+ "$trf =\t"#Name#"."#busIndex#".packed"#PostFix#"${pred} $vsrc, $width",
+ []>, Bundle<B_VEX1>;
+let isPseudo = 1 in {
+def _PACKED_PSEUDO : TPUInstP<(outs trf:$trf), (ins VPR:$vsrclow, VPR:$vsrchigh,
+ i32imm:$width, i32imm:$height, trf:$trfsrc),
+ "$trf =\t"#Name#"PACKED"#busIndex##PostFix#"${pred} $vsrclow, $width, $vsrchigh",
+ [(set trf:$trf, (!cast<Intrinsic>(IntrName#"_packed") (vNi32 VPR:$vsrclow),
+ (vNi32 VPR:$vsrchigh), (timm:$width),
+ (timm:$height), (i32 busIndex), (i32 trf:$trfsrc)))]>,
+ Bundle<B_VEXBoth>, Sched<[!cast<SchedWrite>(Sch#"Packed"#busIndex)]>,
+ PseudoInstMapping<NAME#"_PACKED_PSEUDO", NAME#"_PACKED">;
+} // isPseudo = 1
+} // isPacked = 1
+}
+
+// Transpose can be segmented or not.
+multiclass TransposeSegmented<string PostFix, string Sch, string IntrName,
+ int busIndex, DAGOperand trf> {
+ defm "" : Transpose<"vxpose", PostFix, Sch, IntrName, busIndex, trf>;
+let isSegmented = 1 in {
+ defm _SEGMENTED : Transpose<"vsxpose", PostFix, Sch,
+ IntrName#"_segmented", busIndex, trf>;
+}
+}
+
+// One transpose and one transpose end
+multiclass TransposeEnd<string Sch, string IntrName, int busIndex, DAGOperand trf> {
+ defm "" : TransposeSegmented<"", Sch, IntrName, busIndex, trf>;
+let isTransposeEnd = 1, isPush = 1 in {
+ defm _END : TransposeSegmented<".end", Sch#"End", IntrName#"_end", busIndex, trf>;
+}
+}
+
+// Pattern for the float case.
+multiclass TransposeFloatPat<string Name, string IntrName, int busIndex,
+ DAGOperand trf> {
+ def : Pat<(!cast<Intrinsic>(IntrName) (vNf32 VPR:$vsrc), (timm:$width), (timm:$height),
+ (i32 busIndex), (i32 trf:$trfsrc)),
+ (!cast<Instruction>(Name) VPR:$vsrc, i32imm:$width,
+ i32imm:$height, trf:$trfsrc)>;
+ // Packed case.
+ def : Pat<(!cast<Intrinsic>(IntrName#"_packed") (vNf32 VPR:$vsrclow),
+ (vNf32 VPR:$vsrchigh), (timm:$width), (timm:$height),
+ (i32 busIndex), (i32 trf:$trfsrc)),
+ (!cast<Instruction>(Name#"_PACKED_PSEUDO") VPR:$vsrclow, VPR:$vsrchigh,
+ i32imm:$width, i32imm:$height, trf:$trfsrc)>;
+}
+
+// Pattern for segmented and normal case.
+multiclass TransposeSegmentedFloatPat<string Name, string Intrinsic, int busIndex,
+ DAGOperand trf> {
+ defm : TransposeFloatPat<Name, Intrinsic, busIndex, trf>;
+ defm : TransposeFloatPat<Name#"_SEGMENTED", Intrinsic#"_segmented", busIndex, trf>;
+}
+
+// Pattern for transpose and transpose_end intrinsics.
+multiclass TransposeEndFloatPat<string Name, string Intrinsic, int busIndex,
+ DAGOperand trf> {
+ defm : TransposeSegmentedFloatPat<Name, Intrinsic, busIndex, trf>;
+ defm : TransposeSegmentedFloatPat<Name#"_END", Intrinsic#"_end", busIndex, trf>;
+}
+
+multiclass RotateSource<int busIndex, DAGOperand trf, DAGOperand SrcT, DAGOperand PatType,
+ ImmSlotRequirement slots, list<ImmOperRequirement> operands> {
+def "" : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, SrcT:$amount),
+ "$trf =\tvrot."#busIndex#"${pred} $vsrc, $amount",
+ [(set trf:$trf, (int_tpu_vrotate (vNi32 VPR:$vsrc),
+ (i32 PatType:$amount), (i32 busIndex)))]>,
+ Bundle<B_VEX>, Sched<[!cast<SchedWrite>("WritePermute"#busIndex)]>, BundleImm<slots, operands>;
+let isPacked = 1 in {
+ def _PACKED : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, SrcT:$amount),
+ "$trf =\tvrot."#busIndex#".packed${pred} $vsrc, $amount",
+ []>, Bundle<B_VEX1>, BundleImm<slots, operands>;
+let isPseudo = 1 in {
+def _PACKED_PSEUDO : TPUInstP<(outs trf:$trf), (ins VPR:$vsrclow, VPR:$vsrchigh, SrcT:$amount),
+ "$trf =\t#VROTPACKED_PSEUDO."#busIndex#".packed${pred} $vsrclow, $vsrchigh, $amount",
+ [(set trf:$trf, (int_tpu_vrotate_packed (vNi32 VPR:$vsrclow),
+ (vNi32 VPR:$vsrchigh), (i32 PatType:$amount), (i32 busIndex)))]>,
+ Bundle<B_VEXBoth>, Sched<[!cast<SchedWrite>("WritePermutePacked"#busIndex)]>,
+ BundleImm<slots, operands>,
+ PseudoInstMapping<NAME#"_PACKED_PSEUDO", NAME#"_PACKED">;
+} // isPacked = 1
+} // isPseudo = 1
+}
+
+multiclass Rotate<int busIndex, DAGOperand trf> {
+ defm r : RotateSource<busIndex, trf, GPR, GPR, IMM_NONE, []>;
+ // immediate amount case.
+ defm i : RotateSource<busIndex, trf, i32imm, imm, IMM_0_to_3, [IMM_OP_0]>;
+}
+
+// Pattern for the float case.
+multiclass RotateFloatPat<string Name, int busIndex, DAGOperand trf> {
+ def : Pat<(int_tpu_vrotate (vNf32 VPR:$vsrc), (i32 GPR:$amount), (i32 busIndex)),
+ (!cast<Instruction>(Name#r) VPR:$vsrc, GPR:$amount)>;
+ def : Pat<(int_tpu_vrotate (vNf32 VPR:$vsrc), (i32 imm:$amount), (i32 busIndex)),
+ (!cast<Instruction>(Name#i) VPR:$vsrc, imm:$amount)>;
+ // Packed pattern
+ def : Pat<(int_tpu_vrotate_packed (vNf32 VPR:$vsrclow), (vNf32 VPR:$vsrchigh),
+ (i32 GPR:$amount), (i32 busIndex)),
+ (!cast<Instruction>(Name#"r_PACKED_PSEUDO") VPR:$vsrclow, VPR:$vsrchigh, GPR:$amount)>;
+ def : Pat<(int_tpu_vrotate_packed (vNf32 VPR:$vsrclow), (vNf32 VPR:$vsrchigh),
+ (i32 imm:$amount), (i32 busIndex)),
+ (!cast<Instruction>(Name#"i_PACKED_PSEUDO") VPR:$vsrclow, VPR:$vsrchigh, imm:$amount)>;
+}
+
+multiclass XLaneInst<string Name, string IntrName, int busIndex, DAGOperand trf, DAGOperand spr,
+ SchedWrite Sch = !cast<SchedWrite>("WriteXLane"#busIndex)> {
+def "" : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc),
+ "$trf =\t"#Name#".xlane."#busIndex#"${pred} $vsrc",
+ [(set trf:$trf, (!cast<Intrinsic>("int_tpu_xlane_"#IntrName)
+ (vNf32 VPR:$vsrc), (i32 busIndex)))]>,
+ Bundle<B_VEX>, Sched<[Sch]>;
+def _SEGMENTED : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, spr:$spr),
+ "$trf =\t"#Name#".xlane."#busIndex#"${pred}.seg.perm $vsrc",
+ [(set trf:$trf, (!cast<Intrinsic>("int_tpu_xlane_segmented_"#IntrName)
+ (vNf32 VPR:$vsrc), (i32 spr:$spr), (i32 busIndex)))]>,
+ Bundle<B_VEX>, Sched<[Sch]>;
+}
+
+multiclass SetPatternReg<string Name, string postFix, PatFrag OpNode, int busIndex,
+ DAGOperand trf, DAGOperand pcr, SchedWrite Sch> {
+ def "" : TPUInstP<(outs pcr:$pcr), (ins VPR:$vsrc),
+ "$pcr =\tvsetperm."#busIndex#"."#postFix#"${pred} $vsrc",
+ [(set pcr:$pcr, (OpNode (vNi32 VPR:$vsrc), (i32 busIndex)))]>,
+ Bundle<B_VEX>, Sched<[Sch]>;
+}
+
+multiclass SetPermute<string postFix, PatFrag OpNode, int busIndex,
+ DAGOperand trf, DAGOperand pcr, SchedWrite Sch> {
+ defm "" : SetPatternReg<"vsetperm", postFix, OpNode, busIndex, trf, pcr, Sch>;
+}
+
+multiclass Permute<int busIndex, DAGOperand trf, DAGOperand pcr> {
+ def "" : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, pcr:$pcr),
+ "$trf =\tvperm."#busIndex#"${pred} $vsrc",
+ [(set trf:$trf, (int_tpu_permute (vNi32 VPR:$vsrc),
+ (i32 pcr:$pcr), (i32 busIndex)))]>,
+ Bundle<B_VEX>, Sched<[!cast<SchedWrite>("WritePermute"#busIndex)]>;
+let isPacked = 1 in {
+ def _PACKED : TPUInstP<(outs trf:$trf), (ins VPR:$vsrc, pcr:$pcr),
+ "$trf =\tvperm."#busIndex#".packed${pred} $vsrc",
+ []>, Bundle<B_VEX1>;
+let isPseudo = 1 in {
+ // Packed instruction also generates a vsupp instruction for the high bits.
+ // It gets expanded post bundle packing.
+ def _PACKED_PSEUDO : TPUInstP<(outs trf:$trf), (ins VPR:$vsrclow, VPR:$vsrchigh, pcr:$pcr),
+ "$trf =\t#VPERMPACKED."#busIndex#"${pred} $vsrclow, $vsrchigh",
+ [(set trf:$trf, (int_tpu_permute_packed (vNi32 VPR:$vsrclow),
+ (vNi32 VPR:$vsrchigh), (i32 pcr:$pcr), (i32 busIndex)))]>,
+ Bundle<B_VEXBoth>, Sched<[!cast<SchedWrite>("WritePermutePacked"#busIndex)]>,
+ PseudoInstMapping<NAME#"_PACKED_PSEUDO", NAME#"_PACKED">;
+} // isPseudo = 1
+} // isPacked = 1
+}
+
+// Instruction for supplemental packed source.
+def XLUSUPP_PACKED : TPUInstP<(outs), (ins VPR:$vsrc),
+ "_ =\tvsupp $vsrc", []>, Bundle<B_VEX0>;
+
+multiclass PermuteFloatPat<string Name, int busIndex, DAGOperand trf, DAGOperand pcr> {
+ def : Pat<(int_tpu_permute (vNf32 VPR:$vsrc), (i32 pcr:$pcr), (i32 busIndex)),
+ (!cast<Instruction>(Name) VPR:$vsrc, pcr:$pcr)>;
+ def : Pat<(int_tpu_permute_packed (vNf32 VPR:$vsrclow), (vNf32 VPR:$vsrchigh), (i32 pcr:$pcr), (i32 busIndex)),
+ (!cast<Instruction>(Name#"_PACKED_PSEUDO") VPR:$vsrclow, VPR:$vsrchigh, pcr:$pcr)>;
+}
+
+multiclass XLUBus<int busIndex, DAGOperand trf, DAGOperand pcr, DAGOperand spr> {
+// Use custom inserter to attach the right memory operands.
+let usesCustomInserter = 1 in {
+// Transpose is not marked as push. We only model other transpose_end
+// instructions as pushing in the FIFO. That allows us to model transpose as
+// a normal FIFO. Tranpose_end pushes a variable number of items based on its
+// width.
+let isTranspose = 1 in {
+defm TRANSPOSE : TransposeEnd<"WriteTranspose", "int_tpu_tc_transpose", busIndex, trf>;
+}
+let isPush = 1 in {
+let isPermute = 1 in {
+defm ROTATE : Rotate<busIndex, trf>;
+defm PERMUTE : Permute<busIndex, trf, pcr>;
+} // isPermute = 1
+let isReduce = 1 in {
+defm XLANE_ADD : XLaneInst<"vadd", "add", busIndex, trf, spr>;
+defm XLANE_MAX : XLaneInst<"vmax", "max",busIndex, trf, spr>;
+defm XLANE_MIN : XLaneInst<"vmin", "min", busIndex, trf, spr>;
+defm XLANE_MAXINDEX : XLaneInst<"vmax.index", "maxindex", busIndex, trf, spr>;
+defm XLANE_MININDEX : XLaneInst<"vmin.index", "minindex", busIndex, trf, spr>;
+} // isReduce = 1
+} // isPush = 1
+
+// SetPermute instructions are not FIFO.
+defm SETPERMUTE_U8 :
+ SetPermute<"u8", int_tpu_set_permute, busIndex, trf, pcr, WriteSetPermute>;
+defm SETSPR :
+ SetPatternReg<"vsetspr", "u1", int_tpu_set_spr, busIndex, trf, spr, WriteSetPermute>;
+defm SETPERMUTE_SUBLANE :
+ SetPermute<"all.u8", int_tpu_set_permute_sublane, busIndex, trf, pcr,
+ !cast<SchedWrite>("WriteSetPermuteAll"#busIndex)>;
+defm SETPERMUTE_BYTES :
+ SetPermute<"all.bytes.u32", int_tpu_set_permute_bytes, busIndex, trf, pcr,
+ !cast<SchedWrite>("WriteSetPermuteAll"#busIndex)>;
+} // usesCustomInserter = 1
+
+defm : TransposeEndFloatPat<NAME#"TRANSPOSE", "int_tpu_tc_transpose", busIndex, trf>;
+defm : RotateFloatPat<NAME#"ROTATE", busIndex, trf>;
+defm : PermuteFloatPat<NAME#"PERMUTE", busIndex, trf, pcr>;
+}
+
+multiclass XLUPop<int XLUIndex,
+ DAGOperand trf = !cast<DAGOperand>("TRFPR"#XLUIndex)> {
+let isPop = 1, usesCustomInserter = 1 in {
+ def Pop : TPUInstP<(outs VPR:$Vd), (ins trf:$trf),
+ "$Vd =\tvpop $trf",
+ [(set (vNi32 VPR:$Vd),
+ (int_tpu_tc_vtrfpop (i32 XLUIndex), (i32 trf:$trf)))]>,
+ Bundle<B_VResAny>, Sched<[!cast<SchedWrite>("WriteTrfPop"#XLUIndex)]>;
+}
+def : Pat<(vNf32 (int_tpu_tc_vtrfpop (i32 XLUIndex), (i32 trf:$trf))),
+ (!cast<Instruction>(NAME#Pop) trf:$trf)>;
+}
+
+multiclass TransposeUnit<int XluIndex,
+ DAGOperand trf = !cast<DAGOperand>("TRFPR"#XluIndex),
+ DAGOperand pcr = !cast<DAGOperand>("PCRPR"#XluIndex),
+ DAGOperand spr = !cast<DAGOperand>("SPRPR"#XluIndex)> {
+ defm B0 : XLUBus<XluIndex, trf, pcr, spr>;
+ defm B1 : XLUBus<!add(XluIndex, 2), trf, pcr, spr>;
+ defm "" : XLUPop<XluIndex>;
+}
+
+defm XLU0 : TransposeUnit<0>, IsVectorInstruction;
+defm XLU1 : TransposeUnit<1>, IsVectorInstruction;
+
+// Host interrupt
+let hasSideEffects = 1 in {
+def VINTr : TPUInstP<(outs), (ins GPR:$o),
+ "_ = vint${pred} $o",
+ [(int_tpu_tc_vint (i32 GPR:$o))]>,
+ Bundle<B_Misc>, IsVectorInstruction;
+def VINTi : TPUInstP<(outs), (ins i32imm:$o),
+ "_ = vint${pred} $o",
+ [(int_tpu_tc_vint (i32 imm:$o))]>,
+ Bundle<B_Misc>, BundleImm<IMM_0_to_3>, IsVectorInstruction;
+}
+
+// Pseudo random number generation
+let hasSideEffects = 1 in {
+def SetRngSeed : TPUInstP<(outs), (ins VPR:$Vx),
+ "_ = setrngseed${pred} $Vx",
+ [(int_tpu_tc_setrngseed (vNi32 VPR:$Vx))]>,
+ Bundle<B_Vany>, Sched<[WriteSetRngSeed]>, IsVectorInstruction;
+def GetRngSeed : TPUInstP<(outs VPR:$Vdst), (ins),
+ "$Vdst = getrngseed${pred}",
+ [(set VPR:$Vdst, (int_tpu_tc_getrngseed))]>,
+ Bundle<B_Vany>, Sched<[WriteGetRngState]>, IsVectorInstruction;
+def VRng : TPUInstP<(outs VPR:$Vdst), (ins),
+ "$Vdst = vrng.8x128.u32${pred}",
+ [(set VPR:$Vdst, (int_tpu_tc_vrng))]>,
+ Bundle<B_Vany>, Sched<[WriteRng]>, IsVectorInstruction;
+}
+
+let Predicates = [HasV1024] in {
+def VLANEMASK : TPUInstP<(outs MPR:$Md), (ins VPR:$Vs),
+ "$Md =\tvlmask${pred} $Vs",
+ [(set MPR:$Md,
+ (int_tpu_lane_mask (vNi32 VPR:$Vs)))]>,
+ Bundle<B_Vany>, IsVectorInstruction;
+def VSUBLANE_MASK : TPUInstP<(outs MPR:$Md), (ins VPR:$Vs),
+ "$Md =\tvsmask${pred} $Vs",
+ [(set MPR:$Md,
+ (int_tpu_sublane_mask (vNi32 VPR:$Vs)))]>,
+ Bundle<B_Vany>, IsVectorInstruction;
+}
diff --git a/LAVM/test/md.md b/LAVM/test/md.md
new file mode 100644
index 0000000..1f1a680
--- /dev/null
+++ b/LAVM/test/md.md
@@ -0,0 +1,252 @@
+# TODOs:
+#
+# focusing on:
+# - implement the "new" syntax for the target description
+# - allow nested ops on the left side: lavm.sub(tgt.neg(%a), %b) : lavm.add(%a, tgt.neg(%b));
+# - support operations with same name but different number of arguments
+# - error on wrong number of arguments in map section and in target sections
+# - catch duplicate argument names in the target section
+# - catch mismatch in the argument count and the type in the target section
+#
+# functionality:
+# -- define memory description section
+# -- allow [parallel] loops / ops with regions?
+# -- generate machine description C++ class
+# -- generate patterns for structured op tiling
+# - generate td file with the target dialect?
+# - allow sequences in rhs of map section: (op0(args...), op1(args), ...) so that we can chain buffer-level ops
+# - have type containers? (as in vreg, sreg, etc.)
+# - allow immediates
+# - support naming: lavm.add(%a, %b) : tgt.add(tgt.neg(%a):$n, tgt.neg($n)));
+# - allow instantiations for specific types only: lavm.add(%a : f32, %b : f32) -> f32 : zzz;
+# - how to express instructions operating on memory directly?
+# - allow arbitrary variable names and various parameter counts(?)
+# - support default values for arguments: lavm.add(%a, %b = 42)...
+# - detect infinite cycles in expansions
+# - how to represent denorm/NaN/rounding/saturating behavior?
+# - add support for AnyType, which is helpful to define MEM <-> AnyType loads/stores and op(%a) : %a empty expansions
+# - ??? change syntax to (lavm.matmul %dst %a %b) : (lavm.store (tgt.matmul (lavm.load %a) (lavm.load %b)) %dst);
+# - !!! change syntax to be target op-centric instead of lavm-op centric (define target ops rather than mappings
+# from lavm ops)
+# - allow lavm ops to drop "lavm." prefix?
+# - add pre-checkin tests
+# - re-implement the target description parser
+#
+# error checking:
+# - emit error at parsing *and stop* when ';' is missing at the end of the line
+# (this is a frequent cause of the late assert 'GetTargetExpansions(): lavm_op != nullptr')
+# - target section: remove variable names in lhs? as in tgt.nuts(%, %) or just get rid of the args on the lhs?
+# - catch duplicate expansions or anything else duplicated on lhs or rhs, issue warning and ignore
+# - filter out duplicates that differ only in variable names
+#
+# other:
+# - rename 'map' section into 'expansions'?
+
+### Operations and types supported by the target. ###
+target {
+ tgt.nuts(%x, %y) : (aaa, bbb) -> ccc, (a, b, c) -> (x, y, z), (f32, vector<128xf32>) -> ();
+ tgt.nuts(%x, %y) : ((m) -> p) -> q;
+ tgt.nuts(%x, %y) : ((m) -> ((p) -> bar)) -> q;
+ tgt.nuts(%x, %y) : ((m) -> ((p) -> (bar) -> foo, (p) -> ((bar) -> foo), ((p) -> (bar)) -> foo)) -> ((q) -> r);
+ tgt.nuts(%x, %y) : (m, ((n) -> ()) -> p) -> ((q) -> r);
+
+ tgt.i2f(%x) : (i32) -> f32, (vector<8x128xi32>) -> vector<8x128xf32>;
+ tgt.f2i(%x) : (f32) -> i32, (vector<8x128xf32>) -> vector<8x128xi32>;
+ tgt.neg(%x) : vector<8x128xf32>, f32; # vector<128xf32>
+
+ tgt.add(%x, %y) : vector<128xf32> @ 2, vector<8x128xf32> @ 3, f32 @ 1;
+ tgt.sub(%x, %y) : vector<128xf32>, vector<8x128xf32>, f32;
+ tgt.mul(%x, %y) : f32;
+
+ tgt.dot(%x, %y) : (vector<128xf32>, vector<128xf32>) -> vector<128xf32> @ 42;
+ tgt.matmul(%x, %y) : (vector<128x128xf32>, vector<128x128xf32>) -> vector<128x128xf32> @ 99;
+ tgt.wait() : () -> ();
+
+ tgt.dma(%to, %from, %size) : (memref<HBM>, memref<VMEM>, i32) -> ();
+ tgt.dma(%to, %from, %size) : (memref<VMEM>, memref<HBM>, i32) -> ();
+ tgt.dma(%to, %from, %size) : (memref<HBM>, memref<SMEM>, i32) -> ();
+ tgt.dma(%to, %from, %size) : (memref<SMEM>, memref<HBM>, i32) -> ();
+ tgt.load(%address) : (memref<SMEM>) -> f32, (memref<SMEM>) -> i32;
+ tgt.load(%address) : (memref<VMEM>) -> vector<128xf32>, (memref<VMEM>) -> vector<128xi32>;
+ tgt.load(%address) : (memref<VMEM>) -> vector<8x128xf32>, (memref<VMEM>) -> vector<8x128xi32>;
+ tgt.store(%value, %address) : (f32, memref<SMEM>) -> (), (i32, memref<SMEM>) -> ();
+ tgt.store(%value, %address) : (vector<128xf32>, memref<VMEM>) -> (), (vector<128xi32>, memref<VMEM>) -> ();
+ tgt.store(%value, %address) : (vector<8x128xf32>, memref<VMEM>) -> (), (vector<8x128xi32>, memref<VMEM>) -> ();
+
+ # pseudo.combine(%x) : f32 -> vector<128xf32>;
+ # pseudo.extract(%x) : vector<128xf32> -> f32;
+ # pseudo.combine(%x) : vector<128xf32> -> vector<8x128xf32>;
+ # pseudo.extract(%x) : vector<8x128xf32> -> vector<128xf32>;
+ # pseudo.combine(%x) : vector<8x128xf32> -> vector<128x128xf32>;
+ # pseudo.extract(%x) : vector<128x128xf32> -> vector<8x128xf32>;
+
+ # pseudo.combine_f32_to_128xf32(%x) : f32 -> vector<128xf32>;
+ # pseudo.combine_f32_to_8x128xf32(%x) : f32 -> vector<8x128xf32>;
+ # pseudo.combine_f32_to_128x128xf32(%x) : f32 -> vector<128x128xf32>;
+ # pseudo.combine_128xf32_to_8x128xf32(%x) : vector<128xf32> -> vector<8x128xf32>;
+ # pseudo.combine_128xf32_to_128x128xf32(%x) : vector<128xf32> -> vector<128x128xf32>;
+ # pseudo.combine_8x128xf32_to_128x128xf32(%x) : vector<8x128xf32> -> vector<128x128xf32>;
+ # pseudo.extract_f32_from_128xf32(%x) : vector<128xf32> -> f32;
+ # pseudo.extract_f32_from_8x128xf32(%x) : vector<8x128xf32> -> f32;
+ # pseudo.extract_f32_from_128x128xf32(%x) : vector<128x128xf32> -> f32;
+ # pseudo.extract_128xf32_from_8x128xf32(%x) : vector<8x128xf32> -> vector<128xf32>;
+ # pseudo.extract_128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<128xf32>;
+ # pseudo.extract_8x128xf32_from_128x128xf32(%x) : vector<128x128xf32> -> vector<8x128xf32>;
+}
+
+### Map LAVM ops to previously defined target ops or their combinations. ###
+# current restriction is that lavm ops on the lhs should appear with same arg names and count
+map {
+ lavm.neg(%a) : tgt.neg(%a);
+ lavm.neg(%a) : tgt.neg(tgt.i2f(%a));
+
+ # not yet supported
+ # lavm.none(%a) : %a;
+ # lavm.zadd(%a, %b) : tgt.add(lavm.none(%a), %b);
+
+ lavm.add(%a, %b) : tgt.add(%a, %b),
+ tgt.sub(%a, tgt.neg(%b)),
+ tgt.add(tgt.i2f(%a), tgt.i2f(%b)),
+ tgt.add(tgt.i2f(%b), %a), tgt.add(%b, tgt.i2f(%a)),
+ tgt.add(%b, tgt.add(tgt.add(%a, %b), tgt.neg(%b)));
+ lavm.add(%a, %b) : lavm.neg(tgt.add(lavm.neg(%a), lavm.neg(%b))),
+ lavm.neg(tgt.add(lavm.neg(tgt.f2i(%b)), tgt.neg(%a)));
+
+ lavm.sub(%a, %b) : lavm.add(%a, tgt.neg(%b));
+
+ # lavm.add_x(%dst, %a, %b) : # tgt.dma(
+ # tgt.store(tgt.add(tgt.load(tgt.dma(%a)),
+ # tgt.load(tgt.dma(%b))),
+ # %dst)
+ # # )
+ # ;
+ # lavm.matmul(%dst : memref<HBM>, %a : memref<HBM>, %b : memref<HBM>) :
+ # %aa = alloc() : memref<VMEM>,
+ # %bb = alloc() : memref<VMEM>,
+ # %cc = alloc() : memref<VMEM>,
+ # tgt.dma(%a, %aa),
+ # tgt.dma(%b, %bb),
+ # lavm.store(lavm.matmul(lavm.load(%aa), lavm.load(%bb)), %cc)
+ # tgt.dma(%cc, %dst); # dealloc???
+
+ lavm.dma(%to, %from, %size) : tgt.dma(%to, %from, %size);
+
+ lavm.load(%address) : tgt.load(%address);
+ # lavm.load(%address) : lavm.combine(tgt.load(%address));
+ # lavm.load(%address) : lavm.extract(tgt.load(%address));
+ lavm.store(%value, %address) : tgt.store(%value, %address);
+
+ # want:
+ lavm.matmul(%dst, %a, %b) : lavm.store(tgt.matmul(lavm.load(%a),
+ lavm.load(%b)),
+ %dst);
+
+ # ...no: get rid of these
+ # lavm.combine(%a) : pseudo.combine(%a);
+ # lavm.extract(%a) : pseudo.extract(%a);
+ # lavm.combine(%a) : pseudo.combine(pseudo.combine(%a));
+ # lavm.extract(%a) : pseudo.extract(pseudo.extract(%a));
+ # lavm.combine(%a) : pseudo.combine(pseudo.combine(pseudo.combine(%a)));
+ # lavm.extract(%a) : pseudo.extract(pseudo.extract(pseudo.extract(%a)));
+ #
+ # lavm.matmulY(%dst, %a, %b) : lavm.store(lavm.extract(
+ # tgt.matmul(lavm.combine(lavm.load(%a)),
+ # lavm.combine(lavm.load(%b)))),
+ # %dst);
+
+ # ... also no
+ # lavm.combineX(%a) : pseudo.combine_f32_to_128xf32(%a);
+ # lavm.combineX(%a) : pseudo.combine_f32_to_8x128xf32(%a);
+ # lavm.combineX(%a) : pseudo.combine_f32_to_128x128xf32(%a);
+ # lavm.combineX(%a) : pseudo.combine_128xf32_to_8x128xf32(%a);
+ # lavm.combineX(%a) : pseudo.combine_128xf32_to_128x128xf32(%a);
+ # lavm.combineX(%a) : pseudo.combine_8x128xf32_to_128x128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_f32_from_128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_f32_from_8x128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_f32_from_128x128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_128xf32_from_8x128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_128xf32_from_128x128xf32(%a);
+ # lavm.extractX(%a) : pseudo.extract_8x128xf32_from_128x128xf32(%a);
+ #
+ # lavm.matmulX(%dst, %a, %b) : lavm.store(lavm.extractX(
+ # tgt.matmul(lavm.combineX(lavm.load(%a)),
+ # lavm.combineX(lavm.load(%b)))),
+ # %dst);
+
+ # lavm.add(%a, %b) : tgt.add(%a, %a); # should cause a warning that %b is unused (implemented in wrong place, too late)
+ # lavm.add(%a, %b) : tgt.add(lavm.neg(%a), lavm.neg(%a)); # should cause an earlier warning
+ # lavm.add(%a, %b) : tgt.boo(%a, %b); # should be an error because tgt.boo is undefined
+ # lavm.add(%a, %b) : tgt.add(tgt.boo(%a), %b); # should be an error: undefined tgt.boo (implemented in wrong place, too late)
+ # lavm.add(%a, %b) : tgt.add(%a, %c); # this is an error! should be caught (NYI)
+}
+
+### Target memory description. ###
+memory {
+
+ # TPU-like memory:
+ HBM : size = 16G;
+ HBM : garbage = ignored;
+ VMEM : size = 16M;
+ SMEM : size = 16K;
+ CMEM : size = 16M;
+ HBM -> VMEM;
+ VMEM -> HBM;
+ HBM -> SMEM;
+ SMEM -> HBM;
+ HBM -> CMEM -> VMEM;
+ VMEM -> CMEM -> HBM;
+ # GPU-like memory:
+ GLOBAL : size = 8G;
+ SHARED : size = 16M;
+ LOCAL : size = 1M;
+ # CPU-like memory:
+ MEMORY : size = 64GB;
+ $L1 : size = 512K;
+ $L2 : size = 4M;
+ MEMORY -> $L2 -> $L1;
+ $L1 -> $L2 -> MEMORY;
+
+ # specifying
+ # - on the mem side: specify what is address, what is the value, stride, banks
+ # - on the register(?) side: size, stride, zero/sign-extended, mask, if can partially fill target reg
+ # - load/store latency, size/granularity
+ # - dma latency, size/granularity
+ # - cache latency, size/granularity, policy, associativity, on which path
+ # some properties are memory instructions' properties, can use attributes to represent them
+ #
+ # Interface: cache levels/sizes/latencies/lane sizes/policies
+ # memory spaces, dmas between memory spaces, sizes, etc.
+ #
+ # how to specify direct access to memory from other instructions? using implicit addresses? chained instructions?
+ # for ex. a chain of instructions, which process in memory data, and each modifies the implicit pointer
+ # used in the next instruction.
+ #
+ # how to represent memory local to PEs?
+ # how to represent in-memory calculations?
+
+ # name [-> $name -> ...] -> 'reg : load instructions; # do not use register?
+ # 'reg [-> $name -> ...] -> name : store instructions;
+ # name -> name : dma instructions;
+
+ # TPU-like memory:
+ # SMEM -> 'reg : sload(%address);
+ # 'reg -> SMEM : sstore(%value, %address);
+ # VMEM -> 'reg : vload(%address); # stride, mask
+ # 'reg -> VMEM : vstore(%value, %address); # stride, mask
+ # HBM -> VMEM : dma(%from, %to, %size) size0^latency0, size1^latency1; # syntax?
+ # VMEM -> HBM : dma();
+ # HBM -> SMEM : dma();
+ # SMEM -> HBM : dma();
+
+ # GPU-like memory:
+ # GLOBAL -> 'reg : ld(%address);
+ # SHARED -> 'reg : lds(%address);
+ # LOCAL -> 'reg : ldl(%address);
+ # 'reg -> GLOBAL : st(%value, %address);
+ # 'reg -> SHARED : sts(%value, %address);
+ # 'reg -> LOCAL : stl(%value, %address);
+
+ # CPU-like memory:
+ # MEMORY -> $L2 -> $L1 -> 'reg : load(%address);
+ # 'reg -> $L1 -> $L2 -> MEMORY : store(%value, %address);
+}
diff --git a/LAVM/test/md_new_syntax.md b/LAVM/test/md_new_syntax.md
new file mode 100644
index 0000000..99118ae
--- /dev/null
+++ b/LAVM/test/md_new_syntax.md
@@ -0,0 +1,36 @@
+# To resolve:
+# ? how to define the container (storage) meanings (GPR, i32imm, etc.)
+# ? how to define the result types
+# ? whether need to distinguish between rr, ri, rs instruction variations
+#
+
+md {
+ def tgt.halt {
+ mnemonics = "_ = shalt${pred}";
+ outs = ();
+ ins = ();
+ pattern = (lavm.halt);
+ llvm = HALT;
+ }
+ def tgt.addrr {
+ mnemonics = "$d = sadd.s32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set (i32 GPR:$d), (lavm.add (i32 GPR:$x), (i32 GPR:$y)));
+ llvm = ADDrr;
+ }
+ def tgt.addri {
+ mnemonics = "$d = sadd.s32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, i32imm:$y);
+ pattern = (set (i32 GPR:$d), (lavm.add (i32 GPR:$x), (i32 imm:$y)));
+ llvm = ADDri;
+ }
+ def tgt.faddrr {
+ mnemonics = "$d = sadd.f32${pred} $x, $y";
+ outs = (GPR:$d);
+ ins = (GPR:$x, GPR:$y);
+ pattern = (set (f32 GPR:$d), (lavm.add (f32 GPR:$x), (f32 GPR:$y)));
+ llvm = FADDrr;
+ }
+}
diff --git a/LAVM/test/mt.mlir b/LAVM/test/mt.mlir
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/LAVM/test/mt.mlir
diff --git a/LAVM/test/simple.md b/LAVM/test/simple.md
new file mode 100644
index 0000000..e3d260b
--- /dev/null
+++ b/LAVM/test/simple.md
@@ -0,0 +1,10 @@
+target {
+ tgt.neg(%x) : f32;
+ tgt.add(%x, %y) : f32, i32;
+ tgt.i2f(%x) : (i32) -> f32;
+}
+
+map {
+ lavm.neg(%a) : tgt.neg(%a), tgt.neg(tgt.i2f(%a));
+ lavm.sub(%a, %b) : tgt.add(%a, lavm.neg(%b));
+}
diff --git a/LAVM/test/xyz.md b/LAVM/test/xyz.md
new file mode 100644
index 0000000..c31cebe
--- /dev/null
+++ b/LAVM/test/xyz.md
@@ -0,0 +1,16 @@
+target {
+ pseudo.combine(%x) : (2x8) -> 8x8;
+ pseudo.extract(%x) : (8x8) -> 2x8;
+ tgt.load(%addr) : (mem) -> 2x8;
+ tgt.store(%val, %addr) : (2x8, mem) -> ();
+ tgt.xyz(%x, %y) : 8x8;
+}
+
+map {
+ lavm.load(%addr) : tgt.load(%addr);
+ lavm.store(%val, %addr) : tgt.store(%val, %addr);
+ lavm.xyz(%d, %a, %b) : lavm.store(pseudo.extract(
+ tgt.xyz(pseudo.combine(lavm.load(%a)),
+ pseudo.combine(lavm.load(%b)))),
+ %d);
+}