Added more tests and fixed issues in the initial cache implementation. Now enabled to work as a data cache as well with memory backing. PiperOrigin-RevId: 679177125 Change-Id: I5f7e9869e98a5d12b866ece94e297d90d8cd4e4c
diff --git a/mpact/sim/util/memory/cache.cc b/mpact/sim/util/memory/cache.cc index 732841d..f8ee594 100644 --- a/mpact/sim/util/memory/cache.cc +++ b/mpact/sim/util/memory/cache.cc
@@ -189,9 +189,9 @@ (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/true); if (memory_ == nullptr) return; - auto *cache_context = new CacheContext{context, db, inst, db->latency()}; - context->IncRef(); - inst->IncRef(); + auto *cache_context = new CacheContext(context, db, inst, db->latency()); + if (context) context->IncRef(); + if (inst) inst->IncRef(); db->set_latency(0); memory_->Load(address, db, cache_inst_, cache_context); cache_context->DecRef(); @@ -209,8 +209,8 @@ if (memory_ == nullptr) return; auto *cache_context = new CacheContext(context, db, inst, db->latency()); - context->IncRef(); - inst->IncRef(); + if (context) context->IncRef(); + if (inst) inst->IncRef(); db->set_latency(0); memory_->Load(address_db, mask_db, el_size, db, cache_inst_, cache_context); cache_context->DecRef(); @@ -221,17 +221,19 @@ (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/true); if (tagged_memory_ == nullptr) return; - auto *cache_context = new CacheContext{context, db, inst, db->latency()}; - context->IncRef(); - inst->IncRef(); + auto *cache_context = + new CacheContext(context, db, tags, inst, db->latency()); + if (context) context->IncRef(); + if (inst) inst->IncRef(); db->set_latency(0); - tagged_memory_->Load(address, db, cache_inst_, cache_context); + tagged_memory_->Load(address, db, tags, cache_inst_, cache_context); cache_context->DecRef(); } void Cache::Store(uint64_t address, DataBuffer *db) { (void)CacheLookup(address, db->size<uint8_t>(), /*is_read=*/false); if (memory_ == nullptr) return; + memory_->Store(address, db); } @@ -266,7 +268,7 @@ auto *og_inst = cache_context->inst; // Reset the db latency to the original value. db->set_latency(cache_context->latency); - if (nullptr != inst) { + if (nullptr != og_inst) { if (db->latency() > 0) { og_inst->IncRef(); og_inst->state()->function_delay_line()->Add(db->latency(),
diff --git a/mpact/sim/util/memory/cache.h b/mpact/sim/util/memory/cache.h index 7d82c38..580e956 100644 --- a/mpact/sim/util/memory/cache.h +++ b/mpact/sim/util/memory/cache.h
@@ -131,6 +131,15 @@ void Store(DataBuffer *address, DataBuffer *mask, int el_size, DataBuffer *db) override; void Store(uint64_t address, DataBuffer *db, DataBuffer *tags) override; + // Setters for the memory interfaces. + void set_memory(MemoryInterface *memory) { + memory_ = memory; + tagged_memory_ = nullptr; + } + void set_tagged_memory(TaggedMemoryInterface *tagged_memory) { + tagged_memory_ = tagged_memory; + memory_ = tagged_memory; + } private: // This struct represents a cache line. @@ -158,12 +167,12 @@ // The cache. CacheLine *cache_lines_ = nullptr; // Shift amounts and mask used to compute the index from the address. - int block_shift_; - int set_shift_; - uint64_t index_mask_; + int block_shift_ = 0; + int set_shift_ = 0; + uint64_t index_mask_ = 0; // True if allocate cache line on write is enabled. - bool write_allocate_; - uint64_t num_sets_; + bool write_allocate_ = false; + uint64_t num_sets_ = 0; // Instruction object used to perform the writeback to the processor. Instruction *cache_inst_; CounterValueOutputBase<uint64_t> *cycle_counter_;
diff --git a/mpact/sim/util/memory/flat_demand_memory.h b/mpact/sim/util/memory/flat_demand_memory.h index c050ada..4b34bdd 100644 --- a/mpact/sim/util/memory/flat_demand_memory.h +++ b/mpact/sim/util/memory/flat_demand_memory.h
@@ -15,6 +15,8 @@ #ifndef SIM_UTIL_MEMORY_FLAT_DEMAND_MEMORY_H_ #define SIM_UTIL_MEMORY_FLAT_DEMAND_MEMORY_H_ +#include <cstdint> + #include "absl/container/flat_hash_map.h" #include "mpact/sim/generic/data_buffer.h" #include "mpact/sim/generic/instruction.h"
diff --git a/mpact/sim/util/memory/test/cache_test.cc b/mpact/sim/util/memory/test/cache_test.cc index 3c74907..34223ba 100644 --- a/mpact/sim/util/memory/test/cache_test.cc +++ b/mpact/sim/util/memory/test/cache_test.cc
@@ -21,6 +21,8 @@ #include "googletest/include/gtest/gtest.h" #include "mpact/sim/generic/counters.h" #include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/util/memory/flat_demand_memory.h" +#include "mpact/sim/util/memory/tagged_flat_demand_memory.h" namespace { @@ -28,6 +30,10 @@ using ::mpact::sim::generic::DataBufferFactory; using ::mpact::sim::generic::SimpleCounter; using ::mpact::sim::util::Cache; +using ::mpact::sim::util::FlatDemandMemory; +using ::mpact::sim::util::TaggedFlatDemandMemory; + +constexpr unsigned kTagGranule = 16; class CacheTest : public testing::Test { protected: @@ -81,6 +87,18 @@ EXPECT_EQ(read_hits_->GetValue(), (refs / 4) * 3); } +TEST_F(CacheTest, DirectMappedWritesCold) { + // Create a cache 16kB, 16B blocks, direct mapped. + CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_)); + + for (uint64_t address = 0; address < 1024; address += 4) { + cache_->Store(address, db_); + } + uint64_t refs = 1024 / 4; + EXPECT_EQ(write_misses_->GetValue(), refs / 4); + EXPECT_EQ(write_hits_->GetValue(), (refs / 4) * 3); +} + TEST_F(CacheTest, DirectMappedReadsWarm) { // Create a cache 16kB, 16B blocks, direct mapped. CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_)); @@ -112,6 +130,37 @@ EXPECT_EQ(read_hits_->GetValue(), (refs / 4) * 3); } +TEST_F(CacheTest, DirectMappedWritesWarm) { + // Create a cache 16kB, 16B blocks, direct mapped. + CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_)); + + // Warm the cache. + for (uint64_t address = 0; address < 1024; address += 4) { + cache_->Store(address, db_); + } + // Clear the counters. + write_misses_->SetValue(0); + write_hits_->SetValue(0); + + // Access the cache again. Should be all hits. + for (uint64_t address = 0; address < 1024; address += 4) { + cache_->Store(address, db_); + } + uint64_t refs = 1024 / 4; + EXPECT_EQ(write_misses_->GetValue(), 0); + EXPECT_EQ(write_hits_->GetValue(), refs); + + // Clear the counters. + write_misses_->SetValue(0); + write_hits_->SetValue(0); + // Access the next 1k, should be like a cold cache. + for (uint64_t address = 1024; address < 2048; address += 4) { + cache_->Store(address, db_); + } + EXPECT_EQ(write_misses_->GetValue(), refs / 4); + EXPECT_EQ(write_hits_->GetValue(), (refs / 4) * 3); +} + TEST_F(CacheTest, TwoWayReads) { // Create a cache 16kB, 16B blocks, two way set associative. CHECK_OK(cache_->Configure("1k,16,2,true", &cycle_counter_)); @@ -134,4 +183,102 @@ EXPECT_EQ(read_hits_->GetValue(), 2 * 512 / 16); } +TEST_F(CacheTest, MemoryTest) { + FlatDemandMemory memory; + cache_->set_memory(&memory); + CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_)); + + DataBuffer *st_db1 = db_factory_.Allocate<uint8_t>(1); + DataBuffer *st_db2 = db_factory_.Allocate<uint16_t>(1); + DataBuffer *st_db4 = db_factory_.Allocate<uint32_t>(1); + DataBuffer *st_db8 = db_factory_.Allocate<uint64_t>(1); + + st_db1->Set<uint8_t>(0, 0x0F); + st_db2->Set<uint16_t>(0, 0xA5A5); + st_db4->Set<uint32_t>(0, 0xDEADBEEF); + st_db8->Set<uint64_t>(0, 0x0F0F0F0F'A5A5A5A5); + + cache_->Store(0x1000, st_db1); + cache_->Store(0x1002, st_db2); + cache_->Store(0x1004, st_db4); + cache_->Store(0x1008, st_db8); + + DataBuffer *ld_db1 = db_factory_.Allocate<uint8_t>(1); + DataBuffer *ld_db2 = db_factory_.Allocate<uint16_t>(1); + DataBuffer *ld_db4 = db_factory_.Allocate<uint32_t>(1); + DataBuffer *ld_db8 = db_factory_.Allocate<uint64_t>(1); + + cache_->Load(0x1000, ld_db1, nullptr, nullptr); + cache_->Load(0x1002, ld_db2, nullptr, nullptr); + cache_->Load(0x1004, ld_db4, nullptr, nullptr); + cache_->Load(0x1008, ld_db8, nullptr, nullptr); + + EXPECT_EQ(ld_db1->Get<uint8_t>(0), st_db1->Get<uint8_t>(0)); + EXPECT_EQ(ld_db2->Get<uint16_t>(0), st_db2->Get<uint16_t>(0)); + EXPECT_EQ(ld_db4->Get<uint32_t>(0), st_db4->Get<uint32_t>(0)); + EXPECT_EQ(ld_db8->Get<uint64_t>(0), st_db8->Get<uint64_t>(0)); + + ld_db1->DecRef(); + ld_db2->DecRef(); + ld_db4->DecRef(); + ld_db8->DecRef(); + + st_db1->DecRef(); + st_db2->DecRef(); + st_db4->DecRef(); + st_db8->DecRef(); +} + +TEST_F(CacheTest, TaggedMemoryTest) { + TaggedFlatDemandMemory memory(kTagGranule); + cache_->set_tagged_memory(&memory); + CHECK_OK(cache_->Configure("1k,16,1,true", &cycle_counter_)); + + DataBuffer *ld_data_db = db_factory_.Allocate<uint8_t>(kTagGranule * 16); + DataBuffer *ld_tag_db = db_factory_.Allocate<uint8_t>(16); + DataBuffer *st_data_db = db_factory_.Allocate<uint8_t>(kTagGranule * 16); + DataBuffer *st_tag_db = db_factory_.Allocate<uint8_t>(16); + cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr); + // The loaded data should be all zeros. + for (int i = 0; i < 16; i++) { + for (int j = 0; j < kTagGranule; j++) { + EXPECT_EQ(ld_data_db->Get<uint8_t>(i * kTagGranule + j), 0); + } + EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), 0); + } + // Write out known data. + for (int i = 0; i < st_data_db->size<uint8_t>(); i++) { + st_data_db->Set<uint8_t>(i, i); + } + for (int i = 0; i < st_tag_db->size<uint8_t>(); i++) { + st_tag_db->Set<uint8_t>(i, 1); + } + cache_->Store(0x1000, st_data_db, st_tag_db); + // Verify that the loaded data is equal to the stored data. + cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr); + for (int i = 0; i < ld_data_db->size<uint8_t>(); i++) { + EXPECT_EQ(ld_data_db->Get<uint8_t>(i), i); + } + for (int i = 0; i < ld_tag_db->size<uint8_t>(); i++) { + EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), 1); + } + // Clear every third tag and store them. + for (int i = 0; i < st_tag_db->size<uint8_t>(); i++) { + if (i % 3 == 0) st_tag_db->Set<uint8_t>(i, 0); + } + cache_->Store(0x1000, st_data_db, st_tag_db); + // Re-load and compare. + cache_->Load(0x1000, ld_data_db, ld_tag_db, nullptr, nullptr); + for (int i = 0; i < ld_data_db->size<uint8_t>(); i++) { + EXPECT_EQ(ld_data_db->Get<uint8_t>(i), i); + } + for (int i = 0; i < ld_tag_db->size<uint8_t>(); i++) { + EXPECT_EQ(ld_tag_db->Get<uint8_t>(i), i % 3 == 0 ? 0 : 1) << i; + } + ld_data_db->DecRef(); + ld_tag_db->DecRef(); + st_data_db->DecRef(); + st_tag_db->DecRef(); +} + } // namespace