Adds a new CSR type RiscVShadowCsr that allows a CSR to provide a more
restrictive view of another CSR.

Adds cycle/cycleh and instret/instreth as shadow CSRs

PiperOrigin-RevId: 705640436
Change-Id: I91f0c2396abc54e154cba387243ab0f8a136f70b
diff --git a/riscv/riscv_csr.h b/riscv/riscv_csr.h
index 6765b2c..b3b9480 100644
--- a/riscv/riscv_csr.h
+++ b/riscv/riscv_csr.h
@@ -69,12 +69,12 @@
   kTime = 0xc01,
   kInstret = 0xc02,
 
-  // Ignoring perf monitoring counters for now.
-
   kCycleH = 0xc80,
   kTimeH = 0xc81,
   kInstretH = 0x82,
 
+  // Ignoring perf monitoring counters for now.
+
   // Ignoring high bits of perf monitoring counters for now.
 
   // Supervisor trap setup.
@@ -388,6 +388,94 @@
   RiscVCsrClearBitsDb *clear_bits_target_;
 };
 
+// The shadow csr class is used to implement a more restricted view of another
+// CSR, for instance, making a read-only view version of a CSR that may be
+// accessible at lower privilege levels.
+template <typename T>
+class RiscVShadowCsr : public RiscVCsrInterface {
+ public:
+  RiscVShadowCsr(std::string name, RiscVCsrEnum index, T read_mask,
+                 T write_mask, ArchState *state, RiscVCsrInterface *csr)
+      : RiscVCsrInterface(name, static_cast<uint64_t>(index), state),
+        csr_(csr),
+        read_mask_(read_mask),
+        write_mask_(write_mask),
+        write_target_(new RiscVCsrWriteDb(this)),
+        set_bits_target_(new RiscVCsrSetBitsDb(this)),
+        clear_bits_target_(new RiscVCsrClearBitsDb(this)) {}
+  RiscVShadowCsr() = delete;
+  RiscVShadowCsr(const RiscVShadowCsr &) = delete;
+  RiscVShadowCsr &operator=(const RiscVShadowCsr &) = delete;
+
+  ~RiscVShadowCsr() override {
+    delete write_target_;
+    delete set_bits_target_;
+    delete clear_bits_target_;
+  }
+
+  // Return the value, modified as per read mask.
+  uint32_t AsUint32() override {
+    return static_cast<uint32_t>(static_cast<T>(csr_->AsUint32()) & read_mask_);
+  }
+  uint64_t AsUint64() override {
+    return static_cast<uint64_t>(static_cast<T>(csr_->AsUint64()) & read_mask_);
+  }
+  // Write the value, modified as per write mask.
+  void Write(uint32_t value) override {
+    if (write_mask_ != 0) {
+      csr_->Write((static_cast<T>(csr_->GetUint32()) & ~write_mask_) |
+                  (static_cast<T>(value) & write_mask_));
+    }
+  }
+  void Write(uint64_t value) override {
+    if (write_mask_ != 0) {
+      csr_->Write((static_cast<T>(csr_->GetUint64()) & ~write_mask_) |
+                  (static_cast<T>(value) & write_mask_));
+    }
+  }
+  // Set the bits that are set in value, leave other bits unchanged.
+  // Set the bits specified in the value. Don't change the other bits.
+  void SetBits(uint32_t value) override { Write(GetUint32() | value); }
+  void SetBits(uint64_t value) override { Write(GetUint64() | value); }
+  // Clear the bits specified in the value. Don't change the other bits.
+  void ClearBits(uint32_t value) override { Write(GetUint32() & ~value); }
+  void ClearBits(uint64_t value) override { Write(GetUint64() & ~value); }
+  // Return the value, ignoring the read mask.
+  uint32_t GetUint32() override { return csr_->GetUint32(); }
+  uint64_t GetUint64() override { return csr_->GetUint64(); }
+  // Sets the value, ignoring the write mask.
+  void Set(uint32_t value) override { csr_->Set(static_cast<T>(value)); }
+  void Set(uint64_t value) override { csr_->Set(static_cast<T>(value)); }
+  // Size of value.
+  size_t size() const override { return sizeof(T); }
+  // Set to reset value.
+  void Reset() override { /* Empty. */ }
+  // Operand creation interface.
+  generic::SourceOperandInterface *CreateSourceOperand() override;
+  generic::DestinationOperandInterface *CreateSetDestinationOperand(
+      int latency, std::string op_name) override;
+  generic::DestinationOperandInterface *CreateClearDestinationOperand(
+      int latency, std::string op_name) override;
+  generic::DestinationOperandInterface *CreateWriteDestinationOperand(
+      int latency, std::string op_name) override;
+
+  RiscVCsrWriteDb *write_target() const { return write_target_; }
+  RiscVCsrSetBitsDb *set_bits_target() const { return set_bits_target_; }
+  RiscVCsrClearBitsDb *clear_bits_target() const { return clear_bits_target_; }
+
+  RiscVCsrInterface *csr() const { return csr_; }
+  T read_mask() const { return read_mask_; }
+  T write_mask() const { return write_mask_; }
+
+ private:
+  RiscVCsrInterface *csr_;
+  T read_mask_;
+  T write_mask_;
+  RiscVCsrWriteDb *write_target_;
+  RiscVCsrSetBitsDb *set_bits_target_;
+  RiscVCsrClearBitsDb *clear_bits_target_;
+};
+
 using RiscV32SimpleCsr = RiscVSimpleCsr<uint32_t>;
 using RiscV64SimpleCsr = RiscVSimpleCsr<uint64_t>;
 
