blob: 8ed6fae61077c03810ec633ec37b93cdf9223f91 [file] [log] [blame] [edit]
#include <string>
#include "base/googleinit.h"
#include "learning/deepmind/partir/compiler/rewrites/custom_call_registry.h"
#include "platforms/xla/mosaic/custom_call_kernel_name.h"
#include "third_party/absl/log/check.h"
#include "third_party/absl/status/statusor.h"
#include "third_party/llvm/llvm-project/mlir/include/mlir/IR/Operation.h"
namespace {
// Add to the partir custom call registry a callback that extracts the name of
// the kernel from a tpu_custom_call call. This registration is performed at
// InitGoogle time, this works because the partir custom call registry is a
// static local variable.
REGISTER_MODULE_INITIALIZER(mosaic_partir_name_callback, {
CHECK_OK(deepmind::partir::CustomCallRegistry::RegisterOpNameCallback(
"tpu_custom_call",
[](mlir::Operation* op) -> absl::StatusOr<std::string> {
return xla::mosaic::CustomCallKernelName(op);
}));
});
} // namespace