RVA23: Add CSRs associated with extension Zicntr

PiperOrigin-RevId: 799776222
Change-Id: I63d5e5a1a4dd9868bc59c9d8f14d2187c49f1949
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index 0553d46..f7f6321 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -234,7 +234,7 @@
   auto* minstret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
       state, csr_vec, "minstret", RiscVCsrEnum ::kMInstret, state);
   CHECK_NE(minstret, nullptr);
-  if (sizeof(T) == sizeof(uint32_t)) {
+  if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
         CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
             state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state,
@@ -245,7 +245,7 @@
   auto* mcycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
       state, csr_vec, "mcycle", RiscVCsrEnum::kMCycle, state);
   CHECK_NE(mcycle, nullptr);
-  if (sizeof(T) == sizeof(uint32_t)) {
+  if (std::is_same_v<T, uint32_t>) {
     CHECK_NE(
         CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
             state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state,
@@ -253,6 +253,42 @@
         nullptr);
   }
 
+  // cycle / cycleh
+  auto* cycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+      state, csr_vec, "cycle", RiscVCsrEnum::kCycle, state);
+  CHECK_NE(cycle, nullptr);
+  if (std::is_same_v<T, uint32_t>) {
+    CHECK_NE(
+        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+            state, csr_vec, "cycleh", RiscVCsrEnum::kCycleH, state,
+            reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(cycle)),
+        nullptr);
+  }
+
+  // time / timeh
+  auto* time = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+      state, csr_vec, "time", RiscVCsrEnum::kTime, state);
+  CHECK_NE(time, nullptr);
+  if (std::is_same_v<T, uint32_t>) {
+    CHECK_NE(
+        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+            state, csr_vec, "timeh", RiscVCsrEnum::kTimeH, state,
+            reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(time)),
+        nullptr);
+  }
+
+  // instret / instreth
+  auto* instret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+      state, csr_vec, "instret", RiscVCsrEnum::kInstret, state);
+  CHECK_NE(instret, nullptr);
+  if (std::is_same_v<T, uint32_t>) {
+    CHECK_NE(
+        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+            state, csr_vec, "instreth", RiscVCsrEnum::kInstretH, state,
+            reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(instret)),
+        nullptr);
+  }
+
   // Hypervisor level CSRs
 
   // henvcfg
diff --git a/riscv/riscv_top.cc b/riscv/riscv_top.cc
index 1f96b13..3f6dbf1 100644
--- a/riscv/riscv_top.cc
+++ b/riscv/riscv_top.cc
@@ -29,6 +29,7 @@
 #include "absl/status/status.h"
 #include "absl/strings/str_cat.h"
 #include "absl/strings/str_format.h"
+#include "absl/strings/string_view.h"
 #include "absl/synchronization/notification.h"
 #include "mpact/sim/generic/action_point_manager_base.h"
 #include "mpact/sim/generic/breakpoint_manager.h"
@@ -145,39 +146,13 @@
   }
 
   // Connect counters to instret(h) and mcycle(h) CSRs.
-  auto csr_res = state_->csr_set()->GetCsr("minstret");
-  CHECK_OK(csr_res.status()) << "Failed to get minstret CSR";
-  if (state_->xlen() == RiscVXlen::RV32) {
-    // Minstret/minstreth.
-    auto* minstret = reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(
-        csr_res.value());
-    minstret->set_counter(&counter_num_instructions_);
-    csr_res = state_->csr_set()->GetCsr("minstreth");
-    CHECK_OK(csr_res.status()) << "Failed to get minstret CSR";
-    auto* minstreth =
-        reinterpret_cast<RiscVCounterCsrHigh<RiscVState>*>(csr_res.value());
-    minstreth->set_counter(&counter_num_instructions_);
-    // Mcycle/mcycleh.
-    csr_res = state_->csr_set()->GetCsr("mcycle");
-    CHECK_OK(csr_res.status()) << "Failed to get mcycle CSR";
-    auto* mcycle = reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(
-        csr_res.value());
-    mcycle->set_counter(&counter_num_cycles_);
-    csr_res = state_->csr_set()->GetCsr("mcycleh");
-    CHECK_OK(csr_res.status()) << "Failed to get mcycleh CSR";
-    auto* mcycleh =
-        reinterpret_cast<RiscVCounterCsrHigh<RiscVState>*>(csr_res.value());
-    mcycleh->set_counter(&counter_num_cycles_);
-  } else {
-    // Minstret/minstreth.
-    auto* minstret = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState>*>(
-        csr_res.value());
-    minstret->set_counter(&counter_num_instructions_);
-    // Mcycle/mcycleh.
-    auto* mcycle = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState>*>(
-        csr_res.value());
-    mcycle->set_counter(&counter_num_cycles_);
-  }
+  CHECK_OK(SetCsrCounter("minstret", counter_num_instructions_));
+  CHECK_OK(SetCsrCounter("mcycle", counter_num_cycles_));
+
+  // Connect Zicntr counters to cycle(h), time(h), and instret(h) CSRs.
+  CHECK_OK(SetCsrCounter("instret", counter_num_instructions_));
+  CHECK_OK(SetCsrCounter("cycle", counter_num_cycles_));
+  CHECK_OK(SetCsrCounter("time", counter_num_cycles_));
 
   // Set up break and action points.
   rv_action_point_memory_interface_ = new RiscVActionPointMemoryInterface(
@@ -893,6 +868,34 @@
                                        static_cast<uint32_t>(to), 1};
 }
 
