| #include <string> |
| |
| #include "base/googleinit.h" |
| #include "learning/deepmind/partir/compiler/rewrites/custom_call_registry.h" |
| #include "platforms/xla/mosaic/custom_call_kernel_name.h" |
| #include "third_party/absl/log/check.h" |
| #include "third_party/absl/status/statusor.h" |
| #include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h" |
| |
| namespace { |
| |
| // Add to the partir custom call registry a callback that extracts the name of |
| // the kernel from a tpu_custom_call call. This registration is performed at |
| // InitGoogle time, this works because the partir custom call registry is a |
| // static local variable. |
| REGISTER_MODULE_INITIALIZER(mosaic_partir_name_callback, { |
| CHECK_OK(deepmind::partir::CustomCallRegistry::RegisterOpNameCallback( |
| "tpu_custom_call", |
| [](mlir::Operation* op) -> absl::StatusOr<std::string> { |
| return xla::mosaic::CustomCallKernelName(op); |
| })); |
| }); |
| |
| } // namespace |