Compensate for instruction increment when writing to counter CSRs

- Derive RiscVPerformanceCounterCsr{High} from RiscVCounterCsr
- Decrement offset in Set to account for the retiring instruction.
- High::Set handles carry and propagates offset adjustments for high counter CSRs.
- Add unit tests for minstret and minstreth to verify write compensation.

PiperOrigin-RevId: 862918613
Change-Id: I04dc67e1137dc9d82402006411d91a6e3c894423
diff --git a/riscv/riscv_counter_csr.h b/riscv/riscv_counter_csr.h
index c20c269..f2a0a4f 100644
--- a/riscv/riscv_counter_csr.h
+++ b/riscv/riscv_counter_csr.h
@@ -59,6 +59,9 @@
   // This is called to tie a cycle counter to the CSR.
   void set_counter(SimpleCounter<uint64_t>* counter) { counter_ = counter; }
 
+ protected:
+  T offset_ = 0;
+
  private:
   inline T GetCounterValue() const {
     if (counter_ == nullptr) return 0;
@@ -66,7 +69,28 @@
   };
 
   SimpleCounter<uint64_t>* counter_ = nullptr;
-  T offset_ = 0;
+};
+
+// This class implements the performance counter CSRs, which have slightly
+// different behavior on write due to how they interact with the pipeline.
+template <typename T, typename S>
+class RiscVPerformanceCounterCsr : public RiscVCounterCsr<T, S> {
+ public:
+  RiscVPerformanceCounterCsr(std::string name, RiscVCsrEnum csr_enum, S* state)
+      : RiscVCounterCsr<T, S>(name, csr_enum, state) {}
+  RiscVPerformanceCounterCsr(const RiscVPerformanceCounterCsr&) = delete;
+  RiscVPerformanceCounterCsr& operator=(const RiscVPerformanceCounterCsr&) =
+      delete;
+  ~RiscVPerformanceCounterCsr() override = default;
+
+  void Set(uint32_t value) override {
+    RiscVCounterCsr<T, S>::Set(value);
+    --this->offset_;
+  }
+  void Set(uint64_t value) override {
+    RiscVCounterCsr<T, S>::Set(value);
+    --this->offset_;
+  }
 };
 
 // This class implements the "high" version of the CSR on 32-bit RiscV.
@@ -92,22 +116,63 @@
   // Any value written to the CSR is used to create an offset from the current
   // value of the counter.
   void Set(uint32_t value) override {
-    offset_ = value - (GetCounterValue() >> 32);
+    uint64_t counter_val = GetCounterValue();
+    uint32_t counter_low = static_cast<uint32_t>(counter_val);
+    uint32_t low_offset = low_csr_->offset_;
+    // Check for carry when reconstructing logical low value.
+    bool carry =
+        (static_cast<uint64_t>(counter_low) + low_offset) >= 0x100000000ULL;
+    offset_ = value - (counter_val >> 32) - (carry ? 1 : 0);
   };
   void Set(uint64_t value) override { Set(static_cast<uint32_t>(value)); };
 
   // This is called to tie a cycle counter to the CSR.
   void set_counter(SimpleCounter<uint64_t>* counter) { counter_ = counter; }
 
- private:
+ protected:
   inline uint64_t GetCounterValue() const {
     if (counter_ == nullptr) return 0;
     return counter_->GetValue();
   };
 
+  uint32_t get_low_offset() const { return low_csr_->offset_; }
+  void decr_low_offset() { low_csr_->offset_--; }
+  uint64_t offset_ = 0;
+
+ private:
   RiscVCounterCsr<uint32_t, S>* low_csr_;
   SimpleCounter<uint64_t>* counter_ = nullptr;
-  uint64_t offset_ = 0;
+};
+
+// This class implements the "high" version of the performance counter CSRs,
+// which have slightly different behavior on write due to how they interact
+// with the pipeline.
+template <typename S>
+class RiscVPerformanceCounterCsrHigh : public RiscVCounterCsrHigh<S> {
+ public:
+  RiscVPerformanceCounterCsrHigh(std::string name, RiscVCsrEnum csr_enum,
+                                 S* state,
+                                 RiscVCounterCsr<uint32_t, S>* low_csr)
+      : RiscVCounterCsrHigh<S>(name, csr_enum, state, low_csr) {}
+  RiscVPerformanceCounterCsrHigh(const RiscVPerformanceCounterCsrHigh&) =
+      delete;
+  RiscVPerformanceCounterCsrHigh& operator=(
+      const RiscVPerformanceCounterCsrHigh&) = delete;
+  ~RiscVPerformanceCounterCsrHigh() override = default;
+
+  void Set(uint32_t value) override {
+    uint64_t counter_val = this->GetCounterValue();
+    uint32_t counter_low = static_cast<uint32_t>(counter_val);
+    uint32_t low_offset = this->get_low_offset();
+    // Check for carry when reconstructing logical low value.
+    bool carry =
+        (static_cast<uint64_t>(counter_low) + low_offset) >= 0x100000000ULL;
+    this->offset_ = value - (counter_val >> 32) - (carry ? 1 : 0);
+    if (this->get_low_offset() == 0) {
+      this->offset_--;
+    }
+    this->decr_low_offset();
+  }
 };
 
 }  // namespace mpact::sim::riscv
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index 135bbf9..b53ffb8 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -231,35 +231,35 @@
            nullptr);
 
   // minstret/minstreth
