RVA23: Add CSRs associated with extension Zicntr PiperOrigin-RevId: 799776222 Change-Id: I63d5e5a1a4dd9868bc59c9d8f14d2187c49f1949
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc index 0553d46..f7f6321 100644 --- a/riscv/riscv_state.cc +++ b/riscv/riscv_state.cc
@@ -234,7 +234,7 @@ auto* minstret = CreateCsr<RiscVCounterCsr<T, RiscVState>>( state, csr_vec, "minstret", RiscVCsrEnum ::kMInstret, state); CHECK_NE(minstret, nullptr); - if (sizeof(T) == sizeof(uint32_t)) { + if (std::is_same_v<T, uint32_t>) { CHECK_NE( CreateCsr<RiscVCounterCsrHigh<RiscVState>>( state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state, @@ -245,7 +245,7 @@ auto* mcycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>( state, csr_vec, "mcycle", RiscVCsrEnum::kMCycle, state); CHECK_NE(mcycle, nullptr); - if (sizeof(T) == sizeof(uint32_t)) { + if (std::is_same_v<T, uint32_t>) { CHECK_NE( CreateCsr<RiscVCounterCsrHigh<RiscVState>>( state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state, @@ -253,6 +253,42 @@ nullptr); } + // cycle / cycleh + auto* cycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>( + state, csr_vec, "cycle", RiscVCsrEnum::kCycle, state); + CHECK_NE(cycle, nullptr); + if (std::is_same_v<T, uint32_t>) { + CHECK_NE( + CreateCsr<RiscVCounterCsrHigh<RiscVState>>( + state, csr_vec, "cycleh", RiscVCsrEnum::kCycleH, state, + reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(cycle)), + nullptr); + } + + // time / timeh + auto* time = CreateCsr<RiscVCounterCsr<T, RiscVState>>( + state, csr_vec, "time", RiscVCsrEnum::kTime, state); + CHECK_NE(time, nullptr); + if (std::is_same_v<T, uint32_t>) { + CHECK_NE( + CreateCsr<RiscVCounterCsrHigh<RiscVState>>( + state, csr_vec, "timeh", RiscVCsrEnum::kTimeH, state, + reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(time)), + nullptr); + } + + // instret / instreth + auto* instret = CreateCsr<RiscVCounterCsr<T, RiscVState>>( + state, csr_vec, "instret", RiscVCsrEnum::kInstret, state); + CHECK_NE(instret, nullptr); + if (std::is_same_v<T, uint32_t>) { + CHECK_NE( + CreateCsr<RiscVCounterCsrHigh<RiscVState>>( + state, csr_vec, "instreth", RiscVCsrEnum::kInstretH, state, + reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(instret)), + nullptr); + } + // Hypervisor level CSRs // henvcfg
diff --git a/riscv/riscv_top.cc b/riscv/riscv_top.cc index 1f96b13..3f6dbf1 100644 --- a/riscv/riscv_top.cc +++ b/riscv/riscv_top.cc
@@ -29,6 +29,7 @@ #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "mpact/sim/generic/action_point_manager_base.h" #include "mpact/sim/generic/breakpoint_manager.h" @@ -145,39 +146,13 @@ } // Connect counters to instret(h) and mcycle(h) CSRs. - auto csr_res = state_->csr_set()->GetCsr("minstret"); - CHECK_OK(csr_res.status()) << "Failed to get minstret CSR"; - if (state_->xlen() == RiscVXlen::RV32) { - // Minstret/minstreth. - auto* minstret = reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>( - csr_res.value()); - minstret->set_counter(&counter_num_instructions_); - csr_res = state_->csr_set()->GetCsr("minstreth"); - CHECK_OK(csr_res.status()) << "Failed to get minstret CSR"; - auto* minstreth = - reinterpret_cast<RiscVCounterCsrHigh<RiscVState>*>(csr_res.value()); - minstreth->set_counter(&counter_num_instructions_); - // Mcycle/mcycleh. - csr_res = state_->csr_set()->GetCsr("mcycle"); - CHECK_OK(csr_res.status()) << "Failed to get mcycle CSR"; - auto* mcycle = reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>( - csr_res.value()); - mcycle->set_counter(&counter_num_cycles_); - csr_res = state_->csr_set()->GetCsr("mcycleh"); - CHECK_OK(csr_res.status()) << "Failed to get mcycleh CSR"; - auto* mcycleh = - reinterpret_cast<RiscVCounterCsrHigh<RiscVState>*>(csr_res.value()); - mcycleh->set_counter(&counter_num_cycles_); - } else { - // Minstret/minstreth. - auto* minstret = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState>*>( - csr_res.value()); - minstret->set_counter(&counter_num_instructions_); - // Mcycle/mcycleh. - auto* mcycle = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState>*>( - csr_res.value()); - mcycle->set_counter(&counter_num_cycles_); - } + CHECK_OK(SetCsrCounter("minstret", counter_num_instructions_)); + CHECK_OK(SetCsrCounter("mcycle", counter_num_cycles_)); + + // Connect Zicntr counters to cycle(h), time(h), and instret(h) CSRs. + CHECK_OK(SetCsrCounter("instret", counter_num_instructions_)); + CHECK_OK(SetCsrCounter("cycle", counter_num_cycles_)); + CHECK_OK(SetCsrCounter("time", counter_num_cycles_)); // Set up break and action points. rv_action_point_memory_interface_ = new RiscVActionPointMemoryInterface( @@ -893,6 +868,34 @@ static_cast<uint32_t>(to), 1}; } +absl::Status RiscVTop::SetCsrCounter( + absl::string_view name, generic::SimpleCounter<uint64_t>& counter) { + absl::StatusOr<RiscVCsrInterface*> csr_result = + state_->csr_set()->GetCsr(name); + if (!csr_result.ok()) { + return csr_result.status(); + } + switch (state_->xlen()) { + case RiscVXlen::RV32: + dynamic_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(*csr_result) + ->set_counter(&counter); + csr_result = state_->csr_set()->GetCsr(absl::StrCat(name, "h")); + if (!csr_result.ok()) { + return csr_result.status(); + } + dynamic_cast<RiscVCounterCsrHigh<RiscVState>*>(*csr_result) + ->set_counter(&counter); + break; + case RiscVXlen::RV64: + dynamic_cast<RiscVCounterCsr<uint64_t, RiscVState>*>(*csr_result) + ->set_counter(&counter); + break; + default: + return absl::InternalError("Unknown Xlen value"); + } + return absl::OkStatus(); +} + void RiscVTop::EnableStatistics() { for (auto& [unused, counter_ptr] : counter_map()) { if (counter_ptr->GetName() == "pc") continue;
diff --git a/riscv/riscv_top.h b/riscv/riscv_top.h index e51ad42..171ef9c 100644 --- a/riscv/riscv_top.h +++ b/riscv/riscv_top.h
@@ -15,7 +15,6 @@ #ifndef MPACT_RISCV_RISCV_RISCV_TOP_H_ #define MPACT_RISCV_RISCV_RISCV_TOP_H_ -// #include <algorithm> #include <cstddef> #include <cstdint> #include <string> @@ -25,6 +24,7 @@ #include "absl/functional/any_invocable.h" #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "absl/synchronization/notification.h" #include "mpact/sim/generic/action_point_manager_base.h" #include "mpact/sim/generic/breakpoint_manager.h" @@ -166,6 +166,11 @@ void ICacheFetch(uint64_t address); // Branch tracing. void AddToBranchTrace(uint64_t from, uint64_t to); + // Add a counter to a CSR. This is used to connect the counters to the + // various CSRs that track the same quantity (e.g., cycle, time, retired + // instructions). + absl::Status SetCsrCounter(absl::string_view name, + generic::SimpleCounter<uint64_t>& counter); // The DB factory is used to manage data buffers for memory read/writes. generic::DataBufferFactory db_factory_;
diff --git a/riscv/test/riscv_top_test.cc b/riscv/test/riscv_top_test.cc index 7e0657f..a5d3709 100644 --- a/riscv/test/riscv_top_test.cc +++ b/riscv/test/riscv_top_test.cc
@@ -584,6 +584,12 @@ EXPECT_EQ(static_cast<int>(halt_result.value()), static_cast<int>(HaltReason::kSemihostHaltRequest)); EXPECT_EQ("Hello world! 5\n", testing::internal::GetCapturedStdout()); + + // Verify that the counter CSRs are non-zero. Mutation testing found the + // coverage gap. + EXPECT_NE(riscv_top_->ReadRegister("cycle").value(), 0); + EXPECT_NE(riscv_top_->ReadRegister("time").value(), 0); + EXPECT_NE(riscv_top_->ReadRegister("instret").value(), 0); } } // namespace