blob: 34b7fe59bf308347195f6d3370dd8a520f60b86c [file]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpact/sim/util/memory/memory_watcher.h"
#include <cstdint>
#include <utility>
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "mpact/sim/util/memory/memory_interface.h"
namespace mpact {
namespace sim {
namespace util {
MemoryWatcher::MemoryWatcher(MemoryInterface* memory) : memory_(memory) {}
// Methods to insert store and load watch callbacks. The methods check that
// the address range is well formed (start <= end), and that there is no
// overlapping range in the btree map.
absl::Status MemoryWatcher::SetStoreWatchCallback(const AddressRange& range,
Callback callback) {
if (range.start > range.end) {
return absl::InternalError(absl::StrCat("Illegal store watch range: start ",
absl::Hex(range.start), " > end ",
absl::Hex(range.end)));
}
if (st_watch_actions_.contains(range)) {
return absl::InternalError(absl::StrCat(
"Store watch range [", absl::Hex(range.start), ", ",
absl::Hex(range.end), "] overlaps with an existing watch range"));
}
st_watch_actions_.insert({range, std::move(callback)});
return absl::OkStatus();
}
absl::Status MemoryWatcher::ClearStoreWatchCallback(uint64_t address) {
auto ptr = st_watch_actions_.find(AddressRange(address));
if (ptr == st_watch_actions_.end()) {
return absl::NotFoundError(absl::StrCat(
"ClearStoreWatchCallback: Error: No range found that contains: ",
absl::Hex(address)));
}
st_watch_actions_.erase(ptr);
return absl::OkStatus();
}
absl::Status MemoryWatcher::SetLoadWatchCallback(const AddressRange& range,
Callback callback) {
if (range.start > range.end) {
return absl::InternalError(absl::StrCat("Illegal store watch range: start ",
absl::Hex(range.start), " > end ",
absl::Hex(range.end)));
}
if (ld_watch_actions_.contains(range)) {
return absl::InternalError(absl::StrCat(
"Load watch range [", absl::Hex(range.start), ", ",
absl::Hex(range.end), "] overlaps with an existing watch range"));
}
ld_watch_actions_.insert({range, std::move(callback)});
return absl::OkStatus();
}
absl::Status MemoryWatcher::ClearLoadWatchCallback(uint64_t address) {
auto ptr = ld_watch_actions_.find(AddressRange(address));
if (ptr == ld_watch_actions_.end()) {
return absl::NotFoundError(absl::StrCat(
"ClearLoadWatchCallback: Error: No range found that contains: ",
absl::Hex(address)));
}
ld_watch_actions_.erase(ptr);
return absl::OkStatus();
}
// Each of the overridden methods for Loads and Stores checks if the address
// is in a range that is being watched. If it is, the load/store action callback
// is called before/after the load/store is forwarded to the interface.
// Single address.
void MemoryWatcher::Load(uint64_t address, DataBuffer* db, Instruction* inst,
ReferenceCount* context) {
if (!ld_watch_actions_.empty()) {
auto [first_match, last] = ld_watch_actions_.equal_range(
AddressRange(address, address + db->size<uint8_t>() - 1));
while (first_match != last) {
first_match->second(address, db->size<uint8_t>());
first_match++;
}
}
memory_->Load(address, db, inst, context);
}
// Gather load (multiple addresses and a mask vector).
void MemoryWatcher::Load(DataBuffer* address_db, DataBuffer* mask_db,
int el_size, DataBuffer* db, Instruction* inst,
ReferenceCount* context) {
if (!ld_watch_actions_.empty()) {
int num_entries = mask_db->size<bool>();
auto addresses = address_db->Get<uint64_t>();
auto masks = mask_db->Get<bool>();
for (int i = 0; i < num_entries; i++) {
// Need to check each unmasked address.
if (!masks[i]) continue;
uint64_t address = addresses[i];
auto [first_match, last] = ld_watch_actions_.equal_range(
AddressRange(address, address + el_size - 1));
while (first_match != last) {
first_match->second(address, el_size);
first_match++;
}
}
}
memory_->Load(address_db, mask_db, el_size, db, inst, context);
}
// Single address store.
void MemoryWatcher::Store(uint64_t address, DataBuffer* db) {
memory_->Store(address, db);
if (!st_watch_actions_.empty()) {
auto [first_match, last] = st_watch_actions_.equal_range(
AddressRange(address, address + db->size<uint8_t>() - 1));
while (first_match != last) {
first_match->second(address, db->size<uint8_t>());
first_match++;
}
}
}
// Scatter store (multiple addresses and a mask vector).
void MemoryWatcher::Store(DataBuffer* address_db, DataBuffer* mask_db,
int el_size, DataBuffer* db) {
memory_->Store(address_db, mask_db, el_size, db);
if (!st_watch_actions_.empty()) {
int num_entries = mask_db->size<bool>();
auto addresses = address_db->Get<uint64_t>();
auto masks = mask_db->Get<bool>();
for (int i = 0; i < num_entries; i++) {
// Need to check each unmasked address.
if (!masks[i]) continue;
uint64_t address = addresses[i];
auto [first_match, last] = st_watch_actions_.equal_range(
AddressRange(address, address + el_size - 1));
while (first_match != last) {
first_match->second(address, el_size);
first_match++;
}
}
}
}
} // namespace util
} // namespace sim
} // namespace mpact