Update mtval and stval CSRs on traps

- Save the trap_value to the appropriate xtval CSR during trap handling.
- Track mtval and stval pointers in RiscVState for direct access.
- Add unit tests to verify mtval and stval are correctly set on memory access faults.

PiperOrigin-RevId: 862919471
Change-Id: I857f8ce9df453281886b84352f4957b872906de1
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index b53ffb8..db6c198 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -226,7 +226,7 @@
                 CsrInfo<uint64_t>::kMstatusInitialValue, state, misa);
   CHECK_NE(mstatus, nullptr);
   // mtval
-  CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, csr_vec, "mtval",
+  CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, state->mtval_, csr_vec, "mtval",
                                         RiscVCsrEnum::kMTval, 0, state),
            nullptr);
 
@@ -389,7 +389,7 @@
                                         RiscVCsrEnum::kSEpc, 0, state),
            nullptr);
   // stval
-  CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, csr_vec, "stval",
+  CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, state->stval_, csr_vec, "stval",
                                         RiscVCsrEnum::kSTval, 0, state),
            nullptr);
 
@@ -707,14 +707,17 @@
   RiscVCsrInterface* epc_csr = nullptr;
   RiscVCsrInterface* cause_csr = nullptr;
   RiscVCsrInterface* tvec_csr = nullptr;
+  RiscVCsrInterface* tval_csr = nullptr;
   if (destination_mode == PrivilegeMode::kMachine) {
     epc_csr = mepc_;
     cause_csr = mcause_;
     tvec_csr = mtvec_;
+    tval_csr = mtval_;
   } else if (destination_mode == PrivilegeMode::kSupervisor) {
     epc_csr = sepc_;
     cause_csr = scause_;
     tvec_csr = stvec_;
+    tval_csr = stval_;
   } else {
     LOG(ERROR) << "Invalid destination execution mode";
     return;
@@ -729,6 +732,8 @@
 
   // Set xepc.
   epc_csr->Set(epc);
+  // Set xtval.
+  tval_csr->Set(trap_value);
   // Set xcause.
   cause_csr->Set(exception_code);
   auto current_xlen = xlen();
diff --git a/riscv/riscv_state.h b/riscv/riscv_state.h
index 94196d4..a0666cf 100644
--- a/riscv/riscv_state.h
+++ b/riscv/riscv_state.h
@@ -400,6 +400,7 @@
   RiscVMIe* mie() const { return mie_; }
   RiscVCsrInterface* jvt() const { return jvt_; }
   RiscVCsrInterface* mtvec() const { return mtvec_; }
+  RiscVCsrInterface* mtval() const { return mtval_; }
   RiscVCsrInterface* mepc() const { return mepc_; }
   RiscVCsrInterface* mcause() const { return mcause_; }
   RiscVCsrInterface* medeleg() const { return medeleg_; }
@@ -407,6 +408,7 @@
   RiscVSIp* sip() const { return sip_; }
   RiscVSIe* sie() const { return sie_; }
   RiscVCsrInterface* stvec() const { return stvec_; }
+  RiscVCsrInterface* stval() const { return stval_; }
   RiscVCsrInterface* sepc() const { return sepc_; }
   RiscVCsrInterface* scause() const { return scause_; }
   RiscVCsrInterface* sideleg() const { return sideleg_; }
@@ -452,6 +454,7 @@
   RiscVPmp* pmp_ = nullptr;
   RiscVCsrInterface* jvt_ = nullptr;
   RiscVCsrInterface* mtvec_ = nullptr;
+  RiscVCsrInterface* mtval_ = nullptr;
   RiscVCsrInterface* mepc_ = nullptr;
   RiscVCsrInterface* mcause_ = nullptr;
   RiscVCsrInterface* medeleg_ = nullptr;
@@ -459,6 +462,7 @@
   RiscVSIp* sip_ = nullptr;
   RiscVSIe* sie_ = nullptr;
   RiscVCsrInterface* stvec_ = nullptr;
+  RiscVCsrInterface* stval_ = nullptr;
   RiscVCsrInterface* sepc_ = nullptr;
   RiscVCsrInterface* scause_ = nullptr;
   RiscVCsrInterface* sideleg_ = nullptr;
diff --git a/riscv/test/riscv_state_test.cc b/riscv/test/riscv_state_test.cc
index 74f7cec..325a96d 100644
--- a/riscv/test/riscv_state_test.cc
+++ b/riscv/test/riscv_state_test.cc
@@ -29,6 +29,8 @@
 
 namespace {
 
+using ::mpact::sim::riscv::ExceptionCode;
+using ::mpact::sim::riscv::PrivilegeMode;
 using ::mpact::sim::riscv::RiscVCsrEnum;
 using ::mpact::sim::riscv::RiscVCsrInterface;
 using ::mpact::sim::riscv::RiscVState;
@@ -79,8 +81,7 @@
                         uint64_t exception_code, uint64_t epc,
                         const mpact::sim::riscv::Instruction* inst) -> bool {
     if (exception_code ==
-        static_cast<uint64_t>(
-            mpact::sim::riscv::ExceptionCode::kLoadAccessFault)) {
+        static_cast<uint64_t>(ExceptionCode::kLoadAccessFault)) {
       std::cerr << "Load Access Fault" << std::endl;
       return true;
     }
@@ -163,4 +164,36 @@
   }
 }
 
+TEST(RiscVStateTest, Mtval) {
+  FlatDemandMemory memory;
+  auto state = std::make_unique<RiscVState>("test", RiscVXlen::RV32, &memory);
+  state->set_max_physical_address(kMemAddr - 4);
+  auto* db = state->db_factory()->Allocate<uint32_t>(1);
+  // Create a dummy instruction so trap can dereference the address.
+  auto* dummy_inst = new mpact::sim::riscv::Instruction(0x0, nullptr);
+  dummy_inst->set_size(4);
+  state->LoadMemory(dummy_inst, kMemAddr, db, nullptr, nullptr);
+  EXPECT_EQ(state->mtval()->AsUint64(), kMemAddr);
+  db->DecRef();
+  dummy_inst->DecRef();
+}
+
+TEST(RiscVStateTest, Stval) {
+  FlatDemandMemory memory;
+  auto state = std::make_unique<RiscVState>("test", RiscVXlen::RV32, &memory);
+  state->set_max_physical_address(kMemAddr - 4);
+  // Delegate LoadAccessFault to S-mode and set privilege to S-mode
+  state->medeleg()->Set(uint64_t{1}
+                        << static_cast<int>(ExceptionCode::kLoadAccessFault));
+  state->set_privilege_mode(PrivilegeMode::kSupervisor);
+  auto* db = state->db_factory()->Allocate<uint32_t>(1);
+  // Create a dummy instruction so trap can dereference the address.
+  auto* dummy_inst = new mpact::sim::riscv::Instruction(0x0, nullptr);
+  dummy_inst->set_size(4);
+  state->LoadMemory(dummy_inst, kMemAddr, db, nullptr, nullptr);
+  EXPECT_EQ(state->stval()->AsUint64(), kMemAddr);
+  db->DecRef();
+  dummy_inst->DecRef();
+}
+
 }  // namespace