-  auto* minstret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+  auto* minstret = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
       state, csr_vec, "minstret", RiscVCsrEnum ::kMInstret, state);
   CHECK_NE(minstret, nullptr);
   if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
-        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
             state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state,
             reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(minstret)),
         nullptr);
   }
   // mcycle/mcycleh
-  auto* mcycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+  auto* mcycle = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
       state, csr_vec, "mcycle", RiscVCsrEnum::kMCycle, state);
   CHECK_NE(mcycle, nullptr);
   if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
-        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
             state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state,
             reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(mcycle)),
         nullptr);
   }
 
   // cycle / cycleh
-  auto* cycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+  auto* cycle = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
       state, csr_vec, "cycle", RiscVCsrEnum::kCycle, state);
   CHECK_NE(cycle, nullptr);
   if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
-        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
             state, csr_vec, "cycleh", RiscVCsrEnum::kCycleH, state,
             reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(cycle)),
         nullptr);
@@ -278,12 +278,12 @@
   }
 
   // instret / instreth
-  auto* instret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+  auto* instret = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
       state, csr_vec, "instret", RiscVCsrEnum::kInstret, state);
   CHECK_NE(instret, nullptr);
   if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
-        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
             state, csr_vec, "instreth", RiscVCsrEnum::kInstretH, state,
             reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(instret)),
         nullptr);
diff --git a/riscv/test/riscv_counter_csr_test.cc b/riscv/test/riscv_counter_csr_test.cc
index 0dc6dc4..abb053f 100644
--- a/riscv/test/riscv_counter_csr_test.cc
+++ b/riscv/test/riscv_counter_csr_test.cc
@@ -17,6 +17,8 @@
 using ::mpact::sim::riscv::RiscVCounterCsr;
 using ::mpact::sim::riscv::RiscVCounterCsrHigh;
 using ::mpact::sim::riscv::RiscVCsrEnum;
+using ::mpact::sim::riscv::RiscVPerformanceCounterCsr;
+using ::mpact::sim::riscv::RiscVPerformanceCounterCsrHigh;
 using ::mpact::sim::riscv::RiscVState;
 
 class RiscVMCycleTest : public ::testing::Test {
@@ -29,8 +31,8 @@
 
 // Verify the operation of 64 bit mcycle with counter increments.
 TEST_F(RiscVMCycleTest, GetTest64) {
-  RiscVCounterCsr<uint64_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
-                                               nullptr);
+  RiscVPerformanceCounterCsr<uint64_t, RiscVState> mcycle(
+      "mcycle", RiscVCsrEnum::kMCycle, nullptr);
   mcycle.set_counter(&counter_);
   // Initial value should be zero.
   EXPECT_EQ(mcycle.GetUint32(), 0);
@@ -39,19 +41,21 @@
   EXPECT_EQ(mcycle.GetUint32(), 1);
   EXPECT_EQ(mcycle.GetUint64(), 1);
   mcycle.Write(100u);
+  counter_.Increment(1);
   EXPECT_EQ(mcycle.GetUint32(), 100);
   EXPECT_EQ(mcycle.GetUint64(), 100);
   mcycle.Set(static_cast<uint64_t>(1000));
+  counter_.Increment(1);
   EXPECT_EQ(mcycle.GetUint32(), 1000);
   EXPECT_EQ(mcycle.GetUint64(), 1000);
 }
 
 // Verify the operation of 32 bit mcycle and mcycleh with counter increments.
 TEST_F(RiscVMCycleTest, GetTest32) {
-  RiscVCounterCsr<uint32_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
-                                               nullptr);
-  RiscVCounterCsrHigh<RiscVState> mcycleh("mcycleh", RiscVCsrEnum::kMCycleH,
-                                          nullptr, &mcycle);
+  RiscVPerformanceCounterCsr<uint32_t, RiscVState> mcycle(
+      "mcycle", RiscVCsrEnum::kMCycle, nullptr);
+  RiscVPerformanceCounterCsrHigh<RiscVState> mcycleh(
+      "mcycleh", RiscVCsrEnum::kMCycleH, nullptr, &mcycle);
   mcycle.set_counter(&counter_);
   mcycleh.set_counter(&counter_);
   // Initial value should be zero.
@@ -82,12 +86,13 @@
 
 // Test that write to mcycle is reflected in the value of mcycle.
 TEST_F(RiscVMCycleTest, SetTest64) {
-  RiscVCounterCsr<uint64_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
-                                               nullptr);
+  RiscVPerformanceCounterCsr<uint64_t, RiscVState> mcycle(
+      "mcycle", RiscVCsrEnum::kMCycle, nullptr);
   mcycle.set_counter(&counter_);
   EXPECT_EQ(mcycle.GetUint32(), 0);
   EXPECT_EQ(mcycle.GetUint64(), 0);
   mcycle.Write(100u);
