Added more tests and fixed issues in the initial cache implementation. Now enabled to
work as a data cache as well with memory backing.

PiperOrigin-RevId: 679177125
Change-Id: I5f7e9869e98a5d12b866ece94e297d90d8cd4e4c
diff --git a/mpact/sim/util/memory/cache.cc b/mpact/sim/util/memory/cache.cc
index 732841d..f8ee594 100644
--- a/mpact/sim/util/memory/cache.cc
+++ b/mpact/sim/util/memory/cache.cc
@@ -189,9 +189,9 @@
   (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/true);
   if (memory_ == nullptr) return;
 
-  auto *cache_context = new CacheContext{context, db, inst, db->latency()};
-  context->IncRef();
-  inst->IncRef();
+  auto *cache_context = new CacheContext(context, db, inst, db->latency());
+  if (context) context->IncRef();
+  if (inst) inst->IncRef();
   db->set_latency(0);
   memory_->Load(address, db, cache_inst_, cache_context);
   cache_context->DecRef();
@@ -209,8 +209,8 @@
   if (memory_ == nullptr) return;
 
   auto *cache_context = new CacheContext(context, db, inst, db->latency());
-  context->IncRef();
-  inst->IncRef();
+  if (context) context->IncRef();
+  if (inst) inst->IncRef();
   db->set_latency(0);
   memory_->Load(address_db, mask_db, el_size, db, cache_inst_, cache_context);
   cache_context->DecRef();
@@ -221,17 +221,19 @@
   (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/true);
   if (tagged_memory_ == nullptr) return;
 
-  auto *cache_context = new CacheContext{context, db, inst, db->latency()};
-  context->IncRef();
-  inst->IncRef();
+  auto *cache_context =
+      new CacheContext(context, db, tags, inst, db->latency());
+  if (context) context->IncRef();
+  if (inst) inst->IncRef();
   db->set_latency(0);
-  tagged_memory_->Load(address, db, cache_inst_, cache_context);
+  tagged_memory_->Load(address, db, tags, cache_inst_, cache_context);
   cache_context->DecRef();
 }
 
 void Cache::Store(uint64_t address, DataBuffer *db) {
   (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/false);
   if (memory_ == nullptr) return;
+
   memory_->Store(address, db);
 }
 
@@ -266,7 +268,7 @@
   auto *og_inst = cache_context->inst;
   // Reset the db latency to the original value.
   db->set_latency(cache_context->latency);
