Adds a model of the Ibex HW invoker.

PiperOrigin-RevId: 709106784
Change-Id: Ie7d0ebf1156a0fa805908313f14b5ae8205986e0
diff --git a/cheriot/BUILD b/cheriot/BUILD
index 6bb599d..f4717c8 100644
--- a/cheriot/BUILD
+++ b/cheriot/BUILD
@@ -639,6 +639,27 @@
 )
 
 cc_library(
+    name = "cheriot_ibex_hw_revoker",
+    srcs = [
+        "cheriot_ibex_hw_revoker.cc",
+    ],
+    hdrs = [
+        "cheriot_ibex_hw_revoker.h",
+    ],
+    deps = [
+        ":cheriot_state",
+        "@com_google_absl//absl/log",
+        "@com_google_absl//absl/log:check",
+        "@com_google_absl//absl/strings",
+        "@com_google_mpact-riscv//riscv:riscv_plic",
+        "@com_google_mpact-sim//mpact/sim/generic:core",
+        "@com_google_mpact-sim//mpact/sim/generic:counters",
+        "@com_google_mpact-sim//mpact/sim/generic:instruction",
+        "@com_google_mpact-sim//mpact/sim/util/memory",
+    ],
+)
+
+cc_library(
     name = "instrumentation",
     srcs = [
         "cheriot_instrumentation_control.cc",
diff --git a/cheriot/cheriot_ibex_hw_revoker.cc b/cheriot/cheriot_ibex_hw_revoker.cc
new file mode 100644
index 0000000..4df1238
--- /dev/null
+++ b/cheriot/cheriot_ibex_hw_revoker.cc
@@ -0,0 +1,324 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "cheriot/cheriot_ibex_hw_revoker.h"
+
+#include <cstdint>
+#include <cstring>
+
+#include "absl/log/log.h"
+#include "cheriot/cheriot_register.h"
+#include "mpact/sim/generic/data_buffer.h"
+#include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/generic/ref_count.h"
+#include "mpact/sim/util/memory/memory_interface.h"
+#include "mpact/sim/util/memory/tagged_memory_interface.h"
+#include "riscv//riscv_plic.h"
+
+namespace mpact {
+namespace sim {
+namespace cheriot {
+
+using ::mpact::sim::generic::DataBuffer;
+using ::mpact::sim::generic::Instruction;
+using ::mpact::sim::generic::ReferenceCount;
+using ::mpact::sim::riscv::RiscVPlicIrqInterface;
+using ::mpact::sim::util::MemoryInterface;
+using ::mpact::sim::util::TaggedMemoryInterface;
+
+CheriotIbexHWRevoker::CheriotIbexHWRevoker(RiscVPlicIrqInterface *plic_irq,
+                                           uint64_t heap_base,
+                                           uint64_t heap_size,
+                                           TaggedMemoryInterface *heap_memory,
+                                           uint64_t revocation_bits_base,
+                                           MemoryInterface *revocation_memory)
+    : plic_irq_(plic_irq),
+      heap_base_(heap_base),
+      heap_max_(heap_base + heap_size),
+      heap_memory_(heap_memory),
+      revocation_memory_(revocation_memory),
+      revocation_bits_base_(revocation_bits_base) {
+  cap_reg_ = new CheriotRegister(nullptr, "filter_cap");
+  db_ = db_factory_.Allocate<uint32_t>(2);
+  db_->Set<uint32_t>(0, 0);
+  db_->Set<uint32_t>(1, 0);
+  db_->set_latency(0);
+  cap_reg_->SetDataBuffer(db_);
+  db_->DecRef();
+  // Allocate data buffers used in loads/stores.
+  db_ = db_factory_.Allocate<uint8_t>(CheriotRegister::kCapabilitySizeInBytes);
+  db_->set_latency(0);
+  tag_db_ = db_factory_.Allocate<uint8_t>(1);
+  tag_db_->set_latency(0);
+  Reset();
+}
+
+CheriotIbexHWRevoker::CheriotIbexHWRevoker(uint64_t heap_base,
+                                           uint64_t heap_size,
+                                           TaggedMemoryInterface *heap_memory,
+                                           uint64_t revocation_bits_base,
+                                           MemoryInterface *revocation_memory)
+    : CheriotIbexHWRevoker(nullptr, heap_base, heap_size, heap_memory,
+                           revocation_bits_base, revocation_memory) {}
+
+CheriotIbexHWRevoker::~CheriotIbexHWRevoker() {
+  delete cap_reg_;
+  db_->DecRef();
+  tag_db_->DecRef();
+}
+
+// Reset state to initial values.
+void CheriotIbexHWRevoker::Reset() {
+  num_calls_ = 0;
+  start_address_ = 0;
+  end_address_ = 0;
+  go_ = 0;
+  epoch_ = 0;
+  interrupt_enable_ = 0;
+  interrupt_status_ = 0;
+}
+
+// This is called by the counter using the CounterValueSetInterface interface.
+void CheriotIbexHWRevoker::SetValue(const uint64_t &val) {
+  if (interrupt_status_) SetInterrupt(false);
+  if (!sweep_in_progress_) return;
+  num_calls_++;
+  if (num_calls_ >= period_) {
+    num_calls_ = 0;
+    Revoke();
+  }
+}
+
+// Reads from the MMRs.
+void CheriotIbexHWRevoker::Load(uint64_t address, DataBuffer *db,
+                                DataBuffer *tags, Instruction *inst,
+                                ReferenceCount *context) {
+  if (tags != nullptr) std::memset(tags->raw_ptr(), 0, tags->size<uint8_t>());
+  Load(address, db, inst, context);
+}
+
+// Reads from the MMRs.
+void CheriotIbexHWRevoker::Load(uint64_t address, DataBuffer *db,
+                                Instruction *inst, ReferenceCount *context) {
+  uint32_t offset = address & 0xffff;
+  switch (db->size<uint8_t>()) {
+    case 1:
+      db->Set<uint32_t>(0, static_cast<uint8_t>(Read(offset)));
+      break;
+    case 2:
+      db->Set<uint32_t>(0, static_cast<uint16_t>(Read(offset)));
+      break;
+    case 4:
+      db->Set<uint32_t>(0, static_cast<uint32_t>(Read(offset)));
+      break;
+    case 8:
+      db->Set<uint32_t>(0, static_cast<uint64_t>(Read(offset)));
+      break;
+    default:
+      ::memset(db->raw_ptr(), 0, sizeof(db->size<uint8_t>()));
+      break;
+  }
+  // Execute the instruction to process and write back the load data.
+  if (nullptr != inst) {
+    if (db->latency() > 0) {
+      inst->IncRef();
+      if (context != nullptr) context->IncRef();
+      inst->state()->function_delay_line()->Add(db->latency(),
+                                                [inst, context]() {
+                                                  inst->Execute(context);
+                                                  if (context != nullptr)
+                                                    context->DecRef();
+                                                  inst->DecRef();
+                                                });
+    } else {
+      inst->Execute(context);
+    }
+  }
+}
+
+// Vector load is not supported.
+void CheriotIbexHWRevoker::Load(DataBuffer *address_db, DataBuffer *mask_db,
+                                int el_size, DataBuffer *db, Instruction *inst,
+                                ReferenceCount *context) {
+  // This is left unimplemented. Vector load is not supported.
+  LOG(FATAL) << "Vector load not supported";  // Crash OK
+}
+
+// Writes to the MMRs.
+void CheriotIbexHWRevoker::Store(uint64_t address, DataBuffer *db,
+                                 DataBuffer *tags) {
+  Store(address, db);
+}
+
+// Writes to the MMRs.
+void CheriotIbexHWRevoker::Store(uint64_t address, DataBuffer *db) {
+  uint32_t offset = address & 0xffff;
+  switch (db->size<uint8_t>()) {
+    case 1:
+      return Write(offset, static_cast<uint32_t>(db->Get<uint8_t>(0)));
+    case 2:
+      return Write(offset, static_cast<uint32_t>(db->Get<uint16_t>(0)));
+    case 4:
+      return Write(offset, static_cast<uint32_t>(db->Get<uint32_t>(0)));
+    case 8:
+      return Write(offset, static_cast<uint32_t>(db->Get<uint32_t>(0)));
+      return Write(offset + 4, static_cast<uint32_t>(db->Get<uint32_t>(1)));
+    default:
+      return;
+  }
+}
+
+// Vector accesses are not supported.
+void CheriotIbexHWRevoker::Store(DataBuffer *address, DataBuffer *mask,
+                                 int el_size, DataBuffer *db) {
+  // This is left unimplemented. Vector store is not supported.
+  LOG(FATAL) << "Vector store not supported";  // Crash OK
+}
+
+uint32_t CheriotIbexHWRevoker::Read(uint32_t offset) {
+  uint32_t value = 0;
+  switch (offset) {
+    case kStartAddressOffset:  // start address
+      value = start_address_;
+      break;
+    case kEndAddressOffset:  // end address
+      value = end_address_;
+      break;
+    case kGoOffset:  // go
+      value = 0x5500'0000 | (go_ & 0x00ff'ffff);
+      break;
+    case kEpochOffset:  // epoch
+      value = (epoch_ << 1) | (sweep_in_progress_ ? 0b1 : 0b0);
+      break;
+    case kStatusOffset:  // stat q
+      value = interrupt_enable_ ? interrupt_status_ : 0;
+      break;
+    case kInterruptEnableOffset:  // interrupt enable
+      value = interrupt_enable_ & 0b1;
+      break;
+    default:
+      value = 0;
+      break;
+  }
+  return value;
+}
+
+// Vector store is not supported.
+void CheriotIbexHWRevoker::Write(uint32_t offset, uint32_t value) {
+  switch (offset) {
+    case kStartAddressOffset:  // start address
+      start_address_ = value;
+      break;
+    case kEndAddressOffset:  // end address
+      end_address_ = value;
+      break;
+    case kGoOffset:  // go
+      WriteGo();
+      go_ = value;
+      break;
+    // 0x000c:  Epoch is not writable.
+    case kStatusOffset:
+      interrupt_status_ = 0;
+      plic_irq_->SetIrq(false);
+      break;
+    case kInterruptEnableOffset:  // interrupt enable
+      interrupt_enable_ = value & 0b1;
+      ;
+      break;
+    default:
+      return;
+  }
+}
+
+void CheriotIbexHWRevoker::WriteGo() {
+  if (sweep_in_progress_) return;
+  sweep_in_progress_ = true;
+  current_cap_ = 0;
+  num_calls_ = 0;
+  epoch_ = 0;
+}
+
+void CheriotIbexHWRevoker::Revoke() {
+  if (!sweep_in_progress_) return;
+  // Increment the epoch.
+  epoch_++;
+  uint64_t cap_address = start_address_ + (current_cap_++ << 3);
+  // Align address to the capability size.
+  cap_address &= ~0b111ULL;
+  ProcessCapability(cap_address);
+  // Check to see if we have reached the end of the region.
+  if (cap_address >= end_address_) {
+    sweep_in_progress_ = false;
+    SetInterrupt(true);
+  }
+}
+
+// Process the capability at the given address.
+void CheriotIbexHWRevoker::ProcessCapability(uint64_t address) {
+  if ((address < start_address_) || (address >= end_address_)) return;
+  // Load the capability.
+  heap_memory_->Load(address, db_, tag_db_, nullptr, nullptr);
+  // If the tag is 0, no need to go on.
+  auto tag = tag_db_->Get<uint8_t>(0);
+  if (tag == 0) return;
+
+  // Expand the capability. Return if the tag is not valid.
+  cap_reg_->Expand(db_->Get<uint32_t>(0), db_->Get<uint32_t>(1), tag);
+  if (!cap_reg_->tag()) return;
+
+  // Check for revocation.
+  if (!MustRevoke(cap_reg_->base())) return;
+
+  // Invalidate and store the capability back to memory.
+  cap_reg_->Invalidate();
+  db_->Set<uint32_t>(0, cap_reg_->address());
+  db_->Set<uint32_t>(1, cap_reg_->Compress());
+  tag_db_->Set<uint8_t>(0, cap_reg_->tag());
+  heap_memory_->Store(address, db_, tag_db_);
+}
+
+// Check if the capability must be revoked.
+bool CheriotIbexHWRevoker::MustRevoke(uint64_t address) {
+  if (address < heap_base_) return false;
+  if (address >= heap_max_) return false;
+  // Compute the address of the byte containing the revocation information.
+  uint64_t offset = address - heap_base_;
+  // Shift by 3 for the size of the capability (8), and by 3 for 8 bits in a
+  // byte.
+  uint64_t revocation_offset = offset >> 6;
+  revocation_memory_->Load(revocation_bits_base_ + revocation_offset, tag_db_,
+                           nullptr, nullptr);
+  // Get the revocation bit.
+  uint8_t revocation_bits = tag_db_->Get<uint8_t>(0);
+  int bit_offset = (offset >> 3) & 0b111;
+  bool result = revocation_bits & (1 << bit_offset);
+  return result;
+}
+
+void CheriotIbexHWRevoker::SetInterrupt(bool value) {
+  if (!value) {
+    plic_irq_->SetIrq(false);
+    interrupt_status_ = 0;
+    return;
+  }
+  interrupt_status_ = static_cast<uint32_t>(value);
+
+  if (!interrupt_enable_) return;
+
+  if (plic_irq_ != nullptr) plic_irq_->SetIrq(value);
+}
+
+}  // namespace cheriot
+}  // namespace sim
+}  // namespace mpact
diff --git a/cheriot/cheriot_ibex_hw_revoker.h b/cheriot/cheriot_ibex_hw_revoker.h
new file mode 100644
index 0000000..2fd7f5d
--- /dev/null
+++ b/cheriot/cheriot_ibex_hw_revoker.h
@@ -0,0 +1,168 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MPACT_CHERIOT_CHERIOT_IBEX_HW_REVOKER_H_
+#define MPACT_CHERIOT_CHERIOT_IBEX_HW_REVOKER_H_
+
+#include <cstdint>
+
+#include "cheriot/cheriot_register.h"
+#include "mpact/sim/generic/counters_base.h"
+#include "mpact/sim/generic/data_buffer.h"
+#include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/generic/ref_count.h"
+#include "mpact/sim/util/memory/memory_interface.h"
+#include "mpact/sim/util/memory/tagged_memory_interface.h"
+#include "riscv//riscv_plic.h"
+
+// This file contains the class declaration for the model of the Ibex HW
+// revoker for Cheriot. The HW revoker is a module that is used to invalidate
+// (or revoke the validity of) capabilities pointing to a freed portion of heap
+// memory. It is controlled by a set of memory mapped registers.
+//
+// The HW revoker is implemented as a counter value set object. It is bound
+// to a counter that is incremented whenever an instruction is executed, and,
+// when active, performs an action every 'period' number of increments
+// (configurable).
+//
+// The HW revoker is programmed using a memory interface. It supports non-vector
+// loads and stores only.
+//
+// The HW revoker is described in more detail in the following documents:
+// https://lowrisc.github.io/sonata-system/doc/ip/revoker.html
+// https://github.com/microsoft/cheriot-safe/blob/main/src/msft_cheri_subsystem/msftDvIp_mmreg.sv
+
+namespace mpact {
+namespace sim {
+namespace cheriot {
+
+using ::mpact::sim::generic::CounterValueSetInterface;
+using ::mpact::sim::generic::DataBuffer;
+using ::mpact::sim::generic::DataBufferFactory;
+using ::mpact::sim::generic::Instruction;
+using ::mpact::sim::generic::ReferenceCount;
+using ::mpact::sim::riscv::RiscVPlic;
+using ::mpact::sim::riscv::RiscVPlicIrqInterface;
+using ::mpact::sim::util::MemoryInterface;
+using ::mpact::sim::util::TaggedMemoryInterface;
+
+class CheriotIbexHWRevoker : public CounterValueSetInterface<uint64_t>,
+                             public TaggedMemoryInterface {
+ public:
+  static constexpr uint32_t kStartAddressOffset = 0x0000;
+  static constexpr uint32_t kEndAddressOffset = 0x0004;
+  static constexpr uint32_t kGoOffset = 0x0008;
+  static constexpr uint32_t kEpochOffset = 0x000c;
+  static constexpr uint32_t kStatusOffset = 0x0010;
+  static constexpr uint32_t kInterruptEnableOffset = 0x0014;
+
+  CheriotIbexHWRevoker() = delete;
+  CheriotIbexHWRevoker(const CheriotIbexHWRevoker &) = delete;
+  CheriotIbexHWRevoker &operator=(const CheriotIbexHWRevoker &) = delete;
+  CheriotIbexHWRevoker(RiscVPlicIrqInterface *plic_irq, uint64_t heap_base,
+                       uint64_t heap_size, TaggedMemoryInterface *heap_memory,
+                       uint64_t revocation_bits_base,
+                       MemoryInterface *revocation_memory);
+  CheriotIbexHWRevoker(uint64_t heap_base, uint64_t heap_size,
+                       TaggedMemoryInterface *heap_memory,
+                       uint64_t revocation_bits_base,
+                       MemoryInterface *revocation_memory);
+  ~CheriotIbexHWRevoker();
+  // Resets the interrupt controller.
+  void Reset();
+  // CounterValueSetInterface override. This is called when the value of the
+  // bound counter is modified.
+  void SetValue(const uint64_t &val) override;
+
+  // MemoryInterface overrides.
+  // Non-vector load method.
+  void Load(uint64_t address, DataBuffer *db, DataBuffer *tags,
+            Instruction *inst, ReferenceCount *context) override;
+  void Load(uint64_t address, DataBuffer *db, Instruction *inst,
+            ReferenceCount *context) override;
+  // Vector load method - this is stubbed out.
+  void Load(DataBuffer *address_db, DataBuffer *mask_db, int el_size,
+            DataBuffer *db, Instruction *inst,
+            ReferenceCount *context) override;
+  // Non-vector store method.
+  void Store(uint64_t address, DataBuffer *db, DataBuffer *tags) override;
+  void Store(uint64_t address, DataBuffer *dbs) override;
+  // Vector store method - this is stubbed out.
+  void Store(DataBuffer *address, DataBuffer *mask, int el_size,
+             DataBuffer *db) override;
+
+  // Getters & setters.
+  void set_plic_irq(RiscVPlicIrqInterface *plic_irq) { plic_irq_ = plic_irq; }
+  int period() const { return period_; }
+  void set_period(int period) { period_ = period; }
+  int cap_count() const { return cap_count_; }
+  void set_cap_count(int cap_count) { cap_count_ = cap_count; }
+  uint64_t revocation_bits_base() const { return revocation_bits_base_; }
+  void set_revocation_bits_base(uint64_t revocation_bits_base) {
+    revocation_bits_base_ = revocation_bits_base;
+  }
+
+ private:
+  // MMR read/write methods.
+  uint32_t Read(uint32_t offset);
+  void Write(uint32_t offset, uint32_t value);
+  void WriteGo();
+  // Perform an iteration of revocation.
+  void Revoke();
+  void ProcessCapability(uint64_t address);
+  bool MustRevoke(uint64_t address);
+  void SetInterrupt(bool value);
+  // The number of times SetValue is called before triggering a revocation
+  // operation.
+  int period_ = 1;
+  // Tracker for the number of times SetValue is called.
+  int num_calls_ = 0;
+  // The max number of capabilities to revoke in a single operation.
+  int cap_count_ = 0;
+  // Current capability index.
+  uint64_t current_cap_ = 0;
+  // Sweep in progress.
+  bool sweep_in_progress_ = false;
+  RiscVPlicIrqInterface *plic_irq_ = nullptr;
+  // Heap range.
+  uint64_t heap_base_ = 0;
+  uint64_t heap_max_ = 0;
+  // Memory interface to use for the tagged heap.
+  TaggedMemoryInterface *heap_memory_ = nullptr;
+  // Memory interface to use for the revocation bits.
+  MemoryInterface *revocation_memory_ = nullptr;
+  // Data buffers.
+  DataBuffer *db_ = nullptr;
+  DataBuffer *tag_db_ = nullptr;
+  // Capability register.
+  CheriotRegister *cap_reg_ = nullptr;
+  // Base address of the revocation bits.
+  uint64_t revocation_bits_base_ = 0;
+  // DB factory.
+  DataBufferFactory db_factory_;
+
+  // MMRs
+  uint64_t start_address_ = 0;
+  uint64_t end_address_ = 0;
+  uint32_t go_ = 0;
+  uint32_t epoch_ = 0;
+  uint32_t interrupt_enable_ = 0;
+  uint32_t interrupt_status_ = 0;
+};
+
+}  // namespace cheriot
+}  // namespace sim
+}  // namespace mpact
+
+#endif  // MPACT_CHERIOT_CHERIOT_IBEX_HW_REVOKER_H_
diff --git a/cheriot/cheriot_load_filter.cc b/cheriot/cheriot_load_filter.cc
index 8ef1ca9..f7f3439 100644
--- a/cheriot/cheriot_load_filter.cc
+++ b/cheriot/cheriot_load_filter.cc
@@ -89,7 +89,7 @@
   db_->Set<uint32_t>(0, cap_reg_->address());
   db_->Set<uint32_t>(1, cap_reg_->Compress());
   tag_db_->Set<uint8_t>(0, cap_reg_->tag());
-  tagged_memory_->Store(cap_address_, db_, tag_db_);
+  tagged_memory_->Store(address, db_, tag_db_);
 }
 
 // Check if the capability must be revoked.
