blob: 00dac5ec38dea9254bea188f0fe045d20beb7047 [file]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <signal.h>
#include <cstdint>
#include <fstream>
#include <ios>
#include <iostream>
#include <optional>
#include <ostream>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "mpact/sim/generic/core_debug_interface.h"
#include "mpact/sim/generic/type_helpers.h"
#include "mpact/sim/util/memory/atomic_memory.h"
#include "mpact/sim/util/memory/flat_demand_memory.h"
#include "mpact/sim/util/memory/memory_watcher.h"
#include "mpact/sim/util/program_loader/elf_program_loader.h"
#include "riscv/riscv32_decoder.h"
#include "riscv/riscv_fp_state.h"
#include "riscv/riscv_register.h"
#include "riscv/riscv_register_aliases.h"
#include "riscv/riscv_state.h"
#include "riscv/riscv_top.h"
// This top level is customized to execute tests generated by
// https://github.com/riscv-software-src/riscv-tests and report the results.
using HaltReason = ::mpact::sim::generic::CoreDebugInterface::HaltReason;
using AddressRange = ::mpact::sim::util::MemoryWatcher::AddressRange;
using ::mpact::sim::generic::operator*; // NOLINT: clang-tidy false positive.
using ::mpact::sim::riscv::RiscV32Decoder;
using ::mpact::sim::riscv::RiscVFPState;
using ::mpact::sim::riscv::RiscVState;
using ::mpact::sim::riscv::RiscVTop;
using ::mpact::sim::riscv::RiscVXlen;
using ::mpact::sim::riscv::RV32Register;
using ::mpact::sim::riscv::RVFpRegister;
constexpr char kBeginSignature[] = "begin_signature";
constexpr char kEndSignature[] = "end_signature";
ABSL_FLAG(std::optional<std::string>, dump_signature, std::nullopt,
"Dump signature file name (riscv torture test)");
int main(int argc, char **argv) {
auto arg_vec = absl::ParseCommandLine(argc, argv);
if (arg_vec.size() > 2) {
std::cerr << "Only a single input file allowed" << std::endl;
return -1;
}
std::string full_file_name = arg_vec[1];
std::string file_name =
full_file_name.substr(full_file_name.find_last_of('/') + 1);
std::string file_basename = file_name.substr(0, file_name.find_first_of('.'));
mpact::sim::util::FlatDemandMemory memory;
auto *watcher = new mpact::sim::util::MemoryWatcher(&memory);
auto *atomic_memory = new mpact::sim::util::AtomicMemory(watcher);
// Load the elf segments into memory.
mpact::sim::util::ElfProgramLoader elf_loader(&memory);
auto load_result = elf_loader.LoadProgram(full_file_name);
if (!load_result.ok()) {
std::cerr << "Error while loading '" << full_file_name
<< "': " << load_result.status().message();
return -1;
}
// Set up architectural state and decoder.
RiscVState rv_state("RiscV32", RiscVXlen::RV32, watcher, atomic_memory);
// For floating point support add the fp state.
RiscVFPState rv_fp_state(rv_state.csr_set(), &rv_state);
rv_state.set_rv_fp(&rv_fp_state);
// Create the instruction decoder.
RiscV32Decoder rv_decoder(&rv_state, watcher);
// Make sure the architectural and abi register aliases are added.
std::string reg_name;
for (int i = 0; i < 32; i++) {
reg_name = absl::StrCat(RiscVState::kXregPrefix, i);
(void)rv_state.AddRegister<RV32Register>(reg_name);
(void)rv_state.AddRegisterAlias<RV32Register>(
reg_name, ::mpact::sim::riscv::kXRegisterAliases[i]);
}
for (int i = 0; i < 32; i++) {
reg_name = absl::StrCat(RiscVState::kFregPrefix, i);
(void)rv_state.AddRegister<RVFpRegister>(reg_name);
(void)rv_state.AddRegisterAlias<RVFpRegister>(
reg_name, ::mpact::sim::riscv::kFRegisterAliases[i]);
}
RiscVTop riscv_top("RiscV32TestSim", &rv_state, &rv_decoder);
// Initialize the PC to the entry point.
uint32_t entry_point = load_result.value();
auto pc_write = riscv_top.WriteRegister("pc", entry_point);
if (!pc_write.ok()) {
std::cerr << "Error writing to pc: " << pc_write.message();
return -1;
}
// The test result gets stored to label <tohost>, so set a watchpoint there.
auto result = elf_loader.GetSymbol("tohost");
if (!result.ok()) {
std::cerr << "Cannot find symbol 'tohost'";
return -1;
}
auto tohost = result.value().first;
auto status = watcher->SetStoreWatchCallback(
AddressRange(tohost), [&riscv_top](uint64_t, int) -> void {
riscv_top.RequestHalt(RiscVTop::HaltReason::kUserRequest, nullptr);
});
if (!status.ok()) {
std::cerr << "Cannot set watcher callback";
return -1;
}
// Run the executable.
auto run_status = riscv_top.Run();
if (!run_status.ok()) {
std::cerr << run_status.message() << std::endl;
return -1;
}
// Wait for halt.
auto wait_status = riscv_top.Wait();
if (!wait_status.ok()) {
std::cerr << wait_status.message() << std::endl;
return -1;
}
// Get halt reason.
auto halt_reason = riscv_top.GetLastHaltReason();
if (!halt_reason.ok()) {
std::cerr << "Failed to get halt reason: "
<< halt_reason.status().message();
return -1;
}
// Read PC, see where we halted.
auto pc_read = riscv_top.ReadRegister("pc");
if (!pc_read.ok()) {
std::cerr << "Failed to read pc: " << pc_read.status().message();
return -1;
}
int ret = -1;
if (halt_reason.value() == *HaltReason::kUserRequest) {
auto db = riscv_top.state()->db_factory()->Allocate<uint32_t>(1);
memory.Load(tohost, db, nullptr, nullptr);
auto value = db->Get<uint32_t>(0);
db->DecRef();
if (value == 1) {
std::cerr << "PASS (" << value << ")\n";
ret = 0;
} else {
std::cerr << "FAIL (" << value << ")\n";
ret = -1;
}
}
// Check to see if we need to dump the riscv torture signature section.
if (absl::GetFlag(FLAGS_dump_signature).has_value()) {
std::string file_name = absl::GetFlag(FLAGS_dump_signature).value();
std::fstream sig_file(file_name.c_str(), std::ios_base::out);
auto begin_res = elf_loader.GetSymbol(kBeginSignature);
auto end_res = elf_loader.GetSymbol(kEndSignature);
if (!begin_res.ok() || !end_res.ok()) {
std::cerr << "Unable to find signature symbols";
} else {
uint64_t begin_sig = begin_res.value().first;
uint64_t end_sig = end_res.value().first;
uint64_t length = end_sig - begin_sig;
uint64_t *buffer = new uint64_t[length >> 3];
auto status = riscv_top.ReadMemory(begin_sig, buffer, length);
for (int i = 0; i < length >> 3; ++i) {
sig_file << std::hex << buffer[i] << std::endl;
}
delete[] buffer;
sig_file.close();
}
}
delete atomic_memory;
delete watcher;
return ret;
}