-  if (nullptr != inst) {
+  if (nullptr != og_inst) {
     if (db->latency() > 0) {
       og_inst->IncRef();
       og_inst->state()->function_delay_line()->Add(db->latency(),
diff --git a/mpact/sim/util/memory/cache.h b/mpact/sim/util/memory/cache.h
index 7d82c38..580e956 100644
--- a/mpact/sim/util/memory/cache.h
+++ b/mpact/sim/util/memory/cache.h
@@ -131,6 +131,15 @@
   void Store(DataBuffer *address, DataBuffer *mask, int el_size,
              DataBuffer *db) override;
   void Store(uint64_t address, DataBuffer *db, DataBuffer *tags) override;
+  // Setters for the memory interfaces.
+  void set_memory(MemoryInterface *memory) {
+    memory_ = memory;
+    tagged_memory_ = nullptr;
+  }
+  void set_tagged_memory(TaggedMemoryInterface *tagged_memory) {
+    tagged_memory_ = tagged_memory;
+    memory_ = tagged_memory;
+  }
 
  private:
   // This struct represents a cache line.
@@ -158,12 +167,12 @@
   // The cache.
   CacheLine *cache_lines_ = nullptr;
   // Shift amounts and mask used to compute the index from the address.
-  int block_shift_;
-  int set_shift_;
-  uint64_t index_mask_;
+  int block_shift_ = 0;
+  int set_shift_ = 0;
+  uint64_t index_mask_ = 0;
   // True if allocate cache line on write is enabled.
-  bool write_allocate_;
-  uint64_t num_sets_;
+  bool write_allocate_ = false;
+  uint64_t num_sets_ = 0;
   // Instruction object used to perform the writeback to the processor.
   Instruction *cache_inst_;
   CounterValueOutputBase<uint64_t> *cycle_counter_;
diff --git a/mpact/sim/util/memory/flat_demand_memory.h b/mpact/sim/util/memory/flat_demand_memory.h
index c050ada..4b34bdd 100644
--- a/mpact/sim/util/memory/flat_demand_memory.h
+++ b/mpact/sim/util/memory/flat_demand_memory.h
@@ -15,6 +15,8 @@
 #ifndef SIM_UTIL_MEMORY_FLAT_DEMAND_MEMORY_H_
 #define SIM_UTIL_MEMORY_FLAT_DEMAND_MEMORY_H_
 
+#include <cstdint>
+
 #include "absl/container/flat_hash_map.h"
 #include "mpact/sim/generic/data_buffer.h"
 #include "mpact/sim/generic/instruction.h"
diff --git a/mpact/sim/util/memory/test/cache_test.cc b/mpact/sim/util/memory/test/cache_test.cc
index 3c74907..34223ba 100644
--- a/mpact/sim/util/memory/test/cache_test.cc
+++ b/mpact/sim/util/memory/test/cache_test.cc
@@ -21,6 +21,8 @@
 #include "googletest/include/gtest/gtest.h"
 #include "mpact/sim/generic/counters.h"
 #include "mpact/sim/generic/data_buffer.h"
+#include "mpact/sim/util/memory/flat_demand_memory.h"
+#include "mpact/sim/util/memory/tagged_flat_demand_memory.h"
 
 namespace {
 
@@ -28,6 +30,10 @@
 using ::mpact::sim::generic::DataBufferFactory;
 using ::mpact::sim::generic::SimpleCounter;
 using ::mpact::sim::util::Cache;
+using ::mpact::sim::util::FlatDemandMemory;
+using ::mpact::sim::util::TaggedFlatDemandMemory;
+
+constexpr unsigned kTagGranule = 16;
 
 class CacheTest : public testing::Test {
  protected:
@@ -81,6 +87,18 @@
   EXPECT_EQ(read_hits_->GetValue(), (refs / 4) * 3);
 }
 
+TEST_F(CacheTest, DirectMappedWritesCold) {
+  // Create a cache 16kB, 16B blocks, direct mapped.
+  CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_));
+
+  for (uint64_t address = 0; address < 1024; address += 4) {
+    cache_->Store(address, db_);
+  }
+  uint64_t refs = 1024 / 4;
+  EXPECT_EQ(write_misses_->GetValue(), refs / 4);
+  EXPECT_EQ(write_hits_->GetValue(), (refs / 4) * 3);
+}
+
 TEST_F(CacheTest, DirectMappedReadsWarm) {
   // Create a cache 16kB, 16B blocks, direct mapped.
   CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_));
@@ -112,6 +130,37 @@
   EXPECT_EQ(read_hits_->GetValue(), (refs / 4) * 3);
 }
 
+TEST_F(CacheTest, DirectMappedWritesWarm) {
+  // Create a cache 16kB, 16B blocks, direct mapped.
+  CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_));
+
+  // Warm the cache.
+  for (uint64_t address = 0; address < 1024; address += 4) {
+    cache_->Store(address, db_);
+  }
+  // Clear the counters.
+  write_misses_->SetValue(0);
+  write_hits_->SetValue(0);
+
+  // Access the cache again. Should be all hits.
+  for (uint64_t address = 0; address < 1024; address += 4) {
+    cache_->Store(address, db_);
+  }
+  uint64_t refs = 1024 / 4;
+  EXPECT_EQ(write_misses_->GetValue(), 0);
+  EXPECT_EQ(write_hits_->GetValue(), refs);
+
+  // Clear the counters.
+  write_misses_->SetValue(0);
+  write_hits_->SetValue(0);
+  // Access the next 1k, should be like a cold cache.
+  for (uint64_t address = 1024; address < 2048; address += 4) {
+    cache_->Store(address, db_);
+  }
+  EXPECT_EQ(write_misses_->GetValue(), refs / 4);
+  EXPECT_EQ(write_hits_->GetValue(), (refs / 4) * 3);
+}
+
 TEST_F(CacheTest, TwoWayReads) {
   // Create a cache 16kB, 16B blocks, two way set associative.
   CHECK_OK(cache_->Configure("1k,16,2,true", &cycle_counter_));
@@ -134,4 +183,102 @@
   EXPECT_EQ(read_hits_->GetValue(), 2 * 512 / 16);
 }
 