@@ -521,6 +609,35 @@
   return new RiscVCsrSourceOperand(this);
 }
 
+template <typename T>
+generic::DestinationOperandInterface *
+RiscVShadowCsr<T>::CreateSetDestinationOperand(int latency,
+                                               std::string op_name) {
+  return new RiscVCsrDestinationOperand(this, this->set_bits_target(), latency,
+                                        op_name);
+}
+
+template <typename T>
+generic::DestinationOperandInterface *
+RiscVShadowCsr<T>::CreateClearDestinationOperand(int latency,
+                                                 std::string op_name) {
+  return new RiscVCsrDestinationOperand(this, this->clear_bits_target(),
+                                        latency, op_name);
+}
+
+template <typename T>
+generic::DestinationOperandInterface *
+RiscVShadowCsr<T>::CreateWriteDestinationOperand(int latency,
+                                                 std::string op_name) {
+  return new RiscVCsrDestinationOperand(this, this->write_target(), latency,
+                                        op_name);
+}
+
+template <typename T>
+generic::SourceOperandInterface *RiscVShadowCsr<T>::CreateSourceOperand() {
+  return new RiscVCsrSourceOperand(this);
+}
+
 }  // namespace riscv
 }  // namespace sim
 }  // namespace mpact
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index 8d71bad..36e4fe9 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -234,23 +234,23 @@
   auto *minstret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
       state, csr_vec, "minstret", RiscVCsrEnum ::kMInstret, state);
   CHECK_NE(minstret, nullptr);
+  RiscVCsrInterface *minstreth = nullptr;
   if (sizeof(T) == sizeof(uint32_t)) {
-    CHECK_NE(CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
-                 state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state,
-                 reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState> *>(
-                     minstret)),
-             nullptr);
+    minstreth = CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state,
+        reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState> *>(minstret));
+    CHECK_NE(minstreth, nullptr);
   }
   // mcycle/mcycleh
   auto *mcycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
       state, csr_vec, "mcycle", RiscVCsrEnum::kMCycle, state);
   CHECK_NE(mcycle, nullptr);
+  RiscVCsrInterface *mcycleh = nullptr;
   if (sizeof(T) == sizeof(uint32_t)) {
-    CHECK_NE(
-        CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
-            state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state,
-            reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState> *>(mcycle)),
-        nullptr);
+    mcycleh = CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+        state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state,
+        reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState> *>(mcycle));
+    CHECK_NE(mcycleh, nullptr);
   }
 
   // Hypervisor level CSRs
@@ -318,6 +318,29 @@
 
   // User level CSRs
 
