blob: c956c8db3e198a403030226a201688146d6b7fef [file]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <signal.h>
#include <cstdint>
#include <fstream>
#include <iomanip>
#include <ios>
#include <iostream>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "mpact/sim/generic/core_debug_interface.h"
#include "mpact/sim/generic/type_helpers.h"
#include "mpact/sim/util/memory/atomic_memory.h"
#include "mpact/sim/util/memory/flat_demand_memory.h"
#include "mpact/sim/util/memory/memory_watcher.h"
#include "mpact/sim/util/program_loader/elf_program_loader.h"
#include "riscv/riscv64_decoder.h"
#include "riscv/riscv_register.h"
#include "riscv/riscv_state.h"
#include "riscv/riscv_test_mem_watcher.h"
#include "riscv/riscv_top.h"
// This top level is customized to execute tests generated by
// https://github.com/riscv-software-src/riscv-tests and report the results.
using HaltReason = ::mpact::sim::generic::CoreDebugInterface::HaltReason;
using HaltReasonValueType =
::mpact::sim::generic::CoreDebugInterface::HaltReasonValueType;
using AddressRange = ::mpact::sim::util::MemoryWatcher::AddressRange;
using ::mpact::sim::generic::operator*; // NOLINT: clang-tidy false positive.
using ::mpact::sim::riscv::RiscV64Decoder;
using ::mpact::sim::riscv::RiscVFPState;
using ::mpact::sim::riscv::RiscVState;
using ::mpact::sim::riscv::RiscVTop;
using ::mpact::sim::riscv::RiscVXlen;
using ::mpact::sim::riscv::RV64Register;
using ::mpact::sim::riscv::RVFpRegister;
constexpr char kBeginSignature[] = "begin_signature";
constexpr char kEndSignature[] = "end_signature";
ABSL_FLAG(std::optional<std::string>, dump_signature, std::nullopt,
"Dump signature file name (riscv torture test)");
ABSL_FLAG(bool, log_commits, false, "Log commits similar to spike");
ABSL_FLAG(int64_t, max_cycles, -1, "Max cycles to simulate");
int main(int argc, char** argv) {
auto arg_vec = absl::ParseCommandLine(argc, argv);
if (arg_vec.size() > 2) {
std::cerr << "Only a single input file allowed" << std::endl;
return -1;
}
std::string full_file_name = arg_vec[1];
std::string file_name =
full_file_name.substr(full_file_name.find_last_of('/') + 1);
std::string file_basename = file_name.substr(0, file_name.find_first_of('.'));
mpact::sim::util::FlatDemandMemory memory;
auto* watcher = new mpact::sim::util::MemoryWatcher(&memory);
auto* test_watcher = new mpact::sim::riscv::RiscVTestMemWatcher(watcher);
auto* atomic_memory = new mpact::sim::util::AtomicMemory(test_watcher);
// Load the elf segments into memory.
mpact::sim::util::ElfProgramLoader elf_loader(&memory);
auto load_result = elf_loader.LoadProgram(full_file_name);
if (!load_result.ok()) {
std::cerr << "Error while loading '" << full_file_name
<< "': " << load_result.status().message();
return -1;
}
// Set up architectural state and decoder.
RiscVState rv_state("RiscV64", RiscVXlen::RV64, test_watcher, atomic_memory);
// For floating point support add the fp state.
RiscVFPState rv_fp_state(rv_state.csr_set(), &rv_state);
rv_state.set_rv_fp(&rv_fp_state);
// Create the instruction decoder.
RiscV64Decoder rv_decoder(&rv_state, watcher);
RiscVTop riscv_top("RiscV64Sim", &rv_state, &rv_decoder);
// Initialize the PC to the entry point.
uint64_t entry_point = load_result.value();
auto pc_write = riscv_top.WriteRegister("pc", entry_point);
if (!pc_write.ok()) {
std::cerr << "Error writing to pc: " << pc_write.message();
return -1;
}
// The test result gets stored to label <tohost>, so set a watchpoint there.
auto result = elf_loader.GetSymbol("tohost");
if (!result.ok()) {
std::cerr << "Cannot find symbol 'tohost'";
return -1;
}
auto tohost = result.value().first;
auto status = watcher->SetStoreWatchCallback(
AddressRange(tohost), [&riscv_top](uint64_t, int) -> void {
riscv_top.RequestHalt(RiscVTop::HaltReason::kUserRequest, nullptr);
});
if (!status.ok()) {
std::cerr << "Cannot set watcher callback";
return -1;
}
// Run the executable.
HaltReasonValueType halt_reason;
bool ok = true;
uint64_t pc = entry_point;
bool commit_trace = absl::GetFlag(FLAGS_log_commits);
auto* register_map = riscv_top.state()->registers();
mpact::sim::generic::DataBuffer* inst_db =
riscv_top.state()->db_factory()->Allocate<uint32_t>(1);
int64_t count = 0;
int64_t max_count = absl::GetFlag(FLAGS_max_cycles);
do {
ok = false;
auto status = riscv_top.Step(1);
if (!status.ok()) break;
count++;
if (max_count > 0 && count > max_count) {
riscv_top.RequestHalt(RiscVTop::HaltReason::kUserRequest, nullptr);
}
auto halt_status = riscv_top.GetLastHaltReason();
if (!halt_status.ok()) break;
halt_reason = halt_status.value();
ok = true;
if (commit_trace) {
// This IncRef's the inst instance. Need to DecRef it when we are done.
auto* inst = riscv_top.GetInstruction(pc).value();
std::string trace_str;
absl::StrAppend(
&trace_str, "core 0: ", *(riscv_top.state()->privilege_mode()),
" 0x", absl::Hex(inst->address(), absl::PadSpec::kZeroPad16));
memory.Load(inst->address(), inst_db, nullptr, nullptr);
absl::StrAppend(
&trace_str, " (0x",
absl::Hex(inst_db->Get<uint32_t>(0), absl::PadSpec::kZeroPad8), ")");
for (int i = 0; i < inst->DestinationsSize(); ++i) {
auto* dest = inst->Destination(i);
if (dest == nullptr) continue;
auto name = dest->AsString();
if (name == "pc") continue;
if (name == "x0") continue;
auto iter = register_map->find(name);
if (iter == register_map->end()) {
continue;
} else {
auto* db = iter->second->data_buffer();
auto size = db->size<uint8_t>();
if (size != sizeof(uint64_t)) {
continue;
}
absl::StrAppend(
&trace_str, " ", absl::StrFormat("%-3s", name), " 0x",
absl::Hex(db->Get<uint64_t>(0), absl::PadSpec::kZeroPad16));
}
}
auto* child = inst->child();
while (child != nullptr) {
for (int i = 0; i < child->DestinationsSize(); ++i) {
auto* dest = child->Destination(i);
if (dest == nullptr) continue;
auto name = dest->AsString();
if (name == "pc") continue;
if (name == "x0") continue;
auto iter = register_map->find(name);
if (iter == register_map->end()) {
continue;
} else {
auto* db = iter->second->data_buffer();
auto size = db->size<uint8_t>();
if (size != sizeof(uint64_t)) {
absl::StrAppend(&trace_str, " size issue: ", name);
continue;
}
absl::StrAppend(
&trace_str, " ", absl::StrFormat("%-3s", name), " 0x",
absl::Hex(db->Get<uint64_t>(0), absl::PadSpec::kZeroPad16));
}
}
child = child->next();
}
trace_str.append(test_watcher->trace_str());
test_watcher->clear_trace_str();
std::cerr << trace_str << std::endl;
inst->DecRef();
}
pc = riscv_top.ReadRegister("pc").value();
} while (halt_reason == *HaltReason::kNone);
if (!ok) {
std::cerr << "Failure in stepping or obtaining halt reason";
return -1;
}
// Read PC, see where we halted.
auto pc_read = riscv_top.ReadRegister("pc");
if (!pc_read.ok()) {
std::cerr << "Failed to read pc: " << pc_read.status().message();
return -1;
}
int ret = -1;
if (halt_reason == *HaltReason::kUserRequest) {
auto db = riscv_top.state()->db_factory()->Allocate<uint32_t>(1);
memory.Load(tohost, db, nullptr, nullptr);
auto value = db->Get<uint32_t>(0);
db->DecRef();
if (value == 1) {
std::cerr << "PASS (" << value << ")\n";
ret = 0;
} else {
std::cerr << "FAIL (" << value << ")\n";
ret = -1;
}
}
// Check to see if we need to dump the riscv torture signature section.
if (absl::GetFlag(FLAGS_dump_signature).has_value()) {
std::string file_name = absl::GetFlag(FLAGS_dump_signature).value();
std::fstream sig_file(file_name.c_str(), std::ios_base::out);
auto begin_res = elf_loader.GetSymbol(kBeginSignature);
auto end_res = elf_loader.GetSymbol(kEndSignature);
if (!begin_res.ok() || !end_res.ok()) {
std::cerr << "Unable to find signature symbols";
} else {
uint64_t begin_sig = begin_res.value().first;
uint64_t end_sig = end_res.value().first;
uint64_t length = end_sig - begin_sig;
uint8_t* buffer = new uint8_t[length];
auto status = riscv_top.ReadMemory(begin_sig, buffer, length);
sig_file << std::setfill('0') << std::hex;
for (int i = 0; i < length; i += 16) {
for (int j = 16; j > 0; j--) {
if (i + j <= length) {
sig_file << std::setw(2)
<< static_cast<uint16_t>(buffer[i + j - 1]);
} else {
sig_file << std::setw(2) << 0;
}
}
sig_file << std::endl;
}
delete[] buffer;
sig_file.close();
}
}
delete atomic_memory;
delete watcher;
return ret;
}