+absl::Status RiscVTop::SetCsrCounter(
+    absl::string_view name, generic::SimpleCounter<uint64_t>& counter) {
+  absl::StatusOr<RiscVCsrInterface*> csr_result =
+      state_->csr_set()->GetCsr(name);
+  if (!csr_result.ok()) {
+    return csr_result.status();
+  }
+  switch (state_->xlen()) {
+    case RiscVXlen::RV32:
+      dynamic_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(*csr_result)
+          ->set_counter(&counter);
+      csr_result = state_->csr_set()->GetCsr(absl::StrCat(name, "h"));
+      if (!csr_result.ok()) {
+        return csr_result.status();
+      }
+      dynamic_cast<RiscVCounterCsrHigh<RiscVState>*>(*csr_result)
+          ->set_counter(&counter);
+      break;
+    case RiscVXlen::RV64:
+      dynamic_cast<RiscVCounterCsr<uint64_t, RiscVState>*>(*csr_result)
+          ->set_counter(&counter);
+      break;
+    default:
+      return absl::InternalError("Unknown Xlen value");
+  }
+  return absl::OkStatus();
+}
+
 void RiscVTop::EnableStatistics() {
   for (auto& [unused, counter_ptr] : counter_map()) {
     if (counter_ptr->GetName() == "pc") continue;
diff --git a/riscv/riscv_top.h b/riscv/riscv_top.h
index e51ad42..171ef9c 100644
--- a/riscv/riscv_top.h
+++ b/riscv/riscv_top.h
@@ -15,7 +15,6 @@
 #ifndef MPACT_RISCV_RISCV_RISCV_TOP_H_
 #define MPACT_RISCV_RISCV_RISCV_TOP_H_
 
-// #include <algorithm>
 #include <cstddef>
 #include <cstdint>
 #include <string>
@@ -25,6 +24,7 @@
 #include "absl/functional/any_invocable.h"
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "absl/strings/string_view.h"
 #include "absl/synchronization/notification.h"
 #include "mpact/sim/generic/action_point_manager_base.h"
 #include "mpact/sim/generic/breakpoint_manager.h"
@@ -166,6 +166,11 @@
   void ICacheFetch(uint64_t address);
   // Branch tracing.
   void AddToBranchTrace(uint64_t from, uint64_t to);
+  // Add a counter to a CSR. This is used to connect the counters to the
+  // various CSRs that track the same quantity (e.g., cycle, time, retired
+  // instructions).
+  absl::Status SetCsrCounter(absl::string_view name,
+                             generic::SimpleCounter<uint64_t>& counter);
 
   // The DB factory is used to manage data buffers for memory read/writes.
   generic::DataBufferFactory db_factory_;
diff --git a/riscv/test/riscv_top_test.cc b/riscv/test/riscv_top_test.cc
index 7e0657f..a5d3709 100644
--- a/riscv/test/riscv_top_test.cc
+++ b/riscv/test/riscv_top_test.cc
@@ -584,6 +584,12 @@
   EXPECT_EQ(static_cast<int>(halt_result.value()),
             static_cast<int>(HaltReason::kSemihostHaltRequest));
   EXPECT_EQ("Hello world! 5\n", testing::internal::GetCapturedStdout());
+
+  // Verify that the counter CSRs are non-zero. Mutation testing found the
+  // coverage gap.
+  EXPECT_NE(riscv_top_->ReadRegister("cycle").value(), 0);
+  EXPECT_NE(riscv_top_->ReadRegister("time").value(), 0);
+  EXPECT_NE(riscv_top_->ReadRegister("instret").value(), 0);
 }
 
 }  // namespace