+  // instret/instreth
+  CHECK_NE(CreateCsr<RiscVShadowCsr<T>>(
+               state, csr_vec, "instret", RiscVCsrEnum ::kInstret,
+               std::numeric_limits<T>::max(), 0, state, minstret),
+           nullptr);
+  if (sizeof(T) == sizeof(uint32_t)) {
+    CHECK_NE(CreateCsr<RiscVShadowCsr<T>>(
+                 state, csr_vec, "instreth", RiscVCsrEnum::kInstretH,
+                 std::numeric_limits<T>::max(), 0, state, minstreth),
+             nullptr);
+  }
+  // cycle/cycleh
+  CHECK_NE(CreateCsr<RiscVShadowCsr<T>>(
+               state, csr_vec, "cycle", RiscVCsrEnum::kCycle,
+               std::numeric_limits<T>::max(), 0, state, mcycle),
+           nullptr);
+  if (sizeof(T) == sizeof(uint32_t)) {
+    CHECK_NE(CreateCsr<RiscVShadowCsr<T>>(
+                 state, csr_vec, "cycleh", RiscVCsrEnum::kCycleH,
+                 std::numeric_limits<T>::max(), 0, state, mcycleh),
+             nullptr);
+  }
+
   // ustatus
   CHECK_NE(CreateCsr<RiscVSimpleCsr<T>>(
                state, csr_vec, "ustatus", RiscVCsrEnum::kUStatus, 0,
diff --git a/riscv/riscv_top.cc b/riscv/riscv_top.cc
index 17885c7..ab18240 100644
--- a/riscv/riscv_top.cc
+++ b/riscv/riscv_top.cc
@@ -144,7 +144,7 @@
         << "Failed to register opcode counter";
   }
 
-  // Connect counters to instret(h) and mcycle(h) CSRs.
+  // Connect counters to minstret(h) and mcycle(h) CSRs.
   auto csr_res = state_->csr_set()->GetCsr("minstret");
   CHECK_OK(csr_res.status()) << "Failed to get minstret CSR";
   if (state_->xlen() == RiscVXlen::RV32) {
@@ -169,11 +169,13 @@
         reinterpret_cast<RiscVCounterCsrHigh<RiscVState> *>(csr_res.value());
     mcycleh->set_counter(&counter_num_cycles_);
   } else {
-    // Minstret/minstreth.
+    // Minstret.
+    csr_res = state_->csr_set()->GetCsr("minstret");
     auto *minstret = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState> *>(
         csr_res.value());
     minstret->set_counter(&counter_num_instructions_);
-    // Mcycle/mcycleh.
+    // Mcycle
+    csr_res = state_->csr_set()->GetCsr("mcycle");
     auto *mcycle = reinterpret_cast<RiscVCounterCsr<uint64_t, RiscVState> *>(
         csr_res.value());
     mcycle->set_counter(&counter_num_cycles_);
diff --git a/riscv/test/riscv_csr_test.cc b/riscv/test/riscv_csr_test.cc
index 1f6650a..10801d3 100644
--- a/riscv/test/riscv_csr_test.cc
+++ b/riscv/test/riscv_csr_test.cc
@@ -30,6 +30,7 @@
 
 using ::mpact::sim::riscv::RiscV32SimpleCsr;
 using ::mpact::sim::riscv::RiscVCsrEnum;
+using ::mpact::sim::riscv::RiscVShadowCsr;
 using ::mpact::sim::riscv::RiscVState;
 using ::mpact::sim::riscv::RiscVXlen;
 using ::mpact::sim::util::FlatDemandMemory;
@@ -124,4 +125,25 @@
   delete csr;
 }
 
+// Test that the shadow csr constructs properly and with the expected values.
+TEST_F(RiscV32CsrTest, ShadowCsrConstruction) {
+  auto *csr0 = new RiscV32SimpleCsr(kCsrName0, RiscVCsrEnum::kMScratch,
+                                    kDeadBeef, state_);
+  EXPECT_EQ(csr0->name(), kCsrName0);
+  EXPECT_EQ(csr0->index(), static_cast<int>(RiscVCsrEnum::kMScratch));
+
+  auto *csr1 = new RiscVShadowCsr<uint32_t>(
+      kCsrName1, RiscVCsrEnum::kUScratch, kReadMask, kWriteMask, state_, csr0);
+  EXPECT_EQ(csr1->name(), kCsrName1);
+  EXPECT_EQ(csr1->index(), static_cast<int>(RiscVCsrEnum::kUScratch));
+  EXPECT_EQ(csr1->read_mask(), kReadMask);
+  EXPECT_EQ(csr1->write_mask(), kWriteMask);
+
+  EXPECT_EQ(csr1->AsUint32(), csr0->AsUint32() & kReadMask);
+  csr1->Write(kAllOnes);
+  EXPECT_EQ(csr0->AsUint32(), kDeadBeef | (kAllOnes & kWriteMask));
+  delete csr0;
+  delete csr1;
+}
+
 }  // namespace