Prevent minstret increment on exceptions

- Update RiscVState::Trap to ensure the minstret CSR is not incremented when an exception occurs.
- Ensure that non-interrupt traps explicitly maintain the current minstret value.
- Add a unit test to verify that illegal instructions do not increment the minstret counter.

PiperOrigin-RevId: 862920200
Change-Id: I92e3fc3048b25d1cf975e6e2ed6e2ca8ed780bc3
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index db6c198..29f6ea5 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -665,6 +665,19 @@
 void RiscVState::Trap(bool is_interrupt, uint64_t trap_value,
                       uint64_t exception_code, uint64_t epc,
                       const Instruction* inst) {
+  if (!is_interrupt) {
+    auto minstret_res =
+        csr_set()->GetCsr(static_cast<uint64_t>(RiscVCsrEnum::kMInstret));
+    if (minstret_res.ok()) {
+      // If an exception causes a trap, the instruction did not retire.
+      // The minstret counter is implemented using RiscVPerformanceCounterCsr
+      // which increments its value as part of instruction processing before it
+      // is known if it will cause a trap. The write below corrects for this
+      // by decrementing the minstret value by 1. This is a side-effect of
+      // RiscVPerformanceCounterCsr::Set().
+      (*minstret_res)->Set((*minstret_res)->AsUint64());
+    }
+  }
   if (on_trap_ != nullptr) {
     bool res = on_trap_(is_interrupt, trap_value, exception_code, epc, inst);
     if (res) return;
diff --git a/riscv/riscv_state.h b/riscv/riscv_state.h
index a0666cf..bc7b2e8 100644
--- a/riscv/riscv_state.h
+++ b/riscv/riscv_state.h
@@ -446,6 +446,7 @@
   PrivilegeMode privilege_mode_ = PrivilegeMode::kMachine;
   // Flag set on branch instructions.
   bool branch_ = false;
+
   // Handles to frequently used CSRs.
   RiscVMStatus* mstatus_ = nullptr;
   RiscVMIsa* misa_ = nullptr;
diff --git a/riscv/test/riscv_top_test.cc b/riscv/test/riscv_top_test.cc
index 058b6d2..0e136c7 100644
--- a/riscv/test/riscv_top_test.cc
+++ b/riscv/test/riscv_top_test.cc
@@ -627,4 +627,37 @@
       << absl::StrJoin(failed_perf_counter_csr_names, ",");
 }
 
+// Test that executing an illegal instruction does not increment minstret.
+TEST_F(RiscVTopTest, IllegalInstructionTrap) {
+  uint32_t illegal_instruction = 0;
+  EXPECT_OK(
+      riscv_top_->WriteMemory(0x1000, &illegal_instruction, sizeof(uint32_t)));
+  EXPECT_OK(riscv_top_->WriteRegister("pc", 0x1000));
+  EXPECT_OK(riscv_top_->WriteRegister("minstret", 1));
+
+  bool trap_called = false;
+  state_->set_on_trap([&trap_called](bool is_interrupt, uint64_t trap_value,
+                                     uint64_t exception_code, uint64_t epc,
+                                     const Instruction* inst) {
+    trap_called = true;
+    EXPECT_FALSE(is_interrupt);
+    EXPECT_EQ(exception_code,
+              *mpact::sim::riscv::ExceptionCode::kIllegalInstruction);
+    EXPECT_EQ(epc, 0x1000);
+    return false;  // in order to exercise default trap handling
+  });
+
+  auto minstret_before = riscv_top_->ReadRegister("minstret");
+  EXPECT_OK(minstret_before.status());
+  EXPECT_EQ(minstret_before.value(), 0);
+
+  auto res = riscv_top_->Step(1);
+  EXPECT_OK(res.status());
+
+  EXPECT_TRUE(trap_called);
+  auto minstret_after = riscv_top_->ReadRegister("minstret");
+  EXPECT_OK(minstret_after.status());
+  EXPECT_EQ(minstret_after.value(), 0);
+}
+
 }  // namespace