+TEST_F(CacheTest, MemoryTest) {
+  FlatDemandMemory memory;
+  cache_->set_memory(&memory);
+  CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_));
+
+  DataBuffer *st_db1 = db_factory_.Allocate<uint8_t>(1);
+  DataBuffer *st_db2 = db_factory_.Allocate<uint16_t>(1);
+  DataBuffer *st_db4 = db_factory_.Allocate<uint32_t>(1);
+  DataBuffer *st_db8 = db_factory_.Allocate<uint64_t>(1);
+
+  st_db1->Set<uint8_t>(0, 0x0F);
+  st_db2->Set<uint16_t>(0, 0xA5A5);
+  st_db4->Set<uint32_t>(0, 0xDEADBEEF);
+  st_db8->Set<uint64_t>(0, 0x0F0F0F0F'A5A5A5A5);
+
+  cache_->Store(0x1000, st_db1);
+  cache_->Store(0x1002, st_db2);
+  cache_->Store(0x1004, st_db4);
+  cache_->Store(0x1008, st_db8);
+
+  DataBuffer *ld_db1 = db_factory_.Allocate<uint8_t>(1);
+  DataBuffer *ld_db2 = db_factory_.Allocate<uint16_t>(1);
+  DataBuffer *ld_db4 = db_factory_.Allocate<uint32_t>(1);
+  DataBuffer *ld_db8 = db_factory_.Allocate<uint64_t>(1);
+
+  cache_->Load(0x1000, ld_db1, nullptr, nullptr);
+  cache_->Load(0x1002, ld_db2, nullptr, nullptr);
+  cache_->Load(0x1004, ld_db4, nullptr, nullptr);
+  cache_->Load(0x1008, ld_db8, nullptr, nullptr);
+
+  EXPECT_EQ(ld_db1->Get<uint8_t>(0), st_db1->Get<uint8_t>(0));
+  EXPECT_EQ(ld_db2->Get<uint16_t>(0), st_db2->Get<uint16_t>(0));
+  EXPECT_EQ(ld_db4->Get<uint32_t>(0), st_db4->Get<uint32_t>(0));
+  EXPECT_EQ(ld_db8->Get<uint64_t>(0), st_db8->Get<uint64_t>(0));
+
+  ld_db1->DecRef();
+  ld_db2->DecRef();
+  ld_db4->DecRef();
+  ld_db8->DecRef();
+
+  st_db1->DecRef();
+  st_db2->DecRef();
+  st_db4->DecRef();
+  st_db8->DecRef();
+}
+
+TEST_F(CacheTest, TaggedMemoryTest) {
+  TaggedFlatDemandMemory memory(kTagGranule);
+  cache_->set_tagged_memory(&memory);
+  CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_));
+
+  DataBuffer *ld_data_db = db_factory_.Allocate<uint8_t>(kTagGranule * 16);
+  DataBuffer *ld_tag_db = db_factory_.Allocate<uint8_t>(16);
+  DataBuffer *st_data_db = db_factory_.Allocate<uint8_t>(kTagGranule * 16);
+  DataBuffer *st_tag_db = db_factory_.Allocate<uint8_t>(16);
+  cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr);
+  // The loaded data should be all zeros.
+  for (int i = 0; i < 16; i++) {
+    for (int j = 0; j < kTagGranule; j++) {
+      EXPECT_EQ(ld_data_db->Get<uint8_t>(i * kTagGranule + j), 0);
+    }
+    EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), 0);
+  }
+  // Write out known data.
+  for (int i = 0; i < st_data_db->size<uint8_t>(); i++) {
+    st_data_db->Set<uint8_t>(i, i);
+  }
+  for (int i = 0; i < st_tag_db->size<uint8_t>(); i++) {
+    st_tag_db->Set<uint8_t>(i, 1);
+  }
+  cache_->Store(0x1000, st_data_db, st_tag_db);
+  // Verify that the loaded data is equal to the stored data.
+  cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr);
+  for (int i = 0; i < ld_data_db->size<uint8_t>(); i++) {
+    EXPECT_EQ(ld_data_db->Get<uint8_t>(i), i);
+  }
+  for (int i = 0; i < ld_tag_db->size<uint8_t>(); i++) {
+    EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), 1);
+  }
+  // Clear every third tag and store them.
+  for (int i = 0; i < st_tag_db->size<uint8_t>(); i++) {
+    if (i % 3 == 0) st_tag_db->Set<uint8_t>(i, 0);
+  }
+  cache_->Store(0x1000, st_data_db, st_tag_db);
+  // Re-load and compare.
+  cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr);
+  for (int i = 0; i < ld_data_db->size<uint8_t>(); i++) {
+    EXPECT_EQ(ld_data_db->Get<uint8_t>(i), i);
+  }
+  for (int i = 0; i < ld_tag_db->size<uint8_t>(); i++) {
+    EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), i % 3 == 0 ? 0 : 1) << i;
+  }
+  ld_data_db->DecRef();
+  ld_tag_db->DecRef();
+  st_data_db->DecRef();
+  st_tag_db->DecRef();
+}
+
 }  // namespace