diff --git a/cheriot/test/BUILD b/cheriot/test/BUILD
index 4ba3cbb..5cbf103 100644
--- a/cheriot/test/BUILD
+++ b/cheriot/test/BUILD
@@ -541,3 +541,20 @@
         "@com_google_mpact-sim//mpact/sim/generic:instruction",
     ],
 )
+
+cc_test(
+    name = "cheriot_ibex_hw_revoker_test",
+    size = "small",
+    srcs = [
+        "cheriot_ibex_hw_revoker_test.cc",
+    ],
+    deps = [
+        "//cheriot:cheriot_ibex_hw_revoker",
+        "//cheriot:cheriot_state",
+        "@com_google_googletest//:gtest_main",
+        "@com_google_mpact-riscv//riscv:riscv_plic",
+        "@com_google_mpact-sim//mpact/sim/generic:core",
+        "@com_google_mpact-sim//mpact/sim/generic:instruction",
+        "@com_google_mpact-sim//mpact/sim/util/memory",
+    ],
+)
diff --git a/cheriot/test/cheriot_ibex_hw_revoker_test.cc b/cheriot/test/cheriot_ibex_hw_revoker_test.cc
new file mode 100644
index 0000000..cb5c7f6
--- /dev/null
+++ b/cheriot/test/cheriot_ibex_hw_revoker_test.cc
@@ -0,0 +1,357 @@
+// Copyright 2024 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "cheriot/cheriot_ibex_hw_revoker.h"
+
+#include <sys/types.h>
+
+#include <cstdint>
+#include <cstring>
+
+#include "cheriot/cheriot_register.h"
+#include "googlemock/include/gmock/gmock.h"
+#include "mpact/sim/generic/data_buffer.h"
+#include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/util/memory/flat_demand_memory.h"
+#include "mpact/sim/util/memory/tagged_flat_demand_memory.h"
+#include "mpact/sim/util/memory/tagged_memory_interface.h"
+#include "riscv//riscv_plic.h"
+
+// This file contains unit tests for the CheriotIbexHWRevoker class.
+
+namespace {
+
+using ::mpact::sim::cheriot::CheriotIbexHWRevoker;
+using ::mpact::sim::cheriot::CheriotRegister;
+using ::mpact::sim::generic::DataBuffer;
+using ::mpact::sim::generic::DataBufferFactory;
+using ::mpact::sim::generic::Instruction;
+using ::mpact::sim::generic::ReferenceCount;
+using ::mpact::sim::riscv::RiscVPlicIrqInterface;
+using ::mpact::sim::util::FlatDemandMemory;
+using ::mpact::sim::util::TaggedFlatDemandMemory;
+using ::mpact::sim::util::TaggedMemoryInterface;
+
+constexpr uint64_t kRevocationBase = 0x200'0000;
+constexpr uint64_t kHeapBase = 0x8001'0000;
+constexpr uint64_t kSweepBase = 0x8000'0000;
+
+// Mock plic source interface.
+class MockPlicSource : public RiscVPlicIrqInterface {
+ public:
+  MockPlicSource() = default;
+  ~MockPlicSource() override = default;
+  void SetIrq(bool irq_value) override { irq_value_ = irq_value; }
+
+  bool irq_value() const { return irq_value_; }
+  void set_irq_value(bool value) { irq_value_ = value; }
+
+ private:
+  bool irq_value_ = false;
+};
+
+// This class is used to capture memory load/store addresses.
+class MemoryViewer : public TaggedMemoryInterface {
+ public:
+  MemoryViewer() = delete;
+  MemoryViewer(TaggedMemoryInterface *memory) : memory_(memory) {}
+  ~MemoryViewer() override = default;
+
+  void Load(uint64_t address, DataBuffer *db, DataBuffer *tags,
+            Instruction *inst, ReferenceCount *context) override {
+    ld_address_ = address;
+    memory_->Load(address, db, tags, inst, context);
+  }
+  void Load(uint64_t address, DataBuffer *db, Instruction *inst,
+            ReferenceCount *context) override {
+    ld_address_ = address;
+    memory_->Load(address, db, inst, context);
+  }
+  void Load(DataBuffer *address_db, DataBuffer *mask_db, int el_size,
+            DataBuffer *db, Instruction *inst,
+            ReferenceCount *context) override {
+    ld_address_ = address_db->Get<uint64_t>(0);
+    memory_->Load(address_db, mask_db, el_size, db, inst, context);
+  }
+  void Store(uint64_t address, DataBuffer *db, DataBuffer *tags) override {
+    st_address_ = address;
+    memory_->Store(address, db, tags);
+  }
+  void Store(uint64_t address, DataBuffer *db) override {
+    st_address_ = address;
+    memory_->Store(address, db);
+  }
+  void Store(DataBuffer *address_db, DataBuffer *mask_db, int el_size,
+             DataBuffer *db) override {
+    st_address_ = address_db->Get<uint64_t>(0);
+    memory_->Store(address_db, mask_db, el_size, db);
+  }
+
+  uint64_t ld_address() const { return ld_address_; }
+  uint64_t st_address() const { return st_address_; }
+
+ private:
+  TaggedMemoryInterface *memory_ = nullptr;
+  uint64_t ld_address_ = 0;
+  uint64_t st_address_ = 0;
+};
+
+class CheriotIbexHwRevokerTest : public ::testing::Test {
+ protected:
+  CheriotIbexHwRevokerTest() {
+    db1_ = db_factory_.Allocate<uint8_t>(1);
+    db4_ = db_factory_.Allocate<uint32_t>(1);
+    db8_ = db_factory_.Allocate<uint32_t>(2);
+    db128_ = db_factory_.Allocate<uint8_t>(128);
+    plic_irq_ = new MockPlicSource();
+    heap_memory_ = new TaggedFlatDemandMemory(8);
+    memory_viewer_ = new MemoryViewer(heap_memory_);
+    revocation_memory_ = new FlatDemandMemory();
+    revoker_ =
+        new CheriotIbexHWRevoker(plic_irq_, kHeapBase, 0x8000, memory_viewer_,
+                                 kRevocationBase, revocation_memory_);
+    cap_reg_ = new CheriotRegister(nullptr, "cap");
+    cap_db_ = db_factory_.Allocate<uint32_t>(1);
+    cap_reg_->SetDataBuffer(cap_db_);
+  }
+
+  ~CheriotIbexHwRevokerTest() override {
+    db1_->DecRef();
+    db4_->DecRef();
+    db8_->DecRef();
+    db128_->DecRef();
+    cap_db_->DecRef();
+    delete plic_irq_;
+    delete revoker_;
+    delete heap_memory_;
+    delete memory_viewer_;
+    delete revocation_memory_;
+    delete cap_reg_;
+  }
+
+  // Call to advance the revoker.
+  void AdvanceRevoker() { revoker_->SetValue(0); }
+
+  // Convenience method to set the revocation bit for the given address.
+  void RevokeAddress(uint64_t address) {
+    if (address < kHeapBase) return;
+    uint64_t offset = address - kHeapBase;
+    offset >>= 3;
+    auto bit = offset & 0x7;
+    offset >>= 3;
+    revocation_memory_->Load(kRevocationBase + offset, db1_, nullptr, nullptr);
+    uint8_t val = db1_->Get<uint8_t>(0);
+    val |= 1 << bit;
+    db1_->Set<uint8_t>(0, val);
+    revocation_memory_->Store(kRevocationBase + offset, db1_);
+  }
+
+  // This clears the revocation bits for the memory range [kHeapBase, kHeapBase
+  // + 0x8000].
+  void ClearRevocationBits() {
+    std::memset(db128_->raw_ptr(), 0, db128_->size<uint8_t>());
+    uint64_t address = kRevocationBase;
+    for (uint64_t i = 0; i < 0x100; ++i) {
+      revocation_memory_->Store(address, db128_);
+      address += db128_->size<uint8_t>();
+    }
+  }
+
+  // The following methods are convenience methods for accessing the MMRs of
+  // the hw revoker using the revoker's memory interface.
+  void SetStartAddress(uint32_t address) {
+    db4_->Set<uint32_t>(0, address);
+    revoker_->Store(CheriotIbexHWRevoker::kStartAddressOffset, db4_);
+  }
+  uint32_t GetStartAddress() {
+    revoker_->Load(CheriotIbexHWRevoker::kStartAddressOffset, db4_, nullptr,
+                   nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+  void SetEndAddress(uint32_t address) {
+    db4_->Set<uint32_t>(0, address);
+    revoker_->Store(CheriotIbexHWRevoker::kEndAddressOffset, db4_);
+  }
+  uint32_t GetEndAddress() {
+    revoker_->Load(CheriotIbexHWRevoker::kEndAddressOffset, db4_, nullptr,
+                   nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+  void SetGo(uint32_t go) {
+    db4_->Set<uint32_t>(0, go);
+    revoker_->Store(CheriotIbexHWRevoker::kGoOffset, db4_);
+  }
+  uint32_t GetGo() {
+    revoker_->Load(CheriotIbexHWRevoker::kGoOffset, db4_, nullptr, nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+  void SetEpoch(uint32_t epoch) {
+    db4_->Set<uint32_t>(0, epoch);
+    revoker_->Store(CheriotIbexHWRevoker::kEpochOffset, db4_);
+  }
+  uint32_t GetEpoch() {
+    revoker_->Load(CheriotIbexHWRevoker::kEpochOffset, db4_, nullptr, nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+  void SetStatus(uint32_t status) {
+    db4_->Set<uint32_t>(0, status);
+    revoker_->Store(CheriotIbexHWRevoker::kStatusOffset, db4_);
+  }
+  uint32_t GetStatus() {
+    revoker_->Load(CheriotIbexHWRevoker::kStatusOffset, db4_, nullptr, nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+  void SetInterruptEnable(uint32_t enable) {
+    db4_->Set<uint32_t>(0, enable);
+    revoker_->Store(CheriotIbexHWRevoker::kInterruptEnableOffset, db4_);
+  }
+  uint32_t GetInterruptEnable() {
+    revoker_->Load(CheriotIbexHWRevoker::kInterruptEnableOffset, db4_, nullptr,
+                   nullptr);
+    return db4_->Get<uint32_t>(0);
+  }
+
+  // Convenience method to write a valid capability to memory with the given
+  // base.
+  void WriteCapability(uint64_t address, uint64_t base) {
+    cap_reg_->ResetMemoryRoot();
+    cap_reg_->SetAddress(base);
+    cap_reg_->SetBounds(base, 0x10);
+    db8_->Set<uint32_t>(0, cap_reg_->address());
+    db8_->Set<uint32_t>(1, cap_reg_->Compress());
+    db1_->Set<uint8_t>(0, true);
+    heap_memory_->Store(address, db8_, db1_);
+  }
+
+  // Convenience method to read a capability from memory with the given base.
+  CheriotRegister *ReadCapability(uint64_t address) {
+    heap_memory_->Load(address, db8_, db1_, nullptr, nullptr);
+    cap_reg_->Expand(db8_->Get<uint32_t>(0), db8_->Get<uint32_t>(1),
+                     db1_->Get<uint8_t>(0));
+    return cap_reg_;
+  }
+
+  uint64_t GetLoadAddress() { return memory_viewer_->ld_address(); }
+  uint64_t GetStoreAddress() { return memory_viewer_->st_address(); }
+  MockPlicSource *plic_irq() { return plic_irq_; }
+
+ private:
+  CheriotRegister *cap_reg_ = nullptr;
+  DataBufferFactory db_factory_;
+  DataBuffer *db1_;
+  DataBuffer *db4_;
+  DataBuffer *db8_;
+  DataBuffer *db128_;
+  DataBuffer *cap_db_;
+  CheriotIbexHWRevoker *revoker_ = nullptr;
+  MockPlicSource *plic_irq_ = nullptr;
+  TaggedFlatDemandMemory *heap_memory_ = nullptr;
+  FlatDemandMemory *revocation_memory_ = nullptr;
+  MemoryViewer *memory_viewer_ = nullptr;
+};
+
+// Initial state should all be clear.
+TEST_F(CheriotIbexHwRevokerTest, TestInitial) {
+  EXPECT_EQ(GetStartAddress(), 0);
+  EXPECT_EQ(GetEndAddress(), 0);
+  EXPECT_EQ(GetGo(), 0x5500'0000);
+  EXPECT_EQ(GetEpoch(), 0);
+  EXPECT_EQ(GetStatus(), 0);
+  EXPECT_EQ(GetInterruptEnable(), 0);
+}
+
+// No valid capabilities in the sweep range.
+TEST_F(CheriotIbexHwRevokerTest, RevokeNone) {
+  SetStartAddress(kSweepBase);
+  SetEndAddress(kSweepBase + 0x100);
+  SetGo(1);
+  EXPECT_EQ(GetGo(), 0x5500'0001);
+  // Expect zero status.
+  EXPECT_EQ(GetStatus(), 0);
+  // Expect sweep to be started.
+  EXPECT_EQ(GetEpoch(), 1);
+  // Step through 256/8 - 1 capabilities.
+  int num = 0x100 / 8;
+  for (int i = 0; i < num; ++i) {
+    AdvanceRevoker();
+    EXPECT_EQ(GetLoadAddress(), kSweepBase + (i << 3));
+    EXPECT_EQ(GetEpoch(), ((i + 1) << 1) | 1);
+    EXPECT_EQ(GetStatus(), 0);
+  }
+  // Step through the next capability. The sweep should be done.
+  AdvanceRevoker();
+  // Notice the in progress bit is cleared.
+  EXPECT_EQ(GetEpoch(), ((num + 1) << 1) | 0);
+  // Interrupt status should be 0, as interrupt enable is off.
+  EXPECT_EQ(GetStatus(), 0);
+}
+
+TEST_F(CheriotIbexHwRevokerTest, RevokeOne) {
+  // Write a capability at the sweep base.
+  for (auto offset = 0; offset < 0x100; offset += 0x8) {
+    WriteCapability(kSweepBase + offset, kHeapBase + offset);
+    auto *cap = ReadCapability(kSweepBase);
+    EXPECT_TRUE(cap->tag());
+  }
+  // Revoke one capability.
+  RevokeAddress(kHeapBase + 0x20);
+  // Set the sweep range to include the capability.
+  SetStartAddress(kSweepBase);
+  SetEndAddress(kSweepBase + 0x100);
+  SetGo(1);
+  // Expect zero status.
+  EXPECT_EQ(GetStatus(), 0);
+  // Expect sweep to be started.
+  EXPECT_EQ(GetEpoch(), 1);
+  // Step through the sweep.
+  while ((GetEpoch() & 0x1) == 1) {
+    AdvanceRevoker();
+  }
+  // Since interrupt enable is not set, the status should be zero.
+  EXPECT_EQ(GetStatus(), 0);
+  // Verify that only the one revoked capability was invalidated.
+  for (auto offset = 0; offset < 0x100; offset += 0x8) {
+    auto *cap = ReadCapability(kSweepBase + offset);
+    if (offset == 0x20) {
+      EXPECT_FALSE(cap->tag());
+    } else {
+      EXPECT_TRUE(cap->tag());
+    }
+  }
+}
+
+TEST_F(CheriotIbexHwRevokerTest, RevokeWithInterrupt) {
+  // Write a capability at the sweep base.
+  for (auto offset = 0; offset < 0x100; offset += 0x8) {
+    WriteCapability(kSweepBase + offset, kHeapBase + offset);
+    auto *cap = ReadCapability(kSweepBase);
+    EXPECT_TRUE(cap->tag());
+  }
+  // Revoke one capability.
+  RevokeAddress(kHeapBase + 0x20);
+  // Set the sweep range to include the capability.
+  SetStartAddress(kSweepBase);
+  SetEndAddress(kSweepBase + 0x100);
+  // Enable interrupt.
+  SetInterruptEnable(1);
+  SetGo(1);
+  while ((GetEpoch() & 0x1) == 1) {
+    AdvanceRevoker();
+  }
+  EXPECT_EQ(GetStatus(), 1);
+  // Verify that the interrupt was set.
+  EXPECT_TRUE(plic_irq()->irq_value());
+}
+
+}  // namespace