Compensate for instruction increment when writing to counter CSRs
- Derive RiscVPerformanceCounterCsr{High} from RiscVCounterCsr
- Decrement offset in Set to account for the retiring instruction.
- High::Set handles carry and propagates offset adjustments for high counter CSRs.
- Add unit tests for minstret and minstreth to verify write compensation.
PiperOrigin-RevId: 862918613
Change-Id: I04dc67e1137dc9d82402006411d91a6e3c894423
diff --git a/riscv/riscv_counter_csr.h b/riscv/riscv_counter_csr.h
index c20c269..f2a0a4f 100644
--- a/riscv/riscv_counter_csr.h
+++ b/riscv/riscv_counter_csr.h
@@ -59,6 +59,9 @@
// This is called to tie a cycle counter to the CSR.
void set_counter(SimpleCounter<uint64_t>* counter) { counter_ = counter; }
+ protected:
+ T offset_ = 0;
+
private:
inline T GetCounterValue() const {
if (counter_ == nullptr) return 0;
@@ -66,7 +69,28 @@
};
SimpleCounter<uint64_t>* counter_ = nullptr;
- T offset_ = 0;
+};
+
+// This class implements the performance counter CSRs, which have slightly
+// different behavior on write due to how they interact with the pipeline.
+template <typename T, typename S>
+class RiscVPerformanceCounterCsr : public RiscVCounterCsr<T, S> {
+ public:
+ RiscVPerformanceCounterCsr(std::string name, RiscVCsrEnum csr_enum, S* state)
+ : RiscVCounterCsr<T, S>(name, csr_enum, state) {}
+ RiscVPerformanceCounterCsr(const RiscVPerformanceCounterCsr&) = delete;
+ RiscVPerformanceCounterCsr& operator=(const RiscVPerformanceCounterCsr&) =
+ delete;
+ ~RiscVPerformanceCounterCsr() override = default;
+
+ void Set(uint32_t value) override {
+ RiscVCounterCsr<T, S>::Set(value);
+ --this->offset_;
+ }
+ void Set(uint64_t value) override {
+ RiscVCounterCsr<T, S>::Set(value);
+ --this->offset_;
+ }
};
// This class implements the "high" version of the CSR on 32-bit RiscV.
@@ -92,22 +116,63 @@
// Any value written to the CSR is used to create an offset from the current
// value of the counter.
void Set(uint32_t value) override {
- offset_ = value - (GetCounterValue() >> 32);
+ uint64_t counter_val = GetCounterValue();
+ uint32_t counter_low = static_cast<uint32_t>(counter_val);
+ uint32_t low_offset = low_csr_->offset_;
+ // Check for carry when reconstructing logical low value.
+ bool carry =
+ (static_cast<uint64_t>(counter_low) + low_offset) >= 0x100000000ULL;
+ offset_ = value - (counter_val >> 32) - (carry ? 1 : 0);
};
void Set(uint64_t value) override { Set(static_cast<uint32_t>(value)); };
// This is called to tie a cycle counter to the CSR.
void set_counter(SimpleCounter<uint64_t>* counter) { counter_ = counter; }
- private:
+ protected:
inline uint64_t GetCounterValue() const {
if (counter_ == nullptr) return 0;
return counter_->GetValue();
};
+ uint32_t get_low_offset() const { return low_csr_->offset_; }
+ void decr_low_offset() { low_csr_->offset_--; }
+ uint64_t offset_ = 0;
+
+ private:
RiscVCounterCsr<uint32_t, S>* low_csr_;
SimpleCounter<uint64_t>* counter_ = nullptr;
- uint64_t offset_ = 0;
+};
+
+// This class implements the "high" version of the performance counter CSRs,
+// which have slightly different behavior on write due to how they interact
+// with the pipeline.
+template <typename S>
+class RiscVPerformanceCounterCsrHigh : public RiscVCounterCsrHigh<S> {
+ public:
+ RiscVPerformanceCounterCsrHigh(std::string name, RiscVCsrEnum csr_enum,
+ S* state,
+ RiscVCounterCsr<uint32_t, S>* low_csr)
+ : RiscVCounterCsrHigh<S>(name, csr_enum, state, low_csr) {}
+ RiscVPerformanceCounterCsrHigh(const RiscVPerformanceCounterCsrHigh&) =
+ delete;
+ RiscVPerformanceCounterCsrHigh& operator=(
+ const RiscVPerformanceCounterCsrHigh&) = delete;
+ ~RiscVPerformanceCounterCsrHigh() override = default;
+
+ void Set(uint32_t value) override {
+ uint64_t counter_val = this->GetCounterValue();
+ uint32_t counter_low = static_cast<uint32_t>(counter_val);
+ uint32_t low_offset = this->get_low_offset();
+ // Check for carry when reconstructing logical low value.
+ bool carry =
+ (static_cast<uint64_t>(counter_low) + low_offset) >= 0x100000000ULL;
+ this->offset_ = value - (counter_val >> 32) - (carry ? 1 : 0);
+ if (this->get_low_offset() == 0) {
+ this->offset_--;
+ }
+ this->decr_low_offset();
+ }
};
} // namespace mpact::sim::riscv
diff --git a/riscv/riscv_state.cc b/riscv/riscv_state.cc
index 135bbf9..b53ffb8 100644
--- a/riscv/riscv_state.cc
+++ b/riscv/riscv_state.cc
@@ -231,35 +231,35 @@
nullptr);
// minstret/minstreth
- auto* minstret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+ auto* minstret = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
state, csr_vec, "minstret", RiscVCsrEnum ::kMInstret, state);
CHECK_NE(minstret, nullptr);
if (std::is_same_v<T, uint32_t>) {
CHECK_NE(
- CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+ CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
state, csr_vec, "minstreth", RiscVCsrEnum::kMInstretH, state,
reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(minstret)),
nullptr);
}
// mcycle/mcycleh
- auto* mcycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+ auto* mcycle = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
state, csr_vec, "mcycle", RiscVCsrEnum::kMCycle, state);
CHECK_NE(mcycle, nullptr);
if (std::is_same_v<T, uint32_t>) {
CHECK_NE(
- CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+ CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
state, csr_vec, "mcycleh", RiscVCsrEnum::kMCycleH, state,
reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(mcycle)),
nullptr);
}
// cycle / cycleh
- auto* cycle = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+ auto* cycle = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
state, csr_vec, "cycle", RiscVCsrEnum::kCycle, state);
CHECK_NE(cycle, nullptr);
if (std::is_same_v<T, uint32_t>) {
CHECK_NE(
- CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+ CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
state, csr_vec, "cycleh", RiscVCsrEnum::kCycleH, state,
reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(cycle)),
nullptr);
@@ -278,12 +278,12 @@
}
// instret / instreth
- auto* instret = CreateCsr<RiscVCounterCsr<T, RiscVState>>(
+ auto* instret = CreateCsr<RiscVPerformanceCounterCsr<T, RiscVState>>(
state, csr_vec, "instret", RiscVCsrEnum::kInstret, state);
CHECK_NE(instret, nullptr);
if (std::is_same_v<T, uint32_t>) {
CHECK_NE(
- CreateCsr<RiscVCounterCsrHigh<RiscVState>>(
+ CreateCsr<RiscVPerformanceCounterCsrHigh<RiscVState>>(
state, csr_vec, "instreth", RiscVCsrEnum::kInstretH, state,
reinterpret_cast<RiscVCounterCsr<uint32_t, RiscVState>*>(instret)),
nullptr);
diff --git a/riscv/test/riscv_counter_csr_test.cc b/riscv/test/riscv_counter_csr_test.cc
index 0dc6dc4..abb053f 100644
--- a/riscv/test/riscv_counter_csr_test.cc
+++ b/riscv/test/riscv_counter_csr_test.cc
@@ -17,6 +17,8 @@
using ::mpact::sim::riscv::RiscVCounterCsr;
using ::mpact::sim::riscv::RiscVCounterCsrHigh;
using ::mpact::sim::riscv::RiscVCsrEnum;
+using ::mpact::sim::riscv::RiscVPerformanceCounterCsr;
+using ::mpact::sim::riscv::RiscVPerformanceCounterCsrHigh;
using ::mpact::sim::riscv::RiscVState;
class RiscVMCycleTest : public ::testing::Test {
@@ -29,8 +31,8 @@
// Verify the operation of 64 bit mcycle with counter increments.
TEST_F(RiscVMCycleTest, GetTest64) {
- RiscVCounterCsr<uint64_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
- nullptr);
+ RiscVPerformanceCounterCsr<uint64_t, RiscVState> mcycle(
+ "mcycle", RiscVCsrEnum::kMCycle, nullptr);
mcycle.set_counter(&counter_);
// Initial value should be zero.
EXPECT_EQ(mcycle.GetUint32(), 0);
@@ -39,19 +41,21 @@
EXPECT_EQ(mcycle.GetUint32(), 1);
EXPECT_EQ(mcycle.GetUint64(), 1);
mcycle.Write(100u);
+ counter_.Increment(1);
EXPECT_EQ(mcycle.GetUint32(), 100);
EXPECT_EQ(mcycle.GetUint64(), 100);
mcycle.Set(static_cast<uint64_t>(1000));
+ counter_.Increment(1);
EXPECT_EQ(mcycle.GetUint32(), 1000);
EXPECT_EQ(mcycle.GetUint64(), 1000);
}
// Verify the operation of 32 bit mcycle and mcycleh with counter increments.
TEST_F(RiscVMCycleTest, GetTest32) {
- RiscVCounterCsr<uint32_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
- nullptr);
- RiscVCounterCsrHigh<RiscVState> mcycleh("mcycleh", RiscVCsrEnum::kMCycleH,
- nullptr, &mcycle);
+ RiscVPerformanceCounterCsr<uint32_t, RiscVState> mcycle(
+ "mcycle", RiscVCsrEnum::kMCycle, nullptr);
+ RiscVPerformanceCounterCsrHigh<RiscVState> mcycleh(
+ "mcycleh", RiscVCsrEnum::kMCycleH, nullptr, &mcycle);
mcycle.set_counter(&counter_);
mcycleh.set_counter(&counter_);
// Initial value should be zero.
@@ -82,12 +86,13 @@
// Test that write to mcycle is reflected in the value of mcycle.
TEST_F(RiscVMCycleTest, SetTest64) {
- RiscVCounterCsr<uint64_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
- nullptr);
+ RiscVPerformanceCounterCsr<uint64_t, RiscVState> mcycle(
+ "mcycle", RiscVCsrEnum::kMCycle, nullptr);
mcycle.set_counter(&counter_);
EXPECT_EQ(mcycle.GetUint32(), 0);
EXPECT_EQ(mcycle.GetUint64(), 0);
mcycle.Write(100u);
+ counter_.Increment(1);
EXPECT_EQ(mcycle.GetUint32(), 100);
EXPECT_EQ(mcycle.GetUint64(), 100);
counter_.Increment(10);
@@ -98,10 +103,10 @@
// Test that write to mcycle and mcycleh is reflected in the value of mcycle and
// mcycleh.
TEST_F(RiscVMCycleTest, SetTest32) {
- RiscVCounterCsr<uint32_t, RiscVState> mcycle("mcycle", RiscVCsrEnum::kMCycle,
- nullptr);
- RiscVCounterCsrHigh<RiscVState> mcycleh("mcycleh", RiscVCsrEnum::kMCycleH,
- nullptr, &mcycle);
+ RiscVPerformanceCounterCsr<uint32_t, RiscVState> mcycle(
+ "mcycle", RiscVCsrEnum::kMCycle, nullptr);
+ RiscVPerformanceCounterCsrHigh<RiscVState> mcycleh(
+ "mcycleh", RiscVCsrEnum::kMCycleH, nullptr, &mcycle);
mcycle.set_counter(&counter_);
mcycleh.set_counter(&counter_);
EXPECT_EQ(mcycle.GetUint32(), 0);
@@ -109,7 +114,9 @@
EXPECT_EQ(mcycleh.GetUint32(), 0);
EXPECT_EQ(mcycleh.GetUint64(), 0);
mcycle.Write(100u);
+ counter_.Increment(1);
mcycleh.Write(200u);
+ counter_.Increment(1);
EXPECT_EQ(mcycle.GetUint32(), 100);
EXPECT_EQ(mcycle.GetUint64(), 100);
EXPECT_EQ(mcycleh.GetUint32(), 200);
@@ -126,4 +133,73 @@
EXPECT_EQ(mcycleh.GetUint64(), 201);
}
+// Verify the operation of minstret with counter increments and writes.
+// Specifically, ensure that writing to the CSR masks the increment of the
+// retiring instruction.
+class RiscVMInstretTest : public ::testing::Test {
+ protected:
+ RiscVMInstretTest() : counter_("instructions", 0) {};
+ ~RiscVMInstretTest() override = default;
+
+ SimpleCounter<uint64_t> counter_;
+};
+
+// Test that write to 64-bit minstret accounts for the instruction increment.
+TEST_F(RiscVMInstretTest, SetTest64) {
+ RiscVPerformanceCounterCsr<uint64_t, RiscVState> minstret(
+ "minstret", RiscVCsrEnum::kMInstret, nullptr);
+ minstret.set_counter(&counter_);
+
+ EXPECT_EQ(minstret.GetUint64(), 0);
+
+ // Write a value to the CSR.
+ minstret.Set(1000u);
+ // Simulate the instruction retirement (the instruction that did the write).
+ counter_.Increment(1);
+
+ // The value read should be exactly what was written, because writes to
+ // minstret don't increment the counter.
+ EXPECT_EQ(minstret.GetUint64(), 1000);
+
+ // Subsequent increments should work normally.
+ counter_.Increment(1);
+ EXPECT_EQ(minstret.GetUint64(), 1001);
+}
+
+// Test that writes to 32-bit minstret/minstreth account for the instruction
+// increment.
+TEST_F(RiscVMInstretTest, SetTest32) {
+ RiscVPerformanceCounterCsr<uint32_t, RiscVState> minstret(
+ "minstret", RiscVCsrEnum::kMInstret, nullptr);
+ RiscVPerformanceCounterCsrHigh<RiscVState> minstreth(
+ "minstreth", RiscVCsrEnum::kMInstretH, nullptr, &minstret);
+ minstret.set_counter(&counter_);
+ minstreth.set_counter(&counter_);
+
+ EXPECT_EQ(minstret.GetUint32(), 0);
+ EXPECT_EQ(minstreth.GetUint32(), 0);
+
+ // 1. Write Low CSR.
+ minstret.Set(100u);
+ counter_.Increment(1);
+ // Should equal 100 (increment compensated).
+ EXPECT_EQ(minstret.GetUint32(), 100);
+ EXPECT_EQ(minstreth.GetUint32(), 0);
+
+ // 2. Write High CSR.
+ minstreth.Set(5u);
+ counter_.Increment(1);
+ // Should equal 5 (increment compensated).
+ EXPECT_EQ(minstreth.GetUint32(), 5);
+ // Side effect: The increment compensation on the high write effectively
+ // stalls the low counter for this cycle. This is consistent with preserving
+ // the exact 64-bit value desired.
+ EXPECT_EQ(minstret.GetUint32(), 100);
+
+ // 3. Normal increment.
+ counter_.Increment(1);
+ EXPECT_EQ(minstret.GetUint32(), 101);
+ EXPECT_EQ(minstreth.GetUint32(), 5);
+}
+
} // namespace