+  counter_.Increment(1);
   EXPECT_EQ(mcycle.GetUint32(), 100);
   EXPECT_EQ(mcycle.GetUint64(), 100);
   counter_.Increment(10);
@@ -98,10 +103,10 @@
 // Test that write to mcycle and mcycleh is reflected in the value of mcycle and
 // mcycleh.
 TEST_F(RiscVMCycleTest, SetTest32) {
-  RiscVCounterCsr<uint32_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
-                                               nullptr);
-  RiscVCounterCsrHigh<RiscVState> mcycleh("mcycleh", RiscVCsrEnum::kMCycleH,
-                                          nullptr, &mcycle);
+  RiscVPerformanceCounterCsr<uint32_t, RiscVState> mcycle(
+      "mcycle", RiscVCsrEnum::kMCycle, nullptr);
+  RiscVPerformanceCounterCsrHigh<RiscVState> mcycleh(
+      "mcycleh", RiscVCsrEnum::kMCycleH, nullptr, &mcycle);
   mcycle.set_counter(&counter_);
   mcycleh.set_counter(&counter_);
   EXPECT_EQ(mcycle.GetUint32(), 0);
@@ -109,7 +114,9 @@
   EXPECT_EQ(mcycleh.GetUint32(), 0);
   EXPECT_EQ(mcycleh.GetUint64(), 0);
   mcycle.Write(100u);
+  counter_.Increment(1);
   mcycleh.Write(200u);
+  counter_.Increment(1);
   EXPECT_EQ(mcycle.GetUint32(), 100);
   EXPECT_EQ(mcycle.GetUint64(), 100);
   EXPECT_EQ(mcycleh.GetUint32(), 200);
@@ -126,4 +133,73 @@
   EXPECT_EQ(mcycleh.GetUint64(), 201);
 }
 
+// Verify the operation of minstret with counter increments and writes.
+// Specifically, ensure that writing to the CSR masks the increment of the
+// retiring instruction.
+class RiscVMInstretTest : public ::testing::Test {
+ protected:
+  RiscVMInstretTest() : counter_("instructions", 0) {};
+  ~RiscVMInstretTest() override = default;
+
+  SimpleCounter<uint64_t> counter_;
+};
+
+// Test that write to 64-bit minstret accounts for the instruction increment.
+TEST_F(RiscVMInstretTest, SetTest64) {
+  RiscVPerformanceCounterCsr<uint64_t, RiscVState> minstret(
+      "minstret", RiscVCsrEnum::kMInstret, nullptr);
+  minstret.set_counter(&counter_);
+
+  EXPECT_EQ(minstret.GetUint64(), 0);
+
+  // Write a value to the CSR.
+  minstret.Set(1000u);
+  // Simulate the instruction retirement (the instruction that did the write).
+  counter_.Increment(1);
+
+  // The value read should be exactly what was written, because writes to
+  // minstret don't increment the counter.
+  EXPECT_EQ(minstret.GetUint64(), 1000);
+
+  // Subsequent increments should work normally.
+  counter_.Increment(1);
+  EXPECT_EQ(minstret.GetUint64(), 1001);
+}
+
+// Test that writes to 32-bit minstret/minstreth account for the instruction
+// increment.
+TEST_F(RiscVMInstretTest, SetTest32) {
+  RiscVPerformanceCounterCsr<uint32_t, RiscVState> minstret(
+      "minstret", RiscVCsrEnum::kMInstret, nullptr);
+  RiscVPerformanceCounterCsrHigh<RiscVState> minstreth(
+      "minstreth", RiscVCsrEnum::kMInstretH, nullptr, &minstret);
+  minstret.set_counter(&counter_);
+  minstreth.set_counter(&counter_);
+
+  EXPECT_EQ(minstret.GetUint32(), 0);
+  EXPECT_EQ(minstreth.GetUint32(), 0);
+
+  // 1. Write Low CSR.
+  minstret.Set(100u);
+  counter_.Increment(1);
+  // Should equal 100 (increment compensated).
+  EXPECT_EQ(minstret.GetUint32(), 100);
+  EXPECT_EQ(minstreth.GetUint32(), 0);
+
+  // 2. Write High CSR.
+  minstreth.Set(5u);
+  counter_.Increment(1);
+  // Should equal 5 (increment compensated).
+  EXPECT_EQ(minstreth.GetUint32(), 5);
+  // Side effect: The increment compensation on the high write effectively
+  // stalls the low counter for this cycle. This is consistent with preserving
+  // the exact 64-bit value desired.
+  EXPECT_EQ(minstret.GetUint32(), 100);
+
+  // 3. Normal increment.
+  counter_.Increment(1);
+  EXPECT_EQ(minstret.GetUint32(), 101);
+  EXPECT_EQ(minstreth.GetUint32(), 5);
+}
+
 }  // namespace