Update mtval and stval CSRs on traps - Save the trap_value to the appropriate xtval CSR during trap handling. - Track mtval and stval pointers in RiscVState for direct access. - Add unit tests to verify mtval and stval are correctly set on memory access faults. PiperOrigin-RevId: 862919471 Change-Id: I857f8ce9df453281886b84352f4957b872906de1
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc index b53ffb8..db6c198 100644 --- a/riscv/riscv_state.cc +++ b/riscv/riscv_state.cc
@@ -226,7 +226,7 @@ CsrInfo<uint64_t>::kMstatusInitialValue, state, misa); CHECK_NE(mstatus, nullptr); // mtval - CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, csr_vec, "mtval", + CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, state->mtval_, csr_vec, "mtval", RiscVCsrEnum::kMTval, 0, state), nullptr); @@ -389,7 +389,7 @@ RiscVCsrEnum::kSEpc, 0, state), nullptr); // stval - CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, csr_vec, "stval", + CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(state, state->stval_, csr_vec, "stval", RiscVCsrEnum::kSTval, 0, state), nullptr); @@ -707,14 +707,17 @@ RiscVCsrInterface* epc_csr = nullptr; RiscVCsrInterface* cause_csr = nullptr; RiscVCsrInterface* tvec_csr = nullptr; + RiscVCsrInterface* tval_csr = nullptr; if (destination_mode == PrivilegeMode::kMachine) { epc_csr = mepc_; cause_csr = mcause_; tvec_csr = mtvec_; + tval_csr = mtval_; } else if (destination_mode == PrivilegeMode::kSupervisor) { epc_csr = sepc_; cause_csr = scause_; tvec_csr = stvec_; + tval_csr = stval_; } else { LOG(ERROR) << "Invalid destination execution mode"; return; @@ -729,6 +732,8 @@ // Set xepc. epc_csr->Set(epc); + // Set xtval. + tval_csr->Set(trap_value); // Set xcause. cause_csr->Set(exception_code); auto current_xlen = xlen();
diff --git a/riscv/riscv_state.h b/riscv/riscv_state.h index 94196d4..a0666cf 100644 --- a/riscv/riscv_state.h +++ b/riscv/riscv_state.h
@@ -400,6 +400,7 @@ RiscVMIe* mie() const { return mie_; } RiscVCsrInterface* jvt() const { return jvt_; } RiscVCsrInterface* mtvec() const { return mtvec_; } + RiscVCsrInterface* mtval() const { return mtval_; } RiscVCsrInterface* mepc() const { return mepc_; } RiscVCsrInterface* mcause() const { return mcause_; } RiscVCsrInterface* medeleg() const { return medeleg_; } @@ -407,6 +408,7 @@ RiscVSIp* sip() const { return sip_; } RiscVSIe* sie() const { return sie_; } RiscVCsrInterface* stvec() const { return stvec_; } + RiscVCsrInterface* stval() const { return stval_; } RiscVCsrInterface* sepc() const { return sepc_; } RiscVCsrInterface* scause() const { return scause_; } RiscVCsrInterface* sideleg() const { return sideleg_; } @@ -452,6 +454,7 @@ RiscVPmp* pmp_ = nullptr; RiscVCsrInterface* jvt_ = nullptr; RiscVCsrInterface* mtvec_ = nullptr; + RiscVCsrInterface* mtval_ = nullptr; RiscVCsrInterface* mepc_ = nullptr; RiscVCsrInterface* mcause_ = nullptr; RiscVCsrInterface* medeleg_ = nullptr; @@ -459,6 +462,7 @@ RiscVSIp* sip_ = nullptr; RiscVSIe* sie_ = nullptr; RiscVCsrInterface* stvec_ = nullptr; + RiscVCsrInterface* stval_ = nullptr; RiscVCsrInterface* sepc_ = nullptr; RiscVCsrInterface* scause_ = nullptr; RiscVCsrInterface* sideleg_ = nullptr;
diff --git a/riscv/test/riscv_state_test.cc b/riscv/test/riscv_state_test.cc index 74f7cec..325a96d 100644 --- a/riscv/test/riscv_state_test.cc +++ b/riscv/test/riscv_state_test.cc
@@ -29,6 +29,8 @@ namespace { +using ::mpact::sim::riscv::ExceptionCode; +using ::mpact::sim::riscv::PrivilegeMode; using ::mpact::sim::riscv::RiscVCsrEnum; using ::mpact::sim::riscv::RiscVCsrInterface; using ::mpact::sim::riscv::RiscVState; @@ -79,8 +81,7 @@ uint64_t exception_code, uint64_t epc, const mpact::sim::riscv::Instruction* inst) -> bool { if (exception_code == - static_cast<uint64_t>( - mpact::sim::riscv::ExceptionCode::kLoadAccessFault)) { + static_cast<uint64_t>(ExceptionCode::kLoadAccessFault)) { std::cerr << "Load Access Fault" << std::endl; return true; } @@ -163,4 +164,36 @@ } } +TEST(RiscVStateTest, Mtval) { + FlatDemandMemory memory; + auto state = std::make_unique<RiscVState>("test", RiscVXlen::RV32, &memory); + state->set_max_physical_address(kMemAddr - 4); + auto* db = state->db_factory()->Allocate<uint32_t>(1); + // Create a dummy instruction so trap can dereference the address. + auto* dummy_inst = new mpact::sim::riscv::Instruction(0x0, nullptr); + dummy_inst->set_size(4); + state->LoadMemory(dummy_inst, kMemAddr, db, nullptr, nullptr); + EXPECT_EQ(state->mtval()->AsUint64(), kMemAddr); + db->DecRef(); + dummy_inst->DecRef(); +} + +TEST(RiscVStateTest, Stval) { + FlatDemandMemory memory; + auto state = std::make_unique<RiscVState>("test", RiscVXlen::RV32, &memory); + state->set_max_physical_address(kMemAddr - 4); + // Delegate LoadAccessFault to S-mode and set privilege to S-mode + state->medeleg()->Set(uint64_t{1} + << static_cast<int>(ExceptionCode::kLoadAccessFault)); + state->set_privilege_mode(PrivilegeMode::kSupervisor); + auto* db = state->db_factory()->Allocate<uint32_t>(1); + // Create a dummy instruction so trap can dereference the address. + auto* dummy_inst = new mpact::sim::riscv::Instruction(0x0, nullptr); + dummy_inst->set_size(4); + state->LoadMemory(dummy_inst, kMemAddr, db, nullptr, nullptr); + EXPECT_EQ(state->stval()->AsUint64(), kMemAddr); + db->DecRef(); + dummy_inst->DecRef(); +} + } // namespace