Adding vector and floating point instruction support for cheriot (experimental). PiperOrigin-RevId: 660400409 Change-Id: I0bf291905b7242e45cfcc3d0522b2c51b5dd00e9
diff --git a/cheriot/BUILD b/cheriot/BUILD index 4bb61f5..fceadee 100644 --- a/cheriot/BUILD +++ b/cheriot/BUILD
@@ -24,15 +24,69 @@ exports_files([ "riscv_cheriot.bin_fmt", "riscv_cheriot.isa", + "riscv_cheriot_f.bin_fmt", + "riscv_cheriot_f.isa", + "riscv_cheriot_rvv.isa", + "riscv_cheriot_vector.bin_fmt", + "riscv_cheriot_vector_fp.bin_fmt", + "riscv_cheriot_vector.isa", ]) +config_setting( + name = "arm_cpu", + values = {"cpu": "arm"}, +) + +config_setting( + name = "darwin_arm64_cpu", + values = {"cpu": "darwin_arm64"}, +) + mpact_isa_decoder( name = "riscv_cheriot_isa", src = "riscv_cheriot.isa", includes = [], isa_name = "RiscVCheriot", deps = [ - ":riscv_cheriot", + ":riscv_cheriot_instructions", + "@com_google_absl//absl/functional:bind_front", + ], +) + +mpact_isa_decoder( + name = "riscv_cheriot_rvv_isa", + src = "riscv_cheriot_rvv.isa", + includes = [ + "riscv_cheriot.isa", + "riscv_cheriot_f.isa", + "riscv_cheriot_vector.isa", + "riscv_cheriot_vector_fp.isa", + ], + isa_name = "RiscVCheriotRVV", + prefix = "riscv_cheriot_rvv", + deps = [ + ":riscv_cheriot_instructions", + ":riscv_cheriot_vector", + "@com_google_absl//absl/functional:bind_front", + ], +) + +mpact_isa_decoder( + name = "riscv_cheriot_rvv_fp_isa", + src = "riscv_cheriot_rvv.isa", + includes = [ + "riscv_cheriot.isa", + "riscv_cheriot_f.isa", + "riscv_cheriot_vector.isa", + "riscv_cheriot_vector_fp.isa", + ], + isa_name = "RiscVCheriotRVVFp", + prefix = "riscv_cheriot_rvv_fp", + deps = [ + ":riscv_cheriot_f", + ":riscv_cheriot_instructions", + ":riscv_cheriot_vector", + ":riscv_cheriot_vector_fp", "@com_google_absl//absl/functional:bind_front", ], ) @@ -48,34 +102,91 @@ ], ) +mpact_bin_fmt_decoder( + name = "riscv_cheriot_rvv_bin_fmt", + src = "riscv_cheriot_rvv.bin_fmt", + decoder_name = "RiscVCheriotRVV", + includes = [ + "riscv_cheriot.bin_fmt", + "riscv_cheriot_f.bin_fmt", + "riscv_cheriot_vector.bin_fmt", + "riscv_cheriot_vector_fp.bin_fmt", + ], + deps = [ + ":riscv_cheriot_rvv_isa", + ], +) + +mpact_bin_fmt_decoder( + name = "riscv_cheriot_rvv_fp_bin_fmt", + src = "riscv_cheriot_rvv.bin_fmt", + decoder_name = "RiscVCheriotRVVFp", + includes = [ + "riscv_cheriot.bin_fmt", + "riscv_cheriot_f.bin_fmt", + "riscv_cheriot_vector.bin_fmt", + "riscv_cheriot_vector_fp.bin_fmt", + ], + prefix = "riscv_cheriot_rvv_fp", + deps = [ + ":riscv_cheriot_rvv_fp_isa", + ], +) + cc_library( - name = "riscv_cheriot", + name = "riscv_cheriot_instructions", srcs = [ - "cheriot_register.cc", - "cheriot_state.cc", "riscv_cheriot_a_instructions.cc", "riscv_cheriot_i_instructions.cc", "riscv_cheriot_instructions.cc", "riscv_cheriot_m_instructions.cc", - "riscv_cheriot_minstret.cc", "riscv_cheriot_priv_instructions.cc", "riscv_cheriot_zicsr_instructions.cc", ], hdrs = [ - "cheriot_register.h", - "cheriot_state.h", "riscv_cheriot_a_instructions.h", - "riscv_cheriot_csr_enum.h", "riscv_cheriot_i_instructions.h", - "riscv_cheriot_instruction_helpers.h", "riscv_cheriot_instructions.h", "riscv_cheriot_m_instructions.h", - "riscv_cheriot_minstret.h", "riscv_cheriot_priv_instructions.h", "riscv_cheriot_zicsr_instructions.h", ], tags = ["not_run:arm"], deps = [ + ":cheriot_state", + ":instruction_helpers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-riscv//riscv:stoull_wrapper", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_library( + name = "cheriot_state", + srcs = [ + "cheriot_register.cc", + "cheriot_state.cc", + "cheriot_vector_true_operand.cc", + "riscv_cheriot_minstret.cc", + ], + hdrs = [ + "cheriot_register.h", + "cheriot_state.h", + "cheriot_vector_true_operand.h", + "riscv_cheriot_csr_enum.h", + "riscv_cheriot_minstret.h", + "riscv_cheriot_register_aliases.h", + ], + tags = ["not_run:arm"], + deps = [ "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/functional:any_invocable", @@ -86,8 +197,8 @@ "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_mpact-riscv//riscv:riscv_fp_state", "@com_google_mpact-riscv//riscv:riscv_state", - "@com_google_mpact-riscv//riscv:stoull_wrapper", "@com_google_mpact-sim//mpact/sim/generic:arch_state", "@com_google_mpact-sim//mpact/sim/generic:core", "@com_google_mpact-sim//mpact/sim/generic:counters", @@ -98,6 +209,183 @@ ) cc_library( + name = "riscv_cheriot_f", + srcs = [ + "riscv_cheriot_f_instructions.cc", + ], + hdrs = [ + "riscv_cheriot_f_instructions.h", + ], + tags = ["not_run:arm"], + deps = [ + ":cheriot_state", + ":instruction_helpers", + "@com_google_absl//absl/functional:bind_front", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_library( + name = "cheriot_vector_state", + srcs = [ + "cheriot_vector_state.cc", + ], + hdrs = [ + "cheriot_vector_state.h", + ], + tags = ["not_run:arm"], + deps = [ + ":cheriot_state", + "@com_google_absl//absl/log", + "@com_google_mpact-riscv//riscv:riscv_state", + ], +) + +cc_library( + name = "riscv_cheriot_vector", + srcs = [ + "riscv_cheriot_vector_memory_instructions.cc", + "riscv_cheriot_vector_opi_instructions.cc", + "riscv_cheriot_vector_opm_instructions.cc", + "riscv_cheriot_vector_permute_instructions.cc", + "riscv_cheriot_vector_reduction_instructions.cc", + "riscv_cheriot_vector_unary_instructions.cc", + ], + hdrs = [ + "riscv_cheriot_vector_memory_instructions.h", + "riscv_cheriot_vector_opi_instructions.h", + "riscv_cheriot_vector_opm_instructions.h", + "riscv_cheriot_vector_permute_instructions.h", + "riscv_cheriot_vector_reduction_instructions.h", + "riscv_cheriot_vector_unary_instructions.h", + ], + copts = [ + "-O3", + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":cheriot_state", + ":cheriot_vector_state", + ":instruction_helpers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_library( + name = "riscv_cheriot_vector_fp", + srcs = [ + "riscv_cheriot_vector_fp_compare_instructions.cc", + "riscv_cheriot_vector_fp_instructions.cc", + "riscv_cheriot_vector_fp_reduction_instructions.cc", + "riscv_cheriot_vector_fp_unary_instructions.cc", + ], + hdrs = [ + "riscv_cheriot_vector_fp_compare_instructions.h", + "riscv_cheriot_vector_fp_instructions.h", + "riscv_cheriot_vector_fp_reduction_instructions.h", + "riscv_cheriot_vector_fp_unary_instructions.h", + ], + copts = [ + "-O3", + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":cheriot_state", + ":cheriot_vector_state", + ":instruction_helpers", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_library( + name = "instruction_helpers", + hdrs = [ + "riscv_cheriot_instruction_helpers.h", + "riscv_cheriot_vector_instruction_helpers.h", + ], + deps = [ + ":cheriot_state", + ":cheriot_vector_state", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_library( + name = "cheriot_getter_helpers", + hdrs = [ + "cheriot_getter_helpers.h", + ], + deps = [ + ":cheriot_state", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +) + +cc_library( + name = "cheriot_getters", + hdrs = [ + "cheriot_f_getters.h", + "cheriot_getters.h", + "cheriot_rvv_fp_getters.h", + "cheriot_rvv_getters.h", + "riscv_cheriot_encoding_common.h", + ], + deps = [ + ":cheriot_getter_helpers", + ":cheriot_state", + ":riscv_cheriot_bin_fmt", + ":riscv_cheriot_isa", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/strings", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + ], +) + +cc_library( name = "riscv_cheriot_decoder", srcs = [ "cheriot_decoder.cc", @@ -106,10 +394,10 @@ hdrs = [ "cheriot_decoder.h", "riscv_cheriot_encoding.h", - "riscv_cheriot_register_aliases.h", ], deps = [ - ":riscv_cheriot", + ":cheriot_getters", + ":cheriot_state", ":riscv_cheriot_bin_fmt", ":riscv_cheriot_isa", "@com_google_absl//absl/container:flat_hash_map", @@ -127,6 +415,66 @@ ) cc_library( + name = "riscv_cheriot_rvv_decoder", + srcs = [ + "cheriot_rvv_decoder.cc", + "riscv_cheriot_rvv_encoding.cc", + ], + hdrs = [ + "cheriot_rvv_decoder.h", + "riscv_cheriot_register_aliases.h", + "riscv_cheriot_rvv_encoding.h", + ], + deps = [ + ":cheriot_getters", + ":cheriot_state", + ":riscv_cheriot_rvv_bin_fmt", + ":riscv_cheriot_rvv_isa", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:program_error", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_library( + name = "riscv_cheriot_rvv_fp_decoder", + srcs = [ + "cheriot_rvv_fp_decoder.cc", + "riscv_cheriot_rvv_fp_encoding.cc", + ], + hdrs = [ + "cheriot_rvv_fp_decoder.h", + "riscv_cheriot_register_aliases.h", + "riscv_cheriot_rvv_fp_encoding.h", + ], + deps = [ + ":cheriot_getters", + ":cheriot_state", + ":riscv_cheriot_rvv_fp_bin_fmt", + ":riscv_cheriot_rvv_fp_isa", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:program_error", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_library( name = "cheriot_top", srcs = [ "cheriot_top.cc", @@ -137,8 +485,7 @@ copts = ["-O3"], deps = [ ":cheriot_debug_interface", - ":riscv_cheriot", - ":riscv_cheriot_decoder", + ":cheriot_state", ":riscv_cheriot_isa", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/functional:any_invocable", @@ -189,8 +536,8 @@ copts = ["-O3"], deps = [ ":cheriot_debug_interface", + ":cheriot_state", ":cheriot_top", - ":riscv_cheriot", ":riscv_cheriot_decoder", ":riscv_cheriot_isa", "@com_google_absl//absl/container:btree", @@ -217,11 +564,14 @@ ], copts = ["-O3"], deps = [ + ":cheriot_state", ":cheriot_top", ":debug_command_shell", ":instrumentation", - ":riscv_cheriot", ":riscv_cheriot_decoder", + ":riscv_cheriot_instructions", + ":riscv_cheriot_rvv_decoder", + ":riscv_cheriot_rvv_fp_decoder", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/flags:parse", "@com_google_absl//absl/flags:usage", @@ -236,6 +586,7 @@ "@com_google_mpact-riscv//riscv:riscv_arm_semihost", "@com_google_mpact-riscv//riscv:riscv_clint", "@com_google_mpact-riscv//riscv:stoull_wrapper", + "@com_google_mpact-sim//mpact/sim/generic:core", "@com_google_mpact-sim//mpact/sim/generic:core_debug_interface", "@com_google_mpact-sim//mpact/sim/generic:counters", "@com_google_mpact-sim//mpact/sim/generic:instruction", @@ -258,7 +609,7 @@ "cheriot_load_filter.h", ], deps = [ - ":riscv_cheriot", + ":cheriot_state", "@com_google_mpact-sim//mpact/sim/generic:core", "@com_google_mpact-sim//mpact/sim/generic:counters", "@com_google_mpact-sim//mpact/sim/util/memory", @@ -335,10 +686,10 @@ deps = [ ":cheriot_debug_info", ":cheriot_debug_interface", + ":cheriot_state", ":cheriot_top", ":debug_command_shell", ":instrumentation", - ":riscv_cheriot", ":riscv_cheriot_decoder", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/functional:bind_front", @@ -377,7 +728,7 @@ "test_rig_packets.h", ], deps = [ - ":riscv_cheriot", + ":cheriot_state", ":riscv_cheriot_bin_fmt", ":riscv_cheriot_decoder", ":riscv_cheriot_isa",
diff --git a/cheriot/cheriot_decoder.cc b/cheriot/cheriot_decoder.cc index 63e6103..805d5fd 100644 --- a/cheriot/cheriot_decoder.cc +++ b/cheriot/cheriot_decoder.cc
@@ -22,7 +22,6 @@ #include "cheriot/riscv_cheriot_encoding.h" #include "cheriot/riscv_cheriot_enums.h" #include "mpact/sim/generic/instruction.h" -#include "mpact/sim/generic/program_error.h" #include "mpact/sim/generic/type_helpers.h" #include "mpact/sim/util/memory/memory_interface.h" #include "riscv//riscv_state.h" @@ -38,10 +37,6 @@ CheriotDecoder::CheriotDecoder(CheriotState *state, util::MemoryInterface *memory) : state_(state), memory_(memory) { - // Get a handle to the internal error in the program error controller. - decode_error_ = state->program_error_controller()->GetProgramError( - generic::ProgramErrorController::kInternalErrorName); - // Need a data buffer to load instructions from memory. Allocate a single // buffer that can be reused for each instruction word. inst_db_ = db_factory_.Allocate<uint32_t>(1);
diff --git a/cheriot/cheriot_decoder.h b/cheriot/cheriot_decoder.h index 7912734..f219457 100644 --- a/cheriot/cheriot_decoder.h +++ b/cheriot/cheriot_decoder.h
@@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MPACT_CHERIOT__CHERIOT_DECODER_H_ -#define MPACT_CHERIOT__CHERIOT_DECODER_H_ +#ifndef MPACT_CHERIOT_CHERIOT_DECODER_H_ +#define MPACT_CHERIOT_CHERIOT_DECODER_H_ #include <cstdint> #include <memory> @@ -28,7 +28,6 @@ #include "mpact/sim/generic/data_buffer.h" #include "mpact/sim/generic/decoder_interface.h" #include "mpact/sim/generic/instruction.h" -#include "mpact/sim/generic/program_error.h" #include "mpact/sim/util/memory/memory_interface.h" namespace mpact { @@ -81,7 +80,6 @@ private: CheriotState *state_; util::MemoryInterface *memory_; - std::unique_ptr<generic::ProgramError> decode_error_; generic::DataBufferFactory db_factory_; generic::DataBuffer *inst_db_; isa32::RiscVCheriotEncoding *cheriot_encoding_; @@ -93,4 +91,4 @@ } // namespace sim } // namespace mpact -#endif // MPACT_CHERIOT__CHERIOT_DECODER_H_ +#endif // MPACT_CHERIOT_CHERIOT_DECODER_H_
diff --git a/cheriot/cheriot_f_getters.h b/cheriot/cheriot_f_getters.h new file mode 100644 index 0000000..c4565ba --- /dev/null +++ b/cheriot/cheriot_f_getters.h
@@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_F_GETTERS_H_ +#define MPACT_CHERIOT_CHERIOT_F_GETTERS_H_ + +#include <cstdint> +#include <string> + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_getter_helpers.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "mpact/sim/generic/immediate_operand.h" +#include "mpact/sim/generic/literal_operand.h" +#include "mpact/sim/generic/operand_interface.h" +#include "riscv//riscv_register.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::cheriot::RiscVCheriotEncodingCommon; +using ::mpact::sim::generic::DestinationOperandInterface; +using ::mpact::sim::generic::ImmediateOperand; +using ::mpact::sim::generic::IntLiteralOperand; +using ::mpact::sim::generic::SourceOperandInterface; +using ::mpact::sim::riscv::RiscVState; +using ::mpact::sim::riscv::RVFpRegister; + +using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; +using DestOpGetterMap = + absl::flat_hash_map<int, + absl::AnyInvocable<DestinationOperandInterface *(int)>>; + +template <typename Enum, typename Extractors> +void AddCheriotFSourceGetters(SourceOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, *Enum::kFrs1, [common]() { + int num = Extractors::RType::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num)); + }); + Insert(getter_map, *Enum::kFrs2, [common]() { + int num = Extractors::RType::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num)); + }); + Insert(getter_map, *Enum::kFrs3, [common]() { + int num = Extractors::R4Type::ExtractRs3(common->inst_word()); + return GetRegisterSourceOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num)); + }); + Insert(getter_map, *Enum::kFs1, [common]() { + int num = Extractors::RType::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num)); + }); + Insert(getter_map, *Enum::kRm, [common]() -> SourceOperandInterface * { + uint32_t rm = (common->inst_word() >> 12) & 0x7; + switch (rm) { + case 0: + return new generic::IntLiteralOperand<0>(); + case 1: + return new generic::IntLiteralOperand<1>(); + case 2: + return new generic::IntLiteralOperand<2>(); + case 3: + return new generic::IntLiteralOperand<3>(); + case 4: + return new generic::IntLiteralOperand<4>(); + case 5: + return new generic::IntLiteralOperand<5>(); + case 6: + return new generic::IntLiteralOperand<6>(); + case 7: + return new generic::IntLiteralOperand<7>(); + default: + return nullptr; + } + }); +} + +template <typename Enum, typename Extractors> +void AddCheriotFDestGetters(DestOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, *Enum::kFrd, [common](int latency) { + int num = Extractors::RType::ExtractRd(common->inst_word()); + return GetRegisterDestinationOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num), latency); + }); + Insert(getter_map, *Enum::kFflags, [common](int latency) { + return GetCSRSetBitsDestinationOp<uint32_t>(common->state(), "fflags", + latency, ""); + }); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_F_GETTERS_H_
diff --git a/cheriot/cheriot_getter_helpers.h b/cheriot/cheriot_getter_helpers.h new file mode 100644 index 0000000..d6546cb --- /dev/null +++ b/cheriot/cheriot_getter_helpers.h
@@ -0,0 +1,160 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_GETTER_HELPERS_H_ +#define MPACT_CHERIOT_CHERIOT_GETTER_HELPERS_H_ + +#include <string> +#include <vector> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "mpact/sim/generic/operand_interface.h" +#include "mpact/sim/generic/register.h" +#include "riscv//riscv_register.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::DestinationOperandInterface; +using ::mpact::sim::generic::SourceOperandInterface; +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; + +constexpr int kNumRegTable[8] = {8, 1, 2, 1, 4, 1, 2, 1}; + +template <typename M, typename E, typename G> +inline void Insert(M &map, E entry, G getter) { + map.insert(std::make_pair(static_cast<int>(entry), getter)); +} + +// Generic helper functions to create register operands. +template <typename RegType> +inline DestinationOperandInterface *GetRegisterDestinationOp( + CheriotState *state, std::string name, int latency) { + auto *reg = state->GetRegister<RegType>(name).first; + return reg->CreateDestinationOperand(latency); +} + +template <typename RegType> +inline DestinationOperandInterface *GetRegisterDestinationOp( + CheriotState *state, std::string name, int latency, std::string op_name) { + auto *reg = state->GetRegister<RegType>(name).first; + return reg->CreateDestinationOperand(latency, op_name); +} + +template <typename T> +inline DestinationOperandInterface *GetCSRSetBitsDestinationOp( + CheriotState *state, std::string name, int latency, std::string op_name) { + auto result = state->csr_set()->GetCsr(name); + if (!result.ok()) { + LOG(ERROR) << "No such CSR '" << name << "'"; + return nullptr; + } + auto *csr = result.value(); + auto *op = csr->CreateSetDestinationOperand(latency, op_name); + return op; +} + +template <typename RegType> +inline SourceOperandInterface *GetRegisterSourceOp(CheriotState *state, + std::string name) { + auto *reg = state->GetRegister<RegType>(name).first; + auto *op = reg->CreateSourceOperand(); + return op; +} + +template <typename RegType> +inline SourceOperandInterface *GetRegisterSourceOp(CheriotState *state, + std::string name, + std::string op_name) { + auto *reg = state->GetRegister<RegType>(name).first; + auto *op = reg->CreateSourceOperand(op_name); + return op; +} + +template <typename RegType> +inline void GetVRegGroup(CheriotState *state, int reg_num, + std::vector<generic::RegisterBase *> *vreg_group) { + // The number of registers in a vector register group depends on the register + // index: 0, 8, 16, 24 each have 8 registers, 4, 12, 20, 28 each have 4, + // 2, 6, 10, 14, 18, 22, 26, 30 each have two, and all odd numbered register + // groups have only 1. + int num_regs = kNumRegTable[reg_num % 8]; + for (int i = 0; i < num_regs; i++) { + auto vreg_name = absl::StrCat(CheriotState::kVregPrefix, reg_num + i); + vreg_group->push_back(state->GetRegister<RegType>(vreg_name).first); + } +} +template <typename RegType> +inline SourceOperandInterface *GetVectorRegisterSourceOp(CheriotState *state, + int reg_num) { + std::vector<generic::RegisterBase *> vreg_group; + GetVRegGroup<RegType>(state, reg_num, &vreg_group); + auto *v_src_op = new RV32VectorSourceOperand( + absl::Span<generic::RegisterBase *>(vreg_group), + absl::StrCat(CheriotState::kVregPrefix, reg_num)); + return v_src_op; +} + +template <typename RegType> +inline DestinationOperandInterface *GetVectorRegisterDestinationOp( + CheriotState *state, int latency, int reg_num) { + std::vector<generic::RegisterBase *> vreg_group; + GetVRegGroup<RegType>(state, reg_num, &vreg_group); + auto *v_dst_op = new RV32VectorDestinationOperand( + absl::Span<generic::RegisterBase *>(vreg_group), latency, + absl::StrCat(CheriotState::kVregPrefix, reg_num)); + return v_dst_op; +} + +template <typename RegType> +inline SourceOperandInterface *GetVectorMaskRegisterSourceOp( + CheriotState *state, int reg_num) { + // Mask register groups only have a single register. + std::vector<generic::RegisterBase *> vreg_group; + vreg_group.push_back(state + ->GetRegister<RegType>( + absl::StrCat(CheriotState::kVregPrefix, reg_num)) + .first); + auto *v_src_op = new RV32VectorSourceOperand( + absl::Span<generic::RegisterBase *>(vreg_group), + absl::StrCat(CheriotState::kVregPrefix, reg_num)); + return v_src_op; +} + +template <typename RegType> +inline DestinationOperandInterface *GetVectorMaskRegisterDestinationOp( + CheriotState *state, int latency, int reg_num) { + // Mask register groups only have a single register. + std::vector<generic::RegisterBase *> vreg_group; + vreg_group.push_back(state + ->GetRegister<RegType>( + absl::StrCat(CheriotState::kVregPrefix, reg_num)) + .first); + auto *v_dst_op = new RV32VectorDestinationOperand( + absl::Span<generic::RegisterBase *>(vreg_group), latency, + absl::StrCat(CheriotState::kVregPrefix, reg_num)); + return v_dst_op; +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_GETTER_HELPERS_H_
diff --git a/cheriot/cheriot_getters.h b/cheriot/cheriot_getters.h new file mode 100644 index 0000000..efc2af2 --- /dev/null +++ b/cheriot/cheriot_getters.h
@@ -0,0 +1,391 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_GETTERS_H_ +#define MPACT_CHERIOT_CHERIOT_GETTERS_H_ + +#include <cstdint> +#include <string> + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_getter_helpers.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "cheriot/riscv_cheriot_register_aliases.h" +#include "mpact/sim/generic/immediate_operand.h" +#include "mpact/sim/generic/literal_operand.h" +#include "mpact/sim/generic/operand_interface.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::cheriot::RiscVCheriotEncodingCommon; +using ::mpact::sim::generic::DestinationOperandInterface; +using ::mpact::sim::generic::ImmediateOperand; +using ::mpact::sim::generic::IntLiteralOperand; +using ::mpact::sim::generic::SourceOperandInterface; + +using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; +using DestOpGetterMap = + absl::flat_hash_map<int, + absl::AnyInvocable<DestinationOperandInterface *(int)>>; + +template <typename Enum, typename Extractors> +void AddCheriotSourceGetters(SourceOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + // Source operand getters. + Insert(getter_map, *Enum::kAAq, [common]() -> SourceOperandInterface * { + if (Extractors::Inst32Format::ExtractAq(common->inst_word())) { + return new IntLiteralOperand<1>(); + } + return new IntLiteralOperand<0>(); + }); + Insert(getter_map, *Enum::kARl, [common]() -> SourceOperandInterface * { + if (Extractors::Inst32Format::ExtractRl(common->inst_word())) { + return new generic::IntLiteralOperand<1>(); + } + return new generic::IntLiteralOperand<0>(); + }); + Insert(getter_map, *Enum::kBImm12, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst32Format::ExtractBImm(common->inst_word())); + }); + Insert(getter_map, *Enum::kC2, [common]() { + return GetRegisterSourceOp<CheriotRegister>(common->state(), "c2", "csp"); + }); + Insert(getter_map, *Enum::kC3cs1, [common]() { + auto num = Extractors::CS::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kC3cs2, [common]() { + auto num = Extractors::CS::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kC3rs1, [common]() { + auto num = Extractors::CS::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kC3rs2, [common]() { + auto num = Extractors::CS::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCcs2, [common]() { + auto num = Extractors::CSS::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCgp, [common]() { + return GetRegisterSourceOp<CheriotRegister>(common->state(), "c3", "c3"); + }); + Insert(getter_map, *Enum::kCSRUimm5, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst32Format::ExtractIUimm5(common->inst_word())); + }); + Insert(getter_map, *Enum::kCrs1, [common]() { + auto num = Extractors::CR::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCrs2, [common]() { + auto num = Extractors::CR::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCs1, [common]() { + auto num = Extractors::RType::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCs2, [common]() { + auto num = Extractors::RType::ExtractRs2(common->inst_word()); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCsr, [common]() { + auto csr_indx = Extractors::IType::ExtractUImm12(common->inst_word()); + auto res = common->state()->csr_set()->GetCsr(csr_indx); + if (!res.ok()) { + return new ImmediateOperand<uint32_t>(csr_indx); + } + auto *csr = res.value(); + return new ImmediateOperand<uint32_t>(csr_indx, csr->name()); + }); + Insert(getter_map, *Enum::kICbImm8, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst16Format::ExtractBimm(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiImm6, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::CI::ExtractImm6(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiImm612, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst16Format::ExtractImm18(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiUimm6, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractUimm6(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiUimm6x4, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractCiImmW(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiImm6x16, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst16Format::ExtractCiImm10(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiUimm6x8, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractCiImmD(common->inst_word())); + }); + Insert(getter_map, *Enum::kICiwUimm8x4, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractCiwImm10(common->inst_word())); + }); + Insert(getter_map, *Enum::kICjImm11, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst16Format::ExtractJimm(common->inst_word())); + }); + Insert(getter_map, *Enum::kIClUimm5x4, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractClImmW(common->inst_word())); + }); + Insert(getter_map, *Enum::kIClUimm5x8, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractClImmD(common->inst_word())); + }); + Insert(getter_map, *Enum::kICshUimm6, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::CSH::ExtractUimm6(common->inst_word())); + }); + Insert(getter_map, *Enum::kICshImm6, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::CSH::ExtractImm6(common->inst_word())); + }); + Insert(getter_map, *Enum::kICssUimm6x4, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractCssImmW(common->inst_word())); + }); + Insert(getter_map, *Enum::kICssUimm6x8, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst16Format::ExtractCssImmD(common->inst_word())); + }); + Insert(getter_map, *Enum::kIImm12, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst32Format::ExtractImm12(common->inst_word())); + }); + Insert(getter_map, *Enum::kIUimm5, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::I5Type::ExtractRUimm5(common->inst_word())); + }); + Insert(getter_map, *Enum::kIUimm12, [common]() { + return new ImmediateOperand<uint32_t>( + Extractors::Inst32Format::ExtractUImm12(common->inst_word())); + }); + Insert(getter_map, *Enum::kJImm12, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst32Format::ExtractImm12(common->inst_word())); + }); + Insert(getter_map, *Enum::kJImm20, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst32Format::ExtractJImm(common->inst_word())); + }); + Insert(getter_map, *Enum::kPcc, [common]() { + return GetRegisterSourceOp<CheriotRegister>(common->state(), "pcc", "pcc"); + }); + Insert(getter_map, *Enum::kRd, [common]() -> SourceOperandInterface * { + int num = Extractors::RType::ExtractRd(common->inst_word()); + if (num == 0) return new generic::IntLiteralOperand<0>({1}); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kRs1, [common]() -> SourceOperandInterface * { + int num = Extractors::RType::ExtractRs1(common->inst_word()); + if (num == 0) return new generic::IntLiteralOperand<0>({1}); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kRs2, [common]() -> SourceOperandInterface * { + int num = Extractors::RType::ExtractRs2(common->inst_word()); + if (num == 0) return new generic::IntLiteralOperand<0>({1}); + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kSImm12, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::SType::ExtractSImm(common->inst_word())); + }); + Insert(getter_map, *Enum::kScr, [common]() -> SourceOperandInterface * { + int csr_indx = Extractors::RType::ExtractRs2(common->inst_word()); + std::string csr_name; + switch (csr_indx) { + case 28: + csr_name = "mtcc"; + break; + case 29: + csr_name = "mtdc"; + break; + case 30: + csr_name = "mscratchc"; + break; + case 31: + csr_name = "mepcc"; + break; + default: + return nullptr; + } + auto res = common->state()->csr_set()->GetCsr(csr_name); + if (!res.ok()) { + return GetRegisterSourceOp<CheriotRegister>(common->state(), csr_name, + csr_name); + } + auto *csr = res.value(); + auto *op = csr->CreateSourceOperand(); + return op; + }); + Insert(getter_map, *Enum::kSImm20, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::UType::ExtractSImm(common->inst_word())); + }); + Insert(getter_map, *Enum::kUImm20, [common]() { + return new ImmediateOperand<int32_t>( + Extractors::Inst32Format::ExtractUImm(common->inst_word())); + }); + Insert(getter_map, *Enum::kX0, + []() { return new generic::IntLiteralOperand<0>({1}); }); + Insert(getter_map, *Enum::kX2, [common]() { + return GetRegisterSourceOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, 2), + kXRegisterAliases[2]); + }); +} + +template <typename Enum, typename Extractors> +void AddCheriotDestGetters(DestOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + // Destination operand getters. + Insert(getter_map, *Enum::kC2, [common](int latency) { + return GetRegisterDestinationOp<CheriotRegister>(common->state(), "c2", + latency, "csp"); + }); + Insert(getter_map, *Enum::kC3cd, [common](int latency) { + int num = Extractors::CL::ExtractRd(common->inst_word()); + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), latency, + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kC3rd, [common](int latency) { + int num = Extractors::CL::ExtractRd(common->inst_word()); + if (num == 0) { + return GetRegisterDestinationOp<CheriotRegister>(common->state(), + "X0Dest", latency); + } + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), latency, + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kC3rs1, [common](int latency) { + int num = Extractors::CL::ExtractRs1(common->inst_word()); + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), latency, + kXRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCd, [common](int latency) { + int num = Extractors::RType::ExtractRd(common->inst_word()); + if (num == 0) { + return GetRegisterDestinationOp<CheriotRegister>(common->state(), + "X0Dest", latency); + } + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kCregPrefix, num), latency, + kCRegisterAliases[num]); + }); + Insert(getter_map, *Enum::kCsr, [common](int latency) { + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), CheriotState::kCsrName, latency); + }); + Insert(getter_map, *Enum::kScr, + [common](int latency) -> DestinationOperandInterface * { + int csr_indx = Extractors::RType::ExtractRs2(common->inst_word()); + std::string csr_name; + switch (csr_indx) { + case 28: + csr_name = "mtcc"; + break; + case 29: + csr_name = "mtdc"; + break; + case 30: + csr_name = "mscratchc"; + break; + case 31: + csr_name = "mepcc"; + break; + default: + return nullptr; + } + auto res = common->state()->csr_set()->GetCsr(csr_name); + if (!res.ok()) { + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), csr_name, latency); + } + auto *csr = res.value(); + auto *op = csr->CreateWriteDestinationOperand(latency, csr_name); + return op; + }); + Insert(getter_map, *Enum::kRd, + [common](int latency) -> DestinationOperandInterface * { + int num = Extractors::RType::ExtractRd(common->inst_word()); + if (num == 0) { + return GetRegisterDestinationOp<CheriotRegister>(common->state(), + "X0Dest", 0); + } else { + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, num), + latency, kXRegisterAliases[num]); + } + }); + Insert(getter_map, *Enum::kX1, [common](int latency) { + return GetRegisterDestinationOp<CheriotRegister>( + common->state(), absl::StrCat(CheriotState::kXregPrefix, 1), latency, + kXRegisterAliases[1]); + }); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_GETTERS_H_
diff --git a/cheriot/cheriot_rvv_decoder.cc b/cheriot/cheriot_rvv_decoder.cc new file mode 100644 index 0000000..bd576d8 --- /dev/null +++ b/cheriot/cheriot_rvv_decoder.cc
@@ -0,0 +1,103 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/cheriot_rvv_decoder.h" + +#include <cstdint> +#include <string> + +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_rvv_decoder.h" +#include "cheriot/riscv_cheriot_rvv_encoding.h" +#include "cheriot/riscv_cheriot_rvv_enums.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "mpact/sim/util/memory/memory_interface.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using RV_EC = ::mpact::sim::riscv::ExceptionCode; + +using ::mpact::sim::generic::operator*; // NOLINT: is used below (clang error). + +CheriotRVVDecoder::CheriotRVVDecoder(CheriotState *state, + util::MemoryInterface *memory) + : state_(state), memory_(memory) { + // Need a data buffer to load instructions from memory. Allocate a single + // buffer that can be reused for each instruction word. + inst_db_ = db_factory_.Allocate<uint32_t>(1); + // Allocate the isa factory class, the top level isa decoder instance, and + // the encoding parser. + cheriot_rvv_isa_factory_ = new CheriotRVVIsaFactory(); + cheriot_rvv_isa_ = new isa32_rvv::RiscVCheriotRVVInstructionSet( + state, cheriot_rvv_isa_factory_); + cheriot_rvv_encoding_ = new isa32_rvv::RiscVCheriotRVVEncoding(state); +} + +CheriotRVVDecoder::~CheriotRVVDecoder() { + delete cheriot_rvv_isa_; + delete cheriot_rvv_isa_factory_; + delete cheriot_rvv_encoding_; + inst_db_->DecRef(); +} + +generic::Instruction *CheriotRVVDecoder::DecodeInstruction(uint64_t address) { + // First check that the address is aligned properly. If not, create and return + // an instruction object that will raise an exception. + if (address & 0x1) { + auto *inst = new generic::Instruction(0, state_); + inst->set_size(1); + inst->SetDisassemblyString("Misaligned instruction address"); + inst->set_opcode(*isa32_rvv::OpcodeEnum::kNone); + inst->set_address(address); + inst->set_semantic_function([this](generic::Instruction *inst) { + state_->Trap(/*is_interrupt*/ false, inst->address(), + *RV_EC::kInstructionAddressMisaligned, inst->address() ^ 0x1, + inst); + }); + return inst; + } + + // If the address is greater than the max address, return an instruction + // that will raise an exception. + if (address > state_->max_physical_address()) { + auto *inst = new generic::Instruction(0, state_); + inst->set_size(0); + inst->SetDisassemblyString("Instruction access fault"); + inst->set_opcode(*isa32_rvv::OpcodeEnum::kNone); + inst->set_address(address); + inst->set_semantic_function([this](generic::Instruction *inst) { + state_->Trap(/*is_interrupt*/ false, inst->address(), + *RV_EC::kInstructionAccessFault, inst->address(), nullptr); + }); + return inst; + } + + // Read the instruction word from memory and parse it in the encoding parser. + memory_->Load(address, inst_db_, nullptr, nullptr); + uint32_t iword = inst_db_->Get<uint32_t>(0); + cheriot_rvv_encoding_->ParseInstruction(iword); + + // Call the isa decoder to obtain a new instruction object for the instruction + // word that was parsed above. + auto *instruction = cheriot_rvv_isa_->Decode(address, cheriot_rvv_encoding_); + return instruction; +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/cheriot_rvv_decoder.h b/cheriot/cheriot_rvv_decoder.h new file mode 100644 index 0000000..6731ab1 --- /dev/null +++ b/cheriot/cheriot_rvv_decoder.h
@@ -0,0 +1,95 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MPACT_CHERIOT_CHERIOT_RVV_DECODER_H_ +#define MPACT_CHERIOT_CHERIOT_RVV_DECODER_H_ + +#include <cstdint> +#include <memory> + +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_rvv_decoder.h" +#include "cheriot/riscv_cheriot_rvv_encoding.h" +#include "cheriot/riscv_cheriot_rvv_enums.h" +#include "mpact/sim/generic/arch_state.h" +#include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/generic/decoder_interface.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/util/memory/memory_interface.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::ArchState; + +// This is the factory class needed by the generated decoder. It is responsible +// for creating the decoder for each slot instance. Since the riscv architecture +// only has a single slot, it's a pretty simple class. +class CheriotRVVIsaFactory + : public isa32_rvv::RiscVCheriotRVVInstructionSetFactory { + public: + std::unique_ptr<isa32_rvv::RiscvCheriotRvvSlot> CreateRiscvCheriotRvvSlot( + ArchState *state) override { + return std::make_unique<isa32_rvv::RiscvCheriotRvvSlot>(state); + } +}; + +// This class implements the generic DecoderInterface and provides a bridge +// to the (isa specific) generated decoder classes. +class CheriotRVVDecoder : public generic::DecoderInterface { + public: + using SlotEnum = isa32_rvv::SlotEnum; + using OpcodeEnum = isa32_rvv::OpcodeEnum; + + CheriotRVVDecoder(CheriotState *state, util::MemoryInterface *memory); + CheriotRVVDecoder() = delete; + ~CheriotRVVDecoder() override; + + // This will always return a valid instruction that can be executed. In the + // case of a decode error, the semantic function in the instruction object + // instance will raise an internal simulator error when executed. + generic::Instruction *DecodeInstruction(uint64_t address) override; + + // Return the number of opcodes supported by this decoder. + int GetNumOpcodes() const override { + return static_cast<int>(OpcodeEnum::kPastMaxValue); + } + // Return the name of the opcode at the given index. + const char *GetOpcodeName(int index) const override { + return isa32_rvv::kOpcodeNames[index]; + } + + // Getter. + isa32_rvv::RiscVCheriotRVVEncoding *cheriot_rvv_encoding() const { + return cheriot_rvv_encoding_; + } + + private: + CheriotState *state_; + util::MemoryInterface *memory_; + generic::DataBufferFactory db_factory_; + generic::DataBuffer *inst_db_; + isa32_rvv::RiscVCheriotRVVEncoding *cheriot_rvv_encoding_; + isa32_rvv::RiscVCheriotRVVInstructionSetFactory *cheriot_rvv_isa_factory_; + isa32_rvv::RiscVCheriotRVVInstructionSet *cheriot_rvv_isa_; +}; + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_RVV_DECODER_H_
diff --git a/cheriot/cheriot_rvv_fp_decoder.cc b/cheriot/cheriot_rvv_fp_decoder.cc new file mode 100644 index 0000000..d6d3b74 --- /dev/null +++ b/cheriot/cheriot_rvv_fp_decoder.cc
@@ -0,0 +1,104 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/cheriot_rvv_fp_decoder.h" + +#include <cstdint> +#include <string> + +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_rvv_fp_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_encoding.h" +#include "cheriot/riscv_cheriot_rvv_fp_enums.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "mpact/sim/util/memory/memory_interface.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using RV_EC = ::mpact::sim::riscv::ExceptionCode; + +using ::mpact::sim::generic::operator*; // NOLINT: is used below (clang error). + +CheriotRVVFPDecoder::CheriotRVVFPDecoder(CheriotState *state, + util::MemoryInterface *memory) + : state_(state), memory_(memory) { + // Need a data buffer to load instructions from memory. Allocate a single + // buffer that can be reused for each instruction word. + inst_db_ = db_factory_.Allocate<uint32_t>(1); + // Allocate the isa factory class, the top level isa decoder instance, and + // the encoding parser. + cheriot_rvv_fp_isa_factory_ = new CheriotRVVFPIsaFactory(); + cheriot_rvv_fp_isa_ = new isa32_rvv_fp::RiscVCheriotRVVFpInstructionSet( + state, cheriot_rvv_fp_isa_factory_); + cheriot_rvv_fp_encoding_ = new isa32_rvv_fp::RiscVCheriotRVVFPEncoding(state); +} + +CheriotRVVFPDecoder::~CheriotRVVFPDecoder() { + delete cheriot_rvv_fp_isa_; + delete cheriot_rvv_fp_isa_factory_; + delete cheriot_rvv_fp_encoding_; + inst_db_->DecRef(); +} + +generic::Instruction *CheriotRVVFPDecoder::DecodeInstruction(uint64_t address) { + // First check that the address is aligned properly. If not, create and return + // an instruction object that will raise an exception. + if (address & 0x1) { + auto *inst = new generic::Instruction(0, state_); + inst->set_size(1); + inst->SetDisassemblyString("Misaligned instruction address"); + inst->set_opcode(*isa32_rvv_fp::OpcodeEnum::kNone); + inst->set_address(address); + inst->set_semantic_function([this](generic::Instruction *inst) { + state_->Trap(/*is_interrupt*/ false, inst->address(), + *RV_EC::kInstructionAddressMisaligned, inst->address() ^ 0x1, + inst); + }); + return inst; + } + + // If the address is greater than the max address, return an instruction + // that will raise an exception. + if (address > state_->max_physical_address()) { + auto *inst = new generic::Instruction(0, state_); + inst->set_size(0); + inst->SetDisassemblyString("Instruction access fault"); + inst->set_opcode(*isa32_rvv_fp::OpcodeEnum::kNone); + inst->set_address(address); + inst->set_semantic_function([this](generic::Instruction *inst) { + state_->Trap(/*is_interrupt*/ false, inst->address(), + *RV_EC::kInstructionAccessFault, inst->address(), nullptr); + }); + return inst; + } + + // Read the instruction word from memory and parse it in the encoding parser. + memory_->Load(address, inst_db_, nullptr, nullptr); + uint32_t iword = inst_db_->Get<uint32_t>(0); + cheriot_rvv_fp_encoding_->ParseInstruction(iword); + + // Call the isa decoder to obtain a new instruction object for the instruction + // word that was parsed above. + auto *instruction = + cheriot_rvv_fp_isa_->Decode(address, cheriot_rvv_fp_encoding_); + return instruction; +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/cheriot_rvv_fp_decoder.h b/cheriot/cheriot_rvv_fp_decoder.h new file mode 100644 index 0000000..3693dea --- /dev/null +++ b/cheriot/cheriot_rvv_fp_decoder.h
@@ -0,0 +1,96 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MPACT_CHERIOT_CHERIOT_RVV_FP_DECODER_H_ +#define MPACT_CHERIOT_CHERIOT_RVV_FP_DECODER_H_ + +#include <cstdint> +#include <memory> + +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_rvv_fp_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_encoding.h" +#include "cheriot/riscv_cheriot_rvv_fp_enums.h" +#include "mpact/sim/generic/arch_state.h" +#include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/generic/decoder_interface.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/util/memory/memory_interface.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::ArchState; + +// This is the factory class needed by the generated decoder. It is responsible +// for creating the decoder for each slot instance. Since the riscv architecture +// only has a single slot, it's a pretty simple class. +class CheriotRVVFPIsaFactory + : public isa32_rvv_fp::RiscVCheriotRVVFpInstructionSetFactory { + public: + std::unique_ptr<isa32_rvv_fp::RiscvCheriotRvvFpSlot> + CreateRiscvCheriotRvvFpSlot(ArchState *state) override { + return std::make_unique<isa32_rvv_fp::RiscvCheriotRvvFpSlot>(state); + } +}; + +// This class implements the generic DecoderInterface and provides a bridge +// to the (isa specific) generated decoder classes. +class CheriotRVVFPDecoder : public generic::DecoderInterface { + public: + using SlotEnum = isa32_rvv_fp::SlotEnum; + using OpcodeEnum = isa32_rvv_fp::OpcodeEnum; + + CheriotRVVFPDecoder(CheriotState *state, util::MemoryInterface *memory); + CheriotRVVFPDecoder() = delete; + ~CheriotRVVFPDecoder() override; + + // This will always return a valid instruction that can be executed. In the + // case of a decode error, the semantic function in the instruction object + // instance will raise an internal simulator error when executed. + generic::Instruction *DecodeInstruction(uint64_t address) override; + + // Return the number of opcodes supported by this decoder. + int GetNumOpcodes() const override { + return static_cast<int>(OpcodeEnum::kPastMaxValue); + } + // Return the name of the opcode at the given index. + const char *GetOpcodeName(int index) const override { + return isa32_rvv_fp::kOpcodeNames[index]; + } + + // Getter. + isa32_rvv_fp::RiscVCheriotRVVFPEncoding *cheriot_rvv_fp_encoding() const { + return cheriot_rvv_fp_encoding_; + } + + private: + CheriotState *state_; + util::MemoryInterface *memory_; + generic::DataBufferFactory db_factory_; + generic::DataBuffer *inst_db_; + isa32_rvv_fp::RiscVCheriotRVVFPEncoding *cheriot_rvv_fp_encoding_; + isa32_rvv_fp::RiscVCheriotRVVFpInstructionSetFactory + *cheriot_rvv_fp_isa_factory_; + isa32_rvv_fp::RiscVCheriotRVVFpInstructionSet *cheriot_rvv_fp_isa_; +}; + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_RVV_FP_DECODER_H_
diff --git a/cheriot/cheriot_rvv_fp_getters.h b/cheriot/cheriot_rvv_fp_getters.h new file mode 100644 index 0000000..c3eee6b --- /dev/null +++ b/cheriot/cheriot_rvv_fp_getters.h
@@ -0,0 +1,66 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_RVV_FP_GETTERS_H_ +#define MPACT_CHERIOT_CHERIOT_RVV_FP_GETTERS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_getter_helpers.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "mpact/sim/generic/operand_interface.h" +#include "riscv//riscv_register.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { +using ::mpact::sim::cheriot::RiscVCheriotEncodingCommon; +using ::mpact::sim::generic::DestinationOperandInterface; +using ::mpact::sim::generic::SourceOperandInterface; +using ::mpact::sim::riscv::RiscVState; +using ::mpact::sim::riscv::RVFpRegister; + +using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; +using DestOpGetterMap = + absl::flat_hash_map<int, + absl::AnyInvocable<DestinationOperandInterface *(int)>>; + +template <typename Enum, typename Extractors> +void AddCheriotRVVFPSourceGetters(SourceOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, *Enum::kFs1, [common]() { + int num = Extractors::VArith::ExtractRs1(common->inst_word()); + return GetRegisterSourceOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num)); + }); +} + +template <typename Enum, typename Extractors> +void AddCheriotRVVFPDestGetters(DestOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, *Enum::kFd, [common](int latency) { + int num = Extractors::VArith::ExtractRd(common->inst_word()); + return GetRegisterDestinationOp<RVFpRegister>( + common->state(), absl::StrCat(RiscVState::kFregPrefix, num), latency); + }); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_RVV_FP_GETTERS_H_
diff --git a/cheriot/cheriot_rvv_getters.h b/cheriot/cheriot_rvv_getters.h new file mode 100644 index 0000000..303a2cb --- /dev/null +++ b/cheriot/cheriot_rvv_getters.h
@@ -0,0 +1,124 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_RVV_GETTERS_H_ +#define MPACT_CHERIOT_CHERIOT_RVV_GETTERS_H_ + +#include <cstdint> + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "cheriot/cheriot_getter_helpers.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_true_operand.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "mpact/sim/generic/immediate_operand.h" +#include "mpact/sim/generic/literal_operand.h" +#include "mpact/sim/generic/operand_interface.h" +#include "riscv//riscv_register.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::cheriot::RiscVCheriotEncodingCommon; +using ::mpact::sim::generic::DestinationOperandInterface; +using ::mpact::sim::generic::ImmediateOperand; +using ::mpact::sim::generic::IntLiteralOperand; +using ::mpact::sim::generic::SourceOperandInterface; +using ::mpact::sim::riscv::RV32VectorTrueOperand; +using ::mpact::sim::riscv::RVVectorRegister; + +using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; +using DestOpGetterMap = + absl::flat_hash_map<int, + absl::AnyInvocable<DestinationOperandInterface *(int)>>; + +template <typename Enum, typename Extractors> +void AddCheriotRVVSourceGetters(SourceOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, *Enum::kConst1, [common]() -> SourceOperandInterface * { + return new IntLiteralOperand<1>(); + }); + Insert(getter_map, *Enum::kNf, [common]() -> SourceOperandInterface * { + auto imm = Extractors::VMem::ExtractNf(common->inst_word()); + return new ImmediateOperand<uint32_t>(imm); + }); + Insert(getter_map, *Enum::kSimm5, [common]() -> SourceOperandInterface * { + auto imm = Extractors::VArith::ExtractSimm5(common->inst_word()); + return new ImmediateOperand<uint32_t>(imm); + }); + Insert(getter_map, *Enum::kUimm5, [common]() -> SourceOperandInterface * { + auto imm = Extractors::VArith::ExtractUimm5(common->inst_word()); + return new ImmediateOperand<int32_t>(imm); + }); + Insert(getter_map, *Enum::kVd, [common]() -> SourceOperandInterface * { + auto num = Extractors::VArith::ExtractVd(common->inst_word()); + return GetVectorRegisterSourceOp<RVVectorRegister>(common->state(), num); + }); + Insert(getter_map, *Enum::kVm, [common]() -> SourceOperandInterface * { + auto vm = Extractors::VArith::ExtractVm(common->inst_word()); + return new ImmediateOperand<uint32_t>(vm); + }); + Insert(getter_map, *Enum::kVmask, [common]() -> SourceOperandInterface * { + auto vm = Extractors::VArith::ExtractVm(common->inst_word()); + if (vm == 1) { + // Unmasked, return the True mask. + return new CheriotVectorTrueOperand(common->state()); + } + // Masked. Return the mask register. + return GetVectorMaskRegisterSourceOp<RVVectorRegister>(common->state(), 0); + }); + Insert(getter_map, *Enum::kVmaskTrue, [common]() -> SourceOperandInterface * { + return new CheriotVectorTrueOperand(common->state()); + }); + Insert(getter_map, *Enum::kVs1, [common]() -> SourceOperandInterface * { + auto num = Extractors::VArith::ExtractVs1(common->inst_word()); + return GetVectorRegisterSourceOp<RVVectorRegister>(common->state(), num); + }); + Insert(getter_map, *Enum::kVs2, [common]() -> SourceOperandInterface * { + auto num = Extractors::VArith::ExtractVs2(common->inst_word()); + return GetVectorRegisterSourceOp<RVVectorRegister>(common->state(), num); + }); + Insert(getter_map, *Enum::kVs3, [common]() -> SourceOperandInterface * { + auto num = Extractors::VMem::ExtractVs3(common->inst_word()); + return GetVectorRegisterSourceOp<RVVectorRegister>(common->state(), num); + }); + Insert(getter_map, *Enum::kZimm10, [common]() -> SourceOperandInterface * { + auto imm = Extractors::VConfig::ExtractZimm10(common->inst_word()); + return new ImmediateOperand<uint32_t>(imm); + }); + Insert(getter_map, *Enum::kZimm11, [common]() -> SourceOperandInterface * { + auto imm = Extractors::VConfig::ExtractZimm11(common->inst_word()); + return new ImmediateOperand<uint32_t>(imm); + }); +} + +template <typename Enum, typename Extractors> +void AddCheriotRVVDestGetters(DestOpGetterMap &getter_map, + RiscVCheriotEncodingCommon *common) { + Insert(getter_map, Enum::kVd, + [common](int latency) -> DestinationOperandInterface * { + auto num = Extractors::VArith::ExtractVd(common->inst_word()); + return GetVectorRegisterDestinationOp<RVVectorRegister>( + common->state(), latency, num); + }); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_RVV_GETTERS_H_
diff --git a/cheriot/cheriot_state.cc b/cheriot/cheriot_state.cc index 1842f29..882e1cf 100644 --- a/cheriot/cheriot_state.cc +++ b/cheriot/cheriot_state.cc
@@ -442,15 +442,8 @@ DataBuffer *mask_db, int el_size, DataBuffer *db, Instruction *child_inst, ReferenceCount *context) { - // Check for alignment. - uint64_t mask = el_size - 1; - for (auto address : address_db->Get<uint64_t>()) { - if ((address & mask) != 0) { - Trap(/*is_interrupt*/ false, address, *EC::kLoadAddressMisaligned, - inst == nullptr ? 0 : inst->address(), inst); - return; - } - } + // For now, we don't check for alignment on vector memory accesses. + // Check for physical address violation. for (auto address : address_db->Get<uint64_t>()) { if (address < min_physical_address_ || address > max_physical_address_) { @@ -496,15 +489,8 @@ void CheriotState::StoreMemory(const Instruction *inst, DataBuffer *address_db, DataBuffer *mask_db, int el_size, DataBuffer *db) { - // Check for alignment. - uint64_t mask = el_size - 1; - for (auto address : address_db->Get<uint64_t>()) { - if ((address & mask) != 0) { - Trap(/*is_interrupt*/ false, address, *EC::kStoreAddressMisaligned, - inst == nullptr ? 0 : inst->address(), inst); - return; - } - } + // Ignore alignment check for vector memory accesses. + // Check for physical address violation. for (auto address : address_db->Get<uint64_t>()) { if (address < min_physical_address_ || address > max_physical_address_) { @@ -525,6 +511,10 @@ tagged_memory_->Store(address_db, mask_db, el_size, db); } +void CheriotState::DbgStoreMemory(uint64_t address, DataBuffer *db) { + tagged_memory_->Store(address, db); +} + void CheriotState::DbgLoadMemory(uint64_t address, DataBuffer *db) { tagged_memory_->Load(address, db, nullptr, nullptr); }
diff --git a/cheriot/cheriot_state.h b/cheriot/cheriot_state.h index c14c73e..6d0ec7e 100644 --- a/cheriot/cheriot_state.h +++ b/cheriot/cheriot_state.h
@@ -38,8 +38,11 @@ #include "mpact/sim/util/memory/memory_interface.h" #include "mpact/sim/util/memory/tagged_memory_interface.h" #include "riscv//riscv_csr.h" +#include "riscv//riscv_fp_state.h" #include "riscv//riscv_misa.h" +#include "riscv//riscv_register.h" #include "riscv//riscv_state.h" +#include "riscv//riscv_vector_state.h" #include "riscv//riscv_xip_xie.h" #include "riscv//riscv_xstatus.h" @@ -61,14 +64,17 @@ using ::mpact::sim::riscv::PrivilegeMode; using ::mpact::sim::riscv::RiscVCsrInterface; using ::mpact::sim::riscv::RiscVCsrSet; +using ::mpact::sim::riscv::RiscVFPState; using ::mpact::sim::riscv::RiscVMIe; using ::mpact::sim::riscv::RiscVMIp; using ::mpact::sim::riscv::RiscVMIsa; using ::mpact::sim::riscv::RiscVMStatus; using ::mpact::sim::riscv::RiscVSimpleCsr; +using ::mpact::sim::riscv::RVVectorRegister; // Forward declare the CHERIoT register type. class CheriotRegister; +class CheriotVectorState; // CHERIoT exception codes. These are used in addition to the ones defined for // vanilla RiscV. @@ -159,6 +165,9 @@ std::vector<RiscVCsrInterface *> &); friend void CreateCsrs<uint64_t>(CheriotState *, std::vector<RiscVCsrInterface *> &); + auto static constexpr kVregPrefix = + ::mpact::sim::riscv::RiscVState::kVregPrefix; + // Memory footprint of a capability register. static constexpr int kCapabilitySizeInBytes = 8; // Pc name. @@ -191,6 +200,21 @@ return std::make_pair(AddRegister<RegisterType>(name), true); } + // Specialization for RiscV vector registers. + template <> + std::pair<RVVectorRegister *, bool> GetRegister<RVVectorRegister>( + absl::string_view name) { + int vector_byte_width = vector_register_width(); + if (vector_byte_width == 0) return std::make_pair(nullptr, false); + auto ptr = registers()->find(std::string(name)); + if (ptr != registers()->end()) + return std::make_pair(static_cast<RVVectorRegister *>(ptr->second), + false); + // Create a new register and return a pointer to the object. + return std::make_pair( + AddRegister<RVVectorRegister>(name, vector_byte_width), true); + } + // Add register alias. template <typename RegisterType> absl::Status AddRegisterAlias(absl::string_view current_name, @@ -228,6 +252,7 @@ // Debug memory methods. void DbgLoadMemory(uint64_t address, DataBuffer *db); + void DbgStoreMemory(uint64_t address, DataBuffer *db); // Called by the fence instruction semantic function to signal a fence // operation. void Fence(const Instruction *inst, int fm, int predecessor, int successor); @@ -341,6 +366,12 @@ on_trap_ = std::move(callback); } + RiscVFPState *rv_fp() { return rv_fp_; } + void set_rv_fp(RiscVFPState *rv_fp) { rv_fp_ = rv_fp; } + CheriotVectorState *rv_vector() { return rv_vector_; } + void set_rv_vector(CheriotVectorState *rv_vector) { rv_vector_ = rv_vector; } + void set_vector_register_width(int value) { vector_register_width_ = value; } + int vector_register_width() const { return vector_register_width_; } RiscVMStatus *mstatus() { return mstatus_; } RiscVMIsa *misa() { return misa_; } RiscVMIp *mip() { return mip_; } @@ -382,6 +413,7 @@ CheriotRegister *pcc_ = nullptr; CheriotRegister *cgp_ = nullptr; bool branch_ = false; + int vector_register_width_ = 0; uint64_t max_physical_address_; uint64_t min_physical_address_ = 0; int num_tags_per_load_; @@ -396,6 +428,8 @@ absl::AnyInvocable<bool(const Instruction *)> on_wfi_; absl::AnyInvocable<bool(const Instruction *)> on_cease_; std::vector<RiscVCsrInterface *> csr_vec_; + RiscVFPState *rv_fp_ = nullptr; + CheriotVectorState *rv_vector_ = nullptr; // For interrupt handling. bool is_interrupt_available_ = false; int interrupt_handler_depth_ = 0;
diff --git a/cheriot/cheriot_test_rig_decoder.cc b/cheriot/cheriot_test_rig_decoder.cc index c1ff08f..3a77023 100644 --- a/cheriot/cheriot_test_rig_decoder.cc +++ b/cheriot/cheriot_test_rig_decoder.cc
@@ -17,9 +17,6 @@ #include <cstdint> #include <string> -#include "absl/log/log.h" -#include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "cheriot/cheriot_state.h" #include "cheriot/riscv_cheriot_bin_decoder.h" #include "cheriot/riscv_cheriot_decoder.h" @@ -171,9 +168,9 @@ rs2 = 0; break; case FormatEnum::kCI: // 2 reg operands: rd, rs1. - // cnop, caddi, cli, caddi16sp, clui, cslli, clwsp, cldsp. + // cnop, caddi, cli, caddi16sp, clui, cslli, clwsp, clcsp. rd = encoding::c_i::ExtractRd(inst_word16); - if ((opcode == OpcodeEnum::kClwsp) || (opcode == OpcodeEnum::kCldsp)) { + if ((opcode == OpcodeEnum::kClwsp) || (opcode == OpcodeEnum::kClcsp)) { rs1 = 2; } else { rs1 = encoding::c_i::ExtractRs1(inst_word16);
diff --git a/cheriot/cheriot_top.cc b/cheriot/cheriot_top.cc index c357f0e..243d112 100644 --- a/cheriot/cheriot_top.cc +++ b/cheriot/cheriot_top.cc
@@ -31,10 +31,8 @@ #include "absl/strings/str_format.h" #include "absl/synchronization/notification.h" #include "cheriot/cheriot_debug_interface.h" -#include "cheriot/cheriot_decoder.h" #include "cheriot/cheriot_register.h" #include "cheriot/cheriot_state.h" -#include "cheriot/riscv_cheriot_enums.h" #include "cheriot/riscv_cheriot_register_aliases.h" #include "mpact/sim/generic/action_point_manager_base.h" #include "mpact/sim/generic/breakpoint_manager.h" @@ -59,8 +57,6 @@ namespace sim { namespace cheriot { -constexpr char kCheriotName[] = "CherIoT"; - using ::mpact::sim::generic::ActionPointManagerBase; using ::mpact::sim::generic::BreakpointManager; using ::mpact::sim::riscv::RiscVActionPointMemoryInterface; @@ -68,7 +64,7 @@ using PB = ::mpact::sim::cheriot::CheriotRegister::PermissionBits; CheriotTop::CheriotTop(std::string name, CheriotState *state, - CheriotDecoder *decoder) + DecoderInterface *decoder) : Component(name), state_(state), cheriot_decoder_(decoder),
diff --git a/cheriot/cheriot_top.h b/cheriot/cheriot_top.h index 6c0f5dd..660d188 100644 --- a/cheriot/cheriot_top.h +++ b/cheriot/cheriot_top.h
@@ -28,7 +28,6 @@ #include "absl/status/statusor.h" #include "absl/synchronization/notification.h" #include "cheriot/cheriot_debug_interface.h" -#include "cheriot/cheriot_decoder.h" #include "cheriot/cheriot_register.h" #include "cheriot/cheriot_state.h" #include "mpact/sim/generic/action_point_manager_base.h" @@ -51,6 +50,7 @@ using ::mpact::sim::generic::ActionPointManagerBase; using ::mpact::sim::generic::BreakpointManager; +using ::mpact::sim::generic::DecoderInterface; using ::mpact::sim::riscv::RiscVActionPointMemoryInterface; struct BranchTraceEntry { @@ -68,7 +68,7 @@ using RunStatus = generic::CoreDebugInterface::RunStatus; using HaltReason = generic::CoreDebugInterface::HaltReason; - CheriotTop(std::string name, CheriotState *state, CheriotDecoder *decoder); + CheriotTop(std::string name, CheriotState *state, DecoderInterface *decoder); ~CheriotTop() override; // Methods inherited from CoreDebugInterface.
diff --git a/cheriot/cheriot_vector_state.cc b/cheriot/cheriot_vector_state.cc new file mode 100644 index 0000000..f5fc5a4 --- /dev/null +++ b/cheriot/cheriot_vector_state.cc
@@ -0,0 +1,193 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/cheriot_vector_state.h" + +#include <cstdint> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "riscv//riscv_csr.h" + +namespace mpact { +namespace sim { +namespace cheriot { +namespace { + +constexpr char kVlName[] = "vl"; +constexpr uint32_t kVlReadMask = 0xffff'ffff; +constexpr uint32_t kVlWriteMask = 0; +constexpr uint32_t kVlInitial = 0; + +constexpr char kVtypeName[] = "vtype"; +constexpr uint32_t kVtypeReadMask = 0xffff'ffff; +constexpr uint32_t kVtypeWriteMask = 0; +constexpr uint32_t kVtypeInitial = 0; + +constexpr char kVlenbName[] = "vlenb"; +constexpr uint32_t kVlenbReadMask = 0xffff'ffff; +constexpr uint32_t kVlenbWriteMask = 0; + +constexpr char kVstartName[] = "vstart"; +constexpr uint32_t kVstartReadMask = 0xffff'ffff; +constexpr uint32_t kVstartWriteMask = 0; +constexpr uint32_t kVstartInitial = 0; + +constexpr char kVxsatName[] = "vxsat"; +constexpr uint32_t kVxsatReadMask = 1; +constexpr uint32_t kVxsatWriteMask = 1; +constexpr uint32_t kVxsatInitial = 0; + +constexpr char kVxrmName[] = "vxrm"; +constexpr uint32_t kVxrmReadMask = 3; +constexpr uint32_t kVxrmWriteMask = 3; +constexpr uint32_t kVxrmInitial = 0; + +constexpr char kVcsrName[] = "vcsr"; +constexpr uint32_t kVcsrReadMask = 7; +constexpr uint32_t kVcsrWriteMask = 7; +constexpr uint32_t kVcsrInitial = 0; + +// Helper function to avoid some extra code below. +static inline void LogIfError(absl::Status status) { + if (status.ok()) return; + LOG(ERROR) << status.message(); +} + +} // namespace + +using ::mpact::sim::riscv::RiscVCsrEnum; +using ::mpact::sim::riscv::RiscVSimpleCsr; + +CheriotVl::CheriotVl(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVlName, RiscVCsrEnum::kVl, kVlInitial, + kVlReadMask, kVlWriteMask, + vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVl::AsUint32() { return vector_state_->vector_length(); } + +CheriotVtype::CheriotVtype(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVtypeName, RiscVCsrEnum::kVtype, kVtypeInitial, + kVtypeReadMask, kVtypeWriteMask, + vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVtype::AsUint32() { return vector_state_->vtype(); } + +CheriotVstart::CheriotVstart(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVstartName, RiscVCsrEnum::kVstart, + kVstartInitial, kVstartReadMask, + kVstartWriteMask, vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVstart::AsUint32() { return vector_state_->vstart(); } + +void CheriotVstart::Write(uint32_t value) { vector_state_->set_vstart(value); } + +CheriotVxsat::CheriotVxsat(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVxsatName, RiscVCsrEnum::kVxsat, kVxsatInitial, + kVxsatReadMask, kVxsatWriteMask, + vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVxsat::AsUint32() { return vector_state_->vxsat() ? 1 : 0; } + +void CheriotVxsat::Write(uint32_t value) { + vector_state_->set_vxsat(value & 1); +} + +CheriotVxrm::CheriotVxrm(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVxrmName, RiscVCsrEnum::kVxrm, kVxrmInitial, + kVxrmReadMask, kVxrmWriteMask, + vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVxrm::AsUint32() { return vector_state_->vxrm(); } + +void CheriotVxrm::Write(uint32_t value) { + vector_state_->set_vxrm(value & kVxrmWriteMask); +} + +CheriotVcsr::CheriotVcsr(CheriotVectorState* vector_state) + : RiscVSimpleCsr<uint32_t>(kVcsrName, RiscVCsrEnum::kVcsr, kVcsrInitial, + kVcsrReadMask, kVcsrWriteMask, + vector_state->state()), + vector_state_(vector_state) {} + +uint32_t CheriotVcsr::AsUint32() { + const uint32_t vxrm_shifted = (vector_state_->vxrm() & kVxrmWriteMask) << 1; + const uint32_t vxsat = vector_state_->vxsat() ? 1 : 0; + return vxrm_shifted | vxsat; +} + +void CheriotVcsr::Write(uint32_t value) { + const uint32_t vxrm = (value >> 1) & kVxrmWriteMask; + const uint32_t vxsat = value & 1; + vector_state_->set_vxrm(vxrm); + vector_state_->set_vxsat(vxsat); +} + +// Constructor for the vector class. Need to pass in the parent RV32 state and +// the vector length in bytes. +CheriotVectorState::CheriotVectorState(CheriotState* state, int byte_length) + : vector_register_byte_length_(byte_length), + vl_csr_(this), + vtype_csr_(this), + vlenb_csr_(kVlenbName, RiscVCsrEnum::kVlenb, vector_register_byte_length_, + kVlenbReadMask, kVlenbWriteMask, state), + vstart_csr_(this), + vxsat_csr_(this), + vxrm_csr_(this), + vcsr_csr_(this) { + state_ = state; + state->set_rv_vector(this); + state->set_vector_register_width(byte_length); + + LogIfError(state->csr_set()->AddCsr(&vl_csr_)); + LogIfError(state->csr_set()->AddCsr(&vtype_csr_)); + LogIfError(state->csr_set()->AddCsr(&vlenb_csr_)); + LogIfError(state->csr_set()->AddCsr(&vstart_csr_)); + LogIfError(state->csr_set()->AddCsr(&vxsat_csr_)); + LogIfError(state->csr_set()->AddCsr(&vxrm_csr_)); + LogIfError(state->csr_set()->AddCsr(&vcsr_csr_)); +} + +// This function parses the vector type, as used in the vset* instructions +// and sets the internal vector state accordingly. +void CheriotVectorState::SetVectorType(uint32_t vtype) { + static const int lmul8_values[8] = {8, 16, 32, 64, 0, 1, 2, 4}; + static const int sew_values[8] = {8, 16, 32, 64, 0, 0, 0, 0}; + set_vtype(vtype); + // The vtype field is divided into the following fields: + // [2..0]: vector length multiplier. + // [5..3]: element width specifier. + // [6]: vector tail agnostic bit. + // [7]: vector mask agnostic bit. + // Extract the lmul. + set_vector_length_multiplier(lmul8_values[(vtype & 0b111)]); + // Extract the sew and convert from bits to bytes. + set_selected_element_width(sew_values[(vtype >> 3) & 0b111] >> 3); + // Extract the tail and mask agnostic flags. + set_vector_tail_agnostic(static_cast<bool>((vtype >> 6) & 0b1)); + set_vector_mask_agnostic(static_cast<bool>((vtype >> 7) & 0b1)); + // Compute the new max vector length. + max_vector_length_ = vector_register_byte_length() * + vector_length_multiplier() / + (8 * selected_element_width()); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/cheriot_vector_state.h b/cheriot/cheriot_vector_state.h new file mode 100644 index 0000000..a2898ea --- /dev/null +++ b/cheriot/cheriot_vector_state.h
@@ -0,0 +1,198 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_RISCV_VECTOR_STATE_H_ +#define MPACT_CHERIOT_CHERIOT_RISCV_VECTOR_STATE_H_ + +#include <cstdint> + +#include "riscv//riscv_csr.h" + +// This file contains the definition of the vector state class. This class +// is used by the vector instructions to obtain information about the state +// and configuration of the vector unit. This class is also used to provide +// values that are read from CSRs, and updated by by values written to CSRs. + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::riscv::RiscVSimpleCsr; + +class CheriotState; +class CheriotVectorState; + +// Implementation of the 'vl' CSR. +class CheriotVl : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVl(CheriotVectorState* vector_state); + + // Overrides. Note that this CSR is read-only. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + + private: + const CheriotVectorState* const vector_state_; +}; + +// Implementation of the 'vtype' CSR. +class CheriotVtype : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVtype(CheriotVectorState* vector_state); + + // Overrides. Note that this CSR is read-only. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + + private: + const CheriotVectorState* const vector_state_; +}; + +// Implementation of the 'vstart' CSR. +class CheriotVstart : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVstart(CheriotVectorState* vector_state); + + // Overrides. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + void Write(uint32_t value) override; + void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } + + private: + CheriotVectorState* const vector_state_; +}; + +// Implementation of the 'vxsat' CSR. +class CheriotVxsat : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVxsat(CheriotVectorState* vector_state); + + // Overrides. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + void Write(uint32_t value) override; + void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } + + private: + CheriotVectorState* const vector_state_; +}; + +// Implementation of the 'vxrm' CSR. +class CheriotVxrm : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVxrm(CheriotVectorState* vector_state); + + // Overrides. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + void Write(uint32_t value) override; + void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } + + private: + CheriotVectorState* const vector_state_; +}; + +// Implementation of the 'vcsr' CSR. This CSR mirrors the bits in 'vxsat' and +// 'vxrm' as follows: +// +// bits 2:1 - vxrm +// bits 0:0 - vxsat +class CheriotVcsr : public RiscVSimpleCsr<uint32_t> { + public: + explicit CheriotVcsr(CheriotVectorState* vector_state); + + // Overrides. + uint32_t AsUint32() override; + uint64_t AsUint64() override { return static_cast<uint64_t>(AsUint32()); } + void Write(uint32_t value) override; + void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } + + private: + CheriotVectorState* const vector_state_; +}; + +class CheriotVectorState { + public: + CheriotVectorState(CheriotState* state, int byte_length); + + void SetVectorType(uint32_t vtype); + + // Public getters and setters. + int vstart() const { return vstart_; } + void clear_vstart() { vstart_ = 0; } + void set_vstart(int value) { vstart_ = value; } + int vector_length() const { return vector_length_; } + void set_vector_length(int value) { vector_length_ = value; } + bool vector_tail_agnostic() const { return vector_tail_agnostic_; } + bool vector_mask_agnostic() const { return vector_mask_agnostic_; } + int vector_length_multiplier() const { return vector_length_multiplier_; } + int selected_element_width() const { return selected_element_width_; } + bool vector_exception() const { return vector_exception_; } + void clear_vector_exception() { vector_exception_ = false; } + void set_vector_exception() { vector_exception_ = true; } + uint32_t vtype() const { return vtype_; } + void set_vtype(uint32_t value) { vtype_ = value; } + int vector_register_byte_length() const { + return vector_register_byte_length_; + } + int max_vector_length() const { return max_vector_length_; } + bool vxsat() const { return vxsat_; } + void set_vxsat(bool value) { vxsat_ = value; } + int vxrm() const { return vxrm_; } + void set_vxrm(int value) { vxrm_ = value & 0x3; } + + const CheriotState* state() const { return state_; } + CheriotState* state() { return state_; } + + private: + // Vector length multiplier is scaled by 8, to provide integer representation + // of values from 1/8, 1/4, 1/2, 1, 2, 4, 8, as 1, 2, 4, 8, 16, 32, 64. + void set_vector_length_multiplier(int value) { + vector_length_multiplier_ = value; + } + void set_selected_element_width(int value) { + selected_element_width_ = value; + } + void set_vector_tail_agnostic(bool value) { vector_tail_agnostic_ = value; } + void set_vector_mask_agnostic(bool value) { vector_mask_agnostic_ = value; } + + CheriotState* state_ = nullptr; + uint32_t vtype_; + bool vector_exception_ = false; + int vector_register_byte_length_ = 0; + int vstart_ = 0; + int max_vector_length_ = 0; + int vector_length_ = 0; + int vector_length_multiplier_ = 8; + // Selected element width (SEW) in bytes. + int selected_element_width_ = 1; + bool vector_tail_agnostic_ = false; + bool vector_mask_agnostic_ = false; + bool vxsat_ = false; + int vxrm_ = 0; + + CheriotVl vl_csr_; + CheriotVtype vtype_csr_; + RiscVSimpleCsr<uint32_t> vlenb_csr_; + CheriotVstart vstart_csr_; + CheriotVxsat vxsat_csr_; + CheriotVxrm vxrm_csr_; + CheriotVcsr vcsr_csr_; +}; + +} // namespace cheriot +} // namespace sim +} // namespace mpact +#endif // MPACT_CHERIOT_CHERIOT_RISCV_VECTOR_STATE_H_
diff --git a/cheriot/cheriot_vector_true_operand.cc b/cheriot/cheriot_vector_true_operand.cc new file mode 100644 index 0000000..549020d --- /dev/null +++ b/cheriot/cheriot_vector_true_operand.cc
@@ -0,0 +1,40 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/cheriot_vector_true_operand.h" + +#include <cstdint> +#include <limits> + +#include "cheriot/cheriot_state.h" +#include "riscv//riscv_register.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +CheriotVectorTrueOperand::CheriotVectorTrueOperand(CheriotState *state) + : RV32VectorSourceOperand( + state->GetRegister<RVVectorRegister>(kName).first) { + // Ensure the value is all ones. + auto *reg = state->GetRegister<RVVectorRegister>(kName).first; + auto data = reg->data_buffer()->Get<uint64_t>(); + for (int i = 0; i < data.size(); i++) { + data[i] = std::numeric_limits<uint64_t>::max(); + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/cheriot_vector_true_operand.h b/cheriot/cheriot_vector_true_operand.h new file mode 100644 index 0000000..2a880c1 --- /dev/null +++ b/cheriot/cheriot_vector_true_operand.h
@@ -0,0 +1,57 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_CHERIOT_VECTOR_TRUE_OPERAND_H_ +#define MPACT_CHERIOT_CHERIOT_VECTOR_TRUE_OPERAND_H_ + +#include <cstdint> +#include <string> + +#include "riscv//riscv_register.h" + +// File defines a Cheriot version of the RV32VectorTrueOperand registers. + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::riscv::RV32VectorSourceOperand; + +class CheriotState; + +class CheriotVectorTrueOperand : public RV32VectorSourceOperand { + public: + explicit CheriotVectorTrueOperand(CheriotState *state); + + CheriotVectorTrueOperand() = delete; + bool AsBool(int) final { return true; } + int8_t AsInt8(int) final { return 0xff; } + uint8_t AsUint8(int) final { return 0xff; } + int16_t AsInt16(int) final { return 0xffff; } + uint16_t AsUint16(int) final { return 0xffff; } + int32_t AsInt32(int) final { return 0xffff'ffff; } + uint32_t AsUint32(int) final { return 0xffff'ffff; } + int64_t AsInt64(int) final { return 0xffff'ffff'ffff'ffffULL; } + uint64_t AsUint64(int) final { return 0xffff'ffff'ffff'ffffLL; } + std::string AsString() const override { return ""; } + + private: + static constexpr char kName[] = "__VectorTrue__"; +}; + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_CHERIOT_VECTOR_TRUE_OPERAND_H_
diff --git a/cheriot/mpact_cheriot.cc b/cheriot/mpact_cheriot.cc index 9174c63..a15cfee 100644 --- a/cheriot/mpact_cheriot.cc +++ b/cheriot/mpact_cheriot.cc
@@ -41,11 +41,14 @@ #include "absl/time/time.h" #include "cheriot/cheriot_decoder.h" #include "cheriot/cheriot_instrumentation_control.h" +#include "cheriot/cheriot_rvv_decoder.h" +#include "cheriot/cheriot_rvv_fp_decoder.h" #include "cheriot/cheriot_top.h" #include "cheriot/debug_command_shell.h" #include "cheriot/riscv_cheriot_minstret.h" #include "mpact/sim/generic/core_debug_interface.h" #include "mpact/sim/generic/counters.h" +#include "mpact/sim/generic/decoder_interface.h" #include "mpact/sim/generic/instruction.h" #include "mpact/sim/proto/component_data.pb.h" #include "mpact/sim/util/memory/atomic_memory.h" @@ -67,7 +70,10 @@ using AddressRange = mpact::sim::util::MemoryWatcher::AddressRange; using ::mpact::sim::cheriot::CheriotDecoder; using ::mpact::sim::cheriot::CheriotInstrumentationControl; +using ::mpact::sim::cheriot::CheriotRVVDecoder; +using ::mpact::sim::cheriot::CheriotRVVFPDecoder; using ::mpact::sim::cheriot::CheriotState; +using ::mpact::sim::generic::DecoderInterface; using ::mpact::sim::proto::ComponentData; using ::mpact::sim::util::InstructionProfiler; using ::mpact::sim::util::TaggedMemoryUseProfiler; @@ -137,6 +143,12 @@ // Enable memory use profiling. ABSL_FLAG(bool, mem_profile, false, "Enable memory use profiling"); +// Enable RiscV Vector instructions +ABSL_FLAG(bool, rvv, false, "Enable RVV"); + +// Enable RiscV Vector instructions + FP. +ABSL_FLAG(bool, rvv_fp, false, "Enable RVV + FP"); + constexpr char kStackEndSymbolName[] = "__stack_end"; constexpr char kStackSizeSymbolName[] = "__stack_size"; @@ -256,10 +268,20 @@ } CheriotState cheriot_state("CherIoT", data_memory, static_cast<AtomicMemoryOpInterface *>(router)); - CheriotDecoder cheriot_decoder(&cheriot_state, - static_cast<MemoryInterface *>(router)); - CheriotTop cheriot_top("Cheriot", &cheriot_state, &cheriot_decoder); + DecoderInterface *decoder = nullptr; + if (absl::GetFlag(FLAGS_rvv_fp)) { + decoder = new CheriotRVVFPDecoder(&cheriot_state, + static_cast<MemoryInterface *>(router)); + } else if (absl::GetFlag(FLAGS_rvv_fp)) { + decoder = new CheriotRVVDecoder(&cheriot_state, + static_cast<MemoryInterface *>(router)); + } else { + decoder = new CheriotDecoder(&cheriot_state, + static_cast<MemoryInterface *>(router)); + } + + CheriotTop cheriot_top("Cheriot", &cheriot_state, decoder); // Enable instruction profiling if the flag is set. InstructionProfiler *inst_profiler = nullptr;
diff --git a/cheriot/riscv_cheriot.bin_fmt b/cheriot/riscv_cheriot.bin_fmt index db6799c..0d2cdf2 100644 --- a/cheriot/riscv_cheriot.bin_fmt +++ b/cheriot/riscv_cheriot.bin_fmt
@@ -1,3 +1,17 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // RiscV 32 bit CHERIoT instruction decoder. decoder RiscVCheriot { namespace mpact::sim::cheriot::encoding; @@ -435,9 +449,9 @@ instruction group RiscVCheriotInst16[16] : Inst16Format { caddi4spn : CIW: func3 == 0b000, op == 0b00, imm8 != 0; clw : CL : func3 == 0b010, op == 0b00; - cld : CL : func3 == 0b011, op == 0b00; + clc : CL : func3 == 0b011, op == 0b00; csw : CS : func3 == 0b110, op == 0b00; - csd : CS : func3 == 0b111, op == 0b00; + csc : CS : func3 == 0b111, op == 0b00; cnop : CI : func3 == 0b000, imm1 == 0, rs1 == 0, imm5 == 0, op == 0b01; chint : CI : func3 == 0b000, imm6 != 0, rs1 == 0, op == 0b01; caddi : CI : func3 == 0b000, imm6 != 0, rd != 0, op == 0b01; @@ -465,13 +479,13 @@ chint : CI : func3 == 0b000, imm1 == 0, rs1 == 0, imm5 != 0, op == 0b10; chint : CI : func3 == 0b000, imm6 == 0, op == 0b10; clwsp : CI : func3 == 0b010, rd != 0, op == 0b10; - cldsp : CI : func3 == 0b011, rd != 0, op == 0b10; + clcsp : CI : func3 == 0b011, rd != 0, op == 0b10; cmv : CR : func4 == 0b1000, rs1 != 0, rs2 != 0, op == 0b10; cebreak : Inst16Format : func3 == 0b100, bits == 0b1'00000'00000, op == 0b10; cadd : CR : func4 == 0b1001, rs1 != 0, rs2 != 0, op == 0b10; chint : CR : func4 == 0b1001, rs1 == 0, rs2 != 0, op == 0b10; cswsp : CSS: func3 == 0b110, op == 0b10; - csdsp : CSS: func3 == 0b111, op == 0b10; + cscsp : CSS: func3 == 0b111, op == 0b10; cheriot_cj : CJ : func3 == 0b101, op == 0b01; cheriot_cjal : CJ : func3 == 0b001, op == 0b01; cheriot_cjr : CR : func4 == 0b1000, rs1 > 1, rs2 == 0, op == 0b10;
diff --git a/cheriot/riscv_cheriot.isa b/cheriot/riscv_cheriot.isa index fea34c3..5c14c15 100644 --- a/cheriot/riscv_cheriot.isa +++ b/cheriot/riscv_cheriot.isa
@@ -1,4 +1,18 @@ -// This file contains the ISA description for the RiscV32G architecture. +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains the ISA description for the Cheriot architecture. includes { #include "absl/functional/bind_front.h" @@ -459,11 +473,11 @@ clwsp{(: c2, I_ci_uimm6x4 : ), (: : rd)}, disasm: "c.lw", "%rd, %I_ci_uimm6x4(%c2)", semfunc: "&RiscVILw", "&RiscVILwChild"; - cflwsp{(: c2, I_ci_uimm6x4 : ), (: : frd)}, - disasm: "c.flw", "%frd, %I_ci_uimm6x4(%c2)", - semfunc: "&RiscVILw", "&RiscVILwChild"; + // cflwsp{(: c2, I_ci_uimm6x4 : ), (: : frd)}, + // disasm: "c.flw", "%frd, %I_ci_uimm6x4(%c2)", + // semfunc: "&RiscVILw", "&RiscVILwChild"; // Reused for clc - cldsp{(: c2, I_ci_uimm6x8 : ), (: : cd)}, + clcsp{(: c2, I_ci_uimm6x8 : ), (: : cd)}, disasm: "c.clc", "%cd, %I_ci_uimm6x8(%c2)", semfunc: "&CheriotCLc", "&CheriotCLcChild"; // cfldsp{(: x2, I_ci_uimm6x8 : ), (: : drd)}, @@ -472,11 +486,11 @@ cswsp{: c2, I_css_uimm6x4, crs2 : }, disasm: "c.csw", "%crs2, %I_css_uimm6x4(%c2)", semfunc: "&RiscVISw"; - cfswsp{: c2, I_css_uimm6x4, cfrs2 : }, - disasm: ".cfsw", "%cfrs2, %I_css_uimm6x4(%c2)", - semfunc: "&RiscVISw"; + // cfswsp{: c2, I_css_uimm6x4, cfrs2 : }, + // disasm: ".cfsw", "%cfrs2, %I_css_uimm6x4(%c2)", + // semfunc: "&RiscVISw"; // Reused for csc - csdsp{: c2, I_css_uimm6x8, ccs2 : }, + cscsp{: c2, I_css_uimm6x8, ccs2 : }, disasm: "c.csc", "%ccs2, %I_css_uimm6x8(%c2)", semfunc: "&CheriotCSc"; // cfsdsp{: x2, I_css_uimm6x8, rdrs2 : }, @@ -486,7 +500,7 @@ disasm: "c.clw", "%c3rd, %I_cl_uimm5x4(%c3rs1)", semfunc: "&RiscVILw", "&RiscVILwChild"; // Reused for clc - cld{(: c3cs1, I_cl_uimm5x8 : ), (: : c3cd)}, + clc{(: c3cs1, I_cl_uimm5x8 : ), (: : c3cd)}, disasm: "c.clc", "%c3cd, %I_cl_uimm5x8(%c3cs1)", semfunc: "&CheriotCLc", "&CheriotCLcChild"; // cfld{(: c3rs1, I_cl_uimm5x8 : ), (: : c3drd)}, @@ -496,7 +510,7 @@ disasm: "c.csw", "%c3cs2, %I_cl_uimm5x4(%c3rs1)", semfunc: "&RiscVISw"; // Reused for csc - csd{: c3cs1, I_cl_uimm5x8, c3cs2 : }, + csc{: c3cs1, I_cl_uimm5x8, c3cs2 : }, disasm: "c.csc", "%c3cs2, %I_cl_uimm5x8(%c3cs1)", semfunc: "&CheriotCSc"; // cfsd{: c3rs1, I_cl_uimm5x8, c3drs2 : },
diff --git a/cheriot/riscv_cheriot_encoding.cc b/cheriot/riscv_cheriot_encoding.cc index df9b3cd..be376d7 100644 --- a/cheriot/riscv_cheriot_encoding.cc +++ b/cheriot/riscv_cheriot_encoding.cc
@@ -15,484 +15,43 @@ #include "cheriot/riscv_cheriot_encoding.h" #include <cstdint> -#include <string> #include "absl/log/log.h" #include "absl/strings/str_cat.h" +#include "cheriot/cheriot_getters.h" #include "cheriot/cheriot_register.h" #include "cheriot/cheriot_state.h" #include "cheriot/riscv_cheriot_bin_decoder.h" #include "cheriot/riscv_cheriot_decoder.h" +#include "cheriot/riscv_cheriot_encoding_common.h" #include "cheriot/riscv_cheriot_enums.h" -#include "cheriot/riscv_cheriot_register_aliases.h" -#include "mpact/sim/generic/immediate_operand.h" -#include "mpact/sim/generic/literal_operand.h" -#include "mpact/sim/generic/type_helpers.h" -#include "riscv//riscv_register.h" namespace mpact { namespace sim { namespace cheriot { namespace isa32 { -using ::mpact::sim::generic::operator*; // NOLINT: is used below (clang error). -using ::mpact::sim::riscv::RVFpRegister; - -// Generic helper functions to create register operands. -template <typename RegType> -inline DestinationOperandInterface *GetRegisterDestinationOp( - CheriotState *state, std::string name, int latency) { - auto *reg = state->GetRegister<RegType>(name).first; - return reg->CreateDestinationOperand(latency); -} - -template <typename RegType> -inline DestinationOperandInterface *GetRegisterDestinationOp( - CheriotState *state, std::string name, int latency, std::string op_name) { - auto *reg = state->GetRegister<RegType>(name).first; - return reg->CreateDestinationOperand(latency, op_name); -} - -template <typename T> -inline DestinationOperandInterface *GetCSRSetBitsDestinationOp( - CheriotState *state, std::string name, int latency, std::string op_name) { - auto result = state->csr_set()->GetCsr(name); - if (!result.ok()) { - LOG(ERROR) << "No such CSR '" << name << "'"; - return nullptr; - } - auto *csr = result.value(); - auto *op = csr->CreateSetDestinationOperand(latency, op_name); - return op; -} - -template <typename RegType> -inline SourceOperandInterface *GetRegisterSourceOp(CheriotState *state, - std::string name) { - auto *reg = state->GetRegister<RegType>(name).first; - auto *op = reg->CreateSourceOperand(); - return op; -} - -template <typename RegType> -inline SourceOperandInterface *GetRegisterSourceOp(CheriotState *state, - std::string name, - std::string op_name) { - auto *reg = state->GetRegister<RegType>(name).first; - auto *op = reg->CreateSourceOperand(op_name); - return op; -} +using Extractors = ::mpact::sim::cheriot::encoding::Extractors; RiscVCheriotEncoding::RiscVCheriotEncoding(CheriotState *state) - : state_(state) { - InitializeSourceOperandGetters(); - InitializeDestinationOperandGetters(); -} - -void RiscVCheriotEncoding::InitializeSourceOperandGetters() { - // Source operand getters. - source_op_getters_.emplace( - *SourceOpEnum::kAAq, [this]() -> SourceOperandInterface * { - if (encoding::inst32_format::ExtractAq(inst_word_)) { - return new generic::IntLiteralOperand<1>(); - } - return new generic::IntLiteralOperand<0>(); - }); - source_op_getters_.emplace( - *SourceOpEnum::kARl, [this]() -> SourceOperandInterface * { - if (encoding::inst32_format::ExtractRl(inst_word_)) { - return new generic::IntLiteralOperand<1>(); - } - return new generic::IntLiteralOperand<0>(); - }); - source_op_getters_.emplace(*SourceOpEnum::kBImm12, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst32_format::ExtractBImm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kC2, [this]() { - return GetRegisterSourceOp<CheriotRegister>(state_, "c2", "csp"); - }); - source_op_getters_.emplace(*SourceOpEnum::kC3cs1, [this]() { - auto num = encoding::c_s::ExtractRs1(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), - kCRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kC3cs2, [this]() { - auto num = encoding::c_s::ExtractRs2(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), - kCRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kC3rs1, [this]() { - auto num = encoding::c_s::ExtractRs1(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kC3rs2, [this]() { - auto num = encoding::c_s::ExtractRs2(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCcs2, [this]() { - auto num = encoding::c_s_s::ExtractRs2(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), - kCRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCgp, [this]() { - return GetRegisterSourceOp<CheriotRegister>(state_, "c3", "c3"); - }); - source_op_getters_.emplace(*SourceOpEnum::kCSRUimm5, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst32_format::ExtractIUimm5(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kCfrs2, [this]() { - auto num = encoding::c_r::ExtractRs2(inst_word_); - return GetRegisterSourceOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kFregPrefix, num), - kFRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCrs1, [this]() { - auto num = encoding::c_r::ExtractRs1(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCrs2, [this]() { - auto num = encoding::c_r::ExtractRs2(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCs1, [this]() { - auto num = encoding::r_type::ExtractRs1(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), - kCRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCs2, [this]() { - auto num = encoding::r_type::ExtractRs2(inst_word_); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), - kCRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kCsr, [this]() { - auto csr_indx = encoding::i_type::ExtractUImm12(inst_word_); - auto res = state_->csr_set()->GetCsr(csr_indx); - if (!res.ok()) { - return new generic::ImmediateOperand<uint32_t>(csr_indx); - } - auto *csr = res.value(); - return new generic::ImmediateOperand<uint32_t>(csr_indx, csr->name()); - }); - /* - source_op_getters_.emplace(*SourceOpEnum::kFrs1, [this]() { - int num = encoding::r_type::ExtractRs1(inst_word_); - return GetRegisterSourceOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kFregPrefix, num), - kFRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kFrs2, [this]() { - int num = encoding::r_type::ExtractRs2(inst_word_); - return GetRegisterSourceOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kFregPrefix, num), - kFRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kFrs3, [this]() { - int num = encoding::r4_type::ExtractRs3(inst_word_); - return GetRegisterSourceOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kFregPrefix, num), - kFRegisterAliases[num]); - }); - */ - source_op_getters_.emplace(*SourceOpEnum::kICbImm8, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst16_format::ExtractBimm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiImm6, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::c_i::ExtractImm6(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiImm612, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst16_format::ExtractImm18(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiUimm6, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractUimm6(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiUimm6x4, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractCiImmW(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiImm6x16, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst16_format::ExtractCiImm10(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiUimm6x8, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractCiImmD(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICiwUimm8x4, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractCiwImm10(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICjImm11, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst16_format::ExtractJimm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kIClUimm5x4, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractClImmW(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kIClUimm5x8, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractClImmD(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICshUimm6, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::c_s_h::ExtractUimm6(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICshImm6, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::c_s_h::ExtractImm6(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICssUimm6x4, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractCssImmW(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kICssUimm6x8, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst16_format::ExtractCssImmD(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kIImm12, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst32_format::ExtractImm12(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kIUimm5, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::i5_type::ExtractRUimm5(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kIUimm12, [this]() { - return new generic::ImmediateOperand<uint32_t>( - encoding::inst32_format::ExtractUImm12(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kJImm12, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst32_format::ExtractImm12(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kJImm20, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst32_format::ExtractJImm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kPcc, [this]() { - return GetRegisterSourceOp<CheriotRegister>(state_, "pcc", "pcc"); - }); - /* - source_op_getters_.emplace(*SourceOpEnum::kRm, - [this]() -> SourceOperandInterface * { - uint32_t rm = (inst_word_ >> 12) & 0x7; - switch (rm) { - case 0: - return new generic::IntLiteralOperand<0>(); - case 1: - return new generic::IntLiteralOperand<1>(); - case 2: - return new generic::IntLiteralOperand<2>(); - case 3: - return new generic::IntLiteralOperand<3>(); - case 4: - return new generic::IntLiteralOperand<4>(); - case 5: - return new generic::IntLiteralOperand<5>(); - case 6: - return new generic::IntLiteralOperand<6>(); - case 7: - return new generic::IntLiteralOperand<7>(); - default: - return nullptr; - } - }); - */ - source_op_getters_.emplace( - *SourceOpEnum::kRd, [this]() -> SourceOperandInterface * { - int num = encoding::r_type::ExtractRd(inst_word_); - if (num == 0) return new generic::IntLiteralOperand<0>({1}); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace( - *SourceOpEnum::kRs1, [this]() -> SourceOperandInterface * { - int num = encoding::r_type::ExtractRs1(inst_word_); - if (num == 0) return new generic::IntLiteralOperand<0>({1}); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace( - *SourceOpEnum::kRs2, [this]() -> SourceOperandInterface * { - int num = encoding::r_type::ExtractRs2(inst_word_); - if (num == 0) return new generic::IntLiteralOperand<0>({1}); - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), - kXRegisterAliases[num]); - }); - source_op_getters_.emplace(*SourceOpEnum::kSImm12, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::s_type::ExtractSImm(inst_word_)); - }); - source_op_getters_.emplace( - *SourceOpEnum::kScr, [this]() -> SourceOperandInterface * { - int csr_indx = encoding::r_type::ExtractRs2(inst_word_); - std::string csr_name; - switch (csr_indx) { - case 28: - csr_name = "mtcc"; - break; - case 29: - csr_name = "mtdc"; - break; - case 30: - csr_name = "mscratchc"; - break; - case 31: - csr_name = "mepcc"; - break; - default: - return nullptr; - } - auto res = state_->csr_set()->GetCsr(csr_name); - if (!res.ok()) { - return GetRegisterSourceOp<CheriotRegister>(state_, csr_name, - csr_name); - } - auto *csr = res.value(); - auto *op = csr->CreateSourceOperand(); - return op; - }); - source_op_getters_.emplace(*SourceOpEnum::kSImm20, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::u_type::ExtractSImm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kUImm20, [this]() { - return new generic::ImmediateOperand<int32_t>( - encoding::inst32_format::ExtractUImm(inst_word_)); - }); - source_op_getters_.emplace(*SourceOpEnum::kX0, []() { - return new generic::IntLiteralOperand<0>({1}); - }); - source_op_getters_.emplace(*SourceOpEnum::kX2, [this]() { - return GetRegisterSourceOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, 2), - kXRegisterAliases[2]); - }); + : RiscVCheriotEncodingCommon(state) { source_op_getters_.emplace(*SourceOpEnum::kNone, []() { return nullptr; }); -} - -void RiscVCheriotEncoding::InitializeDestinationOperandGetters() { - // Destination operand getters. - dest_op_getters_.emplace(*DestOpEnum::kC2, [this](int latency) { - return GetRegisterDestinationOp<CheriotRegister>(state_, "c2", latency, - "csp"); - }); - dest_op_getters_.emplace(*DestOpEnum::kC3cd, [this](int latency) { - int num = encoding::c_l::ExtractRd(inst_word_); - return GetRegisterDestinationOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), latency, - kCRegisterAliases[num]); - }); - dest_op_getters_.emplace(*DestOpEnum::kC3rd, [this](int latency) { - int num = encoding::c_l::ExtractRd(inst_word_); - if (num == 0) { - return GetRegisterDestinationOp<CheriotRegister>(state_, "X0Dest", - latency); - } - return GetRegisterDestinationOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), latency, - kXRegisterAliases[num]); - }); - dest_op_getters_.emplace(*DestOpEnum::kC3rs1, [this](int latency) { - int num = encoding::c_l::ExtractRs1(inst_word_); - return GetRegisterDestinationOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), latency, - kXRegisterAliases[num]); - }); - dest_op_getters_.emplace(*DestOpEnum::kCd, [this](int latency) { - int num = encoding::r_type::ExtractRd(inst_word_); - if (num == 0) { - return GetRegisterDestinationOp<CheriotRegister>(state_, "X0Dest", - latency); - } - return GetRegisterDestinationOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kCregPrefix, num), latency, - kCRegisterAliases[num]); - }); - dest_op_getters_.emplace(*DestOpEnum::kCsr, [this](int latency) { - return GetRegisterDestinationOp<CheriotRegister>( - state_, CheriotState::kCsrName, latency); - }); - dest_op_getters_.emplace(*DestOpEnum::kFrd, [this](int latency) { - int num = encoding::r_type::ExtractRd(inst_word_); - return GetRegisterDestinationOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kFregPrefix, num), latency, - kFRegisterAliases[num]); - }); - dest_op_getters_.emplace( - *DestOpEnum::kScr, [this](int latency) -> DestinationOperandInterface * { - int csr_indx = encoding::r_type::ExtractRs2(inst_word_); - std::string csr_name; - switch (csr_indx) { - case 28: - csr_name = "mtcc"; - break; - case 29: - csr_name = "mtdc"; - break; - case 30: - csr_name = "mscratchc"; - break; - case 31: - csr_name = "mepcc"; - break; - default: - return nullptr; - } - auto res = state_->csr_set()->GetCsr(csr_name); - if (!res.ok()) { - return GetRegisterDestinationOp<CheriotRegister>(state_, csr_name, - latency); - } - auto *csr = res.value(); - auto *op = csr->CreateWriteDestinationOperand(latency, csr_name); - return op; - }); - dest_op_getters_.emplace( - *DestOpEnum::kRd, [this](int latency) -> DestinationOperandInterface * { - int num = encoding::r_type::ExtractRd(inst_word_); - if (num == 0) { - return GetRegisterDestinationOp<CheriotRegister>(state_, "X0Dest", 0); - } else { - return GetRegisterDestinationOp<RVFpRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, num), latency, - kXRegisterAliases[num]); - } - }); - dest_op_getters_.emplace(*DestOpEnum::kX1, [this](int latency) { - return GetRegisterDestinationOp<CheriotRegister>( - state_, absl::StrCat(CheriotState::kXregPrefix, 1), latency, - kXRegisterAliases[1]); - }); - /* - dest_op_getters_.emplace(*DestOpEnum::kFflags, [this](int latency) { - return GetCSRSetBitsDestinationOp<uint32_t>(state_, "fflags", latency, ""); - }); - */ dest_op_getters_.emplace(*DestOpEnum::kNone, [](int latency) { return nullptr; }); + // Add Cheriot ISA source and destination operand getters. + AddCheriotSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, this); + AddCheriotDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Verify that all source and destination op enum values have a getter. + for (int i = *SourceOpEnum::kNone; i < *SourceOpEnum::kPastMaxValue; ++i) { + if (source_op_getters_.find(i) == source_op_getters_.end()) { + LOG(ERROR) << "No getter for source op enum value " << i; + } + } + for (int i = *DestOpEnum::kNone; i < *DestOpEnum::kPastMaxValue; ++i) { + if (dest_op_getters_.find(i) == dest_op_getters_.end()) { + LOG(ERROR) << "No getter for destination op enum value " << i; + } + } } // Parse the instruction word to determine the opcode.
diff --git a/cheriot/riscv_cheriot_encoding.h b/cheriot/riscv_cheriot_encoding.h index de50ca2..b5af035 100644 --- a/cheriot/riscv_cheriot_encoding.h +++ b/cheriot/riscv_cheriot_encoding.h
@@ -24,6 +24,7 @@ #include "cheriot/cheriot_state.h" #include "cheriot/riscv_cheriot_bin_decoder.h" #include "cheriot/riscv_cheriot_decoder.h" +#include "cheriot/riscv_cheriot_encoding_common.h" #include "cheriot/riscv_cheriot_enums.h" namespace mpact { @@ -38,7 +39,8 @@ // instructions) and the instruction representation. This class provides methods // to return the opcode, source operands, and destination operands for // instructions according to the operand fields in the encoding. -class RiscVCheriotEncoding : public RiscVCheriotEncodingBase { +class RiscVCheriotEncoding : public RiscVCheriotEncodingCommon, + public RiscVCheriotEncodingBase { public: explicit RiscVCheriotEncoding(CheriotState *state); @@ -96,15 +98,8 @@ using DestOpGetterMap = absl::flat_hash_map< int, absl::AnyInvocable<DestinationOperandInterface *(int)>>; - // These two methods initialize the source and destination operand getter - // arrays. - void InitializeSourceOperandGetters(); - void InitializeDestinationOperandGetters(); - SourceOpGetterMap source_op_getters_; DestOpGetterMap dest_op_getters_; - CheriotState *state_; - uint32_t inst_word_; OpcodeEnum opcode_; FormatEnum format_; };
diff --git a/cheriot/riscv_cheriot_encoding_common.h b/cheriot/riscv_cheriot_encoding_common.h new file mode 100644 index 0000000..7a355b9 --- /dev/null +++ b/cheriot/riscv_cheriot_encoding_common.h
@@ -0,0 +1,47 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_ENCODING_COMMON_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_ENCODING_COMMON_H_ + +#include <cstdint> + +namespace mpact { +namespace sim { +namespace cheriot { + +class CheriotState; + +// This class provides a common interface for accessing the state and +// instruction word for the RiscVCheriotEncoding classes (scalar, vector, vector +// + fp). + +class RiscVCheriotEncodingCommon { + public: + explicit RiscVCheriotEncodingCommon(CheriotState *state) : state_(state) {} + + // Accessors. + CheriotState *state() const { return state_; } + uint32_t inst_word() const { return inst_word_; } + + protected: + CheriotState *state_; + uint32_t inst_word_; +}; + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_ENCODING_COMMON_H_
diff --git a/cheriot/riscv_cheriot_f.bin_fmt b/cheriot/riscv_cheriot_f.bin_fmt new file mode 100644 index 0000000..f2f5741 --- /dev/null +++ b/cheriot/riscv_cheriot_f.bin_fmt
@@ -0,0 +1,43 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +instruction group RiscVFInst32[32] : Inst32Format { + // RiscV32 single precision floating point instructions. + flw : IType : func3 == 0b010, opcode == 0b000'0111; + fsw : SType : func3 == 0b010, opcode == 0b010'0111; + fmadd_s : R4Type : func2 == 0b00, opcode == 0b100'0011; + fmsub_s : R4Type : func2 == 0b00, opcode == 0b100'0111; + fnmsub_s : R4Type : func2 == 0b00, opcode == 0b100'1011; + fnmadd_s : R4Type : func2 == 0b00, opcode == 0b100'1111; + fadd_s : RType : func7 == 0b000'0000, opcode == 0b101'0011; + fsub_s : RType : func7 == 0b000'0100, opcode == 0b101'0011; + fmul_s : RType : func7 == 0b000'1000, opcode == 0b101'0011; + fdiv_s : RType : func7 == 0b000'1100, opcode == 0b101'0011; + fsqrt_s : RType : func7 == 0b010'1100, rs2 == 0, opcode == 0b101'0011; + fsgnj_s : RType : func7 == 0b001'0000, func3 == 0b000, opcode == 0b101'0011; + fsgnjn_s : RType : func7 == 0b001'0000, func3 == 0b001, opcode == 0b101'0011; + fsgnjx_s : RType : func7 == 0b001'0000, func3 == 0b010, opcode == 0b101'0011; + fmin_s : RType : func7 == 0b001'0100, func3 == 0b000, opcode == 0b101'0011; + fmax_s : RType : func7 == 0b001'0100, func3 == 0b001, opcode == 0b101'0011; + fcvt_ws : RType : func7 == 0b110'0000, rs2 == 0, opcode == 0b101'0011; + fcvt_wus : RType : func7 == 0b110'0000, rs2 == 1, opcode == 0b101'0011; + fmv_xw : RType : func7 == 0b111'0000, rs2 == 0, func3 == 0b000, opcode == 0b101'0011; + fcmpeq_s : RType : func7 == 0b101'0000, func3 == 0b010, opcode == 0b101'0011; + fcmplt_s : RType : func7 == 0b101'0000, func3 == 0b001, opcode == 0b101'0011; + fcmple_s : RType : func7 == 0b101'0000, func3 == 0b000, opcode == 0b101'0011; + fclass_s : RType : func7 == 0b111'0000, rs2 == 0, func3 == 0b001, opcode == 0b101'0011; + fcvt_sw : RType : func7 == 0b110'1000, rs2 == 0, opcode == 0b101'0011; + fcvt_swu : RType : func7 == 0b110'1000, rs2 == 1, opcode == 0b101'0011; + fmv_wx : RType : func7 == 0b111'1000, rs2 == 0, func3 == 0b000, opcode == 0b101'0011; +};
diff --git a/cheriot/riscv_cheriot_f.isa b/cheriot/riscv_cheriot_f.isa new file mode 100644 index 0000000..fda14dc --- /dev/null +++ b/cheriot/riscv_cheriot_f.isa
@@ -0,0 +1,132 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// RiscV32 F (single precision floating point) instructions. +slot riscv_cheriot_f { + includes { + #include "cheriot/riscv_cheriot_f_instructions.h" + } + default size = 4; + default latency = global_latency; + resources TwoOp = { next_pc, frs1 : frd[0..]}; + resources ThreeOp = { next_pc, frs1, frs2 : frd[0..]}; + resources FourOp = { next_pc, frs1, frs2, frs3 : frd[0..]}; + opcodes { + flw{(: rs1, I_imm12 : ), (: : frd)}, + resources: { next_pc, rs1 : frd[0..]}, + semfunc: "&RiscVILw", "&RiscVIFlwChild", + disasm: "flw", "%frd, %I_imm12(%rs1)"; + fsw{: rs1, S_imm12, frs2}, + resources: { next_pc, rs1, frs2}, + semfunc: "&RV32::RiscVFSw", + disasm: "fsw", "%frs2, %S_imm12(%rs1)"; + fadd_s{: frs1, frs2, rm : frd}, + resources: ThreeOp, + semfunc: "&RiscVFAdd", + disasm: "fadd", "%frd, %frs1, %frs2"; + fsub_s{: frs1, frs2, rm : frd}, + resources: ThreeOp, + semfunc: "&RiscVFSub", + disasm: "fsub", "%frd, %frs1, %frs2"; + fmul_s{: frs1, frs2, rm : frd}, + resources: ThreeOp, + semfunc: "&RiscVFMul", + disasm: "fmul", "%frd, %frs1, %frs2"; + fdiv_s{: frs1, frs2, rm : frd}, + resources: ThreeOp, + semfunc: "&RiscVFDiv", + disasm: "fdiv", "%frd, %frs1, %frs2"; + fsqrt_s{: frs1, rm : frd}, + resources: TwoOp, + semfunc: "&RiscVFSqrt", + disasm: "fsqrt", "%frd, %frs1"; + fmin_s{: frs1, frs2 : frd, fflags}, + resources: ThreeOp, + semfunc: "&RiscVFMin", + disasm: "fmin", "%frd, %frs1, %frs2"; + fmax_s{: frs1, frs2 : frd, fflags}, + resources: ThreeOp, + semfunc: "&RiscVFMax", + disasm: "fmax", "%frd, %frs1, %frs2"; + fmadd_s{: frs1, frs2, frs3, rm : frd, fflags}, + resources: FourOp, + semfunc: "&RiscVFMadd", + disasm: "fmadd", "%frd, %frs1, %frs2, %frs3"; + fmsub_s{: frs1, frs2, frs3, rm : frd, fflags}, + resources: FourOp, + semfunc: "&RiscVFMsub", + disasm: "fmsub", "%frd, %frs1, %frs2, %frs3"; + fnmadd_s{: frs1, frs2, frs3, rm : frd, fflags}, + resources: FourOp, + semfunc: "&RiscVFNmadd", + disasm: "fnmadd", "%frd, %frs1, %frs2, %frs3"; + fnmsub_s{: frs1, frs2, frs3, rm : frd, fflags}, + resources: FourOp, + semfunc: "&RiscVFNmsub", + disasm: "fnmsub", "%frd, %frs1, %frs2, %frs3"; + fcvt_ws{: frs1, rm : rd, fflags}, + resources: TwoOp, + semfunc: "&RV32::RiscVFCvtWs", + disasm: "fcvt.w.s", "%rd, %frs1"; + fcvt_sw{: rs1, rm : frd}, + resources: TwoOp, + semfunc: "&RiscVFCvtSw", + disasm: "fcvt.s.w", "%frd, %rs1"; + fcvt_wus{: frs1, rm : rd, fflags}, + resources: TwoOp, + semfunc: "&RV32::RiscVFCvtWus", + disasm: "fcvt.wu.s", "%rd, %frs1"; + fcvt_swu{: rs1, rm : frd}, + resources: TwoOp, + semfunc: "&RiscVFCvtSwu", + disasm: "fcvt.s.wu", "%frd, %rs1"; + fsgnj_s{: frs1, frs2 : frd}, + resources: ThreeOp, + semfunc: "&RiscVFSgnj", + disasm: "fsgn.s", "%frd, %frs1, %frs2"; + fsgnjn_s{: frs1, frs2 : frd}, + resources: ThreeOp, + semfunc: "&RiscVFSgnjn", + disasm: "fsgnjx.s", "%frd, %frs1, %frs2"; + fsgnjx_s{: frs1, frs2 : frd}, + resources: ThreeOp, + semfunc: "&RiscVFSgnjx", + disasm: "fsgnjx.s", "%frd, %frs1, %frs2"; + fmv_xw{: frs1 : rd}, + resources: { next_pc, frs1 : rd[0..]}, + disasm: "mv.x.w", "%rd, %frs1", + semfunc: "&RV32::RiscVFMvxw"; + fmv_wx{: rs1 : frd}, + resources: { next_pc, rs1 : frd[0..]}, + disasm: "mv.w.x", "%frd, %rs1", + semfunc: "&RiscVFMvwx"; + fcmpeq_s{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVFCmpeq", + disasm: "fcmpeq", "%rd, %frs1, %frs2"; + fcmplt_s{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVFCmplt", + disasm: "fcmplt", "%rd, %frs1, %frs2"; + fcmple_s{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVFCmple", + disasm: "fcmple", "%rd, %frs1, %frs2"; + fclass_s{: frs1 : rd}, + resources: { next_pc, frs1 : rd[0..]}, + semfunc: "&RV32::RiscVFClass", + disasm: "fclass", "%rd, %frs1"; + } +} \ No newline at end of file
diff --git a/cheriot/riscv_cheriot_f_instructions.cc b/cheriot/riscv_cheriot_f_instructions.cc index a5d9b15..26e2dd6 100644 --- a/cheriot/riscv_cheriot_f_instructions.cc +++ b/cheriot/riscv_cheriot_f_instructions.cc
@@ -17,9 +17,7 @@ #include <cmath> #include <cstdint> #include <functional> -#include <iostream> #include <limits> -#include <tuple> #include <type_traits> #include "cheriot/cheriot_register.h" @@ -27,6 +25,7 @@ #include "cheriot/riscv_cheriot_instruction_helpers.h" #include "mpact/sim/generic/register.h" #include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_info.h" #include "riscv//riscv_register.h" #include "riscv//riscv_state.h" @@ -34,6 +33,8 @@ namespace sim { namespace cheriot { +using ::mpact::sim::generic::FPTypeInfo; +using ::mpact::sim::riscv::FPExceptions; using ::mpact::sim::riscv::LoadContext; // The following instruction semantic functions implement the single precision @@ -63,22 +64,20 @@ // Convert float to signed 32 bit integer. template <typename XInt> static inline void RVFCvtWs(const Instruction *instruction) { - RiscVConvertFloatWithFflagsOp<CheriotRegister, XInt, float, int32_t>( - instruction); + RVCheriotConvertFloatWithFflagsOp<XInt, float, int32_t>(instruction); } // Convert float to unsigned 32 bit integer. template <typename XInt> static inline void RVFCvtWus(const Instruction *instruction) { - RiscVConvertFloatWithFflagsOp<CheriotRegister, XInt, float, uint32_t>( - instruction); + RVCheriotConvertFloatWithFflagsOp<XInt, float, uint32_t>(instruction); } // Single precision compare equal. template <typename XRegister> static inline void RVFCmpeq(const Instruction *instruction) { - RVCheriotBinaryNaNBoxOp<typename XRegister::ValueType, - typename XRegister::ValueType, float>( + RVCheriotBinaryOp<typename XRegister::ValueType, + typename XRegister::ValueType, float>( instruction, [instruction](float a, float b) -> typename XRegister::ValueType { if (FPTypeInfo<float>::IsSNaN(a) || FPTypeInfo<float>::IsSNaN(b)) { @@ -93,8 +92,8 @@ // Single precicion compare less than. template <typename XRegister> static inline void RVFCmplt(const Instruction *instruction) { - RVCheriotBinaryNaNBoxOp<typename XRegister::ValueType, - typename XRegister::ValueType, float>( + RVCheriotBinaryOp<typename XRegister::ValueType, + typename XRegister::ValueType, float>( instruction, [instruction](float a, float b) -> typename XRegister::ValueType { if (FPTypeInfo<float>::IsNaN(a) || FPTypeInfo<float>::IsNaN(b)) { @@ -109,8 +108,8 @@ // Single precision compare less than or equal. template <typename XRegister> static inline void RVFCmple(const Instruction *instruction) { - RVCheriotBinaryNaNBoxOp<typename XRegister::ValueType, - typename XRegister::ValueType, float>( + RVCheriotBinaryOp<typename XRegister::ValueType, + typename XRegister::ValueType, float>( instruction, [instruction](float a, float b) -> typename XRegister::ValueType { if (FPTypeInfo<float>::IsNaN(a) || FPTypeInfo<float>::IsNaN(b)) { @@ -152,29 +151,29 @@ // Basic arithmetic instructions. void RiscVFAdd(const Instruction *instruction) { - RiscVBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( + RVCheriotBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( instruction, [](float a, float b) { return a + b; }); } void RiscVFSub(const Instruction *instruction) { - RiscVBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( + RVCheriotBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( instruction, [](float a, float b) { return a - b; }); } void RiscVFMul(const Instruction *instruction) { - RiscVBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( + RVCheriotBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( instruction, [](float a, float b) { return a * b; }); } void RiscVFDiv(const Instruction *instruction) { - RiscVBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( + RVCheriotBinaryFloatNaNBoxOp<FPRegister::ValueType, float, float>( instruction, [](float a, float b) { return a / b; }); } // Square root uses the library square root. void RiscVFSqrt(const Instruction *instruction) { - RiscVUnaryFloatNaNBoxOp<FPRegister::ValueType, FPRegister::ValueType, float, - float>(instruction, [](float a) -> float { + RVCheriotUnaryFloatNaNBoxOp<FPRegister::ValueType, FPRegister::ValueType, + float, float>(instruction, [](float a) -> float { float res = sqrt(a); if (std::isnan(res)) return *reinterpret_cast<const float *>( @@ -239,131 +238,66 @@ void RiscVFMadd(const Instruction *instruction) { using T = float; - RiscVTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( + RVCheriotTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( instruction, [instruction](T a, T b, T c) -> T { - // Propagate any NaNs. - if (FPTypeInfo<T>::IsNaN(a)) return internal::CanonicalizeNaN(a); - if (FPTypeInfo<T>::IsNaN(b)) return internal::CanonicalizeNaN(b); if ((std::isinf(a) && (b == 0.0)) || ((std::isinf(b) && (a == 0.0)))) { auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); flag_db->Set<uint32_t>(0, *FPExceptions::kInvalidOp); flag_db->Submit(); } - if (FPTypeInfo<T>::IsNaN(c)) return internal::CanonicalizeNaN(c); - if (std::isinf(c) && !std::isinf(a) && !std::isinf(b)) return c; - if (c == 0.0) { - if ((a == 0.0 && !std::isinf(b)) || (b == 0.0 && !std::isinf(a))) { - FPUInt c_sign = *reinterpret_cast<FPUInt *>(&c) >> - (FPTypeInfo<T>::kBitSize - 1); - FPUInt ua = *reinterpret_cast<FPUInt *>(&a); - FPUInt ub = *reinterpret_cast<FPUInt *>(&b); - FPUInt prod_sign = (ua ^ ub) >> (FPTypeInfo<T>::kBitSize - 1); - if (prod_sign != c_sign) return 0.0; - return c; - } - return internal::CanonicalizeNaN(a * b); - } - return internal::CanonicalizeNaN((a * b) + c); + return internal::CanonicalizeNaN(fma(a, b, c)); }); } void RiscVFMsub(const Instruction *instruction) { using T = float; - RiscVTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( + RVCheriotTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( instruction, [instruction](T a, T b, T c) -> T { - if (FPTypeInfo<T>::IsNaN(a)) return internal::CanonicalizeNaN(a); - if (FPTypeInfo<T>::IsNaN(b)) return internal::CanonicalizeNaN(b); if ((std::isinf(a) && (b == 0.0)) || ((std::isinf(b) && (a == 0.0)))) { auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); flag_db->Set<uint32_t>(0, *FPExceptions::kInvalidOp); flag_db->Submit(); } - if (FPTypeInfo<T>::IsNaN(c)) return internal::CanonicalizeNaN(c); - if (std::isinf(c) && !std::isinf(a) && !std::isinf(b)) return -c; - if (c == 0.0) { - if ((a == 0.0 && !std::isinf(b)) || (b == 0.0 && !std::isinf(a))) { - FPUInt c_sign = -*reinterpret_cast<FPUInt *>(&c) >> - (FPTypeInfo<T>::kBitSize - 1); - FPUInt ua = *reinterpret_cast<FPUInt *>(&a); - FPUInt ub = *reinterpret_cast<FPUInt *>(&b); - FPUInt prod_sign = (ua ^ ub) >> (FPTypeInfo<T>::kBitSize - 1); - if (prod_sign == c_sign) return 0.0; - return -c; - } - return internal::CanonicalizeNaN(a * b); - } - return internal::CanonicalizeNaN((a * b) - c); + return internal::CanonicalizeNaN(fma(a, b, -c)); }); } void RiscVFNmadd(const Instruction *instruction) { using T = float; - RiscVTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( + RVCheriotTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( instruction, [instruction](T a, T b, T c) -> T { - if (FPTypeInfo<T>::IsNaN(a)) return internal::CanonicalizeNaN(a); - if (FPTypeInfo<T>::IsNaN(b)) return internal::CanonicalizeNaN(b); if ((std::isinf(a) && (b == 0.0)) || ((std::isinf(b) && (a == 0.0)))) { auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); flag_db->Set<uint32_t>(0, *FPExceptions::kInvalidOp); flag_db->Submit(); } - if (FPTypeInfo<T>::IsNaN(c)) return internal::CanonicalizeNaN(c); - if (std::isinf(c) && !std::isinf(a) && !std::isinf(b)) return -c; - if (c == 0.0) { - if ((a == 0.0 && !std::isinf(b)) || (b == 0.0 && !std::isinf(a))) { - FPUInt c_sign = *reinterpret_cast<FPUInt *>(&c) >> - (FPTypeInfo<T>::kBitSize - 1); - FPUInt ua = *reinterpret_cast<FPUInt *>(&a); - FPUInt ub = *reinterpret_cast<FPUInt *>(&b); - FPUInt prod_sign = (ua ^ ub) >> (FPTypeInfo<T>::kBitSize - 1); - if (prod_sign != c_sign) return 0.0; - return -c; - } - return internal::CanonicalizeNaN(-a * b); - } - return internal::CanonicalizeNaN(-((a * b) + c)); + return internal::CanonicalizeNaN(fma(-a, b, -c)); }); } void RiscVFNmsub(const Instruction *instruction) { using T = float; - RiscVTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( + RVCheriotTernaryFloatNaNBoxOp<FPRegister::ValueType, T, T>( instruction, [instruction](T a, T b, T c) -> T { - if (FPTypeInfo<T>::IsNaN(a)) return internal::CanonicalizeNaN(a); - if (FPTypeInfo<T>::IsNaN(b)) return internal::CanonicalizeNaN(b); if ((std::isinf(a) && (b == 0.0)) || ((std::isinf(b) && (a == 0.0)))) { auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); flag_db->Set<uint32_t>(0, *FPExceptions::kInvalidOp); flag_db->Submit(); } - if (FPTypeInfo<T>::IsNaN(c)) return internal::CanonicalizeNaN(c); - if (std::isinf(c) && !std::isinf(a) && !std::isinf(b)) return c; - if (c == 0.0) { - if ((a == 0.0 && !std::isinf(b)) || (b == 0.0 && !std::isinf(a))) { - FPUInt c_sign = -*reinterpret_cast<FPUInt *>(&c) >> - (FPTypeInfo<T>::kBitSize - 1); - FPUInt ua = *reinterpret_cast<FPUInt *>(&a); - FPUInt ub = *reinterpret_cast<FPUInt *>(&b); - FPUInt prod_sign = (ua ^ ub) >> (FPTypeInfo<T>::kBitSize - 1); - if (prod_sign != c_sign) return 0.0; - return c; - } - return internal::CanonicalizeNaN(-a * b); - } - return internal::CanonicalizeNaN(-((a * b) - c)); + return internal::CanonicalizeNaN(fma(-a, b, c)); }); } // Set sign of the first operand to that of the second. void RiscVFSgnj(const Instruction *instruction) { - RiscVBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( + RVCheriotBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( instruction, [](FPUInt a, FPUInt b) { return (a & 0x7fff'ffff) | (b & 0x8000'0000); }); } // Set the sign of the first operand to the opposite of the second. void RiscVFSgnjn(const Instruction *instruction) { - RiscVBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( + RVCheriotBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( instruction, [](FPUInt a, FPUInt b) { return (a & 0x7fff'ffff) | (~b & 0x8000'0000); }); @@ -372,7 +306,7 @@ // Set the sign of the first operand to the xor of the signs of the two // operands. void RiscVFSgnjx(const Instruction *instruction) { - RiscVBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( + RVCheriotBinaryNaNBoxOp<FPRegister::ValueType, FPUInt, FPUInt>( instruction, [](FPUInt a, FPUInt b) { return (a & 0x7fff'ffff) | ((a ^ b) & 0x8000'0000); }); @@ -380,22 +314,24 @@ // Convert signed 32 bit integer to float. void RiscVFCvtSw(const Instruction *instruction) { - RiscVUnaryFloatNaNBoxOp<FPRegister::ValueType, uint32_t, float, int32_t>( + RVCheriotUnaryFloatNaNBoxOp<FPRegister::ValueType, uint32_t, float, int32_t>( instruction, [](int32_t a) -> float { return static_cast<float>(a); }); } // Convert unsigned 32 bit integer to float. void RiscVFCvtSwu(const Instruction *instruction) { - RiscVUnaryFloatNaNBoxOp<FPRegister::ValueType, uint32_t, float, uint32_t>( + RVCheriotUnaryFloatNaNBoxOp<FPRegister::ValueType, uint32_t, float, uint32_t>( instruction, [](uint32_t a) -> float { return static_cast<float>(a); }); } // Single precision move instruction from integer to fp register file. void RiscVFMvwx(const Instruction *instruction) { - RiscVUnaryNaNBoxOp<FPRegister::ValueType, uint32_t, uint32_t, uint32_t>( + RVCheriotUnaryNaNBoxOp<FPRegister::ValueType, uint32_t, uint32_t, uint32_t>( instruction, [](uint32_t a) -> uint32_t { return a; }); } +namespace RV32 { + using XRegister = CheriotRegister; using XUint = typename std::make_unsigned<XRegister::ValueType>::type; using XInt = typename std::make_signed<XRegister::ValueType>::type; @@ -447,6 +383,8 @@ [](float a) -> uint32_t { return static_cast<uint32_t>(ClassifyFP(a)); }); } +} // namespace RV32 + } // namespace cheriot } // namespace sim } // namespace mpact
diff --git a/cheriot/riscv_cheriot_f_instructions.h b/cheriot/riscv_cheriot_f_instructions.h index 12c30f9..b7a680f 100644 --- a/cheriot/riscv_cheriot_f_instructions.h +++ b/cheriot/riscv_cheriot_f_instructions.h
@@ -65,6 +65,8 @@ // The move instruction takes a single register source operand and a single void RiscVFMvwx(const Instruction *instruction); +namespace RV32 { + // Store float instruction semantic function, source operand 0 is the base // register, source operand 1 is the offset, while source operand 2 is the value // to be stored referred to by rs2. @@ -86,6 +88,8 @@ // and a single destination register operand. void RiscVFClass(const Instruction *instruction); +} // namespace RV32 + } // namespace cheriot } // namespace sim } // namespace mpact
diff --git a/cheriot/riscv_cheriot_fp_state.h b/cheriot/riscv_cheriot_fp_state.h deleted file mode 100644 index 138e139..0000000 --- a/cheriot/riscv_cheriot_fp_state.h +++ /dev/null
@@ -1,137 +0,0 @@ -/* - * Copyright 2024 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MPACT_CHERIOT__RISCV_CHERIOT_FP_STATE_H_ -#define MPACT_CHERIOT__RISCV_CHERIOT_FP_STATE_H_ - -#include <cstdint> - -#include "riscv//riscv_csr.h" -#include "riscv//riscv_fp_host.h" -#include "riscv//riscv_fp_info.h" - -// This file contains code that manages the fp state of the RiscV processor. - -namespace mpact { -namespace sim { -namespace cheriot { - -class RiscVCheriotFPState; -class CheriotState; - -using ::mpact::sim::riscv::FPRoundingMode; -using ::mpact::sim::riscv::HostFloatingPointInterface; -using ::mpact::sim::riscv::RiscVSimpleCsr; - -// Floating point CSR. -class RiscVFcsr : public RiscVSimpleCsr<uint32_t> { - public: - RiscVFcsr() = delete; - explicit RiscVFcsr(RiscVCheriotFPState *fp_state); - ~RiscVFcsr() override = default; - - // Overrides. - uint32_t AsUint32() override; - uint64_t AsUint64() override; - void Write(uint32_t value) override; - void Write(uint64_t value) override; - - private: - RiscVCheriotFPState *fp_state_; -}; - -// Floating point rounding mode csr. -class RiscVFrm : public RiscVSimpleCsr<uint32_t> { - public: - RiscVFrm() = delete; - explicit RiscVFrm(RiscVCheriotFPState *fp_state); - ~RiscVFrm() override = default; - - // Overrides. - uint32_t AsUint32() override; - uint64_t AsUint64() override { return AsUint32(); } - void Write(uint32_t value) override; - void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } - uint32_t GetUint32() override; - uint64_t GetUint64() override { return GetUint32(); } - void Set(uint32_t value) override; - void Set(uint64_t value) override { Set(static_cast<uint32_t>(value)); } - - private: - RiscVCheriotFPState *fp_state_; -}; - -// Floating point status flags csr. -class RiscVFflags : public RiscVSimpleCsr<uint32_t> { - public: - RiscVFflags() = delete; - explicit RiscVFflags(RiscVCheriotFPState *fp_state); - ~RiscVFflags() override = default; - - // Overrides. - uint32_t AsUint32() override; - uint64_t AsUint64() override { return AsUint32(); } - void Write(uint32_t value) override; - void Write(uint64_t value) override { Write(static_cast<uint32_t>(value)); } - uint32_t GetUint32() override; - uint64_t GetUint64() override { return GetUint32(); } - void Set(uint32_t value) override; - void Set(uint64_t value) override { Set(static_cast<uint32_t>(value)); } - - private: - RiscVCheriotFPState *fp_state_; -}; - -class RiscVCheriotFPState { - public: - RiscVCheriotFPState() = delete; - RiscVCheriotFPState(const RiscVCheriotFPState &) = delete; - explicit RiscVCheriotFPState(CheriotState *rv_state); - ~RiscVCheriotFPState(); - - FPRoundingMode GetRoundingMode() const; - - void SetRoundingMode(FPRoundingMode mode); - - bool rounding_mode_valid() const { return rounding_mode_valid_; } - - // FP CSRs. - RiscVFcsr *fcsr() const { return fcsr_; } - RiscVFrm *frm() const { return frm_; } - RiscVFflags *fflags() const { return fflags_; } - // Parent state. - CheriotState *rv_state() const { return rv_state_; } - // Host interface. - HostFloatingPointInterface *host_fp_interface() const { - return host_fp_interface_; - } - - private: - CheriotState *rv_state_; - RiscVFcsr *fcsr_ = nullptr; - RiscVFrm *frm_ = nullptr; - RiscVFflags *fflags_ = nullptr; - HostFloatingPointInterface *host_fp_interface_; - - bool rounding_mode_valid_ = true; - FPRoundingMode rounding_mode_ = FPRoundingMode::kRoundToNearest; -}; - -} // namespace cheriot -} // namespace sim -} // namespace mpact - -#endif // MPACT_CHERIOT__RISCV_CHERIOT_FP_STATE_H_
diff --git a/cheriot/riscv_cheriot_instruction_helpers.h b/cheriot/riscv_cheriot_instruction_helpers.h index ee5af8d..28fd75a 100644 --- a/cheriot/riscv_cheriot_instruction_helpers.h +++ b/cheriot/riscv_cheriot_instruction_helpers.h
@@ -33,6 +33,9 @@ #include "mpact/sim/generic/operand_interface.h" #include "mpact/sim/generic/register.h" #include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" #include "riscv//riscv_state.h" namespace mpact { @@ -41,7 +44,12 @@ using ::mpact::sim::generic::Instruction; using ::mpact::sim::generic::operator*; +using ::mpact::sim::generic::FPTypeInfo; using ::mpact::sim::generic::RegisterBase; +using ::mpact::sim::riscv::FPExceptions; +using ::mpact::sim::riscv::FPRoundingMode; +using ::mpact::sim::riscv::ScopedFPStatus; + using CapReg = CheriotRegister; using PB = ::mpact::sim::cheriot::CheriotRegister::PermissionBits; @@ -62,6 +70,256 @@ cap_reg->set_is_null(); } +// Templated helper function for convert instruction semantic functions. +template <typename From, typename To> +inline std::tuple<To, uint32_t> CvtHelper(From value) { + constexpr From kMax = static_cast<From>(std::numeric_limits<To>::max()); + constexpr From kMin = static_cast<From>(std::numeric_limits<To>::min()); + + if (FPTypeInfo<From>::IsNaN(value)) { + return std::make_tuple(std::numeric_limits<To>::max(), + *FPExceptions::kInvalidOp); + } + if (value > kMax) { + return std::make_tuple(std::numeric_limits<To>::max(), + *FPExceptions::kInvalidOp); + } + if (value < kMin) { + if (std::is_unsigned<To>::value && (value > -1.0)) { + using SignedTo = typename std::make_signed<To>::type; + SignedTo signed_val = static_cast<SignedTo>(value); + if (signed_val == 0) { + return std::make_tuple(0, *FPExceptions::kInexact); + } + } + return std::make_tuple(std::numeric_limits<To>::min(), + *FPExceptions::kInvalidOp); + } + + auto output_value = static_cast<To>(value); + return std::make_tuple(output_value, 0); +} + +// Generic helper function for floating op instructions that do not require +// NaN boxing since they produce non fp-values, but set fflags. +template <typename Result, typename From, typename To> +inline void RVCheriotConvertFloatWithFflagsOp(const Instruction *instruction) { + constexpr To kMax = std::numeric_limits<To>::max(); + constexpr To kMin = std::numeric_limits<To>::min(); + + From lhs = generic::GetInstructionSource<From>(instruction, 0); + + uint32_t flags = 0; + uint32_t rm = generic::GetInstructionSource<uint32_t>(instruction, 1); + // Dynamic rounding mode will get rounding mode from the global state. + if (rm == *FPRoundingMode::kDynamic) { + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm = *rv_fp->GetRoundingMode(); + } + To value = 0; + if (FPTypeInfo<From>::IsNaN(lhs)) { + value = std::numeric_limits<To>::max(); + flags = *FPExceptions::kInvalidOp; + } else if (lhs == 0.0) { + value = 0; + } else { + // static_cast<>() doesn't necessarily round, so will have to force + // rounding before converting to the integer type if necessary. + using FromUint = typename FPTypeInfo<From>::UIntType; + auto constexpr kBias = FPTypeInfo<From>::kExpBias; + auto constexpr kExpMask = FPTypeInfo<From>::kExpMask; + auto constexpr kSigSize = FPTypeInfo<From>::kSigSize; + auto constexpr kSigMask = FPTypeInfo<From>::kSigMask; + auto constexpr kBitSize = FPTypeInfo<From>::kBitSize; + FromUint lhs_u = *reinterpret_cast<FromUint *>(&lhs); + const bool sign = (lhs_u & (1ULL << (kBitSize - 1))) != 0; + FromUint exp = kExpMask & lhs_u; + int exp_value = exp >> kSigSize; + int unbiased_exp = exp_value - kBias; + FromUint sig = kSigMask & lhs_u; + + // Get fraction part of the number, and right shift it to leave 1 bits + // length of fraction part. + // In forms of "<integer_value>.<fraction>"", where fraction is 2 bits. + // e.g., 1.75 -> 0b1.11 + // (float32) -> base = 0b110'0000'0000'0000'0000'0000'(23bits) + // exp = 127 (unbiased exp = 0) + // After right shift to leave 1 bit in fraction part. + // base = 0b1 + // rightshift_compressed = 1 (compress the right-shift eliminated number) + int right_shift = exp ? kSigSize - 1 - unbiased_exp : 1; + uint64_t base = sig; + bool rightshift_compressed = 0; + if (exp == 0) { + // Denormalized value. + // Format: (-1)^sign * 2 ^(exp - bias + 1) * 0.{sig} + // fraction part is too small keep it as 1 bits if not zero. + rightshift_compressed = base != 0; + base = 0; + flags = *FPExceptions::kInexact; + } else { + // Normalized value. + // Format: (-1)^sign * 2 ^(exp - bias) * 1.{sig + // base = 1.{sig} (total (1 + `kSigSize`) bits) + base |= 1ULL << kSigSize; + + // Right shift all base part out. + if (right_shift > (kBitSize - 1)) { + rightshift_compressed = base != 0; + flags = *FPExceptions::kInexact; + base = 0; + } else if (right_shift > 0) { + // Right shift to leave only 1 bit in the fraction part, compressed the + // right-shift eliminated number. + right_shift = std::min(right_shift, kBitSize); + uint64_t right_shifted_sig_mask = (1ULL << right_shift) - 1; + rightshift_compressed = (base & right_shifted_sig_mask) != 0; + base >>= right_shift; + } + } + + // Handle fraction part rounding. + if (right_shift >= 0) { + switch (rm) { + case *FPRoundingMode::kRoundToNearest: + // 0.5, tie condition + if (rightshift_compressed == 0 && base & 0b1) { + // <odd>.5 -> <odd>.5 + 0.5 = even + if ((base & 0b11) == 0b11) { + flags = *FPExceptions::kInexact; + base += 0b01; + } + } else if (base & 0b1 || rightshift_compressed) { + // not tie condition, round to nearest integer, it equals to add + // 0.5(=base + 0b01) and eliminate the fraction part. + base += 0b01; + } + break; + case *FPRoundingMode::kRoundTowardsZero: + // Round towards zero will eliminate the fraction part. + // Do nothing on fraction part. + // 1.2 -> 1.0, -1.5 -> -1.0, -0.7 -> 0.0 + break; + case *FPRoundingMode::kRoundDown: + // Positive float will eliminate the fraction part. + // Negative float with fraction part will subtract 1(= base + 0b10), + // and eliminate the fraction part. + // e.g., 1.2 -> 1.0, -1.5 -> -2.0, -0.7 -> -1.0 + if (sign && (base & 0b1 || rightshift_compressed)) { + base += 0b10; + } + break; + case *FPRoundingMode::kRoundUp: + // Positive float will add 1(= base + 0b10), and eliminate the + // fraction part. + // Negative float will eliminate the fraction part. + // e.g., 1.2 -> 2.0, -1.5 -> -1.0, -0.7 -> 0.0 + if (!sign && (base & 0b1 || rightshift_compressed)) { + base += 0b10; + } + break; + case *FPRoundingMode::kRoundToNearestTiesToMax: + // Round to nearest integer that is far from zero. + // e.g., 1.2 -> 2.0, -1.5 -> -2.0, -0.7 -> -1.0 + if (base & 0b1 || rightshift_compressed) { + base += 0b1; + } + break; + default: + LOG(ERROR) << "Invalid rounding mode"; + return; + } + } + uint64_t unsigned_value; + // Handle base with fraction part and store it to `unsigned_value`. + if (right_shift >= 0) { + // Set inexact flag if floating value has fraction part. + if (base & 0b1 || rightshift_compressed) { + flags = *FPExceptions::kInexact; + } + unsigned_value = base >> 1; + } else { + // Handle base without fraction part but need to left shift. + int left_shift = -right_shift - 1; + auto prev = unsigned_value = base; + while (left_shift) { + unsigned_value <<= 1; + // Check if overflow happened and set the flag. + if (prev > unsigned_value) { + flags = *FPExceptions::kInvalidOp; + unsigned_value = sign ? kMin : kMax; + break; + } + prev = unsigned_value; + --left_shift; + } + } + + // Handle the case that value is out of range, and final convert to value + // with sign. + if (std::is_signed<To>::value) { + // Positive value but exceeds the max value. + if (!sign && unsigned_value > kMax) { + flags = *FPExceptions::kInvalidOp; + value = kMax; + } else if (sign && (unsigned_value > 0 && -unsigned_value < kMin)) { + // Negative value but exceeds the min value. + flags = *FPExceptions::kInvalidOp; + value = kMin; + } else { + value = sign ? -((To)unsigned_value) : unsigned_value; + } + } else { + // Positive value but exceeds the max value. + if (unsigned_value > kMax) { + flags = *FPExceptions::kInvalidOp; + value = sign ? kMin : kMax; + } else if (sign && unsigned_value != 0) { + // float is negative value this is out of range of valid unsigned value. + flags = *FPExceptions::kInvalidOp; + value = kMin; + } else { + value = sign ? -((To)unsigned_value) : unsigned_value; + } + } + } + using SignedTo = typename std::make_signed<To>::type; + // The final value is sign-extended to the register width, even if it's + // conversion to an unsigned value. + SignedTo signed_value = static_cast<SignedTo>(value); + Result dest_value = static_cast<Result>(signed_value); + WriteCapIntResult(instruction, 0, dest_value); + if (flags) { + auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); + flag_db->Set<uint32_t>(0, flags); + flag_db->Submit(); + } +} + +// Helper function to read a NaN boxed source value, converting it to NaN if +// it isn't formatted properly. +template <typename RegValue, typename Argument> +inline Argument GetNaNBoxedSource(const Instruction *instruction, int arg) { + if (sizeof(RegValue) <= sizeof(Argument)) { + return generic::GetInstructionSource<Argument>(instruction, arg); + } else { + using SInt = typename std::make_signed<RegValue>::type; + using UInt = typename std::make_unsigned<RegValue>::type; + SInt val = generic::GetInstructionSource<SInt>(instruction, arg); + UInt uval = static_cast<UInt>(val); + UInt mask = std::numeric_limits<UInt>::max() << (sizeof(Argument) * 8); + if (((mask & uval) != mask)) { + return *reinterpret_cast<const Argument *>( + &FPTypeInfo<Argument>::kCanonicalNaN); + } + return generic::GetInstructionSource<Argument>(instruction, arg); + } +} + template <typename Register, typename Result, typename Argument> inline void RiscVBinaryOp(const Instruction *instruction, std::function<Result(Argument, Argument)> operation) { @@ -255,6 +513,304 @@ db->DecRef(); } +// Generic helper function for binary instructions with NaN boxing. This is +// used for those instructions that produce results in fp registers, but are +// not really executing an fp operation that requires rounding. +template <typename RegValue, typename Result, typename Argument> +inline void RVCheriotBinaryNaNBoxOp( + const Instruction *instruction, + std::function<Result(Argument, Argument)> operation) { + Argument lhs = GetNaNBoxedSource<RegValue, Argument>(instruction, 0); + Argument rhs = GetNaNBoxedSource<RegValue, Argument>(instruction, 1); + Result dest_value = operation(lhs, rhs); + auto *reg = static_cast<generic::RegisterDestinationOperand<RegValue> *>( + instruction->Destination(0)) + ->GetRegister(); + // Check to see if we need to NaN box the result. + if (sizeof(RegValue) > sizeof(Result)) { + // If the floating point value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<RegValue>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * (sizeof(RegValue) - sizeof(Result)); + reg_value = (reg_value << shift) | dest_u_value; + reg->data_buffer()->template Set<RegValue>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Generic helper function for unary instructions with NaN boxing. +template <typename DstRegValue, typename SrcRegValue, typename Result, + typename Argument> +inline void RVCheriotUnaryNaNBoxOp(const Instruction *instruction, + std::function<Result(Argument)> operation) { + Argument lhs = GetNaNBoxedSource<SrcRegValue, Argument>(instruction, 0); + Result dest_value = operation(lhs); + auto *reg = static_cast<generic::RegisterDestinationOperand<DstRegValue> *>( + instruction->Destination(0)) + ->GetRegister(); + // Check to see if we need to NaN box the result. + if (sizeof(DstRegValue) > sizeof(Result)) { + // If the floating point value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<DstRegValue>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * (sizeof(DstRegValue) - sizeof(Result)); + reg_value = (reg_value << shift) | dest_u_value; + WriteCapIntResult(instruction, 0, reg_value); + reg->data_buffer()->template Set<DstRegValue>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Generic helper function for unary floating point instructions. The main +// difference is that it handles rounding mode and performs NaN boxing. +template <typename DstRegValue, typename SrcRegValue, typename Result, + typename Argument> +inline void RVCheriotUnaryFloatNaNBoxOp( + const Instruction *instruction, std::function<Result(Argument)> operation) { + using ResUint = typename FPTypeInfo<Result>::UIntType; + Argument lhs = GetNaNBoxedSource<SrcRegValue, Argument>(instruction, 0); + // Get the rounding mode. + int rm_value = generic::GetInstructionSource<int>(instruction, 1); + + // If the rounding mode is dynamic, read it from the current state. + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *(rv_fp->GetRoundingMode()); + } + Result dest_value; + { + ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs); + } + if (std::isnan(dest_value) && std::signbit(dest_value)) { + ResUint res_value = *reinterpret_cast<ResUint *>(&dest_value); + res_value &= FPTypeInfo<Result>::kInfMask; + dest_value = *reinterpret_cast<Result *>(&res_value); + } + auto *dest = instruction->Destination(0); + auto *reg_dest = + static_cast<generic::RegisterDestinationOperand<DstRegValue> *>(dest); + auto *reg = reg_dest->GetRegister(); + // Check to see if we need to NaN box the result. + if (sizeof(DstRegValue) > sizeof(Result)) { + // If the floating point Value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<DstRegValue>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * (sizeof(DstRegValue) - sizeof(Result)); + reg_value = (reg_value << shift) | dest_u_value; + reg->data_buffer()->template Set<DstRegValue>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Generic helper function for floating op instructions that do not require +// NaN boxing since they produce non fp-values. +template <typename Result, typename Argument> +inline void RVCheriotUnaryFloatOp(const Instruction *instruction, + std::function<Result(Argument)> operation) { + Argument lhs = generic::GetInstructionSource<Argument>(instruction, 0); + // Get the rounding mode. + int rm_value = generic::GetInstructionSource<int>(instruction, 1); + + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *rv_fp->GetRoundingMode(); + } + Result dest_value; + { + ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs); + } + auto *dest = instruction->Destination(0); + using UInt = typename FPTypeInfo<Result>::UIntType; + auto *reg_dest = + static_cast<generic::RegisterDestinationOperand<UInt> *>(dest); + auto *reg = reg_dest->GetRegister(); + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Generic helper function for floating op instructions that do not require +// NaN boxing since they produce non fp-values, but set fflags. +template <typename Result, typename Argument> +inline void RVCheriotUnaryFloatWithFflagsOp( + const Instruction *instruction, + std::function<Result(Argument, uint32_t &)> operation) { + Argument lhs = generic::GetInstructionSource<Argument>(instruction, 0); + // Get the rounding mode. + int rm_value = generic::GetInstructionSource<int>(instruction, 1); + + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *rv_fp->GetRoundingMode(); + } + uint32_t flag = 0; + Result dest_value; + { + ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs, flag); + } + auto *dest = instruction->Destination(0); + using UInt = typename FPTypeInfo<Result>::UIntType; + auto *reg_dest = + static_cast<generic::RegisterDestinationOperand<UInt> *>(dest); + auto *reg = reg_dest->GetRegister(); + reg->data_buffer()->template Set<Result>(0, dest_value); + auto *flag_db = instruction->Destination(1)->AllocateDataBuffer(); + flag_db->Set<uint32_t>(0, flag); + flag_db->Submit(); +} + +// Generic helper function for binary floating point instructions. The main +// difference is that it handles rounding mode. +template <typename Register, typename Result, typename Argument> +inline void RVCheriotBinaryFloatNaNBoxOp( + const Instruction *instruction, + std::function<Result(Argument, Argument)> operation) { + Argument lhs = GetNaNBoxedSource<Register, Argument>(instruction, 0); + Argument rhs = GetNaNBoxedSource<Register, Argument>(instruction, 1); + // Argument lhs = generic::GetInstructionSource<Argument>(instruction, 0); + // Argument rhs = generic::GetInstructionSource<Argument>(instruction, 1); + + // Get the rounding mode. + int rm_value = generic::GetInstructionSource<int>(instruction, 2); + + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *rv_fp->GetRoundingMode(); + } + Result dest_value; + { + ScopedFPStatus fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs, rhs); + } + if (std::isnan(dest_value)) { + *reinterpret_cast<typename FPTypeInfo<Result>::UIntType *>(&dest_value) = + FPTypeInfo<Result>::kCanonicalNaN; + } + auto *reg = static_cast<generic::RegisterDestinationOperand<Register> *>( + instruction->Destination(0)) + ->GetRegister(); + // Check to see if we need to NaN box the result. + if (sizeof(Register) > sizeof(Result)) { + // If the floating point value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<Register>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * (sizeof(Register) - sizeof(Result)); + reg_value = (reg_value << shift) | dest_u_value; + reg->data_buffer()->template Set<Register>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Generic helper function for ternary floating point instructions. +template <typename Register, typename Result, typename Argument> +inline void RVCheriotTernaryFloatNaNBoxOp( + const Instruction *instruction, + std::function<Result(Argument, Argument, Argument)> operation) { + Argument rs1 = generic::GetInstructionSource<Argument>(instruction, 0); + Argument rs2 = generic::GetInstructionSource<Argument>(instruction, 1); + Argument rs3 = generic::GetInstructionSource<Argument>(instruction, 2); + // Get the rounding mode. + int rm_value = generic::GetInstructionSource<int>(instruction, 3); + + auto *rv_fp = static_cast<CheriotState *>(instruction->state())->rv_fp(); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *rv_fp->GetRoundingMode(); + } + Result dest_value; + { + ScopedFPStatus fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(rs1, rs2, rs3); + } + auto *reg = static_cast<generic::RegisterDestinationOperand<Register> *>( + instruction->Destination(0)) + ->GetRegister(); + // Check to see if we need to NaN box the result. + if (sizeof(Register) > sizeof(Result)) { + // If the floating point value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<Register>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * (sizeof(Register) - sizeof(Result)); + reg_value = (reg_value << shift) | dest_u_value; + reg->data_buffer()->template Set<Register>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +// Helper function to classify floating point values. +template <typename T> +typename FPTypeInfo<T>::UIntType ClassifyFP(T val) { + using UIntType = typename FPTypeInfo<T>::UIntType; + auto int_value = *reinterpret_cast<UIntType *>(&val); + UIntType sign = int_value >> (FPTypeInfo<T>::kBitSize - 1); + UIntType exp_mask = (1 << FPTypeInfo<T>::kExpSize) - 1; + UIntType exp = (int_value >> FPTypeInfo<T>::kSigSize) & exp_mask; + UIntType sig = + int_value & ((static_cast<UIntType>(1) << FPTypeInfo<T>::kSigSize) - 1); + if (exp == 0) { // The number is denormal or zero. + if (sig == 0) { // The number is zero. + return sign ? 1 << 3 : 1 << 4; + } else { // subnormal. + return sign ? 1 << 2 : 1 << 5; + } + } else if (exp == exp_mask) { // The number is infinity or NaN. + if (sig == 0) { // infinity + return sign ? 1 : 1 << 7; + } else { + if ((sig >> (FPTypeInfo<T>::kSigSize - 1)) != 0) { // Quiet NaN. + return 1 << 9; + } else { // signaling NaN. + return 1 << 8; + } + } + } + return sign ? 1 << 1 : 1 << 6; +} + } // namespace cheriot } // namespace sim } // namespace mpact
diff --git a/cheriot/riscv_cheriot_rvv.bin_fmt b/cheriot/riscv_cheriot_rvv.bin_fmt new file mode 100644 index 0000000..757f4aa --- /dev/null +++ b/cheriot/riscv_cheriot_rvv.bin_fmt
@@ -0,0 +1,44 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Cheriot RVV FP binary decoder. +decoder RiscVCheriotRVVFp { + namespace mpact::sim::cheriot::encoding_rvv_fp; + opcode_enum = "cheriot::isa32_rvv_fp::OpcodeEnum"; + includes { + #include "cheriot/riscv_cheriot_rvv_fp_decoder.h" + } + // Group these instruction groups in the same decoder function. + RiscVCheriotRVVFPInst32 = {RiscVFInst32, RiscVVInst32, RiscVVFPInst32, RiscVCheriotInst32}; + // Keep this separate (different base format). + RiscVCheriotRVVFPInst16 = {RiscVCheriotInst16}; +}; + +// Cheriot RVV binary decoder. +decoder RiscVCheriotRVV { + namespace mpact::sim::cheriot::encoding_rvv; + opcode_enum = "cheriot::isa32_rvv::OpcodeEnum"; + includes { + #include "cheriot/riscv_cheriot_rvv_decoder.h" + } + // Group these instruction groups in the same decoder function. + RiscVCheriotRVVInst32 = {RiscVVInst32, RiscVCheriotInst32}; + // Keep this separate (different base format). + RiscVCheriotRVVInst16 = {RiscVCheriotInst16}; +}; + +#include "cheriot/riscv_cheriot.bin_fmt" +#include "cheriot/riscv_cheriot_f.bin_fmt" +#include "cheriot/riscv_cheriot_vector.bin_fmt" +#include "cheriot/riscv_cheriot_vector_fp.bin_fmt" \ No newline at end of file
diff --git a/cheriot/riscv_cheriot_rvv.isa b/cheriot/riscv_cheriot_rvv.isa new file mode 100644 index 0000000..071e515 --- /dev/null +++ b/cheriot/riscv_cheriot_rvv.isa
@@ -0,0 +1,48 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector.isa" +#include "cheriot/riscv_cheriot_vector_fp.isa" +#include "cheriot/riscv_cheriot_f.isa" +#include "cheriot/riscv_cheriot.isa" + +includes { + #include "absl/functional/bind_front.h" +} + +disasm widths = {-18}; + +isa RiscVCheriotRVV { + namespace mpact::sim::cheriot::isa32_rvv; + slots { riscv_cheriot_rvv; } +} + +isa RiscVCheriotRVVFp { + namespace mpact::sim::cheriot::isa32_rvv_fp; + slots { riscv_cheriot_rvv_fp; } +} + +slot riscv_cheriot_rvv : riscv32_cheriot, riscv_cheriot_vector { + default size = 4; + default opcode = + disasm: "Illegal instruction at 0x%(@:08x)", + semfunc: "&RiscVIllegalInstruction"; +} + +slot riscv_cheriot_rvv_fp : riscv32_cheriot, riscv_cheriot_vector, riscv_cheriot_vector_fp, riscv_cheriot_f { + default size = 4; + default opcode = + disasm: "Illegal instruction at 0x%(@:08x)", + semfunc: "&RiscVIllegalInstruction"; +} \ No newline at end of file
diff --git a/cheriot/riscv_cheriot_rvv_encoding.cc b/cheriot/riscv_cheriot_rvv_encoding.cc new file mode 100644 index 0000000..b1b3f3c --- /dev/null +++ b/cheriot/riscv_cheriot_rvv_encoding.cc
@@ -0,0 +1,111 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_rvv_encoding.h" + +#include <cstdint> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_getters.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_rvv_getters.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "cheriot/riscv_cheriot_rvv_bin_decoder.h" +#include "cheriot/riscv_cheriot_rvv_decoder.h" +#include "cheriot/riscv_cheriot_rvv_enums.h" +#include "mpact/sim/generic/type_helpers.h" + +namespace mpact { +namespace sim { +namespace cheriot { +namespace isa32_rvv { + +using Extractors = ::mpact::sim::cheriot::encoding_rvv::Extractors; + +RiscVCheriotRVVEncoding::RiscVCheriotRVVEncoding(CheriotState *state) + : RiscVCheriotEncodingCommon(state) { + source_op_getters_.emplace(*SourceOpEnum::kNone, []() { return nullptr; }); + dest_op_getters_.emplace(*DestOpEnum::kNone, + [](int latency) { return nullptr; }); + // Add Cheriot ISA source and destination operand getters. + AddCheriotSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, this); + AddCheriotDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Add non-fp RVV source and destination operand getters. + AddCheriotRVVSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, + this); + AddCheriotRVVDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Verify that all source and destination op enum values have a getter. + for (int i = *SourceOpEnum::kNone; i < *SourceOpEnum::kPastMaxValue; ++i) { + if (source_op_getters_.find(i) == source_op_getters_.end()) { + LOG(ERROR) << "No getter for source op enum value " << i; + } + } + for (int i = *DestOpEnum::kNone; i < *DestOpEnum::kPastMaxValue; ++i) { + if (dest_op_getters_.find(i) == dest_op_getters_.end()) { + LOG(ERROR) << "No getter for destination op enum value " << i; + } + } +} + +// Parse the instruction word to determine the opcode. +void RiscVCheriotRVVEncoding::ParseInstruction(uint32_t inst_word) { + inst_word_ = inst_word; + if ((inst_word_ & 0x3) == 3) { + auto [opcode, format] = mpact::sim::cheriot::encoding_rvv:: + DecodeRiscVCheriotRVVInst32WithFormat(inst_word_); + opcode_ = opcode; + format_ = format; + return; + } + + auto [opcode, format] = + mpact::sim::cheriot::encoding_rvv::DecodeRiscVCheriotRVVInst16WithFormat( + static_cast<uint16_t>(inst_word_ & 0xffff)); + opcode_ = opcode; + format_ = format; +} + +DestinationOperandInterface *RiscVCheriotRVVEncoding::GetDestination( + SlotEnum, int, OpcodeEnum opcode, DestOpEnum dest_op, int dest_no, + int latency) { + int index = static_cast<int>(dest_op); + auto iter = dest_op_getters_.find(index); + if (iter == dest_op_getters_.end()) { + LOG(ERROR) << absl::StrCat("No getter for destination op enum value ", + index, " for instruction ", + kOpcodeNames[static_cast<int>(opcode)]); + return nullptr; + } + return (iter->second)(latency); +} + +SourceOperandInterface *RiscVCheriotRVVEncoding::GetSource( + SlotEnum, int, OpcodeEnum opcode, SourceOpEnum source_op, int source_no) { + int index = static_cast<int>(source_op); + auto iter = source_op_getters_.find(index); + if (iter == source_op_getters_.end()) { + LOG(ERROR) << absl::StrCat("No getter for source op enum value ", index, + " for instruction ", + kOpcodeNames[static_cast<int>(opcode)]); + return nullptr; + } + return (iter->second)(); +} + +} // namespace isa32_rvv +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_rvv_encoding.h b/cheriot/riscv_cheriot_rvv_encoding.h new file mode 100644 index 0000000..2ac1bde --- /dev/null +++ b/cheriot/riscv_cheriot_rvv_encoding.h
@@ -0,0 +1,112 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_RVV_ENCODING_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_RVV_ENCODING_H_ + +#include <cstdint> + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "cheriot/riscv_cheriot_rvv_bin_decoder.h" +#include "cheriot/riscv_cheriot_rvv_decoder.h" +#include "cheriot/riscv_cheriot_rvv_enums.h" + +namespace mpact { +namespace sim { +namespace cheriot { +namespace isa32_rvv { + +using ::mpact::sim::cheriot::encoding_rvv::FormatEnum; + +// This class provides the interface between the generated instruction decoder +// framework (which is agnostic of the actual bit representation of +// instructions) and the instruction representation. This class provides methods +// to return the opcode, source operands, and destination operands for +// instructions according to the operand fields in the encoding. +class RiscVCheriotRVVEncoding : public RiscVCheriotEncodingCommon, + public RiscVCheriotRVVEncodingBase { + public: + explicit RiscVCheriotRVVEncoding(CheriotState *state); + + // Parses an instruction and determines the opcode. + void ParseInstruction(uint32_t inst_word); + + // RiscV32 CHERIoT has a single slot type and single entry, so the following + // methods ignore those parameters. + + // Returns the opcode in the current instruction representation. + OpcodeEnum GetOpcode(SlotEnum, int) override { return opcode_; } + + // Returns the instruction format in the current instruction representation. + FormatEnum GetFormat(SlotEnum, int) { return format_; } + + // There is no predicate, so return nullptr. + PredicateOperandInterface *GetPredicate(SlotEnum, int, OpcodeEnum, + PredOpEnum) override { + return nullptr; + } + + // Currently no resources modeled for RiscV CHERIoT. + ResourceOperandInterface *GetSimpleResourceOperand( + SlotEnum, int, OpcodeEnum, SimpleResourceVector &resource_vec, + int end) override { + return nullptr; + } + + ResourceOperandInterface *GetComplexResourceOperand( + SlotEnum, int, OpcodeEnum, ComplexResourceEnum resource, int begin, + int end) override { + return nullptr; + } + + // The following method returns a source operand that corresponds to the + // particular operand field. + SourceOperandInterface *GetSource(SlotEnum, int, OpcodeEnum, SourceOpEnum op, + int source_no) override; + + // The following method returns a destination operand that corresponds to the + // particular operand field. + DestinationOperandInterface *GetDestination(SlotEnum, int, OpcodeEnum, + DestOpEnum op, int dest_no, + int latency) override; + // This method returns latency for any destination operand for which the + // latency specifier in the .isa file is '*'. Since there are none, just + // return 0. + int GetLatency(SlotEnum, int, OpcodeEnum, DestOpEnum, int) override { + return 0; + } + + private: + using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; + using DestOpGetterMap = absl::flat_hash_map< + int, absl::AnyInvocable<DestinationOperandInterface *(int)>>; + + SourceOpGetterMap source_op_getters_; + DestOpGetterMap dest_op_getters_; + OpcodeEnum opcode_; + FormatEnum format_; +}; + +} // namespace isa32_rvv +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_RVV_ENCODING_H_
diff --git a/cheriot/riscv_cheriot_rvv_fp_encoding.cc b/cheriot/riscv_cheriot_rvv_fp_encoding.cc new file mode 100644 index 0000000..c7f3930 --- /dev/null +++ b/cheriot/riscv_cheriot_rvv_fp_encoding.cc
@@ -0,0 +1,120 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_rvv_fp_encoding.h" + +#include <cstdint> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_f_getters.h" +#include "cheriot/cheriot_getters.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_rvv_fp_getters.h" +#include "cheriot/cheriot_rvv_getters.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "cheriot/riscv_cheriot_rvv_fp_bin_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_enums.h" +#include "mpact/sim/generic/type_helpers.h" + +namespace mpact { +namespace sim { +namespace cheriot { +namespace isa32_rvv_fp { + +using Extractors = ::mpact::sim::cheriot::encoding_rvv_fp::Extractors; + +RiscVCheriotRVVFPEncoding::RiscVCheriotRVVFPEncoding(CheriotState *state) + : RiscVCheriotEncodingCommon(state) { + source_op_getters_.emplace(*SourceOpEnum::kNone, []() { return nullptr; }); + dest_op_getters_.emplace(*DestOpEnum::kNone, + [](int latency) { return nullptr; }); + // Add Cheriot ISA source and destination operand getters. + AddCheriotSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, this); + AddCheriotDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Add RVV source and destination operand getters. + AddCheriotRVVSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, + this); + AddCheriotRVVDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Add RVV FP source and destination operand getters. + AddCheriotRVVFPSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, + this); + AddCheriotRVVFPDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Add FP source and destination operand getters. + AddCheriotFSourceGetters<SourceOpEnum, Extractors>(source_op_getters_, this); + AddCheriotFDestGetters<DestOpEnum, Extractors>(dest_op_getters_, this); + // Verify that all source and destination op enum values have a getter. + for (int i = *SourceOpEnum::kNone; i < *SourceOpEnum::kPastMaxValue; ++i) { + if (source_op_getters_.find(i) == source_op_getters_.end()) { + LOG(ERROR) << "No getter for source op enum value " << i; + } + } + for (int i = *DestOpEnum::kNone; i < *DestOpEnum::kPastMaxValue; ++i) { + if (dest_op_getters_.find(i) == dest_op_getters_.end()) { + LOG(ERROR) << "No getter for destination op enum value " << i; + } + } +} + +// Parse the instruction word to determine the opcode. +void RiscVCheriotRVVFPEncoding::ParseInstruction(uint32_t inst_word) { + inst_word_ = inst_word; + if ((inst_word_ & 0x3) == 3) { + auto [opcode, format] = mpact::sim::cheriot::encoding_rvv_fp:: + DecodeRiscVCheriotRVVFPInst32WithFormat(inst_word_); + opcode_ = opcode; + format_ = format; + return; + } + + auto [opcode, format] = mpact::sim::cheriot::encoding_rvv_fp:: + DecodeRiscVCheriotRVVFPInst16WithFormat( + static_cast<uint16_t>(inst_word_ & 0xffff)); + opcode_ = opcode; + format_ = format; +} + +DestinationOperandInterface *RiscVCheriotRVVFPEncoding::GetDestination( + SlotEnum, int, OpcodeEnum opcode, DestOpEnum dest_op, int dest_no, + int latency) { + int index = static_cast<int>(dest_op); + auto iter = dest_op_getters_.find(index); + if (iter == dest_op_getters_.end()) { + LOG(ERROR) << absl::StrCat("No getter for destination op enum value ", + index, " for instruction ", + kOpcodeNames[static_cast<int>(opcode)]); + return nullptr; + } + return (iter->second)(latency); +} + +SourceOperandInterface *RiscVCheriotRVVFPEncoding::GetSource( + SlotEnum, int, OpcodeEnum opcode, SourceOpEnum source_op, int source_no) { + int index = static_cast<int>(source_op); + auto iter = source_op_getters_.find(index); + if (iter == source_op_getters_.end()) { + LOG(ERROR) << absl::StrCat("No getter for source op enum value ", index, + " for instruction ", + kOpcodeNames[static_cast<int>(opcode)]); + return nullptr; + } + return (iter->second)(); +} + +} // namespace isa32_rvv_fp +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_rvv_fp_encoding.h b/cheriot/riscv_cheriot_rvv_fp_encoding.h new file mode 100644 index 0000000..7f52735 --- /dev/null +++ b/cheriot/riscv_cheriot_rvv_fp_encoding.h
@@ -0,0 +1,112 @@ +/* + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_RVV_FP_ENCODING_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_RVV_FP_ENCODING_H_ + +#include <cstdint> + +#include "absl/container/flat_hash_map.h" +#include "absl/functional/any_invocable.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_encoding_common.h" +#include "cheriot/riscv_cheriot_rvv_fp_bin_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_decoder.h" +#include "cheriot/riscv_cheriot_rvv_fp_enums.h" + +namespace mpact { +namespace sim { +namespace cheriot { +namespace isa32_rvv_fp { + +using ::mpact::sim::cheriot::encoding_rvv_fp::FormatEnum; + +// This class provides the interface between the generated instruction decoder +// framework (which is agnostic of the actual bit representation of +// instructions) and the instruction representation. This class provides methods +// to return the opcode, source operands, and destination operands for +// instructions according to the operand fields in the encoding. +class RiscVCheriotRVVFPEncoding : public RiscVCheriotEncodingCommon, + public RiscVCheriotRVVFpEncodingBase { + public: + explicit RiscVCheriotRVVFPEncoding(CheriotState *state); + + // Parses an instruction and determines the opcode. + void ParseInstruction(uint32_t inst_word); + + // RiscV32 CHERIoT has a single slot type and single entry, so the following + // methods ignore those parameters. + + // Returns the opcode in the current instruction representation. + OpcodeEnum GetOpcode(SlotEnum, int) override { return opcode_; } + + // Returns the instruction format in the current instruction representation. + FormatEnum GetFormat(SlotEnum, int) { return format_; } + + // There is no predicate, so return nullptr. + PredicateOperandInterface *GetPredicate(SlotEnum, int, OpcodeEnum, + PredOpEnum) override { + return nullptr; + } + + // Currently no resources modeled for RiscV CHERIoT. + ResourceOperandInterface *GetSimpleResourceOperand( + SlotEnum, int, OpcodeEnum, SimpleResourceVector &resource_vec, + int end) override { + return nullptr; + } + + ResourceOperandInterface *GetComplexResourceOperand( + SlotEnum, int, OpcodeEnum, ComplexResourceEnum resource, int begin, + int end) override { + return nullptr; + } + + // The following method returns a source operand that corresponds to the + // particular operand field. + SourceOperandInterface *GetSource(SlotEnum, int, OpcodeEnum, SourceOpEnum op, + int source_no) override; + + // The following method returns a destination operand that corresponds to the + // particular operand field. + DestinationOperandInterface *GetDestination(SlotEnum, int, OpcodeEnum, + DestOpEnum op, int dest_no, + int latency) override; + // This method returns latency for any destination operand for which the + // latency specifier in the .isa file is '*'. Since there are none, just + // return 0. + int GetLatency(SlotEnum, int, OpcodeEnum, DestOpEnum, int) override { + return 0; + } + + private: + using SourceOpGetterMap = + absl::flat_hash_map<int, absl::AnyInvocable<SourceOperandInterface *()>>; + using DestOpGetterMap = absl::flat_hash_map< + int, absl::AnyInvocable<DestinationOperandInterface *(int)>>; + + SourceOpGetterMap source_op_getters_; + DestOpGetterMap dest_op_getters_; + OpcodeEnum opcode_; + FormatEnum format_; +}; + +} // namespace isa32_rvv_fp +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_RVV_FP_ENCODING_H_
diff --git a/cheriot/riscv_cheriot_vector.bin_fmt b/cheriot/riscv_cheriot_vector.bin_fmt new file mode 100644 index 0000000..8970e9e --- /dev/null +++ b/cheriot/riscv_cheriot_vector.bin_fmt
@@ -0,0 +1,445 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Non-floating point vector intruction encodings. + +format VMem[32] : Inst32Format { + fields: + unsigned nf[3]; + unsigned mew[1]; + unsigned mop[2]; + unsigned vm[1]; + unsigned rs2[5]; + unsigned rs1[5]; + unsigned width[3]; + unsigned vd[5]; + unsigned opcode[7]; + overlays: + unsigned lumop[5] = rs2; + unsigned sumop[5] = rs2; + unsigned vs2[5] = rs2; + unsigned vs3[5] = vd; +}; + +format VArith[32] : Inst32Format { + fields: + unsigned func6[6]; + unsigned vm[1]; + unsigned vs2[5]; + unsigned vs1[5]; + unsigned func3[3]; + unsigned vd[5]; + unsigned opcode[7]; + overlays: + unsigned uimm5[5] = vs1; + signed simm5[5] = vs1; + unsigned rd[5] = vd; + unsigned rs1[5] = vs1; + unsigned vd_mask[5] = vd; +}; + +format VConfig[32] : Inst32Format { + fields: + unsigned top12[12]; + unsigned rs1[5]; + unsigned func3[3]; + unsigned rd[5]; + unsigned opcode[7]; + overlays: + signed zimm11[11] = top12[10..0]; + unsigned func1[1] = top12[11]; + unsigned func2[2] = top12[11..10]; + unsigned func7[7] = top12[11..5]; + signed zimm10[10] = top12[9..0]; + unsigned uimm5[5] = rs1; + unsigned rs2[5] = top12[4..0]; +}; + +instruction group RiscVVInst32[32] : Inst32Format { + //opcfg : VArith : func6 == 0bxxx'xxx, func3 == 0b111, opcode == 0b101'0111; + vsetvli_xn : VConfig : rs1 != 0, func1 == 0, func3 == 0b111, opcode == 0b101'0111; + vsetvli_nz : VConfig : rd != 0, rs1 == 0, func1 == 0, func3 == 0b111, opcode == 0b101'0111; + vsetvli_zz : VConfig : rd == 0, rs1 == 0, func1 == 0, func3 == 0b111, opcode == 0b101'0111; + vsetivli : VConfig : func2 == 0b11, func3 == 0b111, opcode == 0b101'0111; + vsetvl_xn : VConfig : rs1 != 0, func7 == 0b100'0000, func3 == 0b111, opcode == 0b101'0111; + vsetvl_nz : VConfig : rd != 0, rs1 == 0, func7 == 0b100'0000, func3 == 0b111, opcode == 0b101'0111; + vsetvl_zz : VConfig : rd == 0, rs1 == 0, func7 == 0b100'0000, func3 == 0b111, opcode == 0b101'0111; + + // Unit stride, masked (vm=0). + vle8 : VMem : vm == 0, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b000, opcode == 0b000'0111; + vle16 : VMem : vm == 0, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b101, opcode == 0b000'0111; + vle32 : VMem : vm == 0, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b110, opcode == 0b000'0111; + vle64 : VMem : vm == 0, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b111, opcode == 0b000'0111; + // Unit stride, unmasked (vm=1). + vle8_vm1 : VMem : vm == 1, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b000, opcode == 0b000'0111; + vle16_vm1 : VMem : vm == 1, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b101, opcode == 0b000'0111; + vle32_vm1 : VMem : vm == 1, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b110, opcode == 0b000'0111; + vle64_vm1 : VMem : vm == 1, nf == 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b111, opcode == 0b000'0111; + // Mask load. + vlm : VMem : nf == 0, mew == 0, mop == 0b00, lumop == 0b01011, width == 0b000, opcode == 0b000'0111; + // Unit stride, fault first. + vle8ff : VMem : nf == 0, mew == 0, mop == 0b00, lumop == 0b10000, width == 0b000, opcode == 0b000'0111; + vle16ff : VMem : nf == 0, mew == 0, mop == 0b00, lumop == 0b10000, width == 0b101, opcode == 0b000'0111; + vle32ff : VMem : nf == 0, mew == 0, mop == 0b00, lumop == 0b10000, width == 0b110, opcode == 0b000'0111; + vle64ff : VMem : nf == 0, mew == 0, mop == 0b00, lumop == 0b10000, width == 0b111, opcode == 0b000'0111; + // Unit stride, whole register load. + vl1re8 : VMem : nf == 0, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b000, opcode == 0b000'0111; + vl1re16 : VMem : nf == 0, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b101, opcode == 0b000'0111; + vl1re32 : VMem : nf == 0, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b110, opcode == 0b000'0111; + vl1re64 : VMem : nf == 0, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b111, opcode == 0b000'0111; + vl2re8 : VMem : nf == 1, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b000, opcode == 0b000'0111; + vl2re16 : VMem : nf == 1, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b101, opcode == 0b000'0111; + vl2re32 : VMem : nf == 1, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b110, opcode == 0b000'0111; + vl2re64 : VMem : nf == 1, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b111, opcode == 0b000'0111; + vl4re8 : VMem : nf == 3, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b000, opcode == 0b000'0111; + vl4re16 : VMem : nf == 3, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b101, opcode == 0b000'0111; + vl4re32 : VMem : nf == 3, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b110, opcode == 0b000'0111; + vl4re64 : VMem : nf == 3, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b111, opcode == 0b000'0111; + vl8re8 : VMem : nf == 7, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b000, opcode == 0b000'0111; + vl8re16 : VMem : nf == 7, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b101, opcode == 0b000'0111; + vl8re32 : VMem : nf == 7, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b110, opcode == 0b000'0111; + vl8re64 : VMem : nf == 7, mop == 0b00, vm == 1, lumop == 0b01000, width == 0b111, opcode == 0b000'0111; + // Vector load strided. + vlse8 : VMem : nf == 0, mew == 0, mop == 0b10, width == 0b000, opcode == 0b000'0111; + vlse16 : VMem : nf == 0, mew == 0, mop == 0b10, width == 0b101, opcode == 0b000'0111; + vlse32 : VMem : nf == 0, mew == 0, mop == 0b10, width == 0b110, opcode == 0b000'0111; + vlse64 : VMem : nf == 0, mew == 0, mop == 0b10, width == 0b111, opcode == 0b000'0111; + // Vector load indexed, unordered. + vluxei8 : VMem : nf == 0, mew == 0, mop == 0b01, width == 0b000, opcode == 0b000'0111; + vluxei16: VMem : nf == 0, mew == 0, mop == 0b01, width == 0b101, opcode == 0b000'0111; + vluxei32: VMem : nf == 0, mew == 0, mop == 0b01, width == 0b110, opcode == 0b000'0111; + vluxei64: VMem : nf == 0, mew == 0, mop == 0b01, width == 0b111, opcode == 0b000'0111; + // Vector load indexed, ordered. + vloxei8 : VMem : nf == 0, mew == 0, mop == 0b11, width == 0b000, opcode == 0b000'0111; + vloxei16: VMem : nf == 0, mew == 0, mop == 0b11, width == 0b101, opcode == 0b000'0111; + vloxei32: VMem : nf == 0, mew == 0, mop == 0b11, width == 0b110, opcode == 0b000'0111; + vloxei64: VMem : nf == 0, mew == 0, mop == 0b11, width == 0b111, opcode == 0b000'0111; + // Vector segment load, unit stride. + vlsege8: VMem : nf != 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b000, opcode == 0b000'0111; + vlsege16: VMem : nf != 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b101, opcode == 0b000'0111; + vlsege32: VMem : nf != 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b110, opcode == 0b000'0111; + vlsege64: VMem : nf != 0, mew == 0, mop == 0b00, lumop == 0b00000, width == 0b111, opcode == 0b000'0111; + // Vector segment load, strided. + vlssege8: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b000, opcode == 0b000'0111; + vlssege16: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b101, opcode == 0b000'0111; + vlssege32: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b110, opcode == 0b000'0111; + vlssege64: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b111, opcode == 0b000'0111; + // Vector segment load, indexed, unordered. + vluxsegei8: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b000, opcode == 0b000'0111; + vluxsegei16: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b101, opcode == 0b000'0111; + vluxsegei32: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b110, opcode == 0b000'0111; + vluxsegei64: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b111, opcode == 0b000'0111; + // Vector segement load, indexed, ordered. + vloxsegei8: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b000, opcode == 0b000'0111; + vloxsegei16: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b101, opcode == 0b000'0111; + vloxsegei32: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b110, opcode == 0b000'0111; + vloxsegei64: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b111, opcode == 0b000'0111; + + + // VECTOR STORES + + // Unit stride. + vse8 : VMem : mew == 0, mop == 0b00, sumop == 0b00000, width == 0b000, opcode == 0b010'0111; + vse16 : VMem : mew == 0, mop == 0b00, sumop == 0b00000, width == 0b101, opcode == 0b010'0111; + vse32 : VMem : mew == 0, mop == 0b00, sumop == 0b00000, width == 0b110, opcode == 0b010'0111; + vse64 : VMem : mew == 0, mop == 0b00, sumop == 0b00000, width == 0b111, opcode == 0b010'0111; + // Mask store. + vsm : VMem : mew == 0, mop == 0b00, sumop == 0b01011, width == 0b000, opcode == 0b010'0111; + // Unit stride, fault first. + vse8ff : VMem : mew == 0, mop == 0b00, sumop == 0b10000, width == 0b000, opcode == 0b010'0111; + vse16ff : VMem : mew == 0, mop == 0b00, sumop == 0b10000, width == 0b101, opcode == 0b010'0111; + vse32ff : VMem : mew == 0, mop == 0b00, sumop == 0b10000, width == 0b110, opcode == 0b010'0111; + vse64ff : VMem : mew == 0, mop == 0b00, sumop == 0b10000, width == 0b111, opcode == 0b010'0111; + // Unit stride, whole register store. + vs1re8 : VMem : nf == 0, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b000, opcode == 0b010'0111; + vs1re16 : VMem : nf == 0, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b101, opcode == 0b010'0111; + vs1re32 : VMem : nf == 0, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b110, opcode == 0b010'0111; + vs1re64 : VMem : nf == 0, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b111, opcode == 0b010'0111; + vs2re8 : VMem : nf == 1, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b000, opcode == 0b010'0111; + vs2re16 : VMem : nf == 1, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b101, opcode == 0b010'0111; + vs2re32 : VMem : nf == 1, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b110, opcode == 0b010'0111; + vs2re64 : VMem : nf == 1, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b111, opcode == 0b010'0111; + vs4re8 : VMem : nf == 3, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b000, opcode == 0b010'0111; + vs4re16 : VMem : nf == 3, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b101, opcode == 0b010'0111; + vs4re32 : VMem : nf == 3, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b110, opcode == 0b010'0111; + vs4re64 : VMem : nf == 3, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b111, opcode == 0b010'0111; + vs8re8 : VMem : nf == 7, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b000, opcode == 0b010'0111; + vs8re16 : VMem : nf == 7, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b101, opcode == 0b010'0111; + vs8re32 : VMem : nf == 7, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b110, opcode == 0b010'0111; + vs8re64 : VMem : nf == 7, mop == 0b00, vm == 1, sumop == 0b01000, width == 0b111, opcode == 0b010'0111; + // Store strided. + vsse8 : VMem : mew == 0, mop == 0b10, width == 0b000, opcode == 0b010'0111; + vsse16 : VMem : mew == 0, mop == 0b10, width == 0b101, opcode == 0b010'0111; + vsse32 : VMem : mew == 0, mop == 0b10, width == 0b110, opcode == 0b010'0111; + vsse64 : VMem : mew == 0, mop == 0b10, width == 0b111, opcode == 0b010'0111; + // Store indexed, unordered. + vsuxei8 : VMem : mew == 0, mop == 0b01, width == 0b000, opcode == 0b010'0111; + vsuxei16: VMem : mew == 0, mop == 0b01, width == 0b101, opcode == 0b010'0111; + vsuxei32: VMem : mew == 0, mop == 0b01, width == 0b110, opcode == 0b010'0111; + vsuxei64: VMem : mew == 0, mop == 0b01, width == 0b111, opcode == 0b010'0111; + // Store indexed, ordered. + vsoxei8 : VMem : mew == 0, mop == 0b11, width == 0b000, opcode == 0b010'0111; + vsoxei16: VMem : mew == 0, mop == 0b11, width == 0b101, opcode == 0b010'0111; + vsoxei32: VMem : mew == 0, mop == 0b11, width == 0b110, opcode == 0b010'0111; + vsoxei64: VMem : mew == 0, mop == 0b11, width == 0b111, opcode == 0b010'0111; + // Vector segment store, unit stride. + vssege8: VMem : nf != 0, mew == 0, mop == 0b00, sumop == 0b00000, width == 0b000, opcode == 0b010'0111; + vssege16: VMem : nf != 0, mew == 0, mop == 0b00, sumop == 0b00000, width == 0b101, opcode == 0b010'0111; + vssege32: VMem : nf != 0, mew == 0, mop == 0b00, sumop == 0b00000, width == 0b110, opcode == 0b010'0111; + vssege64: VMem : nf != 0, mew == 0, mop == 0b00, sumop == 0b00000, width == 0b111, opcode == 0b010'0111; + // Vector segment store, strided. + vsssege8: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b000, opcode == 0b010'0111; + vsssege16: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b101, opcode == 0b010'0111; + vsssege32: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b110, opcode == 0b010'0111; + vsssege64: VMem : nf != 0, mew == 0, mop == 0b10, width == 0b111, opcode == 0b010'0111; + // Vector segment store, indexed, unordered. + vsuxsegei8: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b000, opcode == 0b010'0111; + vsuxsegei16: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b101, opcode == 0b010'0111; + vsuxsegei32: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b110, opcode == 0b010'0111; + vsuxsegei64: VMem : nf != 0, mew == 0, mop == 0b01, width == 0b111, opcode == 0b010'0111; + // Vector segement store, indexed, ordered. + vsoxsegei8: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b000, opcode == 0b010'0111; + vsoxsegei16: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b101, opcode == 0b010'0111; + vsoxsegei32: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b110, opcode == 0b010'0111; + vsoxsegei64: VMem : nf != 0, mew == 0, mop == 0b11, width == 0b111, opcode == 0b010'0111; + + // Integer: OPIVV, OPIVX, OPIVI + //opivv : VArith : func6 == 0bxxx'xxx, func3 == 0b000, opcode == 0b101'0111; + //opivx : VArith : func6 == 0bxxx'xxx, func3 == 0b100, opcode == 0b101'0111; + //opivi : VArith : func6 == 0bxxx'xxx, func3 == 0b011, opcode == 0b101'0111; + + vadd_vv : VArith : func6 == 0b000'000, func3 == 0b000, opcode == 0b101'0111; + vadd_vx : VArith : func6 == 0b000'000, func3 == 0b100, opcode == 0b101'0111; + vadd_vi : VArith : func6 == 0b000'000, func3 == 0b011, opcode == 0b101'0111; + vsub_vv : VArith : func6 == 0b000'010, func3 == 0b000, opcode == 0b101'0111; + vsub_vx : VArith : func6 == 0b000'010, func3 == 0b100, opcode == 0b101'0111; + vrsub_vx : VArith : func6 == 0b000'011, func3 == 0b100, opcode == 0b101'0111; + vrsub_vi : VArith : func6 == 0b000'011, func3 == 0b011, opcode == 0b101'0111; + vminu_vv : VArith : func6 == 0b000'100, func3 == 0b000, opcode == 0b101'0111; + vminu_vx : VArith : func6 == 0b000'100, func3 == 0b100, opcode == 0b101'0111; + vmin_vv : VArith : func6 == 0b000'101, func3 == 0b000, opcode == 0b101'0111; + vmin_vx : VArith : func6 == 0b000'101, func3 == 0b100, opcode == 0b101'0111; + vmaxu_vv : VArith : func6 == 0b000'110, func3 == 0b000, opcode == 0b101'0111; + vmaxu_vx : VArith : func6 == 0b000'110, func3 == 0b100, opcode == 0b101'0111; + vmax_vv : VArith : func6 == 0b000'111, func3 == 0b000, opcode == 0b101'0111; + vmax_vx : VArith : func6 == 0b000'111, func3 == 0b100, opcode == 0b101'0111; + vand_vv : VArith : func6 == 0b001'001, func3 == 0b000, opcode == 0b101'0111; + vand_vx : VArith : func6 == 0b001'001, func3 == 0b100, opcode == 0b101'0111; + vand_vi : VArith : func6 == 0b001'001, func3 == 0b011, opcode == 0b101'0111; + vor_vv : VArith : func6 == 0b001'010, func3 == 0b000, opcode == 0b101'0111; + vor_vx : VArith : func6 == 0b001'010, func3 == 0b100, opcode == 0b101'0111; + vor_vi : VArith : func6 == 0b001'010, func3 == 0b011, opcode == 0b101'0111; + vxor_vv : VArith : func6 == 0b001'011, func3 == 0b000, opcode == 0b101'0111; + vxor_vx : VArith : func6 == 0b001'011, func3 == 0b100, opcode == 0b101'0111; + vxor_vi : VArith : func6 == 0b001'011, func3 == 0b011, opcode == 0b101'0111; + vrgather_vv : VArith : func6 == 0b001'100, func3 == 0b000, opcode == 0b101'0111; + vrgather_vx : VArith : func6 == 0b001'100, func3 == 0b100, opcode == 0b101'0111; + vrgather_vi : VArith : func6 == 0b001'100, func3 == 0b011, opcode == 0b101'0111; + vslideup_vx : VArith : func6 == 0b001'110, func3 == 0b100, opcode == 0b101'0111; + vslideup_vi : VArith : func6 == 0b001'110, func3 == 0b011, opcode == 0b101'0111; + vrgatherei16_vv : VArith : func6 == 0b001'110, func3 == 0b000, opcode == 0b101'0111; + vslidedown_vx : VArith : func6 == 0b001'111, func3 == 0b100, opcode == 0b101'0111; + vslidedown_vi : VArith : func6 == 0b001'111, func3 == 0b011, opcode == 0b101'0111; + vadc_vv : VArith : func6 == 0b010'000, vd != 0, vm == 0, func3 == 0b000, opcode == 0b101'0111; + vadc_vx : VArith : func6 == 0b010'000, vd != 0, vm == 0, func3 == 0b100, opcode == 0b101'0111; + vadc_vi : VArith : func6 == 0b010'000, vd != 0, vm == 0, func3 == 0b011, opcode == 0b101'0111; + vmadc_vv : VArith : func6 == 0b010'001, func3 == 0b000, opcode == 0b101'0111; + vmadc_vx : VArith : func6 == 0b010'001, func3 == 0b100, opcode == 0b101'0111; + vmadc_vi : VArith : func6 == 0b010'001, func3 == 0b011, opcode == 0b101'0111; + vsbc_vv : VArith : func6 == 0b010'010, vd != 0, vm == 0, func3 == 0b000, opcode == 0b101'0111; + vsbc_vx : VArith : func6 == 0b010'010, vd != 0, vm == 0, func3 == 0b100, opcode == 0b101'0111; + vmsbc_vv : VArith : func6 == 0b010'011, func3 == 0b000, opcode == 0b101'0111; + vmsbc_vx : VArith : func6 == 0b010'011, func3 == 0b100, opcode == 0b101'0111; + vmerge_vv : VArith : func6 == 0b010'111, vm == 0, func3 == 0b000, opcode == 0b101'0111; + vmerge_vx : VArith : func6 == 0b010'111, vm == 0, func3 == 0b100, opcode == 0b101'0111; + vmerge_vi : VArith : func6 == 0b010'111, vm == 0, func3 == 0b011, opcode == 0b101'0111; + vmv_vv : VArith : func6 == 0b010'111, vm == 1, vs2 == 0, func3 == 0b000, opcode == 0b101'0111; + vmv_vx : VArith : func6 == 0b010'111, vm == 1, vs2 == 0, func3 == 0b100, opcode == 0b101'0111; + vmv_vi : VArith : func6 == 0b010'111, vm == 1, vs2 == 0, func3 == 0b011, opcode == 0b101'0111; + vmseq_vv : VArith : func6 == 0b011'000, func3 == 0b000, opcode == 0b101'0111; + vmseq_vx : VArith : func6 == 0b011'000, func3 == 0b100, opcode == 0b101'0111; + vmseq_vi : VArith : func6 == 0b011'000, func3 == 0b011, opcode == 0b101'0111; + vmsne_vv : VArith : func6 == 0b011'001, func3 == 0b000, opcode == 0b101'0111; + vmsne_vx : VArith : func6 == 0b011'001, func3 == 0b100, opcode == 0b101'0111; + vmsne_vi : VArith : func6 == 0b011'001, func3 == 0b011, opcode == 0b101'0111; + vmsltu_vv : VArith : func6 == 0b011'010, func3 == 0b000, opcode == 0b101'0111; + vmsltu_vx : VArith : func6 == 0b011'010, func3 == 0b100, opcode == 0b101'0111; + vmslt_vv : VArith : func6 == 0b011'011, func3 == 0b000, opcode == 0b101'0111; + vmslt_vx : VArith : func6 == 0b011'011, func3 == 0b100, opcode == 0b101'0111; + vmsleu_vv : VArith : func6 == 0b011'100, func3 == 0b000, opcode == 0b101'0111; + vmsleu_vx : VArith : func6 == 0b011'100, func3 == 0b100, opcode == 0b101'0111; + vmsleu_vi : VArith : func6 == 0b011'100, func3 == 0b011, opcode == 0b101'0111; + vmsle_vv : VArith : func6 == 0b011'101, func3 == 0b000, opcode == 0b101'0111; + vmsle_vx : VArith : func6 == 0b011'101, func3 == 0b100, opcode == 0b101'0111; + vmsle_vi : VArith : func6 == 0b011'101, func3 == 0b011, opcode == 0b101'0111; + vmsgtu_vx : VArith : func6 == 0b011'110, func3 == 0b100, opcode == 0b101'0111; + vmsgtu_vi : VArith : func6 == 0b011'110, func3 == 0b011, opcode == 0b101'0111; + vmsgt_vx : VArith : func6 == 0b011'111, func3 == 0b100, opcode == 0b101'0111; + vmsgt_vi : VArith : func6 == 0b011'111, func3 == 0b011, opcode == 0b101'0111; + vsaddu_vv : VArith : func6 == 0b100'000, func3 == 0b000, opcode == 0b101'0111; + vsaddu_vx : VArith : func6 == 0b100'000, func3 == 0b100, opcode == 0b101'0111; + vsaddu_vi : VArith : func6 == 0b100'000, func3 == 0b011, opcode == 0b101'0111; + vsadd_vv : VArith : func6 == 0b100'001, func3 == 0b000, opcode == 0b101'0111; + vsadd_vx : VArith : func6 == 0b100'001, func3 == 0b100, opcode == 0b101'0111; + vsadd_vi : VArith : func6 == 0b100'001, func3 == 0b011, opcode == 0b101'0111; + vssubu_vv : VArith : func6 == 0b100'010, func3 == 0b000, opcode == 0b101'0111; + vssubu_vx : VArith : func6 == 0b100'010, func3 == 0b100, opcode == 0b101'0111; + vssub_vv : VArith : func6 == 0b100'011, func3 == 0b000, opcode == 0b101'0111; + vssub_vx : VArith : func6 == 0b100'011, func3 == 0b100, opcode == 0b101'0111; + vsll_vv : VArith : func6 == 0b100'101, func3 == 0b000, opcode == 0b101'0111; + vsll_vx : VArith : func6 == 0b100'101, func3 == 0b100, opcode == 0b101'0111; + vsll_vi : VArith : func6 == 0b100'101, func3 == 0b011, opcode == 0b101'0111; + vsmul_vv : VArith : func6 == 0b100'111, func3 == 0b000, opcode == 0b101'0111; + vsmul_vx : VArith : func6 == 0b100'111, func3 == 0b100, opcode == 0b101'0111; + vmv1r_vi : VArith : func6 == 0b100'111, uimm5 == 0, func3 == 0b011, opcode == 0b101'0111; + vmv2r_vi : VArith : func6 == 0b100'111, uimm5 == 1, func3 == 0b011, opcode == 0b101'0111; + vmv4r_vi : VArith : func6 == 0b100'111, uimm5 == 3, func3 == 0b011, opcode == 0b101'0111; + vmv8r_vi : VArith : func6 == 0b100'111, uimm5 == 7, func3 == 0b011, opcode == 0b101'0111; + vsrl_vv : VArith : func6 == 0b101'000, func3 == 0b000, opcode == 0b101'0111; + vsrl_vx : VArith : func6 == 0b101'000, func3 == 0b100, opcode == 0b101'0111; + vsrl_vi : VArith : func6 == 0b101'000, func3 == 0b011, opcode == 0b101'0111; + vsra_vv : VArith : func6 == 0b101'001, func3 == 0b000, opcode == 0b101'0111; + vsra_vx : VArith : func6 == 0b101'001, func3 == 0b100, opcode == 0b101'0111; + vsra_vi : VArith : func6 == 0b101'001, func3 == 0b011, opcode == 0b101'0111; + vssrl_vv : VArith : func6 == 0b101'010, func3 == 0b000, opcode == 0b101'0111; + vssrl_vx : VArith : func6 == 0b101'010, func3 == 0b100, opcode == 0b101'0111; + vssrl_vi : VArith : func6 == 0b101'010, func3 == 0b011, opcode == 0b101'0111; + vssra_vv : VArith : func6 == 0b101'011, func3 == 0b000, opcode == 0b101'0111; + vssra_vx : VArith : func6 == 0b101'011, func3 == 0b100, opcode == 0b101'0111; + vssra_vi : VArith : func6 == 0b101'011, func3 == 0b011, opcode == 0b101'0111; + vnsrl_vv : VArith : func6 == 0b101'100, func3 == 0b000, opcode == 0b101'0111; + vnsrl_vx : VArith : func6 == 0b101'100, func3 == 0b100, opcode == 0b101'0111; + vnsrl_vi : VArith : func6 == 0b101'100, func3 == 0b011, opcode == 0b101'0111; + vnsra_vv : VArith : func6 == 0b101'101, func3 == 0b000, opcode == 0b101'0111; + vnsra_vx : VArith : func6 == 0b101'101, func3 == 0b100, opcode == 0b101'0111; + vnsra_vi : VArith : func6 == 0b101'101, func3 == 0b011, opcode == 0b101'0111; + vnclipu_vv : VArith : func6 == 0b101'110, func3 == 0b000, opcode == 0b101'0111; + vnclipu_vx : VArith : func6 == 0b101'110, func3 == 0b100, opcode == 0b101'0111; + vnclipu_vi : VArith : func6 == 0b101'110, func3 == 0b011, opcode == 0b101'0111; + vnclip_vv : VArith : func6 == 0b101'111, func3 == 0b000, opcode == 0b101'0111; + vnclip_vx : VArith : func6 == 0b101'111, func3 == 0b100, opcode == 0b101'0111; + vnclip_vi : VArith : func6 == 0b101'111, func3 == 0b011, opcode == 0b101'0111; + vwredsumu_vv : VArith : func6 == 0b110'000, func3 == 0b000, opcode == 0b101'0111; + vwredsum_vv : VArith : func6 == 0b110'001, func3 == 0b000, opcode == 0b101'0111; + + // Integer: OPMVV, OPMVX + //opmvv : VArith : func6 == 0bxxx'xxx, func3 == 0b010, opcode == 0b101'0111; + //opmvx : VArith : func6 == 0bxxx'xxx, func3 == 0b110, opcode == 0b101'0111; + + vredsum_vv : VArith : func6 == 0b000'000, func3 == 0b010, opcode == 0b101'0111; + vredand_vv : VArith : func6 == 0b000'001, func3 == 0b010, opcode == 0b101'0111; + vredor_vv : VArith : func6 == 0b000'010, func3 == 0b010, opcode == 0b101'0111; + vredxor_vv : VArith : func6 == 0b000'011, func3 == 0b010, opcode == 0b101'0111; + vredminu_vv : VArith : func6 == 0b000'100, func3 == 0b010, opcode == 0b101'0111; + vredmin_vv : VArith : func6 == 0b000'101, func3 == 0b010, opcode == 0b101'0111; + vredmaxu_vv : VArith : func6 == 0b000'110, func3 == 0b010, opcode == 0b101'0111; + vredmax_vv : VArith : func6 == 0b000'111, func3 == 0b010, opcode == 0b101'0111; + vaaddu_vv : VArith : func6 == 0b001'000, func3 == 0b010, opcode == 0b101'0111; + vaaddu_vx : VArith : func6 == 0b001'000, func3 == 0b110, opcode == 0b101'0111; + vaadd_vv : VArith : func6 == 0b001'001, func3 == 0b010, opcode == 0b101'0111; + vaadd_vx : VArith : func6 == 0b001'001, func3 == 0b110, opcode == 0b101'0111; + vasubu_vv : VArith : func6 == 0b001'010, func3 == 0b010, opcode == 0b101'0111; + vasubu_vx : VArith : func6 == 0b001'010, func3 == 0b110, opcode == 0b101'0111; + vasub_vv : VArith : func6 == 0b001'011, func3 == 0b010, opcode == 0b101'0111; + vasub_vx : VArith : func6 == 0b001'011, func3 == 0b110, opcode == 0b101'0111; + vslide1up_vx : VArith : func6 == 0b001'110, func3 == 0b110, opcode == 0b101'0111; + vslide1down_vx : VArith : func6 == 0b001'111, func3 == 0b110, opcode == 0b101'0111; + vcompress_vv : VArith : func6 == 0b010'111, func3 == 0b010, opcode == 0b101'0111; + vmandnot_vv : VArith : func6 == 0b011'000, func3 == 0b010, opcode == 0b101'0111; + vmand_vv : VArith : func6 == 0b011'001, func3 == 0b010, opcode == 0b101'0111; + vmor_vv : VArith : func6 == 0b011'010, func3 == 0b010, opcode == 0b101'0111; + vmxor_vv : VArith : func6 == 0b011'011, func3 == 0b010, opcode == 0b101'0111; + vmornot_vv : VArith : func6 == 0b011'100, func3 == 0b010, opcode == 0b101'0111; + vmnand_vv : VArith : func6 == 0b011'101, func3 == 0b010, opcode == 0b101'0111; + vmnor_vv : VArith : func6 == 0b011'110, func3 == 0b010, opcode == 0b101'0111; + vmxnor_vv : VArith : func6 == 0b011'111, func3 == 0b010, opcode == 0b101'0111; + + vdivu_vv : VArith : func6 == 0b100'000, func3 == 0b010, opcode == 0b101'0111; + vdivu_vx : VArith : func6 == 0b100'000, func3 == 0b110, opcode == 0b101'0111; + vdiv_vv : VArith : func6 == 0b100'001, func3 == 0b010, opcode == 0b101'0111; + vdiv_vx : VArith : func6 == 0b100'001, func3 == 0b110, opcode == 0b101'0111; + vremu_vv : VArith : func6 == 0b100'010, func3 == 0b010, opcode == 0b101'0111; + vremu_vx : VArith : func6 == 0b100'010, func3 == 0b110, opcode == 0b101'0111; + vrem_vv : VArith : func6 == 0b100'011, func3 == 0b010, opcode == 0b101'0111; + vrem_vx : VArith : func6 == 0b100'011, func3 == 0b110, opcode == 0b101'0111; + vmulhu_vv : VArith : func6 == 0b100'100, func3 == 0b010, opcode == 0b101'0111; + vmulhu_vx : VArith : func6 == 0b100'100, func3 == 0b110, opcode == 0b101'0111; + vmul_vv : VArith : func6 == 0b100'101, func3 == 0b010, opcode == 0b101'0111; + vmul_vx : VArith : func6 == 0b100'101, func3 == 0b110, opcode == 0b101'0111; + vmulhsu_vv : VArith : func6 == 0b100'110, func3 == 0b010, opcode == 0b101'0111; + vmulhsu_vx : VArith : func6 == 0b100'110, func3 == 0b110, opcode == 0b101'0111; + vmulh_vv : VArith : func6 == 0b100'111, func3 == 0b010, opcode == 0b101'0111; + vmulh_vx : VArith : func6 == 0b100'111, func3 == 0b110, opcode == 0b101'0111; + vmadd_vv : VArith : func6 == 0b101'001, func3 == 0b010, opcode == 0b101'0111; + vmadd_vx : VArith : func6 == 0b101'001, func3 == 0b110, opcode == 0b101'0111; + vnmsub_vv : VArith : func6 == 0b101'011, func3 == 0b010, opcode == 0b101'0111; + vnmsub_vx : VArith : func6 == 0b101'011, func3 == 0b110, opcode == 0b101'0111; + vmacc_vv : VArith : func6 == 0b101'101, func3 == 0b010, opcode == 0b101'0111; + vmacc_vx : VArith : func6 == 0b101'101, func3 == 0b110, opcode == 0b101'0111; + vnmsac_vv : VArith : func6 == 0b101'111, func3 == 0b010, opcode == 0b101'0111; + vnmsac_vx : VArith : func6 == 0b101'111, func3 == 0b110, opcode == 0b101'0111; + vwaddu_vv : VArith : func6 == 0b110'000, func3 == 0b010, opcode == 0b101'0111; + vwaddu_vx : VArith : func6 == 0b110'000, func3 == 0b110, opcode == 0b101'0111; + vwadd_vv : VArith : func6 == 0b110'001, func3 == 0b010, opcode == 0b101'0111; + vwadd_vx : VArith : func6 == 0b110'001, func3 == 0b110, opcode == 0b101'0111; + vwsubu_vv : VArith : func6 == 0b110'010, func3 == 0b010, opcode == 0b101'0111; + vwsubu_vx : VArith : func6 == 0b110'010, func3 == 0b110, opcode == 0b101'0111; + vwsub_vv : VArith : func6 == 0b110'011, func3 == 0b010, opcode == 0b101'0111; + vwsub_vx : VArith : func6 == 0b110'011, func3 == 0b110, opcode == 0b101'0111; + vwaddu_w_vv : VArith : func6 == 0b110'100, func3 == 0b010, opcode == 0b101'0111; + vwaddu_w_vx : VArith : func6 == 0b110'100, func3 == 0b110, opcode == 0b101'0111; + vwadd_w_vv : VArith : func6 == 0b110'101, func3 == 0b010, opcode == 0b101'0111; + vwadd_w_vx : VArith : func6 == 0b110'101, func3 == 0b110, opcode == 0b101'0111; + vwsubu_w_vv : VArith : func6 == 0b110'110, func3 == 0b010, opcode == 0b101'0111; + vwsubu_w_vx : VArith : func6 == 0b110'110, func3 == 0b110, opcode == 0b101'0111; + vwsub_w_vv : VArith : func6 == 0b110'111, func3 == 0b010, opcode == 0b101'0111; + vwsub_w_vx : VArith : func6 == 0b110'111, func3 == 0b110, opcode == 0b101'0111; + vwmulu_vv : VArith : func6 == 0b111'000, func3 == 0b010, opcode == 0b101'0111; + vwmulu_vx : VArith : func6 == 0b111'000, func3 == 0b110, opcode == 0b101'0111; + vwmulsu_vv : VArith : func6 == 0b111'010, func3 == 0b010, opcode == 0b101'0111; + vwmulsu_vx : VArith : func6 == 0b111'010, func3 == 0b110, opcode == 0b101'0111; + vwmul_vv : VArith : func6 == 0b111'011, func3 == 0b010, opcode == 0b101'0111; + vwmul_vx : VArith : func6 == 0b111'011, func3 == 0b110, opcode == 0b101'0111; + vwmaccu_vv : VArith : func6 == 0b111'100, func3 == 0b010, opcode == 0b101'0111; + vwmaccu_vx : VArith : func6 == 0b111'100, func3 == 0b110, opcode == 0b101'0111; + vwmacc_vv : VArith : func6 == 0b111'101, func3 == 0b010, opcode == 0b101'0111; + vwmacc_vx : VArith : func6 == 0b111'101, func3 == 0b110, opcode == 0b101'0111; + vwmaccus_vv : VArith : func6 == 0b111'110, func3 == 0b010, opcode == 0b101'0111; + vwmaccus_vx : VArith : func6 == 0b111'110, func3 == 0b110, opcode == 0b101'0111; + vwmaccsu_vv : VArith : func6 == 0b111'111, func3 == 0b010, opcode == 0b101'0111; + vwmaccsu_vx : VArith : func6 == 0b111'111, func3 == 0b110, opcode == 0b101'0111; + + // VWXUNARY0 vv: VArith : func6 == 0b010'000, func3 == 0b010, opcode == 0b101'0111; + vmv_x_s : VArith : func6 == 0b010'000, vs1 == 0b00000, func3 == 0b010, opcode == 0b101'0111; + vcpop : VArith : func6 == 0b010'000, vs1 == 0b10000, func3 == 0b010, opcode == 0b101'0111; + vfirst : VArith : func6 == 0b010'000, vs1 == 0b10001, func3 == 0b010, opcode == 0b101'0111; + + // VRXUNARY0 vx: VArith : func6 == 0b010'000, func3 == 0b110, opcode == 0b101'0111; + vmv_s_x : VArith : func6 == 0b010'000, vs2 == 0, func3 == 0b110, opcode == 0b101'0111; + + // VXUNARY0 vv : VArith : func6 == 0b010'010, func3 == 0b010, opcode == 0b101'0111; + vzext_vf8: VArith : func6 == 0b010'010, vs1 == 0b00010, func3 == 0b010, opcode == 0b101'0111; + vsext_vf8: VArith : func6 == 0b010'010, vs1 == 0b00011, func3 == 0b010, opcode == 0b101'0111; + vzext_vf4: VArith : func6 == 0b010'010, vs1 == 0b00100, func3 == 0b010, opcode == 0b101'0111; + vsext_vf4: VArith : func6 == 0b010'010, vs1 == 0b00101, func3 == 0b010, opcode == 0b101'0111; + vzext_vf2: VArith : func6 == 0b010'010, vs1 == 0b00110, func3 == 0b010, opcode == 0b101'0111; + vsext_vf2: VArith : func6 == 0b010'010, vs1 == 0b00111, func3 == 0b010, opcode == 0b101'0111; + + // VMUNARY vv : VArith : func6 == 0b010'100, func3 == 0b010, opcode == 0b101'0111; + vmsbf : VArith : func6 == 0b010'100, vs1 == 0b00001, func3 == 0b010, opcode == 0b101'0111; + vmsof : VArith : func6 == 0b010'100, vs1 == 0b00010, func3 == 0b010, opcode == 0b101'0111; + vmsif : VArith : func6 == 0b010'100, vs1 == 0b00011, func3 == 0b010, opcode == 0b101'0111; + viota : VArith : func6 == 0b010'100, vs1 == 0b10000, func3 == 0b010, opcode == 0b101'0111; + vid : VArith : func6 == 0b010'100, vs1 == 0b10001, func3 == 0b010, opcode == 0b101'0111; +};
diff --git a/cheriot/riscv_cheriot_vector.isa b/cheriot/riscv_cheriot_vector.isa new file mode 100644 index 0000000..a4c00d8 --- /dev/null +++ b/cheriot/riscv_cheriot_vector.isa
@@ -0,0 +1,1094 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file defines the non-floating point vector instruction definitions. + +// First disasm field is 18 char wide and left justified. +disasm widths = {-18}; + +slot riscv_cheriot_vector { + includes { + #include "cheriot/riscv_cheriot_vector_memory_instructions.h" + #include "cheriot/riscv_cheriot_vector_opi_instructions.h" + #include "cheriot/riscv_cheriot_vector_opm_instructions.h" + #include "cheriot/riscv_cheriot_vector_permute_instructions.h" + #include "cheriot/riscv_cheriot_vector_reduction_instructions.h" + #include "cheriot/riscv_cheriot_vector_unary_instructions.h" + #include "absl/functional/bind_front.h" + } + default size = 4; + default latency = 0; + default opcode = + disasm: "Unimplemented instruction at 0x%(@:08x)", + semfunc: "&RV32VUnimplementedInstruction"; + opcodes { + // Configuration. + vsetvli_xn{: rs1, zimm11: rd}, + disasm: "vsetvli","%rd,", "%rs1, %zimm11", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/ false, /*rs1_zero*/ false)"; + vsetvli_nz{: rs1, zimm11: rd}, + disasm: "vsetvli", "%rd, %rs1, %zimm11", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/false, /*rs1_zero*/ true)"; + vsetvli_zz{: rs1, zimm11: rd}, + disasm: "vsetvli", "%rd, %rs1, %zimm11", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/true, /*rs1_zero*/ true)"; + vsetivli{: uimm5, zimm10: rd}, + disasm: "vsetivli %uimm5, %zimm10", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/false, /*rs1_zero*/ false)"; + vsetvl_xn{: rs1, rs2: rd}, + disasm: "vsetvl", "%rd, %rs1, %rs2", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/false, /*rs1_zero*/ false)"; + vsetvl_nz{: rs1, rs2: rd}, + disasm: "vsetvl", "%rd, %rs1, %rs2", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/false, /*rs1_zero*/ true)"; + vsetvl_zz{: rs1, rs2: rd}, + disasm: "vsetvl", "%rd, %rs1, %rs2", + semfunc: "absl::bind_front(&Vsetvl, /*rd_zero*/true, /*rs1_zero*/ true)"; + + // VECTOR LOADS + + // Unit stride loads, masked (vm=0) + vle8{(: rs1, vmask :), (: : vd )}, + disasm: "vle8.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; + vle16{(: rs1, vmask :), (: : vd )}, + disasm: "vle16.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; + vle32{(: rs1, vmask :), ( : : vd) }, + disasm: "vle32.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; + vle64{(: rs1, vmask :), ( : : vd) }, + disasm: "vle64.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + + // Unit stride loads, unmasked (vm=1) + vle8_vm1{(: rs1, vmask_true :), (: : vd )}, + disasm: "vle8.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; + vle16_vm1{(: rs1, vmask_true :), (: : vd )}, + disasm: "vle16.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; + vle32_vm1{(: rs1, vmask_true :), ( : : vd) }, + disasm: "vle32.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; + vle64_vm1{(: rs1, vmask_true :), ( : : vd) }, + disasm: "vle64.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + + // Vector strided loads + vlse8{(: rs1, rs2, vmask :), (: : vd)}, + disasm: "vlse8.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild"; + vlse16{(: rs1, rs2, vmask :), (: : vd)}, + disasm: "vlse16.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild"; + vlse32{(: rs1, rs2, vmask :), (: : vd)}, + disasm: "vlse32.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild"; + vlse64{(: rs1, rs2, vmask :), (: : vd)}, + disasm: "vlse64.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild"; + + // Vector mask load + vlm{(: rs1 :), (: : vd)}, + disasm: "vlm.v", "%vd, (%rs1)", + semfunc: "&Vlm", "&VlChild"; + + // Unit stride vector load, fault first + vle8ff{(: rs1, vmask:), (: : vd)}, + disasm: "vle8ff.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; + vle16ff{(: rs1, vmask:), (: : vd)}, + disasm: "vle16ff.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; + vle32ff{(: rs1, vmask:), (: : vd)}, + disasm: "vle32ff.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; + vle64ff{(: rs1, vmask:), (: : vd)}, + disasm: "vle64ff.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + + // Vector register load + vl1re8{(: rs1 :), (: : vd)}, + disasm: "vl1re8.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 1, /*element_width*/ 1)", "&VlChild"; + vl1re16{(: rs1 :), (: : vd)}, + disasm: "vl1re16.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 1, /*element_width*/ 2)", "&VlChild"; + vl1re32{(: rs1 :), (: : vd)}, + disasm: "vl1re32.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 1, /*element_width*/ 4)", "&VlChild"; + vl1re64{(: rs1 :), (: : vd)}, + disasm: "vl1re64.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 1, /*element_width*/ 8)", "&VlChild"; + vl2re8{(: rs1 :), (: : vd)}, + disasm: "vl2re8.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 2, /*element_width*/ 1)", "&VlChild"; + vl2re16{(: rs1 :), (: : vd)}, + disasm: "vl2re16.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 2, /*element_width*/ 2)", "&VlChild"; + vl2re32{(: rs1 :), (: : vd)}, + disasm: "vl2re32.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 2, /*element_width*/ 4)", "&VlChild"; + vl2re64{(: rs1 :), (: : vd)}, + disasm: "vl2re64.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 2, /*element_width*/ 8)", "&VlChild"; + vl4re8{(: rs1 :), (: : vd)}, + disasm: "vl4re8.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 4, /*element_width*/ 1)", "&VlChild"; + vl4re16{(: rs1 :), (: : vd)}, + disasm: "vl4re16.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 4, /*element_width*/ 2)", "&VlChild"; + vl4re32{(: rs1 :), (: : vd)}, + disasm: "vl4re32.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 4, /*element_width*/ 4)", "&VlChild"; + vl4re64{(: rs1 :), (: : vd)}, + disasm: "vl4re64.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 4, /*element_width*/ 8)", "&VlChild"; + vl8re8{(: rs1 :), (: : vd)}, + disasm: "vl8re8.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 8, /*element_width*/ 1)", "&VlChild"; + vl8re16{(: rs1 :), (: : vd)}, + disasm: "vl8re16.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 8, /*element_width*/ 2)", "&VlChild"; + vl8re32{(: rs1 :), (: : vd)}, + disasm: "vl8re32.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 8, /*element_width*/ 4)", "&VlChild"; + vl8re64{(: rs1 :), (: : vd)}, + disasm: "vl8re64.v", "%vd, (%rs1)", + semfunc: "absl::bind_front(&VlRegister, /*num_regs*/ 8, /*element_width*/ 8)", "&VlChild"; + + // Vector load, indexed, unordered. + vluxei8{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vluxei8.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 1)", "&VlChild"; + vluxei16{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vluxei16.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 2)", "&VlChild"; + vluxei32{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vluxei32.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 4)", "&VlChild"; + vluxei64{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vluxei64.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 8)", "&VlChild"; + + // Vector load, indexed, ordered. + vloxei8{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vloxei8.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 1)", "&VlChild"; + vloxei16{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vloxei16.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 2)", "&VlChild"; + vloxei32{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vloxei32.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 4)", "&VlChild"; + vloxei64{(: rs1, vs2, vmask:), (: : vd)}, + disasm: "vloxei64.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlIndexed, /*index_width*/ 8)", "&VlChild"; + + // Vector unit-stride segment load + vlsege8{(: rs1, vmask, nf:), (: nf : vd)}, + disasm: "vlseg%nf\\e.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlSegment, /*element_width*/ 1)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 1)"; + vlsege16{(: rs1, vmask, nf:), (: nf : vd)}, + disasm: "vlseg%nf\\e.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlSegment, /*element_width*/ 2)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 2)"; + vlsege32{(: rs1, vmask, nf:), (: nf : vd)}, + disasm: "vlseg%nf\\e.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlSegment, /*element_width*/ 4)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 4)"; + vlsege64{(: rs1, vmask, nf:), (: nf : vd)}, + disasm: "vlseg%nf\\e.v", "%vd, (%rs1), %vmask", + semfunc: "absl::bind_front(&VlSegment, /*element_width*/ 8)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 8)"; + + // Vector strided segment load. + vlssege8{(: rs1, rs2, vmask, nf: ), (: nf : vd)}, + disasm: "vlssg%nf\\e8.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentStrided, /*element_width*/ 1)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 1)"; + vlssege16{(: rs1, rs2, vmask, nf: ), (: nf : vd)},/*element_width*/ + disasm: "vlssg%nf\\e16.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentStrided, /*element_width*/ 2)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 2)"; + vlssege32{(: rs1, rs2, vmask, nf: ), (: nf : vd)}, + disasm: "vlssg%nf\\e32.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentStrided, /*element_width*/ 4)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 4)"; + vlssege64{(: rs1, rs2, vmask, nf: ), (: nf : vd)}, + disasm: "vlssg%nf\\e64.v", "%vd, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentStrided, /*element_width*/ 8)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 8)"; + + // Vector indexed segment load unordered. + vluxsegei8{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei1.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 1)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 1)"; + vluxsegei16{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei2.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 2)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 2)"; + vluxsegei32{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei4.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 4)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 4)"; + vluxsegei64{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei8.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 8)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 8)"; + + // Vector indexed segment load ordered. + + vloxsegei8{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei1.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 1)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 1)"; + vloxsegei16{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei2.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 2)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 2)"; + vloxsegei32{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei4.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 4)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 4)"; + vloxsegei64{(: rs1, vs2, vmask, nf :), (: nf : vd)}, + disasm: "vluxseg%nf\\ei8.v", "%vd, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VlSegmentIndexed, /*index_width*/ 8)", + "absl::bind_front(&VlSegmentChild, /*element_width*/ 8)"; + + // VECTOR STORES + + // Vector store, unit stride. + vse8{: vs3, rs1, const1, vmask : }, + disasm: "vse8.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)"; + vse16{: vs3, rs1, const1, vmask : }, + disasm: "vse16.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)"; + vse32{: vs3, rs1, const1, vmask : }, + disasm: "vse32.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)"; + vse64{: vs3, rs1, const1, vmask : }, + disasm: "vse64.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)"; + + // Vector store mask + vsm{: vs3, rs1, const1, vmask_true:}, + disasm: "vsm", + semfunc: "absl::bind_front(&Vsm)"; + + // Vector store, unit stride, fault first. + vse8ff{: vs3, rs1, const1, vmask:}, + disasm: "vse8ff.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)"; + vse16ff{: vs3, rs1, const1, vmask:}, + disasm: "vse16ff.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)"; + vse32ff{: vs3, rs1, const1, vmask:}, + disasm: "vse32ff.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)"; + vse64ff{: vs3, rs1, const1, vmask:}, + disasm: "vse64ff.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)"; + + // Vector store register. + vs1re8{(: vs3, rs1 :)}, + disasm: "vs1re8.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 1)"; + vs1re16{(: vs3, rs1 :)}, + disasm: "vs1re16.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 1)"; + vs1re32{(: vs3, rs1 :)}, + disasm: "vs1re32.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 1)"; + vs1re64{(: vs3, rs1 :)}, + disasm: "vs1re64.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 1)"; + vs2re8{(: vs3, rs1 :)}, + disasm: "vs2re8.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 2)"; + vs2re16{(: vs3, rs1 :)}, + disasm: "vs2re16.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 2)"; + vs2re32{(: vs3, rs1 :)}, + disasm: "vs2re32.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 2)"; + vs2re64{(: vs3, rs1 :)}, + disasm: "vs2re64.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 2)"; + vs4re8{(: vs3, rs1 :)}, + disasm: "vs4re8.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 4)"; + vs4re16{(: vs3, rs1 :)}, + disasm: "vs4re16.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 4)"; + vs4re32{(: vs3, rs1 :)}, + disasm: "vs4re32.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 4)"; + vs4re64{(: vs3, rs1 :)}, + disasm: "vs4re64.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/ 4)"; + vs8re8{(: vs3, rs1 :)}, + disasm: "vs8re8.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/8)"; + vs8re16{(: vs3, rs1 :)}, + disasm: "vs8re16.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/8)"; + vs8re32{(: vs3, rs1 :)}, + disasm: "vs8re32.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/8)"; + vs8re64{(: vs3, rs1 :)}, + disasm: "vs8re64.v", "%vs3, (%rs1)", + semfunc: "absl::bind_front(&VsRegister, /*num_regs*/8)"; + + // Vector store, strided. + vsse8{: vs3, rs1, rs2, vmask : }, + disasm: "vsse8.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)"; + vsse16{: vs3, rs1, rs2, vmask : }, + disasm: "vsse16.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)"; + vsse32{: vs3, rs1, rs2, vmask : }, + disasm: "vsse32.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)"; + vsse64{: vs3, rs1, rs2, vmask : }, + disasm: "vsse64.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)"; + + // Vector store, indexed, unordered. + vsuxei8{: vs3, rs1, vs2, vmask: }, + disasm: "vsuxei8", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 1)"; + vsuxei16{: vs3, rs1, vs2, vmask:}, + disasm: "vsuxei16", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 2)"; + vsuxei32{: vs3, rs1, vs2, vmask:}, + disasm: "vsuxei32", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 4)"; + vsuxei64{: vs3, rs1, vs2, vmask:}, + disasm: "vsuxei64", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 8)"; + + // Vector store, indexed, unordered + vsoxei8{: vs3, rs1, vs2, vmask:}, + disasm: "vsoxei8", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 1)"; + vsoxei16{: vs3, rs1, vs2, vmask:}, + disasm: "vsoxei16", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 2)"; + vsoxei32{: vs3, rs1, vs2, vmask:}, + disasm: "vsoxei32", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 4)"; + vsoxei64{: vs3, rs1, vs2, vmask:}, + disasm: "vsoxei64", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsIndexed, /*index_width*/ 8)"; + + // Vector unit-stride segment store. + vssege8{(: vs3, rs1, vmask, nf:)}, + disasm: "vsseg%nf\\e.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsSegment, /*element_width*/ 1)"; + vssege16{(: vs3, rs1, vmask, nf:)}, + disasm: "vsseg%nf\\e.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsSegment, /*element_width*/ 2)"; + vssege32{(: vs3, rs1, vmask, nf:)}, + disasm: "vsseg%nf\\e.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsSegment, /*element_width*/ 4)"; + vssege64{(: vs3, rs1, vmask, nf:)}, + disasm: "vsseg%nf\\e.v", "%vs3, (%rs1), %vmask", + semfunc: "absl::bind_front(&VsSegment, /*element_width*/ 8)"; + + // Vector strided segment store. + vsssege8{(: vs3, rs1, rs2, vmask, nf: )}, + disasm: "vssseg%nf\\e8.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 1)"; + vsssege16{(: vs3, rs1, rs2, vmask, nf: )}, + disasm: "vssseg%nf\\e16.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 2)"; + vsssege32{(: vs3, rs1, rs2, vmask, nf: )}, + disasm: "vssseg%nf\\e32.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 4)"; + vsssege64{(: vs3, rs1, rs2, vmask, nf: )}, + disasm: "vssseg%nf\\e64.v", "%vs3, (%rs1), %rs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 8)"; + + // Vector indexed segment store unordered. + vsuxsegei8{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei1.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 1)"; + vsuxsegei16{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei2.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 2)"; + vsuxsegei32{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei4.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 4)"; + vsuxsegei64{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei8.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentStrided, /*element_width*/ 8)"; + + // Vector indexed segment store ordered. + vsoxsegei8{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei1.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentIndexed, /*index_width*/ 1)"; + vsoxsegei16{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei2.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentIndexed, /*index_width*/ 2)"; + vsoxsegei32{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei4.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentIndexed, /*index_width*/ 4)"; + vsoxsegei64{(: vs3, rs1, vs2, vmask, nf :)}, + disasm: "vsuxseg%nf\\ei8.v", "%vs3, (%rs1), %vs2, %vmask", + semfunc: "absl::bind_front(&VsSegmentIndexed, /*index_width*/ 8)"; + + // Integer OPIVV, OPIVX, OPIVI. + vadd_vv{: vs2, vs1, vmask : vd}, + disasm: "vadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vadd"; + vadd_vx{: vs2, rs1, vmask : vd}, + disasm: "vadd.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vadd"; + vadd_vi{: vs2, simm5, vmask : vd}, + disasm: "vadd.vi", "%vd, %simm5, %vmask", + semfunc: "&Vadd"; + vsub_vv{: vs2, vs1, vmask : vd}, + disasm: "vsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsub"; + vsub_vx{: vs2, rs1, vmask : vd}, + disasm: "vsub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsub"; + vrsub_vx{: vs2, rs1, vmask : vd}, + disasm: "vrsub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vrsub"; + vrsub_vi{: vs2, simm5, vmask, vd}, + disasm: "vrsub.vi", "%vd, %simm5, %vmask", + semfunc: "&Vrsub"; + vminu_vv{: vs2, vs1, vmask : vd}, + disasm: "vminu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vminu"; + vminu_vx{: vs2, rs1, vmask : vd}, + disasm: "vminu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vminu"; + vmin_vv{: vs2, vs1, vmask : vd}, + disasm: "vmin.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmin"; + vmin_vx{: vs2, rs1, vmask : vd}, + disasm: "vmin.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmin"; + vmaxu_vv{: vs2, vs1, vmask : vd}, + disasm: "vmax.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmaxu"; + vmaxu_vx{: vs2, rs1, vmask : vd}, + disasm: "vmax.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmaxu"; + vmax_vv{: vs2, vs1, vmask : vd}, + disasm: "vmax.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmax"; + vmax_vx{: vs2, rs1, vmask : vd}, + disasm: "vmax.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmax"; + vand_vv{: vs2, vs1, vmask : vd}, + disasm: "vand.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vand"; + vand_vx{: vs2, rs1, vmask : vd}, + disasm: "vand.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vand"; + vand_vi{: vs2, simm5, vmask : vd}, + disasm: "vand.vi", "%vd, %simm5, %vmask", + semfunc: "&Vand"; + vor_vv{: vs2, vs1, vmask : vd}, + disasm: "vor.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vor"; + vor_vx{: vs2, rs1, vmask : vd}, + disasm: "vor.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vor"; + vor_vi{: vs2, simm5, vmask : vd}, + disasm: "vor.vi", "%vd, %simm5, %vmask", + semfunc: "&Vor"; + vxor_vv{: vs2, vs1, vmask : vd}, + disasm: "vxor.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vxor"; + vxor_vx{: vs2, rs1, vmask : vd}, + disasm: "vxor.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vxor"; + vxor_vi{: vs2, simm5, vmask : vd}, + disasm: "vxor.vi", "%vd, %simm5, %vmask", + semfunc: "&Vxor"; + vrgather_vv{: vs2, vs1, vmask: vd}, + disasm: "vrgather.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vrgather"; + vrgather_vx{: vs2, rs1, vmask: vd}, + disasm: "vrgather.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vrgather"; + vrgather_vi{: vs2, uimm5, vmask: vd}, + disasm: "vrgather.vi", "%vd, %uimm5, %vmask", + semfunc: "&Vrgather"; + vrgatherei16_vv{: vs2, vs1, vmask: vd}, + disasm: "vrgatherei16.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vrgatherei16"; + vslideup_vx{: vs2, rs1, vmask: vd}, + disasm: "vslideup.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vslideup"; + vslideup_vi{: vs2, uimm5, vmask: vd}, + disasm: "vslideup.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vslideup"; + vslidedown_vx{: vs2, rs1, vmask: vd}, + disasm: "vslidedown.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vslidedown"; + vslidedown_vi{: vs2, uimm5, vmask: vd}, + disasm: "vslidedown.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vslidedown"; + vadc_vv{: vs2, vs1, vmask: vd}, + disasm: "vadc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vadc"; + vadc_vx{: vs2, rs1, vmask: vd}, + disasm: "vadc.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vadc"; + vadc_vi{: vs2, simm5, vmask: vd}, + disasm: "vadc.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vadc"; + vmadc_vv{: vs2, vs1, vmask, vm: vd}, + disasm: "vmadc.vv", "%vd, %vs2, %vs1, %vmask, %vmask", + semfunc: "&Vmadc"; + vmadc_vx{: vs2, rs1, vmask, vm: vd}, + disasm: "vmadc.vx", "%vd, %vs2, %rs1, %vmask, %vmask", + semfunc: "&Vmadc"; + vmadc_vi{: vs2, simm5, vmask, vm: vd}, + disasm: "vmadc.vi", "%vd, %vs2, %simm5, %vmask, %vmask", + semfunc: "&Vmadc"; + vsbc_vv{: vs2, vs1, vmask: vd}, + disasm: "vsbc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsbc"; + vsbc_vx{: vs2, rs1, vmask: vd}, + disasm: "vsbc.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsbc"; + vmsbc_vv{: vs2, vs1, vmask, vm: vd}, + disasm: "vmsbc.vv", "%vd, %vs2, %vs1, %vmask, %vmask", + semfunc: "&Vmsbc"; + vmsbc_vx{: vs2, rs1, vmask, vm: vd}, + disasm: "vmsbc.vx", "%vd, %vs2, %rs1, %vmask, %vmask", + semfunc: "&Vmsbc"; + vmerge_vv{: vs2, vs1, vmask: vd}, + disasm: "vmerge.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmerge"; + vmerge_vx{: vs2, rs1, vmask: vd}, + disasm: "vmerge.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmerge"; + vmerge_vi{: vs2, simm5, vmask: vd}, + disasm: "vmerge.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmerge"; + vmv_vv{: vs2, vs1, vmask_true: vd}, + disasm: "vmv.vv", "%vd, %vs1", + semfunc: "&Vmerge"; + vmv_vx{: vs2, rs1, vmask_true: vd}, + disasm: "vmv.vx", "%vd, %rs1", + semfunc: "&Vmerge"; + vmv_vi{: vs2, simm5, vmask_true: vd}, + disasm: "vmv.vi", "%vd, %simm5", + semfunc: "&Vmerge"; + vmseq_vv{: vs2, vs1, vmask: vd}, + disasm: "vmseq.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmseq"; + vmseq_vx{: vs2, rs1, vmask: vd}, + disasm: "vmseq.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmseq"; + vmseq_vi{: vs2, simm5, vmask: vd}, + disasm: "vmseq.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmseq"; + vmsne_vv{: vs2, vs1, vmask: vd}, + disasm: "vmsne.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmsne"; + vmsne_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsne.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsne"; + vmsne_vi{: vs2, simm5, vmask: vd}, + disasm: "vmsne.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmsne"; + vmsltu_vv{: vs2, vs1, vmask: vd}, + disasm: "vmsltu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmsltu"; + vmsltu_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsltu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsltu"; + vmslt_vv{: vs2, vs1, vmask: vd}, + disasm: "vmslt.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmslt"; + vmslt_vx{: vs2, rs1, vmask: vd}, + disasm: "vmslt.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmslt"; + vmsleu_vv{: vs2, vs1, vmask: vd}, + disasm: "vmsleu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmsleu"; + vmsleu_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsleu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsleu"; + vmsleu_vi{: vs2, simm5, vmask: vd}, + disasm: "vmsleu.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmsleu"; + vmsle_vv{: vs2, vs1, vmask: vd}, + disasm: "vmsle.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmsle"; + vmsle_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsle.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsle"; + vmsle_vi{: vs2, simm5, vmask: vd}, + disasm: "vmsle.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmsle"; + vmsgtu_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsgtu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsgtu"; + vmsgtu_vi{: vs2, simm5, vmask: vd}, + disasm: "vmsgtu.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmsgtu"; + vmsgt_vx{: vs2, rs1, vmask: vd}, + disasm: "vmsgt.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmsgt"; + vmsgt_vi{: vs2, simm5, vmask: vd}, + disasm: "vmsgt.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vmsgt"; + vsaddu_vv{: vs2, vs1, vmask: vd}, + disasm: "vsaddu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsaddu"; + vsaddu_vx{: vs2, rs1, vmask: vd}, + disasm: "vsaddu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsaddu"; + vsaddu_vi{: vs2, simm5, vmask: vd}, + disasm: "vsaddu.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vsaddu"; + vsadd_vv{: vs2, vs1, vmask: vd}, + disasm: "vsadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsadd"; + vsadd_vx{: vs2, rs1, vmask: vd}, + disasm: "vsadd.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsadd"; + vsadd_vi{: vs2, simm5, vmask: vd}, + disasm: "vsadd.vi", "%vd, %vs2, %simm5, %vmask", + semfunc: "&Vsadd"; + vssubu_vv{: vs2, vs1, vmask: vd}, + disasm: "vssubu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vssubu"; + vssubu_vx{: vs2, rs1, vmask: vd}, + disasm: "vssubu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vssubu"; + vssub_vv{: vs2, vs1, vmask: vd}, + disasm: "vssub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vssub"; + vssub_vx{: vs2, rs1, vmask: vd}, + disasm: "vssub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vssub"; + vsll_vv{: vs2, vs1, vmask : vd}, + disasm: "vsll.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsll"; + vsll_vx{: vs2, rs1, vmask : vd}, + disasm: "vsll.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsll"; + vsll_vi{: vs2, simm5, vmask: vd}, + disasm: "vsll.vi", "%vd, %simm5, %vmask", + semfunc: "&Vsll"; + vsmul_vv{: vs2, vs1, vmask : vd}, + disasm: "vsmul.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsmul"; + vsmul_vx{: vs2, rs1, vmask : vd}, + disasm: "vsmul.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsmul"; + vmv1r_vi{: vs2 : vd}, + disasm: "vmv1r.vi", "%vd, %vs2", + semfunc: "absl::bind_front(&Vmvr, 1)"; + vmv2r_vi{: vs2 : vd}, + disasm: "vmv2r.vi", "%vd, %vs2", + semfunc: "absl::bind_front(&Vmvr, 2)"; + vmv4r_vi{: vs2 : vd}, + disasm: "vmv4r.vi", "%vd, %vs2", + semfunc: "absl::bind_front(&Vmvr, 4)"; + vmv8r_vi{: vs2 : vd}, + disasm: "vmv8r.vi", "%vd, %vs2", + semfunc: "absl::bind_front(&Vmvr, 8)"; + vsrl_vv{: vs2, vs1, vmask : vd}, + disasm: "vsrl.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsrl"; + vsrl_vx{: vs2, rs1, vmask : vd}, + disasm: "vsrl.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsrl"; + vsrl_vi{: vs2, simm5, vmask: vd}, + disasm: "vsrl.vi", "%vd, %simm5, %vmask", + semfunc: "&Vsrl"; + vsra_vv{: vs2, vs1, vmask : vd}, + disasm: "vsra.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vsra"; + vsra_vx{: vs2, rs1, vmask : vd}, + disasm: "vsra.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vsra"; + vsra_vi{: vs2, simm5, vmask: vd}, + disasm: "vsra.vi", "%vd, %simm5, %vmask", + semfunc: "&Vsra"; + vssrl_vv{: vs2, vs1, vmask: vd}, + disasm: "vssrl.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vssrl"; + vssrl_vx{: vs2, rs1, vmask: vd}, + disasm: "vssrl.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vssrl"; + vssrl_vi{: vs2, uimm5, vmask: vd}, + disasm: "vssrl.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vssrl"; + vssra_vv{: vs2, vs1, vmask: vd}, + disasm: "vssra.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vssra"; + vssra_vx{: vs2, rs1, vmask: vd}, + disasm: "vssra.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vssra"; + vssra_vi{: vs2, uimm5, vmask: vd}, + disasm: "vssra.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vssra"; + vnsrl_vv{: vs2, vs1, vmask : vd}, + disasm: "vnsrl.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnsrl"; + vnsrl_vx{: vs2, rs1, vmask : vd}, + disasm: "vnsrl.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnsrl"; + vnsrl_vi{: vs2, uimm5, vmask : vd}, + disasm: "vnsrl.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vnsrl"; + vnsra_vv{: vs2, vs1, vmask : vd}, + disasm: "vnsra.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnsra"; + vnsra_vx{: vs2, rs1, vmask : vd}, + disasm: "vnsra.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnsra"; + vnsra_vi{: vs2, uimm5, vmask : vd}, + disasm: "vnsra.vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vnsra"; + vnclipu_vv{: vs2, vs1, vmask : vd}, + disasm: "vnclipu_vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnclipu"; + vnclipu_vx{: vs2, rs1, vmask : vd}, + disasm: "vnclipu_vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnclipu"; + vnclipu_vi{: vs2, uimm5, vmask : vd}, + disasm: "vnclipu_vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vnclipu"; + vnclip_vv{: vs2, vs1, vmask : vd}, + disasm: "vnclip_vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnclip"; + vnclip_vx{: vs2, rs1, vmask : vd}, + disasm: "vnclip_vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnclip"; + vnclip_vi{: vs2, uimm5, vmask : vd}, + disasm: "vnclip_vi", "%vd, %vs2, %uimm5, %vmask", + semfunc: "&Vnclip"; + vwredsumu_vv{: vs2, vs1, vmask: vd}, + disasm: "vwredsumu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwredsumu"; + vwredsum_vv{: vs2, vs1, vmask: vd}, + disasm: "vwredsum.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwredsum"; + + // Integer OPMVV, OPMVX. + vredsum_vv{: vs2, vs1, vmask: vd}, + disasm: "vredsum.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredsum"; + vredand_vv{: vs2, vs1, vmask: vd}, + disasm: "vredand.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredand"; + vredor_vv{: vs2, vs1, vmask: vd}, + disasm: "vredor.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredor"; + vredxor_vv{: vs2, vs1, vmask: vd}, + disasm: "vredxor.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredxor"; + vredminu_vv{: vs2, vs1, vmask: vd}, + disasm: "vredminu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredminu"; + vredmin_vv{: vs2, vs1, vmask: vd}, + disasm: "vredmin.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredmin"; + vredmaxu_vv{: vs2, vs1, vmask: vd}, + disasm: "vredmaxu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredmaxu"; + vredmax_vv{: vs2, vs1, vmask: vd}, + disasm: "vredmax.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vredmax"; + vaaddu_vv{: vs2, vs1, vmask: vd}, + disasm: "vaaddu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vaaddu"; + vaaddu_vx{: vs2, rs1, vmask: vd}, + disasm: "vaaddu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vaaddu"; + vaadd_vv{: vs2, vs1, vmask: vd}, + disasm: "vaadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vaadd"; + vaadd_vx{: vs2, rs1, vmask: vd}, + disasm: "vaadd.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vaadd"; + vasubu_vv{: vs2, vs1, vmask: vd}, + disasm: "vasubu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vasubu"; + vasubu_vx{: vs2, rs1, vmask: vd}, + disasm: "vasubu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vasubu"; + vasub_vv{: vs2, vs1, vmask: vd}, + disasm: "vasub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vasub"; + vasub_vx{: vs2, rs1, vmask: vd}, + disasm: "vasub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vasub"; + vslide1up_vx{: vs2, rs1, vmask: vd}, + disasm: "vslide1up.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vslide1up"; + vslide1down_vx{: vs2, rs1, vmask: vd}, + disasm: "vslide1down.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vslide1down"; + vcompress_vv{: vs2, vs1: vd}, + disasm: "vcompress.vv", "%vd, %vs2, %vs1", + semfunc: "&Vcompress"; + vmandnot_vv{: vs2, vs1: vd}, + disasm: "vwmandnot.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmandnot"; + vmand_vv{: vs2, vs1: vd}, + disasm: "vmand.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmand"; + vmor_vv{: vs2, vs1: vd}, + disasm: "vmor.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmor"; + vmxor_vv{: vs2, vs1: vd}, + disasm: "vmxor.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmxor"; + vmornot_vv{: vs2, vs1: vd}, + disasm: "vmornot.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmornot"; + vmnand_vv{: vs2, vs1: vd}, + disasm: "vmnand.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmnand"; + vmnor_vv{: vs2, vs1: vd}, + disasm: "vmnor.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmnor"; + vmxnor_vv{: vs2, vs1: vd}, + disasm: "vmxnor.vv", "%vd, %vs2, %vs1", + semfunc: "&Vmxnor"; + vdivu_vv{: vs2, vs1, vmask: vd}, + disasm: "vdivu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vdivu"; + vdivu_vx{: vs2, rs1, vmask: vd}, + disasm: "vdivu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vdivu"; + vdiv_vv{: vs2, vs1, vmask: vd}, + disasm: "vdiv.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vdiv"; + vdiv_vx{: vs2, rs1, vmask: vd}, + disasm: "vdiv.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vdiv"; + vremu_vv{: vs2, vs1, vmask: vd}, + disasm: "vremu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vremu"; + vremu_vx{: vs2, rs1, vmask: vd}, + disasm: "vremu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vremu"; + vrem_vv{: vs2, vs1, vmask: vd}, + disasm: "vrem.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vrem"; + vrem_vx{: vs2, rs1, vmask: vd}, + disasm: "vrem.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vrem"; + vmulhu_vv{: vs2, vs1, vmask: vd}, + disasm: "vmulhu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmulhu"; + vmulhu_vx{: vs2, rs1, vmask: vd}, + disasm: "vmulhu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmulhu"; + vmul_vv{: vs2, vs1, vmask: vd}, + disasm: "vmul.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmul"; + vmul_vx{: vs2, rs1, vmask: vd}, + disasm: "vmul.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmul"; + vmulhsu_vv{: vs2, vs1, vmask: vd}, + disasm: "vmulhsu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmulhsu"; + vmulhsu_vx{: vs2, rs1, vmask: vd}, + disasm: "vmulhsu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmulhsu"; + vmulh_vv{: vs2, vs1, vmask: vd}, + disasm: "vmulh.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmulh"; + vmulh_vx{: vs2, rs1, vmask: vd}, + disasm: "vmulh.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmulh"; + vmadd_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vmadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmadd"; + vmadd_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vmadd.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmadd"; + vnmsub_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vnmsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnmsub"; + vnmsub_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vnmsub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnmsub"; + vmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmacc"; + vmacc_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vmacc.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vmacc"; + vnmsac_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vnmsac.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vnmsac"; + vnmsac_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vnmsac.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vnmsac"; + vwaddu_vv{: vs2, vs1, vmask : vd}, + disasm: "vwaddu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwaddu"; + vwaddu_vx{: vs2, rs1, vmask : vd}, + disasm: "vwaddu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwaddu"; + vwadd_vv{: vs2, vs1, vmask : vd}, + disasm: "vwadd_vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwadd"; + vwadd_vx{: vs2, rs1, vmask : vd}, + disasm: "vwadd.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwadd"; + vwsubu_vv{: vs2, vs1, vmask : vd}, + disasm: "vwsubu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwsubu"; + vwsubu_vx{: vs2, rs1, vmask : vd}, + disasm: "vwsubu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwsubu"; + vwsub_vv{: vs2, vs1, vmask : vd}, + disasm: "vwsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwsub"; + vwsub_vx{: vs2, rs1, vmask : vd}, + disasm: "vwsub.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwsub"; + vwaddu_w_vv{: vs2, vs1, vmask : vd}, + disasm: "vwaddu.wv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwadduw"; + vwaddu_w_vx{: vs2, rs1, vmask : vd}, + disasm: "vwaddu.wx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwadduw"; + vwadd_w_vv{: vs2, vs1, vmask : vd}, + disasm: "vwadd.wv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwaddw"; + vwadd_w_vx{: vs2, rs1, vmask : vd}, + disasm: "vwadd.wx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwaddw"; + vwsubu_w_vv{: vs2, vs1, vmask : vd}, + disasm: "vwsubu.wv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwsubuw"; + vwsubu_w_vx{: vs2, rs1, vmask : vd}, + disasm: "vwsubu.wx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwsubuw"; + vwsub_w_vv{: vs2, vs1, vmask : vd}, + disasm: "vwsub.wv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwsubw"; + vwsub_w_vx{: vs2, rs1, vmask : vd}, + disasm: "vwsub.wx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwsubw"; + vwmulu_vv{: vs2, vs1, vmask: vd}, + disasm: "vwmulu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmulu"; + vwmulu_vx{: vs2, rs1, vmask: vd}, + disasm: "vwmulu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmulu"; + vwmulsu_vv{: vs2, vs1, vmask: vd}, + disasm: "vwmulsu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmulsu"; + vwmulsu_vx{: vs2, rs1, vmask: vd}, + disasm: "vwmulsu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmulsu"; + vwmul_vv{: vs2, vs1, vmask: vd}, + disasm: "vwmul.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmul"; + vwmul_vx{: vs2, rs1, vmask: vd}, + disasm: "vwmul.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmul"; + vwmaccu_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vwmaccu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmaccu"; + vwmaccu_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vwmaccu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmaccu"; + vwmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vwmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmacc"; + vwmacc_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vwmacc.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmacc"; + vwmaccus_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vwmaccus.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmaccus"; + vwmaccus_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vwmaccus.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmaccus"; + vwmaccsu_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vwmaccsu.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vwmaccsu"; + vwmaccsu_vx{: vs2, rs1, vd, vmask: vd}, + disasm: "vwmaccsu.vx", "%vd, %vs2, %rs1, %vmask", + semfunc: "&Vwmaccsu"; + + + // VWXUNARY0 + vmv_x_s{: vs2 : rd}, + disasm: "vmv.x.s", "%rd, %vs2", + semfunc: "&VmvToScalar"; + vcpop{: vs2, vmask: rd}, + disasm: "vcpop", "%rd, %vs2, %vmask", + semfunc: "&Vcpop"; + vfirst{: vs2, vmask: rd}, + disasm: "vfirst", "%rd, %vs2, %vmask", + semfunc: "&Vfirst"; + // VRXUNARY0 + vmv_s_x{: rs1 : vd}, + disasm: "vmv.s.x", "%vd, %rs1", + semfunc: "&VmvFromScalar"; + // VXUNARY0 + vzext_vf8{: vs2, vmask: vd}, + disasm: "vzext.vf8", "%vd, %vs2, %vmask", + semfunc: "&Vzext8"; + vsext_vf8{: vs2, vmask: vd}, + disasm: "vsext.vf8", "%vd, %vs2, %vmask", + semfunc: "&Vsext8"; + vzext_vf4{: vs2, vmask: vd}, + disasm: "vzext.vf4", "%vd, %vs2, %vmask", + semfunc: "&Vzext4"; + vsext_vf4{: vs2, vmask: vd}, + disasm: "vsext.vf4", "%vd, %vs2, %vmask", + semfunc: "&Vsext4"; + vzext_vf2{: vs2, vmask: vd}, + disasm: "vzext.vf2", "%vd, %vs2, %vmask", + semfunc: "&Vzext2"; + vsext_vf2{: vs2, vmask: vd}, + disasm: "vsext.vf2", "%vd, %vs2, %vmask", + semfunc: "&Vsext2"; + // VMUNARY0 + vmsbf{:vs2, vmask: vd}, + disasm: "vmsbf.m", "%vd, %vs2, %vmask", + semfunc: "&Vmsbf"; + vmsof{:vs2, vmask: vd}, + disasm: "vmsof.m", "%vd, %vs2, %vmask", + semfunc: "&Vmsof"; + vmsif{:vs2, vmask: vd}, + disasm: "vmsif.m", "%vd, %vs2, %vmask", + semfunc: "&Vmsif"; + viota{:vs2, vmask: vd}, + disasm: "viota.m", "%vd, %vs2, %vmask", + semfunc: "&Viota"; + vid{: vmask: vd}, + disasm: "vid.v", "%vd, %vmask", + semfunc: "&Vid"; + } +} +
diff --git a/cheriot/riscv_cheriot_vector_fp.bin_fmt b/cheriot/riscv_cheriot_vector_fp.bin_fmt new file mode 100644 index 0000000..d8354ad --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp.bin_fmt
@@ -0,0 +1,133 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Cheriot RiscV vector floating point instruction encodings. + +instruction group RiscVVFPInst32[32] : Inst32Format { + // FP: OPFVV, OPFVF + //opfvv : VArith : func6 == 0bxxx'xxx, func3 == 0b001, opcode == 0b101'0111; + //opfvf : VArith : func6 == 0bxxx'xxx, func3 == 0b101, opcode == 0b101'0111; + + vfadd_vv : VArith : func6 == 0b000'000, func3 == 0b001, opcode == 0b101'0111; + vfadd_vf : VArith : func6 == 0b000'000, func3 == 0b101, opcode == 0b101'0111; + vfredusum_vv : VArith : func6 == 0b000'001, func3 == 0b001, opcode == 0b101'0111; + vfsub_vv : VArith : func6 == 0b000'010, func3 == 0b001, opcode == 0b101'0111; + vfsub_vf : VArith : func6 == 0b000'010, func3 == 0b101, opcode == 0b101'0111; + vfredosum_vv : VArith : func6 == 0b000'011, func3 == 0b001, opcode == 0b101'0111; + vfmin_vv : VArith : func6 == 0b000'100, func3 == 0b001, opcode == 0b101'0111; + vfmin_vf : VArith : func6 == 0b000'100, func3 == 0b101, opcode == 0b101'0111; + vfredmin_vv : VArith : func6 == 0b000'101, func3 == 0b001, opcode == 0b101'0111; + vfmax_vv : VArith : func6 == 0b000'110, func3 == 0b001, opcode == 0b101'0111; + vfmax_vf : VArith : func6 == 0b000'110, func3 == 0b101, opcode == 0b101'0111; + vfredmax_vv : VArith : func6 == 0b000'111, func3 == 0b001, opcode == 0b101'0111; + vfsgnj_vv : VArith : func6 == 0b001'000, func3 == 0b001, opcode == 0b101'0111; + vfsgnj_vf : VArith : func6 == 0b001'000, func3 == 0b101, opcode == 0b101'0111; + vfsgnjn_vv : VArith : func6 == 0b001'001, func3 == 0b001, opcode == 0b101'0111; + vfsgnjn_vf : VArith : func6 == 0b001'001, func3 == 0b101, opcode == 0b101'0111; + vfsgnjx_vv : VArith : func6 == 0b001'010, func3 == 0b001, opcode == 0b101'0111; + vfsgnjx_vf : VArith : func6 == 0b001'010, func3 == 0b101, opcode == 0b101'0111; + vfslide1up_vf : VArith : func6 == 0b001'110, func3 == 0b101, opcode == 0b101'0111; + vfslide1down_vf : VArith : func6 == 0b001'111, func3 == 0b101, opcode == 0b101'0111; + vfmv_vf : VArith : func6 == 0b010'111, vm == 1, vs2 == 0, func3 == 0b101, opcode == 0b101'0111; + vfmerge_vf : VArith : func6 == 0b010'111, vm == 0, func3 == 0b101, opcode == 0b101'0111; + vmfeq_vv : VArith : func6 == 0b011'000, func3 == 0b001, opcode == 0b101'0111; + vmfeq_vf : VArith : func6 == 0b011'000, func3 == 0b101, opcode == 0b101'0111; + vmfle_vv : VArith : func6 == 0b011'001, func3 == 0b001, opcode == 0b101'0111; + vmfle_vf : VArith : func6 == 0b011'001, func3 == 0b101, opcode == 0b101'0111; + vmflt_vv : VArith : func6 == 0b011'011, func3 == 0b001, opcode == 0b101'0111; + vmflt_vf : VArith : func6 == 0b011'011, func3 == 0b101, opcode == 0b101'0111; + vmfne_vv : VArith : func6 == 0b011'100, func3 == 0b001, opcode == 0b101'0111; + vmfne_vf : VArith : func6 == 0b011'100, func3 == 0b101, opcode == 0b101'0111; + vmfgt_vf : VArith : func6 == 0b011'101, func3 == 0b101, opcode == 0b101'0111; + vmfge_vf : VArith : func6 == 0b011'111, func3 == 0b101, opcode == 0b101'0111; + vfdiv_vv : VArith : func6 == 0b100'000, func3 == 0b001, opcode == 0b101'0111; + vfdiv_vf : VArith : func6 == 0b100'000, func3 == 0b101, opcode == 0b101'0111; + vfrdiv_vf : VArith : func6 == 0b100'001, func3 == 0b101, opcode == 0b101'0111; + vfmul_vv : VArith : func6 == 0b100'100, func3 == 0b001, opcode == 0b101'0111; + vfmul_vf : VArith : func6 == 0b100'100, func3 == 0b101, opcode == 0b101'0111; + vfrsub_vf : VArith : func6 == 0b100'111, func3 == 0b101, opcode == 0b101'0111; + vfmadd_vv : VArith : func6 == 0b101'000, func3 == 0b001, opcode == 0b101'0111; + vfmadd_vf : VArith : func6 == 0b101'000, func3 == 0b101, opcode == 0b101'0111; + vfnmadd_vv : VArith : func6 == 0b101'001, func3 == 0b001, opcode == 0b101'0111; + vfnmadd_vf : VArith : func6 == 0b101'001, func3 == 0b101, opcode == 0b101'0111; + vfmsub_vv : VArith : func6 == 0b101'010, func3 == 0b001, opcode == 0b101'0111; + vfmsub_vf : VArith : func6 == 0b101'010, func3 == 0b101, opcode == 0b101'0111; + vfnmsub_vv : VArith : func6 == 0b101'011, func3 == 0b001, opcode == 0b101'0111; + vfnmsub_vf : VArith : func6 == 0b101'011, func3 == 0b101, opcode == 0b101'0111; + vfmacc_vv : VArith : func6 == 0b101'100, func3 == 0b001, opcode == 0b101'0111; + vfmacc_vf : VArith : func6 == 0b101'100, func3 == 0b101, opcode == 0b101'0111; + vfnmacc_vv : VArith : func6 == 0b101'101, func3 == 0b001, opcode == 0b101'0111; + vfnmacc_vf : VArith : func6 == 0b101'101, func3 == 0b101, opcode == 0b101'0111; + vfmsac_vv : VArith : func6 == 0b101'110, func3 == 0b001, opcode == 0b101'0111; + vfmsac_vf : VArith : func6 == 0b101'110, func3 == 0b101, opcode == 0b101'0111; + vfnmsac_vv : VArith : func6 == 0b101'111, func3 == 0b001, opcode == 0b101'0111; + vfnmsac_vf : VArith : func6 == 0b101'111, func3 == 0b101, opcode == 0b101'0111; + vfwadd_vv : VArith : func6 == 0b110'000, func3 == 0b001, opcode == 0b101'0111; + vfwadd_vf : VArith : func6 == 0b110'000, func3 == 0b101, opcode == 0b101'0111; + vfwredusum_vv : VArith : func6 == 0b110'001, func3 == 0b001, opcode == 0b101'0111; + vfwsub_vv : VArith : func6 == 0b110'010, func3 == 0b001, opcode == 0b101'0111; + vfwsub_vf : VArith : func6 == 0b110'010, func3 == 0b101, opcode == 0b101'0111; + vfwredosum_vv : VArith : func6 == 0b110'011, func3 == 0b001, opcode == 0b101'0111; + vfwadd_w_vv : VArith : func6 == 0b110'100, func3 == 0b001, opcode == 0b101'0111; + vfwadd_w_vf : VArith : func6 == 0b110'100, func3 == 0b101, opcode == 0b101'0111; + vfwsub_w_vv : VArith : func6 == 0b110'110, func3 == 0b001, opcode == 0b101'0111; + vfwsub_w_vf : VArith : func6 == 0b110'110, func3 == 0b101, opcode == 0b101'0111; + vfwmul_vv : VArith : func6 == 0b111'000, func3 == 0b001, opcode == 0b101'0111; + vfwmul_vf : VArith : func6 == 0b111'000, func3 == 0b101, opcode == 0b101'0111; + vfwmacc_vv : VArith : func6 == 0b111'100, func3 == 0b001, opcode == 0b101'0111; + vfwmacc_vf : VArith : func6 == 0b111'100, func3 == 0b101, opcode == 0b101'0111; + vfwnmacc_vv : VArith : func6 == 0b111'101, func3 == 0b001, opcode == 0b101'0111; + vfwnmacc_vf : VArith : func6 == 0b111'101, func3 == 0b101, opcode == 0b101'0111; + vfwmsac_vv : VArith : func6 == 0b111'110, func3 == 0b001, opcode == 0b101'0111; + vfwmsac_vf : VArith : func6 == 0b111'110, func3 == 0b101, opcode == 0b101'0111; + vfwnmsac_vv : VArith : func6 == 0b111'111, func3 == 0b001, opcode == 0b101'0111; + vfwnmsac_vf : VArith : func6 == 0b111'111, func3 == 0b101, opcode == 0b101'0111; + + // VWFUNARY0 vv: VArith : func6 == 0b010'000, func3 == 0b001, opcode == 0b101'0111; + vfmv_f_s : VArith : func6 == 0b010'000, vs1 == 0, func3 == 0b001, opcode == 0b101'0111; + + // VRFUNARY0 vf: VArith : func6 == 0b010'000, func3 == 0b101, opcode == 0b101'0111; + vfmv_s_f : VArith : func6 == 0b010'000, vs2 == 0, func3 == 0b101, opcode == 0b101'0111; + + // VFUNARY0 vv: VArith : func6 == 0b010'010, func3 == 0b001, opcode == 0b101'0111; + vfcvt_xu_f_v : VArith : func6 == 0b010'010, vs1 == 0b00000, func3 == 0b001, opcode == 0b101'0111; + vfcvt_x_f_v : VArith : func6 == 0b010'010, vs1 == 0b00001, func3 == 0b001, opcode == 0b101'0111; + vfcvt_f_xu_v : VArith : func6 == 0b010'010, vs1 == 0b00010, func3 == 0b001, opcode == 0b101'0111; + vfcvt_f_x_v : VArith : func6 == 0b010'010, vs1 == 0b00011, func3 == 0b001, opcode == 0b101'0111; + vfcvt_rtz_xu_f_v : VArith : func6 == 0b010'010, vs1 == 0b00110, func3 == 0b001, opcode == 0b101'0111; + vfcvt_rtz_x_f_v : VArith : func6 == 0b010'010, vs1 == 0b00111, func3 == 0b001, opcode == 0b101'0111; + + vfwcvt_xu_f_v : VArith : func6 == 0b010'010, vs1 == 0b01000, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_x_f_v : VArith : func6 == 0b010'010, vs1 == 0b01001, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_f_xu_v : VArith : func6 == 0b010'010, vs1 == 0b01010, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_f_x_v : VArith : func6 == 0b010'010, vs1 == 0b01011, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_f_f_v : VArith : func6 == 0b010'010, vs1 == 0b01100, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_rtz_xu_f_v: VArith : func6 == 0b010'010, vs1 == 0b01110, func3 == 0b001, opcode == 0b101'0111; + vfwcvt_rtz_x_f_v : VArith : func6 == 0b010'010, vs1 == 0b01111, func3 == 0b001, opcode == 0b101'0111; + + vfncvt_xu_f_w : VArith : func6 == 0b010'010, vs1 == 0b10000, func3 == 0b001, opcode == 0b101'0111; + vfncvt_x_f_w : VArith : func6 == 0b010'010, vs1 == 0b10001, func3 == 0b001, opcode == 0b101'0111; + vfncvt_f_xu_w : VArith : func6 == 0b010'010, vs1 == 0b10010, func3 == 0b001, opcode == 0b101'0111; + vfncvt_f_x_w : VArith : func6 == 0b010'010, vs1 == 0b10011, func3 == 0b001, opcode == 0b101'0111; + vfncvt_f_f_w : VArith : func6 == 0b010'010, vs1 == 0b10100, func3 == 0b001, opcode == 0b101'0111; + vfncvt_rod_f_f_w : VArith : func6 == 0b010'010, vs1 == 0b10101, func3 == 0b001, opcode == 0b101'0111; + vfncvt_rtz_xu_f_w: VArith : func6 == 0b010'010, vs1 == 0b10110, func3 == 0b001, opcode == 0b101'0111; + vfncvt_rtz_x_f_w : VArith : func6 == 0b010'010, vs1 == 0b10111, func3 == 0b001, opcode == 0b101'0111; + + // VFUNARY1 vv: VArith : func6 == 0b010'011, func3 == 0b001, opcode == 0b101'0111; + vfsqrt_v : VArith : func6 == 0b010'011, vs1 == 0b00000, func3 == 0b001, opcode == 0b101'0111; + vfrsqrt7_v : VArith : func6 == 0b010'011, vs1 == 0b00100, func3 == 0b001, opcode == 0b101'0111; + vfrec7_v : VArith : func6 == 0b010'011, vs1 == 0b00101, func3 == 0b001, opcode == 0b101'0111; + vfclass_v : VArith : func6 == 0b010'011, vs1 == 0b10000, func3 == 0b001, opcode == 0b101'0111; +};
diff --git a/cheriot/riscv_cheriot_vector_fp.isa b/cheriot/riscv_cheriot_vector_fp.isa new file mode 100644 index 0000000..867df70 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp.isa
@@ -0,0 +1,342 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions of RiscV vector floating point instructions. + +// First disasm field is 18 char wide and left justified. +disasm widths = {-18}; + +slot riscv_cheriot_vector_fp { + includes { + #include "cheriot/riscv_cheriot_vector_fp_compare_instructions.h" + #include "cheriot/riscv_cheriot_vector_fp_instructions.h" + #include "cheriot/riscv_cheriot_vector_fp_reduction_instructions.h" + #include "cheriot/riscv_cheriot_vector_fp_unary_instructions.h" + #include "absl/functional/bind_front.h" + } + default size = 4; + default latency = 0; + default opcode = + disasm: "Unimplemented instruction at 0x%(@:08x)", + semfunc: "&RV32VUnimplementedInstruction"; + opcodes { + // Floating point, OPFVV, OPFVF. + vfadd_vv{: vs2, vs1, vmask : vd}, + disasm: "vfadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfadd"; + vfadd_vf{: vs2, fs1, vmask : vd}, + disasm: "vfadd.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfadd"; + vfredusum_vv{: vs2, vs1, vmask : vd}, + disasm: "vfredusum.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfredosum"; + vfsub_vv{: vs2, vs1, vmask : vd}, + disasm: "vfsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfsub"; + vfsub_vf{: vs2, fs1, vmask : vd}, + disasm: "vfsub.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfsub"; + vfredosum_vv{: vs2, vs1, vmask : vd}, + disasm: "vfredosum.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfredosum"; + vfmin_vv{: vs2, vs1, vmask: vd, fflags}, + disasm: "vfmin.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmin"; + vfmin_vf{: vs2, fs1, vmask : vd, fflags}, + disasm: "vfmin.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmin"; + vfredmin_vv{: vs2, vs1, vmask: vd}, + disasm: "vfredmin.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfredmin"; + vfmax_vv{: vs2, vs1, vmask: vd, fflags}, + disasm: "vfmax.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmax"; + vfmax_vf{: vs2, fs1, vmask : vd, fflags}, + disasm: "vfmax.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmax"; + vfredmax_vv{: vs2, vs1, vmask: vd}, + disasm: "vfredmax.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfredmax"; + vfsgnj_vv{: vs2, vs1, vmask: vd}, + disasm: "vfsgnj.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfsgnj"; + vfsgnj_vf{: vs2, fs1, vmask : vd}, + disasm: "vfsgnj.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfsgnj"; + vfsgnjn_vv{: vs2, vs1, vmask: vd}, + disasm: "vfsgnjn.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfsgnjn"; + vfsgnjn_vf{: vs2, fs1, vmask : vd}, + disasm: "vfsgnjn.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfsgnjn"; + vfsgnjx_vv{: vs2, vs1, vmask: vd}, + disasm: "vfsgnjx.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfsgnjx"; + vfsgnjx_vf{: vs2, fs1, vmask : vd}, + disasm: "vfsgnjx.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfsgnjx"; + vfslide1up_vf{: vs2, fs1, vmask : vd}, + disasm: "vfslide1up.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfslide1up"; + vfslide1down_vf{: vs2, fs1, vmask : vd}, + disasm: "vfslide1down.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfslide1down"; + vfmv_vf{: fs1, vmask : vd}, + disasm: "vfmv.vf", "%vd, %fs1, %vmask", + semfunc: "&Vfmvvf"; + vfmerge_vf{: vs2, vs1, vmask : vd}, + disasm: "vfmerge.vf", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmerge"; + vmfeq_vv{: vs2, vs1, vmask : vd}, + disasm: "vmfeq.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmfeq"; + vmfeq_vf{: vs2, fs1, vmask : vd}, + disasm: "vmfeq.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmfeq"; + vmfle_vv{: vs2, vs1, vmask : vd}, + disasm: "vmfle.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmfle"; + vmfle_vf{: vs2, fs1, vmask : vd}, + disasm: "vmfle.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmfle"; + vmflt_vv{: vs2, vs1, vmask : vd}, + disasm: "vmflt.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmflt"; + vmflt_vf{: vs2, fs1, vmask : vd}, + disasm: "vmflt.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmflt"; + vmfne_vv{: vs2, vs1, vmask : vd}, + disasm: "vmfne.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vmfne"; + vmfne_vf{: vs2, fs1, vmask : vd}, + disasm: "vmfne.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmfne"; + vmfgt_vf{: vs2, fs1, vmask : vd}, + disasm: "vmfgt.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmfgt"; + vmfge_vf{: vs2, fs1, vmask : vd}, + disasm: "vmfge.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vmfge"; + vfdiv_vv{: vs2, vs1, vmask : vd}, + disasm: "vfdiv.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfdiv"; + vfdiv_vf{: vs2, fs1, vmask : vd}, + disasm: "vfdiv.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfdiv"; + vfrdiv_vf{: vs2, fs1, vmask : vd}, + disasm: "vfrdiv.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfrdiv"; + vfmul_vv{: vs2, vs1, vmask : vd}, + disasm: "vfmul.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmul"; + vfmul_vf{: vs2, fs1, vmask : vd}, + disasm: "vfmul.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmul"; + vfrsub_vf{: vs2, fs1, vmask : vd}, + disasm: "vfrsub.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfrsub"; + vfmadd_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfmadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmadd"; + vfmadd_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfmadd.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmadd"; + vfnmadd_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfnmadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfnmadd"; + vfnmadd_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfnmadd.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfnmadd"; + vfmsub_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfmsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmsub"; + vfmsub_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfmsub.v", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmsub"; + vfnmsub_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfnmsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfnmsub"; + vfnmsub_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfnmsub.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfnmsub"; + vfmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmacc"; + vfmacc_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfmacc.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmacc"; + vfnmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfnmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfnmacc"; + vfnmacc_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfnmacc.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfnmacc"; + vfmsac_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfmsac.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfmsac"; + vfmsac_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfmsac.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfmsac"; + vfnmsac_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfnmsac.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfnmsac"; + vfnmsac_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfnmsac.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfnmsac"; + vfwadd_vv{: vs2, vs1, vmask: vd}, + disasm: "vfwadd.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwadd"; + vfwadd_vf{: vs2, fs1, vmask: vd}, + disasm: "vfwadd.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwadd"; + vfwredusum_vv{: vs2, vs1, vmask : vd}, + disasm: "vfwredusum.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwredosum"; + vfwsub_vv{: vs2, vs1, vmask: vd}, + disasm: "vfwsub.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwsub"; + vfwsub_vf{: vs2, fs1, vmask: vd}, + disasm: "vfwsub.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwsub"; + vfwredosum_vv{: vs2, vs1, vmask : vd}, + disasm: "vfwredosum.vv", "%vd, %vs2, %vs1, %vmask"; + vfwadd_w_vv{: vs2, vs1, vmask: vd}, + disasm: "vfwadd.w.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwadd"; + vfwadd_w_vf{: vs2, fs1, vmask : vd}, + disasm: "vfwadd.w.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwadd"; + vfwsub_w_vv{: vs2, vs1, vmask: vd}, + disasm: "vfwsub.w.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwsub"; + vfwsub_w_vf{: vs2, fs1, vmask : vd}, + disasm: "vfwsub.w.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwsub"; + vfwmul_vv{: vs2, vs1, vmask: vd}, + disasm: "vfwmul.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwmul"; + vfwmul_vf{: vs2, fs1, vmask : vd}, + disasm: "vfwmul.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwmul"; + vfwmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfwmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwmacc"; + vfwmacc_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfwmacc.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwmacc"; + vfwnmacc_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfwnmacc.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwnmacc"; + vfwnmacc_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfwnmacc.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwnmacc"; + vfwmsac_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfwmsac.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwmsac"; + vfwmsac_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfwmsac.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwmsac"; + vfwnmsac_vv{: vs2, vs1, vd, vmask: vd}, + disasm: "vfwnmsac.vv", "%vd, %vs2, %vs1, %vmask", + semfunc: "&Vfwnmsac"; + vfwnmsac_vf{: vs2, fs1, vd, vmask: vd}, + disasm: "vfwnmsac.vf", "%vd, %vs2, %fs1, %vmask", + semfunc: "&Vfwnmsac"; + // VWFUNARY0 + vfmv_f_s{: vs2 : fd}, + disasm: "vfmv.f.s", "%fd, %vs2", + semfunc: "&Vfmvfs"; + // VRFUNARY0 + vfmv_s_f{: fs1 : vd}, + disasm: "vfmv.s.f", "%vd, %fs1", + semfunc: "&Vfmvsf"; + // VFUNARY0 + vfcvt_xu_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfcvt.xu.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtxufv"; + vfcvt_x_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfcvt.x.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtxfv"; + vfcvt_f_xu_v{: vs2, vmask: vd}, + disasm: "vfcvt.xu.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtfxuv"; + vfcvt_f_x_v{: vs2, vmask: vd}, + disasm: "vfcvt.x.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtfxv"; + vfcvt_rtz_xu_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfcvt.rtz.xu.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtrtzxufv"; + vfcvt_rtz_x_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfcvt.rtz.x.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfcvtrtzxfv"; + vfwcvt_xu_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfwcvt.xu.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtxufv"; + vfwcvt_x_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfwcvt.x.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtxfv"; + vfwcvt_f_xu_v{: vs2, vmask: vd}, + disasm: "vfwcvt.f.xu.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtfxuv"; + vfwcvt_f_x_v{: vs2, vmask: vd}, + disasm: "vfwcvt.f.x.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtfxv"; + vfwcvt_f_f_v{: vs2, vmask: vd}, + disasm: "vfwcvt.f.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtffv"; + vfwcvt_rtz_xu_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfwcvt.rtz.xu.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtrtzxufv"; + vfwcvt_rtz_x_f_v{: vs2, vmask: vd, fflags}, + disasm: "vfwcvt.rtz.x.f.v", "%vd, %vs2, %vmask", + semfunc: "&Vfwcvtrtzxfv"; + vfncvt_xu_f_w{: vs2, vmask: vd, fflags}, + disasm: "vfncvt.xu.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtxufw"; + vfncvt_x_f_w{: vs2, vmask: vd, fflags}, + disasm: "vfncvt.x.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtxfw"; + vfncvt_f_xu_w{: vs2, vmask: vd}, + disasm: "vfncvt.f.xu.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtfxuw"; + vfncvt_f_x_w{: vs2, vmask: vd}, + disasm: "vfncvt.f.x.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtfxw"; + vfncvt_f_f_w{: vs2, vmask: vd}, + disasm: "vfncvt.f.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtffw"; + vfncvt_rod_f_f_w{: vs2, vmask: vd}, + disasm: "vfncvt.rod.f.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtrodffw"; + vfncvt_rtz_xu_f_w{: vs2, vmask: vd, fflags}, + disasm: "vfncvt.rtz.xu.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtrtzxufw"; + vfncvt_rtz_x_f_w{: vs2, vmask: vd, fflags}, + disasm: "vfncvt.rtz.x.f.w", "%vd, %vs2, %vmask", + semfunc: "&Vfncvtrtzxfw"; + // VFUNARY1 + vfsqrt_v{: vs2, vmask: vd, fflags}, + disasm: "vfsqrt.v", "%vd, %vs2, %vmask", + semfunc: "&Vfsqrtv"; + vfrsqrt7_v{: vs2, vmask: vd, fflags}, + disasm: "vfrsqrt7.v", "%vd, %vs2, %vmask", + semfunc: "&Vfrsqrt7v"; + vfrec7_v{: vs2, vmask: vd}, + disasm: "vfrec7.v", "%vd, %vs2, %vmask", + semfunc: "&Vfrec7v"; + vfclass_v{: vs2, vmask: vd}, + disasm: "vfclass.v", "%vd, %vs2, %vmask", + semfunc: "&Vfclassv"; + } +} \ No newline at end of file
diff --git a/cheriot/riscv_cheriot_vector_fp_compare_instructions.cc b/cheriot/riscv_cheriot_vector_fp_compare_instructions.cc new file mode 100644 index 0000000..2747d7b --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_compare_instructions.cc
@@ -0,0 +1,148 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_compare_instructions.h" + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +// Vector floating point compare equal. +void Vmfeq(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 == vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 == vs1; }); + default: + LOG(ERROR) << "Vmfeq: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating point compare less than or equal. +void Vmfle(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 <= vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 <= vs1; }); + default: + LOG(ERROR) << "Vmfle: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating compare less than. +void Vmflt(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 < vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 < vs1; }); + default: + LOG(ERROR) << "Vmflt: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating point compare not equal. +void Vmfne(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 != vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 != vs1; }); + default: + LOG(ERROR) << "Vmfne: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating point compare greater than. +void Vmfgt(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 > vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 > vs1; }); + default: + LOG(ERROR) << "Vmfgt: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating point compare greater than or equal. +void Vmfge(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorMaskOp<float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> bool { return vs2 >= vs1; }); + case 8: + return RiscVBinaryVectorMaskOp<double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> bool { return vs2 >= vs1; }); + default: + LOG(ERROR) << "Vmfge: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_fp_compare_instructions.h b/cheriot/riscv_cheriot_vector_fp_compare_instructions.h new file mode 100644 index 0000000..9b5f83e --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_compare_instructions.h
@@ -0,0 +1,45 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_COMPARE_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_COMPARE_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file declares the vector floating point compare instructions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Each of these instructions take 3 source operands, and one destination +// operand. Source operand 0 is a vector register group, source operand 1 is +// either a vector register group or a scalar floating point register (Vmfgt and +// Vmfge only take the scalar register), source operand 2 is the vector mask +// register. Destination operand 0 is a vector register treated as the +// destination mask register. +void Vmfeq(const Instruction *inst); +void Vmfle(const Instruction *inst); +void Vmflt(const Instruction *inst); +void Vmfne(const Instruction *inst); +void Vmfgt(const Instruction *inst); +void Vmfge(const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_COMPARE_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_fp_instructions.cc b/cheriot/riscv_cheriot_vector_fp_instructions.cc new file mode 100644 index 0000000..e4d92fa --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_instructions.cc
@@ -0,0 +1,846 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_instructions.h" + +#include <cmath> +#include <cstdint> +#include <functional> +#include <tuple> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_vector_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::FPTypeInfo; +using ::mpact::sim::riscv::FPExceptions; +using ::mpact::sim::riscv::ScopedFPStatus; + +// Floating point add. +void Vfadd(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs2 + vs1; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs2 + vs1; }); + default: + LOG(ERROR) << "Vfadd: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point subtract. +void Vfsub(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs2 - vs1; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs2 - vs1; }); + default: + LOG(ERROR) << "Vfsub: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Reverse floating point subtract (rs1 - vs2). +void Vfrsub(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs1 - vs2; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs1 - vs2; }); + default: + LOG(ERROR) << "Vfrsub: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point add. +void Vfwadd(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1) -> double { + double vs2_d = static_cast<double>(vs2); + double vs1_d = static_cast<double>(vs1); + return (vs2_d + vs1_d); + }); + default: + LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point subtract. +void Vfwsub(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1) -> double { + double vs2_d = static_cast<double>(vs2); + double vs1_d = static_cast<double>(vs1); + return (vs2_d - vs1_d); + }); + default: + LOG(ERROR) << "Vfwsub: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point add with wide operand (vs2). +void Vfwaddw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<double, double, float>( + rv_vector, inst, [](double vs2_d, float vs1) -> double { + double vs1_d = static_cast<double>(vs1); + return (vs2_d + vs1_d); + }); + default: + LOG(ERROR) << "Vfwaddw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point subtract with wide operand (vs2). +void Vfwsubw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<double, double, float>( + rv_vector, inst, [](double vs2_d, float vs1) -> double { + double vs1_d = static_cast<double>(vs1); + return (vs2_d - vs1_d); + }); + default: + LOG(ERROR) << "Vfwsubw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point multiply. +void Vfmul(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs2 * vs1; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs2 * vs1; }); + default: + LOG(ERROR) << "Vfmul: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point division vs2/vs1; +void Vfdiv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs2 / vs1; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs2 / vs1; }); + default: + LOG(ERROR) << "Vfdiv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point reverse division vs1/vs2. +void Vfrdiv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> float { return vs1 / vs2; }); + case 8: + return RiscVBinaryVectorOp<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> double { return vs1 / vs2; }); + default: + LOG(ERROR) << "Vfrdiv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point multiply. +void Vfwmul(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1) -> double { + double vs2_d = static_cast<double>(vs2); + double vs1_d = static_cast<double>(vs1); + return (vs2_d * vs1_d); + }); + default: + LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point multiply and add vs2. +void Vfmadd(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vd, vs2); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vd, vs2); + }); + default: + LOG(ERROR) << "Vfmadd: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Negated floating point multiply and add vs2. +void Vfnmadd(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(-vs1, vd, -vs2); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(-vs1, vd, -vs2); + }); + default: + LOG(ERROR) << "Vfnmadd: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point multiply and subtract vs2. +void Vfmsub(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vd, -vs2); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vd, -vs2); + }); + default: + LOG(ERROR) << "Vfmsub: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Negated floating point multiply and subtract vs2. +void Vfnmsub(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(-vs1, vd, vs2); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(-vs1, vd, vs2); + }); + default: + LOG(ERROR) << "Vfnmsub: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point multiply and accumulate vd. +void Vfmacc(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vs2, vd); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vs2, vd); + }); + default: + LOG(ERROR) << "Vfmacc: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Negated floating point multiply and accumulate vd. +void Vfnmacc(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(-vs1, vs2, -vd); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(-vs1, vs2, -vd); + }); + default: + LOG(ERROR) << "Vfnmacc: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Floating point multiply and subtract vd. +void Vfmsac(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vs2, -vd); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vs2, -vd); + }); + default: + LOG(ERROR) << "Vfmsac: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Negated floating point multiply and subtract vd. +void Vfnmsac(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<float, float, float>( + rv_vector, inst, [](float vs2, float vs1, float vd) -> float { + return std::fma(-vs1, vs2, vd); + }); + case 8: + return RiscVTernaryVectorOp<double, double, double>( + rv_vector, inst, [](double vs2, double vs1, double vd) -> double { + return std::fma(-vs1, vs2, vd); + }); + default: + LOG(ERROR) << "Vfnmsac: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point multiply and accumulate vd. +void Vfwmacc(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1, double vd) -> double { + double vs1_d = vs1; + double vs2_d = vs2; + return ((vs1_d * vs2_d) + vd); + }); + default: + LOG(ERROR) << "Vfwmacc: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening negated floating point multiply and accumulate vd. +void Vfwnmacc(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1, double vd) -> double { + double vs1_d = vs1; + double vs2_d = vs2; + return (-(vs1_d * vs2_d)) - vd; + }); + default: + LOG(ERROR) << "Vfwnmacc: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening floating point multiply and subtract vd. +void Vfwmsac(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1, double vd) -> double { + double vs1_d = vs1; + double vs2_d = vs2; + return ((vs1_d * vs2_d) - vd); + }); + default: + LOG(ERROR) << "Vfwmsac: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening negated floating point multiply and subtract vd. +void Vfwnmsac(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVTernaryVectorOp<double, float, float>( + rv_vector, inst, [](float vs2, float vs1, double vd) -> double { + double vs1_d = vs1; + double vs2_d = vs2; + return (-(vs1_d * vs2_d)) + vd; + }); + default: + LOG(ERROR) << "Vfwnmsac: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Change the sign of vs2 to the sign of vs1. +void Vfsgnj(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return (vs2 & 0x7fff'ffff) | (vs1 & 0x8000'0000); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return (vs2 & 0x7fff'ffff'ffff'ffff) | + (vs1 & 0x8000'0000'0000'0000); + }); + default: + LOG(ERROR) << "Vfsgnj: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Change the sign of vs2 to the negation of the sign of vs1. +void Vfsgnjn(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return (vs2 & 0x7fff'ffff) | (~vs1 & 0x8000'0000); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return (vs2 & 0x7fff'ffff'ffff'ffff) | + (~vs1 & 0x8000'0000'0000'0000); + }); + default: + LOG(ERROR) << "Vfsgnjn: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Change the sign of vs2 to the xor of the sign of the two operands. +void Vfsgnjx(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return (vs2 & 0x7fff'ffff) | ((vs1 ^ vs2) & 0x8000'0000); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return (vs2 & 0x7fff'ffff'ffff'ffff) ^ + ((vs1 ^ vs2) & 0x8000'0000'0000'0000); + }); + default: + LOG(ERROR) << "Vfsgnjx: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Templated helper function for vfmin and vfmax instructions. +template <typename T> +inline std::tuple<T, uint32_t> MaxMinHelper(T vs2, T vs1, + std::function<T(T, T)> operation) { + // If either operand is a signaling NaN or if both operands are NaNs, then + // return a canonical (non-signaling) NaN. + uint32_t flag = 0; + if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2)) { + flag = static_cast<uint32_t>(FPExceptions::kInvalidOp); + } + if (FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1)) { + auto c_nan = FPTypeInfo<T>::kCanonicalNaN; + return std::make_tuple(*reinterpret_cast<T *>(&c_nan), flag); + } + // If either operand is a NaN return the other. + if (FPTypeInfo<T>::IsNaN(vs2)) return std::tie(vs1, flag); + if (FPTypeInfo<T>::IsNaN(vs1)) return std::tie(vs2, flag); + // Return the min/max of the two operands. + if ((vs2 == 0.0) && (vs1 == 0.0)) { + T tmp2 = std::signbit(vs2) ? -1.0 : 1; + T tmp1 = std::signbit(vs1) ? -1.0 : 1; + return std::make_tuple(operation(tmp2, tmp1) == tmp2 ? vs2 : vs1, 0); + } + return std::make_tuple(operation(vs2, vs1), flag); +} + +// Vector floating point min. +void Vfmin(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorOpWithFflags<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); + case 8: + return RiscVBinaryVectorOpWithFflags<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); + default: + LOG(ERROR) << "Vfmin: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector floating point max. +void Vfmax(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVBinaryVectorOpWithFflags<float, float, float>( + rv_vector, inst, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + }); + case 8: + return RiscVBinaryVectorOpWithFflags<double, double, double>( + rv_vector, inst, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + }); + default: + LOG(ERROR) << "Vfmax: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Vector fp merge. +void Vfmerge(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { + return mask ? vs1 : vs2; + }); + case 8: + return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { + return mask ? vs1 : vs2; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Vfmerge: Illegal sew (" << sew << ")"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_fp_instructions.h b/cheriot/riscv_cheriot_vector_fp_instructions.h new file mode 100644 index 0000000..46d56b5 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_instructions.h
@@ -0,0 +1,85 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file declares the main binary and ternary floating point instruction +// semantic functions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Vector floating point arithmetic instructions. Each of these instructions +// take three source operands and one destination operand. Source 0 is a vector +// register group, source 1 is either a vector register group or a scalar +// register, and source 2 is the mask register. Destination 0 is a vector +// register group. +void Vfadd(const Instruction *inst); +void Vfsub(const Instruction *inst); +void Vfrsub(const Instruction *inst); +void Vfwadd(const Instruction *inst); +void Vfwsub(const Instruction *inst); +void Vfwaddw(const Instruction *inst); +void Vfwsubw(const Instruction *inst); +void Vfmul(const Instruction *inst); +void Vfdiv(const Instruction *inst); +void Vfrdiv(const Instruction *inst); +void Vfwmul(const Instruction *inst); + +// Vector floating point multiply and add/subtract instructions. Each of these +// instructions take four source operands and one destination operand. Source 0 +// is a vector register group, source 1 is either a vector register group or a +// scalar register, source 2 is a vector register group, and source 3 is the +// mask register. Destination 0 is a vector register group. +void Vfmadd(const Instruction *inst); +void Vfnmadd(const Instruction *inst); +void Vfmsub(const Instruction *inst); +void Vfnmsub(const Instruction *inst); +void Vfmacc(const Instruction *inst); +void Vfnmacc(const Instruction *inst); +void Vfmsac(const Instruction *inst); +void Vfnmsac(const Instruction *inst); +void Vfwmacc(const Instruction *inst); +void Vfwnmacc(const Instruction *inst); +void Vfwmsac(const Instruction *inst); +void Vfwnmsac(const Instruction *inst); + +// Vector floating point sign modification instructions. Each of these +// instructions take three source operands and one destination operand. Source 0 +// is a vector register group, source 1 is either a vector register group or a +// scalar register, and source 2 is the mask register. Destination 0 is a vector +// register group. +void Vfsgnj(const Instruction *inst); +void Vfsgnjn(const Instruction *inst); +void Vfsgnjx(const Instruction *inst); + +// Vector selection instructions. Each of these instructions take three source +// operands and one destination operand. Source 0 is a vector register group, +// source 1 is either a vector register group or a scalar register, and source 2 +// is the mask register. Destination 0 is a vector register group. +void Vfmin(const Instruction *inst); +void Vfmax(const Instruction *inst); +void Vfmerge(const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_fp_reduction_instructions.cc b/cheriot/riscv_cheriot_vector_fp_reduction_instructions.cc new file mode 100644 index 0000000..cf72616 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_reduction_instructions.cc
@@ -0,0 +1,182 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_reduction_instructions.h" + +#include <functional> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_vector_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::FPTypeInfo; +using ::mpact::sim::riscv::ScopedFPStatus; + +// These reduction instructions take an accumulator and a value and returns +// the result of the reduction operation. Each partial sum is stored to a +// separate entry in the destination vector. + +// Sum reduction. +void Vfredosum(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryReductionVectorOp<float, float, float>( + rv_vector, inst, + [](float acc, float vs2) -> float { return acc + vs2; }); + return; + case 8: + return RiscVBinaryReductionVectorOp<double, double, double>( + rv_vector, inst, + [](double acc, double vs2) -> double { return acc + vs2; }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +void Vfwredosum(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryReductionVectorOp<double, float, double>( + rv_vector, inst, [](double acc, float vs2) -> double { + return acc + static_cast<double>(vs2); + }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Templated helper function for vfmin and vfmax instructions. +template <typename T> +inline T MaxMinHelper(T vs2, T vs1, std::function<T(T, T)> operation) { + // If either operand is a signaling NaN or if both operands are NaNs, then + // return a canonical (non-signaling) NaN. + if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2) || + (FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1))) { + typename FPTypeInfo<T>::UIntType c_nan = FPTypeInfo<T>::kCanonicalNaN; + return *reinterpret_cast<T *>(&c_nan); + } + // If either operand is a NaN return the other. + if (FPTypeInfo<T>::IsNaN(vs2)) return vs1; + if (FPTypeInfo<T>::IsNaN(vs1)) return vs2; + // Return the min/max of the two operands. + return operation(vs2, vs1); +} + +// FP min reduction. +void Vfredmin(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryReductionVectorOp<float, float, float>( + rv_vector, inst, [](float acc, float vs2) -> float { + return MaxMinHelper<float>(acc, vs2, + [](float acc, float vs2) -> float { + return (acc > vs2) ? vs2 : acc; + }); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<double, double, double>( + rv_vector, inst, [](double acc, double vs2) -> double { + return MaxMinHelper<double>(acc, vs2, + [](double acc, double vs2) -> double { + return (acc > vs2) ? vs2 : acc; + }); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// FP max reduction. +void Vfredmax(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVBinaryReductionVectorOp<float, float, float>( + rv_vector, inst, [](float acc, float vs2) -> float { + return MaxMinHelper<float>(acc, vs2, + [](float acc, float vs2) -> float { + return (acc < vs2) ? vs2 : acc; + }); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<double, double, double>( + rv_vector, inst, [](double acc, double vs2) -> double { + return MaxMinHelper<double>(acc, vs2, + [](double acc, double vs2) -> double { + return (acc < vs2) ? vs2 : acc; + }); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_fp_reduction_instructions.h b/cheriot/riscv_cheriot_vector_fp_reduction_instructions.h new file mode 100644 index 0000000..bfb4b6c --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_reduction_instructions.h
@@ -0,0 +1,42 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_REDUCTION_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_REDUCTION_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file declares the semantic functions for the vector floating point +// reduction instructions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Each of these instruction semantic functions take 3 source operands and 1 +// destination operand. Source 0 is a vector register group, source 1 is a +// vector register, and source 2 is the vector mask register. Destination +// operand 0 is a vector register group. +void Vfredosum(const Instruction *inst); +void Vfwredosum(const Instruction *inst); +void Vfredmin(const Instruction *inst); +void Vfredmax(const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_REDUCTION_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_fp_unary_instructions.cc b/cheriot/riscv_cheriot_vector_fp_unary_instructions.cc new file mode 100644 index 0000000..06db4e3 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_unary_instructions.cc
@@ -0,0 +1,968 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_unary_instructions.h" + +#include <cmath> +#include <cstdint> +#include <limits> +#include <tuple> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_instruction_helpers.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_vector_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::FPTypeInfo; +using ::mpact::sim::riscv::FPExceptions; +using ::mpact::sim::riscv::ScopedFPStatus; + +// These tables contain the 7 bits of mantissa used by the approximated +// reciprocal square root and reciprocal instructions. +static const int kRecipSqrtMantissaTable[128] = { + 52, 51, 50, 48, 47, 46, 44, 43, 42, 41, 40, 39, 38, 36, 35, + 34, 33, 32, 31, 30, 30, 29, 28, 27, 26, 25, 24, 23, 23, 22, + 21, 20, 19, 19, 18, 17, 16, 16, 15, 14, 14, 13, 12, 12, 11, + 10, 10, 9, 9, 8, 7, 7, 6, 6, 5, 4, 4, 3, 3, 2, + 2, 1, 1, 0, 127, 125, 123, 121, 119, 118, 116, 114, 113, 111, 109, + 108, 106, 105, 103, 102, 100, 99, 97, 96, 95, 93, 92, 91, 20, 88, + 87, 86, 85, 84, 83, 82, 80, 79, 78, 77, 76, 75, 74, 73, 72, + 71, 70, 70, 69, 68, 67, 66, 65, 64, 63, 63, 62, 61, 60, 59, + 59, 58, 57, 56, 56, 55, 54, 53, +}; + +static const int kRecipMantissaTable[128] = { + 127, 125, 123, 121, 119, 117, 116, 114, 112, 110, 109, 107, 105, 104, 102, + 100, 99, 97, 96, 94, 93, 91, 90, 88, 87, 85, 84, 83, 81, 80, + 79, 77, 76, 75, 74, 72, 71, 70, 69, 68, 66, 65, 64, 63, 62, + 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, + 46, 45, 44, 43, 42, 41, 40, 40, 39, 38, 37, 36, 35, 35, 34, + 33, 32, 31, 31, 30, 29, 28, 28, 27, 26, 25, 25, 24, 23, 23, + 22, 21, 21, 20, 19, 19, 18, 17, 17, 16, 15, 15, 14, 14, 13, + 12, 12, 11, 11, 10, 9, 9, 8, 8, 7, 7, 6, 5, 5, 4, + 4, 3, 3, 2, 2, 1, 1, 0}; + +// Move float from scalar fp register to vector register(all elements). +void Vfmvvf(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + const int vl = rv_vector->vector_length(); + if (rv_vector->vstart() > 0) return; + if (vl == 0) return; + + const int sew = rv_vector->selected_element_width(); + auto dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto dest_db = dest_op->CopyDataBuffer(); + switch (sew) { + case 4: + for (int i = 0; i < vl; ++i) { + dest_db->Set<uint32_t>( + i, generic::GetInstructionSource<uint32_t>(inst, 0, 0)); + } + break; + case 8: + for (int i = 0; i < vl; ++i) { + dest_db->Set<uint64_t>( + i, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); + } + break; + default: + dest_db->DecRef(); + LOG(ERROR) << "Vfmv.s.f: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Move float from vector to scalar fp register(first element). +void Vfmvsf(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart() > 0) return; + if (rv_vector->vector_length() == 0) return; + int sew = rv_vector->selected_element_width(); + auto dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto dest_db = dest_op->CopyDataBuffer(); + switch (sew) { + case 4: + dest_db->Set<uint32_t>( + 0, generic::GetInstructionSource<uint32_t>(inst, 0, 0)); + break; + case 8: + dest_db->Set<uint64_t>( + 0, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); + break; + default: + dest_db->DecRef(); + LOG(ERROR) << "Vfmv.s.f: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Move scalar floating point value to element 0 of vector register. +void Vfmvfs(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + auto dest_op = inst->Destination(0); + auto dest_db = dest_op->AllocateDataBuffer(); + int db_size = dest_db->size<uint8_t>(); + switch (sew) { + case 4: { + uint64_t value = generic::GetInstructionSource<uint32_t>(inst, 0, 0); + if (db_size == 4) { + dest_db->Set<uint32_t>(0, value); + } else if (db_size == 8) { + uint64_t val64 = 0xffff'ffff'0000'0000ULL | value; + dest_db->Set<uint64_t>(0, val64); + } else { + LOG(ERROR) << "Unexpected databuffer size in Vfmvfs"; + } + break; + } + case 8: + dest_db->Set<uint64_t>( + 0, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); + break; + default: + dest_db->DecRef(); + LOG(ERROR) << "Vfmv.f.s: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Convert floating point to unsigned integer. +void Vfcvtxufv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint32_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint32_t, uint32_t> { + return CvtHelper<float, uint32_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<uint64_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<uint64_t, uint32_t> { + return CvtHelper<double, uint64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.xu.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Convert floating point to signed integer. +void Vfcvtxfv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int32_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int32_t, uint32_t> { + return CvtHelper<float, int32_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<int64_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<int64_t, uint32_t> { + return CvtHelper<double, int64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.x.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Convert unsigned integer to floating point. +void Vfcvtfxuv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<float, uint32_t>( + rv_vector, inst, + [](uint32_t vs2) -> float { return static_cast<float>(vs2); }); + case 8: + return RiscVUnaryVectorOp<double, uint64_t>( + rv_vector, inst, + [](uint64_t vs2) -> double { return static_cast<double>(vs2); }); + default: + LOG(ERROR) << "Vfcvt.f.xuv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Convert signed integer to floating point. +void Vfcvtfxv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<float, int32_t>( + rv_vector, inst, + [](int32_t vs2) -> float { return static_cast<float>(vs2); }); + case 8: + return RiscVUnaryVectorOp<double, int64_t>( + rv_vector, inst, + [](int64_t vs2) -> double { return static_cast<double>(vs2); }); + default: + LOG(ERROR) << "Vfcvt.f.xv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Convert floating point to unsigned integer with truncation. +void Vfcvtrtzxufv(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint32_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint32_t, uint32_t> { + return CvtHelper<float, uint32_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<uint64_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<uint64_t, uint32_t> { + return CvtHelper<double, uint64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.rtz.xu.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Convert floating point to signed integer with truncation. +void Vfcvtrtzxfv(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int32_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int32_t, uint32_t> { + return CvtHelper<float, int32_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<int64_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<int64_t, uint32_t> { + return CvtHelper<double, int64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.rtz.x.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of floating point to unsigned integer. +void Vfwcvtxufv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint64_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint64_t, uint32_t> { + return CvtHelper<float, uint64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfwcvt.xu.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of floating point to signed integer. +void Vfwcvtxfv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int64_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int64_t, uint32_t> { + return CvtHelper<float, int64_t>(vs2); + }); + default: + LOG(ERROR) << "Vfwcvt.x.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Wideing conversion of floating point to floating point. +void Vfwcvtffv(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<double, float>( + rv_vector, inst, + [](float vs2) -> double { return static_cast<double>(vs2); }); + default: + LOG(ERROR) << "Vfwcvt.f.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of unsigned integer to floating point. +void Vfwcvtfxuv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 2: + return RiscVUnaryVectorOp<float, uint16_t>( + rv_vector, inst, + [](uint16_t vs2) -> float { return static_cast<float>(vs2); }); + case 4: + return RiscVUnaryVectorOp<double, uint32_t>( + rv_vector, inst, + [](uint32_t vs2) -> double { return static_cast<double>(vs2); }); + default: + LOG(ERROR) << "Vfwcvt.f.xuv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of signed integer to floating point. +void Vfwcvtfxv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 2: + return RiscVUnaryVectorOp<float, int16_t>( + rv_vector, inst, + [](int16_t vs2) -> float { return static_cast<float>(vs2); }); + case 4: + return RiscVUnaryVectorOp<double, int32_t>( + rv_vector, inst, + [](int32_t vs2) -> double { return static_cast<double>(vs2); }); + default: + LOG(ERROR) << "Vfwcvt.f.xuv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of floating point to unsigned integer with truncation. +void Vfwcvtrtzxufv(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint64_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint64_t, uint32_t> { + return CvtHelper<float, uint64_t>(vs2); + }); + default: + LOG(ERROR) << "Vwfcvt.rtz.xu.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Widening conversion of floating point to signed integer with truncation. +void Vfwcvtrtzxfv(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int64_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int64_t, uint32_t> { + return CvtHelper<float, int64_t>(vs2); + }); + default: + LOG(ERROR) << "Vwfcvt.rtz.x.fv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to unsigned integer. +void Vfncvtxufw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint16_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint16_t, uint32_t> { + return CvtHelper<float, uint16_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<uint32_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<uint32_t, uint32_t> { + return CvtHelper<double, uint32_t>(vs2); + }); + default: + LOG(ERROR) << "Vfncvt.xu.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to signed integer. +void Vfncvtxfw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int16_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int16_t, uint32_t> { + return CvtHelper<float, int16_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<int32_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<int32_t, uint32_t> { + return CvtHelper<double, int32_t>(vs2); + }); + default: + LOG(ERROR) << "Vfncvt.x.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to floating point. +void Vfncvtffw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 8: + return RiscVUnaryVectorOp<float, double>( + rv_vector, inst, + [](double vs2) -> float { return static_cast<float>(vs2); }); + default: + LOG(ERROR) << "Vfwcvt.f.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to floating point rounding to odd. +void Vfncvtrodffw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + // The rounding mode is round to odd, which means that the lsb of the new + // mantissa is either 1 or it is the logical or of all the bits to the right + // in the original width mantissa. + switch (sew) { + case 8: + return RiscVUnaryVectorOp<float, double>( + rv_vector, inst, [](double vs2) -> float { + if (FPTypeInfo<double>::IsNaN(vs2) || + FPTypeInfo<double>::IsInf(vs2)) { + return static_cast<float>(vs2); + } + using UIntD = typename FPTypeInfo<double>::UIntType; + using UIntF = typename FPTypeInfo<float>::UIntType; + UIntD uval = *reinterpret_cast<UIntD *>(&vs2); + int sig_diff = + FPTypeInfo<double>::kSigSize - FPTypeInfo<float>::kSigSize; + UIntD mask = (1ULL << sig_diff) - 1; + UIntF bit = (mask & uval) != 0; + auto res = static_cast<float>(vs2); + if (FPTypeInfo<float>::IsInf(res)) return res; + UIntF ures = *reinterpret_cast<UIntF *>(&res); + ures |= bit; + return *reinterpret_cast<float *>(&ures); + }); + default: + LOG(ERROR) << "Vfwcvt.rod.f.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of unsigned integer to floating point. +void Vfncvtfxuw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 8: + return RiscVUnaryVectorOp<float, uint64_t>( + rv_vector, inst, + [](uint64_t vs2) -> float { return static_cast<float>(vs2); }); + default: + LOG(ERROR) << "Vfncvt.f.xuw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of signed integeer to floating point. +void Vfncvtfxw(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 8: + return RiscVUnaryVectorOp<float, int64_t>( + rv_vector, inst, + [](int64_t vs2) -> float { return static_cast<float>(vs2); }); + default: + LOG(ERROR) << "Vfncvt.f.xw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to unsigned integer with truncation. +void Vfncvtrtzxufw(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<uint16_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<uint16_t, uint32_t> { + return CvtHelper<float, uint16_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<uint32_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<uint32_t, uint32_t> { + return CvtHelper<double, uint32_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.rtz.xu.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Narrowing conversion of floating point to signed integer with truncation. +void Vfncvtrtzxfw(const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<int16_t, float>( + rv_vector, inst, [](float vs2) -> std::tuple<int16_t, uint32_t> { + return CvtHelper<float, int16_t>(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<int32_t, double>( + rv_vector, inst, [](double vs2) -> std::tuple<int32_t, uint32_t> { + return CvtHelper<double, int32_t>(vs2); + }); + default: + LOG(ERROR) << "Vfcvt.rtz.xu.fw: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Templated helper function to compute square root. +template <typename T> +inline std::tuple<T, uint32_t> SqrtHelper(T vs2) { + uint32_t flags = 0; + T res; + if (FPTypeInfo<T>::IsNaN(vs2) || vs2 < 0.0) { + auto value = FPTypeInfo<T>::kCanonicalNaN; + res = *reinterpret_cast<T *>(&value); + flags = *FPExceptions::kInvalidOp; + return std::tie(res, flags); + } + if (vs2 == 0.0) return std::tie(vs2, flags); + res = sqrt(vs2); + return std::tie(res, flags); +} + +// Square root. +void Vfsqrtv(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + uint32_t flags = 0; + { + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + switch (sew) { + case 4: + RiscVUnaryVectorOp<float, float>(rv_vector, inst, + [&flags](float vs2) -> float { + auto [res, f] = SqrtHelper(vs2); + flags |= f; + return res; + }); + break; + case 8: + RiscVUnaryVectorOp<double, double>(rv_vector, inst, + [&flags](double vs2) -> double { + auto [res, f] = SqrtHelper(vs2); + flags |= f; + return res; + }); + break; + default: + LOG(ERROR) << "Vffcvt.f.xuv: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } + } + auto *fflags = rv_fp->fflags(); + fflags->Write(flags | fflags->AsUint32()); +} + +// Templated helper function to compute the Reciprocal Square Root +// approximation for valid inputs. +template <typename T> +inline T RecipSqrt7(T value) { + using Uint = typename FPTypeInfo<T>::UIntType; + Uint uint_value = *reinterpret_cast<Uint *>(&value); + // The input value is positive. Negative values are already handled. + int norm_exponent = + (uint_value & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; + Uint norm_mantissa = uint_value & FPTypeInfo<T>::kSigMask; + if (norm_exponent == 0) { // The value is a denormal. + Uint mask = static_cast<Uint>(1) << (FPTypeInfo<T>::kSigSize - 1); + // Normalize the mantissa and exponent by shifting the mantissa left until + // the most significant bit is one. + while ((norm_mantissa & mask) == 0) { + norm_exponent--; + norm_mantissa <<= 1; + } + // Shift it left once more - so it becomes the "implied" bit, and not used + // in the lookup below. + norm_mantissa <<= 1; + } + int index = (norm_exponent & 0b1) << 6 | + ((norm_mantissa >> (FPTypeInfo<T>::kSigSize - 6)) & 0b11'1111); + Uint new_mantissa = static_cast<Uint>(kRecipSqrtMantissaTable[index]) + << (FPTypeInfo<T>::kSigSize - 7); + Uint new_exponent = (3 * FPTypeInfo<T>::kExpBias - 1 - norm_exponent) / 2; + Uint new_value = (new_exponent << FPTypeInfo<T>::kSigSize) | new_mantissa; + T new_fp_value = *reinterpret_cast<T *>(&new_value); + return new_fp_value; +} + +// Templated helper function to compute the Reciprocal Square Root +// approximation for all values. +template <typename T> +inline std::tuple<T, uint32_t> RecipSqrt7Helper(T value) { + auto fp_class = std::fpclassify(value); + T return_value = std::numeric_limits<T>::quiet_NaN(); + uint32_t fflags = 0; + switch (fp_class) { + case FP_INFINITE: + return_value = + std::signbit(value) ? std::numeric_limits<T>::quiet_NaN() : 0.0; + fflags = (uint32_t)FPExceptions::kInvalidOp; + break; + case FP_NAN: + // Just propagate the NaN. + return_value = std::numeric_limits<T>::quiet_NaN(); + fflags = (uint32_t)FPExceptions::kInvalidOp; + break; + case FP_ZERO: + return_value = std::signbit(value) ? -std::numeric_limits<T>::infinity() + : std::numeric_limits<T>::infinity(); + fflags = (uint32_t)FPExceptions::kDivByZero; + break; + case FP_SUBNORMAL: + case FP_NORMAL: + if (std::signbit(value)) { + return_value = std::numeric_limits<T>::quiet_NaN(); + fflags = (uint32_t)FPExceptions::kInvalidOp; + } else { + return_value = RecipSqrt7(value); + } + break; + default: + LOG(ERROR) << "RecipSqrt7Helper: Illegal fp_class (" << fp_class << ")"; + break; + } + return std::make_tuple(return_value, fflags); +} + +// Approximation of reciprocal square root to 7 bits mantissa. +void Vfrsqrt7v(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOpWithFflags<float, float>( + rv_vector, inst, [rv_fp](float vs2) -> std::tuple<float, uint32_t> { + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + return RecipSqrt7Helper(vs2); + }); + case 8: + return RiscVUnaryVectorOpWithFflags<double, double>( + rv_vector, inst, [rv_fp](double vs2) -> std::tuple<double, uint32_t> { + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + return RecipSqrt7Helper(vs2); + }); + default: + LOG(ERROR) << "vfrsqrt7.v: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Templated helper function to compute the Reciprocal approximation for valid +// normal floating point inputs. +template <typename T> +inline T Recip7(T value, FPRoundingMode rm) { + using Uint = typename FPTypeInfo<T>::UIntType; + using Int = typename FPTypeInfo<T>::IntType; + Uint uint_value = *reinterpret_cast<Uint *>(&value); + Int norm_exponent = + (uint_value & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; + Uint norm_mantissa = uint_value & FPTypeInfo<T>::kSigMask; + if (norm_exponent == 0) { // The value is a denormal. + Uint msb = static_cast<Uint>(1) << (FPTypeInfo<T>::kSigSize - 1); + // Normalize the mantissa and exponent by shifting the mantissa left until + // the most significant bit is one. + while (norm_mantissa && ((norm_mantissa & msb) == 0)) { + norm_exponent--; + norm_mantissa <<= 1; + } + // Shift it left once more - so it becomes the "implied" bit, and not used + // in the lookup below. + norm_mantissa <<= 1; + } + Int new_exponent = 2 * FPTypeInfo<T>::kExpBias - 1 - norm_exponent; + // If the exponent is too high, then return exceptional values. + if (new_exponent > 2 * FPTypeInfo<T>::kExpBias) { + switch (rm) { + case FPRoundingMode::kRoundDown: + return std::signbit(value) ? -std::numeric_limits<T>::infinity() + : std::numeric_limits<T>::max(); + case FPRoundingMode::kRoundTowardsZero: + return std::signbit(value) ? std::numeric_limits<T>::lowest() + : std::numeric_limits<T>::max(); + case FPRoundingMode::kRoundToNearestTiesToMax: + case FPRoundingMode::kRoundToNearest: + return std::signbit(value) ? -std::numeric_limits<T>::infinity() + : std::numeric_limits<T>::infinity(); + case FPRoundingMode::kRoundUp: + return std::signbit(value) ? std::numeric_limits<T>::lowest() + : std::numeric_limits<T>::infinity(); + default: + // kDynamic can't happen. + return std::numeric_limits<T>::quiet_NaN(); + } + } + // Perform table lookup and compute the new value using the new exponent. + int index = (norm_mantissa >> (FPTypeInfo<T>::kSigSize - 7)) & 0b111'1111; + Uint new_mantissa = static_cast<Uint>(kRecipMantissaTable[index]) + << (FPTypeInfo<T>::kSigSize - 7); + // If the new exponent is negative or 0, the result is denormal. First + // shift the mantissa right and or in the implied '1'. + if (new_exponent <= 0) { + new_mantissa = (new_mantissa >> 1) | 0b100'0000; + // If the exponent is less than 0, shift the mantissa right. + if (new_exponent < 0) { + new_mantissa >>= 1; + new_exponent = 0; + } + new_mantissa &= 0b111'1111; + } + Uint new_value = (new_exponent << FPTypeInfo<T>::kSigSize) | new_mantissa; + T new_fp_value = *reinterpret_cast<T *>(&new_value); + return value < 0.0 ? -new_fp_value : new_fp_value; +} + +// Templated helper function to compute the Reciprocal approximation for all +// values including non-normal floating point values. +template <typename T> +inline T Recip7Helper(T value, FPRoundingMode rm) { + auto fp_class = std::fpclassify(value); + + switch (fp_class) { + case FP_INFINITE: + // TODO: raise exception. + return std::signbit(value) ? -0.0 : 0; + case FP_NAN: + // Just propagate the NaN. + return std::numeric_limits<T>::quiet_NaN(); + case FP_ZERO: + return std::signbit(value) ? -std::numeric_limits<T>::infinity() + : std::numeric_limits<T>::infinity(); + case FP_SUBNORMAL: + case FP_NORMAL: + return Recip7(value, rm); + } + return std::numeric_limits<T>::quiet_NaN(); +} + +// Approximate reciprocal to 7 bits of mantissa. +void Vfrec7v(const Instruction *inst) { + auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); + auto rm = rv_fp->GetRoundingMode(); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<float, float>( + rv_vector, inst, + [rm](float vs2) -> float { return Recip7Helper(vs2, rm); }); + case 8: + return RiscVUnaryVectorOp<double, double>( + rv_vector, inst, + [rm](double vs2) -> double { return Recip7Helper(vs2, rm); }); + default: + LOG(ERROR) << "vfrec7.v: Illegal sew (" << sew << ")"; + rv_vector->set_vector_exception(); + return; + } +} + +// Classify floating point value. +void Vfclassv(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<uint32_t, float>( + rv_vector, inst, [](float vs2) -> uint32_t { + return static_cast<uint32_t>(ClassifyFP(vs2)); + }); + case 8: + return RiscVUnaryVectorOp<uint64_t, double>( + rv_vector, inst, [](double vs2) -> uint64_t { + return static_cast<uint64_t>(ClassifyFP(vs2)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "vfclass.v: Illegal sew (" << sew << ")"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_fp_unary_instructions.h b/cheriot/riscv_cheriot_vector_fp_unary_instructions.h new file mode 100644 index 0000000..ad607d3 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_fp_unary_instructions.h
@@ -0,0 +1,115 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_UNARY_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_UNARY_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file lists the semantic functions for RiscV vector unary floating point +// instructions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Move a single floating point value from vector[0] to vector[vl-1]. +void Vfmvvf(const Instruction *inst); +// Move a single floating point value from vector[0] to scalar fp register. +void Vfmvsf(const Instruction *inst); +// Move single floating point value from scalar fp register to vector[0]. +void Vfmvfs(const Instruction *inst); + +// Each of the following semantic functions take 2 source operands and 1 +// destination operand. Source operand 0 is a vector register group, source +// operand 1 is a vector mask register, and destination operand 0 is a vector +// register group. + +// Vector element conversion instructions. These convert same sized values +// from/to signed/unsigned integer and floating point. The 'rtz' versions use +// truncation (round to zero), whereas the others use the dynamically set +// rounding mode. + +// FP to unsigned integer. +void Vfcvtxufv(const Instruction *inst); +// FP to signed integer. +void Vfcvtxfv(const Instruction *inst); +// Unsigned integer to FP. +void Vfcvtfxuv(const Instruction *inst); +// Signed integer to FP. +void Vfcvtfxv(const Instruction *inst); +// FP to unsigned integer using round to zero. +void Vfcvtrtzxufv(const Instruction *inst); +// FP to signed integer using round to zero. +void Vfcvtrtzxfv(const Instruction *inst); + +// Vector element widening conversion instructions. These convert values from/to +// signed/unsigned integer and floating point, where the resulting value is 2x +// the width of the source value. The 'rtz' versions use truncation (round to +// zero), whereas the others use the dynamically set rounding mode. + +// FP to wider unsigned integer. +void Vfwcvtxufv(const Instruction *inst); +// FP to wider signed integer. +void Vfwcvtxfv(const Instruction *inst); +// FP to next wider FP. +void Vfwcvtffv(const Instruction *inst); +// Unsigned integer to wider FP. +void Vfwcvtfxuv(const Instruction *inst); +// Signed integer to wider FP. +void Vfwcvtfxv(const Instruction *inst); +// FP to wider unsigned integer using round to zero. +void Vfwcvtrtzxufv(const Instruction *inst); +// FP to wider signed integer using round to zero. +void Vfwcvtrtzxfv(const Instruction *inst); + +// Vector element widening conversion instructions. These convert values from/to +// signed/unsigned integer and floating point, where the resulting value is 1/2x +// the width of the source value. The 'rtz' versions use truncation (round to +// zero), the 'rod' version uses 'round to odd', whereas the others use the +// dynamically set rounding mode. + +// FP to narrower unsigned integer. +void Vfncvtxufw(const Instruction *inst); +// FP to narrower signed integer. +void Vfncvtxfw(const Instruction *inst); +// FP to next narrower FP. +void Vfncvtffw(const Instruction *inst); +// FP to next narrower FP with round to odd. +void Vfncvtrodffw(const Instruction *inst); +// Unsigned integer to narrower FP. +void Vfncvtfxuw(const Instruction *inst); +// Signed integer to narrower FP. +void Vfncvtfxw(const Instruction *inst); +// FP to narrower unsigned integer using round to zero. +void Vfncvtrtzxufw(const Instruction *inst); +// FP to narrower signed integer using round to zero. +void Vfncvtrtzxfw(const Instruction *inst); + +// Vector element square root instruction. +void Vfsqrtv(const Instruction *inst); +// Vector element approximate reciprocal square root instruction. +void Vfrsqrt7v(const Instruction *inst); +// Vector element approximate reciprocal instruction. +void Vfrec7v(const Instruction *inst); +// Vector element floating point value classify instruction. +void Vfclassv(const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_FP_UNARY_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_instruction_helpers.h b/cheriot/riscv_cheriot_vector_instruction_helpers.h new file mode 100644 index 0000000..b7640e0 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_instruction_helpers.h
@@ -0,0 +1,748 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_MPACT_RISCV_RISCV_RISCV_VECTOR_INSTRUCTION_HELPERS_H_ +#define THIRD_PARTY_MPACT_RISCV_RISCV_RISCV_VECTOR_INSTRUCTION_HELPERS_H_ + +#include <algorithm> +#include <cstdint> +#include <functional> +#include <limits> +#include <optional> +#include <tuple> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_vector_state.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_register.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::cheriot::CheriotVectorState; +using ::mpact::sim::generic::FPTypeInfo; +using ::mpact::sim::generic::GetInstructionSource; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::riscv::FPExceptions; +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; +using ::mpact::sim::riscv::ScopedFPStatus; +using ::mpact::sim::riscv::VectorLoadContext; + +// This helper function handles the case of instructions that target a vector +// mask. +// It clears the masked bit and uses the mask value in the +// instruction, such as carry generation from add with carry. +// Note that this function will modify masked bits no matter what the mask +// value is. +template <typename Vs2, typename Vs1> +void RiscVSetMaskBinaryVectorMaskOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<bool(Vs2, Vs1, bool)> op) { + if (rv_vector->vector_exception()) return; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Get the vector start element index and compute where to start + // the operation. + const int num_elements = rv_vector->vector_length(); + const int vector_index = rv_vector->vstart(); + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(); + auto dest_span = dest_db->Get<uint8_t>(); + // Determine if it's vector-vector or vector-scalar. + const bool vector_scalar = inst->Source(1)->shape()[0] == 1; + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + bool vm_unmasked_bit = false; + if (inst->SourcesSize() > 3) { + vm_unmasked_bit = GetInstructionSource<bool>(inst, 3); + } + const bool mask_used = !vm_unmasked_bit; + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + for (int i = vector_index; i < num_elements; i++) { + const int mask_index = i >> 3; + const int mask_offset = i & 0b111; + const bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + const Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, i); + const Vs1 vs1 = GetInstructionSource<Vs1>(inst, 1, vector_scalar ? 0 : i); + + // Clear the masked register bit. + dest_span[mask_index] &= ~(1 << mask_offset); + + // Mask value is used only when `vm_unmasked_bit` is 0. + dest_span[mask_index] |= + (op(vs2, vs1, mask_used & mask_value) << mask_offset); + } + // Submit the destination db . + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// This helper function handles the case of instructions that target a vector +// mask and uses the mask value in the instruction, such as carry generation +// from add with carry. +template <typename Vs2, typename Vs1> +void RiscVMaskBinaryVectorMaskOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<bool(Vs2, Vs1, bool)> op) { + if (rv_vector->vector_exception()) return; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Get the vector start element index and compute where to start + // the operation. + int num_elements = rv_vector->vector_length(); + int vector_index = rv_vector->vstart(); + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(); + auto dest_span = dest_db->Get<uint8_t>(); + // Determine if it's vector-vector or vector-scalar. + bool vector_scalar = inst->Source(1)->shape()[0] == 1; + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + bool vm_unmasked_bit = false; + if (inst->SourcesSize() > 3) { + vm_unmasked_bit = GetInstructionSource<bool>(inst, 3); + } + const bool mask_used = !vm_unmasked_bit; + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + for (int i = vector_index; i < num_elements; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + if (mask_used && !mask_value) { + continue; + } + + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, i); + Vs1 vs1 = GetInstructionSource<Vs1>(inst, 1, vector_scalar ? 0 : i); + + // Clear the masked register bit. + dest_span[mask_index] &= ~(1 << mask_offset); + + // Mask value is used only when `vm_unmasked_bit` is 0. + dest_span[mask_index] |= + (op(vs2, vs1, mask_used & mask_value) << mask_offset); + } + // Submit the destination db . + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// This helper function handles the case of vector mask +// operations. +template <typename Vs2, typename Vs1> +void RiscVBinaryVectorMaskOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<bool(Vs2, Vs1)> op) { + RiscVMaskBinaryVectorMaskOp<Vs2, Vs1>( + rv_vector, inst, [op](Vs2 vs2, Vs1 vs1, bool mask_value) -> bool { + if (mask_value) { + return op(vs2, vs1); + } + return false; + }); +} + +// This helper function handles the case of nullary vector +// operations. It implements all the checking necessary for both widening and +// narrowing operations. +template <typename Vd> +void RiscVMaskNullaryVectorOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<Vd(bool)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // There 2 types of instruction with different number of source operands. + // 1. inst vd, vs2, vmask (viota instruction) + // 2. inst vd, vmask (vid instruction) + RV32VectorSourceOperand *vs2_op = nullptr; + RV32VectorSourceOperand *mask_op = nullptr; + if (inst->SourcesSize() > 1) { + vs2_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + } else { + mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + } + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + bool operation_mask = mask_value; + // Instruction with rs2 operand checks vs2 bit value. + if (vs2_op != nullptr) { + const auto rs2_span = + vs2_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + const bool rs2_value = ((rs2_span[mask_index] >> mask_offset) & 0b1); + // If rs2 is set, then the operation is performed. + operation_mask &= rs2_value; + } + + auto result = op(operation_mask); + if (mask_value) { + dest_span[i] = result; + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +// This helper function handles the case of unary vector +// operations. It implements all the checking necessary for both widening and +// narrowing operations. +template <typename Vd, typename Vs2> +void RiscVUnaryVectorOp(CheriotVectorState *rv_vector, const Instruction *inst, + std::function<Vd(Vs2)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int lmul = rv_vector->vector_length_multiplier(); + int sew = rv_vector->selected_element_width(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + if (lmul_vd > 64 || lmul_vd == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value vd (" << lmul_vd << ")"; + return; + } + if (lmul_vs2 > 64 || lmul_vs2 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value vs2 (" << lmul_vs2 << ")"; + return; + } + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + if (mask_value) { + // Compute result. + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, vector_index); + dest_span[i] = op(vs2); + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +// This helper function handles the case of unary vector operations that set +// fflags. It implements all the checking necessary for both widening and +// narrowing operations. +template <typename Vd, typename Vs2> +void RiscVUnaryVectorOpWithFflags( + CheriotVectorState *rv_vector, const Instruction *inst, + std::function<std::tuple<Vd, uint32_t>(Vs2)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int lmul = rv_vector->vector_length_multiplier(); + int sew = rv_vector->selected_element_width(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + if (lmul_vd > 64 || lmul_vd == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value vd (" << lmul_vd << ")"; + return; + } + if (lmul_vs2 > 64 || lmul_vs2 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value vs2 (" << lmul_vs2 << ")"; + return; + } + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Iterate over the number of registers to write. + uint32_t fflags = 0; + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + if (mask_value) { + // Compute result. + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, vector_index); + auto [value, flag] = op(vs2); + dest_span[i] = value; + fflags |= flag; + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + auto *flag_db = inst->Destination(1)->AllocateDataBuffer(); + flag_db->Set<uint32_t>(0, fflags); + flag_db->Submit(); + rv_vector->clear_vstart(); +} + +// This helper function handles the case of mask + two source operand vector +// operations. It implements all the checking necessary for both widening and +// narrowing operations. +template <typename Vd, typename Vs2, typename Vs1> +void RiscVMaskBinaryVectorOp( + CheriotVectorState *rv_vector, const Instruction *inst, + std::function<std::optional<Vd>(Vs2, Vs1, bool)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int lmul = rv_vector->vector_length_multiplier(); + int sew = rv_vector->selected_element_width(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + int lmul_vs1 = lmul * sizeof(Vs1) / sew; + if (lmul_vd > 64 || lmul_vs2 > 64 || lmul_vs1 > 64) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value"; + return; + } + if (lmul_vd == 0 || lmul_vs2 == 0 || lmul_vs1 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value"; + return; + } + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Determine if it's vector-vector or vector-scalar. + bool vector_scalar = inst->Source(1)->shape()[0] == 1; + // Iterate over the number of registers to write. + bool exception = false; + for (int reg = start_reg; + !exception && (reg < max_regs) && (vector_index < num_elements); reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + // Compute result. + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, vector_index); + Vs1 vs1 = GetInstructionSource<Vs1>(inst, 1, + (vector_scalar ? 0 : vector_index)); + auto value = op(vs2, vs1, mask_value); + if (value.has_value()) { + dest_span[i] = value.value(); + } else if (mask_value) { + // If there is no value returned, but the mask_value is true, check + // to see if there was an exception. + if (rv_vector->vector_exception()) { + rv_vector->set_vstart(vector_index); + exception = true; + break; + } + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +// This helper function handles the case of two source operand vector +// operations. It implements all the checking necessary for both widening and +// narrowing operations. +template <typename Vd, typename Vs2, typename Vs1> +void RiscVBinaryVectorOp(CheriotVectorState *rv_vector, const Instruction *inst, + std::function<Vd(Vs2, Vs1)> op) { + RiscVMaskBinaryVectorOp<Vd, Vs2, Vs1>( + rv_vector, inst, + [op](Vs2 vs2, Vs1 vs1, bool mask_value) -> std::optional<Vd> { + if (mask_value) { + return op(vs2, vs1); + } + return std::nullopt; + }); +} + +template <typename Vd, typename Vs2, typename Vs1> +void RiscVBinaryVectorOpWithFflags( + CheriotVectorState *rv_vector, const Instruction *inst, + std::function<std::tuple<Vd, uint32_t>(Vs2, Vs1)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int lmul = rv_vector->vector_length_multiplier(); + int sew = rv_vector->selected_element_width(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + int lmul_vs1 = lmul * sizeof(Vs1) / sew; + if (lmul_vd > 64 || lmul_vs2 > 64 || lmul_vs1 > 64) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value"; + return; + } + if (lmul_vd == 0 || lmul_vs2 == 0 || lmul_vs1 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value"; + return; + } + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Determine if it's vector-vector or vector-scalar. + bool vector_scalar = inst->Source(1)->shape()[0] == 1; + // Iterate over the number of registers to write. + bool exception = false; + uint32_t fflags = 0; + for (int reg = start_reg; + !exception && (reg < max_regs) && (vector_index < num_elements); reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + // Compute result. + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, vector_index); + Vs1 vs1 = GetInstructionSource<Vs1>(inst, 1, + (vector_scalar ? 0 : vector_index)); + if (mask_value) { + auto [value, flag] = op(vs2, vs1); + dest_span[i] = value; + fflags |= flag; + if (rv_vector->vector_exception()) { + rv_vector->set_vstart(vector_index); + exception = true; + break; + } + } + vector_index++; + } + // Submit the destination dbs. + dest_db->Submit(); + item_index = 0; + } + auto *flag_db = inst->Destination(1)->AllocateDataBuffer(); + flag_db->Set<uint32_t>(0, fflags); + flag_db->Submit(); + rv_vector->clear_vstart(); +} + +// This helper function handles three source operand vector operations. It +// implements all the checking necessary for both widening and narrowing +// operations. +template <typename Vd, typename Vs2, typename Vs1> +void RiscVTernaryVectorOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<Vd(Vs2, Vs1, Vd)> op) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int lmul = rv_vector->vector_length_multiplier(); + int sew = rv_vector->selected_element_width(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + int lmul_vs1 = lmul * sizeof(Vs1) / sew; + if (lmul_vd > 64 || lmul_vs2 > 64 || lmul_vs1 > 64) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value"; + return; + } + if (lmul_vd == 0 || lmul_vs2 == 0 || lmul_vs1 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value"; + return; + } + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = (num_elements + elements_per_vector - 1) / elements_per_vector; + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(3)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Determine if it's vector-vector or vector-scalar. + bool vector_scalar = inst->Source(1)->shape()[0] == 1; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = vector_index >> 3; + int mask_offset = vector_index & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + // Compute result. + Vs2 vs2 = GetInstructionSource<Vs2>(inst, 0, vector_index); + Vs1 vs1 = GetInstructionSource<Vs1>(inst, 1, + (vector_scalar ? 0 : vector_index)); + Vd vd = GetInstructionSource<Vd>(inst, 2, vector_index); + if (mask_value) { + dest_span[i] = op(vs2, vs1, vd); + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +// The reduction instructions take Vs1[0], and all the elements (subject to +// masking) from Vs2 and apply the reduction operation to produce a single +// element that is written to Vd[0]. +template <typename Vd, typename Vs2, typename Vs1> +void RiscVBinaryReductionVectorOp(CheriotVectorState *rv_vector, + const Instruction *inst, + std::function<Vd(Vd, Vs2)> op) { + if (rv_vector->vector_exception()) return; + if (rv_vector->vstart()) { + rv_vector->vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + int lmul = rv_vector->vector_length_multiplier(); + int lmul_vd = lmul * sizeof(Vd) / sew; + int lmul_vs2 = lmul * sizeof(Vs2) / sew; + int lmul_vs1 = lmul * sizeof(Vs1) / sew; + if (lmul_vd > 64 || lmul_vs2 > 64 || lmul_vs1 > 64) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul value"; + return; + } + if (lmul_vd == 0 || lmul_vs2 == 0 || lmul_vs1 == 0) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal lmul_value"; + return; + } + int num_elements = rv_vector->vector_length(); + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + Vd accumulator = + static_cast<Vd>(generic::GetInstructionSource<Vs1>(inst, 1, 0)); + for (int i = 0; i < num_elements; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + accumulator = + op(accumulator, generic::GetInstructionSource<Vs2>(inst, 0, i)); + } + } + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto dest_db = dest_op->CopyDataBuffer(); + dest_db->Set<Vd>(0, accumulator); + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +template <typename T> +T GetRoundingBit(int rounding_mode, T rounding_bits, int size) { + switch (rounding_mode) { + case 0: // Round-to-nearest-up (add +0.5 lsb) + if (size < 2) return 0; + return (rounding_bits >> (size - 2)) & 0b1; + case 1: { // Round-to-nearest-event + T v_d_minus_1 = (size < 2) ? 0 : (rounding_bits >> (size - 2)) & 0b1; + T v_d = (size == 0) ? 0 : (rounding_bits >> (size - 1)) & 0b1; + T v_d_minus_2_0 = (size < 3) + ? 0 + : (rounding_bits & ~(std::numeric_limits<T>::max() + << (size - 2))) != 0; + return v_d_minus_1 & (v_d_minus_2_0 | v_d); + } + case 2: // Round-down (truncate). + return 0; + case 3: { // Round-to-odd. + T v_d_minus_1_0 = (size < 2) + ? 0 + : (rounding_bits & ~(std::numeric_limits<T>::max() + << (size - 1))) != 0; + T v_d = (rounding_bits >> (size - 1)) & 0b1; + return (!v_d) & v_d_minus_1_0; + } + default: + LOG(ERROR) << "GetRoundingBit: Invalid value for rounding mode"; + break; + } + return 0; +} + +template <typename T> +T RoundOff(CheriotVectorState *rv_vector, T value, int size) { + auto rm = rv_vector->vxrm(); + auto ret = (value >> size) + GetRoundingBit<T>(rm, value, size + 1); + return ret; +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // THIRD_PARTY_MPACT_RISCV_RISCV_RISCV_VECTOR_INSTRUCTION_HELPERS_H_
diff --git a/cheriot/riscv_cheriot_vector_memory_instructions.cc b/cheriot/riscv_cheriot_vector_memory_instructions.cc new file mode 100644 index 0000000..66f3b41 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_memory_instructions.cc
@@ -0,0 +1,1518 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_memory_instructions.h" + +#include <algorithm> +#include <any> +#include <cstdint> + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_instruction_helpers.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/register.h" +#include "riscv//riscv_register.h" +#include "riscv//riscv_state.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using generic::GetInstructionSource; +using ::mpact::sim::generic::RegisterBase; +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; +using ::mpact::sim::riscv::VectorLoadContext; +using CapReg = CheriotRegister; + +// Helper to get capability register source and destination registers. +static inline CapReg *GetCapSource(const Instruction *instruction, int i) { + return static_cast<CapReg *>( + std::any_cast<RegisterBase *>(instruction->Source(i)->GetObject())); +} + +static inline bool CheckCapForMemoryAccess(const Instruction *instruction, + CapReg *cap_reg, + CheriotState *state) { + // Check for tag unset. + if (!cap_reg->tag()) { + state->HandleCheriRegException(instruction, instruction->address(), + ExceptionCode::kCapExTagViolation, cap_reg); + return false; + } + // Check for sealed. + if (cap_reg->IsSealed()) { + state->HandleCheriRegException(instruction, instruction->address(), + ExceptionCode::kCapExSealViolation, cap_reg); + return false; + } + // Check for permissions. + if (!cap_reg->HasPermission(CheriotRegister::kPermitLoad)) { + state->HandleCheriRegException(instruction, instruction->address(), + ExceptionCode::kCapExPermitLoadViolation, + cap_reg); + return false; + } + return true; +} + +static inline bool CheckCapBounds(const Instruction *instruction, + uint64_t address, int el_width, + CapReg *cap_reg, CheriotState *state) { + // Check for bounds. + if (!cap_reg->IsInBounds(address, el_width)) { + state->HandleCheriRegException(instruction, instruction->address(), + ExceptionCode::kCapExBoundsViolation, + cap_reg); + return false; + } + return true; +} + +// Helper function used by the load child instructions (non segment loads) that +// writes the loaded data into the registers. +template <typename T> +absl::Status WriteBackLoadData(int vector_register_byte_length, + const Instruction *inst) { + // Get values from context. + auto *context = static_cast<VectorLoadContext *>(inst->context()); + auto masks = context->mask_db->Get<bool>(); + auto values = context->value_db->Get<T>(); + int vector_start = context->vstart; + int vector_length = context->vlength; + + int element_size = sizeof(T); + int elements_per_vector = vector_register_byte_length / element_size; + int max_regs = + (vector_length + elements_per_vector - 1) / elements_per_vector; + // Verify that the dest_op has enough registers. Else signal error. + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + if (dest_op->size() < max_regs) { + // TODO: signal error. + return absl::InternalError("Not enough registers in destination operand"); + } + // Compute the number of values to be written. + int value_count = masks.size(); + if (vector_length - vector_start != value_count) { + // TODO: signal error. + return absl::InternalError( + absl::StrCat("The number of mask elements (", value_count, + ") differs from the number of elements to write (", + vector_length - vector_start, ")")); + } + int load_data_index = 0; + int start_reg = vector_start / elements_per_vector; + int item_index = vector_start % elements_per_vector; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (value_count > 0); reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<T>(); + // Write data into register subject to masking. + int count = std::min(elements_per_vector - item_index, value_count); + for (int i = item_index; i < count; i++) { + if (masks[load_data_index + i]) { + dest_span[i] = values[load_data_index + i]; + } + } + value_count -= count; + load_data_index += count; + dest_db->Submit(0); + item_index = 0; + } + return absl::OkStatus(); +} + +// Helper function used by the load child instructions (for segment loads) that +// writes the loaded data into the registers. +template <typename T> +absl::Status WriteBackSegmentLoadData(int vector_register_byte_length, + const Instruction *inst) { + // The number of fields in each segment. + int num_fields = GetInstructionSource<uint32_t>(inst, 0) + 1; + // Get values from context. + auto *context = static_cast<VectorLoadContext *>(inst->context()); + auto masks = context->mask_db->Get<bool>(); + auto values = context->value_db->Get<T>(); + int start_segment = context->vstart; + int vector_length = context->vlength; + + int element_size = sizeof(T); + int num_segments = masks.size() / num_fields; + // Number of registers written for each field. + int max_elements_per_vector = + std::min(vector_register_byte_length / element_size, num_segments); + int num_regs = + std::max(1, num_segments * element_size / vector_register_byte_length); + // Total number of registers written. + int total_regs = num_fields * num_regs; + // Verify that the dest_op has enough registers. Else signal error. + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + if (dest_op->size() < total_regs) { + return absl::InternalError("Not enough registers in destination operand"); + } + // Compute the number of segments to be written. + if (vector_length - start_segment != num_segments) { + return absl::InternalError( + absl::StrCat("The number of mask elements (", num_segments, + ") differs from the number of elements to write (", + vector_length - start_segment, ")")); + } + int load_data_index = 0; + // Data is organized by field. So write back in that order. + for (int field = 0; field < num_fields; field++) { + int start_reg = + field * num_regs + (start_segment / max_elements_per_vector); + int offset = start_segment % max_elements_per_vector; + int remaining_data = num_segments; + for (int reg = start_reg; reg < start_reg + num_regs; reg++) { + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto span = dest_db->Get<T>(); + int max_entry = + std::min(remaining_data + offset, max_elements_per_vector); + for (int i = offset; i < max_entry; i++) { + if (masks[load_data_index]) { + span[i] = values[load_data_index]; + } + load_data_index++; + remaining_data--; + } + offset = 0; + dest_db->Submit(0); + } + } + return absl::OkStatus(); +} + +// This models the vsetvl set of instructions. The immediate versus register +// versions are all modeled by the same function. Flags are bound during decode +// to the two first parameters to specify if rd or rs1 are x0. +void Vsetvl(bool rd_zero, bool rs1_zero, const Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + uint32_t vtype = GetInstructionSource<uint32_t>(inst, 1) & 0b1'1'111'111; + // Get previous vtype. + uint32_t prev_vtype = rv_vector->vtype(); + // Get previous max length. + int old_max_length = rv_vector->max_vector_length(); + // Set the new vector type. + rv_vector->SetVectorType(vtype); + auto new_max_length = rv_vector->max_vector_length(); + uint32_t vl = new_max_length; + if (rs1_zero && rd_zero) { // If rs1 and rd are both zero. + // If max_length changed, then there's an error, otherwise, vector length + // is now vl. + if (old_max_length != new_max_length) { + // ERROR: cannot change max_vector_length. + // Revert, then set error flag. + rv_vector->SetVectorType(prev_vtype); + rv_vector->set_vector_exception(); + return; + } + rv_vector->set_vector_length(new_max_length); + return; + } + if (!rs1_zero) { // There is a requested vector length. + uint32_t avl = GetInstructionSource<uint32_t>(inst, 0); + // Unless the requested vl is less than 2 * max, set it to max. + if (avl <= new_max_length) { + // If the requested vl is less than max use it. + vl = avl; + } + + // The RISCV spec has the following constraint when VLMAX < AVL < 2 * VLMAX: + // ceil(AVL / 2) <= vl <= VLMAX + // + // This allows vl to be assigned to half of the requested AVL value, however + // vl may be assigned to VLMAX instead. SiFive implementations of the RISCV + // vector engine set vl to VLMAX in this case, which is the same approach + // followed here. + } + rv_vector->set_vector_length(vl); + if (!rd_zero) { // Update register if there is a writable destination. + WriteCapIntResult<uint32_t>(inst, 0, vl); + } +} + +// Vector load - models both strided and unit stride. Strides can be positive, +// zero, or negative. + +// Source(0): base address. +// Source(1): vector mask register, vector constant {1..} if not masked. +// Destination(0): vector destination register. +void VlUnitStrided(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int emul = element_width * rv_vector->vector_length_multiplier() / + rv_vector->selected_element_width(); + if ((emul > 64) || (emul == 0)) { + // TODO: signal vector error. + LOG(WARNING) << "EMUL (" << emul << ") out of range"; + return; + } + + // Compute total number of elements to be loaded. + int num_elements = rv_vector->vector_length(); + int num_elements_loaded = num_elements - start; + + // Allocate address data buffer. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded); + + // Allocate the value data buffer that the loaded data is returned in. + auto *value_db = db_factory->Allocate(num_elements_loaded * element_width); + + // Get the source mask (stored in a single vector register). + auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + + // Allocate a byte mask data buffer for the load. + auto *mask_db = db_factory->Allocate<bool>(num_elements_loaded); + + // Get the spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + + // The vector mask in the vector register is a bit mask. The mask used in + // the LoadMemory call is a bool mask so convert the bit masks to bool masks + // and compute the element addresses. + for (int i = start; i < num_elements; i++) { + int index = i >> 3; + int offset = i & 0b111; + addresses[i - start] = base + i * element_width; + masks[i - start] = ((src_masks[index] >> offset) & 0b1) != 0; + if (masks[i - start]) { + if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + return; + } + } + } + + // Set up the context, and submit the load. + auto *context = new VectorLoadContext(value_db, mask_db, element_width, start, + rv_vector->vector_length()); + value_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, value_db, + inst->child(), context); + // Release the context and address_db. The others will be released elsewhere. + context->DecRef(); + address_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Source(0): base address. +// Source(1): stride size bytes. +// Source(2): vector mask register, vector constant {1..} if not masked. +// Destination(0): vector destination register. +void VlStrided(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int64_t stride = GetInstructionSource<int64_t>(inst, 1); + int emul = element_width * rv_vector->vector_length_multiplier() / + rv_vector->selected_element_width(); + if ((emul > 64) || (emul == 0)) { + // TODO: signal vector error. + LOG(WARNING) << "EMUL (" << emul << ") out of range"; + return; + } + + // Compute total number of elements to be loaded. + int num_elements = rv_vector->vector_length(); + int num_elements_loaded = num_elements - start; + + // Allocate address data buffer. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded); + + // Allocate the value data buffer that the loaded data is returned in. + auto *value_db = db_factory->Allocate(num_elements_loaded * element_width); + + // Get the source mask (stored in a single vector register). + auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + + // Allocate a byte mask data buffer for the load. + auto *mask_db = db_factory->Allocate<bool>(num_elements_loaded); + + // Get the spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + + // The vector mask in the vector register is a bit mask. The mask used in + // the LoadMemory call is a bool mask so convert the bit masks to bool masks + // and compute the element addresses. + for (int i = start; i < num_elements; i++) { + int index = i >> 3; + int offset = i & 0b111; + addresses[i - start] = base + i * stride; + masks[i - start] = ((src_masks[index] >> offset) & 0b1) != 0; + if (masks[i - start]) { + if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + return; + } + } + } + + // Set up the context, and submit the load. + auto *context = new VectorLoadContext(value_db, mask_db, element_width, start, + rv_vector->vector_length()); + value_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, value_db, + inst->child(), context); + // Release the context and address_db. The others will be released elsewhere. + context->DecRef(); + address_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load vector-mask. This is simple, just a single register. + +// Source(0): base address. +// Destination(0): vector destination register (for the child instruction). +void Vlm(const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + // Compute the number of bytes to be loaded. + int num_bytes = rv_vector->vector_register_byte_length() - start; + // Allocate address data buffer. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_bytes); + // Allocate the value data buffer that the loaded data is returned in. + auto *value_db = db_factory->Allocate<uint8_t>(num_bytes); + // Allocate a byte mask data buffer. + auto *mask_db = db_factory->Allocate<bool>(num_bytes); + // Get the spans for addresses and masks. + auto masks = mask_db->Get<bool>(); + auto addresses = address_db->Get<uint64_t>(); + // Set up addresses, mark all masks elements as true. + for (int i = start; i < num_bytes; i++) { + addresses[i - start] = base + i; + masks[i - start] = true; + if (!CheckCapBounds(inst, addresses[i - start], 1, cap_reg, state)) { + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + return; + } + } + // Set up the context, and submit the load. + auto *context = + new VectorLoadContext(value_db, mask_db, sizeof(uint8_t), start, + rv_vector->vector_register_byte_length()); + auto *rv32_state = static_cast<CheriotState *>(inst->state()); + value_db->set_latency(0); + rv32_state->LoadMemory(inst, address_db, mask_db, sizeof(uint8_t), value_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load indexed (ordered and unordered). Index values are not scaled by +// element size, as the index values can also be treated as multiple base +// addresses with the base address acting as a common offset. Index values are +// treated as unsigned integers, and are zero extended from the element size to +// the internal address size (or truncated in case the internal XLEN is < index +// element size). + +// Source(0) base address. +// Source(1) index vector. +// Source(2) masks. +// Destination(0): vector destination register (for the child instruction). +void VlIndexed(int index_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int element_width = rv_vector->selected_element_width(); + int lmul = rv_vector->vector_length_multiplier(); + auto *index_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + int index_emul = index_width * lmul / element_width; + // Validate that emul has a legal value. + if ((index_emul > 64) || (index_emul == 0)) { + // TODO: signal vector error. + LOG(WARNING) << absl::StrCat( + "Vector load indexed: emul (index) out of range: ", index_emul); + rv_vector->set_vector_exception(); + return; + } + + // Compute the number of bytes and elements to be loaded. + int num_elements = rv_vector->vector_length(); + int num_elements_loaded = num_elements - start; + int num_bytes_loaded = num_elements_loaded * element_width; + + // Allocate address data buffer. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded); + auto addresses = address_db->Get<uint64_t>(); + + // Allocate the value data buffer that the loaded data is returned in. + auto *value_db = db_factory->Allocate(num_bytes_loaded); + + // Get the source mask (stored in a single vector register). + auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + + // Allocate a byte mask data buffer for the load. + auto *mask_db = db_factory->Allocate<bool>(num_elements); + auto masks = mask_db->Get<bool>(); + + // Convert the bit masks to byte masks and compute the element addresses. + // The index elements are treated as unsigned values. + for (int i = start; i < num_elements; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + uint64_t offset; + switch (index_width) { + case 1: + offset = index_op->AsUint8(i); + break; + case 2: + offset = index_op->AsUint16(i); + break; + case 4: + offset = index_op->AsUint32(i); + break; + case 8: + offset = index_op->AsUint64(i); + break; + default: + offset = 0; + LOG(ERROR) << absl::StrCat("Illegal index width (", index_width, ")"); + rv_vector->set_vector_exception(); + break; + } + addresses[i - start] = base + offset; + masks[i - start] = ((src_masks[mask_index] >> mask_offset) & 0b1) != 0; + if (masks[i - start]) { + if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + return; + } + } + } + + // Set up context and submit load. + auto *context = new VectorLoadContext(value_db, mask_db, element_width, start, + rv_vector->vector_length()); + value_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, value_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load whole register(s). The number of registers is passed as +// a parameter to this function - bound to the called function object by the +// instruction decoder. Simple function, no masks, no diffrentiation between +// element sizes. +// Source(0): base address. +// Destination(0): vector destination register (for the child instruction). +void VlRegister(int num_regs, int element_width_bytes, + const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int num_elements = + rv_vector->vector_register_byte_length() * num_regs / element_width_bytes; + // Allocate data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width_bytes); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + + // Compute addresses and set masks to true. + // Note that the width of each load operation is `element_width_bytes`, not + // SEW (selected element width). + // The SEW is the width of vector element of the vector register, and the + // element width here is the width of the data being loaded, it may differ + // from SEW. + for (int i = 0; i < num_elements; i++) { + addresses[i] = base + i * element_width_bytes; + masks[i] = true; + if (!CheckCapBounds(inst, addresses[i], element_width_bytes, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + } + + // Set up context and submit load. + auto *context = new VectorLoadContext(data_db, mask_db, element_width_bytes, + 0, num_elements); + data_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width_bytes, data_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load segment, unit stride. The stride is the size of each segment, +// i.e., number of fields * element size. The first field of each segment is +// loaded into the first register, the second into the second, etc. If there +// are more segments than elements in the vector register, adjacent vector +// registers are grouped together. So the first field goes in the first register +// group, etc. +// Source(0): base address +// Source(1): mask +// Source(2): number of fields - 1 +// Destination(0): vector destination register (for the child instruction). +void VlSegment(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 2) + 1; + // Effective vector length multiplier. + int emul = (element_width * rv_vector->vector_length_multiplier()) / + rv_vector->selected_element_width(); + if (emul * num_fields > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int segment_stride = num_fields * element_width; + int num_elements = num_fields * num_segments; + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + + for (int i = start; i < num_segments; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + for (int field = 0; field < num_fields; field++) { + masks[field * num_segments + i] = mask_value; + addresses[field * num_segments + i] = + base + i * segment_stride + field * element_width; + if (masks[field * num_segments + i]) { + if (!CheckCapBounds(inst, addresses[field * num_segments + i], + element_width, cap_reg, state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + } + } + } + auto *context = new VectorLoadContext(data_db, mask_db, element_width, start, + num_segments); + data_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, data_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load strided adds a byte address stride to the base address for each +// segment. Note, the stride offset is not scaled by the segment size. +// Source(0): base address +// Source(1): stride +// Source(2): mask +// Source(3): number of fields - 1 +// Destination(0): vector destination register (for the child instruction). +void VlSegmentStrided(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int64_t segment_stride = GetInstructionSource<int64_t>(inst, 1); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 3) + 1; + // Effective vector length multiplier. + int emul = (element_width * rv_vector->vector_length_multiplier()) / + rv_vector->selected_element_width(); + if (emul * num_fields > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int num_elements = num_fields * num_segments; + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get the spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + for (int i = start; i < num_segments; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + for (int field = 0; field < num_fields; field++) { + masks[field * num_segments + i] = mask_value; + addresses[field * num_segments + i] = + base + i * segment_stride + field * element_width; + if (masks[field * num_segments + i]) { + if (!CheckCapBounds(inst, addresses[field * num_segments + i], + element_width, cap_reg, state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + } + } + } + // Allocate the context and submit the load. + auto *context = new VectorLoadContext(data_db, mask_db, element_width, start, + num_segments); + data_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, data_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector load segment, indexed. Similar to the other segment loads, except +// that the offset to the base address comes from a vector of indices. Each +// offset is a byte address, and is not scaled by the segment size. +// Source(0): base address +// Source(1): index vector +// Source(2): mask +// Source(3): number of fields - 1 +// Destination(0): vector destination register (for the child instruction). +void VlSegmentIndexed(int index_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + int start = rv_vector->vstart(); + auto cap_reg = GetCapSource(inst, 0); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + auto *index_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 3) + 1; + int element_width = rv_vector->selected_element_width(); + // Effective vector length multiplier. + int lmul8 = rv_vector->vector_length_multiplier(); + // Validate lmul. + if (lmul8 * num_fields > 64) { + LOG(WARNING) << "Vector segment load indexed: too many registers"; + rv_vector->set_vector_exception(); + return; + } + // Index lmul is scaled from the lmul by the relative size of the index + // element to the SEW (selected element width). + int index_emul = (element_width * lmul8) / element_width; + // Validate that index_emul has a legal value. + if ((index_emul > 64) || (index_emul == 0)) { + // TODO: signal vector error. + LOG(WARNING) << absl::StrCat( + "Vector load indexed: emul (index) out of range: ", index_emul); + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int num_elements = num_fields * num_segments; + + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get the spans for the addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + + for (int i = start; i < num_segments; i++) { + // The mask value is per segment. + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + // Read the index value. + uint64_t offset; + switch (index_width) { + case 1: + offset = index_op->AsUint8(i); + break; + case 2: + offset = index_op->AsUint16(i); + break; + case 4: + offset = index_op->AsUint32(i); + break; + case 8: + offset = index_op->AsUint64(i); + break; + default: + offset = 0; + // TODO: signal error. + LOG(ERROR) << "Internal error - illegal value for index_width"; + rv_vector->set_vector_exception(); + return; + } + for (int field = 0; field < num_fields; field++) { + masks[field * num_segments + i] = mask_value; + addresses[field * num_segments + i] = base + offset + field; + if (masks[field * num_segments + i]) { + if (!CheckCapBounds(inst, addresses[field * num_segments + i], + element_width, cap_reg, state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + } + } + } + auto *context = new VectorLoadContext(data_db, mask_db, element_width, start, + num_segments); + data_db->set_latency(0); + state->LoadMemory(inst, address_db, mask_db, element_width, data_db, + inst->child(), context); + // Release the context and address db. + address_db->DecRef(); + context->DecRef(); + rv_vector->clear_vstart(); +} + +// Child instruction used for non-segment vector loads. This function really +// only is used to select a type specific version of the helper function to +// write back the load data. +void VlChild(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + absl::Status status; + int byte_length = rv_vector->vector_register_byte_length(); + switch (static_cast<VectorLoadContext *>(inst->context())->element_width) { + case 1: + status = WriteBackLoadData<uint8_t>(byte_length, inst); + break; + case 2: + status = WriteBackLoadData<uint16_t>(byte_length, inst); + break; + case 4: + status = WriteBackLoadData<uint32_t>(byte_length, inst); + break; + case 8: + status = WriteBackLoadData<uint64_t>(byte_length, inst); + break; + default: + LOG(ERROR) << "Illegal element width"; + return; + } + if (!status.ok()) { + LOG(WARNING) << status.message(); + rv_vector->set_vector_exception(); + } +} + +// Child instruction used for segmen vector loads. This function really only is +// used to select a type specific version of the helper function to write back +// the load data. +void VlSegmentChild(int element_width, const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + absl::Status status; + int byte_length = rv_vector->vector_register_byte_length(); + switch (static_cast<VectorLoadContext *>(inst->context())->element_width) { + case 1: + status = WriteBackSegmentLoadData<uint8_t>(byte_length, inst); + break; + case 2: + status = WriteBackSegmentLoadData<uint16_t>(byte_length, inst); + break; + case 4: + status = WriteBackSegmentLoadData<uint32_t>(byte_length, inst); + break; + case 8: + status = WriteBackSegmentLoadData<uint64_t>(byte_length, inst); + break; + default: + LOG(ERROR) << "Illegal element width"; + return; + } + if (!status.ok()) { + LOG(WARNING) << status.message(); + rv_vector->set_vector_exception(); + } +} + +// Templated helper function for vector stores. +template <typename T> +void StoreVectorStrided(int vector_length, int vstart, int emul, + const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int64_t stride = GetInstructionSource<int64_t>(inst, 2); + auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(3)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + + // Compute total number of elements to be stored. + int num_elements = vector_length; + // Allocate data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto addresses = address_db->Get<uint64_t>(); + auto *store_data_db = db_factory->Allocate(num_elements * sizeof(T)); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + + // Get the spans for addresses and masks. + auto store_data = store_data_db->Get<T>(); + auto masks = mask_db->Get<bool>(); + + // Convert the bit masks to byte masks. Set up addresses. + for (int i = vstart; i < num_elements; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + addresses[i - vstart] = base + i * stride; + masks[i - vstart] = ((src_masks[mask_index] >> mask_offset) & 0b1) != 0; + store_data[i - vstart] = GetInstructionSource<T>(inst, 0, i); + if (masks[i - vstart]) { + if (!CheckCapBounds(inst, addresses[i - vstart], sizeof(T), cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + store_data_db->DecRef(); + return; + } + } + } + // Perform the store. + state->StoreMemory(inst, address_db, mask_db, sizeof(T), store_data_db); + address_db->DecRef(); + mask_db->DecRef(); + store_data_db->DecRef(); +} + +// Vector store - strided. +// Source(0): store data. +// Source(1): base address. +// Source(2): stride. +// Source(3): vector mask register, vector constant {1..} if not masked. +void VsStrided(int element_width, const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int emul = element_width * rv_vector->vector_length_multiplier() / + rv_vector->selected_element_width(); + // Validate that emul has a legal value. + if ((emul > 64) || (emul == 0)) { + LOG(WARNING) << absl::StrCat("Illegal emul value for vector store (", emul, + ")"); + rv_vector->set_vector_exception(); + return; + } + int vlength = rv_vector->vector_length(); + int vstart = rv_vector->vstart(); + switch (element_width) { + case 1: + StoreVectorStrided<uint8_t>(vlength, vstart, emul, inst); + break; + case 2: + StoreVectorStrided<uint16_t>(vlength, vstart, emul, inst); + break; + case 4: + StoreVectorStrided<uint32_t>(vlength, vstart, emul, inst); + break; + case 8: + StoreVectorStrided<uint64_t>(vlength, vstart, emul, inst); + break; + default: + break; + } + rv_vector->clear_vstart(); +} + +// Store vector mask. Single vector register store. +// Source(0): store data +// Source(1): base address +void Vsm(const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + // Compute base address. + int start = rv_vector->vstart(); + // Compute the number of bytes and elements to be stored. + int num_bytes = rv_vector->vector_register_byte_length(); + int num_bytes_stored = num_bytes - start; + // Allocate address data buffer. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_bytes_stored); + auto *store_data_db = db_factory->Allocate(num_bytes_stored); + auto *mask_db = db_factory->Allocate<uint8_t>(num_bytes_stored); + // Get the spans for addresses, masks, and store data. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + auto store_data = store_data_db->Get<uint8_t>(); + // Convert the bit masks to byte masks. Set up addresses. + for (int i = start; i < num_bytes; i++) { + addresses[i - start] = base + i; + masks[i - start] = true; + store_data[i - start] = GetInstructionSource<uint8_t>(inst, 0, i); + if (!CheckCapBounds(inst, addresses[i - start], sizeof(uint8_t), cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + store_data_db->DecRef(); + return; + } + } + state->StoreMemory(inst, address_db, mask_db, sizeof(uint8_t), store_data_db); + address_db->DecRef(); + mask_db->DecRef(); + store_data_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector store indexed. Index values are not scaled by +// element size, as the index values can also be treated as multiple base +// addresses with the base address acting as a common offset. Index values are +// treated as unsigned integers, and are zero extended from the element size to +// the internal address size (or truncated in case the internal XLEN is < index +// element size). +// Source(0): store data. +// Source(1): base address. +// Source(2): offset vector. +// Source(3): mask. +void VsIndexed(int index_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + // Compute base address. + int start = rv_vector->vstart(); + int num_elements = rv_vector->vector_length() - start; + int element_width = rv_vector->selected_element_width(); + int lmul8 = rv_vector->vector_length_multiplier(); + int index_emul = index_width * lmul8 / element_width; + // Validate that emul has a legal value. + if ((index_emul > 64) || (index_emul == 0)) { + // TODO: signal vector error. + rv_vector->set_vector_exception(); + return; + } + + auto *index_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + + // Allocate data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *value_db = db_factory->Allocate(num_elements * element_width); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + + // Get the source mask (stored in a single vector register). + auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(3)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + + // Get the spans for addresses, masks, and data. + auto masks = mask_db->Get<bool>(); + auto addresses = address_db->Get<uint64_t>(); + + // Convert the bit masks to byte masks and compute the element addresses. + for (int i = start; i < num_elements; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + uint64_t offset; + switch (index_width) { + case 1: + offset = index_op->AsUint8(i); + break; + case 2: + offset = index_op->AsUint16(i); + break; + case 4: + offset = index_op->AsUint32(i); + break; + case 8: + offset = index_op->AsUint64(i); + break; + default: + offset = 0; + // TODO: signal error. + LOG(ERROR) << "Illegal value for index type width"; + return; + } + addresses[i - start] = base + offset; + masks[i - start] = ((src_masks[mask_index] >> mask_offset) & 0b1) != 0; + switch (element_width) { + case 1: + value_db->Set<uint8_t>(i, GetInstructionSource<uint8_t>(inst, 0, i)); + break; + case 2: + value_db->Set<uint16_t>(i, GetInstructionSource<uint16_t>(inst, 0, i)); + break; + case 4: + value_db->Set<uint32_t>(i, GetInstructionSource<uint32_t>(inst, 0, i)); + break; + case 8: + value_db->Set<uint64_t>(i, GetInstructionSource<uint64_t>(inst, 0, i)); + break; + default: + offset = 0; + // TODO: signal error. + LOG(ERROR) << "Illegal value for element width"; + break; + } + if (masks[i - start]) { + if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + return; + } + } + } + + // Set up context and submit store + state->StoreMemory(inst, address_db, mask_db, element_width, value_db); + address_db->DecRef(); + mask_db->DecRef(); + value_db->DecRef(); + rv_vector->clear_vstart(); +} + +void VsRegister(int num_regs, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base = cap_reg->address(); + int num_elements = + rv_vector->vector_register_byte_length() * num_regs / sizeof(uint64_t); + // Allocate data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate<uint64_t>(num_elements); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get the address, mask, and data spans. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + auto data = data_db->Get<uint64_t>(); + for (int i = 0; i < num_elements; i++) { + addresses[i] = base + i * sizeof(uint64_t); + masks[i] = true; + data[i] = GetInstructionSource<uint64_t>(inst, 0, i); + if (!CheckCapBounds(inst, addresses[i], sizeof(uint64_t), cap_reg, state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + } // Submit store. + state->StoreMemory(inst, address_db, mask_db, sizeof(uint64_t), data_db); + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector store segment (unit stride). This stores the segments contiguously +// in memory in a sequential manner. +void VsSegment(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base_address = cap_reg->address(); + int start = rv_vector->vstart(); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 3) + 1; + // Effective vector length multiplier. + int emul = (element_width * rv_vector->vector_length_multiplier()) / + rv_vector->selected_element_width(); + if (emul * num_fields > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + LOG(ERROR) << "Reserved encoding error"; + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int num_elements = num_fields * num_segments; + int num_elements_per_reg = + rv_vector->vector_register_byte_length() / element_width; + int reg_mul = std::max(1, emul / 8); + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + auto data1 = data_db->Get<uint8_t>(); + auto data2 = data_db->Get<uint16_t>(); + auto data4 = data_db->Get<uint32_t>(); + auto data8 = data_db->Get<uint64_t>(); + + auto *data_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + uint64_t address = base_address; + int count = 0; + for (int segment = start; segment < num_segments; segment++) { + // Masks are applied on a segment basis. + int mask_index = segment >> 3; + int mask_offset = segment & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + // If the segments span multiple registers, compute the register offset + // from the current segment number (upper bits). + int reg_offset = segment / num_elements_per_reg; + for (int field = 0; field < num_fields; field++) { + // Compute register offset number within register group. + int reg_no = field * reg_mul + reg_offset; + // Compute element address and set mask value. + addresses[count] = address; + address += element_width; + masks[count] = mask_value; + if (!mask_value) { + // If mask is false, just increment count and go to next field. + count++; + continue; + } + if (!CheckCapBounds(inst, addresses[count], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + // Write store data from register db to data db. + auto *reg_db = data_op->GetRegister(reg_no)->data_buffer(); + switch (element_width) { + case 1: + data1[count] = reg_db->Get<uint8_t>(segment % num_elements_per_reg); + break; + case 2: + data2[count] = reg_db->Get<uint16_t>(segment % num_elements_per_reg); + break; + case 4: + data4[count] = reg_db->Get<uint32_t>(segment % num_elements_per_reg); + break; + case 8: + data8[count] = reg_db->Get<uint64_t>(segment % num_elements_per_reg); + break; + default: + break; + } + count++; + } + } + state->StoreMemory(inst, address_db, mask_db, element_width, data_db); + // Release the dbs. + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector strided segment store. This stores each segment contiguously at +// locations separated by the segment stride. +void VsSegmentStrided(int element_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base_address = cap_reg->address(); + int start = rv_vector->vstart(); + int64_t segment_stride = GetInstructionSource<int64_t>(inst, 2); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(3)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 4) + 1; + // Effective vector length multiplier. + int emul = (element_width * rv_vector->vector_length_multiplier()) / + rv_vector->selected_element_width(); + if (emul * num_fields > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + LOG(ERROR) << "Reserved encoding error"; + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int num_elements = num_fields * num_segments; + int num_elements_per_reg = + rv_vector->vector_register_byte_length() / element_width; + int reg_mul = std::max(1, emul / 8); + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + auto data1 = data_db->Get<uint8_t>(); + auto data2 = data_db->Get<uint16_t>(); + auto data4 = data_db->Get<uint32_t>(); + auto data8 = data_db->Get<uint64_t>(); + + auto *data_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + uint64_t segment_address = base_address; + int count = 0; + for (int segment = start; segment < num_segments; segment++) { + // Masks are applied on a segment basis. + int mask_index = segment >> 3; + int mask_offset = segment & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + // If the segments span multiple registers, compute the register offset + // from the current segment number (upper bits). + int reg_offset = segment / num_elements_per_reg; + uint64_t field_address = segment_address; + for (int field = 0; field < num_fields; field++) { + // Compute register offset number within register group. + int reg_no = field * reg_mul + reg_offset; + // Compute element address and set mask value. + addresses[count] = field_address; + field_address += element_width; + masks[count] = mask_value; + if (!mask_value) { + // If mask is false, just increment count and go to next field. + count++; + continue; + } + if (!CheckCapBounds(inst, addresses[count], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + // Write store data from register db to data db. + auto *reg_db = data_op->GetRegister(reg_no)->data_buffer(); + switch (element_width) { + case 1: + data1[count] = reg_db->Get<uint8_t>(segment % num_elements_per_reg); + break; + case 2: + data2[count] = reg_db->Get<uint16_t>(segment % num_elements_per_reg); + break; + case 4: + data4[count] = reg_db->Get<uint32_t>(segment % num_elements_per_reg); + break; + case 8: + data8[count] = reg_db->Get<uint64_t>(segment % num_elements_per_reg); + break; + default: + break; + } + count++; + } + segment_address += segment_stride; + } + state->StoreMemory(inst, address_db, mask_db, element_width, data_db); + // Release the dbs. + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + rv_vector->clear_vstart(); +} + +// Vector indexed segment store. This instruction stores each segment +// contiguously at an address formed by adding the index value for that +// segment (from the index vector source operand) to the base address. +void VsSegmentIndexed(int index_width, const Instruction *inst) { + auto *state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = state->rv_vector(); + auto cap_reg = GetCapSource(inst, 1); + if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; + uint64_t base_address = cap_reg->address(); + int start = rv_vector->vstart(); + auto src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(3)); + auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + int num_fields = GetInstructionSource<int32_t>(inst, 4) + 1; + int element_width = rv_vector->selected_element_width(); + // Effective vector length multiplier. + int lmul = rv_vector->vector_length_multiplier(); + int emul = index_width * lmul / element_width; + if (lmul * num_fields > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + LOG(ERROR) << "Reserved encoding error - lmul * num_fields out of range"; + rv_vector->set_vector_exception(); + return; + } + if (emul == 0 || emul > 64) { + // This is a reserved encoding error. + // If > 64, it means that the number of registers required is > 8. + // TODO: signal error. + LOG(ERROR) << "Reserved encoding error - emul out of range."; + rv_vector->set_vector_exception(); + return; + } + int num_segments = rv_vector->vector_length(); + int num_elements = num_fields * num_segments; + int num_elements_per_reg = + rv_vector->vector_register_byte_length() / element_width; + int reg_mul = std::max(1, lmul / 8); + // Set up data buffers. + auto *db_factory = inst->state()->db_factory(); + auto *data_db = db_factory->Allocate(num_elements * element_width); + auto *address_db = db_factory->Allocate<uint64_t>(num_elements); + auto *mask_db = db_factory->Allocate<bool>(num_elements); + // Get spans for addresses and masks. + auto addresses = address_db->Get<uint64_t>(); + auto masks = mask_db->Get<bool>(); + auto data1 = data_db->Get<uint8_t>(); + auto data2 = data_db->Get<uint16_t>(); + auto data4 = data_db->Get<uint32_t>(); + auto data8 = data_db->Get<uint64_t>(); + + auto *data_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + int count = 0; + for (int segment = start; segment < num_segments; segment++) { + // Masks are applied on a segment basis. + int mask_index = segment >> 3; + int mask_offset = segment & 0b111; + bool mask_value = ((src_masks[mask_index] >> mask_offset) & 0x1) != 0; + // If the segments span multiple registers, compute the register offset + // from the current segment number (upper bits). + int reg_offset = segment / num_elements_per_reg; + int64_t index_value; + switch (index_width) { + case 1: + index_value = GetInstructionSource<int8_t>(inst, 2, segment); + break; + case 2: + index_value = GetInstructionSource<int16_t>(inst, 2, segment); + break; + case 4: + index_value = GetInstructionSource<int32_t>(inst, 2, segment); + break; + case 8: + index_value = GetInstructionSource<int64_t>(inst, 2, segment); + break; + default: + LOG(ERROR) << "Invalid index width: " << index_width << "."; + rv_vector->set_vector_exception(); + return; + } + uint64_t field_address = base_address + index_value; + for (int field = 0; field < num_fields; field++) { + // Compute register offset number within register group. + int reg_no = field * reg_mul + reg_offset; + // Compute element address and set mask value. + addresses[count] = field_address; + field_address += element_width; + masks[count] = mask_value; + if (!mask_value) { + // If mask is false, just increment count and go to next field. + count++; + continue; + } + if (!CheckCapBounds(inst, addresses[count], element_width, cap_reg, + state)) { + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + return; + } + // Write store data from register db to data db. + auto *reg_db = data_op->GetRegister(reg_no)->data_buffer(); + switch (element_width) { + case 1: + data1[count] = reg_db->Get<uint8_t>(segment % num_elements_per_reg); + break; + case 2: + data2[count] = reg_db->Get<uint16_t>(segment % num_elements_per_reg); + break; + case 4: + data4[count] = reg_db->Get<uint32_t>(segment % num_elements_per_reg); + break; + case 8: + data8[count] = reg_db->Get<uint64_t>(segment % num_elements_per_reg); + break; + default: + LOG(ERROR) << "Invalid element width: " << element_width << "."; + return; + } + count++; + } + } + state->StoreMemory(inst, address_db, mask_db, element_width, data_db); + // Release the dbs. + address_db->DecRef(); + mask_db->DecRef(); + data_db->DecRef(); + rv_vector->clear_vstart(); +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_memory_instructions.h b/cheriot/riscv_cheriot_vector_memory_instructions.h new file mode 100644 index 0000000..ab3cb41 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_memory_instructions.h
@@ -0,0 +1,140 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_MEMORY_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_MEMORY_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file declares the semantic functions used to implement RiscV vector +// load store instructions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Set vector length. rd/rs1_ zero is true if the corresponding operand is +// register x0. +// The instruction takes 2 instruction source scalar operands: source operand 0 +// is the requested vector length, source operand 1 is the requested vector +// configuration value. The destination operand is a scalar. +void Vsetvl(bool rd_zero, bool rs1_zero, const Instruction *inst); + +// Vector load semantic functions. +// Load with unit stride as element width. The instruction takes 2 source and 1 +// destination operands. Source operand 0 is a scalar base address, source +// operand 1 is the vector mask (either a vector register, or a constant). +// Destination operand 0 is assigned to the child instruction and is a vector +// register (group). +void VlUnitStrided(int element_width, const Instruction *inst); +// Load with constant stride, the parameter specifies the width of the vector +// elements. This instruction takes 3 source and 1 destination operands. Source +// operand 0 is a scalar base address, source operand 1 is a scalar stride, +// source operand 2 is the vector mask (either a vector register, or a +// constant). Destination operand 0 is assigned to the child instruction and +// is a vector register (group). +void VlStrided(int element_width, const Instruction *inst); +// Load vector mask. This instruction takes 1 source and 1 destination operand. +// The source operand is a scalar base address, the destination operand is the +// vector register to write the mask to. +void Vlm(const Instruction *inst); +// Indexed vector load (ordered and unordered). This instruction takes 3 source +// and 1 destination operands. Source operand 0 is a scalar base address, source +// operand 1 is a vector register (group) of indices, source operand 2 is the +// vector mask. Destination operand 0 is assigned to the child instruction and +// is a vector register (group). +void VlIndexed(int index_width, const Instruction *inst); +// Load vector register(s). Takes a parameter specifying how many registers to +// load. This instruction takes 1 source and 1 destination operand. Source +// operand 0 is a scalar base address. Destination operand 0 is assigned to the +// child instruction and is a vector register (group). +void VlRegister(int num_regs, int element_width_bytes, const Instruction *inst); +// Child instruction semantic functions for non-segment loads responsible for +// writing load data back to the target register(s). It takes a single +// destination operand. Destination operand 0 is a vector register (group). +void VlChild(const Instruction *inst); +// Load segment, unit stride. The function takes one parameter that specifies +// the element width. The instruction takes 3 source operands and 1 destination +// operand. Source operand 0 is a scalar base address, source operand 1 is +// the vector mask, and source operand 2 is the number of fields - 1. +// Destination operand 0 is assigned to the child instruction and is a vector +// register (group). +void VlSegment(int element_width, const Instruction *inst); +// Load segment strided. The function takes one parameter that specifies +// the element width. The instruction takes 4 source operands and 1 destination +// operand. Source operand 0 is a scalar base address, source operand 1 is a +// scalar stride, source operand 2 is the vector mask, and source operand 3 is +// the number of fields - 1. Destination operand 0 is assigned to the child +// instruction and is a vector register (group). +void VlSegmentStrided(int element_width, const Instruction *inst); +// Load segment indexed. The function takes one parameter that specifies +// the index element width. The instruction takes 4 source operands and 1 +// destination operand. Source operand 0 is a scalar base address, source +// operand 1 is a vector register (group) of indices, source operand 2 is the +// vector mask, and source operand 3 is the number of fields - 1. Destination +// operand 0 is assigned to the child instruction and is a vector register +// (group). +void VlSegmentIndexed(int index_width, const Instruction *inst); +// Child instruction semantic functions for segment loads responsible for +// writing load data back to the target register(s). It takes a single +// destination operand. Destination operand 0 is a vector register (group). +void VlSegmentChild(int element_width, const Instruction *inst); + +// Vector store semantic functions. + +// Store strided. The function takes one parameter that specifies the element +// width. The instruction takes 4 source parameters. Source 0 is the store data +// vector register (group), source 1 is the scalar base address, source 2 is the +// stride, and source 3 is the vector mask. +void VsStrided(int element_width, const Instruction *inst); +// Store vector mask. This instruction takes 2 source operands. Source 0 is the +// vector mask register to be stored, the second is the scalar base address. +void Vsm(const Instruction *inst); +// Store indexed. The function takes one parameter that specifies the element +// width. The instruction takes 4 source parameters. Source 0 is the store data +// vector register (group), source 1 is a vector (group) of indices, source 2 is +// the stride, and source 3 is the vector mask. +void VsIndexed(int index_width, const Instruction *inst); +// Store vector register (group). This function takes one parameter that +// specifies the number of registers to store. The instruction takes 2 source +// operands. Source 0 is the source vector register (group), the second is the +// scalar base address. +void VsRegister(int num_regs, const Instruction *inst); +// Store segment, unit stride. The function takes one parameter that specifies +// the element width. The instruction takes 4 source operands. Source operand 0 +// is the store data, source operand 1 is the scalar base address, source +// operand 2 is the vector mask, and source operand 3 is the number of fields +// - 1. +void VsSegment(int element_width, const Instruction *inst); +// Store segment, unit stride. The function takes one parameter that specifies +// the element width. The instruction takes 5 source operands. Source operand 0 +// is the store data, source operand 1 is the scalar base address, source +// operand 2 is the segment stride, source operand 3 is the vector mask, and +// source operand 4 is the number of fields +// - 1. +void VsSegmentStrided(int element_width, const Instruction *inst); +// Load segment indexed. The function takes one parameter that specifies +// the index element width. The instruction takes 5 source operands. Source +// operand 0 is the store data, source operand 1 is a scalar base address, +// source operand 2 is a vector register (group) of indices, source operand 3 is +// the vector mask, and source operand 4 is the number of fields - 1. +void VsSegmentIndexed(int index_width, const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_MEMORY_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_opi_instructions.cc b/cheriot/riscv_cheriot_vector_opi_instructions.cc new file mode 100644 index 0000000..e2fa037 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_opi_instructions.cc
@@ -0,0 +1,1411 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_opi_instructions.h" + +#include <algorithm> +#include <cstdint> +#include <cstring> +#include <limits> +#include <type_traits> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_register.h" + +// This file contains the instruction semantic functions for most of the +// vector instructions in the OPIVV, OPIVX, and OPIVI encoding spaces. The +// exception is vector element permute instructions and a couple of reduction +// instructions. + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::MakeUnsigned; +using ::mpact::sim::generic::WideType; +using riscv::RV32VectorSourceOperand; +using std::numeric_limits; + +// Vector arithmetic operations. + +// Vector add. +void Vadd(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 + vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 + vs1; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 + vs1; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 + vs1; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector subtract. +void Vsub(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 - vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 - vs1; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 - vs1; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 - vs1; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector reverse subtract. +void Vrsub(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs1 - vs2; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs1 - vs2; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs1 - vs2; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs1 - vs2; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector logical operations. + +// Vector and. +void Vand(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 & vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 & vs1; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 & vs1; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 & vs1; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector or. +void Vor(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 | vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 | vs1; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 | vs1; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 | vs1; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector xor. +void Vxor(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 ^ vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 ^ vs1; }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 ^ vs1; }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 ^ vs1; }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector shift operations. + +// Vector shift left logical. +void Vsll(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 << (vs1 & 0b111); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + return vs2 << (vs1 & 0b1111); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return vs2 << (vs1 & 0b1'1111); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return vs2 << (vs1 & 0b11'1111); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector shift right logical. +void Vsrl(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 >> (vs1 & 0b111); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + return vs2 >> (vs1 & 0b1111); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return vs2 >> (vs1 & 0b1'1111); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return vs2 >> (vs1 & 0b11'1111); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector shift right arithmetic. +void Vsra(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { + return vs2 >> (vs1 & 0b111); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + return vs2 >> (vs1 & 0b1111); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + return vs2 >> (vs1 & 0b1'1111); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + return vs2 >> (vs1 & 0b11'1111); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector narrowing shift operations. Narrow from sew * 2 to sew. + +// Vector narrowing shift right logical. Source op 0 is shifted right +// by source op 1 and the result is 1/2 the size of source op 0. +void Vnsrl(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Vector length multiplier out of range for narrowing shift"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint16_t, uint8_t>( + rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint8_t { + return static_cast<uint8_t>(vs2 >> (vs1 & 0b1111)); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint32_t, uint16_t>( + rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint16_t { + return static_cast<uint16_t>(vs2 >> (vs1 & 0b1'1111)); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint64_t, uint32_t>( + rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint32_t { + return static_cast<uint32_t>(vs2 >> (vs1 & 0b11'1111)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value for narrowing shift right: " << sew; + return; + } +} + +// Vector narrowing shift right arithmetic. Source op 0 is shifted right +// by source op 1 and the result is 1/2 the size of source op 0. +void Vnsra(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // If the vector length multiplier (x8) is greater than 32, that means that + // the source values (sew * 2) would exceed the available register group. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) << "Vector length multiplier out of range for narrowing shift"; + return; + } + // Note, sew cannot be 64 bits, as there is no support for operations on + // 128 bit quantities. + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int16_t, int8_t>( + rv_vector, inst, [](int16_t vs2, int8_t vs1) -> int8_t { + return vs2 >> (vs1 & 0b1111); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int32_t, int16_t>( + rv_vector, inst, [](int32_t vs2, int16_t vs1) -> int16_t { + return vs2 >> (vs1 & 0b1'1111); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int64_t, int32_t>( + rv_vector, inst, [](int64_t vs2, int32_t vs1) -> int32_t { + return vs2 >> (vs1 & 0b11'1111); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value for narrowing shift right: " << sew; + return; + } +} + +// Vector unsigned min. +void Vminu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return std::min(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + return std::min(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return std::min(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return std::min(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector signed min. +void Vmin(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, + [](int8_t vs2, int8_t vs1) -> int8_t { return std::min(vs2, vs1); }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + return std::min(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + return std::min(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + return std::min(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector unsigned max. +void Vmaxu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return std::max(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + return std::max(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return std::max(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return std::max(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector signed max. +void Vmax(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, + [](int8_t vs2, int8_t vs1) -> int8_t { return std::max(vs2, vs1); }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + return std::max(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + return std::max(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + return std::max(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set equal. +void Vmseq(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 == vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 == vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 == vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 == vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector compare instructions. + +// Set not equal. +void Vmsne(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 != vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 != vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 != vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 != vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set less than unsigned. +void Vmsltu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 < vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 < vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 < vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 < vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set less than. +void Vmslt(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<int8_t, int8_t>( + rv_vector, inst, + [](int8_t vs2, int8_t vs1) -> bool { return (vs2 < vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<int16_t, int16_t>( + rv_vector, inst, + [](int16_t vs2, int16_t vs1) -> bool { return (vs2 < vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<int32_t, int32_t>( + rv_vector, inst, + [](int32_t vs2, int32_t vs1) -> bool { return (vs2 < vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<int64_t, int64_t>( + rv_vector, inst, + [](int64_t vs2, int64_t vs1) -> bool { return (vs2 < vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set less than or equal unsigned. +void Vmsleu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 <= vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 <= vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 <= vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 <= vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set less than or equal. +void Vmsle(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<int8_t, int8_t>( + rv_vector, inst, + [](int8_t vs2, int8_t vs1) -> bool { return (vs2 <= vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<int16_t, int16_t>( + rv_vector, inst, + [](int16_t vs2, int16_t vs1) -> bool { return (vs2 <= vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<int32_t, int32_t>( + rv_vector, inst, + [](int32_t vs2, int32_t vs1) -> bool { return (vs2 <= vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<int64_t, int64_t>( + rv_vector, inst, + [](int64_t vs2, int64_t vs1) -> bool { return (vs2 <= vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set greater than unsigned. +void Vmsgtu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 > vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 > vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 > vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 > vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Set greater than. +void Vmsgt(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorMaskOp<int8_t, int8_t>( + rv_vector, inst, + [](int8_t vs2, int8_t vs1) -> bool { return (vs2 > vs1); }); + case 2: + return RiscVBinaryVectorMaskOp<int16_t, int16_t>( + rv_vector, inst, + [](int16_t vs2, int16_t vs1) -> bool { return (vs2 > vs1); }); + case 4: + return RiscVBinaryVectorMaskOp<int32_t, int32_t>( + rv_vector, inst, + [](int32_t vs2, int32_t vs1) -> bool { return (vs2 > vs1); }); + case 8: + return RiscVBinaryVectorMaskOp<int64_t, int64_t>( + rv_vector, inst, + [](int64_t vs2, int64_t vs1) -> bool { return (vs2 > vs1); }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Saturated unsigned addition. +void Vsaddu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { + uint8_t sum = vs2 + vs1; + if (sum < vs1) { + sum = numeric_limits<uint8_t>::max(); + rv_vector->set_vxsat(true); + } + return sum; + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { + uint16_t sum = vs2 + vs1; + if (sum < vs1) { + sum = numeric_limits<uint16_t>::max(); + rv_vector->set_vxsat(true); + } + return sum; + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { + uint32_t sum = vs2 + vs1; + if (sum < vs1) { + sum = numeric_limits<uint32_t>::max(); + rv_vector->set_vxsat(true); + } + return sum; + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { + uint64_t sum = vs2 + vs1; + if (sum < vs1) { + sum = numeric_limits<uint64_t>::max(); + rv_vector->set_vxsat(true); + } + return sum; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Helper function for Vsadd. +// Uses unsigned arithmetic for the addition to avoid signed overflow, which, +// when compiled with --config=asan, will trigger an exception. +template <typename T> +inline T VsaddHelper(T vs2, T vs1, CheriotVectorState *rv_vector) { + using UT = typename std::make_unsigned<T>::type; + UT uvs2 = static_cast<UT>(vs2); + UT uvs1 = static_cast<UT>(vs1); + UT usum = uvs2 + uvs1; + T sum = static_cast<T>(usum); + if (((vs2 ^ vs1) >= 0) && ((sum ^ vs2) < 0)) { + rv_vector->set_vxsat(true); + return vs2 > 0 ? numeric_limits<T>::max() : numeric_limits<T>::min(); + } + return sum; +} + +// Saturated signed addition. +void Vsadd(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VsaddHelper(vs2, vs1, rv_vector); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VsaddHelper(vs2, vs1, rv_vector); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VsaddHelper(vs2, vs1, rv_vector); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VsaddHelper(vs2, vs1, rv_vector); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Saturated unsigned subtract. +void Vssubu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { + uint8_t diff = vs2 - vs1; + if (vs2 < vs1) { + diff = 0; + rv_vector->set_vxsat(true); + } + return diff; + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { + uint16_t diff = vs2 - vs1; + if (vs2 < vs1) { + diff = 0; + rv_vector->set_vxsat(true); + } + return diff; + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { + uint32_t diff = vs2 - vs1; + if (vs2 < vs1) { + diff = 0; + rv_vector->set_vxsat(true); + } + return diff; + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { + uint64_t diff = vs2 - vs1; + if (vs2 < vs1) { + diff = 0; + rv_vector->set_vxsat(true); + } + return diff; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +template <typename T> +T VssubHelper(T vs2, T vs1, CheriotVectorState *rv_vector) { + using UT = typename std::make_unsigned<T>::type; + UT uvs2 = static_cast<UT>(vs2); + UT uvs1 = static_cast<UT>(vs1); + UT udiff = uvs2 - uvs1; + T diff = static_cast<T>(udiff); + if (((vs2 ^ vs1) < 0) && ((diff ^ vs1) >= 0)) { + rv_vector->set_vxsat(true); + return vs1 < 0 ? numeric_limits<T>::max() : numeric_limits<T>::min(); + } + return diff; +} + +// Saturated signed subtract. +void Vssub(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VssubHelper(vs2, vs1, rv_vector); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VssubHelper(vs2, vs1, rv_vector); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VssubHelper(vs2, vs1, rv_vector); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VssubHelper(vs2, vs1, rv_vector); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Add/Subtract with carry, carry generation. +void Vadc(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { + return vs2 + vs1 + static_cast<uint8_t>(mask); + }); + case 2: + return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { + return vs2 + vs1 + static_cast<uint16_t>(mask); + }); + case 4: + return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { + return vs2 + vs1 + static_cast<uint32_t>(mask); + }); + case 8: + return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { + return vs2 + vs1 + static_cast<uint64_t>(mask); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Add with carry - carry generation. +void Vmadc(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVSetMaskBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> bool { + uint16_t sum = static_cast<uint16_t>(vs2) + + static_cast<uint16_t>(vs1) + + static_cast<uint16_t>(mask); + sum >>= 8; + return sum; + }); + case 2: + return RiscVSetMaskBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1, bool mask) -> bool { + uint32_t sum = static_cast<uint32_t>(vs2) + + static_cast<uint32_t>(vs1) + + static_cast<uint32_t>(mask); + sum >>= 16; + return sum != 0; + }); + case 4: + return RiscVSetMaskBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1, bool mask) -> bool { + uint64_t sum = static_cast<uint64_t>(vs2) + + static_cast<uint64_t>(vs1) + + static_cast<uint64_t>(mask); + sum >>= 32; + return sum != 0; + }); + case 8: + return RiscVSetMaskBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1, bool mask) -> bool { + // Compute carry by doing two additions. First get the carry out + // from adding the low byte. + uint64_t carry = + (vs1 & 0xff + vs2 & 0xff + static_cast<uint64_t>(mask)) >> 8; + // Now add the high 7 bytes together with the carry from the low + // byte addition. + uint64_t sum = (vs1 >> 8) + (vs2 >> 8) + carry; + // The carry out is in the high byte. + sum >>= 56; + return sum != 0; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Subtract with borrow. +void Vsbc(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { + return vs2 - vs1 - static_cast<uint8_t>(mask); + }); + case 2: + return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { + return vs2 - vs1 - static_cast<uint16_t>(mask); + }); + case 4: + return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { + return vs2 - vs1 - static_cast<uint32_t>(mask); + }); + case 8: + return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { + return vs2 - vs1 - static_cast<uint64_t>(mask); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Subtract with borrow - borrow generation. +void Vmsbc(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVSetMaskBinaryVectorMaskOp<uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> bool { + return static_cast<uint16_t>(vs2) < + static_cast<uint16_t>(mask) + static_cast<uint16_t>(vs1); + }); + case 2: + return RiscVSetMaskBinaryVectorMaskOp<uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1, bool mask) -> bool { + return static_cast<uint32_t>(vs2) < + static_cast<uint32_t>(mask) + static_cast<uint32_t>(vs1); + }); + case 4: + return RiscVSetMaskBinaryVectorMaskOp<uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1, bool mask) -> bool { + return static_cast<uint64_t>(vs2) < + static_cast<uint64_t>(mask) + static_cast<uint64_t>(vs1); + }); + case 8: + return RiscVSetMaskBinaryVectorMaskOp<uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1, bool mask) -> bool { + if (vs2 < vs1) return true; + if (vs2 == vs1) return mask; + return false; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector merge. +void Vmerge(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { + return mask ? vs1 : vs2; + }); + case 2: + return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { + return mask ? vs1 : vs2; + }); + case 4: + return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { + return mask ? vs1 : vs2; + }); + case 8: + return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { + return mask ? vs1 : vs2; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector move register(s). +void Vmvr(int num_regs, Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vector_exception()) return; + + auto *src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + if (src_op->size() < num_regs) { + LOG(ERROR) << "Vmvr: source operand has fewer registers than requested"; + rv_vector->set_vector_exception(); + return; + } + if (dest_op->size() < num_regs) { + LOG(ERROR) + << "Vmvr: destination operand has fewer registers than requested"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + int num_elements_per_vector = rv_vector->vector_register_byte_length() / sew; + int vstart = rv_vector->vstart(); + int start_reg = vstart / num_elements_per_vector; + for (int i = start_reg; i < num_regs; i++) { + auto *src_db = src_op->GetRegister(i)->data_buffer(); + auto *dest_db = dest_op->AllocateDataBuffer(i); + std::memcpy(dest_db->raw_ptr(), src_db->raw_ptr(), + dest_db->size<uint8_t>()); + dest_db->Submit(); + } + rv_vector->clear_vstart(); +} + +// Templated helper function for shift right with rounding. +template <typename T> +T VssrHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { + using UT = typename MakeUnsigned<T>::type; + int rm = rv_vector->vxrm(); + int max_shift = (sizeof(T) << 3) - 1; + int shift_amount = static_cast<int>(vs1 & max_shift); + // Create mask for the bits that will be shifted out + 1. + UT round_bits = vs2; + if (shift_amount < max_shift) { + UT mask = numeric_limits<UT>::max(); + mask = ~(numeric_limits<UT>::max() << shift_amount + 1); + round_bits = vs2 & mask; + } + vs2 >>= shift_amount; + vs2 += static_cast<T>(GetRoundingBit(rm, round_bits, shift_amount + 1)); + return vs2; +} + +// Logical shift right with rounding. +void Vssrl(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Arithmetic shift right with rounding. +void Vssra(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VssrHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Templated helper function for shift right with rounding and saturation. +template <typename DT, typename WT, typename T> +T VnclipHelper(CheriotVectorState *rv_vector, WT vs2, T vs1) { + using WUT = typename std::make_unsigned<WT>::type; + int rm = rv_vector->vxrm(); + int max_shift = (sizeof(WT) << 3) - 1; + int shift_amount = vs1 & ((sizeof(WT) << 3) - 1); + // Create mask for the bits that will be shifted out + 1. + WUT mask = vs2; + if (shift_amount < max_shift) { + mask = ~(numeric_limits<WUT>::max() << (shift_amount + 1)); + } + WUT round_bits = vs2 & mask; + // Perform the rounded shift. + vs2 = + (vs2 >> shift_amount) + GetRoundingBit(rm, round_bits, shift_amount + 1); + // Saturate if needed. + if (vs2 > numeric_limits<DT>::max()) { + rv_vector->set_vxsat(true); + return numeric_limits<DT>::max(); + } + if (vs2 < numeric_limits<DT>::min()) { + rv_vector->set_vxsat(true); + return numeric_limits<DT>::min(); + } + return static_cast<DT>(vs2); +} + +// Arithmetic shift right and narrowing from 2*sew to sew with rounding and +// signed saturation. +void Vnclip(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int lmul8 = rv_vector->vector_length_multiplier(); + // This is a narrowing operation and sew is that of the narrow data type. + // Thus if lmul > 32, then emul for the wider data type is illegal. + if (lmul8 > 32) { + LOG(ERROR) << "Illegal lmul value"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int16_t, int8_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int8_t vs1) -> int8_t { + return VnclipHelper<int8_t, int16_t, int8_t>(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int32_t, int16_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int16_t vs1) -> int16_t { + return VnclipHelper<int16_t, int32_t, int16_t>(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int64_t, int32_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int32_t vs1) -> int32_t { + return VnclipHelper<int32_t, int64_t, int32_t>(rv_vector, vs2, vs1); + }); + case 8: + // There is no valid sew * 2 = 16. + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Logical shift right and narrowing from 2*sew to sew with rounding and +// unsigned saturation. +void Vnclipu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int lmul8 = rv_vector->vector_length_multiplier(); + // This is a narrowing operation and sew is that of the narrow data type. + // Thus if lmul > 32, then emul for the wider data type is illegal. + if (lmul8 > 32) { + LOG(ERROR) << "Illegal lmul value"; + rv_vector->set_vector_exception(); + return; + } + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint16_t, uint8_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint8_t vs1) -> uint8_t { + return VnclipHelper<uint8_t, uint16_t, uint8_t>(rv_vector, vs2, + vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint32_t, uint16_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint16_t vs1) -> uint16_t { + return VnclipHelper<uint16_t, uint32_t, uint16_t>(rv_vector, vs2, + vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint64_t, uint32_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint32_t vs1) -> uint32_t { + return VnclipHelper<uint32_t, uint64_t, uint32_t>(rv_vector, vs2, + vs1); + }); + case 8: + // There is no valid sew * 2 = 16. + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Perform a signed multiply from T to wider int type. Shift that result +// right by sizeof(T) * 8 - 1 and round. Saturate if needed to fit into T. +template <typename T> +T VsmulHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { + using WT = typename WideType<T>::type; + WT vd_w; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + vd_w = vs2_w * vs1_w; + vd_w = VssrHelper<WT>(rv_vector, vd_w, sizeof(T) * 8 - 1); + if (vd_w < numeric_limits<T>::min()) { + rv_vector->set_vxsat(true); + return numeric_limits<T>::min(); + } + if (vd_w > numeric_limits<T>::max()) { + rv_vector->set_vxsat(true); + return numeric_limits<T>::max(); + } + return static_cast<T>(vd_w); +} + +// Vector fractional multiply with rounding and saturation. +void Vsmul(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VsmulHelper<int8_t>(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VsmulHelper<int16_t>(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VsmulHelper<int32_t>(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VsmulHelper<int64_t>(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_opi_instructions.h b/cheriot/riscv_cheriot_vector_opi_instructions.h new file mode 100644 index 0000000..87ccbd9 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_opi_instructions.h
@@ -0,0 +1,236 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPI_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPI_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +// This file declares the vector instruction semantic functions for most of the +// vector instructions in the OPIVV, OPIVX, and OPIVI encoding spaces. The +// exceptions are vector permute instructions and some vector reduction +// instructions. + +using Instruction = ::mpact::sim::generic::Instruction; + +// Integer vector operations. + +// Element wide vector add. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vadd(Instruction *inst); +// Element wide vector subtract. This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector) or rs1 (scalar). Source 2 is a vector +// mask operand. +void Vsub(Instruction *inst); +// Element wide vector reverse subtract. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vrsub(Instruction *inst); +// Element wide bitwise and. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vand(Instruction *inst); +// Element wide bitwise or. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vor(Instruction *inst); +// Element wide bitwise xor. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vxor(Instruction *inst); +// Element wide logical left shift. This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vsll(Instruction *inst); +// Element wide logical right shift. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand. +void Vsrl(Instruction *inst); +// Element wide arithmetic right shift. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand. +void Vsra(Instruction *inst); +// Element wide narrowing logical right shift. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vnsrl(Instruction *inst); +// Element wide narrowing arithmetic right shift. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vnsra(Instruction *inst); +// Vector signed min (pairwise). This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vmin(Instruction *inst); +// Vector unsigned min (pairwise). This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vminu(Instruction *inst); +// Vector signed max (pairwise). This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vmax(Instruction *inst); +// Vector unsigned max (pairwise). This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vmaxu(Instruction *inst); +// Vector mask set equal. This instruction takes three source operands and a +// vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vmseq(Instruction *inst); +// Vector mask set not equal. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vmsne(Instruction *inst); +// Vector mask set less than unsigned. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand. +void Vmsltu(Instruction *inst); +// Vector mask set less than signed. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector) or rs1 (scalar). Source 2 is a vector mask +// operand. +void Vmslt(Instruction *inst); +// Vector mask set less or equal unsigned. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector) or rs1 (scalar). Source 2 is a vector mask +// operand. +void Vmsleu(Instruction *inst); +// Vector mask set less or equal signed. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand. +void Vmsle(Instruction *inst); +// Vector mask set greater than unsigned. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vmsgtu(Instruction *inst); +// Vector mask set greater than signed. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vmsgt(Instruction *inst); +// Vector saturating unsigned add. This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vsaddu(Instruction *inst); +// Vector saturating signed add. This instruction takes three source operands +// and a vector destination operand. Source 0 is the vs2 vector source. Source 1 +// is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector +// mask operand. +void Vsadd(Instruction *inst); +// Vector saturating unsigned subtrract. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand. +void Vssubu(Instruction *inst); +// Vector saturating subtract. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. +void Vssub(Instruction *inst); +// Vector add with carry. This instruction takes three source operands and a +// vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand that contains the carry in values. +void Vadc(Instruction *inst); +// Vector add with carry - carry generate. This instruction takes three source +// operands and a vector destination operand. Source 0 is the vs2 vector source. +// Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a +// vector mask operand that contains the carry in values. The output of this +// instruction is the carry outs of each element wise addition. It is stored in +// the format of the vector flags. +void Vmadc(Instruction *inst); +// Vector subtract with borrow. This instruction takes three source operands and +// a vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand that contains the borrow values. +void Vsbc(Instruction *inst); +// Vector subtract with borrow - borrow generate. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand that contains the borrow values. The output +// of this instruction is the borrow outs of each element wise subtraction. It +// is stored in the format of the vector flags. +void Vmsbc(Instruction *inst); +// Vector pairwise merge. This instruction takes three source operands and a +// vector destination operand. Source 0 is the vs2 vector source. Source 1 is +// either vs1 (vector), rs1 (scalar), or an immediate. Source 2 is a vector mask +// operand. This semantic function also captures the functionality of vmv.vv, +// vmv.vx, and vmv.vi, in which case vs2 is register group v0, and the mask +// is all ones. +void Vmerge(Instruction *inst); +// Vector register move. This instruction takes one source operands and a +// vector destination operand. Source 0 is the vs2 vector source. The num_regs +// value is part of the opcode and should be bound to the semantic function at +// decode. +void Vmvr(int num_regs, Instruction *inst); +// Vector logical right shift with rounding. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vssrl(Instruction *inst); +// Vector arithmetic right shift with rounding. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vssra(Instruction *inst); +// Vector logical right shift with rounding and (unsigned) saturation from SEW * +// 2 to SEW wide elements. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vnclipu(Instruction *inst); +// Vector arithmetic right shift with rounding and (signed) saturation from SEW +// * 2 to SEW wide elements. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector), rs1 (scalar), or an immediate. +// Source 2 is a vector mask operand. +void Vnclip(Instruction *inst); +// Vector fractional multiply. This instruction takes three +// source operands and a vector destination operand. Source 0 is the vs2 vector +// source. Source 1 is either vs1 (vector) or rs1 (scalar). Source 2 is a vector +// mask operand. +void Vsmul(Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPI_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_opm_instructions.cc b/cheriot/riscv_cheriot_vector_opm_instructions.cc new file mode 100644 index 0000000..78617ef --- /dev/null +++ b/cheriot/riscv_cheriot_vector_opm_instructions.cc
@@ -0,0 +1,1302 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_opm_instructions.h" + +#include <cstdint> +#include <functional> +#include <limits> +#include <type_traits> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_register.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::WideType; + +// Helper function used to factor out some code from Vaadd* instructions. +template <typename T> +inline T VaaddHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { + // Perform the addition using a wider type, then shift and round. + using WT = typename WideType<T>::type; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + auto res = RoundOff(rv_vector, vs2_w + vs1_w, 1); + return static_cast<T>(res); +} + +// Average unsigned add. The two sources are added, then shifted right by one +// and rounded. +void Vaaddu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Average signed add. The two sources are added, then shifted right by one and +// rounded. +void Vaadd(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VaaddHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Helper function for Vasub* instructions. Subract using a wider type, then +// round. +template <typename T> +inline T VasubHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { + using WT = typename WideType<T>::type; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + auto res = RoundOff(rv_vector, vs2_w - vs1_w, 1); + return static_cast<T>(res); +} + +// Averaging unsigned subtract - subtract then shift right by 1 and round. +void Vasubu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Averaging signed subtract. Subtract then shift right by 1 and round. +void Vasub(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { + return VasubHelper(rv_vector, vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Mask operands only operate on a single vector register. This helper function +// is used by the following bitwise mask manipulation instruction semantic +// functions. +static inline void BitwiseMaskBinaryOp( + CheriotVectorState *rv_vector, const Instruction *inst, + std::function<uint8_t(uint8_t, uint8_t)> op) { + if (rv_vector->vector_exception()) return; + int vstart = rv_vector->vstart(); + int vlen = rv_vector->vector_length(); + // Get spans for vector source and destination registers. + auto *vs2_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto vs2_span = vs2_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto *vs1_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto vs1_span = vs1_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto *vd_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto *vd_db = vd_op->CopyDataBuffer(); + auto vd_span = vd_db->Get<uint8_t>(); + // Compute start and end locations. + int start_byte = vstart / 8; + int start_offset = vstart % 8; + uint8_t start_mask = 0b1111'1111 << start_offset; + int end_byte = (vlen - 1) / 8; + int end_offset = (vlen - 1) % 8; + uint8_t end_mask = 0b1111'1111 >> (7 - end_offset); + // The start byte is computed first, applying a mask to mask out any preceding + // bits. + vd_span[start_byte] = + (op(vs2_span[start_byte], vs1_span[start_byte]) & start_mask) | + (vd_span[start_byte] & ~start_mask); + // Perform the bitwise operation on each byte between start and end. + for (int i = start_byte + 1; i < end_byte; i++) { + vd_span[i] = op(vs2_span[i], vs1_span[i]); + } + // Perform the bitwise operation with a mask on the end byte. + vd_span[end_byte] = (op(vs2_span[end_byte], vs1_span[end_byte]) & end_mask) | + (vd_span[end_byte] & ~end_mask); + vd_db->Submit(); + rv_vector->clear_vstart(); +} + +// Bitwise vector mask instructions. The operation is clear by their name. +void Vmandnot(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 & ~vs1; + }); +} + +void Vmand(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 & vs1; + }); +} +void Vmor(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 | vs1; + }); +} +void Vmxor(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 ^ vs1; + }); +} +void Vmornot(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return vs2 | ~vs1; + }); +} +void Vmnand(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return ~(vs2 & vs1); + }); +} +void Vmnor(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return ~(vs2 | vs1); + }); +} +void Vmxnor(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return ~(vs2 ^ vs1); + }); +} + +// Vector unsigned divide. Note, just like the scalar divide instruction, a +// divide by zero does not cause an exception, instead it returns all 1s. +void Vdivu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed divide. Divide by 0 returns all 1s. If -1 is divided by the largest +// magnitude negative number, it returns that negative number. +void Vdiv(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { + if (vs1 == 0) return static_cast<int8_t>(-1); + if ((vs1 == -1) && (vs2 == std::numeric_limits<int8_t>::min())) { + return std::numeric_limits<int8_t>::min(); + } + return vs2 / vs1; + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + if (vs1 == 0) return static_cast<int16_t>(-1); + if ((vs1 == -1) && (vs2 == std::numeric_limits<int16_t>::min())) { + return std::numeric_limits<int16_t>::min(); + } + return vs2 / vs1; + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + if (vs1 == 0) return static_cast<int32_t>(-1); + if ((vs1 == -1) && (vs2 == std::numeric_limits<int32_t>::min())) { + return std::numeric_limits<int32_t>::min(); + } + return vs2 / vs1; + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + if (vs1 == 0) return static_cast<int64_t>(-1); + if ((vs1 == -1) && (vs2 == std::numeric_limits<int64_t>::min())) { + return std::numeric_limits<int64_t>::min(); + } + return vs2 / vs1; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Unsigned remainder. If the denominator is 0, it returns the enumerator. +void Vremu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed remainder. If the denominator is 0, it returns the enumerator. +void Vrem(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Helper function for multiply high. It promotes the to arguments to wider +// types, performs the multiplication, returns the high half of the result. +template <typename T> +inline T VmulHighHelper(T vs2, T vs1) { + using WT = typename WideType<T>::type; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + WT prod = vs2_w * vs1_w; + prod >>= sizeof(T) * 8; + return static_cast<T>(prod); +} + +// Multiply high, unsigned. +void Vmulhu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { + return VmulHighHelper(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + return VmulHighHelper(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { + return VmulHighHelper(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { + return VmulHighHelper(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed multiply. Note, that signed and unsigned multiply operations have the +// same result for the low half of the product. +void Vmul(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 * vs1; }); + case 2: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { + uint32_t vs2_32 = vs2; + uint32_t vs1_32 = vs1; + return static_cast<uint16_t>(vs2_32 * vs1_32); + }); + case 4: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 * vs1; }); + case 8: + // The 64 bit version is treated a little differently. Because the vs1 + // operand may come from a register which may be 32 bits wide, it's first + // converted to int64_t. Then the product is done on unsigned numbers to + // avoid a signed multiply overflow, and returned as a signed number. + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + uint64_t vs2_u = vs2; + uint64_t vs1_u = vs1; + uint64_t prod = vs2_u * vs1_u; + return prod; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Helper for signed-unsigned multiplication return high half. +template <typename T> +inline typename std::make_signed<T>::type VmulHighSUHelper( + typename std::make_signed<T>::type vs2, + typename std::make_unsigned<T>::type vs1) { + using WT = typename WideType<T>::type; + using WST = typename WideType<typename std::make_signed<T>::type>::type; + WST vs2_w = static_cast<WST>(vs2); + WT vs1_w = static_cast<WT>(vs1); + WST prod = vs2_w * vs1_w; + prod >>= sizeof(T) * 8; + return static_cast<typename std::make_signed<T>::type>(prod); +} + +// Multiply signed unsigned and return the high half. +void Vmulhsu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, uint8_t>( + rv_vector, inst, [](int8_t vs2, uint8_t vs1) -> int8_t { + return VmulHighSUHelper<int8_t>(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, uint16_t>( + rv_vector, inst, [](int16_t vs2, uint16_t vs1) -> int16_t { + return VmulHighSUHelper<int16_t>(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, uint32_t>( + rv_vector, inst, [](int32_t vs2, uint32_t vs1) -> int32_t { + return VmulHighSUHelper<int32_t>(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, uint64_t>( + rv_vector, inst, [](int64_t vs2, uint64_t vs1) -> int64_t { + return VmulHighSUHelper<int64_t>(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed multiply, return high half. +void Vmulh(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { + return VmulHighHelper(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { + return VmulHighHelper(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { + return VmulHighHelper(vs2, vs1); + }); + case 8: + return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { + return VmulHighHelper(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Multiply-add. +void Vmadd(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { + uint8_t prod = vs1 * vd; + return prod + vs2; + }); + case 2: + return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { + uint32_t vs2_32 = vs2; + uint32_t vs1_32 = vs1; + uint32_t vd_32 = vd; + return static_cast<uint16_t>(vs1_32 * vd_32 + vs2_32); + }); + case 4: + return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { + return vs1 * vd + vs2; + }); + case 8: + return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { + return vs1 * vd + vs2; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Negated multiply and add. +void Vnmsub(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { + return -(vs1 * vd) + vs2; + }); + case 2: + return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { + uint32_t vs2_32 = vs2; + uint32_t vs1_32 = vs1; + uint32_t vd_32 = vd; + return static_cast<uint16_t>(-(vs1_32 * vd_32) + vs2_32); + }); + case 4: + return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { + return -(vs1 * vd) + vs2; + }); + case 8: + return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { + return -(vs1 * vd) + vs2; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Multiply add overwriting the sum. +void Vmacc(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { + return vs1 * vs2 + vd; + }); + case 2: + return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { + uint32_t vs2_32 = vs2; + uint32_t vs1_32 = vs1; + uint32_t vd_32 = vd; + return static_cast<uint16_t>(vs1_32 * vs2_32 + vd_32); + }); + case 4: + return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { + return vs1 * vs2 + vd; + }); + case 8: + return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { + return vs1 * vs2 + vd; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Negated multiply add, overwriting sum. +void Vnmsac(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { + return -(vs1 * vs2) + vd; + }); + case 2: + return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { + uint32_t vs2_32 = vs2; + uint32_t vs1_32 = vs1; + uint32_t vd_32 = vd; + return static_cast<uint16_t>(-(vs1_32 * vs2_32) + vd_32); + }); + case 4: + return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { + return -(vs1 * vs2) + vd; + }); + case 8: + return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { + return -(vs1 * vs2) + vd; + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening unsigned add. +void Vwaddu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t { + return static_cast<uint16_t>(vs2) + static_cast<uint16_t>(vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { + return static_cast<uint32_t>(vs2) + static_cast<uint32_t>(vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { + return static_cast<uint64_t>(vs2) + static_cast<uint64_t>(vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening unsigned subtract. +void Vwsubu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t { + return static_cast<uint16_t>(vs2) - static_cast<uint16_t>(vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { + return static_cast<uint32_t>(vs2) - static_cast<uint32_t>(vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { + return static_cast<uint64_t>(vs2) - static_cast<uint64_t>(vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed add. +void Vwadd(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + // The values are first sign extended to the wide signed value, then + // an unsigned addition is performed, for which overflow is not undefined, + // as opposed to signed additions. + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t { + return static_cast<uint16_t>(static_cast<int16_t>(vs2)) + + static_cast<uint16_t>(static_cast<int16_t>(vs1)); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t { + return static_cast<uint32_t>(static_cast<int32_t>(vs2)) + + static_cast<uint32_t>(static_cast<int32_t>(vs1)); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t { + return static_cast<uint64_t>(static_cast<int64_t>(vs2)) + + static_cast<uint64_t>(static_cast<int64_t>(vs1)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed subtract. +void Vwsub(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } // The values are first sign extended to the wide signed value, then + // an unsigned subtraction is performed, for which overflow is not undefined, + // as opposed to signed subtraction. + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t { + return static_cast<uint16_t>(static_cast<int16_t>(vs2)) - + static_cast<uint16_t>(static_cast<int16_t>(vs1)); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t { + return static_cast<uint32_t>(static_cast<int32_t>(vs2)) - + static_cast<uint32_t>(static_cast<int32_t>(vs1)); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t { + return static_cast<uint64_t>(static_cast<int64_t>(vs2)) - + static_cast<uint64_t>(static_cast<int64_t>(vs1)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening unsigned add with wide source. +void Vwadduw(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>( + rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t { + return vs2 + static_cast<uint16_t>(vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>( + rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t { + return vs2 + static_cast<uint32_t>(vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>( + rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t { + return vs2 + static_cast<uint64_t>(vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening unsigned subtract with wide source. +void Vwsubuw(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>( + rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t { + return vs2 - static_cast<uint16_t>(vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>( + rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t { + return vs2 - static_cast<uint32_t>(vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>( + rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t { + return vs2 - static_cast<uint64_t>(vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed add with wide source. +void Vwaddw(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int16_t, uint16_t, int8_t>( + rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t { + return vs2 + static_cast<uint16_t>(static_cast<int16_t>(vs1)); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>( + rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t { + return vs2 + static_cast<uint32_t>(static_cast<int32_t>(vs1)); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>( + rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t { + return vs2 + static_cast<uint64_t>(static_cast<int64_t>(vs1)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed subtract with wide source. +void Vwsubw(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint16_t, int8_t>( + rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t { + return vs2 - static_cast<uint16_t>(static_cast<int16_t>(vs1)); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>( + rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t { + return vs2 - static_cast<uint32_t>(static_cast<int32_t>(vs1)); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>( + rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t { + return vs2 - static_cast<uint64_t>(static_cast<int64_t>(vs1)); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening multiply helper function. Factors out some code. +template <typename T> +inline typename WideType<T>::type VwmulHelper(T vs2, T vs1) { + using WT = typename WideType<T>::type; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + WT prod = vs2_w * vs1_w; + return prod; +} + +// Unsigned widening multiply. +void Vwmulu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t vs2, int8_t vs1) -> uint16_t { + return VwmulHelper<uint8_t>(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { + return VwmulHelper<uint16_t>(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { + return VwmulHelper<uint32_t>(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed-unsigned multiply helper function. +template <typename T> +inline typename WideType<typename std::make_signed<T>::type>::type +VwmulSuHelper(typename std::make_signed<T>::type vs2, + typename std::make_unsigned<T>::type vs1) { + using WST = typename WideType<typename std::make_signed<T>::type>::type; + using WT = typename WideType<typename std::make_unsigned<T>::type>::type; + WST vs2_w = static_cast<WST>(vs2); + WT vs1_w = static_cast<WT>(vs1); + WST prod = vs2_w * vs1_w; + return prod; +} + +// Widening multiply signed-unsigned. +void Vwmulsu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int16_t, int8_t, uint8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t { + return VwmulSuHelper<int8_t>(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int32_t, int16_t, uint16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t { + return VwmulSuHelper<int16_t>(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int64_t, int32_t, uint32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t { + return VwmulSuHelper<int32_t>(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed multiply. +void Vwmul(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVBinaryVectorOp<int16_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t { + return VwmulHelper<int8_t>(vs2, vs1); + }); + case 2: + return RiscVBinaryVectorOp<int32_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t { + return VwmulHelper<int16_t>(vs2, vs1); + }); + case 4: + return RiscVBinaryVectorOp<int64_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t { + return VwmulHelper<int32_t>(vs2, vs1); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening multiply accumulate helper function. +template <typename Vd, typename Vs2, typename Vs1> +Vd VwmaccHelper(Vs2 vs2, Vs1 vs1, Vd vd) { + Vd vs2_w = static_cast<Vd>(vs2); + Vd vs1_w = static_cast<Vd>(vs1); + Vd prod = vs2_w * vs1_w; + using UVd = typename std::make_unsigned<Vd>::type; + Vd res = absl::bit_cast<UVd>(prod) + absl::bit_cast<UVd>(vd); + return res; +} + +// Unsigned widening multiply and add. +void Vwmaccu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVTernaryVectorOp<uint16_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2, uint8_t vs1, uint16_t vd) -> uint16_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 2: + return RiscVTernaryVectorOp<uint32_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2, uint16_t vs1, uint32_t vd) -> uint32_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 4: + return RiscVTernaryVectorOp<uint64_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2, uint32_t vs1, uint64_t vd) -> uint64_t { + return VwmaccHelper(vs2, vs1, vd); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed multiply and add. +void Vwmacc(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVTernaryVectorOp<int16_t, int8_t, int8_t>( + rv_vector, inst, [](int8_t vs2, int8_t vs1, int16_t vd) -> int16_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 2: + return RiscVTernaryVectorOp<int32_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t vs2, int16_t vs1, int32_t vd) -> int32_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 4: + return RiscVTernaryVectorOp<int64_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t vs2, int32_t vs1, int64_t vd) -> int64_t { + return VwmaccHelper(vs2, vs1, vd); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening unsigned-signed multiply and add. +void Vwmaccus(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVTernaryVectorOp<int16_t, int8_t, uint8_t>( + rv_vector, inst, [](int8_t vs2, uint8_t vs1, int16_t vd) -> int16_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 2: + return RiscVTernaryVectorOp<int32_t, int16_t, uint16_t>( + rv_vector, inst, + [](int16_t vs2, uint16_t vs1, int32_t vd) -> int32_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 4: + return RiscVTernaryVectorOp<int64_t, int32_t, uint32_t>( + rv_vector, inst, + [](int32_t vs2, uint32_t vs1, int64_t vd) -> int64_t { + return VwmaccHelper(vs2, vs1, vd); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Widening signed-unsigned multiply and add. +void Vwmaccsu(const Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + // LMUL8 cannot be 64. + if (rv_vector->vector_length_multiplier() > 32) { + rv_vector->set_vector_exception(); + LOG(ERROR) + << "Vector length multiplier out of range for widening operation"; + return; + } + switch (sew) { + case 1: + return RiscVTernaryVectorOp<int16_t, uint8_t, int8_t>( + rv_vector, inst, [](uint8_t vs2, int8_t vs1, int16_t vd) -> int16_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 2: + return RiscVTernaryVectorOp<int32_t, uint16_t, int16_t>( + rv_vector, inst, + [](uint16_t vs2, int16_t vs1, int32_t vd) -> int32_t { + return VwmaccHelper(vs2, vs1, vd); + }); + case 4: + return RiscVTernaryVectorOp<int64_t, uint32_t, int32_t>( + rv_vector, inst, + [](uint32_t vs2, int32_t vs1, int64_t vd) -> int64_t { + return VwmaccHelper(vs2, vs1, vd); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_opm_instructions.h b/cheriot/riscv_cheriot_vector_opm_instructions.h new file mode 100644 index 0000000..9dbd61f --- /dev/null +++ b/cheriot/riscv_cheriot_vector_opm_instructions.h
@@ -0,0 +1,180 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPM_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPM_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Integer vector operations. +// Integer average unsigned add. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vaaddu(const Instruction *inst); +// Integer average signed add. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vaadd(const Instruction *inst); +// Integer average unsigned subtract. This instruction takes 3 source operands +// and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vasubu(const Instruction *inst); +// Integer average signed subtract. This instruction takes 3 source operands +// and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vasub(const Instruction *inst); +// The following instructions are vector mask logical operations. Each takes +// two source operands, vs2 and vs1 vector registers, and one destination +// operand, vd vector destination register. +void Vmandnot(const Instruction *inst); +void Vmand(const Instruction *inst); +void Vmor(const Instruction *inst); +void Vmxor(const Instruction *inst); +void Vmornot(const Instruction *inst); +void Vmnand(const Instruction *inst); +void Vmnor(const Instruction *inst); +void Vmxnor(const Instruction *inst); +// Integer unsigned division. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vdivu(const Instruction *inst); +// Integer signed division. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vdiv(const Instruction *inst); +// Integer unsigned remainder. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vremu(const Instruction *inst); +// Integer signed remainder. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vrem(const Instruction *inst); +// Integer unsigned multiply high. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vmulhu(const Instruction *inst); +// Integer signed multiply. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vmul(const Instruction *inst); +// Integer signed-unsigned multiply high. This instruction takes 3 source +// operands and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and +// source 2 is vector mask. The destination operand is a vector register group. +void Vmulhsu(const Instruction *inst); +// Integer signed multiply high. This instruction takes 3 source operands and +// one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vmulh(const Instruction *inst); +// Integer multiply add (vs1 * vd) + vs2. This instruction takes 4 source +// operands and one destination operand. Source 0 is vs2, source 1 is vs1 (or +// rs1), source 2 is Vd as a source operand, and source 4 is vector mask. The +// destination operand is the Vd register group. +void Vmadd(const Instruction *inst); +// Integer multiply subtract -(vs1 * vd) + vs2. This instruction takes 4 +// source operands and one destination operand. Source 0 is vs2, source 1 is vs1 +// (or rs1), source 2 is Vd as a source operand, and source 4 is vector mask. +// The destination operand is the Vd register group. +void Vnmsub(const Instruction *inst); +// Integer multiply add (vs1 * vs2) + vd. This instruction takes 4 source +// operands and one destination operand. Source 0 is vs2, source 1 is vs1 (or +// rs1), source 2 is Vd as a source operand, and source 4 is vector mask. The +// destination operand is the Vd register group. +void Vmacc(const Instruction *inst); +// Integer multiply subtract -(vs1 * vs2) + vd. This instruction takes 4 source +// operands and one destination operand. Source 0 is vs2, source 1 is vs1 (or +// rs1), source 2 is Vd as a source operand, and source 4 is vector mask. The +// destination operand is the Vd register group. +void Vnmsac(const Instruction *inst); +// Integer widening unsigned addition. This instruction takes 3 source operands +// and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vwaddu(const Instruction *inst); +// Integer widening signed addition. This instruction takes 3 source operands +// and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vwadd(const Instruction *inst); +// Integer widening unsigned subtraction. This instruction takes 3 source +// operands and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and +// source 2 is vector mask. The destination operand is a vector register group. +void Vwsubu(const Instruction *inst); +// Integer widening signed subtraction. This instruction takes 3 source operands +// and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and source 2 +// is vector mask. The destination operand is a vector register group. +void Vwsub(const Instruction *inst); +// Integer widening unsigned addition with one wide source operand. This +// instruction takes 3 source operands and one destination. Source 0 is vs2 +// (wide), source 1 is vs1 (or rs1), and source 2 is vector mask. The +// destination operand is a vector register group. +void Vwadduw(const Instruction *inst); +// Integer widening signed addition with one wide source operand. This +// instruction takes 3 source operands and one destination. Source 0 is vs2 +// (wide), source 1 is vs1 (or rs1), and source 2 is vector mask. The +// destination operand is a vector register group. +void Vwaddw(const Instruction *inst); +// Integer widening unsigned subtraction with one wide source operand. This +// instruction takes 3 source operands and one destination. Source 0 is vs2 +// (wide), source 1 is vs1 (or rs1), and source 2 is vector mask. The +// destination operand is a vector register group. +void Vwsubuw(const Instruction *inst); +// Integer widening signed subtraction with one wide source operand. This +// instruction takes 3 source operands and one destination. Source 0 is vs2 +// (wide), source 1 is vs1 (or rs1), and source 2 is vector mask. The +// destination operand is a vector register group. +void Vwsubw(const Instruction *inst); +// Integer widening unsigned multiplication. This instruction takes 3 source +// operands and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and +// source 2 is vector mask. The destination operand is a vector register group. +void Vwmulu(const Instruction *inst); +// Integer widening signed by unsigned multiplication. This instruction takes 3 +// source operands and one destination. Source 0 is vs2, source 1 is vs1 (or +// rs1), and source 2 is vector mask. The destination operand is a vector +// register group. +void Vwmulsu(const Instruction *inst); +// Integer widening signed multiplication. This instruction takes 3 source +// operands and one destination. Source 0 is vs2, source 1 is vs1 (or rs1), and +// source 2 is vector mask. The destination operand is a vector register group. +void Vwmul(const Instruction *inst); +// Integer widening signed multiply and add (vs2 * vs1) + vd. This instruction +// takes 4 source operands and one destination operand. Source 0 is vs2, source +// 1 is vs1 (or rs1), source 2 is Vd as a source operand, and source 4 is vector +// mask. The destination operand is the Vd register group. +void Vwmaccu(const Instruction *inst); +// Integer widening unsigned multiply and add (vs2 * vs1) + vd. This instruction +// takes 4 source operands and one destination operand. Source 0 is vs2, source +// 1 is vs1 (or rs1), source 2 is Vd as a source operand, and source 4 is vector +// mask. The destination operand is the Vd register group. +void Vwmacc(const Instruction *inst); +// Integer widening unsigned by signed multiply and add (vs2 * vs1) + vd. This +// instruction takes 4 source operands and one destination operand. Source 0 is +// vs2, source 1 is vs1 (or rs1), source 2 is Vd as a source operand, and source +// 4 is vector mask. The destination operand is the Vd register group. +void Vwmaccus(const Instruction *inst); +// Integer widening signed by unsigned multiply and add (vs2 * vs1) + vd. This +// instruction takes 4 source operands and one destination operand. Source 0 is +// vs2, source 1 is vs1 (or rs1), source 2 is Vd as a source operand, and source +// 4 is vector mask. The destination operand is the Vd register group. +void Vwmaccsu(const Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_OPM_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_permute_instructions.cc b/cheriot/riscv_cheriot_vector_permute_instructions.cc new file mode 100644 index 0000000..b5bad6e --- /dev/null +++ b/cheriot/riscv_cheriot_vector_permute_instructions.cc
@@ -0,0 +1,465 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_permute_instructions.h" + +#include <algorithm> +#include <cstdint> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/generic/instruction.h" +#include "riscv//riscv_register.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; + +// This helper function handles the vector gather operations. +template <typename Vd, typename Vs2, typename Vs1> +void VrgatherHelper(CheriotVectorState *rv_vector, Instruction *inst) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + // Verify that the lmul is compatible with index size. + int index_emul = + rv_vector->vector_length_multiplier() * sizeof(Vs1) / sizeof(Vd); + if (index_emul > 64) { + rv_vector->set_vector_exception(); + return; + } + int max_regs = std::max( + 1, (num_elements + elements_per_vector - 1) / elements_per_vector); + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute the where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Determine if it's vector-vector or vector-scalar. + bool vector_scalar = inst->Source(1)->shape()[0] == 1; + auto src0_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + int max_index = src0_op->size() * elements_per_vector; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + if (mask_value) { + // Compute result. + CheriotRegister::ValueType vs1; + if (vector_scalar) { + vs1 = generic::GetInstructionSource<CheriotRegister::ValueType>(inst, + 1, 0); + } else { + vs1 = generic::GetInstructionSource<Vs1>(inst, 1, vector_index); + } + Vs2 vs2 = 0; + if (vs1 < max_index) { + vs2 = generic::GetInstructionSource<Vs2>(inst, 0, vs1); + } + dest_span[i] = vs2; + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +// Vector register gather. +void Vrgather(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return VrgatherHelper<uint8_t, uint8_t, uint8_t>(rv_vector, inst); + case 2: + return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst); + case 4: + return VrgatherHelper<uint32_t, uint32_t, uint32_t>(rv_vector, inst); + case 8: + return VrgatherHelper<uint64_t, uint64_t, uint64_t>(rv_vector, inst); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Vector register gather with 16 bit indices. +void Vrgatherei16(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return VrgatherHelper<uint8_t, uint8_t, uint16_t>(rv_vector, inst); + case 2: + return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst); + case 4: + return VrgatherHelper<uint32_t, uint32_t, uint16_t>(rv_vector, inst); + case 8: + return VrgatherHelper<uint64_t, uint64_t, uint16_t>(rv_vector, inst); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// This helper function handles the vector slide up/down instructions. +template <typename Vd> +void VSlideHelper(CheriotVectorState *rv_vector, Instruction *inst, + int offset) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = std::max( + 1, (num_elements + elements_per_vector - 1) / elements_per_vector); + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute the where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1); + int src_index = vector_index - offset; + if ((src_index >= 0) && (mask_value)) { + // Compute result. + Vd src_value = 0; + if (src_index < rv_vector->max_vector_length()) { + src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index); + } + dest_span[i] = src_value; + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +void Vslideup(Instruction *inst) { + using ValueType = CheriotRegister::ValueType; + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0); + int int_offset = static_cast<int>(offset); + if (offset > rv_vector->max_vector_length()) return; + // Slide up amount is positive. + switch (sew) { + case 1: + return VSlideHelper<uint8_t>(rv_vector, inst, int_offset); + case 2: + return VSlideHelper<uint16_t>(rv_vector, inst, int_offset); + case 4: + return VSlideHelper<uint32_t>(rv_vector, inst, int_offset); + case 8: + return VSlideHelper<uint64_t>(rv_vector, inst, int_offset); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +void Vslidedown(Instruction *inst) { + using ValueType = CheriotRegister::ValueType; + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0); + // Slide down amount is negative. + int int_offset = -static_cast<int>(offset); + switch (sew) { + case 1: + return VSlideHelper<uint8_t>(rv_vector, inst, int_offset); + case 2: + return VSlideHelper<uint16_t>(rv_vector, inst, int_offset); + case 4: + return VSlideHelper<uint32_t>(rv_vector, inst, int_offset); + case 8: + return VSlideHelper<uint64_t>(rv_vector, inst, int_offset); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// This helper function handles the vector slide up/down 1 instructions. +template <typename Vd> +void VSlide1Helper(CheriotVectorState *rv_vector, Instruction *inst, + int offset) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = std::max( + 1, (num_elements + elements_per_vector - 1) / elements_per_vector); + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute the where to start + // the operation. + int vector_index = rv_vector->vstart(); + int start_reg = vector_index / elements_per_vector; + int item_index = vector_index % elements_per_vector; + auto slide_value = generic::GetInstructionSource<Vd>(inst, 1, 0); + // Iterate over the number of registers to write. + for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); + reg++) { + // Allocate data buffer for the new register data. + auto *dest_db = dest_op->CopyDataBuffer(reg); + auto dest_span = dest_db->Get<Vd>(); + // Write data into register subject to masking. + int element_count = std::min(elements_per_vector, num_elements); + for (int i = item_index; + (i < element_count) && (vector_index < num_elements); i++) { + // Get the mask value. + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; + if (mask_value) { + // Compute result. + Vd src_value = slide_value; + int src_index = vector_index - offset; + if ((src_index > 0) && (src_index < rv_vector->max_vector_length())) { + src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index); + } + dest_span[i] = src_value; + } + vector_index++; + } + // Submit the destination db . + dest_db->Submit(); + item_index = 0; + } + rv_vector->clear_vstart(); +} + +void Vslide1up(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return VSlide1Helper<uint8_t>(rv_vector, inst, 1); + case 2: + return VSlide1Helper<uint16_t>(rv_vector, inst, 1); + case 4: + return VSlide1Helper<uint32_t>(rv_vector, inst, 1); + case 8: + return VSlide1Helper<uint64_t>(rv_vector, inst, 1); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +void Vslide1down(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return VSlide1Helper<uint8_t>(rv_vector, inst, -1); + case 2: + return VSlide1Helper<uint16_t>(rv_vector, inst, -1); + case 4: + return VSlide1Helper<uint32_t>(rv_vector, inst, -1); + case 8: + return VSlide1Helper<uint64_t>(rv_vector, inst, -1); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +void Vfslide1up(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return VSlide1Helper<uint32_t>(rv_vector, inst, 1); + case 8: + return VSlide1Helper<uint64_t>(rv_vector, inst, 1); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +void Vfslide1down(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return VSlide1Helper<uint32_t>(rv_vector, inst, -1); + case 8: + return VSlide1Helper<uint64_t>(rv_vector, inst, -1); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +template <typename Vd> +void VCompressHelper(CheriotVectorState *rv_vector, Instruction *inst) { + if (rv_vector->vector_exception()) return; + int num_elements = rv_vector->vector_length(); + int elements_per_vector = + rv_vector->vector_register_byte_length() / sizeof(Vd); + int max_regs = std::max( + 1, (num_elements + elements_per_vector - 1) / elements_per_vector); + auto *dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + // Verify that there are enough registers in the destination operand. + if (dest_op->size() < max_regs) { + rv_vector->set_vector_exception(); + LOG(ERROR) << absl::StrCat( + "Vector destination '", dest_op->AsString(), "' has fewer registers (", + dest_op->size(), ") than required by the operation (", max_regs, ")"); + return; + } + // Get the vector mask. + auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Get the vector start element index and compute the where to start + // the operation. + int vector_index = rv_vector->vstart(); + int dest_index = 0; + int prev_reg = -1; + absl::Span<Vd> dest_span; + generic::DataBuffer *dest_db = nullptr; + // Iterate over the input elements. + for (int i = vector_index; i < num_elements; i++) { + // Get mask value. + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + // Compute destination register. + int reg = dest_index / elements_per_vector; + if (prev_reg != reg) { + // Submit previous data buffer if needed. + if (dest_db != nullptr) dest_db->Submit(); + // Allocate a data buffer. + dest_db = dest_op->CopyDataBuffer(reg); + dest_span = dest_db->Get<Vd>(); + prev_reg = reg; + } + // Copy the source value to the dest_index. + Vd src_value = generic::GetInstructionSource<Vd>(inst, 0, i); + dest_span[dest_index % elements_per_vector] = src_value; + ++dest_index; + } + } + if (dest_db != nullptr) dest_db->Submit(); + rv_vector->clear_vstart(); +} + +void Vcompress(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return VCompressHelper<uint8_t>(rv_vector, inst); + case 2: + return VCompressHelper<uint16_t>(rv_vector, inst); + case 4: + return VCompressHelper<uint32_t>(rv_vector, inst); + case 8: + return VCompressHelper<uint64_t>(rv_vector, inst); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_permute_instructions.h b/cheriot/riscv_cheriot_vector_permute_instructions.h new file mode 100644 index 0000000..9ccc612 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_permute_instructions.h
@@ -0,0 +1,89 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_PERMUTE_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_PERMUTE_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Vector register gather instruction. This instruction takes three source +// operands and one destination operand. Source 0 is the vector from which +// elements are gathered, source 1 is the index vector, and source 2 is the +// vector mask register. The destination operand is the target vector register. +void Vrgather(Instruction *inst); +// Vector register gather instruction with 16 bit indices. This instruction +// takes three source operands and one destination operand. Source 0 is the +// vector from which elements are gathered, source 1 is the index vector, and +// source 2 is the vector mask register. The destination operand is the target +// vector register. +void Vrgatherei16(Instruction *inst); +// Vector slide up instruction. This instruction takes three source operands +// and one destination operand. Source 0 is the vector source register that +// contains the values that are 'slid' up. Source 1 is a scalar register or +// immediate that specifies the number of 'entries' by which source 0 values +// are slid up. Source 2 is the vector mask register. The destination operand +// is the target vector register. +void Vslideup(Instruction *inst); +// Vector slide down instruction. This instruction takes three source operands +// and one destination operand. Source 0 is the vector source register that +// contains the values that are 'slid' down. Source 1 is a scalar register or +// immediate that specifies the number of 'entries' by which source 0 values +// are slid down. Source 2 is the vector mask register. The destination operand +// is the target vector register. +void Vslidedown(Instruction *inst); +// Vector slide up instruction. This instruction takes three source operands +// and one destination operand. Source 0 is the vector source register that +// contains the values that are 'slid' up by 1. Source 1 is a scalar register or +// immediate that specifies the value written into the 'empty' slot. Source 2 is +// the vector mask register. The destination operand is the target vector +// register. +void Vslide1up(Instruction *inst); +// Vector slide down instruction. This instruction takes three source operands +// and one destination operand. Source 0 is the vector source register that +// contains the values that are 'slid' down. Source 1 is a scalar register or +// immediate that specifies the value written into the 'empty' slot. Source 2 is +// the vector mask register. The destination operand is the target vector +// register. +void Vslide1down(Instruction *inst); +// Vector fp slide up instruction. This instruction takes three source operands +// and one destination operand. Source 0 is the vector source register that +// contains the values that are 'slid' up by 1. Source 1 is a floating point +// register or immediate that specifies the value written into the 'empty' slot. +// Source 2 is the vector mask register. The destination operand is the target +// vector register. +void Vfslide1up(Instruction *inst); +// Vector fp slide down instruction. This instruction takes three source +// operands and one destination operand. Source 0 is the vector source register +// that contains the values that are 'slid' down. Source 1 is a floating point +// register or immediate that specifies the value written into the 'empty' slot. +// Source 2 is the vector mask register. The destination operand is the target +// vector register. +void Vfslide1down(Instruction *inst); +// Vector compress instruction. This instruction takes two source operands and +// one destination operand. Source 0 is the source value vector register. Source +// 1 is a mask register, with specifies which elements of source 0 should be +// selected and packed into the destination register. +void Vcompress(Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_PERMUTE_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_reduction_instructions.cc b/cheriot/riscv_cheriot_vector_reduction_instructions.cc new file mode 100644 index 0000000..8be7f68 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_reduction_instructions.cc
@@ -0,0 +1,355 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_reduction_instructions.h" + +#include <algorithm> +#include <cstdint> + +#include "absl/log/log.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/instruction.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using ::mpact::sim::generic::Instruction; + +// Sum reduction. +void Vredsum(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t acc, uint8_t vs2) -> uint8_t { return acc + vs2; }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t acc, uint16_t vs2) -> uint16_t { return acc + vs2; }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t acc, uint32_t vs2) -> uint32_t { return acc + vs2; }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t acc, uint64_t vs2) -> uint64_t { return acc + vs2; }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// And reduction. +void Vredand(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t acc, uint8_t vs2) -> uint8_t { return acc & vs2; }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t acc, uint16_t vs2) -> uint16_t { return acc & vs2; }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t acc, uint32_t vs2) -> uint32_t { return acc & vs2; }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t acc, uint64_t vs2) -> uint64_t { return acc & vs2; }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Or reduction. +void Vredor(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t acc, uint8_t vs2) -> uint8_t { return acc | vs2; }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t acc, uint16_t vs2) -> uint16_t { return acc | vs2; }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t acc, uint32_t vs2) -> uint32_t { return acc | vs2; }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t acc, uint64_t vs2) -> uint64_t { return acc | vs2; }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Xor reduction. +void Vredxor(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, + [](uint8_t acc, uint8_t vs2) -> uint8_t { return acc ^ vs2; }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, + [](uint16_t acc, uint16_t vs2) -> uint16_t { return acc ^ vs2; }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, + [](uint32_t acc, uint32_t vs2) -> uint32_t { return acc ^ vs2; }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, + [](uint64_t acc, uint64_t vs2) -> uint64_t { return acc ^ vs2; }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Unsigned min reduction. +void Vredminu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t acc, uint8_t vs2) -> uint8_t { + return std::min(acc, vs2); + }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t acc, uint16_t vs2) -> uint16_t { + return std::min(acc, vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t acc, uint32_t vs2) -> uint32_t { + return std::min(acc, vs2); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t acc, uint64_t vs2) -> uint64_t { + return std::min(acc, vs2); + }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed min reduction. +void Vredmin(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, + [](int8_t acc, int8_t vs2) -> int8_t { return std::min(acc, vs2); }); + case 2: + return RiscVBinaryReductionVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t acc, int16_t vs2) -> int16_t { + return std::min(acc, vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t acc, int32_t vs2) -> int32_t { + return std::min(acc, vs2); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t acc, int64_t vs2) -> int64_t { + return std::min(acc, vs2); + }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Unsigned max reduction. +void Vredmaxu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint8_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint8_t acc, uint8_t vs2) -> uint8_t { + return std::max(acc, vs2); + }); + case 2: + return RiscVBinaryReductionVectorOp<uint16_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint16_t acc, uint16_t vs2) -> uint16_t { + return std::max(acc, vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint32_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint32_t acc, uint32_t vs2) -> uint32_t { + return std::max(acc, vs2); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<uint64_t, uint64_t, uint64_t>( + rv_vector, inst, [](uint64_t acc, uint64_t vs2) -> uint64_t { + return std::max(acc, vs2); + }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed max reduction. +void Vredmax(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<int8_t, int8_t, int8_t>( + rv_vector, inst, + [](int8_t acc, int8_t vs2) -> int8_t { return std::max(acc, vs2); }); + case 2: + return RiscVBinaryReductionVectorOp<int16_t, int16_t, int16_t>( + rv_vector, inst, [](int16_t acc, int16_t vs2) -> int16_t { + return std::max(acc, vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<int32_t, int32_t, int32_t>( + rv_vector, inst, [](int32_t acc, int32_t vs2) -> int32_t { + return std::max(acc, vs2); + }); + return; + case 8: + return RiscVBinaryReductionVectorOp<int64_t, int64_t, int64_t>( + rv_vector, inst, [](int64_t acc, int64_t vs2) -> int64_t { + return std::max(acc, vs2); + }); + return; + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Unsigned widening (SEW->SEW * 2) reduction. +void Vwredsumu(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<uint16_t, uint8_t, uint8_t>( + rv_vector, inst, [](uint16_t acc, uint8_t vs2) -> uint16_t { + return acc + static_cast<uint16_t>(vs2); + }); + case 2: + return RiscVBinaryReductionVectorOp<uint32_t, uint16_t, uint16_t>( + rv_vector, inst, [](uint32_t acc, uint16_t vs2) -> uint32_t { + return acc + static_cast<uint32_t>(vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<uint64_t, uint32_t, uint32_t>( + rv_vector, inst, [](uint64_t acc, uint32_t vs2) -> uint64_t { + return acc + static_cast<uint64_t>(vs2); + }); + return; + case 8: + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Signed widening (SEW->SEW * 2) reduction. +void Vwredsum(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 1: + return RiscVBinaryReductionVectorOp<int16_t, int8_t, int8_t>( + rv_vector, inst, [](int16_t acc, int8_t vs2) -> int16_t { + return acc + static_cast<int16_t>(vs2); + }); + case 2: + return RiscVBinaryReductionVectorOp<int32_t, int16_t, int16_t>( + rv_vector, inst, [](int32_t acc, int16_t vs2) -> int32_t { + return acc + static_cast<int32_t>(vs2); + }); + return; + case 4: + return RiscVBinaryReductionVectorOp<int64_t, int32_t, int32_t>( + rv_vector, inst, [](int64_t acc, int32_t vs2) -> int64_t { + return acc + static_cast<int64_t>(vs2); + }); + return; + case 8: + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_reduction_instructions.h b/cheriot/riscv_cheriot_vector_reduction_instructions.h new file mode 100644 index 0000000..58c93ae --- /dev/null +++ b/cheriot/riscv_cheriot_vector_reduction_instructions.h
@@ -0,0 +1,58 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_REDUCTION_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_REDUCTION_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Each of these instruction semantic functions take 3 sources. Source 0 is +// vector register vs2, source 1 is vector register vs1, and source 2 is a +// vector mask operand. There is a single vector destination operand. +// Each reduction applies the reduction operation to the 0 element of source +// operand 1 (vs1), and all unmasked elements of source operand 0 (vs2). The +// result is written to the 0 element of destination operand vd. + +// Vector sum reduction. +void Vredsum(Instruction *inst); +// Vector and reduction. +void Vredand(Instruction *inst); +// Vector or reduction. +void Vredor(Instruction *inst); +// Vector xor reduction. +void Vredxor(Instruction *inst); +// Vector unsigned min reduction. +void Vredminu(Instruction *inst); +// Vector signed min reduction. +void Vredmin(Instruction *inst); +// Vector unsigned max reduction. +void Vredmaxu(Instruction *inst); +// Vector signed max reduction. +void Vredmax(Instruction *inst); +// Vector unsigned widening sum reduction. The result is 2 * SEW. +void Vwredsumu(Instruction *inst); +// vector signed widening sum reduction. The result is 2 * SEW. +void Vwredsum(Instruction *inst); + +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_REDUCTION_INSTRUCTIONS_H_
diff --git a/cheriot/riscv_cheriot_vector_unary_instructions.cc b/cheriot/riscv_cheriot_vector_unary_instructions.cc new file mode 100644 index 0000000..becf09a --- /dev/null +++ b/cheriot/riscv_cheriot_vector_unary_instructions.cc
@@ -0,0 +1,461 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_unary_instructions.h" + +#include <cstdint> +#include <cstring> +#include <functional> + +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_instruction_helpers.h" +#include "cheriot/riscv_cheriot_vector_instruction_helpers.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" + +namespace mpact { +namespace sim { +namespace cheriot { + +using SignedXregType = + ::mpact::sim::generic::SameSignedType<CheriotRegister::ValueType, + int64_t>::type; + +// Move scalar to vector register. +void VmvToScalar(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) return; + if (rv_vector->vector_length() == 0) return; + int sew = rv_vector->selected_element_width(); + SignedXregType value; + switch (sew) { + case 1: + value = static_cast<SignedXregType>( + generic::GetInstructionSource<int8_t>(inst, 0)); + break; + case 2: + value = static_cast<SignedXregType>( + generic::GetInstructionSource<int16_t>(inst, 0)); + break; + case 4: + value = static_cast<SignedXregType>( + generic::GetInstructionSource<int32_t>(inst, 0)); + break; + case 8: + value = static_cast<SignedXregType>( + generic::GetInstructionSource<int64_t>(inst, 0)); + break; + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs"); + rv_vector->set_vector_exception(); + return; + } + WriteCapIntResult<SignedXregType>(inst, 0, value); +} + +void VmvFromScalar(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) return; + if (rv_vector->vector_length() == 0) return; + int sew = rv_vector->selected_element_width(); + auto *dest_db = inst->Destination(0)->AllocateDataBuffer(); + std::memset(dest_db->raw_ptr(), 0, dest_db->size<uint8_t>()); + switch (sew) { + case 1: + dest_db->Set<int8_t>(0, generic::GetInstructionSource<int8_t>(inst, 0)); + break; + case 2: + dest_db->Set<int16_t>(0, generic::GetInstructionSource<int16_t>(inst, 0)); + break; + case 4: + dest_db->Set<int32_t>(0, generic::GetInstructionSource<int32_t>(inst, 0)); + break; + case 8: + dest_db->Set<int64_t>(0, generic::GetInstructionSource<int64_t>(inst, 0)); + break; + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs"); + rv_vector->set_vector_exception(); + return; + } + dest_db->Submit(); +} + +// Population count of vector mask register. The value is written to a scalar +// register. +void Vcpop(Instruction *inst) { + auto *rv_state = static_cast<CheriotState *>(inst->state()); + auto *rv_vector = rv_state->rv_vector(); + if (rv_vector->vstart()) { + rv_vector->set_vector_exception(); + return; + } + int vlen = rv_vector->vector_length(); + auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + uint64_t count = 0; + for (int i = 0; i < vlen; i++) { + int index = i >> 3; + int offset = i & 0b111; + int mask_value = (mask_span[index] >> offset); + int src_value = (src_span[index] >> offset); + count += mask_value & src_value & 0b1; + } + WriteCapIntResult<uint32_t>(inst, 0, count); +} + +// Find first set of vector mask register. The value is written to a scalar +// register. +void Vfirst(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) { + rv_vector->set_vector_exception(); + return; + } + auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + // Initialize the element index to -1. + uint64_t element_index = -1LL; + int vlen = rv_vector->vector_length(); + for (int i = 0; i < vlen; i++) { + int index = i >> 3; + int offset = i & 0b111; + int mask_value = (mask_span[index] >> offset); + int src_value = (src_span[index] >> offset); + if (mask_value & src_value & 0b1) { + element_index = i; + break; + } + } + WriteCapIntResult<uint32_t>(inst, 0, element_index); +} + +// Vector integer sign and zero extension instructions. +void Vzext2(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 2: + return RiscVUnaryVectorOp<uint16_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2) -> uint16_t { return static_cast<uint16_t>(vs2); }); + case 4: + return RiscVUnaryVectorOp<uint32_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); }); + case 8: + return RiscVUnaryVectorOp<uint64_t, uint32_t>( + rv_vector, inst, + [](uint32_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext2"); + rv_vector->set_vector_exception(); + return; + } +} + +void Vsext2(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 2: + return RiscVUnaryVectorOp<int16_t, int8_t>( + rv_vector, inst, + [](int8_t vs2) -> int16_t { return static_cast<int16_t>(vs2); }); + case 4: + return RiscVUnaryVectorOp<uint32_t, uint16_t>( + rv_vector, inst, + [](int16_t vs2) -> int32_t { return static_cast<int32_t>(vs2); }); + case 8: + return RiscVUnaryVectorOp<int64_t, int32_t>( + rv_vector, inst, + [](int32_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vsext2"); + rv_vector->set_vector_exception(); + return; + } +} + +void Vzext4(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<uint32_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); }); + case 8: + return RiscVUnaryVectorOp<uint64_t, uint16_t>( + rv_vector, inst, + [](uint16_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4"); + rv_vector->set_vector_exception(); + return; + } +} + +void Vsext4(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 4: + return RiscVUnaryVectorOp<uint32_t, uint8_t>( + rv_vector, inst, + [](int8_t vs2) -> int32_t { return static_cast<int32_t>(vs2); }); + case 8: + return RiscVUnaryVectorOp<int64_t, int16_t>( + rv_vector, inst, + [](int16_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4"); + rv_vector->set_vector_exception(); + return; + } +} + +void Vzext8(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 8: + return RiscVUnaryVectorOp<uint64_t, uint8_t>( + rv_vector, inst, + [](uint8_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8"); + rv_vector->set_vector_exception(); + return; + } +} + +void Vsext8(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + switch (sew) { + case 8: + return RiscVUnaryVectorOp<int64_t, int8_t>( + rv_vector, inst, + [](int8_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); + default: + LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8"); + rv_vector->set_vector_exception(); + return; + } +} + +// Vector mask set-before-first mask bit. +void Vmsbf(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) { + rv_vector->set_vector_exception(); + return; + } + int vlen = rv_vector->vector_length(); + auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto *dest_db = dest_op->CopyDataBuffer(0); + auto dest_span = dest_db->Get<uint8_t>(); + bool before_first = true; + int last = 0; + // Set the bits before the first active 1. + for (int i = 0; i < vlen; i++) { + last = i; + int index = i >> 3; + int offset = i & 0b111; + int mask_value = (mask_span[index] >> offset) & 0b1; + int src_value = (src_span[index] >> offset) & 0b1; + if (mask_value) { + before_first = before_first && (src_value == 0); + if (!before_first) break; + + dest_span[index] |= 1 << offset; + } + } + // Clear the remaining bits. + for (int i = last; !before_first && (i < vlen); i++) { + int index = i >> 3; + int offset = i & 0b111; + dest_span[index] &= ~(1 << offset); + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Vector mask set-including-first mask bit. +void Vmsif(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) { + rv_vector->set_vector_exception(); + return; + } + int vlen = rv_vector->vector_length(); + auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto *dest_db = dest_op->CopyDataBuffer(0); + auto dest_span = dest_db->Get<uint8_t>(); + uint8_t value = 1; + for (int i = 0; i < vlen; i++) { + int index = i >> 3; + int offset = i & 0b111; + int mask_value = (mask_span[index] >> offset) & 0b1; + int src_value = (src_span[index] >> offset) & 0b1; + if (mask_value) { + if (value) { + dest_span[index] |= 1 << offset; + } else { + dest_span[index] &= ~(1 << offset); + } + if (src_value) { + value = 0; + } + } + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Vector maks set-only-first mask bit. +void Vmsof(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + if (rv_vector->vstart()) { + rv_vector->set_vector_exception(); + return; + } + int vlen = rv_vector->vector_length(); + auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); + auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); + auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); + auto dest_op = + static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); + auto *dest_db = dest_op->CopyDataBuffer(0); + auto dest_span = dest_db->Get<uint8_t>(); + bool first = true; + for (int i = 0; i < vlen; i++) { + int index = i >> 3; + int offset = i & 0b111; + int mask_value = (mask_span[index] >> offset) & 0b1; + int src_value = (src_span[index] >> offset) & 0b1; + if (mask_value) { + if (first & src_value) { + dest_span[index] |= (1 << offset); + first = false; + } else { + dest_span[index] &= ~(1 << offset); + } + } + } + dest_db->Submit(); + rv_vector->clear_vstart(); +} + +// Vector iota. This instruction reads a source vector mask register and +// writes to each element of the destination vector register group the sum +// of all bits of elements in the mask register whose index is less than the +// element. This is subject to masking. +void Viota(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + int count = 0; + switch (sew) { + case 1: + return RiscVMaskNullaryVectorOp<uint8_t>( + rv_vector, inst, [&count](bool mask) -> uint8_t { + return mask ? static_cast<uint8_t>(count++) + : static_cast<uint8_t>(count); + }); + case 2: + return RiscVMaskNullaryVectorOp<uint16_t>( + rv_vector, inst, [&count](bool mask) -> uint16_t { + return mask ? static_cast<uint16_t>(count++) + : static_cast<uint16_t>(count); + }); + case 4: + return RiscVMaskNullaryVectorOp<uint32_t>( + rv_vector, inst, [&count](bool mask) -> uint32_t { + return mask ? static_cast<uint32_t>(count++) + : static_cast<uint32_t>(count); + }); + case 8: + return RiscVMaskNullaryVectorOp<uint64_t>( + rv_vector, inst, [&count](bool mask) -> uint64_t { + return mask ? static_cast<uint64_t>(count++) + : static_cast<uint64_t>(count); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +// Writes the index of each active (mask true) element to the destination +// vector elements. +void Vid(Instruction *inst) { + auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); + int sew = rv_vector->selected_element_width(); + int index = 0; + switch (sew) { + case 1: + return RiscVMaskNullaryVectorOp<uint8_t>( + rv_vector, inst, [&index](bool mask) -> uint8_t { + uint64_t ret = index++; + return static_cast<uint8_t>(ret); + }); + case 2: + return RiscVMaskNullaryVectorOp<uint16_t>( + rv_vector, inst, [&index](bool mask) -> uint16_t { + uint64_t ret = index++; + return static_cast<uint16_t>(ret); + }); + case 4: + return RiscVMaskNullaryVectorOp<uint32_t>( + rv_vector, inst, [&index](bool mask) -> uint32_t { + uint64_t ret = index++; + return static_cast<uint32_t>(ret); + }); + case 8: + return RiscVMaskNullaryVectorOp<uint64_t>( + rv_vector, inst, [&index](bool mask) -> uint64_t { + uint64_t ret = index++; + return static_cast<uint64_t>(ret); + }); + default: + rv_vector->set_vector_exception(); + LOG(ERROR) << "Illegal SEW value"; + return; + } +} + +} // namespace cheriot +} // namespace sim +} // namespace mpact
diff --git a/cheriot/riscv_cheriot_vector_unary_instructions.h b/cheriot/riscv_cheriot_vector_unary_instructions.h new file mode 100644 index 0000000..4043a11 --- /dev/null +++ b/cheriot/riscv_cheriot_vector_unary_instructions.h
@@ -0,0 +1,109 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_UNARY_INSTRUCTIONS_H_ +#define MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_UNARY_INSTRUCTIONS_H_ + +#include "mpact/sim/generic/instruction.h" + +// This file declares the vector instruction semantic functions for the integer +// unary vector instructions. +namespace mpact { +namespace sim { +namespace cheriot { + +using Instruction = ::mpact::sim::generic::Instruction; + +// VWXUNARY0 +// Moves a value from index 0 element of a vector register to a scalar register. +// This instruction takes 1 source and 1 destination. One is a vector register, +// the other is a scalar (x) register. +void VmvToScalar(Instruction *inst); +// Moves a scalar to index 0 element of a vector register. This instruction +// takes 1 source and 1 destination. One is a vector register, the other is +// a scalar (x) register. +void VmvFromScalar(Instruction *inst); +// Does a population count on the vector mask in the source operand 0 (subject +// to masking), and writes the result to the scalar register in the destination +// operand. Operand 1 is a mask register. Only vlen bits are considered. +void Vcpop(Instruction *inst); +// Computes the index of the first set bit of the vector mask in the source +// operand 0 (subject ot masking), and writes the result to the scalar register +// in the destination operand. Operand 1 is a mask register. Only vlen bits +// are considered. +void Vfirst(Instruction *inst); + +// VXUNARY0 +// Element wide zero extend from SEW/2 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vzext2(Instruction *inst); +// Element wide sign extend from SEW/2 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vsext2(Instruction *inst); +// Element wide zero extend from SEW/4 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vzext4(Instruction *inst); +// Element wide sign extend from SEW/4 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vsext4(Instruction *inst); +// Element wide zero extend from SEW/8 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vzext8(Instruction *inst); +// Element wide sign extend from SEW/8 to SEW. This instruction takes two source +// operands, and a vector destination operand. Source 0 is the vs2 vector +// source, and source 1 is a vector mask operand. +void Vsext8(Instruction *inst); + +// VMUNARY0 +// Set before first mask bit. Takes a vector mask stored in a vector register +// and produces a mask register with all active bits before the first set active +// bit in the source mask set to 1. This instruction takes one vector source +// operands, a mask register, and a vector destination operand. Source 0 is the +// vs2 register source, source 1 is the vector mask operand. +void Vmsbf(Instruction *inst); +// Set only first mask bit. Takes a vector mask stored in a vector register +// and produces a mask register with only the bit set that corresponds to the +// the first set active bit in the source mask set to 1. This instruction takes +// one vector source operands, a mask register, and a vector destination +// operand. Source 0 is the vs2 register source, source 1 is the vector mask +// operand. +void Vmsof(Instruction *inst); +// Set including first mask bit. Takes a vector mask stored in a vector register +// and produces a mask register with all active bits before and including the +// first set active bit in the source mask set to 1. This instruction takes one +// vector source operands, a mask register, and a vector destination operand. +// Source 0 is the vs2 register source, source 1 is the vector mask operand. +void Vmsif(Instruction *inst); +// Vector Iota instruction. Takes a vector mask stored in a vector register +// and writes to each element of the destination vector register group the sum +// of all bits of elements in the mask register whose index is less than the +// element (parallel prefix sum). This instruction takes two sources and one +// destination. Source 0 is the vs2 register source, source 1 is the vector +// mask operand. +void Viota(Instruction *inst); +// Vector element index instruction. Writes the element index to the destination +// vector element group (masking does not change the value written to active +// elements, only which elements are written to). This instruction takes 1 +// source (mask register) and one destination. +void Vid(Instruction *inst); +} // namespace cheriot +} // namespace sim +} // namespace mpact + +#endif // MPACT_CHERIOT_RISCV_CHERIOT_VECTOR_UNARY_INSTRUCTIONS_H_
diff --git a/cheriot/test/BUILD b/cheriot/test/BUILD index c4920c9..ebaa5e5 100644 --- a/cheriot/test/BUILD +++ b/cheriot/test/BUILD
@@ -16,12 +16,17 @@ package(default_applicable_licenses = ["//:license"]) +config_setting( + name = "darwin_arm64_cpu", + values = {"cpu": "darwin_arm64"}, +) + cc_test( name = "cheriot_register_test", size = "small", srcs = ["cheriot_register_test.cc"], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", "@com_google_absl//absl/log:check", "@com_google_absl//absl/random", "@com_google_absl//absl/status", @@ -37,7 +42,7 @@ size = "small", srcs = ["cheriot_state_test.cc"], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:str_format", "@com_google_googletest//:gtest_main", @@ -54,7 +59,8 @@ "riscv_cheriot_instructions_test.cc", ], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_instructions", "@com_google_absl//absl/log:check", "@com_google_absl//absl/random", "@com_google_absl//absl/strings:str_format", @@ -75,7 +81,8 @@ "riscv_cheriot_i_instructions_test.cc", ], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_instructions", "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@com_google_mpact-sim//mpact/sim/generic:core", @@ -92,7 +99,8 @@ "riscv_cheriot_m_instructions_test.cc", ], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_instructions", "@com_google_absl//absl/random", "@com_google_googletest//:gtest_main", "@com_google_mpact-sim//mpact/sim/generic:core", @@ -110,7 +118,8 @@ ], tags = ["not_run:arm"], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_instructions", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", @@ -129,7 +138,7 @@ "riscv_cheriot_encoding_test.cc", ], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", "//cheriot:riscv_cheriot_decoder", "//cheriot:riscv_cheriot_isa", "@com_google_googletest//:gtest_main", @@ -145,7 +154,8 @@ "riscv_cheriot_a_instructions_test.cc", ], deps = [ - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_instructions", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", "@com_google_googletest//:gtest_main", @@ -163,7 +173,7 @@ ], deps = [ "//cheriot:cheriot_load_filter", - "//cheriot:riscv_cheriot", + "//cheriot:cheriot_state", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest_main", @@ -199,3 +209,314 @@ "@com_google_googletest//:gtest_main", ], ) + +cc_test( + name = "cheriot_rvv_fp_decoder_test", + size = "small", + srcs = ["cheriot_rvv_fp_decoder_test.cc"], + deps = [ + "//cheriot:riscv_cheriot_rvv_fp_decoder", + "@com_google_absl//absl/log:log_sink_registry", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-sim//mpact/sim/util/other:log_sink", + ], +) + +cc_test( + name = "cheriot_rvv_decoder_test", + size = "small", + srcs = ["cheriot_rvv_decoder_test.cc"], + deps = [ + "//cheriot:riscv_cheriot_rvv_decoder", + "@com_google_absl//absl/log:log_sink_registry", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-sim//mpact/sim/util/other:log_sink", + ], +) + +cc_library( + name = "riscv_cheriot_vector_instructions_test_base", + testonly = True, + hdrs = ["riscv_cheriot_vector_instructions_test_base.h"], + deps = [ + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_library( + name = "riscv_vector_fp_test_utilities", + testonly = True, + hdrs = ["riscv_cheriot_vector_fp_test_utilities.h"], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_for_library_testonly", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_test( + name = "riscv_cheriot_vector_true_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_true_test.cc", + ], + deps = [ + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_test( + name = "riscv_cheriot_vector_memory_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_memory_instructions_test.cc", + ], + deps = [ + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/log", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:arch_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/util/memory", + ], +) + +cc_test( + name = "riscv_cheriot_vector_opi_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_opi_instructions_test.cc", + ], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/functional:bind_front", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +) + +cc_test( + name = "riscv_cheriot_vector_opm_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_opm_instructions_test.cc", + ], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/base", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +) + +cc_test( + name = "riscv_cheriot_vector_reduction_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_reduction_instructions_test.cc", + ], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_test( + name = "riscv_cheriot_vector_unary_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_unary_instructions_test.cc", + ], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/random", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_test( + name = "riscv_cheriot_vector_fp_unary_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_fp_unary_instructions_test.cc", + ], + copts = [ + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + ":riscv_vector_fp_test_utilities", + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector_fp", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:core", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_test( + name = "riscv_cheriot_vector_fp_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_fp_instructions_test.cc", + ], + copts = [ + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + ":riscv_vector_fp_test_utilities", + "//cheriot:cheriot_state", + "//cheriot:cheriot_vector_state", + "//cheriot:riscv_cheriot_vector_fp", + "@com_google_absl//absl/log", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", + ], +) + +cc_test( + name = "riscv_cheriot_vector_fp_compare_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_fp_compare_instructions_test.cc", + ], + copts = [ + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + ":riscv_vector_fp_test_utilities", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_vector_fp", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +) + +cc_test( + name = "riscv_cheriot_vector_fp_reduction_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_fp_reduction_instructions_test.cc", + ], + copts = [ + "-ffp-model=strict", + ] + select({ + "darwin_arm64_cpu": [], + "//conditions:default": ["-fprotect-parens"], + }), + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + ":riscv_vector_fp_test_utilities", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_vector_fp", + "@com_google_absl//absl/random", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_fp_state", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +) + +cc_test( + name = "riscv_cheriot_vector_permute_instructions_test", + size = "small", + srcs = [ + "riscv_cheriot_vector_permute_instructions_test.cc", + ], + deps = [ + ":riscv_cheriot_vector_instructions_test_base", + "//cheriot:cheriot_state", + "//cheriot:riscv_cheriot_vector", + "@com_google_absl//absl/random", + "@com_google_googletest//:gtest_main", + "@com_google_mpact-riscv//riscv:riscv_state", + "@com_google_mpact-sim//mpact/sim/generic:instruction", + ], +)
diff --git a/cheriot/test/cheriot_rvv_decoder_test.cc b/cheriot/test/cheriot_rvv_decoder_test.cc new file mode 100644 index 0000000..d3aa53d --- /dev/null +++ b/cheriot/test/cheriot_rvv_decoder_test.cc
@@ -0,0 +1,20 @@ +#include "cheriot/cheriot_rvv_decoder.h" + +#include "absl/log/log_sink_registry.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/util/other/log_sink.h" + +namespace { + +using ::mpact::sim::cheriot::CheriotRVVDecoder; +using ::mpact::sim::util::LogSink; + +TEST(RiscvCheriotRvvDecoderTest, Instantiation) { + LogSink log_sink; + absl::AddLogSink(&log_sink); + CheriotRVVDecoder decoder(nullptr, nullptr); + EXPECT_EQ(log_sink.num_error(), 0); + absl::RemoveLogSink(&log_sink); +} + +} // namespace
diff --git a/cheriot/test/cheriot_rvv_fp_decoder_test.cc b/cheriot/test/cheriot_rvv_fp_decoder_test.cc new file mode 100644 index 0000000..68cfce7 --- /dev/null +++ b/cheriot/test/cheriot_rvv_fp_decoder_test.cc
@@ -0,0 +1,20 @@ +#include "cheriot/cheriot_rvv_fp_decoder.h" + +#include "absl/log/log_sink_registry.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/util/other/log_sink.h" + +namespace { + +using ::mpact::sim::cheriot::CheriotRVVFPDecoder; +using ::mpact::sim::util::LogSink; + +TEST(RiscvCheriotRvvFpDecoderTest, Instantiation) { + LogSink log_sink; + absl::AddLogSink(&log_sink); + CheriotRVVFPDecoder decoder(nullptr, nullptr); + EXPECT_EQ(log_sink.num_error(), 0); + absl::RemoveLogSink(&log_sink); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_encoding_test.cc b/cheriot/test/riscv_cheriot_encoding_test.cc index 31604ba..4226c71 100644 --- a/cheriot/test/riscv_cheriot_encoding_test.cc +++ b/cheriot/test/riscv_cheriot_encoding_test.cc
@@ -36,8 +36,6 @@ // RV32I constexpr uint32_t kLui = 0b0000000000000000000000000'0110111; -constexpr uint32_t kJal = 0b00000000000000000000'00000'1101111; -constexpr uint32_t kJalr = 0b00000000000'00000'000'00000'1100111; constexpr uint32_t kBeq = 0b0000000'00000'00000'000'00000'1100011; constexpr uint32_t kBne = 0b0000000'00000'00000'001'00000'1100011; constexpr uint32_t kBlt = 0b0000000'00000'00000'100'00000'1100011; @@ -204,8 +202,6 @@ }; constexpr int kRdValue = 1; -constexpr int kSuccValue = 0xf; -constexpr int kPredValue = 0xf; static uint32_t SetRd(uint32_t iword, uint32_t rdval) { return (iword | ((rdval & 0x1f) << 7)); @@ -219,14 +215,6 @@ return (iword | ((rsval & 0x1f) << 20)); } -static uint32_t SetPred(uint32_t iword, uint32_t pred) { - return (iword | ((pred & 0xf) << 24)); -} - -static uint32_t SetSucc(uint32_t iword, uint32_t succ) { - return (iword | ((succ & 0xf) << 20)); -} - static uint32_t Set16Rd(uint32_t iword, uint32_t val) { return (iword | ((val & 0x1f) << 7)); } @@ -412,25 +400,25 @@ enc_->ParseInstruction(Set16Rd(kClwsp, 1)); EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kClwsp); enc_->ParseInstruction(Set16Rd(kCldsp, 1)); - EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCldsp); + EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kClcsp); // enc_->ParseInstruction(kCdldsp); // EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), // OpcodeEnum::kCfldsp); enc_->ParseInstruction(kCswsp); EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCswsp); enc_->ParseInstruction(kCsdsp); - EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCsdsp); + EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCscsp); // enc_->ParseInstruction(kCdsdsp); // EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), // OpcodeEnum::kCfsdsp); enc_->ParseInstruction(kClw); EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kClw); enc_->ParseInstruction(kCld); - EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCld); + EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kClc); enc_->ParseInstruction(kCsw); EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCsw); enc_->ParseInstruction(kCsd); - EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCsd); + EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), OpcodeEnum::kCsc); // enc_->ParseInstruction(kCdsd); // EXPECT_EQ(enc_->GetOpcode(SlotEnum::kRiscv32Cheriot, 0), // OpcodeEnum::kCdsd);
diff --git a/cheriot/test/riscv_cheriot_vector_fp_compare_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_fp_compare_instructions_test.cc new file mode 100644 index 0000000..1fd6a2b --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_fp_compare_instructions_test.cc
@@ -0,0 +1,444 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_compare_instructions.h" + +#include <algorithm> +#include <cstdint> +#include <functional> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/test/riscv_cheriot_vector_fp_test_utilities.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_register.h" + +// This file contains the tests of the instruction semantic functions for +// RiscV vector floating point compare instructions. +namespace { + +using Instruction = ::mpact::sim::generic::Instruction; + +// The semantic functions. +using ::mpact::sim::cheriot::Vmfeq; +using ::mpact::sim::cheriot::Vmfge; +using ::mpact::sim::cheriot::Vmfgt; +using ::mpact::sim::cheriot::Vmfle; +using ::mpact::sim::cheriot::Vmflt; +using ::mpact::sim::cheriot::Vmfne; + +// Needed types. +using ::absl::Span; +using ::mpact::sim::riscv::RVFpRegister; +using ::mpact::sim::riscv::ScopedFPStatus; + +class RiscVCheriotFPCompareInstructionsTest + : public RiscVCheriotFPInstructionsTestBase { + public: + // Helper function for testing binary mask vector-vector instructions that + // use the mask bit. + template <typename Vs2, typename Vs1> + void BinaryMaskFPOpWithMaskTestHelperVV( + absl::string_view name, int sew, Instruction *inst, + std::function<uint8_t(Vs2, Vs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vs2) && byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vs2: " << sizeof(Vs2) + << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs1 vs1_value[vs1_size * 8]; + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + FillArrayWithRandomFPValues<Vs1>(vs1_span); + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + // Make every third value the same (at least if the types are same sized). + for (int i = 0; i < std::min(vs1_size, vs2_size); i += 3) { + vs1_span[i] = static_cast<Vs1>(vs2_span[i]); + } + + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Modify the first mask bits to use each of the special floating point + // values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ClearVectorRegisterGroup(kVd, 8); + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + inst->Execute(); + if ((lmul8_vs2 < 1) || (lmul8_vs2 > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes * 8; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = true; + if (mask_index > 0) { + mask_value = ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + uint8_t inst_value = dest_span[i >> 3]; + inst_value = (inst_value >> mask_offset) & 0b1; + if ((i >= vstart) && (i < num_values)) { + // Set rounding mode and perform the computation. + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + uint8_t expected_value = + operation(vs2_value[i], vs1_value[i], mask_value); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[i]); + auto int_vs1_val = + *reinterpret_cast<typename FPTypeInfo<Vs1>::IntType *>( + &vs1_value[i]); + EXPECT_EQ(expected_value, inst_value) + << absl::StrCat(name, "[", i, "] op(", vs2_value[i], "[0x", + absl::Hex(int_vs2_val), "], ", vs1_value[i], + "[0x", absl::Hex(int_vs1_val), "])"); + } else { + EXPECT_EQ(0, inst_value) << absl::StrCat( + name, "[", i, "] 0 != reg[][", i, "] lmul8(", lmul8, + ") vstart(", vstart, ") num_values(", num_values, ")"); + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing binary mask vector-vector instructions that do + // not use the mask bit. + template <typename Vs2, typename Vs1> + void BinaryMaskFPOpTestHelperVV(absl::string_view name, int sew, + Instruction *inst, + std::function<uint8_t(Vs2, Vs1)> operation) { + BinaryMaskFPOpWithMaskTestHelperVV<Vs2, Vs1>( + name, sew, inst, + [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> uint8_t { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + // Helper function for testing mask vector-scalar/immediate instructions that + // use the mask bit. + template <typename Vs2, typename Fs1> + void BinaryMaskFPOpWithMaskTestHelperVX( + absl::string_view name, int sew, Instruction *inst, + std::function<uint8_t(Vs2, Fs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vs2) && byte_sew != sizeof(Fs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vs2: " << sizeof(Vs2) + << " Rs1: " << sizeof(Fs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kFs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ClearVectorRegisterGroup(kVd, 8); + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + // Generate a new rs1 value. + Fs1 fs1_value = RandomFPValue<Fs1>(); + // Need to NaN box the value, that is, if the register value type is + // wider than the data type for a floating point value, the upper bits + // are all set to 1's. + typename RVFpRegister::ValueType fs1_reg_value = + NaNBox<Fs1, typename RVFpRegister::ValueType>(fs1_value); + SetRegisterValues<typename RVFpRegister::ValueType, RVFpRegister>( + {{kFs1Name, fs1_reg_value}}); + + inst->Execute(); + if ((lmul8_vs2 < 1) || (lmul8_vs2 > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes * 8; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + uint8_t inst_value = dest_span[i >> 3]; + inst_value = (inst_value >> mask_offset) & 0b1; + if ((i >= vstart) && (i < num_values)) { + // Set rounding mode and perform the computation. + + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + uint8_t expected_value = + operation(vs2_value[i], fs1_value, mask_value); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[i]); + auto int_fs1_val = + *reinterpret_cast<typename FPTypeInfo<Fs1>::IntType *>( + &fs1_value); + EXPECT_EQ(expected_value, inst_value) + << absl::StrCat(name, "[", i, "] op(", vs2_value[i], "[0x", + absl::Hex(int_vs2_val), "], ", fs1_value, + "[0x", absl::Hex(int_fs1_val), "])"); + } else { + EXPECT_EQ(0, inst_value) << absl::StrCat( + name, " 0 != reg[0][", i, "] lmul8(", lmul8, ")"); + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing mask vector-vector instructions that do not + // use the mask bit. + template <typename Vs2, typename Fs1> + void BinaryMaskFPOpTestHelperVX(absl::string_view name, int sew, + Instruction *inst, + std::function<uint8_t(Vs2, Fs1)> operation) { + BinaryMaskFPOpWithMaskTestHelperVX<Vs2, Fs1>( + name, sew, inst, + [operation](Vs2 vs2, Fs1 fs1, bool mask_value) -> uint8_t { + if (mask_value) { + return operation(vs2, fs1); + } + return 0; + }); + } +}; + +// Testing vector floating point compare instructions. + +// Vector floating point compare equal. +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmfeq) { + SetSemanticFunction(&Vmfeq); + BinaryMaskFPOpTestHelperVV<float, float>( + "Vmfeq_vv32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 == vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfeq); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmfeq_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 == vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfeq); + BinaryMaskFPOpTestHelperVV<double, double>( + "Vmfeq_vv64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 == vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfeq); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmfeq_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 == vs1) ? 1 : 0; }); +} + +// Vector floating point compare less than or equal. +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmfle) { + SetSemanticFunction(&Vmfle); + BinaryMaskFPOpTestHelperVV<float, float>( + "Vmfle_vv32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 <= vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfle); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmfle_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 <= vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfle); + BinaryMaskFPOpTestHelperVV<double, double>( + "Vmfle_vv64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 <= vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfle); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmfle_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 <= vs1) ? 1 : 0; }); +} + +// Vector floating point compare less than. +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmflt) { + SetSemanticFunction(&Vmflt); + BinaryMaskFPOpTestHelperVV<float, float>( + "Vmflt_vv32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 < vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmflt); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmflt_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 < vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmflt); + BinaryMaskFPOpTestHelperVV<double, double>( + "Vmflt_vv64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 < vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmflt); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmflt_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 < vs1) ? 1 : 0; }); +} + +// Vector floating point compare not equal. +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmfne) { + SetSemanticFunction(&Vmfne); + BinaryMaskFPOpTestHelperVV<float, float>( + "Vmfne_vv32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 != vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfne); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmfne_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 != vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfne); + BinaryMaskFPOpTestHelperVV<double, double>( + "Vmfne_vv64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 != vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfne); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmfne_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 != vs1) ? 1 : 0; }); +} + +// Vector floating point compare greater than (used for Vector-Scalar +// comparisons). +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmfgt) { + SetSemanticFunction(&Vmfgt); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmfgt_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 > vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfgt); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmfgt_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 > vs1) ? 1 : 0; }); +} + +// Vector floating point compare greater than or equal (used for Vector-Scalar +// comparisons). +TEST_F(RiscVCheriotFPCompareInstructionsTest, Vmfge) { + SetSemanticFunction(&Vmfge); + BinaryMaskFPOpTestHelperVX<float, float>( + "Vmfge_vx32", /*sew*/ 32, instruction_, + [](float vs2, float vs1) -> uint8_t { return (vs2 >= vs1) ? 1 : 0; }); + ResetInstruction(); + SetSemanticFunction(&Vmfge); + BinaryMaskFPOpTestHelperVX<double, double>( + "Vmfge_vx64", /*sew*/ 64, instruction_, + [](double vs2, double vs1) -> uint8_t { return (vs2 >= vs1) ? 1 : 0; }); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_fp_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_fp_instructions_test.cc new file mode 100644 index 0000000..00f78c5 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_fp_instructions_test.cc
@@ -0,0 +1,1376 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_instructions.h" + +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <functional> +#include <tuple> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/test/riscv_cheriot_vector_fp_test_utilities.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_register.h" + +namespace { + +using Instruction = ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::operator*; // NOLINT: used below. + +// Functions to test. +using ::mpact::sim::cheriot::Vfadd; +using ::mpact::sim::cheriot::Vfdiv; +using ::mpact::sim::cheriot::Vfmacc; +using ::mpact::sim::cheriot::Vfmadd; +using ::mpact::sim::cheriot::Vfmax; +using ::mpact::sim::cheriot::Vfmerge; +using ::mpact::sim::cheriot::Vfmin; +using ::mpact::sim::cheriot::Vfmsac; +using ::mpact::sim::cheriot::Vfmsub; +using ::mpact::sim::cheriot::Vfmul; +using ::mpact::sim::cheriot::Vfnmacc; +using ::mpact::sim::cheriot::Vfnmadd; +using ::mpact::sim::cheriot::Vfnmsac; +using ::mpact::sim::cheriot::Vfnmsub; +using ::mpact::sim::cheriot::Vfrdiv; +using ::mpact::sim::cheriot::Vfrsub; +using ::mpact::sim::cheriot::Vfsgnj; +using ::mpact::sim::cheriot::Vfsgnjn; +using ::mpact::sim::cheriot::Vfsgnjx; +using ::mpact::sim::cheriot::Vfsub; +using ::mpact::sim::cheriot::Vfwadd; +using ::mpact::sim::cheriot::Vfwaddw; +using ::mpact::sim::cheriot::Vfwmacc; +using ::mpact::sim::cheriot::Vfwmsac; +using ::mpact::sim::cheriot::Vfwmul; +using ::mpact::sim::cheriot::Vfwnmacc; +using ::mpact::sim::cheriot::Vfwnmsac; +using ::mpact::sim::cheriot::Vfwsub; +using ::mpact::sim::cheriot::Vfwsubw; + +using ::absl::Span; +using ::mpact::sim::riscv::FPExceptions; +using ::mpact::sim::riscv::FPRoundingMode; +using ::mpact::sim::riscv::RVFpRegister; +using ::mpact::sim::riscv::ScopedFPStatus; + +// Test fixture for binary fp instructions. +class RiscVCheriotFPInstructionsTest + : public RiscVCheriotFPInstructionsTestBase { + public: + // Floating point test needs to ensure to use the fp special values (inf, NaN + // etc.) during testing, not just random values. + template <typename Vd, typename Vs2, typename Vs1> + void TernaryOpFPTestHelperVV(absl::string_view name, int sew, + Instruction *inst, int delta_position, + std::function<Vd(Vs2, Vs1, Vd)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << " sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + constexpr int vd_size = kVectorLengthInBytes / sizeof(Vd); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + Vs1 vs1_value[vs1_size * 8]; + Vd vd_value[vd_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + auto vs1_span = Span<Vs1>(vs1_value); + auto vd_span = Span<Vd>(vd_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVd, kVmask}, {kVd}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + FillArrayWithRandomFPValues<Vs1>(vs1_span); + FillArrayWithRandomFPValues<Vd>(vd_span); + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using Vs1Int = typename FPTypeInfo<Vs1>::IntType; + using VdInt = typename FPTypeInfo<Vd>::IntType; + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + *reinterpret_cast<VdInt *>(&vd_span[0]) = FPTypeInfo<Vd>::kQNaN; + *reinterpret_cast<VdInt *>(&vd_span[1]) = FPTypeInfo<Vd>::kSNaN; + *reinterpret_cast<VdInt *>(&vd_span[2]) = FPTypeInfo<Vd>::kPosInf; + *reinterpret_cast<VdInt *>(&vd_span[3]) = FPTypeInfo<Vd>::kNegInf; + *reinterpret_cast<VdInt *>(&vd_span[4]) = FPTypeInfo<Vd>::kPosZero; + *reinterpret_cast<VdInt *>(&vd_span[5]) = FPTypeInfo<Vd>::kNegZero; + *reinterpret_cast<VdInt *>(&vd_span[6]) = FPTypeInfo<Vd>::kPosDenorm; + *reinterpret_cast<VdInt *>(&vd_span[7]) = FPTypeInfo<Vd>::kNegDenorm; + if (lmul_index == 4) { + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kPosDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kNegDenorm; + } else if (lmul_index == 5) { + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kPosDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kNegDenorm; + } else if (lmul_index == 6) { + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kNegDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kPosDenorm; + } + // Modify the first mask bits to use each of the special floating point + // values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + + // Reset Vd values, since the previous instruction execution + // overwrites them. + for (int i = 0; i < 8; i++) { + auto vd_name = absl::StrCat("v", kVd + i); + SetVectorRegisterValues<Vd>( + {{vd_name, vd_span.subspan(vd_size * i, vd_size)}}); + } + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so only + // read the mask value after the first byte. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + ®_val); + auto int_vd_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &vd_value[count]); + if ((count >= vstart) && mask_value && (count < num_values)) { + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto op_val = operation(vs2_value[count], vs1_value[count], + vd_value[count]); + auto int_op_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &op_val); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[count]); + auto int_vs1_val = + *reinterpret_cast<typename FPTypeInfo<Vs1>::IntType *>( + &vs1_value[count]); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat( + name, "[", count, "] op(", vs2_value[count], "[0x", + absl::Hex(int_vs2_val), "], ", vs1_value[count], + "[0x", absl::Hex(int_vs1_val), "], ", vd_value[count], + "[0x", absl::Hex(int_vd_val), + "]) = ", absl::Hex(int_op_val), " != reg[", reg, "][", + i, "] (", reg_val, " [0x", absl::Hex(int_reg_val), + "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_THAT(int_vd_val, int_reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(int_reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + if (HasFailure()) return; + } + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Floating point test needs to ensure to use the fp special values (inf, NaN + // etc.) during testing, not just random values. This function handles vector + // scalar instructions. + template <typename Vd, typename Vs2, typename Fs1> + void TernaryOpFPTestHelperVX(absl::string_view name, int sew, + Instruction *inst, int delta_position, + std::function<Vd(Vs2, Fs1, Vd)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Fs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Fs1: " << sizeof(Fs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vd_size = kVectorLengthInBytes / sizeof(Vd); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + Vd vd_value[vd_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + auto vd_span = Span<Vd>(vd_value); + AppendVectorRegisterOperands({kVs2}, {kVd}); + AppendRegisterOperands({kFs1Name}, {}); + AppendVectorRegisterOperands({kVd, kVmask}, {kVd}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using VdInt = typename FPTypeInfo<Vd>::IntType; + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + *reinterpret_cast<VdInt *>(&vd_span[0]) = FPTypeInfo<Vd>::kQNaN; + *reinterpret_cast<VdInt *>(&vd_span[1]) = FPTypeInfo<Vd>::kSNaN; + *reinterpret_cast<VdInt *>(&vd_span[2]) = FPTypeInfo<Vd>::kPosInf; + *reinterpret_cast<VdInt *>(&vd_span[3]) = FPTypeInfo<Vd>::kNegInf; + *reinterpret_cast<VdInt *>(&vd_span[4]) = FPTypeInfo<Vd>::kPosZero; + *reinterpret_cast<VdInt *>(&vd_span[5]) = FPTypeInfo<Vd>::kNegZero; + *reinterpret_cast<VdInt *>(&vd_span[6]) = FPTypeInfo<Vd>::kPosDenorm; + *reinterpret_cast<VdInt *>(&vd_span[7]) = FPTypeInfo<Vd>::kNegDenorm; + // Modify the first mask bits to use each of the special floating point + // values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Fs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Generate a new rs1 value. + Fs1 fs1_value = RandomFPValue<Fs1>(); + // Need to NaN box the value, that is, if the register value type is + // wider than the data type for a floating point value, the upper bits + // are all set to 1's. + typename RVFpRegister::ValueType fs1_reg_value = + NaNBox<Fs1, typename RVFpRegister::ValueType>(fs1_value); + SetRegisterValues<typename RVFpRegister::ValueType, RVFpRegister>( + {{kFs1Name, fs1_reg_value}}); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + // Reset Vd values, since the previous instruction execution + // overwrites them. + for (int i = 0; i < 8; i++) { + auto vd_name = absl::StrCat("v", kVd + i); + SetVectorRegisterValues<Vd>( + {{vd_name, vd_span.subspan(vd_size * i, vd_size)}}); + } + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so only + // read the mask value after the first byte from the constant + // mask values. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + ®_val); + auto int_vd_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &vd_value[count]); + if ((count >= vstart) && mask_value && (count < num_values)) { + // Set rounding mode and perform the computation. + + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto op_val = + operation(vs2_value[count], fs1_value, vd_value[count]); + // Extract the integer view of the fp values. + auto int_op_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &op_val); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[count]); + auto int_fs1_val = + *reinterpret_cast<typename FPTypeInfo<Fs1>::IntType *>( + &fs1_value); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat( + name, "[", count, "] op(", vs2_value[count], "[0x", + absl::Hex(int_vs2_val), "], ", fs1_value, "[0x", + absl::Hex(int_fs1_val), "], ", vd_value[count], "[0x", + absl::Hex(int_vd_val), "]) = ", op_val, "[0x", + absl::Hex(int_op_val), "] ", " != reg[", reg, "][", i, + "] (", reg_val, " [0x", absl::Hex(int_reg_val), + "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_EQ(int_vd_val, int_reg_val) << absl::StrCat( + name, " ", vd_value[count], " [0x", + absl::Hex(int_vd_val), "] != reg[", reg, "][", i, "] (", + reg_val, " [0x", absl::Hex(int_reg_val), "]) lmul8(", + lmul8, ")"); + } + count++; + } + if (HasFailure()) return; + } + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } +}; + +// Test fp add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfadd) { + // Vector-vector. + SetSemanticFunction(&Vfadd); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfadd_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 + vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfadd); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfadd_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 + vs1; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfadd); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfadd_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 + vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfadd); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfadd_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 + vs1; }); +} + +// Test fp sub. +TEST_F(RiscVCheriotFPInstructionsTest, Vfsub) { + // Vector-vector. + SetSemanticFunction(&Vfsub); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfsub_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 - vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfsub); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfsub_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 - vs1; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfsub); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfsub_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 - vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfsub); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfsub_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 - vs1; }); +} + +// Test fp reverse sub. +TEST_F(RiscVCheriotFPInstructionsTest, Vfrsub) { + // Vector-vector. + SetSemanticFunction(&Vfrsub); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfrsub_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs1 - vs2; }); + ResetInstruction(); + SetSemanticFunction(&Vfrsub); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfrsub_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs1 - vs2; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfrsub); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfrsub_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs1 - vs2; }); + ResetInstruction(); + SetSemanticFunction(&Vfrsub); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfrsub_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs1 - vs2; }); +} + +// Test fp widening add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwadd) { + // Vector-vector. + SetSemanticFunction(&Vfwadd); + BinaryOpFPTestHelperVV<double, float, float>( + "Vfwadd_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) + static_cast<double>(vs1); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwadd); + BinaryOpFPTestHelperVX<double, float, float>( + "Vfwadd_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) + static_cast<double>(vs1); + }); +} + +// Test fp widening subtract. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwsub) { + // Vector-vector. + SetSemanticFunction(&Vfwsub); + BinaryOpFPTestHelperVV<double, float, float>( + "Vfwsub_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) - static_cast<double>(vs1); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwsub); + BinaryOpFPTestHelperVX<double, float, float>( + "Vfwsub_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) - static_cast<double>(vs1); + }); +} + +// Test fp widening add with wide operand. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwaddw) { + // Vector-vector. + SetSemanticFunction(&Vfwaddw); + BinaryOpFPTestHelperVV<double, double, float>( + "Vfwaddw_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](double vs2, float vs1) -> double { + return vs2 + static_cast<double>(vs1); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwaddw); + BinaryOpFPTestHelperVX<double, double, float>( + "Vfwaddw_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](double vs2, float vs1) -> double { + return vs2 + static_cast<double>(vs1); + }); +} + +// Test fp widening subtract with wide operand. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwsubw) { + // Vector-vector. + SetSemanticFunction(&Vfwsubw); + BinaryOpFPTestHelperVV<double, double, float>( + "Vfwsubw_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](double vs2, float vs1) -> double { + return vs2 - static_cast<double>(vs1); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwsubw); + BinaryOpFPTestHelperVX<double, double, float>( + "Vfwsubw_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](double vs2, float vs1) -> double { + return vs2 - static_cast<double>(vs1); + }); +} + +// Test fp multiply. +TEST_F(RiscVCheriotFPInstructionsTest, Vfmul) { + // Vector-vector. + SetSemanticFunction(&Vfmul); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfmul_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 * vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfmul); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfmul_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 * vs1; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmul); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfmul_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 * vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfmul); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfmul_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 * vs1; }); +} + +// Test fp divide. +TEST_F(RiscVCheriotFPInstructionsTest, Vfdiv) { + // Vector-vector. + SetSemanticFunction(&Vfdiv); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfdiv_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 / vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfdiv); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfdiv_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 / vs1; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfdiv); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfdiv_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs2 / vs1; }); + ResetInstruction(); + SetSemanticFunction(&Vfdiv); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfdiv_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs2 / vs1; }); +} + +// Test fp reverse divide. +TEST_F(RiscVCheriotFPInstructionsTest, Vfrdiv) { + // Vector-vector. + SetSemanticFunction(&Vfrdiv); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfrdiv_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs1 / vs2; }); + ResetInstruction(); + SetSemanticFunction(&Vfrdiv); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfrdiv_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs1 / vs2; }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfrdiv); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfrdiv_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { return vs1 / vs2; }); + ResetInstruction(); + SetSemanticFunction(&Vfrdiv); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfrdiv_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { return vs1 / vs2; }); +} + +// Test fp widening multiply. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwmul) { + // Vector-vector. + SetSemanticFunction(&Vfwmul); + BinaryOpFPTestHelperVV<double, float, float>( + "Vfwmul_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) * static_cast<double>(vs1); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwmul); + BinaryOpFPTestHelperVX<double, float, float>( + "Vfwmul_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> double { + return static_cast<double>(vs2) * static_cast<double>(vs1); + }); +} + +// Test fp multiply add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfmadd) { + // Vector-vector. + SetSemanticFunction(&Vfmadd); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfmadd_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vd, vs2); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmadd); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfmadd_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vd, vs2); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmadd); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfmadd_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return std::fma(vs1, vd, vs2); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmadd); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfmadd_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return std::fma(vs1, vd, vs2); + }); +} + +// Test fp negated multiply add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfnmadd) { + // Vector-vector. + SetSemanticFunction(&Vfnmadd); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfnmadd_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vd, -vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmadd); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfnmadd_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vd, -vs2)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfnmadd); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfnmadd_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vd, -vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmadd); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfnmadd_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vd, -vs2)); + }); +} + +// Test fp multiply subtract. +TEST_F(RiscVCheriotFPInstructionsTest, Vfmsub) { + // Vector-vector. + SetSemanticFunction(&Vfmsub); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfmsub_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vd, -vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmsub); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfmsub_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vd, -vs2)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmsub); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfmsub_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vd, -vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmsub); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfmsub_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vd, -vs2)); + }); +} + +// Test fp negated multiply subtract. +TEST_F(RiscVCheriotFPInstructionsTest, Vfnmsub) { + // Vector-vector. + SetSemanticFunction(&Vfnmsub); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfnmsub_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vd, vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmsub); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfnmsub_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vd, vs2)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfnmsub); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfnmsub_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vd, vs2)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmsub); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfnmsub_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vd, vs2)); + }); +} + +// Test fp multiply accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfmacc) { + // Vector-vector. + SetSemanticFunction(&Vfmacc); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfmacc_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vs2, vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmacc); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfmacc_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vs2, vd)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmacc); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfmacc_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vs2, vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmacc); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfmacc_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vs2, vd)); + }); +} + +// Test fp negated multiply add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfnmacc) { + // Vector-vector. + SetSemanticFunction(&Vfnmacc); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfnmacc_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vs2, -vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmacc); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfnmacc_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vs2, -vd)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfnmacc); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfnmacc_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vs2, -vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmacc); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfnmacc_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vs2, -vd)); + }); +} + +// Test fp multiply subtract accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfmsac) { + // Vector-vector. + SetSemanticFunction(&Vfmsac); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfmsac_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vs2, -vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmsac); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfmsac_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vs2, -vd)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmsac); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfmsac_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(vs1, vs2, -vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmsac); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfmsac_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(vs1, vs2, -vd)); + }); +} + +// Test fp negated multiply subtract accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfnmsac) { + // Vector-vector. + SetSemanticFunction(&Vfnmsac); + TernaryOpFPTestHelperVV<float, float, float>( + "Vfnmsac_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vs2, vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmsac); + TernaryOpFPTestHelperVV<double, double, double>( + "Vfnmsac_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vs2, vd)); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfnmsac); + TernaryOpFPTestHelperVX<float, float, float>( + "Vfnmsac_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1, float vd) -> float { + return OptimizationBarrier(std::fma(-vs1, vs2, vd)); + }); + ResetInstruction(); + SetSemanticFunction(&Vfnmsac); + TernaryOpFPTestHelperVX<double, double, double>( + "Vfnmsac_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1, double vd) -> double { + return OptimizationBarrier(std::fma(-vs1, vs2, vd)); + }); +} + +// Test fp widening multiply accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwmacc) { + // Vector-vector. + SetSemanticFunction(&Vfwmacc); + TernaryOpFPTestHelperVV<double, float, float>( + "Vfwmacc_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return OptimizationBarrier(vs1d * vs2d) + vd; + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwmacc); + TernaryOpFPTestHelperVX<double, float, float>( + "Vfwmacc_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return OptimizationBarrier(vs1d * vs2d) + vd; + }); +} + +// Test fp widening negated multiply add. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwnmacc) { + // Vector-vector. + SetSemanticFunction(&Vfwnmacc); + TernaryOpFPTestHelperVV<double, float, float>( + "Vfwnmacc_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return -OptimizationBarrier(vs1d * vs2d) - vd; + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwnmacc); + TernaryOpFPTestHelperVX<double, float, float>( + "Vfwnmacc_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return -OptimizationBarrier(vs1d * vs2d) - vd; + }); +} + +// Test fp widening multiply subtract accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwmsac) { + // Vector-vector. + SetSemanticFunction(&Vfwmsac); + TernaryOpFPTestHelperVV<double, float, float>( + "Vfwmsac_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return OptimizationBarrier(vs1d * vs2d) - vd; + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwmsac); + TernaryOpFPTestHelperVX<double, float, float>( + "Vfwmsac_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return OptimizationBarrier(vs1d * vs2d) - vd; + }); +} + +// Test fp widening negated multiply subtract accumulate. +TEST_F(RiscVCheriotFPInstructionsTest, Vfwnmsac) { + // Vector-vector. + SetSemanticFunction(&Vfwnmsac); + TernaryOpFPTestHelperVV<double, float, float>( + "Vfwnmsac_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return -OptimizationBarrier(vs1d * vs2d) + vd; + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfwnmsac); + TernaryOpFPTestHelperVX<double, float, float>( + "Vfwnmsac_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](float vs2, float vs1, double vd) -> double { + double vs1d = static_cast<double>(vs1); + double vs2d = static_cast<double>(vs2); + return -OptimizationBarrier(vs1d * vs2d) + vd; + }); +} + +// Test vector floating point sign injection instructions. There are three +// of these. vfsgnj, vfsgnjn, and vfsgnjx. The instructions take the sign +// bit from vs1/fs1 and the other bits from vs2. The sign bit is either used +// as is, negated (n) or xor'ed (x). + +template <typename T> +inline T SignHelper( + T vs2, T vs1, + std::function<typename FPTypeInfo<T>::IntType( + typename FPTypeInfo<T>::IntType, typename FPTypeInfo<T>::IntType, + typename FPTypeInfo<T>::IntType)> + sign_op) { + using Int = typename FPTypeInfo<T>::IntType; + Int sign_mask = 1ULL << (FPTypeInfo<T>::kBitSize - 1); + Int vs2i = *reinterpret_cast<Int *>(&vs2); + Int vs1i = *reinterpret_cast<Int *>(&vs1); + Int resi = sign_op(vs2i, vs1i, sign_mask); + return *reinterpret_cast<T *>(&resi); +} + +// The sign is that of vs1. +TEST_F(RiscVCheriotFPInstructionsTest, Vfsgnj) { + // Vector-vector. + SetSemanticFunction(&Vfsgnj); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfsgnj_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (vs1i & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnj); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfsgnj_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (vs1i & mask); + }); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfsgnj); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfsgnj_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (vs1i & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnj); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfsgnj_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (vs1i & mask); + }); + }); +} + +// The sign is the negation of that of vs1. +TEST_F(RiscVCheriotFPInstructionsTest, Vfsgnjn) { + // Vector-vector. + SetSemanticFunction(&Vfsgnjn); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfsgnjn_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (~vs1i & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnjn); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfsgnjn_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (~vs1i & mask); + }); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfsgnjn); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfsgnjn_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (~vs1i & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnjn); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfsgnjn_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | (~vs1i & mask); + }); + }); +} + +// The sign is exclusive or of the signs of vs2 and vs1. +TEST_F(RiscVCheriotFPInstructionsTest, Vfsgnjx) { + // Vector-vector. + SetSemanticFunction(&Vfsgnjx); + BinaryOpFPTestHelperVV<float, float, float>( + "Vfsgnjx_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | ((vs1i ^ vs2i) & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnjx); + BinaryOpFPTestHelperVV<double, double, double>( + "Vfsgnjx_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | ((vs1i ^ vs2i) & mask); + }); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfsgnjx); + BinaryOpFPTestHelperVX<float, float, float>( + "Vfsgnjx_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> float { + using Int = typename FPTypeInfo<float>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | ((vs1i ^ vs2i) & mask); + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsgnjx); + BinaryOpFPTestHelperVX<double, double, double>( + "Vfsgnjx_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> double { + using Int = typename FPTypeInfo<double>::IntType; + return SignHelper(vs2, vs1, [](Int vs2i, Int vs1i, Int mask) -> Int { + return (vs2i & ~mask) | ((vs1i ^ vs2i) & mask); + }); + }); +} + +template <typename T> +bool is_snan(T value) { + using IntType = typename FPTypeInfo<T>::IntType; + IntType int_value = *reinterpret_cast<IntType *>(&value); + bool signal = (int_value & (1ULL << (FPTypeInfo<T>::kSigSize - 1))) == 0; + return std::isnan(value) && signal; +} + +template <typename T> +std::tuple<T, uint32_t> MaxMinHelper(T vs2, T vs1, + std::function<T(T, T)> operation) { + uint32_t flag = 0; + if (is_snan(vs2) || is_snan(vs1)) { + flag = static_cast<uint32_t>(FPExceptions::kInvalidOp); + } + if (std::isnan(vs2) && std::isnan(vs1)) { + // Canonical NaN. + auto canonical = FPTypeInfo<T>::kCanonicalNaN; + T canonical_fp = *reinterpret_cast<T *>(&canonical); + return std::tie(canonical_fp, flag); + } + if (std::isnan(vs2)) return std::tie(vs1, flag); + if (std::isnan(vs1)) return std::tie(vs2, flag); + if ((vs2 == 0.0) && (vs1 == 0.0)) { + T tmp2 = std::signbit(vs2) ? -1.0 : 1; + T tmp1 = std::signbit(vs1) ? -1.0 : 1; + return std::make_tuple(operation(tmp2, tmp1) == tmp2 ? vs2 : vs1, 0); + } + return std::make_tuple(operation(vs2, vs1), 0); +} + +TEST_F(RiscVCheriotFPInstructionsTest, Vfmax) { + // Vector-vector. + SetSemanticFunction(&Vfmax); + BinaryOpWithFflagsFPTestHelperVV<float, float, float>( + "Vfmax_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + auto tmp = MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + return tmp; + }); + ResetInstruction(); + SetSemanticFunction(&Vfmax); + BinaryOpWithFflagsFPTestHelperVV<double, double, double>( + "Vfmax_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmax); + BinaryOpWithFflagsFPTestHelperVX<float, float, float>( + "Vfmax_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmax); + BinaryOpWithFflagsFPTestHelperVX<double, double, double>( + "Vfmax_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 > vs2) ? vs1 : vs2; + }); + }); +} + +TEST_F(RiscVCheriotFPInstructionsTest, Vfmin) { + // Vector-vector. + SetSemanticFunction(&Vfmin); + BinaryOpWithFflagsFPTestHelperVV<float, float, float>( + "Vfmin_vv32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmin); + BinaryOpWithFflagsFPTestHelperVV<double, double, double>( + "Vfmin_vv64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); + // Vector-scalar. + ResetInstruction(); + SetSemanticFunction(&Vfmin); + BinaryOpWithFflagsFPTestHelperVX<float, float, float>( + "Vfmin_vx32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2, float vs1) -> std::tuple<float, uint32_t> { + using T = float; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfmin); + BinaryOpWithFflagsFPTestHelperVX<double, double, double>( + "Vfmin_vx64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2, double vs1) -> std::tuple<double, uint32_t> { + using T = double; + return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T { + return (vs1 < vs2) ? vs1 : vs2; + }); + }); +} + +TEST_F(RiscVCheriotFPInstructionsTest, Vfmerge) { + // Vector-scalar. + SetSemanticFunction(&Vfmerge); + BinaryOpFPWithMaskTestHelperVX<float, float, float>( + "Vfmerge_vx32", /*sew*/ 32, instruction_, /*delta position*/ 32, + [](float vs2, float vs1, bool mask) -> float { + return mask ? vs1 : vs2; + }); + BinaryOpFPWithMaskTestHelperVX<double, double, double>( + "Vfmerge_vx64", /*sew*/ 64, instruction_, /*delta position*/ 64, + [](double vs2, double vs1, bool mask) -> double { + return mask ? vs1 : vs2; + }); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_fp_reduction_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_fp_reduction_instructions_test.cc new file mode 100644 index 0000000..ba3ad17 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_fp_reduction_instructions_test.cc
@@ -0,0 +1,248 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_reduction_instructions.h" + +#include <algorithm> +#include <cstdint> +#include <functional> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/test/riscv_cheriot_vector_fp_test_utilities.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_register.h" + +namespace { + +using Instruction = ::mpact::sim::generic::Instruction; + +// Functions to test. + +using ::mpact::sim::cheriot::Vfredmax; +using ::mpact::sim::cheriot::Vfredmin; +using ::mpact::sim::cheriot::Vfredosum; +using ::mpact::sim::cheriot::Vfwredosum; + +using ::absl::Span; +using ::mpact::sim::riscv::FPRoundingMode; +using ::mpact::sim::riscv::RVFpRegister; +using ::mpact::sim::riscv::ScopedFPStatus; + +// Test fixture for binary fp instructions. +class RiscVCheriotFPReductionInstructionsTest + : public RiscVCheriotFPInstructionsTestBase { + public: + // Helper function for floating point reduction operations. + template <typename Vd, typename Vs2, typename Vs1> + void ReductionOpFPTestHelper(absl::string_view name, int sew, + Instruction *inst, int delta_position, + std::function<Vd(Vs1, Vs2)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs1 vs1_value[vs1_size * 8]; + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + auto mask_span = Span<const uint8_t>(kA5Mask); + SetVectorRegisterValues<uint8_t>({{kVmaskName, mask_span}}); + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + FillArrayWithRandomFPValues<Vs1>(vs1_span); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = + absl::Uniform(absl::IntervalOpenClosed, bitgen_, 0, num_values); + } + num_values = std::min(num_values, vlen); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + // Initialize the accumulator with the value from vs1[0]. + Vd accumulator = static_cast<Vd>(vs1_span[0]); + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + for (int i = 0; i < num_values; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + accumulator = operation(accumulator, vs2_span[i]); + } + } + auto reg_val = vreg_[kVd]->data_buffer()->Get<Vd>(0); + FPCompare<Vd>(accumulator, reg_val, delta_position, ""); + } + } + } + } +}; + +// Test vector floating point sum reduction. +TEST_F(RiscVCheriotFPReductionInstructionsTest, Vfredosum) { + SetSemanticFunction(&Vfredosum); + ReductionOpFPTestHelper<float, float, float>( + "Vfredosum_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float val0, float val1) -> float { return val0 + val1; }); + ResetInstruction(); + SetSemanticFunction(&Vfredosum); + ReductionOpFPTestHelper<double, double, double>( + "Vfredosum_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double val0, double val1) -> double { return val0 + val1; }); +} + +// Test vector floating point widening sum reduction. +TEST_F(RiscVCheriotFPReductionInstructionsTest, Vfwredosum) { + SetSemanticFunction(&Vfwredosum); + ReductionOpFPTestHelper<double, float, double>( + "Vfwredosum_32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [](double val0, float val1) -> double { + return val0 + static_cast<double>(val1); + }); +} + +template <typename T> +T MaxMinHelper(T vs2, T vs1, std::function<T(T, T)> operation) { + using UInt = typename FPTypeInfo<T>::IntType; + UInt vs2_uint = *reinterpret_cast<UInt *>(&vs2); + UInt vs1_uint = *reinterpret_cast<UInt *>(&vs1); + UInt mask = 1ULL << (FPTypeInfo<T>::kSigSize - 1); + bool nan_vs2 = std::isnan(vs2); + bool nan_vs1 = std::isnan(vs1); + if ((nan_vs2 && ((mask & vs2_uint) == 0)) || + (nan_vs1 && ((mask & vs1_uint) == 0)) || (nan_vs2 && nan_vs1)) { + // Canonical NaN. + UInt canonical = ((1ULL << (FPTypeInfo<T>::kExpSize + 1)) - 1) + << (FPTypeInfo<T>::kSigSize - 1); + T canonical_fp = *reinterpret_cast<T *>(&canonical); + return canonical_fp; + } + if (nan_vs2) return vs1; + if (nan_vs1) return vs2; + return operation(vs2, vs1); +} + +// Test vector floating point min reduction. +TEST_F(RiscVCheriotFPReductionInstructionsTest, Vfredmin) { + SetSemanticFunction(&Vfredmin); + ReductionOpFPTestHelper<float, float, float>( + "Vfredmin_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float val0, float val1) -> float { + return MaxMinHelper<float>(val0, val1, + [](float val0, float val1) -> float { + return (val0 > val1) ? val1 : val0; + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfredmin); + ReductionOpFPTestHelper<double, double, double>( + "Vfredmin_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double val0, double val1) -> double { + return MaxMinHelper<double>(val0, val1, + [](double val0, double val1) -> double { + return (val0 > val1) ? val1 : val0; + }); + }); +} + +// Test vector floating point max reduction. +TEST_F(RiscVCheriotFPReductionInstructionsTest, Vfredmax) { + SetSemanticFunction(&Vfredmax); + ReductionOpFPTestHelper<float, float, float>( + "Vfredmin_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float val0, float val1) -> float { + return MaxMinHelper<float>(val0, val1, + [](float val0, float val1) -> float { + return (val0 < val1) ? val1 : val0; + }); + }); + ResetInstruction(); + SetSemanticFunction(&Vfredmax); + ReductionOpFPTestHelper<double, double, double>( + "Vfredmin_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double val0, double val1) -> double { + return MaxMinHelper<double>(val0, val1, + [](double val0, double val1) -> double { + return (val0 < val1) ? val1 : val0; + }); + }); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_fp_test_utilities.h b/cheriot/test/riscv_cheriot_vector_fp_test_utilities.h new file mode 100644 index 0000000..6aba98e --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_fp_test_utilities.h
@@ -0,0 +1,948 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_CHERIOT_TEST_RISCV_CHERIOT_VECTOR_FP_TEST_UTILITIES_H_ +#define MPACT_CHERIOT_TEST_RISCV_CHERIOT_VECTOR_FP_TEST_UTILITIES_H_ + +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <functional> +#include <tuple> +#include <type_traits> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_state.h" + +using ::mpact::sim::generic::operator*; +using ::mpact::sim::riscv::FPRoundingMode; +using ::mpact::sim::riscv::ScopedFPStatus; + +constexpr char kFs1Name[] = "f4"; +constexpr int kFs1 = 4; + +// Templated helper structs to provide information about floating point types. +template <typename T> +struct FPTypeInfo { + using IntType = typename std::make_unsigned<T>::type; + static const int kBitSize = 8 * sizeof(T); + static const int kExpSize = 0; + static const int kSigSize = 0; + static bool IsNaN(T value) { return false; } +}; + +template <> +struct FPTypeInfo<float> { + using T = float; + using IntType = uint32_t; + static const int kBitSize = sizeof(float) << 3; + static const int kExpSize = 8; + static const int kSigSize = kBitSize - kExpSize - 1; + static const IntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static const IntType kSigMask = (1ULL << kSigSize) - 1; + static const IntType kQNaN = kExpMask | (1ULL << (kSigSize - 1)) | 1; + static const IntType kSNaN = kExpMask | 1; + static const IntType kPosInf = kExpMask; + static const IntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static const IntType kPosZero = 0; + static const IntType kNegZero = 1ULL << (kBitSize - 1); + static const IntType kPosDenorm = 1ULL << (kSigSize - 2); + static const IntType kNegDenorm = + (1ULL << (kBitSize - 1)) | (1ULL << (kSigSize - 2)); + static const IntType kCanonicalNaN = 0x7fc0'0000ULL; + static bool IsNaN(T value) { return std::isnan(value); } +}; + +template <> +struct FPTypeInfo<double> { + using T = double; + using IntType = uint64_t; + static const int kBitSize = sizeof(double) << 3; + static const int kExpSize = 11; + static const int kSigSize = kBitSize - kExpSize - 1; + static const IntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static const IntType kSigMask = (1ULL << kSigSize) - 1; + static const IntType kQNaN = kExpMask | (1ULL << (kSigSize - 1)) | 1; + static const IntType kSNaN = kExpMask | 1; + static const IntType kPosInf = kExpMask; + static const IntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static const IntType kPosZero = 0; + static const IntType kNegZero = 1ULL << (kBitSize - 1); + static const IntType kPosDenorm = 1ULL << (kSigSize - 2); + static const IntType kNegDenorm = + (1ULL << (kBitSize - 1)) | (1ULL << (kSigSize - 2)); + static const IntType kCanonicalNaN = 0x7ff8'0000'0000'0000ULL; + static bool IsNaN(T value) { return std::isnan(value); } +}; + +// These templated functions allow for comparison of values with a tolerance +// given for floating point types. The tolerance is stated as the bit position +// in the mantissa of the op, with 0 being the msb of the mantissa. If the +// bit position is beyond the mantissa, a comparison of equal is performed. +template <typename T> +inline void FPCompare(T op, T reg, int, absl::string_view str) { + EXPECT_EQ(reg, op) << str; +} + +template <> +inline void FPCompare<float>(float op, float reg, int delta_position, + absl::string_view str) { + using T = float; + using UInt = typename FPTypeInfo<T>::IntType; + if (!std::isnan(op) && !std::isinf(op) && + delta_position < FPTypeInfo<T>::kSigSize) { + T delta; + UInt exp = FPTypeInfo<T>::kExpMask >> FPTypeInfo<T>::kSigSize; + if (exp > delta_position) { + exp -= delta_position; + UInt udelta = exp << FPTypeInfo<T>::kSigSize; + delta = *reinterpret_cast<T *>(&udelta); + } else { + // Becomes a denormal + int diff = delta_position - exp; + UInt udelta = 1ULL << (FPTypeInfo<T>::kSigSize - 1 - diff); + delta = *reinterpret_cast<T *>(&udelta); + } + EXPECT_THAT(reg, testing::NanSensitiveFloatNear(op, delta)) << str; + } else { + EXPECT_THAT(reg, testing::NanSensitiveFloatEq(op)) << str; + } +} + +template <> +inline void FPCompare<double>(double op, double reg, int delta_position, + absl::string_view str) { + using T = double; + using UInt = typename FPTypeInfo<T>::IntType; + if (!std::isnan(op) && !std::isinf(op) && + delta_position < FPTypeInfo<T>::kSigSize) { + T delta; + UInt exp = FPTypeInfo<T>::kExpMask >> FPTypeInfo<T>::kSigSize; + if (exp > delta_position) { + exp -= delta_position; + UInt udelta = exp << FPTypeInfo<T>::kSigSize; + delta = *reinterpret_cast<T *>(&udelta); + } else { + // Becomes a denormal + int diff = delta_position - exp; + UInt udelta = 1ULL << (FPTypeInfo<T>::kSigSize - 1 - diff); + delta = *reinterpret_cast<T *>(&udelta); + } + EXPECT_THAT(reg, testing::NanSensitiveDoubleNear(op, delta)) << str; + } else { + EXPECT_THAT(reg, testing::NanSensitiveDoubleEq(op)) << str; + } +} + +template <typename FP> +FP OptimizationBarrier(FP op) { + asm volatile("" : "+X"(op)); + return op; +} + +namespace internal { + +// These are predicates used in the following NaNBox function definitions, as +// part of the enable_if construct. +template <typename S, typename D> +struct EqualSize { + static const bool value = sizeof(S) == sizeof(D) && + std::is_floating_point<S>::value && + std::is_integral<D>::value; +}; + +template <typename S, typename D> +struct GreaterSize { + static const bool value = + sizeof(S) > sizeof(D) && + std::is_floating_point<S>::value &&std::is_integral<D>::value; +}; + +template <typename S, typename D> +struct LessSize { + static const bool value = sizeof(S) < sizeof(D) && + std::is_floating_point<S>::value && + std::is_integral<D>::value; +}; + +} // namespace internal + +// Template functions to NaN box a floating point value when being assigned +// to a wider register. The first version places a smaller floating point value +// in a NaN box (all upper bits in the word are set to 1). + +// Enable_if is used to select the proper implementation for different S and D +// type combinations. It uses the SFINAE (substitution failure is not an error) +// "feature" of C++ to hide the implementation that don't match the predicate +// from being resolved. + +template <typename S, typename D> +inline typename std::enable_if<internal::LessSize<S, D>::value, D>::type NaNBox( + S value) { + using SInt = typename FPTypeInfo<S>::IntType; + SInt sval = *reinterpret_cast<SInt *>(&value); + D dval = (~static_cast<D>(0) << (sizeof(S) * 8)) | sval; + return *reinterpret_cast<D *>(&dval); +} + +// This version does a straight copy - as the data types are the same size. +template <typename S, typename D> +inline typename std::enable_if<internal::EqualSize<S, D>::value, D>::type +NaNBox(S value) { + return *reinterpret_cast<D *>(&value); +} + +// Signal error if the register is smaller than the floating point value. +template <typename S, typename D> +inline typename std::enable_if<internal::GreaterSize<S, D>::value, D>::type +NaNBox(S value) { + // No return statement, so error will be reported. +} + +// Test fixture for binary fp instructions. +class RiscVCheriotFPInstructionsTestBase + : public RiscVCheriotVectorInstructionsTestBase { + public: + RiscVCheriotFPInstructionsTestBase() { + rv_fp_ = new mpact::sim::riscv::RiscVFPState(state_->csr_set(), state_); + state_->set_rv_fp(rv_fp_); + } + ~RiscVCheriotFPInstructionsTestBase() override { + state_->set_rv_fp(nullptr); + delete rv_fp_; + } + + // Construct a random FP value by separately generating integer values for + // sign, exponent and mantissa. + template <typename T> + T RandomFPValue() { + using UInt = typename FPTypeInfo<T>::IntType; + UInt sign = absl::Uniform(absl::IntervalClosed, bitgen_, 0ULL, 1ULL); + UInt exp = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0ULL, + 1ULL << FPTypeInfo<T>::kExpSize); + UInt sig = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0ULL, + 1ULL << FPTypeInfo<T>::kSigSize); + UInt value = (sign & 1) << (FPTypeInfo<T>::kBitSize - 1) | + (exp << FPTypeInfo<T>::kSigSize) | sig; + T val = *reinterpret_cast<T *>(&value); + return val; + } + + // This method uses random values for each field in the fp number. + template <typename T> + void FillArrayWithRandomFPValues(absl::Span<T> span) { + for (auto &val : span) { + val = RandomFPValue<T>(); + } + } + + template <typename Vs2, typename Vs1> + void InitializeInputs(absl::Span<Vs2> vs2_span, absl::Span<Vs1> vs1_span, + absl::Span<uint8_t> mask_span, int count) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + FillArrayWithRandomFPValues<Vs1>(vs1_span); + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using Vs1Int = typename FPTypeInfo<Vs1>::IntType; + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + if (count == 4) { + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kPosDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kNegDenorm; + } else if (count == 5) { + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kPosDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kNegDenorm; + } else if (count == 6) { + *reinterpret_cast<Vs1Int *>(&vs1_span[0]) = FPTypeInfo<Vs1>::kQNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[1]) = FPTypeInfo<Vs1>::kSNaN; + *reinterpret_cast<Vs1Int *>(&vs1_span[2]) = FPTypeInfo<Vs1>::kNegInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[3]) = FPTypeInfo<Vs1>::kPosInf; + *reinterpret_cast<Vs1Int *>(&vs1_span[4]) = FPTypeInfo<Vs1>::kNegZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[5]) = FPTypeInfo<Vs1>::kPosZero; + *reinterpret_cast<Vs1Int *>(&vs1_span[6]) = FPTypeInfo<Vs1>::kNegDenorm; + *reinterpret_cast<Vs1Int *>(&vs1_span[7]) = FPTypeInfo<Vs1>::kPosDenorm; + } + // Modify the first mask bits to use each of the special floating + // point values. + mask_span[0] = 0xff; + } + + // Floating point test needs to ensure to use the fp special values (inf, + // NaN etc.) during testing, not just random values. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpFPTestHelperVV(absl::string_view name, int sew, + Instruction *inst, int delta_position, + std::function<Vd(Vs2, Vs1)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << " sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + Vs1 vs1_value[vs1_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + InitializeInputs<Vs2, Vs1>(vs2_span, vs1_span, + vreg_[kVmask]->data_buffer()->Get<uint8_t>(), + lmul_index); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so + // only read the mask value after the first byte. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + ®_val); + if ((count >= vstart) && mask_value && (count < num_values)) { + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto op_val = operation(vs2_value[count], vs1_value[count]); + auto int_op_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &op_val); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[count]); + auto int_vs1_val = + *reinterpret_cast<typename FPTypeInfo<Vs1>::IntType *>( + &vs1_value[count]); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), "], ", + vs1_value[count], "[0x", + absl::Hex(int_vs1_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(int_reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + if (HasFailure()) return; + } + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Floating point test needs to ensure to use the fp special values (inf, + // NaN etc.) during testing, not just random values. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpWithFflagsFPTestHelperVV( + absl::string_view name, int sew, Instruction *inst, int delta_position, + std::function<std::tuple<Vd, uint32_t>(Vs2, Vs1)> operation) { + using VdInt = typename FPTypeInfo<Vd>::IntType; + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using Vs1Int = typename FPTypeInfo<Vs1>::IntType; + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << " sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + Vs1 vs1_value[vs1_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + auto *flag_op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags"); + instruction_->AppendDestination(flag_op); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + InitializeInputs<Vs2, Vs1>(vs2_span, vs1_span, + vreg_[kVmask]->data_buffer()->Get<uint8_t>(), + lmul_index); + + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + uint32_t fflags_test = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so + // only read the mask value after the first byte. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = *reinterpret_cast<VdInt *>(®_val); + if ((count >= vstart) && mask_value && (count < num_values)) { + Vd op_val; + uint32_t flag; + { + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto [op_val_tmp, flag_tmp] = + operation(vs2_value[count], vs1_value[count]); + op_val = op_val_tmp; + flag = flag_tmp; + } + auto int_op_val = *reinterpret_cast<VdInt *>(&op_val); + auto int_vs2_val = + *reinterpret_cast<Vs2Int *>(&vs2_value[count]); + auto int_vs1_val = + *reinterpret_cast<Vs1Int *>(&vs1_value[count]); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), "], ", + vs1_value[count], "[0x", + absl::Hex(int_vs1_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + fflags_test |= flag; + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(int_reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + } + uint32_t fflags = rv_fp_->fflags()->AsUint32(); + EXPECT_EQ(fflags, fflags_test) << name; + if (HasFailure()) return; + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Floating point test needs to ensure to use the fp special values (inf, + // NaN etc.) during testing, not just random values. + template <typename Vd, typename Vs2, typename Fs1> + void BinaryOpWithFflagsFPTestHelperVX( + absl::string_view name, int sew, Instruction *inst, int delta_position, + std::function<std::tuple<Vd, uint32_t>(Vs2, Fs1)> operation) { + using VdInt = typename FPTypeInfo<Vd>::IntType; + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using Fs1Int = typename FPTypeInfo<Fs1>::IntType; + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Fs1)) { + FAIL() << name << ": selected element width != any operand types" + << " sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Fs1: " << sizeof(Fs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {kVd}); + AppendRegisterOperands({kFs1Name}, {}); + auto *flag_op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags"); + instruction_->AppendDestination(flag_op); + AppendVectorRegisterOperands({kVmask}, {}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + // Modify the first mask bits to use each of the special floating + // point values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Fs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Generate a new rs1 value. + Fs1 fs1_value = RandomFPValue<Fs1>(); + // Need to NaN box the value, that is, if the register value type + // is wider than the data type for a floating point value, the + // upper bits are all set to 1's. + typename RVFpRegister::ValueType fs1_reg_value = + NaNBox<Fs1, typename RVFpRegister::ValueType>(fs1_value); + SetRegisterValues<typename RVFpRegister::ValueType, RVFpRegister>( + {{kFs1Name, fs1_reg_value}}); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + uint32_t fflags_test = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so + // only read the mask value after the first byte. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = *reinterpret_cast<VdInt *>(®_val); + if ((count >= vstart) && mask_value && (count < num_values)) { + Vd op_val; + uint32_t flag; + { + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto [op_val_tmp, flag_tmp] = + operation(vs2_value[count], fs1_value); + op_val = op_val_tmp; + flag = flag_tmp; + } + auto int_op_val = *reinterpret_cast<VdInt *>(&op_val); + auto int_vs2_val = + *reinterpret_cast<Vs2Int *>(&vs2_value[count]); + auto int_fs1_val = *reinterpret_cast<Fs1Int *>(&fs1_value); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), "], ", + fs1_value, "[0x", absl::Hex(int_fs1_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + fflags_test |= flag; + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(int_reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + } + uint32_t fflags = rv_fp_->fflags()->AsUint32(); + EXPECT_EQ(fflags, fflags_test) << name; + if (HasFailure()) return; + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Floating point test needs to ensure to use the fp special values (inf, + // NaN etc.) during testing, not just random values. This function handles + // vector scalar instructions. + template <typename Vd, typename Vs2, typename Fs1> + void BinaryOpFPWithMaskTestHelperVX( + absl::string_view name, int sew, Instruction *inst, int delta_position, + std::function<Vd(Vs2, Fs1, bool)> operation) { + using VdInt = typename FPTypeInfo<Vd>::IntType; + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + using Fs1Int = typename FPTypeInfo<Fs1>::IntType; + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Fs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Fs1: " << sizeof(Fs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {kVd}); + AppendRegisterOperands({kFs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + // Modify the first mask bits to use each of the special floating + // point values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Fs1) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Generate a new rs1 value. + Fs1 fs1_value = RandomFPValue<Fs1>(); + // Need to NaN box the value, that is, if the register value type + // is wider than the data type for a floating point value, the + // upper bits are all set to 1's. + typename RVFpRegister::ValueType fs1_reg_value = + NaNBox<Fs1, typename RVFpRegister::ValueType>(fs1_value); + SetRegisterValues<typename RVFpRegister::ValueType, RVFpRegister>( + {{kFs1Name, fs1_reg_value}}); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs1 < 1 || lmul8_vs1 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs1: " << lmul8_vs1; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 8 bits of the mask are set to true above, so + // only read the mask value after the first byte. + if (mask_index > 0) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + auto int_reg_val = *reinterpret_cast<VdInt *>(®_val); + if ((count >= vstart) && (count < num_values)) { + ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); + auto op_val = + operation(vs2_value[count], fs1_value, mask_value); + auto int_op_val = *reinterpret_cast<VdInt *>(&op_val); + auto int_vs2_val = + *reinterpret_cast<Vs2Int *>(&vs2_value[count]); + auto int_fs1_val = *reinterpret_cast<Fs1Int *>(&fs1_value); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), "], ", + fs1_value, "[0x", absl::Hex(int_fs1_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(int_reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + if (HasFailure()) return; + } + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Templated helper function that tests FP vector-scalar instructions that do + // not use the value of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpFPTestHelperVX(absl::string_view name, int sew, + Instruction *inst, int delta_position, + std::function<Vd(Vs2, Vs1)> operation) { + BinaryOpFPWithMaskTestHelperVX<Vd, Vs2, Vs1>( + name, sew, inst, delta_position, + [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> Vd { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + protected: + mpact::sim::riscv::RiscVFPState *rv_fp_; +}; + +#endif // MPACT_CHERIOT_TEST_RISCV_CHERIOT_VECTOR_FP_TEST_UTILITIES_H_
diff --git a/cheriot/test/riscv_cheriot_vector_fp_unary_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_fp_unary_instructions_test.cc new file mode 100644 index 0000000..94315d7 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_fp_unary_instructions_test.cc
@@ -0,0 +1,1005 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_fp_unary_instructions.h" + +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <functional> +#include <limits> +#include <string> +#include <tuple> +#include <type_traits> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/test/riscv_cheriot_vector_fp_test_utilities.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_csr.h" +#include "riscv//riscv_fp_host.h" +#include "riscv//riscv_fp_info.h" +#include "riscv//riscv_fp_state.h" +#include "riscv//riscv_register.h" + +namespace { + +using Instruction = ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::operator*; // NOLINT: used below. +using ::mpact::sim::riscv::FPExceptions; + +// Functions to test. +using ::mpact::sim::cheriot::Vfclassv; +using ::mpact::sim::cheriot::Vfcvtfxuv; +using ::mpact::sim::cheriot::Vfcvtfxv; +using ::mpact::sim::cheriot::Vfcvtrtzxfv; +using ::mpact::sim::cheriot::Vfcvtrtzxufv; +using ::mpact::sim::cheriot::Vfcvtxfv; +using ::mpact::sim::cheriot::Vfcvtxufv; +using ::mpact::sim::cheriot::Vfmvfs; +using ::mpact::sim::cheriot::Vfmvsf; +using ::mpact::sim::cheriot::Vfncvtffw; +using ::mpact::sim::cheriot::Vfncvtfxuw; +using ::mpact::sim::cheriot::Vfncvtfxw; +using ::mpact::sim::cheriot::Vfncvtrodffw; +using ::mpact::sim::cheriot::Vfncvtrtzxfw; +using ::mpact::sim::cheriot::Vfncvtrtzxufw; +using ::mpact::sim::cheriot::Vfncvtxfw; +using ::mpact::sim::cheriot::Vfncvtxufw; +using ::mpact::sim::cheriot::Vfrec7v; +using ::mpact::sim::cheriot::Vfrsqrt7v; +using ::mpact::sim::cheriot::Vfsqrtv; +using ::mpact::sim::cheriot::Vfwcvtffv; +using ::mpact::sim::cheriot::Vfwcvtfxuv; +using ::mpact::sim::cheriot::Vfwcvtfxv; +using ::mpact::sim::cheriot::Vfwcvtrtzxfv; +using ::mpact::sim::cheriot::Vfwcvtrtzxufv; +using ::mpact::sim::cheriot::Vfwcvtxfv; +using ::mpact::sim::cheriot::Vfwcvtxufv; + +using ::absl::Span; +using ::mpact::sim::riscv::FPRoundingMode; +using ::mpact::sim::riscv::RiscVCsrInterface; +using ::mpact::sim::riscv::RiscVFPState; +using ::mpact::sim::riscv::RVFpRegister; +using ::mpact::sim::riscv::ScopedFPStatus; + +constexpr char kFs1Name[] = "f4"; +constexpr int kFs1 = 4; + +// Test fixture that extends the Base fixture for unary floating point +// instructions. +class RiscVCheriotFPUnaryInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase { + public: + RiscVCheriotFPUnaryInstructionsTest() { + rv_fp_ = new mpact::sim::riscv::RiscVFPState(state_->csr_set(), state_); + state_->set_rv_fp(rv_fp_); + } + ~RiscVCheriotFPUnaryInstructionsTest() override { + state_->set_rv_fp(nullptr); + delete rv_fp_; + } + + // This method uses random values for each field in the fp number. + template <typename T> + void FillArrayWithRandomFPValues(absl::Span<T> span) { + using UInt = typename FPTypeInfo<T>::IntType; + for (auto &val : span) { + UInt sign = absl::Uniform(absl::IntervalClosed, bitgen_, 0ULL, 1ULL); + UInt exp = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0ULL, + 1ULL << FPTypeInfo<T>::kExpSize); + UInt sig = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0ULL, + 1ULL << FPTypeInfo<T>::kSigSize); + UInt value = (sign & 1) << (FPTypeInfo<T>::kBitSize - 1) | + (exp << FPTypeInfo<T>::kSigSize) | sig; + val = *reinterpret_cast<T *>(&value); + } + } + + // Floating point test needs to ensure to use the fp special values (inf, NaN + // etc.) during testing, not just random values. + template <typename Vd, typename Vs2> + void UnaryOpFPTestHelperV(absl::string_view name, int sew, Instruction *inst, + int delta_position, + std::function<Vd(Vs2)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[8]) = 0x0119515e; + *reinterpret_cast<Vs2Int *>(&vs2_span[9]) = 0x0007fea3; + *reinterpret_cast<Vs2Int *>(&vs2_span[10]) = 0x800bc58f; + // Modify the first mask bits to use each of the special floating point + // values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + vreg_[kVmask]->data_buffer()->Set<uint8_t>(1, 0xa5 | 0x3); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 10 bits of the mask are set to true above, so only + // read the mask value for the entries after that. + if (count >= 10) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + if ((count >= vstart) && mask_value && (count < num_values)) { + auto op_val = operation(vs2_value[count]); + // Do separate comparison if the result is a NaN. + auto int_reg_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + ®_val); + auto int_op_val = + *reinterpret_cast<typename FPTypeInfo<Vd>::IntType *>( + &op_val); + auto int_vs2_val = + *reinterpret_cast<typename FPTypeInfo<Vs2>::IntType *>( + &vs2_value[count]); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + } + } + if (HasFailure()) { + return; + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + // Floating point test needs to ensure to use the fp special values (inf, NaN + // etc.) during testing, not just random values. + template <typename Vd, typename Vs2> + void UnaryOpWithFflagsFPTestHelperV( + absl::string_view name, int sew, Instruction *inst, int delta_position, + std::function<std::tuple<Vd, uint32_t>(Vs2)> operation) { + using VdInt = typename FPTypeInfo<Vd>::IntType; + using Vs2Int = typename FPTypeInfo<Vs2>::IntType; + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + auto *op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags"); + instruction_->AppendDestination(op); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate across different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Initialize input values. + FillArrayWithRandomFPValues<Vs2>(vs2_span); + // Overwrite the first few values of the input data with infinities, + // zeros, denormals and NaNs. + *reinterpret_cast<Vs2Int *>(&vs2_span[0]) = FPTypeInfo<Vs2>::kQNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[1]) = FPTypeInfo<Vs2>::kSNaN; + *reinterpret_cast<Vs2Int *>(&vs2_span[2]) = FPTypeInfo<Vs2>::kPosInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[3]) = FPTypeInfo<Vs2>::kNegInf; + *reinterpret_cast<Vs2Int *>(&vs2_span[4]) = FPTypeInfo<Vs2>::kPosZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[5]) = FPTypeInfo<Vs2>::kNegZero; + *reinterpret_cast<Vs2Int *>(&vs2_span[6]) = FPTypeInfo<Vs2>::kPosDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[7]) = FPTypeInfo<Vs2>::kNegDenorm; + *reinterpret_cast<Vs2Int *>(&vs2_span[8]) = 0x0119515e; + *reinterpret_cast<Vs2Int *>(&vs2_span[9]) = 0x0007fea3; + *reinterpret_cast<Vs2Int *>(&vs2_span[10]) = 0x800bc58f; + // Modify the first mask bits to use each of the special floating point + // values. + vreg_[kVmask]->data_buffer()->Set<uint8_t>(0, 0xff); + vreg_[kVmask]->data_buffer()->Set<uint8_t>(1, 0xa5 | 0x3); + // Set values for all 8 vector registers in the vector register group. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_reg_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + int vstart = 0; + // Try different vstart values (updated at the bottom of the loop). + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + int vlen = 1024; + // Try different vector lengths (updated at the bottom of the loop). + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ASSERT_TRUE(vlen > vstart); + int num_values = std::min(num_reg_values, vlen); + ConfigureVectorUnit(vtype, vlen); + // Iterate across rounding modes. + for (int rm : {0, 1, 2, 3, 4}) { + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + // Set rounding mode and clear flags. + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_fp_->fflags()->Write(0U); + + inst->Execute(); + + // Get the flags for the instruction execution. + uint32_t fflags = rv_fp_->fflags()->AsUint32(); + + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + // Clear flags for the test execution. + rv_fp_->fflags()->Write(0U); + uint32_t fflags_test = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = true; + // The first 10 bits of the mask are set to true above, so only + // read the mask value for the entries after that. + if (count >= 10) { + mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + } + auto reg_val = vreg_[reg]->data_buffer()->Get<Vd>(i); + if ((count >= vstart) && mask_value && (count < num_values)) { + Vd op_val; + uint32_t flag; + { + ScopedFPStatus set_fp_status(rv_fp_->host_fp_interface()); + auto [op_val_tmp, flag_tmp] = operation(vs2_value[count]); + op_val = op_val_tmp; + flag = flag_tmp; + } + fflags_test |= (rv_fp_->fflags()->AsUint32() | flag); + // Do separate comparison if the result is a NaN. + auto int_reg_val = *reinterpret_cast<VdInt *>(®_val); + auto int_op_val = *reinterpret_cast<VdInt *>(&op_val); + auto int_vs2_val = + *reinterpret_cast<Vs2Int *>(&vs2_value[count]); + FPCompare<Vd>( + op_val, reg_val, delta_position, + absl::StrCat(name, "[", count, "] op(", vs2_value[count], + "[0x", absl::Hex(int_vs2_val), + "]) = ", absl::Hex(int_op_val), " != reg[", + reg, "][", i, "] (", reg_val, " [0x", + absl::Hex(int_reg_val), "]) lmul8(", lmul8, + ") rm = ", *(rv_fp_->GetRoundingMode()))); + } else { + EXPECT_EQ(0, reg_val) << absl::StrCat( + name, " 0 != reg[", reg, "][", i, "] (", reg_val, + " [0x", absl::Hex(reg_val), "]) lmul8(", lmul8, ")"); + } + count++; + } + } + EXPECT_EQ(fflags, fflags_test) << name; + } + if (HasFailure()) { + return; + } + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_reg_values); + } + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_reg_values); + } + } + } + + protected: + mpact::sim::riscv::RiscVFPState *rv_fp_ = nullptr; + RiscVCsrInterface *fflags_ = nullptr; +}; + +// Templated helper function for classifying fp numbers. +template <typename T> +typename FPTypeInfo<T>::IntType VfclassVHelper(T val) { + auto fp_class = std::fpclassify(val); + switch (fp_class) { + case FP_INFINITE: + return std::signbit(val) ? 1 : 1 << 7; + case FP_NAN: { + auto uint_val = + *reinterpret_cast<typename FPTypeInfo<T>::IntType *>(&val); + bool quiet_nan = (uint_val >> (FPTypeInfo<T>::kSigSize - 1)) & 1; + return quiet_nan ? 1 << 9 : 1 << 8; + } + case FP_ZERO: + return std::signbit(val) ? 1 << 3 : 1 << 4; + case FP_SUBNORMAL: + return std::signbit(val) ? 1 << 2 : 1 << 5; + case FP_NORMAL: + return std::signbit(val) ? 1 << 1 : 1 << 6; + } + return 0; +} + +// Test fp classify. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfclassv) { + SetSemanticFunction(&Vfclassv); + UnaryOpFPTestHelperV<uint32_t, float>( + "Vfclassv_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [](float vs2) -> uint32_t { return VfclassVHelper(vs2); }); + ResetInstruction(); + SetSemanticFunction(&Vfclassv); + UnaryOpFPTestHelperV<uint64_t, double>( + "Vfclassv_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2) -> uint64_t { return VfclassVHelper(vs2); }); +} + +// Test convert from unsigned integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtfxuv) { + SetSemanticFunction(&Vfcvtfxuv); + UnaryOpTestHelperV<float, uint32_t>( + "Vfcvt.f.xu.v_32", /*sew*/ 32, instruction_, + [](uint32_t vs2) -> float { return static_cast<float>(vs2); }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtfxuv); + UnaryOpTestHelperV<double, uint64_t>( + "Vfcvt.f.xu.v_64", /*sew*/ 64, instruction_, + [](uint64_t vs2) -> double { return static_cast<double>(vs2); }); +} + +// Test convert from signed integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtfxv) { + SetSemanticFunction(&Vfcvtfxv); + UnaryOpTestHelperV<float, int32_t>( + "Vfcvt.f.x.v_32", /*sew*/ 32, instruction_, + [](int32_t vs2) -> float { return static_cast<float>(vs2); }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtfxv); + UnaryOpTestHelperV<double, int64_t>( + "Vfcvt.f.x.v_64", /*sew*/ 64, instruction_, + [](int64_t vs2) -> double { return static_cast<double>(vs2); }); +} + +// Helper function for fp to integer conversions. +template <typename F, typename I> +std::tuple<I, uint32_t> ConvertHelper(F value, RiscVFPState *fp_state) { + constexpr F kMin = static_cast<F>(std::numeric_limits<I>::min()); + constexpr F kMax = static_cast<F>(std::numeric_limits<I>::max()); + ScopedFPStatus status(fp_state->host_fp_interface()); + auto fp_class = std::fpclassify(value); + switch (fp_class) { + case FP_INFINITE: + return std::make_tuple(std::signbit(value) + ? std::numeric_limits<I>::min() + : std::numeric_limits<I>::max(), + static_cast<uint32_t>(FPExceptions::kInvalidOp)); + case FP_NAN: + return std::make_tuple(std::numeric_limits<I>::max(), + static_cast<uint32_t>(FPExceptions::kInvalidOp)); + case FP_ZERO: + return std::make_tuple(0, 0); + case FP_SUBNORMAL: + case FP_NORMAL: + if (value > kMax) { + return std::make_tuple(std::numeric_limits<I>::max(), + static_cast<uint32_t>(FPExceptions::kInvalidOp)); + } + if (value < kMin) { + if (std::is_unsigned<I>::value) { + if ((value > -1.0) && + (static_cast<typename std::make_signed<I>::type>(value) == 0)) { + return std::make_tuple( + 0, static_cast<uint32_t>(FPExceptions::kInexact)); + } + } + if (value < kMin) { + return std::make_tuple( + std::numeric_limits<I>::min(), + static_cast<uint32_t>(FPExceptions::kInvalidOp)); + } + } + } + return std::make_tuple(static_cast<I>(value), 0); +} + +// Test convert from fp to signed integer with truncation. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtrtzxfv) { + SetSemanticFunction(&Vfcvtrtzxfv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<int32_t, float>( + "Vfcvt.rtz.x.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<int32_t, uint32_t> { + return ConvertHelper<float, int32_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtrtzxfv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<int64_t, double>( + "Vfcvt.rtz.x.f.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<int64_t, uint32_t> { + return ConvertHelper<double, int64_t>(vs2, this->rv_fp_); + }); +} + +// Test convert from fp to unsigned integer with truncation. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtrtzxufv) { + SetSemanticFunction(&Vfcvtrtzxufv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<uint32_t, float>( + "Vfcvt.rtz.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<uint32_t, uint32_t> { + return ConvertHelper<float, uint32_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtrtzxufv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<uint64_t, double>( + "Vfcvt.rtz.xu.f.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<uint64_t, uint32_t> { + return ConvertHelper<double, uint64_t>(vs2, this->rv_fp_); + }); +} + +// Test convert from fp to signed integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtxfv) { + SetSemanticFunction(&Vfcvtxfv); + UnaryOpWithFflagsFPTestHelperV<int32_t, float>( + "Vfcvt.x.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<int32_t, uint32_t> { + return ConvertHelper<float, int32_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtxfv); + UnaryOpWithFflagsFPTestHelperV<int64_t, double>( + "Vfcvt.x.f.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<int64_t, uint32_t> { + return ConvertHelper<double, int64_t>(vs2, this->rv_fp_); + }); +} + +// Test convert from fp to unsigned integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfcvtxufv) { + SetSemanticFunction(&Vfcvtxufv); + UnaryOpWithFflagsFPTestHelperV<uint32_t, float>( + "Vfcvt.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<uint32_t, uint32_t> { + return ConvertHelper<float, uint32_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfcvtxufv); + UnaryOpWithFflagsFPTestHelperV<uint64_t, double>( + "Vfcvt.xu.f.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<uint64_t, uint32_t> { + return ConvertHelper<double, uint64_t>(vs2, this->rv_fp_); + }); +} + +// Test vfmv.f.s instruction - move element 0 to scalar fp register. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, VfmvToScalar) { + SetSemanticFunction(&Vfmvfs); + AppendRegisterOperands({}, {kFs1Name}); + AppendVectorRegisterOperands({kVs2}, {}); + for (int byte_sew : {1, 2, 4, 8}) { + int vlen = kVectorLengthInBytes / byte_sew; + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[4]; + ConfigureVectorUnit(vtype, vlen); + if (byte_sew < 4) { + instruction_->Execute(); + EXPECT_TRUE(rv_vector_->vector_exception()); + continue; + } + // Test 10 different values. + for (int i = 0; i < 10; i++) { + uint64_t value; + switch (byte_sew) { + case 4: { + auto val32 = RandomValue<uint32_t>(); + value = 0xffff'ffff'0000'0000ULL | static_cast<uint64_t>(val32); + SetVectorRegisterValues<uint32_t>( + {{kVs2Name, absl::Span<const uint32_t>(&val32, 1)}}); + break; + } + case 8: { + value = RandomValue<uint64_t>(); + SetVectorRegisterValues<uint64_t>( + {{kVs2Name, absl::Span<const uint64_t>(&value, 1)}}); + break; + } + } + instruction_->Execute(); + EXPECT_EQ(freg_[kFs1]->data_buffer()->Get<RVFpRegister::ValueType>(0), + static_cast<RVFpRegister::ValueType>(value)); + } + } +} + +// Test vfmv.f.s instruction - move scalar fp register to element 0. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, VfmvFromScalar) { + SetSemanticFunction(&Vfmvsf); + AppendRegisterOperands({kFs1Name}, {}); + AppendVectorRegisterOperands({}, {kVd}); + for (int byte_sew : {1, 2, 4, 8}) { + int vlen = kVectorLengthInBytes / byte_sew; + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[4]; + ConfigureVectorUnit(vtype, vlen); + if (byte_sew < 4) { + instruction_->Execute(); + EXPECT_TRUE(rv_vector_->vector_exception()); + continue; + } + // Test 10 different values. + for (int i = 0; i < 10; i++) { + auto value = RandomValue<RVFpRegister::ValueType>(); + freg_[kFs1]->data_buffer()->Set<RVFpRegister::ValueType>(0, value); + instruction_->Execute(); + switch (byte_sew) { + case 4: + EXPECT_EQ(vreg_[kVd]->data_buffer()->Get<uint32_t>(0), + static_cast<uint32_t>(value)); + break; + case 8: + EXPECT_EQ(vreg_[kVd]->data_buffer()->Get<uint64_t>(0), + static_cast<uint64_t>(value)); + break; + } + } + } +} + +// Test narrowing convert from unsigned integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtfxuw) { + SetSemanticFunction(&Vfncvtfxuw); + UnaryOpTestHelperV<float, uint64_t>( + "Vfncvt.f.xu.w_64", /*sew*/ 64, instruction_, + [](uint64_t vs2) -> float { return static_cast<float>(vs2); }); +} + +// Test narrowing convert from signed integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtfxw) { + SetSemanticFunction(&Vfncvtfxw); + UnaryOpTestHelperV<float, int64_t>( + "Vfncvt.f.x.w_64", /*sew*/ 64, instruction_, + [](int64_t vs2) -> float { return static_cast<float>(vs2); }); +} + +// Test narrowing convert fp to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtffw) { + SetSemanticFunction(&Vfncvtffw); + UnaryOpTestHelperV<float, double>( + "Vfncvt.f.f.w_64", /*sew*/ 64, instruction_, + [](double vs2) -> float { return static_cast<float>(vs2); }); +} + +// Test narrowing convert fp to fp with round to odd. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtrodffw) { + SetSemanticFunction(&Vfncvtrodffw); + UnaryOpTestHelperV<float, double>( + "Vfncvt.rod.f.f.w_64", /*sew*/ 64, instruction_, [](double vs2) -> float { + if (std::isnan(vs2) || std::isinf(vs2)) { + return static_cast<float>(vs2); + } + using UIntD = typename FPTypeInfo<double>::IntType; + using UIntF = typename FPTypeInfo<float>::IntType; + UIntD uval = *reinterpret_cast<UIntD *>(&vs2); + int diff = FPTypeInfo<double>::kSigSize - FPTypeInfo<float>::kSigSize; + UIntF bit = (uval & (FPTypeInfo<double>::kSigMask >> diff)) != 0; + float res = static_cast<float>(vs2); + // The narrowing conversion may have generated an infinity, so check + // for infinity before doing rounding. + if (std::isinf(res)) return res; + UIntF ures = *reinterpret_cast<UIntF *>(&res) | bit; + return *reinterpret_cast<float *>(&ures); + }); +} + +// Test narrowing convert from fp to signed integer with truncation. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtrtzxfw) { + SetSemanticFunction(&Vfncvtrtzxfw); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<int16_t, float>( + "Vfncvt.rtz.x.f.w_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<int16_t, uint32_t> { + return ConvertHelper<float, int16_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfncvtrtzxfw); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<int32_t, double>( + "Vfncvt.rtz.x.f.w_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<int32_t, uint32_t> { + return ConvertHelper<double, int32_t>(vs2, this->rv_fp_); + }); +} + +// Test narrowing convert from fp to unsigned integer with truncation. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtrtzxufw) { + SetSemanticFunction(&Vfncvtrtzxufw); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<uint16_t, float>( + "Vfncvt.rtz.xu.f.w_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<uint16_t, uint32_t> { + return ConvertHelper<float, uint16_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfncvtrtzxufw); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<uint32_t, double>( + "Vfncvt.rtz.xu.f.w_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<uint32_t, uint32_t> { + return ConvertHelper<double, uint32_t>(vs2, this->rv_fp_); + }); +} + +// Test narrowing convert fp to signed integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtxfw) { + SetSemanticFunction(&Vfncvtxfw); + UnaryOpWithFflagsFPTestHelperV<int16_t, float>( + "Vfncvt.x.f.w_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<int16_t, uint32_t> { + return ConvertHelper<float, int16_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfncvtxfw); + UnaryOpWithFflagsFPTestHelperV<int32_t, double>( + "Vfncvt.x.f.w_64", /*sew*/ 64, instruction_, /*delta_position*/ 32, + [this](double vs2) -> std::tuple<int32_t, uint32_t> { + return ConvertHelper<double, int32_t>(vs2, this->rv_fp_); + }); +} + +// Test narrowing convert fp to unsigned integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfncvtxufw) { + SetSemanticFunction(&Vfncvtxufw); + UnaryOpWithFflagsFPTestHelperV<uint16_t, float>( + "Vfncvt.xu.f.w_32", /*sew*/ 32, instruction_, /*delta_position*/ 32, + [this](float vs2) -> std::tuple<uint16_t, uint32_t> { + return ConvertHelper<float, uint16_t>(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfncvtxufw); + UnaryOpWithFflagsFPTestHelperV<uint32_t, double>( + "Vfncvt.xu.f.w_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [this](double vs2) -> std::tuple<uint32_t, uint32_t> { + return ConvertHelper<double, uint32_t>(vs2, this->rv_fp_); + }); +} + +// Helper function for testing approximate reciprocal instruction. +template <typename T> +inline T Vrecip7vTestHelper(T vs2, RiscVFPState *rv_fp) { + using UInt = typename FPTypeInfo<T>::IntType; + if (FPTypeInfo<T>::IsNaN(vs2)) { + auto nan_value = FPTypeInfo<T>::kCanonicalNaN; + return *reinterpret_cast<T *>(&nan_value); + } + if (std::isinf(vs2)) { + return std::signbit(vs2) ? -0.0 : 0.0; + } + if (vs2 == 0.0) { + UInt value = + std::signbit(vs2) ? FPTypeInfo<T>::kNegInf : FPTypeInfo<T>::kPosInf; + return *reinterpret_cast<T *>(&value); + } + UInt uint_vs2 = *reinterpret_cast<UInt *>(&vs2); + auto exp = (uint_vs2 & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; + auto sig2 = + (uint_vs2 & FPTypeInfo<T>::kSigMask) >> (FPTypeInfo<T>::kSigSize - 2); + auto rm = rv_fp->GetRoundingMode(); + if ((exp == 0) && (sig2 == 0)) { // Denormal number. + if (std::signbit(vs2)) { + if ((rm == FPRoundingMode::kRoundTowardsZero) || + (rm == FPRoundingMode::kRoundUp)) { + return std::numeric_limits<T>::lowest(); + } else { + UInt value = FPTypeInfo<T>::kNegInf; + return *reinterpret_cast<T *>(&value); + } + } else { + if ((rm == FPRoundingMode::kRoundTowardsZero) || + (rm == FPRoundingMode::kRoundDown)) { + return std::numeric_limits<T>::max(); + } else { + UInt value = FPTypeInfo<T>::kPosInf; + return *reinterpret_cast<T *>(&value); + } + } + } + ScopedFPStatus status(rv_fp->host_fp_interface(), + FPRoundingMode::kRoundTowardsZero); + T value = 1.0 / vs2; + UInt uint_val = *reinterpret_cast<UInt *>(&value); + UInt mask = FPTypeInfo<T>::kSigMask >> 7; + uint_val = uint_val & ~mask; + return *reinterpret_cast<T *>(&uint_val); +} + +// Test approximate reciprocal instruction. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfrec7v) { + SetSemanticFunction(&Vfrec7v); + UnaryOpFPTestHelperV<float, float>( + "Vfrec7.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 7, + [this](float vs2) -> float { + return Vrecip7vTestHelper(vs2, this->rv_fp_); + }); +} + +// Helper function for testing approximate reciprocal square root instruction. +template <typename T> +inline std::tuple<T, uint32_t> Vfrsqrt7vTestHelper(T vs2, RiscVFPState *rv_fp) { + using UInt = typename FPTypeInfo<T>::IntType; + T return_value; + uint32_t fflags = 0; + if (FPTypeInfo<T>::IsNaN(vs2) || (vs2 < 0.0)) { + auto nan_value = FPTypeInfo<T>::kCanonicalNaN; + return_value = *reinterpret_cast<T *>(&nan_value); + fflags = static_cast<uint32_t>(FPExceptions::kInvalidOp); + } else if (vs2 == 0.0) { + UInt value = + std::signbit(vs2) ? FPTypeInfo<T>::kNegInf : FPTypeInfo<T>::kPosInf; + return_value = *reinterpret_cast<T *>(&value); + fflags = static_cast<uint32_t>(FPExceptions::kDivByZero); + } else if (std::isinf(vs2)) { + return_value = 0.0; + fflags = static_cast<uint32_t>(FPExceptions::kInvalidOp); + } else { + ScopedFPStatus status(rv_fp->host_fp_interface(), + FPRoundingMode::kRoundTowardsZero); + T value = 1.0 / sqrt(vs2); + UInt uint_val = *reinterpret_cast<UInt *>(&value); + UInt mask = FPTypeInfo<T>::kSigMask >> 7; + uint_val = uint_val & ~mask; + return_value = *reinterpret_cast<T *>(&uint_val); + } + return std::make_tuple(return_value, fflags); +} + +// Test approximate reciprocal square root. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfrsqrt7v) { + SetSemanticFunction(&Vfrsqrt7v); + UnaryOpWithFflagsFPTestHelperV<float, float>( + "Vfrsqrt7.v_32", /*sew*/ 32, instruction_, + /*delta_position*/ 7, [this](float vs2) -> std::tuple<float, uint32_t> { + return Vfrsqrt7vTestHelper(vs2, this->rv_fp_); + }); + ResetInstruction(); + SetSemanticFunction(&Vfrsqrt7v); + UnaryOpWithFflagsFPTestHelperV<double, double>( + "Vfsqrt.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 7, + [this](double vs2) -> std::tuple<double, uint32_t> { + return Vfrsqrt7vTestHelper(vs2, this->rv_fp_); + }); +} + +// Helper function for testing square root instruction. +template <typename T> +inline std::tuple<T, uint32_t> VfsqrtvTestHelper(T vs2) { + if (vs2 == 0.0) return std::make_tuple(vs2, 0); + if (std::isnan(vs2) || (vs2 < 0.0)) { + auto val = FPTypeInfo<T>::kCanonicalNaN; + uint32_t flags = 0; + if (!mpact::sim::generic::FPTypeInfo<T>::IsQNaN(vs2)) { + flags = (uint32_t)FPExceptions::kInvalidOp; + } + return std::make_tuple(*reinterpret_cast<const T *>(&val), + (uint32_t)FPExceptions::kInvalidOp); + } + T res = sqrt(vs2); + + return std::make_tuple(res, 0); +} + +// Test square root instruction. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfsqrtv) { + SetSemanticFunction(&Vfsqrtv); + UnaryOpWithFflagsFPTestHelperV<float, float>( + "Vfsqrt.v_32", /*sew*/ 32, instruction_, + /*delta_position*/ 32, [](float vs2) -> std::tuple<float, uint32_t> { + return VfsqrtvTestHelper(vs2); + }); + ResetInstruction(); + SetSemanticFunction(&Vfsqrtv); + UnaryOpWithFflagsFPTestHelperV<double, double>( + "Vfsqrt.v_64", /*sew*/ 64, instruction_, /*delta_position*/ 64, + [](double vs2) -> std::tuple<double, uint32_t> { + return VfsqrtvTestHelper(vs2); + }); +} + +// Test widening convert fp to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtffv) { + SetSemanticFunction(&Vfwcvtffv); + UnaryOpTestHelperV<double, float>( + "Vfwcvt.f.f.v_32", /*sew*/ 32, instruction_, + [](float vs2) -> double { return static_cast<double>(vs2); }); +} + +// Test widening convert unsigned integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtfxuv) { + SetSemanticFunction(&Vfwcvtfxuv); + UnaryOpTestHelperV<float, uint16_t>( + "Vfwcvt.f.xu.v_16", /*sew*/ 16, instruction_, + [](uint16_t vs2) -> float { return static_cast<float>(vs2); }); + ResetInstruction(); + SetSemanticFunction(&Vfwcvtfxuv); + UnaryOpTestHelperV<double, uint32_t>( + "Vfwcvt.f.xu.v_32", /*sew*/ 32, instruction_, + [](uint32_t vs2) -> double { return static_cast<double>(vs2); }); +} + +// Test widening convert signed integer to fp. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtfxv) { + SetSemanticFunction(&Vfwcvtfxv); + UnaryOpTestHelperV<float, int16_t>( + "Vfwcvt.f.x.v_16", /*sew*/ 16, instruction_, + [](int16_t vs2) -> float { return static_cast<float>(vs2); }); + ResetInstruction(); + SetSemanticFunction(&Vfwcvtfxv); + UnaryOpTestHelperV<double, int32_t>( + "Vfwcvt.f.x.v_32", /*sew*/ 32, instruction_, + [](int32_t vs2) -> double { return static_cast<double>(vs2); }); +} + +// Test widening convert fp to signed integer with truncation (round to zero). +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtrtzxfv) { + SetSemanticFunction(&Vfwcvtrtzxfv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<int64_t, float>( + "Vfwcvt.rtz.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [this](float vs2) -> std::tuple<int64_t, uint32_t> { + return ConvertHelper<float, int64_t>(vs2, this->rv_fp_); + }); +} + +// Test widening convert fp to unsigned integer with truncation (round to zero). +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtrtzxufv) { + SetSemanticFunction(&Vfwcvtrtzxufv); + rv_fp_->SetRoundingMode(FPRoundingMode::kRoundTowardsZero); + UnaryOpWithFflagsFPTestHelperV<uint64_t, float>( + "Vfwcvt.rtz.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [this](float vs2) -> std::tuple<uint64_t, uint32_t> { + return ConvertHelper<float, uint64_t>(vs2, this->rv_fp_); + }); +} + +// Test widening convert fp to signed integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtxfv) { + SetSemanticFunction(&Vfwcvtxfv); + UnaryOpWithFflagsFPTestHelperV<int64_t, float>( + "Vfwcvt.rtz.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [this](float vs2) -> std::tuple<int64_t, uint32_t> { + return ConvertHelper<float, int64_t>(vs2, this->rv_fp_); + }); +} + +// Test widening convert fp to unsigned integer with rounding. +TEST_F(RiscVCheriotFPUnaryInstructionsTest, Vfwcvtxufv) { + SetSemanticFunction(&Vfwcvtxufv); + UnaryOpWithFflagsFPTestHelperV<uint64_t, float>( + "Vfwcvt.rtz.xu.f.v_32", /*sew*/ 32, instruction_, /*delta_position*/ 64, + [this](float vs2) -> std::tuple<uint64_t, uint32_t> { + return ConvertHelper<float, uint64_t>(vs2, this->rv_fp_); + }); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_instructions_test_base.h b/cheriot/test/riscv_cheriot_vector_instructions_test_base.h new file mode 100644 index 0000000..fc074e1 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_instructions_test_base.h
@@ -0,0 +1,1234 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MPACT_RISCV_RISCV_TEST_RISCV_VECTOR_INSTRUCTIONS_TEST_BASE_H_ +#define MPACT_RISCV_RISCV_TEST_RISCV_VECTOR_INSTRUCTIONS_TEST_BASE_H_ + +#include <algorithm> +#include <cstdint> +#include <cstring> +#include <functional> +#include <ios> +#include <limits> +#include <string> +#include <tuple> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/random/random.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/riscv_cheriot_vector_memory_instructions.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/generic/immediate_operand.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/register.h" +#include "mpact/sim/generic/type_helpers.h" +#include "mpact/sim/util/memory/tagged_flat_demand_memory.h" +#include "riscv//riscv_register.h" + +// This file defines commonly used constants in the vector instruction tests +// as well as a base class for the vector instruction test fixtures. This base +// class contains methods that make it more convenient to write vector +// instruction test cases, and provide "harnesses" to test the functionality +// of individual instructions across different lmul values, vstart values, +// and vector length values. + +using ::absl::Span; +using ::mpact::sim::cheriot::CheriotRegister; +using ::mpact::sim::cheriot::CheriotState; +using ::mpact::sim::cheriot::CheriotVectorState; +using ::mpact::sim::cheriot::Vsetvl; +using ::mpact::sim::generic::ImmediateOperand; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::NarrowType; +using ::mpact::sim::generic::RegisterBase; +using ::mpact::sim::generic::SameSignedType; +using ::mpact::sim::generic::WideType; +using ::mpact::sim::riscv::RiscVState; +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; +using ::mpact::sim::riscv::RVFpRegister; +using ::mpact::sim::riscv::RVVectorRegister; +using ::mpact::sim::util::TaggedFlatDemandMemory; +using ::std::tuple; + +// Constants used in the tests. +constexpr int kVectorLengthInBits = 512; +constexpr int kVectorLengthInBytes = kVectorLengthInBits / 8; +constexpr uint32_t kInstAddress = 0x1000; +constexpr uint32_t kDataLoadAddress = 0x1'0000; +constexpr uint32_t kDataStoreAddress = 0x2'0000; +constexpr char kRs1Name[] = "c1"; +constexpr int kRs1 = 1; +constexpr char kRs2Name[] = "c2"; +constexpr char kRs3Name[] = "c3"; +constexpr char kRdName[] = "c8"; +constexpr int kRd = 8; +constexpr int kVmask = 1; +constexpr char kVmaskName[] = "v1"; +constexpr int kVd = 8; +constexpr char kVdName[] = "v8"; +constexpr int kVs1 = 16; +constexpr char kVs1Name[] = "v16"; +constexpr int kVs2 = 24; +constexpr char kVs2Name[] = "v24"; + +// Setting bits and corresponding values for lmul and sew. +constexpr int kLmulSettings[7] = {0b101, 0b110, 0b111, 0b000, + 0b001, 0b010, 0b011}; +constexpr int kLmul8Values[7] = {1, 2, 4, 8, 16, 32, 64}; +constexpr int kLmulSettingByLogSize[] = {0, 0b101, 0b110, 0b111, + 0b000, 0b001, 0b010, 0b011}; +constexpr int kSewSettings[4] = {0b000, 0b001, 0b010, 0b011}; +constexpr int kSewValues[4] = {1, 2, 4, 8}; +constexpr int kSewSettingsByByteSize[] = {0, 0b000, 0b001, 0, 0b010, + 0, 0, 0, 0b011}; + +// Don't need to set every byte, as only the low bits are used for mask values. +constexpr uint8_t kA5Mask[kVectorLengthInBytes] = { + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, +}; +// This is the base class for vector instruction test fixtures. It implements +// generic methods for testing and supporting testing of the RiscV vector +// instructions. +class RiscVCheriotVectorInstructionsTestBase : public testing::Test { + public: + RiscVCheriotVectorInstructionsTestBase() { + memory_ = new TaggedFlatDemandMemory(8); + state_ = new CheriotState("test", memory_); + rv_vector_ = new CheriotVectorState(state_, kVectorLengthInBytes); + instruction_ = new Instruction(kInstAddress, state_); + instruction_->set_size(4); + child_instruction_ = new Instruction(kInstAddress, state_); + child_instruction_->set_size(4); + // Initialize a portion of memory with a known pattern. + auto *db = state_->db_factory()->Allocate(8192); + auto span = db->Get<uint8_t>(); + for (int i = 0; i < 8192; i++) { + span[i] = i & 0xff; + } + memory_->Store(kDataLoadAddress - 4096, db); + db->DecRef(); + for (int i = 1; i < 32; i++) { + creg_[i] = + state_->GetRegister<CheriotRegister>(absl::StrCat("c", i)).first; + } + for (int i = 1; i < 32; i++) { + freg_[i] = state_->GetRegister<RVFpRegister>(absl::StrCat("f", i)).first; + } + for (int i = 1; i < 32; i++) { + vreg_[i] = + state_->GetRegister<RVVectorRegister>(absl::StrCat("v", i)).first; + } + } + + ~RiscVCheriotVectorInstructionsTestBase() override { + delete state_; + delete rv_vector_; + instruction_->DecRef(); + child_instruction_->DecRef(); + delete memory_; + } + + // Clear the instruction instance and allocate a new one. + void ResetInstruction() { + instruction_->DecRef(); + instruction_ = new Instruction(kInstAddress, state_); + instruction_->set_size(4); + } + + // Creates immediate operands with the values from the vector and appends them + // to the given instruction. + template <typename T> + void AppendImmediateOperands(Instruction *inst, + const std::vector<T> &values) { + for (auto value : values) { + auto *src = new ImmediateOperand<T>(value); + inst->AppendSource(src); + } + } + + // Creates immediate operands with the values from the vector and appends them + // to the default instruction. + template <typename T> + void AppendImmediateOperands(const std::vector<T> &values) { + AppendImmediateOperands<T>(instruction_, values); + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the given instruction. + void AppendRegisterOperands(Instruction *inst, + const std::vector<std::string> &sources, + const std::vector<std::string> &destinations) { + for (auto ®_name : sources) { + auto *reg = state_->GetRegister<CheriotRegister>(reg_name).first; + inst->AppendSource(reg->CreateSourceOperand()); + } + for (auto ®_name : destinations) { + auto *reg = state_->GetRegister<CheriotRegister>(reg_name).first; + inst->AppendDestination(reg->CreateDestinationOperand(0)); + } + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the default instruction. + void AppendRegisterOperands(const std::vector<std::string> &sources, + const std::vector<std::string> &destinations) { + AppendRegisterOperands(instruction_, sources, destinations); + } + + // Returns the value of the named vector register. + template <typename T> + T GetRegisterValue(absl::string_view vreg_name) { + auto *reg = state_->GetRegister<CheriotRegister>(vreg_name).first; + return reg->data_buffer()->Get<T>(); + } + + // named register and sets it to the corresponding value. + template <typename T, typename RegisterType = CheriotRegister> + void SetRegisterValues( + const std::vector<tuple<std::string, const T>> &values) { + for (auto &[reg_name, value] : values) { + auto *reg = state_->GetRegister<RegisterType>(reg_name).first; + auto *db = + state_->db_factory()->Allocate<typename RegisterType::ValueType>(1); + db->template Set<T>(0, value); + reg->SetDataBuffer(db); + db->DecRef(); + } + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the given instruction. + void AppendVectorRegisterOperands(Instruction *inst, + const std::vector<int> &sources, + const std::vector<int> &destinations) { + for (auto ®_no : sources) { + std::vector<RegisterBase *> reg_vec; + for (int i = 0; (i < 8) && (i + reg_no < 32); i++) { + std::string reg_name = absl::StrCat("v", i + reg_no); + reg_vec.push_back( + state_->GetRegister<RVVectorRegister>(reg_name).first); + } + auto *op = new RV32VectorSourceOperand( + absl::Span<RegisterBase *>(reg_vec), absl::StrCat("v", reg_no)); + inst->AppendSource(op); + } + for (auto ®_no : destinations) { + std::vector<RegisterBase *> reg_vec; + for (int i = 0; (i < 8) && (i + reg_no < 32); i++) { + std::string reg_name = absl::StrCat("v", i + reg_no); + reg_vec.push_back( + state_->GetRegister<RVVectorRegister>(reg_name).first); + } + auto *op = new RV32VectorDestinationOperand( + absl::Span<RegisterBase *>(reg_vec), 0, absl::StrCat("v", reg_no)); + inst->AppendDestination(op); + } + } + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the default instruction. + void AppendVectorRegisterOperands(const std::vector<int> &sources, + const std::vector<int> &destinations) { + AppendVectorRegisterOperands(instruction_, sources, destinations); + } + + // Returns the value of the named vector register. + template <typename T> + T GetVectorRegisterValue(absl::string_view reg_name) { + auto *reg = state_->GetRegister<RVVectorRegister>(reg_name).first; + return reg->data_buffer()->Get<T>(0); + } + + // Set a vector register value. Takes a vector of tuples of register names and + // spans of values, fetches each register and sets it to the corresponding + // value. + template <typename T> + void SetVectorRegisterValues( + const std::vector<tuple<std::string, Span<const T>>> &values) { + for (auto &[vreg_name, span] : values) { + auto *vreg = state_->GetRegister<RVVectorRegister>(vreg_name).first; + auto *db = state_->db_factory()->MakeCopyOf(vreg->data_buffer()); + db->template Set<T>(span); + vreg->SetDataBuffer(db); + db->DecRef(); + } + } + + // Initializes the semantic function of the instruction object. + void SetSemanticFunction(Instruction *inst, + Instruction::SemanticFunction fcn) { + inst->set_semantic_function(fcn); + } + + // Initializes the semantic function for the default instruction. + void SetSemanticFunction(Instruction::SemanticFunction fcn) { + instruction_->set_semantic_function(fcn); + } + + // Sets the default child instruction as the child of the default instruction. + void SetChildInstruction() { instruction_->AppendChild(child_instruction_); } + + // Initializes the semantic function for the default child instruction. + void SetChildSemanticFunction(Instruction::SemanticFunction fcn) { + child_instruction_->set_semantic_function(fcn); + } + + // Configure the vector unit according to the vtype and vlen values. + void ConfigureVectorUnit(uint32_t vtype, uint32_t vlen) { + Instruction *inst = new Instruction(state_); + AppendImmediateOperands<uint32_t>(inst, {vlen, vtype}); + SetSemanticFunction(inst, absl::bind_front(&Vsetvl, true, false)); + inst->Execute(nullptr); + inst->DecRef(); + } + + // Clear count registers in the register group, starting at start. + void ClearVectorRegisterGroup(int start, int count) { + for (int reg = start; (reg < start + count) && (reg < 32); reg++) { + memset(vreg_[reg]->data_buffer()->raw_ptr(), 0, kVectorLengthInBytes); + } + } + + // Create a random value in the valid range for the type. + template <typename T> + T RandomValue() { + return absl::Uniform(absl::IntervalClosed, bitgen_, + std::numeric_limits<T>::lowest(), + std::numeric_limits<T>::max()); + } + + // Fill the span with random values. + template <typename T> + void FillArrayWithRandomValues(absl::Span<T> span) { + for (auto &val : span) { + val = RandomValue<T>(); + } + } + + // Helper function for testing unary vector-vector instructions. + template <typename Vd, typename Vs2> + void UnaryOpTestHelperV(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + if (lmul8_vd < 1 || lmul8_vd > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vd: " << lmul8_vd; + rv_vector_->clear_vector_exception(); + continue; + } + if (lmul8_vs2 < 1 || lmul8_vs2 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "lmul8: vs2: " << lmul8_vs2; + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + if ((count >= vstart) && mask_value && (count < num_values)) { + EXPECT_EQ(operation(vs2_value[count]), + vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, "[", count, "] != reg[", reg, "][", i, + "] lmul8(", lmul8, ")"); + } else { + EXPECT_EQ(0, vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, " 0 != reg[", reg, "][", i, + "] lmul8(", lmul8, ")"); + } + count++; + } + } + } + } + } + } + + // Helper function for testing vector-vector instructions that use the value + // of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpWithMaskTestHelperVV( + absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Vs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs1 vs1_value[vs1_size * 8]; + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + FillArrayWithRandomValues<Vs1>(vs1_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs1_name = absl::StrCat("v", kVs1 + i); + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different vstart values. + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + + if ((std::min(std::min(lmul8_vs2, lmul8_vs1), lmul8_vd) < 1) || + (std::max(std::max(lmul8_vs2, lmul8_vs1), lmul8_vd) > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + if ((count >= vstart) && (count < num_values)) { + EXPECT_EQ( + operation(vs2_value[count], vs1_value[count], mask_value), + vreg_[reg]->data_buffer()->Get<Vd>(i)) + << std::hex << (int64_t)vs2_value[count] << ", " + << (int64_t)vs1_value[count] << " " << std::dec + << (int64_t)vs2_value[count] << ", " + << (int64_t)vs1_value[count] << " " + << absl::StrCat(name, "[", count, "] != reg[", reg, "][", i, + "] lmul8(", lmul8, ") vstart(", vstart, + ")"); + } else { + EXPECT_EQ(0, vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, " 0 != reg[", reg, "][", i, + "] lmul8(", lmul8, ") vstart(", vstart, + ")"); + } + count++; + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing vector-vector instructions that do not + // use the value of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpTestHelperVV(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Vs1)> operation) { + BinaryOpWithMaskTestHelperVV<Vd, Vs2, Vs1>( + name, sew, inst, [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> Vd { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + // Helper function for testing vector-scalar/immediate instructions that use + // the value of the mask bit. + template <typename Vd, typename Vs2, typename Rs1> + void BinaryOpWithMaskTestHelperVX( + absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Rs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Rs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Rs1: " << sizeof(Rs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different vstart values. + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Set vstart, but leave vstart at 0 at least once. + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for the different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + ClearVectorRegisterGroup(kVd, 8); + + // Generate a new rs1 value. + CheriotRegister::ValueType rs1_reg_value = + RandomValue<CheriotRegister::ValueType>(); + SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, rs1_reg_value}}); + // Cast the value to the appropriate width, sign-extending if need + // be. + Rs1 rs1_value = static_cast<Rs1>( + static_cast<typename SameSignedType<CheriotRegister::ValueType, + Rs1>::type>(rs1_reg_value)); + + inst->Execute(); + if ((std::min(lmul8_vs2, lmul8_vd) < 1) || + (std::max(lmul8_vs2, lmul8_vd) > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + // Compare elements that are between vstart and vlen for which + // the mask is true. + if ((count >= vstart) && (count < num_values)) { + Vd expected_value = operation( + vs2_value[count], static_cast<Rs1>(rs1_value), mask_value); + Vd inst_value = vreg_[reg]->data_buffer()->Get<Vd>(i); + EXPECT_EQ(expected_value, inst_value) << absl::StrCat( + name, " [", count, "] != reg[", reg, "][", i, "] lmul8(", + lmul8, ") op(", absl::Hex(vs2_value[count]), ", ", + absl::Hex(static_cast<Rs1>(rs1_value)), + ") vreg: ", absl::Hex(inst_value)); + } else { + // The others should be zero. + EXPECT_EQ(0, vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, " 0 != reg[", reg, "][", i, + "] lmul8(", lmul8, ")"); + } + count++; + } + } + if (HasFailure()) return; + } + } + } + } + + // Templated helper function that tests vector-scalar instructions that do + // not use the value of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void BinaryOpTestHelperVX(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Vs1)> operation) { + BinaryOpWithMaskTestHelperVX<Vd, Vs2, Vs1>( + name, sew, inst, [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> Vd { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + // Helper function for testing vector-vector instructions that use the value + // of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void TernaryOpWithMaskTestHelperVV( + absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Vs1, Vd, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vd_size = kVectorLengthInBytes / sizeof(Vd); + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vd vd_value[vd_size * 8]; + auto vd_span = Span<Vd>(vd_value); + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs1 vs1_value[vs1_size * 8]; + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVd, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vd>(vd_span); + FillArrayWithRandomValues<Vs2>(vs2_span); + FillArrayWithRandomValues<Vs1>(vs1_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs1_name = absl::StrCat("v", kVs1 + i); + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different vstart values. + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vs1 = lmul8 * sizeof(Vs1) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + // Reset Vd values, since the previous instruction execution + // overwrites them. + for (int i = 0; i < 8; i++) { + auto vd_name = absl::StrCat("v", kVd + i); + SetVectorRegisterValues<Vd>( + {{vd_name, vd_span.subspan(vd_size * i, vd_size)}}); + } + + inst->Execute(); + + if ((std::min(std::min(lmul8_vs2, lmul8_vs1), lmul8_vd) < 1) || + (std::max(std::max(lmul8_vs2, lmul8_vs1), lmul8_vd) > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + int count = 0; + int reg_offset = count * byte_sew / kVectorLengthInBytes; + for (int reg = kVd + reg_offset; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + if ((count >= vstart) && (count < num_values)) { + EXPECT_EQ(operation(vs2_value[count], vs1_value[count], + vd_value[count], mask_value), + vreg_[reg]->data_buffer()->Get<Vd>(i)) + << "mask: " << mask_value << " (" << std::hex + << (int64_t)vs2_value[count] << ", " + << (int64_t)vs1_value[count] << ") (" << std::dec + << (int64_t)vs2_value[count] << ", " + << (int64_t)vs1_value[count] << ", " + << (int64_t)vd_value[count] << ") " + << absl::StrCat(name, "[", count, "] != reg[", reg, "][", i, + "] lmul8(", lmul8, ") vstart(", vstart, + ")"); + } else { + EXPECT_EQ(vd_value[count], + vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, " 0 != reg[", reg, "][", i, + "] lmul8(", lmul8, ") vstart(", vstart, + ")"); + } + count++; + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing vector-vector instructions that do not + // use the value of the mask bit. + template <typename Vd, typename Vs2, typename Vs1> + void TernaryOpTestHelperVV(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Vs1, Vd)> operation) { + TernaryOpWithMaskTestHelperVV<Vd, Vs2, Vs1>( + name, sew, inst, + [operation](Vs2 vs2, Vs1 vs1, Vd vd, bool mask_value) -> Vd { + if (mask_value) { + return operation(vs2, vs1, vd); + } + return vd; + }); + } + + // Helper function for testing vector-scalar/immediate instructions that use + // the value of the mask bit. + template <typename Vd, typename Vs2, typename Rs1> + void TernaryOpWithMaskTestHelperVX( + absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Rs1, Vd, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2) && + byte_sew != sizeof(Rs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2) << " Rs1: " << sizeof(Rs1); + return; + } + // Number of elements per vector register. + constexpr int vd_size = kVectorLengthInBytes / sizeof(Vd); + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vd vd_value[vd_size * 8]; + auto vd_span = Span<Vd>(vd_value); + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVd, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vd>(vd_span); + FillArrayWithRandomValues<Vs2>(vs2_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different vstart values. + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Set vstart, but leave vstart at 0 at least once. + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for the different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + // Reset Vd values, since the previous instruction execution + // overwrites them. + for (int i = 0; i < 8; i++) { + auto vd_name = absl::StrCat("v", kVd + i); + SetVectorRegisterValues<Vd>( + {{vd_name, vd_span.subspan(vd_size * i, vd_size)}}); + } + + // Generate a new rs1 value. + CheriotRegister::ValueType rs1_reg_value = + RandomValue<CheriotRegister::ValueType>(); + SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, rs1_reg_value}}); + // Cast the value to the appropriate width, sign-extending if need + // be. + Rs1 rs1_value = static_cast<Rs1>( + static_cast<typename SameSignedType<CheriotRegister::ValueType, + Rs1>::type>(rs1_reg_value)); + + inst->Execute(); + if ((std::min(lmul8_vs2, lmul8_vd) < 1) || + (std::max(lmul8_vs2, lmul8_vd) > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + for (int i = 0; i < kVectorLengthInBytes / sizeof(Vd); i++) { + int mask_index = count >> 3; + int mask_offset = count & 0b111; + bool mask_value = + ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + // Compare elements that are between vstart and vlen for which + // the mask is true. + if ((count >= vstart) && (count < num_values)) { + Vd expected_value = + operation(vs2_value[count], static_cast<Rs1>(rs1_value), + vd_value[count], mask_value); + Vd inst_value = vreg_[reg]->data_buffer()->Get<Vd>(i); + EXPECT_EQ(expected_value, inst_value) + << "mask: " << mask_value << " (" << std::hex + << (int64_t)vs2_value[count] << ", " << (int64_t)rs1_value + << ") (" << std::dec << (int64_t)vs2_value[count] << ", " + << (int64_t)rs1_value << ", " << (int64_t)vd_value[count] + << ") " + << absl::StrCat(name, "[", count, "] != reg[", reg, "][", i, + "] lmul8(", lmul8, ") vstart(", vstart, + ")"); + } else { + // The others should be zero. + EXPECT_EQ(vd_span[count], vreg_[reg]->data_buffer()->Get<Vd>(i)) + << absl::StrCat(name, " 0 != reg[", reg, "][", i, + "] lmul8(", lmul8, ")"); + } + count++; + } + } + if (HasFailure()) return; + } + } + } + } + + // Templated helper function that tests vector-scalar instructions that do + // not use the value of the mask bit. + template <typename Vd, typename Vs2, typename Rs1> + void TernaryOpTestHelperVX(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vs2, Rs1, Vd)> operation) { + TernaryOpWithMaskTestHelperVX<Vd, Vs2, Rs1>( + name, sew, inst, + [operation](Vs2 vs2, Rs1 rs1, Vd vd, bool mask_value) -> Vd { + if (mask_value) { + return operation(vs2, rs1, vd); + } + return vd; + }); + } + + // Helper function for testing binary mask vector-vector instructions that + // use the mask bit. + template <typename Vs2, typename Vs1> + void BinaryMaskOpWithMaskTestHelperVV( + absl::string_view name, int sew, Instruction *inst, + std::function<uint8_t(Vs2, Vs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vs2) && byte_sew != sizeof(Vs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vs2: " << sizeof(Vs2) + << " Vs1: " << sizeof(Vs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + constexpr int vs1_size = kVectorLengthInBytes / sizeof(Vs1); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs1 vs1_value[vs1_size * 8]; + auto vs1_span = Span<Vs1>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + FillArrayWithRandomValues<Vs1>(vs1_span); + // Make every third value the same (at least if the types are same sized). + for (int i = 0; i < std::min(vs1_size, vs2_size); i += 3) { + vs1_span[i] = static_cast<Vs1>(vs2_span[i]); + } + + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + auto vs1_name = absl::StrCat("v", kVs1 + i); + SetVectorRegisterValues<Vs1>( + {{vs1_name, vs1_span.subspan(vs1_size * i, vs1_size)}}); + } + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ClearVectorRegisterGroup(kVd, 8); + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + inst->Execute(); + if ((lmul8_vs2 < 1) || (lmul8_vs2 > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes * 8; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + uint8_t inst_value = dest_span[i >> 3]; + inst_value = (inst_value >> mask_offset) & 0b1; + if ((i >= vstart) && (i < num_values)) { + uint8_t expected_value = + operation(vs2_value[i], vs1_value[i], mask_value); + EXPECT_EQ(expected_value, inst_value) << absl::StrCat( + name, "[", i, "] != reg[][", i, "] lmul8(", lmul8, + ") vstart(", vstart, ") num_values(", num_values, ")"); + } else { + EXPECT_EQ(0, inst_value) << absl::StrCat( + name, "[", i, "] 0 != reg[][", i, "] lmul8(", lmul8, + ") vstart(", vstart, ") num_values(", num_values, ")"); + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing binary mask vector-vector instructions that do + // not use the mask bit. + template <typename Vs2, typename Vs1> + void BinaryMaskOpTestHelperVV(absl::string_view name, int sew, + Instruction *inst, + std::function<uint8_t(Vs2, Vs1)> operation) { + BinaryMaskOpWithMaskTestHelperVV<Vs2, Vs1>( + name, sew, inst, + [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> uint8_t { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + // Helper function for testing mask vector-scalar/immediate instructions that + // use the mask bit. + template <typename Vs2, typename Rs1> + void BinaryMaskOpWithMaskTestHelperVX( + absl::string_view name, int sew, Instruction *inst, + std::function<uint8_t(Vs2, Rs1, bool)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vs2) && byte_sew != sizeof(Rs1)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vs2: " << sizeof(Vs2) + << " Rs1: " << sizeof(Rs1); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vstart_count = 0; vstart_count < 4; vstart_count++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + ClearVectorRegisterGroup(kVd, 8); + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + int vstart = 0; + if (vstart_count > 0) { + vstart = absl::Uniform(absl::IntervalOpen, bitgen_, 0, num_values); + } + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = absl::Uniform(absl::IntervalOpenClosed, bitgen_, vstart, + num_values); + } + num_values = std::min(num_values, vlen); + ASSERT_TRUE(vlen > vstart); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[byte_sew] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + rv_vector_->set_vstart(vstart); + + // Generate a new rs1 value. + CheriotRegister::ValueType rs1_reg_value = + RandomValue<CheriotRegister::ValueType>(); + SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, rs1_reg_value}}); + // Cast the value to the appropriate width, sign-extending if need be. + Rs1 rs1_value = static_cast<Rs1>( + static_cast<typename SameSignedType<CheriotRegister::ValueType, + Rs1>::type>(rs1_reg_value)); + inst->Execute(); + if ((lmul8_vs2 < 1) || (lmul8_vs2 > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vstart(), 0); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes * 8; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = ((kA5Mask[mask_index] >> mask_offset) & 0b1) != 0; + uint8_t inst_value = dest_span[i >> 3]; + inst_value = (inst_value >> mask_offset) & 0b1; + if ((i >= vstart) && (i < num_values)) { + uint8_t expected_value = + operation(vs2_value[i], rs1_value, mask_value); + EXPECT_EQ(expected_value, inst_value) << absl::StrCat( + name, "[i] != reg[0][", i, "] lmul8(", lmul8, ")"); + } else { + EXPECT_EQ(0, inst_value) << absl::StrCat( + name, " 0 != reg[0][", i, "] lmul8(", lmul8, ")"); + } + } + if (HasFailure()) return; + } + } + } + } + + // Helper function for testing mask vector-vector instructions that do not + // use the mask bit. + template <typename Vs2, typename Vs1> + void BinaryMaskOpTestHelperVX(absl::string_view name, int sew, + Instruction *inst, + std::function<uint8_t(Vs2, Vs1)> operation) { + BinaryMaskOpWithMaskTestHelperVX<Vs2, Vs1>( + name, sew, inst, + [operation](Vs2 vs2, Vs1 vs1, bool mask_value) -> uint8_t { + if (mask_value) { + return operation(vs2, vs1); + } + return 0; + }); + } + + // Helper function to compute the rounding output bit. + template <typename T> + T RoundBits(int num_bits, T lost_bits) { + bool bit_d = + (num_bits == 0) ? false : ((lost_bits >> (num_bits - 1)) & 0b1) != 0; + bool bit_d_minus_1 = + (num_bits < 2) ? false : ((lost_bits >> (num_bits - 2)) & 0b1) != 0; + bool bits_d_minus_2_to_0 = + (num_bits < 3) ? false + : (lost_bits & ~(std::numeric_limits<uint64_t>::max() + << (num_bits - 2))) != 0; + bool bits_d_minus_1_to_0 = + (num_bits < 2) ? false + : (lost_bits & ~(std::numeric_limits<uint64_t>::max() + << (num_bits - 1))) != 0; + switch (rv_vector_->vxrm()) { + case 0: + return bit_d_minus_1; + case 1: + return bit_d_minus_1 & (bits_d_minus_2_to_0 | bit_d); + case 2: + return 0; + case 3: + return !bit_d & bits_d_minus_1_to_0; + default: + return 0; + } + } + + CheriotVectorState *rv_vector() const { return rv_vector_; } + absl::Span<RVVectorRegister *> vreg() { + return absl::Span<RVVectorRegister *>(vreg_); + } + absl::Span<CheriotRegister *> creg() { + return absl::Span<CheriotRegister *>(creg_); + } + absl::BitGen &bitgen() { return bitgen_; } + Instruction *instruction() { return instruction_; } + + protected: + CheriotRegister *creg_[32]; + RVVectorRegister *vreg_[32]; + RVFpRegister *freg_[32]; + CheriotState *state_; + Instruction *instruction_; + Instruction *child_instruction_; + TaggedFlatDemandMemory *memory_; + CheriotVectorState *rv_vector_; + absl::BitGen bitgen_; +}; + +#endif // MPACT_RISCV_RISCV_TEST_RISCV_VECTOR_INSTRUCTIONS_TEST_BASE_H_
diff --git a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc new file mode 100644 index 0000000..17770bb --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
@@ -0,0 +1,2015 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_memory_instructions.h" + +#include <algorithm> +#include <cstdint> +#include <cstring> +#include <functional> +#include <ios> +#include <string> +#include <type_traits> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/log/log.h" +#include "absl/numeric/bits.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/data_buffer.h" +#include "mpact/sim/generic/immediate_operand.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/register.h" +#include "mpact/sim/util/memory/tagged_flat_demand_memory.h" +#include "riscv//riscv_register.h" + +// This file contains the test fixture and tests for testing RiscV vector +// memory instructions. + +namespace { + +using ::absl::Span; +using ::mpact::sim::cheriot::CheriotRegister; +using ::mpact::sim::cheriot::CheriotState; +using ::mpact::sim::cheriot::CheriotVectorState; +using ::mpact::sim::generic::ImmediateOperand; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::RegisterBase; +using ::mpact::sim::riscv::RV32VectorDestinationOperand; +using ::mpact::sim::riscv::RV32VectorSourceOperand; +using ::mpact::sim::riscv::RVVectorRegister; +using ::mpact::sim::util::TaggedFlatDemandMemory; +using ::std::tuple; + +// Semantic functions. +using ::mpact::sim::cheriot::VlChild; +using ::mpact::sim::cheriot::VlIndexed; +using ::mpact::sim::cheriot::Vlm; +using ::mpact::sim::cheriot::VlRegister; +using ::mpact::sim::cheriot::VlSegment; +using ::mpact::sim::cheriot::VlSegmentChild; +using ::mpact::sim::cheriot::VlSegmentIndexed; +using ::mpact::sim::cheriot::VlSegmentStrided; +using ::mpact::sim::cheriot::VlStrided; +using ::mpact::sim::cheriot::VlUnitStrided; +using ::mpact::sim::cheriot::Vsetvl; +using ::mpact::sim::cheriot::VsIndexed; +using ::mpact::sim::cheriot::Vsm; +using ::mpact::sim::cheriot::VsRegister; +using ::mpact::sim::cheriot::VsSegment; +using ::mpact::sim::cheriot::VsSegmentIndexed; +using ::mpact::sim::cheriot::VsSegmentStrided; +using ::mpact::sim::cheriot::VsStrided; + +// Constants used in the tests. +constexpr int kVectorLengthInBits = 512; +constexpr int kVectorLengthInBytes = kVectorLengthInBits / 8; +constexpr uint32_t kInstAddress = 0x1000; +constexpr uint32_t kDataLoadAddress = 0x1'0000; +constexpr uint32_t kDataStoreAddress = 0x8'0000; +constexpr char kRs1Name[] = "c1"; +constexpr int kRs1 = 1; +constexpr char kRs2Name[] = "c2"; +constexpr char kRs3Name[] = "c3"; +constexpr char kRdName[] = "c8"; +constexpr int kRd = 8; +constexpr int kVmask = 1; +constexpr char kVmaskName[] = "v1"; +constexpr int kVd = 8; +constexpr int kVs1 = 16; +constexpr int kVs2 = 24; + +// Setting bits and corresponding values for lmul and sew. +constexpr int kLmulSettings[7] = {0b101, 0b110, 0b111, 0b000, + 0b001, 0b010, 0b011}; +constexpr int kLmul8Values[7] = {1, 2, 4, 8, 16, 32, 64}; +constexpr int kLmulSettingByLogSize[] = {0, 0b101, 0b110, 0b111, + 0b000, 0b001, 0b010, 0b011}; +constexpr int kSewSettings[4] = {0b000, 0b001, 0b010, 0b011}; +constexpr int kSewValues[4] = {1, 2, 4, 8}; +constexpr int kSewSettingsByByteSize[] = {0, 0b000, 0b001, 0, 0b010, + 0, 0, 0, 0b011}; + +// Don't need to set every byte, as only the low bits are used for mask values. +constexpr uint8_t kA5Mask[kVectorLengthInBytes] = { + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, + 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, 0xa5, +}; + +// Test fixture class. This class allows for more convenient manipulations +// of instructions to test the semantic functions. +class RiscVCheriotVInstructionsTest : public testing::Test { + public: + RiscVCheriotVInstructionsTest() { + memory_ = new TaggedFlatDemandMemory(8); + state_ = new CheriotState("test", memory_); + rv_vector_ = new CheriotVectorState(state_, kVectorLengthInBytes); + instruction_ = new Instruction(kInstAddress, state_); + instruction_->set_size(4); + child_instruction_ = new Instruction(kInstAddress, state_); + child_instruction_->set_size(4); + // Initialize a portion of memory with a known pattern. + auto *db = state_->db_factory()->Allocate(8192); + auto span = db->Get<uint8_t>(); + for (int i = 0; i < 8192; i++) { + span[i] = i & 0xff; + } + memory_->Store(kDataLoadAddress - 4096, db); + db->DecRef(); + for (int i = 1; i < 32; i++) { + creg_[i] = + state_->GetRegister<CheriotRegister>(absl::StrCat("c", i)).first; + creg_[i]->ResetMemoryRoot(); + } + for (int i = 1; i < 32; i++) { + vreg_[i] = + state_->GetRegister<RVVectorRegister>(absl::StrCat("v", i)).first; + } + } + + ~RiscVCheriotVInstructionsTest() override { + delete state_; + delete rv_vector_; + instruction_->DecRef(); + child_instruction_->DecRef(); + delete memory_; + } + + // Creates immediate operands with the values from the vector and appends them + // to the given instruction. + template <typename T> + void AppendImmediateOperands(Instruction *inst, + const std::vector<T> &values) { + for (auto value : values) { + auto *src = new ImmediateOperand<T>(value); + inst->AppendSource(src); + } + } + + // Creates immediate operands with the values from the vector and appends them + // to the default instruction. + template <typename T> + void AppendImmediateOperands(const std::vector<T> &values) { + AppendImmediateOperands<T>(instruction_, values); + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the given instruction. + void AppendRegisterOperands(Instruction *inst, + const std::vector<std::string> &sources, + const std::vector<std::string> &destinations) { + for (auto ®_name : sources) { + auto *reg = state_->GetRegister<CheriotRegister>(reg_name).first; + inst->AppendSource(reg->CreateSourceOperand()); + } + for (auto ®_name : destinations) { + auto *reg = state_->GetRegister<CheriotRegister>(reg_name).first; + inst->AppendDestination(reg->CreateDestinationOperand(0)); + } + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the default instruction. + void AppendRegisterOperands(const std::vector<std::string> &sources, + const std::vector<std::string> &destinations) { + AppendRegisterOperands(instruction_, sources, destinations); + } + + // Returns the value of the named vector register. + template <typename T> + T GetRegisterValue(absl::string_view vreg_name) { + auto *reg = state_->GetRegister<CheriotRegister>(vreg_name).first; + return reg->data_buffer()->Get<T>(); + } + + // named register and sets it to the corresponding value. + template <typename T> + void SetRegisterValues( + const std::vector<tuple<std::string, const T>> &values) { + for (auto &[reg_name, value] : values) { + auto *reg = state_->GetRegister<CheriotRegister>(reg_name).first; + auto *db = state_->db_factory()->Allocate<CheriotRegister::ValueType>(1); + db->Set<T>(0, value); + reg->SetDataBuffer(db); + db->DecRef(); + } + } + + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the given instruction. + void AppendVectorRegisterOperands(Instruction *inst, + const std::vector<int> &sources, + const std::vector<int> &destinations) { + for (auto ®_no : sources) { + std::vector<RegisterBase *> reg_vec; + for (int i = 0; (i < 8) && (i + reg_no < 32); i++) { + std::string reg_name = absl::StrCat("v", i + reg_no); + reg_vec.push_back( + state_->GetRegister<RVVectorRegister>(reg_name).first); + } + auto *op = new RV32VectorSourceOperand( + absl::Span<RegisterBase *>(reg_vec), absl::StrCat("v", reg_no)); + inst->AppendSource(op); + } + for (auto ®_no : destinations) { + std::vector<RegisterBase *> reg_vec; + for (int i = 0; (i < 8) && (i + reg_no < 32); i++) { + std::string reg_name = absl::StrCat("v", i + reg_no); + reg_vec.push_back( + state_->GetRegister<RVVectorRegister>(reg_name).first); + } + auto *op = new RV32VectorDestinationOperand( + absl::Span<RegisterBase *>(reg_vec), 0, absl::StrCat("v", reg_no)); + inst->AppendDestination(op); + } + } + // Creates source and destination scalar register operands for the registers + // named in the two vectors and append them to the default instruction. + void AppendVectorRegisterOperands(const std::vector<int> &sources, + const std::vector<int> &destinations) { + AppendVectorRegisterOperands(instruction_, sources, destinations); + } + + // Returns the value of the named vector register. + template <typename T> + T GetVectorRegisterValue(absl::string_view reg_name) { + auto *reg = state_->GetRegister<RVVectorRegister>(reg_name).first; + return reg->data_buffer()->Get<T>(0); + } + + // Set a vector register value. Takes a vector of tuples of register names and + // spans of values, fetches each register and sets it to the corresponding + // value. + template <typename T> + void SetVectorRegisterValues( + const std::vector<tuple<std::string, Span<const T>>> &values) { + for (auto &[vreg_name, span] : values) { + auto *vreg = state_->GetRegister<RVVectorRegister>(vreg_name).first; + auto *db = state_->db_factory()->MakeCopyOf(vreg->data_buffer()); + db->template Set<T>(span); + vreg->SetDataBuffer(db); + db->DecRef(); + } + } + + // Initializes the semantic function of the instruction object. + void SetSemanticFunction(Instruction *inst, + Instruction::SemanticFunction fcn) { + inst->set_semantic_function(fcn); + } + + // Initializes the semantic function for the default instruction. + void SetSemanticFunction(Instruction::SemanticFunction fcn) { + instruction_->set_semantic_function(fcn); + } + + // Sets the default child instruction as the child of the default instruction. + void SetChildInstruction() { instruction_->AppendChild(child_instruction_); } + + // Initializes the semantic function for the default child instruction. + void SetChildSemanticFunction(Instruction::SemanticFunction fcn) { + child_instruction_->set_semantic_function(fcn); + } + + // Configure the vector unit according to the vtype and vlen values. + void ConfigureVectorUnit(uint32_t vtype, uint32_t vlen) { + Instruction *inst = new Instruction(state_); + AppendImmediateOperands<uint32_t>(inst, {vlen, vtype}); + SetSemanticFunction(inst, absl::bind_front(&Vsetvl, true, false)); + inst->Execute(nullptr); + inst->DecRef(); + } + + template <typename T> + T ComputeValue(int address) { + T value = 0; + uint8_t *ptr = reinterpret_cast<uint8_t *>(&value); + for (int j = 0; j < sizeof(T); j++) { + ptr[j] = (address + j) & 0xff; + } + return value; + } + + template <typename T> + void VectorLoadUnitStridedHelper() { + // Set up instructions. + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {}); + SetSemanticFunction(absl::bind_front(&VlUnitStrided, + /*element_width*/ sizeof(T))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction(&VlChild); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + uint32_t vtype = + (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + // Execute instruction. + instruction_->Execute(nullptr); + + // Check register values. + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < kVectorLengthInBytes / sizeof(T); i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + // First compute the expected value, then compare it. + T value = ComputeValue<T>(4096 + count * sizeof(T)); + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(T) << " LMUL8 " << lmul8 + << " Count " << count << " Reg " << reg << " value " << i; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(T) << " LMUL8 " << lmul8 + << " Count " << count << " Reg " << reg << " value " << i; + } + count++; + } + } + } + } + + template <typename T> + void VectorLoadStridedHelper() { + const int strides[5] = {1, 4, 0, -1, -3}; + // Set up instructions. + AppendRegisterOperands({kRs1Name, kRs2Name}, {}); + AppendVectorRegisterOperands({kVmask}, {}); + SetSemanticFunction(absl::bind_front(&VlStrided, + /*element_width*/ sizeof(T))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction(&VlChild); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different strides. + for (int s = 0; s < 5; s++) { + int32_t stride = strides[s] * sizeof(T); + SetRegisterValues<int32_t>({{kRs2Name, stride}}); + // Configure vector unit. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + // Execute instruction. + instruction_->Execute(nullptr); + + // Check register values. + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < kVectorLengthInBytes / sizeof(T); i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + // First compute the expected value, then compare it. + T value = ComputeValue<T>(4096 + count * stride); + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(T) << " stride: " << stride + << " LMUL8 " << lmul8 << " Count " << count << " Reg " << reg + << " value " << i; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(T) << " stride: " << stride + << " LMUL8 " << lmul8 << " Count " << count << " Reg " << reg + << " value " << i; + } + count++; + } + } + } + } + } + + template <typename T> + T IndexValue(int i) { + T offset = ~i & 0b1111; + T val = (i & ~0b1111) | offset; + return val * sizeof(T); + } + + // Helper function for testing vector load indexed instructions. + template <typename IndexType, typename ValueType> + void VectorLoadIndexedHelper() { + // Set up instructions. + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + SetSemanticFunction(absl::bind_front(&VlIndexed, + /*index_width*/ sizeof(IndexType))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction(&VlChild); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(ValueType)] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + int lmul8 = kLmul8Values[lmul_index]; + int index_emul8 = lmul8 * sizeof(IndexType) / sizeof(ValueType); + + if ((index_emul8 == 0) || (index_emul8 > 64)) { + // The index vector length is illegal. + instruction_->Execute(nullptr); + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(ValueType) * 8); + // Set up index vector values. + int values_per_reg = kVectorLengthInBytes / sizeof(IndexType); + for (int i = 0; i < num_values; i++) { + int reg_index = kVs2 + i / values_per_reg; + int reg_offset = i % values_per_reg; + auto index_span = vreg_[reg_index]->data_buffer()->Get<IndexType>(); + index_span[reg_offset] = IndexValue<ValueType>(i); + } + // Execute instruction. + instruction_->Execute(nullptr); + + // Check register values. + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<ValueType>(); + for (int i = 0; i < kVectorLengthInBytes / sizeof(ValueType); i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + // Compare expected value. + auto value = + ComputeValue<ValueType>(4096 + IndexValue<ValueType>(count)); + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(ValueType) << " index: " << index + << " LMUL8 " << lmul8 << " Count " << count << " reg " << reg; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(ValueType) << " index: " << index + << " LMUL8 " << lmul8 << " count " << count << " reg " << reg; + } + count++; + } + } + } + } + + // Helper function to test vector load segment strided instructions. + template <typename T> + void VectorLoadSegmentHelper() { + // Set up instructions. + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {}); + AppendRegisterOperands({kRs3Name}, {}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + SetSemanticFunction(absl::bind_front(&VlSegment, + /*element_width*/ sizeof(T))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction( + absl::bind_front(&VlSegmentChild, /*element_width*/ sizeof(T))); + AppendRegisterOperands(child_instruction_, {kRs3Name}, {}); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + // Iterate over legal values in the nf field. + for (int nf = 1; nf < 8; nf++) { + int num_fields = nf + 1; + // Iterate over different lmul values. + SetRegisterValues<int32_t>({{kRs3Name, nf}}); + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + // Execute instruction. + instruction_->Execute(nullptr); + + if (lmul8 * num_fields > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + + // Check register values. + int count = 0; + // Fields are in consecutive (groups of) registers. First compute the + // number of registers for each field. + int regs_per_field = ::std::max(1, lmul8 / 8); + for (int field = 0; field < num_fields; field++) { + int start_reg = kVd + field * regs_per_field; + int max_reg = start_reg + regs_per_field; + count = 0; + for (int reg = start_reg; reg < max_reg; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<T>(); + int num_reg_elements = + std::min(kVectorLengthInBytes / sizeof(T), + kVectorLengthInBytes * lmul8 / (sizeof(T) * 8)); + for (int i = 0; i < num_reg_elements; i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + int address = + 4096 + count * sizeof(T) * num_fields + field * sizeof(T); + T value = ComputeValue<T>(address); + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(T) << " LMUL8 " << lmul8 + << " Count " << count << " Reg " << reg << " value " << i; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(T) << " LMUL8 " << lmul8 + << " Count " << count << " Reg " << reg << " value " << i; + } + count++; + } + } + } + } + } + } + + // Helper function to test vector load segment strided instructions. + template <typename T> + void VectorLoadStridedSegmentHelper() { + const int strides[5] = {1, 4, 0, -1, -3}; + // Set up instructions. + // Base address and stride. + AppendRegisterOperands({kRs1Name, kRs2Name}, {}); + // Vector mask register. + AppendVectorRegisterOperands({kVmask}, {}); + // Operand to hold the number of fields. + AppendRegisterOperands({kRs3Name}, {}); + // Initialize the mask. + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Bind semantic function. + SetSemanticFunction(absl::bind_front(&VlSegmentStrided, + /*element_width*/ sizeof(T))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction( + absl::bind_front(&VlSegmentChild, /*element_width*/ sizeof(T))); + // Number of fields. + AppendRegisterOperands(child_instruction_, {kRs3Name}, {}); + // Destination vector register operand. + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + + // Set up register values. + // Base address. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + // Iterate over legal values in the nf field. + for (int nf = 1; nf < 8; nf++) { + int num_fields = nf + 1; + // Set the number of fields. + SetRegisterValues<int32_t>({{kRs3Name, nf}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + + // Try different strides. + for (int s = 0; s < 5; s++) { + int32_t stride = strides[s] * num_fields * sizeof(T); + // Set the stride. + SetRegisterValues<int32_t>({{kRs2Name, stride}}); + // Execute instruction. + instruction_->Execute(nullptr); + + if (lmul8 * num_fields > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + + // Check register values. + // Fields are in consecutive (groups of) registers. First compute the + // number of registers for each field. + int regs_per_field = ::std::max(1, lmul8 / 8); + for (int field = 0; field < num_fields; field++) { + int start_reg = kVd + field * regs_per_field; + int max_reg = start_reg + regs_per_field; + int count = 0; + for (int reg = start_reg; reg < max_reg; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<T>(); + int num_reg_elements = + std::min(kVectorLengthInBytes / sizeof(T), + kVectorLengthInBytes * lmul8 / (sizeof(T) * 8)); + for (int i = 0; i < num_reg_elements; i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + // First compute the expected value, then compare it. + int address = 4096 + stride * count + field * sizeof(T); + T value = ComputeValue<T>(address); + + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(T) << " stride: " << stride + << " LMUL8 " << lmul8 << " Count " << count << " Reg " + << reg << " value " << i; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(T) << " stride: " << stride + << " LMUL8 " << lmul8 << " Count " << count << " Reg " + << reg << " value " << i; + } + count++; + } + } + } + } + } + } + } + + // Helper function to test vector load segment indexed instructions. + template <typename IndexType, typename ValueType> + void VectorLoadIndexedSegmentHelper() { + // Set up instructions. + // Base address and stride. + AppendRegisterOperands({kRs1Name}, {}); + // Vector mask register. + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + // Operand to hold the number of fields. + AppendRegisterOperands({kRs3Name}, {}); + // Initialize the mask. + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Bind semantic function. + SetSemanticFunction(absl::bind_front(&VlSegmentIndexed, + /*element_width*/ sizeof(IndexType))); + // Add the child instruction that performs the register write-back. + SetChildInstruction(); + SetChildSemanticFunction( + absl::bind_front(&VlSegmentChild, + /*element_width*/ sizeof(ValueType))); + // Number of fields. + AppendRegisterOperands(child_instruction_, {kRs3Name}, {}); + // Destination vector register operand. + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + + // Set up register values. + // Base address. + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + // Iterate over legal values in the nf field. + for (int nf = 1; nf < 8; nf++) { + int num_fields = nf + 1; + // Set the number of fields. + SetRegisterValues<int32_t>({{kRs3Name, nf}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + rv_vector_->clear_vector_exception(); + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(ValueType)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + int index_emul8 = lmul8 * sizeof(IndexType) / sizeof(ValueType); + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(ValueType) * 8); + + if ((index_emul8 == 0) || (index_emul8 > 64)) { + // The index vector length is illegal. + instruction_->Execute(nullptr); + EXPECT_TRUE(rv_vector_->vector_exception()); + continue; + } + if (lmul8 * num_fields > 64) { + instruction_->Execute(nullptr); + EXPECT_TRUE(rv_vector_->vector_exception()); + continue; + } + + // Set up index vector values. + int values_per_reg = kVectorLengthInBytes / sizeof(IndexType); + for (int i = 0; i < num_values; i++) { + int reg_index = kVs2 + i / values_per_reg; + int reg_offset = i % values_per_reg; + auto index_span = vreg_[reg_index]->data_buffer()->Get<IndexType>(); + index_span[reg_offset] = IndexValue<ValueType>(i); + } + + // Execute instruction. + instruction_->Execute(nullptr); + EXPECT_FALSE(rv_vector_->vector_exception()); + + // Check register values. + // Fields are in consecutive (groups of) registers. First compute the + // number of registers for each field. + int regs_per_field = ::std::max(1, lmul8 / 8); + for (int field = 0; field < num_fields; field++) { + int start_reg = kVd + field * regs_per_field; + int max_reg = start_reg + regs_per_field; + int count = 0; + for (int reg = start_reg; reg < max_reg; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<ValueType>(); + int num_reg_elements = std::min( + kVectorLengthInBytes / sizeof(ValueType), + kVectorLengthInBytes * lmul8 / (sizeof(ValueType) * 8)); + for (int i = 0; i < num_reg_elements; i++) { + int mask_index = count / 8; + int mask_offset = count % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (count < num_values)) { + // First compute the expected value, then compare it. + int address = 4096 + IndexValue<IndexType>(count) + + field * sizeof(ValueType); + ValueType value = ComputeValue<ValueType>(address); + + EXPECT_EQ(value, span[i]) + << "element size " << sizeof(ValueType) << " LMUL8 " + << lmul8 << " Count " << count << " Reg " << reg + << " value " << i; + } else { + // The remainder of the values should be zero. + EXPECT_EQ(0, span[i]) + << "element size " << sizeof(ValueType) << " LMUL8 " + << lmul8 << " Count " << count << " Reg " << reg + << " value " << i; + } + count++; + } + } + } + } + } + } + + template <typename T> + void VectorStoreStridedHelper() { + const int strides[5] = {1, 4, 8, -1, -3}; + // Set up instructions. + AppendVectorRegisterOperands({kVs1}, {}); + AppendRegisterOperands({kRs1Name, kRs2Name}, {}); + AppendVectorRegisterOperands({kVmask}, {}); + SetSemanticFunction(absl::bind_front(&VsStrided, + /*element_width*/ sizeof(T))); + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Set the store data register elements to be consecutive integers. + for (int reg = 0; reg < 8; reg++) { + auto reg_span = vreg_[reg + kVs1]->data_buffer()->Get<T>(); + for (int i = 0; i < reg_span.size(); i++) { + reg_span[i] = static_cast<T>(reg * reg_span.size() + i + 1); + } + } + auto *clear_mem_db = state_->db_factory()->Allocate<T>(0x8000); + memset(clear_mem_db->raw_ptr(), 0, 0x8000); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Try different strides. + for (int s = 0; s < 5; s++) { + int32_t stride = strides[s] * sizeof(T); + SetRegisterValues<int32_t>({{kRs2Name, stride}}); + // Configure vector unit. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + // Execute instruction. + instruction_->Execute(nullptr); + + // Check memory values. + auto *data_db = state_->db_factory()->Allocate<T>(1); + uint64_t base = kDataStoreAddress; + T value = 1; + for (int i = 0; i < 8 * kVectorLengthInBytes / sizeof(T); i++) { + data_db->template Set<T>(0, 0); + state_->DbgLoadMemory(base, data_db); + int mask_index = i / 8; + int mask_offset = i % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (i < num_values)) { + EXPECT_EQ(data_db->template Get<T>(0), static_cast<T>(value)) + << "index: " << i << " element_size: " << sizeof(T) + << " lmul8: " << lmul8 << " stride: " << stride; + } else { + EXPECT_EQ(data_db->template Get<T>(0), 0) << "index: " << i; + } + base += stride; + value++; + } + data_db->DecRef(); + // Clear memory. + state_->DbgStoreMemory(kDataStoreAddress - 0x4000, clear_mem_db); + } + } + clear_mem_db->DecRef(); + } + + template <typename IndexType, typename ValueType> + void VectorStoreIndexedHelper() { + // Set up instructions. + AppendVectorRegisterOperands({kVs1}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + SetSemanticFunction(absl::bind_front(&VsIndexed, + /*index_width*/ sizeof(IndexType))); + + // Set up register values. + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + + int values_per_reg = kVectorLengthInBytes / sizeof(ValueType); + int index_values_per_reg = kVectorLengthInBytes / sizeof(IndexType); + for (int reg = 0; reg < 8; reg++) { + for (int i = 0; i < values_per_reg; i++) { + vreg_[kVs1 + reg]->data_buffer()->Set<ValueType>( + i, reg * values_per_reg + i); + } + } + + auto *data_db = state_->db_factory()->Allocate<ValueType>(1); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(ValueType)] << 3) | + kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + int lmul8 = kLmul8Values[lmul_index]; + int index_emul8 = lmul8 * sizeof(IndexType) / sizeof(ValueType); + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(ValueType) * 8); + + // Skip if the number of values is greater than the offset representation. + // This only happens for uint8_t. + if (num_values > 256) continue; + + // Check the index vector length. + if ((index_emul8 == 0) || (index_emul8 > 64)) { + // The index vector length is illegal. + instruction_->Execute(nullptr); + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + for (int i = 0; i < num_values; i++) { + int reg = i / index_values_per_reg; + int element = i % index_values_per_reg; + vreg_[kVs2 + reg]->data_buffer()->Set<IndexType>( + element, IndexValue<IndexType>(i)); + } + + // Execute instruction. + instruction_->Execute(nullptr); + + // Check results. + EXPECT_FALSE(rv_vector_->vector_exception()); + + // Check register values. + for (int i = 0; i < 8 * kVectorLengthInBytes / sizeof(ValueType); i++) { + uint64_t address = kDataStoreAddress + IndexValue<IndexType>(i); + state_->DbgLoadMemory(address, data_db); + int mask_index = i / 8; + int mask_offset = i % 8; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (i < num_values)) { + EXPECT_EQ(data_db->template Get<ValueType>(0), i) + << "reg[" << i / values_per_reg << "][" << i % values_per_reg + << "]"; + } else { + EXPECT_EQ(data_db->template Get<ValueType>(0), 0) + << "reg[" << i / values_per_reg << "][" << i % values_per_reg + << "]"; + } + // Clear the memory location. + data_db->template Set<ValueType>(0, 0); + state_->DbgStoreMemory(address, data_db); + } + } + data_db->DecRef(); + } + + // Helper function to test vector load segment strided instructions. + template <typename T> + void VectorStoreSegmentHelper() { + // Set up instructions. + // Store data register. + AppendVectorRegisterOperands({kVs1}, {}); + // Base address and stride. + AppendRegisterOperands({kRs1Name}, {}); + // Vector mask register. + AppendVectorRegisterOperands({kVmask}, {}); + // Operand to hold the number of fields. + AppendRegisterOperands({kRs3Name}, {}); + // Bind semantic function. + SetSemanticFunction(absl::bind_front(&VsSegment, + /*element_width*/ sizeof(T))); + + // Set up register values. + // Set the store data register elements to be consecutive integers. + for (int reg = 0; reg < 8; reg++) { + auto reg_span = vreg_[reg + kVs1]->data_buffer()->Get<T>(); + for (int i = 0; i < reg_span.size(); i++) { + reg_span[i] = static_cast<T>(reg * reg_span.size() + i + 1); + } + } + // Base address. + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + // Initialize the mask. + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + + int num_values_per_register = kVectorLengthInBytes / sizeof(T); + // Can load all the data in one load, so set the data_db size accordingly. + auto *data_db = + state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes * 8); + // Iterate over legal values in the nf field. + for (int nf = 1; nf < 8; nf++) { + int num_fields = nf + 1; + // Set the number of fields in the source operand. + SetRegisterValues<int32_t>({{kRs3Name, nf}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + // Clear the memory. + uint64_t base = kDataStoreAddress; + std::memset(data_db->raw_ptr(), 0, data_db->template size<uint8_t>()); + state_->DbgStoreMemory(base, data_db); + + // Execute instruction. + instruction_->Execute(nullptr); + + if (lmul8 * num_fields > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + int emul = sizeof(T) * lmul8 / rv_vector_->selected_element_width(); + if (emul == 0 || emul > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + // Check memory values. + T value = 1; + int vlen = rv_vector_->vector_length(); + int num_regs = std::max(1, lmul8 / 8); + // Load the store data. + state_->DbgLoadMemory(base, data_db); + // Iterate over fields. + for (int field = 0; field < num_fields; field++) { + // Iterate over the registers used for each field. + int segment_no = 0; + for (int reg = 0; reg < num_regs; reg++) { + // Iterate over segments within each register. + for (int i = 0; i < num_values_per_register; i++) { + // Get the mask value. + int mask_index = segment_no >> 3; + int mask_offset = segment_no & 0b111; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + T mem_value = + data_db->template Get<T>(segment_no * num_fields + field); + if (mask && (segment_no < vlen)) { + EXPECT_EQ(mem_value, static_cast<T>(value)) + << " segment_no: " << segment_no << " field: " << field + << " index: " << i << " element_size: " << sizeof(T) + << " lmul8: " << lmul8 << " mask: " << mask; + // Zero the memory location. + data_db->template Set<T>(0, 0); + } else { + EXPECT_EQ(mem_value, 0) + << " segment_no: " << segment_no << " field: " << field + << " index: " << i << " element_size: " << sizeof(T) + << " lmul8: " << lmul8 << " mask: " << mask; + } + value++; + segment_no++; + } + } + } + if (HasFailure()) { + data_db->DecRef(); + return; + } + } + } + data_db->DecRef(); + } + + // Helper function to test vector load segment strided instructions. + template <typename T> + void VectorStoreStridedSegmentHelper() { + const int strides[5] = {1, 4, 8, -1, -3}; + // Set up instructions. + // Store data register. + AppendVectorRegisterOperands({kVs1}, {}); + // Base address and stride. + AppendRegisterOperands({kRs1Name, kRs2Name}, {}); + // Vector mask register. + AppendVectorRegisterOperands({kVmask}, {}); + // Operand to hold the number of fields. + AppendRegisterOperands({kRs3Name}, {}); + // Bind semantic function. + SetSemanticFunction(absl::bind_front(&VsSegmentStrided, + /*element_width*/ sizeof(T))); + + // Set up register values. + // Set the store data register elements to be consecutive integers. + for (int reg = 0; reg < 8; reg++) { + auto reg_span = vreg_[reg + kVs1]->data_buffer()->Get<T>(); + for (int i = 0; i < reg_span.size(); i++) { + reg_span[i] = static_cast<T>(reg * reg_span.size() + i + 1); + } + } + // Base address. + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + // Initialize the mask. + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + + int num_values_per_register = kVectorLengthInBytes / sizeof(T); + + auto *data_db = state_->db_factory()->Allocate<T>(1); + // Iterate over legal values in the nf field. + for (int nf = 1; nf < 8; nf++) { + int num_fields = nf + 1; + // Set the number of fields in the source operand. + SetRegisterValues<int32_t>({{kRs3Name, nf}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + ConfigureVectorUnit(vtype, /*vlen*/ 1024); + + // Try different strides. + for (int s : strides) { + int32_t stride = s * num_fields * sizeof(T); + // Set the stride. + SetRegisterValues<int32_t>({{kRs2Name, stride}}); + // Execute instruction. + instruction_->Execute(nullptr); + + if (lmul8 * num_fields > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + int emul = sizeof(T) * lmul8 / rv_vector_->selected_element_width(); + if (emul == 0 || emul > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()); + // Check memory values. + uint64_t base = kDataStoreAddress; + T value = 1; + int vlen = rv_vector_->vector_length(); + int num_regs = std::max(1, lmul8 / 8); + // Iterate over fields. + for (int field = 0; field < num_fields; field++) { + uint64_t address = base + field * sizeof(T); + // Iterate over the registers used for each field. + int segment_no = 0; + for (int reg = 0; reg < num_regs; reg++) { + // Iterate over segments within each register. + for (int i = 0; i < num_values_per_register; i++) { + // Load the data. + data_db->template Set<T>(0, 0); + state_->DbgLoadMemory(address, data_db); + // Get the mask value. + int mask_index = segment_no >> 3; + int mask_offset = segment_no & 0b111; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (segment_no < vlen)) { + EXPECT_EQ(data_db->template Get<T>(0), static_cast<T>(value)) + << std::hex << "address: 0x" << address << std::dec + << " segment_no: " << segment_no << " field: " << field + << " index: " << i << " element_size: " << sizeof(T) + << " lmul8: " << lmul8 << " stride: " << stride; + // Zero the memory location. + data_db->template Set<T>(0, 0); + state_->StoreMemory(instruction_, address, data_db); + } else { + EXPECT_EQ(data_db->template Get<T>(0), 0) + << std::hex << "address: 0x" << address << std::dec + << " segment_no: " << segment_no << " field: " << field + << " index: " << i << " element_size: " << sizeof(T) + << " lmul8: " << lmul8 << " stride: " << stride + << " mask: " << mask; + } + value++; + address += stride; + segment_no++; + } + } + } + if (HasFailure()) { + data_db->DecRef(); + return; + } + } + } + } + data_db->DecRef(); + } + + // Helper function to test vector load segment strided instructions. + template <typename T, typename I> + void VectorStoreIndexedSegmentHelper() { + // Ensure that the IndexType is signed. + using IndexType = typename std::make_signed<I>::type; + // Set up instructions. + // Store data register. + AppendVectorRegisterOperands({kVs1}, {}); + // Base address and stride. + AppendRegisterOperands({kRs1Name}, {}); + // Vector index register and vector mask register. + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + // Operand to hold the number of fields. + AppendRegisterOperands({kRs3Name}, {}); + // Bind semantic function. + SetSemanticFunction( + absl::bind_front(&VsSegmentIndexed, /*index_width*/ sizeof(IndexType))); + + // Set up register values. + // Set the store data register elements to be consecutive integers. + for (int reg = 0; reg < 8; reg++) { + auto reg_span = vreg_[reg + kVs1]->data_buffer()->Get<T>(); + for (int i = 0; i < reg_span.size(); i++) { + reg_span[i] = static_cast<T>(reg * reg_span.size() + i + 1); + } + } + // Base address. + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + // Initialize the mask. + SetVectorRegisterValues<uint8_t>( + {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); + // Index values. + int index_values_per_reg = kVectorLengthInBytes / sizeof(IndexType); + int num_values_per_register = kVectorLengthInBytes / sizeof(T); + + auto *data_db = state_->db_factory()->Allocate<T>(1); + // Iterate over legal values in the nf field. + for (int num_fields = 1; num_fields < 8; num_fields++) { + // Set the number of fields in the source operand. + SetRegisterValues<int32_t>({{kRs3Name, num_fields - 1}}); + // Iterate over different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + // Configure vector unit, set the sew to the element width. + uint32_t vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | + kLmulSettings[lmul_index]; + int lmul8 = kLmul8Values[lmul_index]; + + // Set up Index vector. + // Max number of segments (for testing) is limited by the range of the + // index type. For byte indices, the range is only +/- 128 for each + // index. Since the index isn't scaled, this means that the max number + // of segments for byte indices is 256 / (sizeof(T) * nf) + int segment_size = sizeof(T) * num_fields; + int max_segments = + kVectorLengthInBytes * lmul8 / (8 * num_fields * sizeof(T)); + if (sizeof(IndexType) == 1) { + max_segments = std::min(256 / num_fields, max_segments); + } + ConfigureVectorUnit(vtype, max_segments); + + int emul8 = + sizeof(IndexType) * lmul8 / rv_vector_->selected_element_width(); + // Make sure not to write too many indices. At this point the emul + // value may still be "illegal", so avoid a "crash" due to writing + // the data_buffer out of range. + int max_indices = kVectorLengthInBytes * + std::min(8, std::max(1, emul8 / 8)) / + sizeof(IndexType); + if (max_indices > max_segments) { + max_indices = max_segments; + } + // Verify that the index space is large enough. That means that the + // index data type can contain enough index values to spread out all the + // stores to unique locations. + int num_values = kVectorLengthInBytes * lmul8 / (sizeof(T) * 8); + if (((num_values * sizeof(T)) >> 8) > (1 << (sizeof(IndexType) - 1))) { + LOG(WARNING) << "Index space is too small for the number of bytes: " + << num_values * sizeof(T); + continue; + } + for (int i = 0; i < max_indices; i++) { + int reg = i / index_values_per_reg; + int element = i % index_values_per_reg; + // Scale index by the segment size to avoid writing to the same + // location twice. + vreg_[kVs2 + reg]->data_buffer()->Set<IndexType>( + element, IndexValue<IndexType>(i) * segment_size); + } + // Execute instruction. + instruction_->Execute(nullptr); + + // Check for exceptions when they should be set, and verify no + // exception otherwise. + if (lmul8 * num_fields > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "emul8: " << emul8 << " lmul8: " << lmul8; + rv_vector_->clear_vector_exception(); + continue; + } + if (emul8 == 0 || emul8 > 64) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "emul8: " << emul8 << " lmul8: " << lmul8; + rv_vector_->clear_vector_exception(); + continue; + } + EXPECT_FALSE(rv_vector_->vector_exception()) + << "emul8: " << emul8 << " lmul8: " << lmul8; + + // Check memory values. + uint64_t base = kDataStoreAddress; + T value = 1; + int vlen = rv_vector_->vector_length(); + // Iterate over fields. + for (int field = 0; field < num_fields; field++) { + // Expected value starts at 1 for the first register element and + // increments from there. The following computes the expected value + // of segment 0 for each field. + value = field * kVectorLengthInBytes * std::max(1, lmul8 / 8) / + sizeof(T) + + 1; + for (int segment = 0; segment < max_segments; segment++) { + int index_reg = segment / index_values_per_reg; + int index_no = segment % index_values_per_reg; + // Load the data. + int64_t index = + vreg_[kVs2 + index_reg]->data_buffer()->Get<IndexType>( + index_no); + uint64_t address = base + field * sizeof(T) + index; + state_->DbgLoadMemory(address, data_db); + int element = segment % num_values_per_register; + // Get the mask value. + int mask_index = segment >> 3; + int mask_offset = segment & 0b111; + bool mask = (kA5Mask[mask_index] >> mask_offset) & 0x1; + if (mask && (segment < vlen)) { + EXPECT_EQ(data_db->template Get<T>(0), static_cast<T>(value)) + << std::hex << "address: 0x" << address << std::dec + << " index: " << index << " segment_no: " << segment + << " field: " << field << " i: " << element + << " element_size: " << sizeof(T) << " lmul8: " << lmul8 + << " num_fields: " << num_fields; + // Zero the memory location. + data_db->template Set<T>(0, 0); + state_->StoreMemory(instruction_, address, data_db); + } else { + EXPECT_EQ(data_db->template Get<T>(0), 0) + << std::hex << "address: 0x" << address << std::dec + << " index: " << index << " segment_no: " << segment + << " field: " << field << " i: " << element + << " element_size: " << sizeof(T) << " lmul8: " << lmul8 + << " mask: " << mask; + } + value++; + } + } + if (HasFailure()) { + data_db->DecRef(); + return; + } + } + } + data_db->DecRef(); + } + + protected: + CheriotRegister *creg_[32]; + RVVectorRegister *vreg_[32]; + CheriotState *state_; + Instruction *instruction_; + Instruction *child_instruction_; + TaggedFlatDemandMemory *memory_; + CheriotVectorState *rv_vector_; +}; + +// Test the vector configuration set instructions. There are three separate +// versions depending on whether Rs1 is X0 or not, of if Rd is X0. +// The first handles the case when Rs1 is not X0. +TEST_F(RiscVCheriotVInstructionsTest, VsetvlNN) { + AppendRegisterOperands({kRs1Name, kRs2Name}, {kRdName}); + SetSemanticFunction(absl::bind_front(&Vsetvl, + /*rd_zero*/ false, + /*rs1_zero*/ false)); + for (int lmul = 0; lmul < 7; lmul++) { + for (int sew = 0; sew < 4; sew++) { + for (int vlen_select = 0; vlen_select < 2; vlen_select++) { + // Try vlen below max and above. + uint32_t vlen = (vlen_select == 0) ? 16 : 1024; + uint32_t vma = (lmul & 1) ? 0b1'0'000'000 : 0; + uint32_t vta = (sew & 1) ? 0b0'1'000'000 : 0; + uint32_t vtype = + vma | vta | (kSewSettings[sew] << 3) | kLmulSettings[lmul]; + + SetRegisterValues<uint32_t>( + {{kRs1Name, vlen}, {kRs2Name, vtype}, {kRdName, 0}}); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Check results. + uint32_t expected_vlen = + std::min<uint32_t>(vlen, kVectorLengthInBytes * kLmul8Values[lmul] / + (8 * kSewValues[sew])); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<uint32_t>(0), expected_vlen) + << "LMUL: " << kLmul8Values[lmul] << " SEW: " << kSewValues[sew] + << " AVL: " << vlen; + EXPECT_EQ(rv_vector_->vector_length(), expected_vlen); + EXPECT_EQ(rv_vector_->vector_mask_agnostic(), vma != 0); + EXPECT_EQ(rv_vector_->vector_tail_agnostic(), vta != 0); + EXPECT_EQ(rv_vector_->vector_length_multiplier(), kLmul8Values[lmul]); + EXPECT_EQ(rv_vector_->selected_element_width(), kSewValues[sew]); + } + } + } +} + +// The case when Rd is X0, but not Rs1. +TEST_F(RiscVCheriotVInstructionsTest, VsetvlZN) { + AppendRegisterOperands({kRs1Name, kRs2Name}, {kRdName}); + SetSemanticFunction(absl::bind_front(&Vsetvl, + /*rd_zero*/ true, /*rs1_zero*/ false)); + for (int lmul = 0; lmul < 7; lmul++) { + for (int sew = 0; sew < 4; sew++) { + for (int vlen_select = 0; vlen_select < 2; vlen_select++) { + // Try vlen below max and above. + uint32_t vlen = (vlen_select == 0) ? 16 : 1024; + uint32_t vma = (lmul & 1) ? 0b1'0'000'000 : 0; + uint32_t vta = (sew & 1) ? 0b0'1'000'000 : 0; + uint32_t vtype = + vma | vta | (kSewSettings[sew] << 3) | kLmulSettings[lmul]; + + SetRegisterValues<uint32_t>( + {{kRs1Name, vlen}, {kRs2Name, vtype}, {kRdName, 0}}); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Check results. + uint32_t expected_vlen = + std::min<uint32_t>(vlen, kVectorLengthInBytes * kLmul8Values[lmul] / + (8 * kSewValues[sew])); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<uint32_t>(0), 0); + EXPECT_EQ(rv_vector_->vector_length(), expected_vlen); + EXPECT_EQ(rv_vector_->vector_mask_agnostic(), vma != 0); + EXPECT_EQ(rv_vector_->vector_tail_agnostic(), vta != 0); + EXPECT_EQ(rv_vector_->vector_length_multiplier(), kLmul8Values[lmul]); + EXPECT_EQ(rv_vector_->selected_element_width(), kSewValues[sew]); + } + } + } +} + +// The case when Rd is not X0, but Rs1 is X0. +TEST_F(RiscVCheriotVInstructionsTest, VsetvlNZ) { + AppendRegisterOperands({kRs1Name, kRs2Name}, {kRdName}); + SetSemanticFunction(absl::bind_front(&Vsetvl, + /*rd_zero*/ false, /*rs1_zero*/ true)); + for (int lmul = 0; lmul < 7; lmul++) { + for (int sew = 0; sew < 4; sew++) { + for (int vlen_select = 0; vlen_select < 2; vlen_select++) { + // Try vlen below max and above. + uint32_t vlen = (vlen_select == 0) ? 16 : 1024; + uint32_t vma = (lmul & 1) ? 0b1'0'000'000 : 0; + uint32_t vta = (sew & 1) ? 0b0'1'000'000 : 0; + uint32_t vtype = + vma | vta | (kSewSettings[sew] << 3) | kLmulSettings[lmul]; + + SetRegisterValues<uint32_t>( + {{kRs1Name, vlen}, {kRs2Name, vtype}, {kRdName, 0}}); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Check results. + // In this case, vlen is vlen max. + uint32_t expected_vlen = + kVectorLengthInBytes * kLmul8Values[lmul] / (8 * kSewValues[sew]); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<uint32_t>(0), expected_vlen); + EXPECT_EQ(rv_vector_->vector_length(), expected_vlen); + EXPECT_EQ(rv_vector_->vector_mask_agnostic(), vma != 0); + EXPECT_EQ(rv_vector_->vector_tail_agnostic(), vta != 0); + EXPECT_EQ(rv_vector_->vector_length_multiplier(), kLmul8Values[lmul]); + EXPECT_EQ(rv_vector_->selected_element_width(), kSewValues[sew]); + } + } + } +} + +// The case when Rd and Rs1 are X0. In this case we are testing if an invalid +// vector operation happens, which occurs if the max vector length changes due +// to the new value of vtype. +TEST_F(RiscVCheriotVInstructionsTest, VsetvlZZ) { + AppendRegisterOperands({kRs1Name, kRs2Name}, {kRdName}); + SetSemanticFunction(absl::bind_front(&Vsetvl, + /*rd_zero*/ true, /*rs1_zero*/ true)); + // Iterate over vector lengths. + for (int vlen = 512; vlen > 8; vlen /= 2) { + // First set the appropriate vector type for sew = 1 byte. + uint32_t lmul8 = vlen * 8 / kVectorLengthInBytes; + ASSERT_LE(lmul8, 64); + ASSERT_GE(lmul8, 1); + int lmul8_log2 = absl::bit_width<uint32_t>(lmul8); + int lmul_setting = kLmulSettingByLogSize[lmul8_log2]; + // Set vtype for this vector length. + rv_vector_->SetVectorType(lmul_setting); + int max_vector_length = rv_vector_->max_vector_length(); + for (int lmul = 0; lmul < 7; lmul++) { + for (int sew = 0; sew < 4; sew++) { + // Clear any exception. + rv_vector_->clear_vector_exception(); + // Set up the vtype to try to set using the instruction. + uint32_t vma = (lmul & 1) ? 0b1'0'000'000 : 0; + uint32_t vta = (sew & 1) ? 0b0'1'000'000 : 0; + uint32_t vtype = + vma | vta | (kSewSettings[sew] << 3) | kLmulSettings[lmul]; + + SetRegisterValues<uint32_t>( + {{kRs1Name, 0xdeadbeef}, {kRs2Name, vtype}, {kRdName, 0xdeadbeef}}); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Check results. + uint32_t new_vlen = + kVectorLengthInBytes * kLmul8Values[lmul] / (8 * kSewValues[sew]); + // If vlen changes, then we expect an error and no change. + if (new_vlen != max_vector_length) { + EXPECT_TRUE(rv_vector_->vector_exception()) + << "vlen: " << max_vector_length + << " lmul: " << kLmul8Values[lmul] << " sew: " << kSewValues[sew]; + } else { + // Otherwise, check that the values are as expected. + EXPECT_FALSE(rv_vector_->vector_exception()); + EXPECT_EQ(rv_vector_->vector_length_multiplier(), kLmul8Values[lmul]); + EXPECT_EQ(rv_vector_->selected_element_width(), kSewValues[sew]); + EXPECT_EQ(rv_vector_->vector_mask_agnostic(), vma != 0); + EXPECT_EQ(rv_vector_->vector_tail_agnostic(), vta != 0); + } + // No change in registers or vector length. + EXPECT_EQ(creg_[kRs1]->data_buffer()->Get<uint32_t>(0), 0xdeadbeef); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<uint32_t>(0), 0xdeadbeef); + EXPECT_EQ(rv_vector_->max_vector_length(), max_vector_length); + } + } + } +} + +// This tests the semantic function for the VleN and VlseN instructions. VleN +// is just a unit stride Vlse. +TEST_F(RiscVCheriotVInstructionsTest, Vle8) { + VectorLoadUnitStridedHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vle16) { + VectorLoadUnitStridedHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vse32) { + VectorLoadUnitStridedHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vle64) { + VectorLoadUnitStridedHelper<uint64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlse8) { + VectorLoadStridedHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlse16) { + VectorLoadStridedHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlse32) { + VectorLoadStridedHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlse64) { + VectorLoadStridedHelper<uint64_t>(); +} + +// Test of vector load mask. +TEST_F(RiscVCheriotVInstructionsTest, Vlm) { + // Set up operands and register values. + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vlm); + SetChildInstruction(); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + SetChildSemanticFunction(&VlChild); + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + // Execute instruction. + instruction_->Execute(nullptr); + EXPECT_FALSE(rv_vector_->vector_exception()); + auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(i & 0xff, span[i]) << "element: " << i; + } +} + +// Test of vector load register. Loads 1, 2, 4 or 8 registers. +TEST_F(RiscVCheriotVInstructionsTest, VlRegister) { + // Set up operands and register values. + AppendRegisterOperands({kRs1Name}, {}); + SetChildInstruction(); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + SetChildSemanticFunction(&VlChild); + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + // Test 1, 2, 4 and 8 register versions. + for (int num_reg = 1; num_reg <= 8; num_reg *= 2) { + SetSemanticFunction( + absl::bind_front(&VlRegister, num_reg, /*element_width*/ 1)); + // Execute instruction. + instruction_->Execute(); + // Check values. + + for (int reg = kVd; reg < num_reg; reg++) { + auto span = vreg_[reg]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(span[i], i & 0xff) + << absl::StrCat("Reg: ", reg, " element ", i); + } + } + } +} + +// Indexed loads directly encode the element width of the index value. The +// width of the load value is determined by sew (selected element width). +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed8_8) { + VectorLoadIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed8_16) { + VectorLoadIndexedHelper<uint8_t, uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed8_32) { + VectorLoadIndexedHelper<uint8_t, uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed8_64) { + VectorLoadIndexedHelper<uint8_t, uint64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed16_8) { + VectorLoadIndexedHelper<uint16_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed16_16) { + VectorLoadIndexedHelper<uint16_t, uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed16_32) { + VectorLoadIndexedHelper<uint16_t, uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed16_64) { + VectorLoadIndexedHelper<uint16_t, uint64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed32_8) { + VectorLoadIndexedHelper<uint32_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed32_16) { + VectorLoadIndexedHelper<uint32_t, uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed32_32) { + VectorLoadIndexedHelper<uint32_t, uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed32_64) { + VectorLoadIndexedHelper<uint32_t, uint64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed64_8) { + VectorLoadIndexedHelper<uint64_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed64_16) { + VectorLoadIndexedHelper<uint64_t, uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed64_32) { + VectorLoadIndexedHelper<uint64_t, uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VlIndexed64_64) { + VectorLoadIndexedHelper<uint64_t, uint64_t>(); +} + +// Test vector load segment unit stride. +TEST_F(RiscVCheriotVInstructionsTest, Vlsege8) { + VectorLoadSegmentHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlsege16) { + VectorLoadSegmentHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlsege32) { + VectorLoadSegmentHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlsege64) { + VectorLoadSegmentHelper<uint64_t>(); +} + +// Test vector load segment, strided. +TEST_F(RiscVCheriotVInstructionsTest, Vlssege8) { + VectorLoadStridedSegmentHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlssege16) { + VectorLoadStridedSegmentHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlssege32) { + VectorLoadStridedSegmentHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vlssege64) { + VectorLoadStridedSegmentHelper<uint64_t>(); +} + +// Test vector load segment, indexed. +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei8_8) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei8_16) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei8_32) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei8_64) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei16_8) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei16_16) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei16_32) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei16_64) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei32_8) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei32_16) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei32_32) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei32_64) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei64_8) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei64_16) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei64_32) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vluxsegei64_64) { + VectorLoadIndexedSegmentHelper<uint8_t, uint8_t>(); +} + +// Test Vector store strided. + +TEST_F(RiscVCheriotVInstructionsTest, Vsse8) { + VectorStoreStridedHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vsse16) { + VectorStoreStridedHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vsse32) { + VectorStoreStridedHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vsse64) { + VectorStoreStridedHelper<uint64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, Vsm) { + ConfigureVectorUnit(0b0'0'000'000, /*vlen*/ 1024); + // Set up operands and register values. + AppendVectorRegisterOperands({kVs1}, {}); + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vsm); + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, i); + } + // Execute instruction. + instruction_->Execute(nullptr); + + // Verify result. + EXPECT_FALSE(rv_vector_->vector_exception()); + auto *data_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + state_->DbgLoadMemory(kDataStoreAddress, data_db); + auto span = data_db->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(static_cast<int>(span[i]), i); + } + data_db->DecRef(); +} + +// Tests of indexed stores, cross product of index types with value types. +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed8_8) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed8_16) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed8_32) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed8_64) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed16_8) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed16_16) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed16_32) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed16_64) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed32_8) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed32_16) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed32_32) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed32_64) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed64_8) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed64_16) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed64_32) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsIndexed64_64) { + VectorStoreIndexedHelper<uint8_t, uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsRegister) { + ConfigureVectorUnit(0b0'0'000'000, /*vlen*/ 1024); + int num_elem = kVectorLengthInBytes / sizeof(uint64_t); + // Set up operands and register values. + AppendVectorRegisterOperands({kVs1}, {}); + for (int reg = 0; reg < 8; reg++) { + for (int i = 0; i < num_elem; i++) { + vreg_[kVs1 + reg]->data_buffer()->Set<uint64_t>(i, reg * num_elem + i); + } + } + AppendRegisterOperands({kRs1Name}, {}); + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + + auto data_db = state_->db_factory()->Allocate(8 * kVectorLengthInBytes); + for (int num_regs = 1; num_regs <= 8; num_regs++) { + // Clear Memory. + memset(data_db->raw_ptr(), 0, data_db->size<uint8_t>()); + state_->DbgStoreMemory(kDataStoreAddress, data_db); + + SetSemanticFunction(absl::bind_front(&VsRegister, num_regs)); + + // Execute instruction. + instruction_->Execute(); + + // Verify results. + EXPECT_FALSE(rv_vector_->vector_exception()); + uint64_t base = kDataStoreAddress; + for (int reg = 0; reg < 8; reg++) { + state_->DbgLoadMemory(base, data_db); + auto span = data_db->Get<uint64_t>(); + for (int i = 0; i < num_elem; i++) { + if (reg < num_regs) { + EXPECT_EQ(span[i], reg * num_elem + i) + << "reg[" << reg << "][" << i << "]"; + } else { + EXPECT_EQ(span[i], 0) << "reg[" << reg << "][" << i << "]"; + } + } + base += kVectorLengthInBytes; + } + } + data_db->DecRef(); +} + +// Test vector store segment unit stride. +TEST_F(RiscVCheriotVInstructionsTest, VsSegment8) { + VectorStoreSegmentHelper<uint8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegment16) { + VectorStoreSegmentHelper<uint16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegment32) { + VectorStoreSegmentHelper<uint32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegment64) { + VectorStoreSegmentHelper<uint64_t>(); +} + +// Test vector store segment strided. +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentStrided8) { + VectorStoreStridedSegmentHelper<uint8_t>(); +} +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentStrided16) { + VectorStoreStridedSegmentHelper<uint16_t>(); +} +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentStrided32) { + VectorStoreStridedSegmentHelper<uint32_t>(); +} +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentStrided64) { + VectorStoreStridedSegmentHelper<uint64_t>(); +} + +// Test vector store segment indexed. Test each +// combination of element size and index size. +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed8_8) { + VectorStoreIndexedSegmentHelper<uint8_t, int8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed8_16) { + VectorStoreIndexedSegmentHelper<uint8_t, int16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed8_32) { + VectorStoreIndexedSegmentHelper<uint8_t, int32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed8_64) { + VectorStoreIndexedSegmentHelper<uint8_t, int64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed16_8) { + VectorStoreIndexedSegmentHelper<uint16_t, int8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed16_16) { + VectorStoreIndexedSegmentHelper<uint16_t, int16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed16_32) { + VectorStoreIndexedSegmentHelper<uint16_t, int32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed16_64) { + VectorStoreIndexedSegmentHelper<uint16_t, int64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed32_8) { + VectorStoreIndexedSegmentHelper<uint32_t, int8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed32_16) { + VectorStoreIndexedSegmentHelper<uint32_t, int16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed32_32) { + VectorStoreIndexedSegmentHelper<uint32_t, int32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed32_64) { + VectorStoreIndexedSegmentHelper<uint32_t, int64_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed64_8) { + VectorStoreIndexedSegmentHelper<uint64_t, int8_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed64_16) { + VectorStoreIndexedSegmentHelper<uint64_t, int16_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed64_32) { + VectorStoreIndexedSegmentHelper<uint64_t, int32_t>(); +} + +TEST_F(RiscVCheriotVInstructionsTest, VsSegmentIndexed64_64) { + VectorStoreIndexedSegmentHelper<uint64_t, int64_t>(); +} +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_opi_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_opi_instructions_test.cc new file mode 100644 index 0000000..268099f --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_opi_instructions_test.cc
@@ -0,0 +1,2381 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_opi_instructions.h" + +#include <cstdint> +#include <limits> +#include <vector> + +#include "absl/functional/bind_front.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "riscv//riscv_register.h" + +// This file contains test cases for most of the RiscV OPIVV, IPIVX and OPIVI +// instructions. The only instructions not covered by this file are the vector +// permutation instructions. + +namespace { + +using ::absl::Span; +using ::mpact::sim::cheriot::CheriotVectorState; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::MakeUnsigned; +using ::mpact::sim::generic::WideType; +using ::mpact::sim::riscv::RV32Register; +using ::mpact::sim::riscv::RVVectorRegister; + +// Semantic functions. +using ::mpact::sim::cheriot::Vadc; +using ::mpact::sim::cheriot::Vadd; +using ::mpact::sim::cheriot::Vand; +using ::mpact::sim::cheriot::Vmadc; +using ::mpact::sim::cheriot::Vmax; +using ::mpact::sim::cheriot::Vmaxu; +using ::mpact::sim::cheriot::Vmerge; +using ::mpact::sim::cheriot::Vmin; +using ::mpact::sim::cheriot::Vminu; +using ::mpact::sim::cheriot::Vmsbc; +using ::mpact::sim::cheriot::Vmseq; +using ::mpact::sim::cheriot::Vmsgt; +using ::mpact::sim::cheriot::Vmsgtu; +using ::mpact::sim::cheriot::Vmsle; +using ::mpact::sim::cheriot::Vmsleu; +using ::mpact::sim::cheriot::Vmslt; +using ::mpact::sim::cheriot::Vmsltu; +using ::mpact::sim::cheriot::Vmsne; +using ::mpact::sim::cheriot::Vmvr; +using ::mpact::sim::cheriot::Vnclip; +using ::mpact::sim::cheriot::Vnclipu; +using ::mpact::sim::cheriot::Vnsra; +using ::mpact::sim::cheriot::Vnsrl; +using ::mpact::sim::cheriot::Vor; +using ::mpact::sim::cheriot::Vrsub; +using ::mpact::sim::cheriot::Vsadd; +using ::mpact::sim::cheriot::Vsaddu; +using ::mpact::sim::cheriot::Vsbc; +using ::mpact::sim::cheriot::Vsll; +using ::mpact::sim::cheriot::Vsmul; +using ::mpact::sim::cheriot::Vsra; +using ::mpact::sim::cheriot::Vsrl; +using ::mpact::sim::cheriot::Vssra; +using ::mpact::sim::cheriot::Vssrl; +using ::mpact::sim::cheriot::Vssub; +using ::mpact::sim::cheriot::Vssubu; +using ::mpact::sim::cheriot::Vsub; +using ::mpact::sim::cheriot::Vxor; + +class RiscVCheriotVectorInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase {}; + +// Each instruction is tested for each element width, and for vector-vector +// as well as vector-scalar (as applicable). + +// Vector add. +// Vector-vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd8VV) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vadd8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 + val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd16VV) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vadd16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 + val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd32VV) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vadd32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 + val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd64VV) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vadd64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 + val1; }); +} + +// Vector-scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd8VX) { + SetSemanticFunction(&Vadd); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vadd8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return val0 + static_cast<uint8_t>(val1); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd16VX) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vadd16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 + val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd32VX) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vadd32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 + val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vadd64VX) { + SetSemanticFunction(&Vadd); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vadd64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 + val1; }); +} + +// Vector subtract. +// Vector-vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub8VV) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vsub8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub16VV) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vsub16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub32VV) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vsub32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub64VV) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vsub64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 - val1; }); +} + +// Vector-scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub8VX) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vsub8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub16VX) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vsub16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub32VX) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vsub32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 - val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsub64VX) { + SetSemanticFunction(&Vsub); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vsub64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 - val1; }); +} + +// Vector reverse subtract. +// Vector-Scalar only. +TEST_F(RiscVCheriotVectorInstructionsTest, Vrsub8VX) { + SetSemanticFunction(&Vrsub); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vrsub8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val1 - val0; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vrsub16VX) { + SetSemanticFunction(&Vrsub); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vrsub16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val1 - val0; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vrsub32VX) { + SetSemanticFunction(&Vrsub); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vrsub32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val1 - val0; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vrsub64VX) { + SetSemanticFunction(&Vrsub); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vrsub64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val1 - val0; }); +} + +// Vector and. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vand8VV) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vand8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand16VV) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vand16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand32VV) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vand32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand64VV) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vand64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 & val1; }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vand8VX) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vand8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand16VX) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vand16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand32VX) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vand32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 & val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vand64VX) { + SetSemanticFunction(&Vand); + + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vand64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 & val1; }); +} + +// Vector or. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vor8VV) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vor8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor16VV) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vor16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor32VV) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vor32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor64VV) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vor64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 | val1; }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vor8VX) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vor8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor16VX) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vor16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor32VX) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vor32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 | val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vor64VX) { + SetSemanticFunction(&Vor); + + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vor64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 | val1; }); +} + +// Vector xor. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor8VV) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vxor8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor16VV) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vxor16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor32VV) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vxor32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor64VV) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vxor64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor8VX) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vxor8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor16VX) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vxor16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor32VX) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vxor32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { return val0 ^ val1; }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vxor64VX) { + SetSemanticFunction(&Vxor); + + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vxor64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { return val0 ^ val1; }); +} + +// Vector sll. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll8VV) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vsll8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return val0 << (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll16VV) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return val0 << (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll32VV) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return val0 << (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll64VV) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vsll64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return val0 << (val1 & 0b11'1111); + }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll8VX) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vsll8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return val0 << (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll16VX) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return val0 << (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll32VX) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return val0 << (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsll64VX) { + SetSemanticFunction(&Vsll); + + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vsll64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return val0 << (val1 & 0b11'1111); + }); +} + +// Vector srl. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl8VV) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vsrl8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return val0 >> (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl16VV) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vsrl16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return val0 >> (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl32VV) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vsrl32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return val0 >> (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl64VV) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vsrl64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return val0 >> (val1 & 0b11'1111); + }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl8VX) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vsrl8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return val0 >> (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl16VX) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vsrl16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return val0 >> (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl32VX) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vsrl32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return val0 >> (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsrl64VX) { + SetSemanticFunction(&Vsrl); + + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vsrl64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return val0 >> (val1 & 0b11'1111); + }); +} + +// Vector sra. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra8VV) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVV<uint8_t, int8_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](int8_t val0, uint8_t val1) -> int8_t { + return val0 >> (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra16VV) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVV<uint16_t, int16_t, uint16_t>( + "Vsra16", /*sew*/ 16, instruction_, + [](int16_t val0, uint16_t val1) -> int16_t { + return val0 >> (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra32VV) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVV<uint32_t, int32_t, uint32_t>( + "Vsra32", /*sew*/ 32, instruction_, + [](int32_t val0, uint32_t val1) -> int32_t { + return val0 >> (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra64VV) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVV<uint64_t, int64_t, uint64_t>( + "Vsll64", /*sew*/ 64, instruction_, + [](int64_t val0, uint64_t val1) -> int64_t { + return val0 >> (val1 & 0b11'1111); + }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra8VX) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVX<uint8_t, int8_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](int8_t val0, uint8_t val1) -> int8_t { + return val0 >> (val1 & 0b111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra16VX) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVX<uint16_t, int16_t, uint16_t>( + "Vsra16", /*sew*/ 16, instruction_, + [](int16_t val0, uint16_t val1) -> int16_t { + return val0 >> (val1 & 0b1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra32VX) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVX<uint32_t, int32_t, uint32_t>( + "Vsra32", /*sew*/ 32, instruction_, + [](int32_t val0, uint32_t val1) -> int32_t { + return val0 >> (val1 & 0b1'1111); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vsra64VX) { + SetSemanticFunction(&Vsra); + + BinaryOpTestHelperVX<uint64_t, int64_t, uint64_t>( + "Vsll64", /*sew*/ 64, instruction_, + [](int64_t val0, uint64_t val1) -> int64_t { + return val0 >> (val1 & 0b11'1111); + }); +} + +// Vector narrowing srl. +// Vector-Vector.VV +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl8VV) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVV<uint8_t, uint16_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](uint16_t val0, uint8_t val1) -> uint8_t { + return static_cast<uint8_t>(val0 >> (val1 & 0b1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl16VV) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVV<uint16_t, uint32_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](uint32_t val0, uint16_t val1) -> uint16_t { + return static_cast<uint16_t>(val0 >> (val1 & 0b1'1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl32VV) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVV<uint32_t, uint64_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](uint64_t val0, uint32_t val1) -> uint32_t { + return static_cast<uint32_t>(val0 >> (val1 & 0b11'1111)); + }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl8VX) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVX<uint8_t, uint16_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](uint16_t val0, uint8_t val1) -> uint8_t { + return static_cast<uint8_t>(val0 >> (val1 & 0b1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl16VX) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVX<uint16_t, uint32_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](uint32_t val0, uint16_t val1) -> uint16_t { + return static_cast<uint16_t>(val0 >> (val1 & 0b1'1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsrl32VX) { + SetSemanticFunction(&Vnsrl); + + BinaryOpTestHelperVX<uint32_t, uint64_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](uint64_t val0, uint32_t val1) -> uint32_t { + return static_cast<uint32_t>(val0 >> (val1 & 0b11'1111)); + }); +} + +// Vector narrowing sra. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra8VV) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVV<uint8_t, uint16_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](int16_t val0, uint8_t val1) -> uint8_t { + return static_cast<uint8_t>(val0 >> (val1 & 0b1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra16VV) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVV<uint16_t, uint32_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](int32_t val0, uint16_t val1) -> uint16_t { + return static_cast<uint16_t>(val0 >> (val1 & 0b1'1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra32VV) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVV<uint32_t, uint64_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](int64_t val0, uint32_t val1) -> uint32_t { + return static_cast<uint32_t>(val0 >> (val1 & 0b11'1111)); + }); +} + +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra8VX) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVX<uint8_t, uint16_t, uint8_t>( + "Vsra8", /*sew*/ 8, instruction_, + [](int16_t val0, uint8_t val1) -> uint8_t { + return static_cast<uint8_t>(val0 >> (val1 & 0b1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra16VX) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVX<uint16_t, uint32_t, uint16_t>( + "Vsll16", /*sew*/ 16, instruction_, + [](int32_t val0, uint16_t val1) -> uint16_t { + return static_cast<uint16_t>(val0 >> (val1 & 0b1'1111)); + }); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vnsra32VX) { + SetSemanticFunction(&Vnsra); + + BinaryOpTestHelperVX<uint32_t, uint64_t, uint32_t>( + "Vsll32", /*sew*/ 32, instruction_, + [](int64_t val0, uint32_t val1) -> uint32_t { + return static_cast<uint32_t>(val0 >> (val1 & 0b11'1111)); + }); +} + +// Vector unsigned min. +// Vector-Vector +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu8VV) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vminu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu16VV) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vminu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu32VV) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vminu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu64VV) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vminu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 < val1) ? val0 : val1; + }); +} +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu8VX) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vminu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu16VX) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vminu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu32VX) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vminu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vminu64VX) { + SetSemanticFunction(&Vminu); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vminu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 < val1) ? val0 : val1; + }); +} + +// Vector signed min. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin8VV) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVV<int8_t, int8_t, int8_t>( + "Vmin8", /*sew*/ 8, instruction_, [](int8_t val0, int8_t val1) -> int8_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin16VV) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVV<int16_t, int16_t, int16_t>( + "Vmin16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> int16_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin32VV) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVV<int32_t, int32_t, int32_t>( + "Vmin32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> int32_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin64VV) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVV<int64_t, int64_t, int64_t>( + "Vmin64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> int64_t { + return (val0 < val1) ? val0 : val1; + }); +} +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin8VX) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVX<int8_t, int8_t, int8_t>( + "Vmin8", /*sew*/ 8, instruction_, [](int8_t val0, int8_t val1) -> int8_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin16VX) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVX<int16_t, int16_t, int16_t>( + "Vmin16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> int16_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin32VX) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVX<int32_t, int32_t, int32_t>( + "Vmin32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> int32_t { + return (val0 < val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmin64VX) { + SetSemanticFunction(&Vmin); + BinaryOpTestHelperVX<int64_t, int64_t, int64_t>( + "Vmin64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> int64_t { + return (val0 < val1) ? val0 : val1; + }); +} + +// Vector unsigned max. +// Vector-Vector +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu8VV) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vmaxu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu16VV) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vmaxu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu32VV) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vmaxu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu64VV) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vmaxu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 > val1) ? val0 : val1; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu8VX) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vmaxu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu16VX) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vmaxu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu32VX) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vmaxu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmaxu64VX) { + SetSemanticFunction(&Vmaxu); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vmaxu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 > val1) ? val0 : val1; + }); +} +// Vector signed max. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax8VV) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVV<int8_t, int8_t, int8_t>( + "Vmin8", /*sew*/ 8, instruction_, [](int8_t val0, int8_t val1) -> int8_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax16VV) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVV<int16_t, int16_t, int16_t>( + "Vmin16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> int16_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax32VV) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVV<int32_t, int32_t, int32_t>( + "Vmin32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> int32_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax64VV) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVV<int64_t, int64_t, int64_t>( + "Vmin64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> int64_t { + return (val0 > val1) ? val0 : val1; + }); +} +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax8VX) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVX<int8_t, int8_t, int8_t>( + "Vmin8", /*sew*/ 8, instruction_, [](int8_t val0, int8_t val1) -> int8_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax16VX) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVX<int16_t, int16_t, int16_t>( + "Vmin16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> int16_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax32VX) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVX<int32_t, int32_t, int32_t>( + "Vmin32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> int32_t { + return (val0 > val1) ? val0 : val1; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmax64VX) { + SetSemanticFunction(&Vmax); + BinaryOpTestHelperVX<int64_t, int64_t, int64_t>( + "Vmin64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> int64_t { + return (val0 > val1) ? val0 : val1; + }); +} + +// Integer compare instructions. + +// Vector mask set equal. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq8VV) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVV<uint8_t, uint8_t>( + "Vmseq8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq16VV) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVV<uint16_t, uint16_t>( + "Vmseq16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq32VV) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVV<uint32_t, uint32_t>( + "Vmseq32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq64VV) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVV<uint64_t, uint64_t>( + "Vmseq64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 == val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq8VX) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVX<uint8_t, uint8_t>( + "Vmseq8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq16VX) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVX<uint16_t, uint16_t>( + "Vmseq16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint8_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq32VX) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVX<uint32_t, uint32_t>( + "Vmseq32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint8_t { + return (val0 == val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmseq64VX) { + SetSemanticFunction(&Vmseq); + BinaryMaskOpTestHelperVX<uint64_t, uint64_t>( + "Vmseq64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint8_t { + return (val0 == val1) ? 1 : 0; + }); +} + +// Vector mask set not equal. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne8VV) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVV<uint8_t, uint8_t>( + "Vmsne8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne16VV) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVV<uint16_t, uint16_t>( + "Vmsne16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne32VV) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVV<uint32_t, uint32_t>( + "Vmsne32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne64VV) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVV<uint64_t, uint64_t>( + "Vmsne64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 != val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne8VX) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVX<uint8_t, uint8_t>( + "Vmsne8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne16VX) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVX<uint16_t, uint16_t>( + "Vmsne16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint8_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne32VX) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVX<uint32_t, uint32_t>( + "Vmsne32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint8_t { + return (val0 != val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsne64VX) { + SetSemanticFunction(&Vmsne); + BinaryMaskOpTestHelperVX<uint64_t, uint64_t>( + "Vmsne64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint8_t { + return (val0 != val1) ? 1 : 0; + }); +} + +// Vector mask unsigned set less than. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu8VV) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVV<uint8_t, uint8_t>( + "Vmsltu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu16VV) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVV<uint16_t, uint16_t>( + "Vmsltu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu32VV) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVV<uint32_t, uint32_t>( + "Vmsltu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu64VV) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVV<uint64_t, uint64_t>( + "Vmsltu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 < val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu8VX) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVX<uint8_t, uint8_t>( + "Vmsltu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu16VX) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVX<uint16_t, uint16_t>( + "Vmsltu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu32VX) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVX<uint32_t, uint32_t>( + "Vmsltu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsltu64VX) { + SetSemanticFunction(&Vmsltu); + BinaryMaskOpTestHelperVX<uint64_t, uint64_t>( + "Vmsltu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} + +// Vector mask signed set less than. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt8VV) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVV<int8_t, int8_t>( + "Vmslt8", /*sew*/ 8, instruction_, + [](int8_t val0, int8_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt16VV) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVV<int16_t, int16_t>( + "Vmslt16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> uint16_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt32VV) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVV<int32_t, int32_t>( + "Vmslt32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> uint32_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt64VV) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVV<int64_t, int64_t>( + "Vmslt64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> uint64_t { + return (val0 < val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt8VX) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVX<int8_t, int8_t>( + "Vmslt8", /*sew*/ 8, instruction_, + [](int8_t val0, int8_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt16VX) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVX<int16_t, int16_t>( + "Vmslt16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt32VX) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVX<int32_t, int32_t>( + "Vmslt32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmslt64VX) { + SetSemanticFunction(&Vmslt); + BinaryMaskOpTestHelperVX<int64_t, int64_t>( + "Vmslt64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> uint8_t { + return (val0 < val1) ? 1 : 0; + }); +} + +// Vector mask unsigned set less than or equal. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu8VV) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVV<uint8_t, uint8_t>( + "Vmsleu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu16VV) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVV<uint16_t, uint16_t>( + "Vmsleu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint16_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu32VV) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVV<uint32_t, uint32_t>( + "Vmsleu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint32_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu64VV) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVV<uint64_t, uint64_t>( + "Vmsleu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint64_t { + return (val0 <= val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu8VX) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVX<uint8_t, uint8_t>( + "Vmsleu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu16VX) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVX<uint16_t, uint16_t>( + "Vmsleu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu32VX) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVX<uint32_t, uint32_t>( + "Vmsleu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsleu64VX) { + SetSemanticFunction(&Vmsleu); + BinaryMaskOpTestHelperVX<uint64_t, uint64_t>( + "Vmsleu64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} + +// Vector mask signed set less than or equal. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle8VV) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVV<int8_t, int8_t>( + "Vmsle8", /*sew*/ 8, instruction_, + [](int8_t val0, int8_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle16VV) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVV<int16_t, int16_t>( + "Vmsle16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> uint16_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle32VV) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVV<int32_t, int32_t>( + "Vmsle32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> uint32_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle64VV) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVV<int64_t, int64_t>( + "Vmsle64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> uint64_t { + return (val0 <= val1) ? 1 : 0; + }); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle8VX) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVX<int8_t, int8_t>( + "Vmsle8", /*sew*/ 8, instruction_, + [](int8_t val0, int8_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle16VX) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVX<int16_t, int16_t>( + "Vmsle16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle32VX) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVX<int32_t, int32_t>( + "Vmsle32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsle64VX) { + SetSemanticFunction(&Vmsle); + BinaryMaskOpTestHelperVX<int64_t, int64_t>( + "Vmsle64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> uint8_t { + return (val0 <= val1) ? 1 : 0; + }); +} + +// Vector mask unsigned set greater than. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgtu8VX) { + SetSemanticFunction(&Vmsgtu); + BinaryMaskOpTestHelperVX<uint8_t, uint8_t>( + "Vmsgtu8", /*sew*/ 8, instruction_, + [](uint8_t val0, uint8_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgtu16VX) { + SetSemanticFunction(&Vmsgtu); + BinaryMaskOpTestHelperVX<uint16_t, uint16_t>( + "Vmsgtu16", /*sew*/ 16, instruction_, + [](uint16_t val0, uint16_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgtu32VX) { + SetSemanticFunction(&Vmsgtu); + BinaryMaskOpTestHelperVX<uint32_t, uint32_t>( + "Vmsgtu32", /*sew*/ 32, instruction_, + [](uint32_t val0, uint32_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgtu64VX) { + SetSemanticFunction(&Vmsgtu); + BinaryMaskOpTestHelperVX<uint64_t, uint64_t>( + "Vmsgtuk64", /*sew*/ 64, instruction_, + [](uint64_t val0, uint64_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} + +// Vector mask signed set greater than. +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgt8VX) { + SetSemanticFunction(&Vmsgt); + BinaryMaskOpTestHelperVX<int8_t, int8_t>( + "Vmsgt8", /*sew*/ 8, instruction_, + [](int8_t val0, int8_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgt16VX) { + SetSemanticFunction(&Vmsgt); + BinaryMaskOpTestHelperVX<int16_t, int16_t>( + "Vmsgt16", /*sew*/ 16, instruction_, + [](int16_t val0, int16_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgt32VX) { + SetSemanticFunction(&Vmsgt); + BinaryMaskOpTestHelperVX<int32_t, int32_t>( + "Vmsgt32", /*sew*/ 32, instruction_, + [](int32_t val0, int32_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsgt64VX) { + SetSemanticFunction(&Vmsgt); + BinaryMaskOpTestHelperVX<int64_t, int64_t>( + "Vmsgt64", /*sew*/ 64, instruction_, + [](int64_t val0, int64_t val1) -> uint8_t { + return (val0 > val1) ? 1 : 0; + }); +} + +// Vector unsigned saturated add. +template <typename T> +T VsadduHelper(T val0, T val1) { + T sum = val0 + val1; + if (sum < val1) { + sum = std::numeric_limits<T>::max(); + } + return sum; +} +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu8VV) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vsaddu8", /*sew*/ 8, instruction_, VsadduHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu16VV) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vsaddu16", /*sew*/ 16, instruction_, VsadduHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu32VV) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vsaddu32", /*sew*/ 32, instruction_, VsadduHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu64VV) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vsaddu64", /*sew*/ 64, instruction_, VsadduHelper<uint64_t>); +} + +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu8VX) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vsaddu8", /*sew*/ 8, instruction_, VsadduHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu16VX) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vsaddu16", /*sew*/ 16, instruction_, VsadduHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu32VX) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vsaddu32", /*sew*/ 32, instruction_, VsadduHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsaddu64VX) { + SetSemanticFunction(&Vsaddu); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vsaddu64", /*sew*/ 64, instruction_, VsadduHelper<uint64_t>); +} + +// Vector signed saturated add. +template <typename T> +T VsaddHelper(T val0, T val1) { + using WT = typename WideType<T>::type; + WT wval0 = static_cast<WT>(val0); + WT wval1 = static_cast<WT>(val1); + WT wsum = wval0 + wval1; + if (wsum > std::numeric_limits<T>::max()) { + return std::numeric_limits<T>::max(); + } + if (wsum < std::numeric_limits<T>::min()) { + return std::numeric_limits<T>::min(); + } + T sum = static_cast<T>(wsum); + return sum; +} + +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd8VV) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVV<int8_t, int8_t, int8_t>( + "Vsadd8", /*sew*/ 8, instruction_, VsaddHelper<int8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd16VV) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVV<int16_t, int16_t, int16_t>( + "Vsadd16", /*sew*/ 16, instruction_, VsaddHelper<int16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd32VV) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVV<int32_t, int32_t, int32_t>( + "Vsadd32", /*sew*/ 32, instruction_, VsaddHelper<int32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd64VV) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVV<int64_t, int64_t, int64_t>( + "Vsadd64", /*sew*/ 64, instruction_, VsaddHelper<int64_t>); +} + +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd8VX) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVX<int8_t, int8_t, int8_t>( + "Vsadd8", /*sew*/ 8, instruction_, VsaddHelper<int8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd16VX) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVX<int16_t, int16_t, int16_t>( + "Vsadd16", /*sew*/ 16, instruction_, VsaddHelper<int16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd32VX) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVX<int32_t, int32_t, int32_t>( + "Vsadd32", /*sew*/ 32, instruction_, VsaddHelper<int32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsadd64VX) { + SetSemanticFunction(&Vsadd); + BinaryOpTestHelperVX<int64_t, int64_t, int64_t>( + "Vsadd64", /*sew*/ 64, instruction_, VsaddHelper<int64_t>); +} + +// Vector unsigned saturated subtract. +// Vector-Vector. +template <typename T> +T SsubuHelper(T val0, T val1) { + T diff = val0 - val1; + if (val0 < val1) { + diff = 0; + } + return diff; +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu8VV) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vssubu8", /*sew*/ 8, instruction_, SsubuHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu16VV) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vssubu16", /*sew*/ 16, instruction_, SsubuHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu32VV) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vssubu32", /*sew*/ 32, instruction_, SsubuHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu64VV) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vssubu64", /*sew*/ 64, instruction_, SsubuHelper<uint64_t>); +} + +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu8VX) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vssubu8", /*sew*/ 8, instruction_, SsubuHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu16VX) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vssubu16", /*sew*/ 16, instruction_, SsubuHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu32VX) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vssubu32", /*sew*/ 32, instruction_, SsubuHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssubu64VX) { + SetSemanticFunction(&Vssubu); + BinaryOpTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vssubu64", /*sew*/ 64, instruction_, SsubuHelper<uint64_t>); +} + +// Vector signed saturated subtract. +template <typename T> +T VssubHelper(T val0, T val1) { + using UT = typename MakeUnsigned<T>::type; + UT uval0 = static_cast<UT>(val0); + UT uval1 = static_cast<UT>(val1); + UT udiff = uval0 - uval1; + T diff = static_cast<T>(udiff); + if (val0 < 0 && val1 >= 0 && diff >= 0) return std::numeric_limits<T>::min(); + if (val0 >= 0 && val1 < 0 && diff < 0) return std::numeric_limits<T>::max(); + return diff; +} +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub8VV) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVV<int8_t, int8_t, int8_t>( + "Vssub8", /*sew*/ 8, instruction_, VssubHelper<int8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub16VV) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVV<int16_t, int16_t, int16_t>( + "Vssub16", /*sew*/ 16, instruction_, VssubHelper<int16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub32VV) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVV<int32_t, int32_t, int32_t>( + "Vssub32", /*sew*/ 32, instruction_, VssubHelper<int32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub64VV) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVV<int64_t, int64_t, int64_t>( + "Vssub64", /*sew*/ 64, instruction_, VssubHelper<int64_t>); +} + +// Vector-Scalar +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub8VX) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVX<int8_t, int8_t, int8_t>( + "Vssub8", /*sew*/ 8, instruction_, VssubHelper<int8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub16VX) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVX<int16_t, int16_t, int16_t>( + "Vssub16", /*sew*/ 16, instruction_, VssubHelper<int16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub32VX) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVX<int32_t, int32_t, int32_t>( + "Vssub32", /*sew*/ 32, instruction_, VssubHelper<int32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssub64VX) { + SetSemanticFunction(&Vssub); + BinaryOpTestHelperVX<int64_t, int64_t, int64_t>( + "Vssub64", /*sew*/ 64, instruction_, VssubHelper<int64_t>); +} + +template <typename T> +T VadcHelper(T vs2, T vs1, bool mask) { + return vs2 + vs1 + mask; +} + +// Vector add with carry. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc8VV) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vadc", /*sew*/ 8, instruction_, VadcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc16VV) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vadc", /*sew*/ 16, instruction_, VadcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc32VV) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vadc", /*sew*/ 32, instruction_, VadcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc64VV) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vadc", /*sew*/ 64, instruction_, VadcHelper<uint64_t>); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc8VX) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vadc", /*sew*/ 8, instruction_, VadcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc16VX) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vadc", /*sew*/ 16, instruction_, VadcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc32VX) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vadc", /*sew*/ 32, instruction_, VadcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vadc64VX) { + SetSemanticFunction(&Vadc); + BinaryOpWithMaskTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vadc", /*sew*/ 64, instruction_, VadcHelper<uint64_t>); +} + +template <typename T> +uint8_t VmadcHelper(T vs2, T vs1, bool mask_value) { + T cin = ((vs2 & 0b1) + (vs1 & 0b1) + mask_value); + cin >>= 1; + vs2 >>= 1; + vs1 >>= 1; + T sum = vs2 + vs1 + cin; + sum >>= sizeof(T) * 8 - 1; + return sum; +} + +// Vector compute carry from add with carry. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc8VV) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVV<uint8_t, uint8_t>( + "Vmadc", /*sew*/ 8, instruction_, VmadcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc16VV) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVV<uint16_t, uint16_t>( + "Vmadc", /*sew*/ 16, instruction_, VmadcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc32VV) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVV<uint32_t, uint32_t>( + "Vmadc", /*sew*/ 32, instruction_, VmadcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc64VV) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVV<uint64_t, uint64_t>( + "Vmadc", /*sew*/ 64, instruction_, VmadcHelper<uint64_t>); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc8VX) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVX<uint8_t, uint8_t>( + "Vmadc", /*sew*/ 8, instruction_, VmadcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc16VX) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVX<uint16_t, uint16_t>( + "Vmadc", /*sew*/ 16, instruction_, VmadcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc32VX) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVX<uint32_t, uint32_t>( + "Vmadc", /*sew*/ 32, instruction_, VmadcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmadc64VX) { + SetSemanticFunction(&Vmadc); + BinaryMaskOpWithMaskTestHelperVX<uint64_t, uint64_t>( + "Vmadc", /*sew*/ 64, instruction_, VmadcHelper<uint64_t>); +} + +template <typename T> +T VsbcHelper(T vs2, T vs1, bool mask) { + return vs2 - vs1 - mask; +} +// Vector subtract with borrow. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc8VV) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vsbc", /*sew*/ 8, instruction_, VsbcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc16VV) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vsbc", /*sew*/ 16, instruction_, VsbcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc32VV) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVV<uint32_t, uint32_t, uint32_t>( + "Vsbc", /*sew*/ 32, instruction_, VsbcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc64VV) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vsbc", /*sew*/ 64, instruction_, VsbcHelper<uint64_t>); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc8VX) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vsbc", /*sew*/ 8, instruction_, VsbcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc16VX) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vsbc", /*sew*/ 16, instruction_, VsbcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc32VX) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVX<uint32_t, uint32_t, uint32_t>( + "Vsbc", /*sew*/ 32, instruction_, VsbcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsbc64VX) { + SetSemanticFunction(&Vsbc); + BinaryOpWithMaskTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vsbc", /*sew*/ 64, instruction_, VsbcHelper<uint64_t>); +} + +template <typename T> +uint8_t VmsbcHelper(T vs2, T vs1, bool mask_value) { + if (vs2 == vs1) return mask_value; + if (vs2 < vs1) return 1; + return 0; +} + +// Vector compute carry from subtract with borrow. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc8VV) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVV<uint8_t, uint8_t>( + "Vmsbc", /*sew*/ 8, instruction_, VmsbcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc16VV) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVV<uint16_t, uint16_t>( + "Vmsbc", /*sew*/ 16, instruction_, VmsbcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc32VV) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVV<uint32_t, uint32_t>( + "Vmsbc", /*sew*/ 32, instruction_, VmsbcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc64VV) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVV<uint64_t, uint64_t>( + "Vmsbc", /*sew*/ 64, instruction_, VmsbcHelper<uint64_t>); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc8VX) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVX<uint8_t, uint8_t>( + "Vmsbc", /*sew*/ 8, instruction_, VmsbcHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc16VX) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVX<uint16_t, uint16_t>( + "Vmsbc", /*sew*/ 16, instruction_, VmsbcHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc32VX) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVX<uint32_t, uint32_t>( + "Vmsbc", /*sew*/ 32, instruction_, VmsbcHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmsbc64VX) { + SetSemanticFunction(&Vmsbc); + BinaryMaskOpWithMaskTestHelperVX<uint64_t, uint64_t>( + "Vmsbc", /*sew*/ 64, instruction_, VmsbcHelper<uint64_t>); +} + +// Vector merge. +template <typename T> +T VmergeHelper(T vs2, T vs1, bool mask_value) { + return mask_value ? vs1 : vs2; +} +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge8VV) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVV<uint8_t, uint8_t, uint8_t>( + "Vmerge", /*sew*/ 8, instruction_, VmergeHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge16VV) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVV<uint16_t, uint16_t, uint16_t>( + "Vmerge", /*sew*/ 16, instruction_, VmergeHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge32VV) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVV<uint32_t, uint32_t, uint32_t>( + "mergec", /*sew*/ 32, instruction_, VmergeHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge64VV) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVV<uint64_t, uint64_t, uint64_t>( + "Vmerge", /*sew*/ 64, instruction_, VmergeHelper<uint64_t>); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge8VX) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVX<uint8_t, uint8_t, uint8_t>( + "Vmerge", /*sew*/ 8, instruction_, VmergeHelper<uint8_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge16VX) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVX<uint16_t, uint16_t, uint16_t>( + "Vmerge", /*sew*/ 16, instruction_, VmergeHelper<uint16_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge32VX) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVX<uint32_t, uint32_t, uint32_t>( + "mergec", /*sew*/ 32, instruction_, VmergeHelper<uint32_t>); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmerge64VX) { + SetSemanticFunction(&Vmerge); + BinaryOpWithMaskTestHelperVX<uint64_t, uint64_t, uint64_t>( + "Vmerge", /*sew*/ 64, instruction_, VmergeHelper<uint64_t>); +} + +// This wrapper function factors out the main body of the Vmvr test. +void VmvrWrapper(int num_reg, RiscVCheriotVectorInstructionsTest *tester, + Instruction *inst) { + tester->SetSemanticFunction(absl::bind_front(&Vmvr, num_reg)); + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(uint64_t); + // Input values for 8 registers. + uint64_t vs2_value[vs2_size * 8]; + auto vs2_span = Span<uint64_t>(vs2_value); + tester->AppendVectorRegisterOperands({kVs2}, {kVd}); + // Initialize input values. + tester->FillArrayWithRandomValues<uint64_t>(vs2_span); + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + tester->SetVectorRegisterValues<uint64_t>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + tester->ClearVectorRegisterGroup(kVd, 8); + inst->Execute(); + EXPECT_FALSE(tester->rv_vector()->vector_exception()); + int count = 0; + for (int reg = kVd; reg < kVd + 8; reg++) { + auto dest_span = tester->vreg()[reg]->data_buffer()->Get<uint64_t>(); + for (int i = 0; i < kVectorLengthInBytes / sizeof(uint64_t); i++) { + if (reg < kVd + num_reg) { + EXPECT_EQ(vs2_span[count], dest_span[i]) + << "count: " << count << " i: " << i; + } else { + EXPECT_EQ(0, dest_span[i]) << "count: " << count << " i: " << i; + } + count++; + } + } +} + +// Vector move register. +TEST_F(RiscVCheriotVectorInstructionsTest, Vmvr1) { + VmvrWrapper(1, this, instruction_); +} + +TEST_F(RiscVCheriotVectorInstructionsTest, Vmvr2) { + VmvrWrapper(2, this, instruction_); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmvr4) { + VmvrWrapper(4, this, instruction_); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vmvr8) { + VmvrWrapper(8, this, instruction_); +} + +// Templated helper functions for Vssr testing. +template <typename T> +T VssrHelper(RiscVCheriotVectorInstructionsTest *tester, T vs2, T vs1, + int rounding_mode) { + using UT = typename MakeUnsigned<T>::type; + int max_shift = (sizeof(T) << 3) - 1; + int shift_amount = static_cast<int>(vs1 & max_shift); + // Extract the bits that will be lost + 1. + UT lost_bits = vs2; + if (shift_amount < max_shift) { + lost_bits = vs2 & ~(std::numeric_limits<UT>::max() << (shift_amount + 1)); + } + T result = vs2 >> shift_amount; + result += static_cast<T>(tester->RoundBits(shift_amount + 1, lost_bits)); + return result; +} + +// These wrapper functions simplify the test bodies, and make it a little +// easier to avoid errors due to type and sew specifications. +template <typename T> +void VssrVVWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + // Iterate across rounding modes. + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vssrl_", rm), /*sew*/ sizeof(T) * 8, inst, + [rm, tester](T vs2, T vs1) -> T { + return VssrHelper<T>(tester, vs2, vs1, rm); + }); + } +} +template <typename T> +void VssrVXWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + // Iterate across rounding modes. + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vssrl_", rm), /*sew*/ sizeof(T) * 8, inst, + [rm, tester](T vs2, T vs1) -> T { + return VssrHelper<T>(tester, vs2, vs1, rm); + }); + } +} +// Vector shift right logical with rounding. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl8VV) { + SetSemanticFunction(&Vssrl); + VssrVVWrapper<uint8_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl16VV) { + SetSemanticFunction(&Vssrl); + VssrVVWrapper<uint16_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl32VV) { + SetSemanticFunction(&Vssrl); + VssrVVWrapper<uint32_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl64VV) { + SetSemanticFunction(&Vssrl); + VssrVVWrapper<uint64_t>("Vssrl", instruction_, this); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl8VX) { + SetSemanticFunction(&Vssrl); + VssrVXWrapper<uint8_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl16VX) { + SetSemanticFunction(&Vssrl); + VssrVXWrapper<uint16_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl32VX) { + SetSemanticFunction(&Vssrl); + VssrVXWrapper<uint32_t>("Vssrl", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssrl64VX) { + SetSemanticFunction(&Vssrl); + VssrVXWrapper<uint64_t>("Vssrl", instruction_, this); +} + +// Vector shift right arithmetic with rounding. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra8VV) { + SetSemanticFunction(&Vssra); + VssrVVWrapper<int8_t>("Vssra", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra16VV) { + SetSemanticFunction(&Vssra); + VssrVVWrapper<int16_t>("Vssal", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra32VV) { + SetSemanticFunction(&Vssra); + VssrVVWrapper<int32_t>("Vssal", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra64VV) { + SetSemanticFunction(&Vssra); + VssrVVWrapper<int64_t>("Vssal", instruction_, this); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra8VX) { + SetSemanticFunction(&Vssra); + VssrVXWrapper<int8_t>("Vssra", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra16VX) { + SetSemanticFunction(&Vssra); + VssrVXWrapper<int16_t>("Vssal", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra32VX) { + SetSemanticFunction(&Vssra); + VssrVXWrapper<int32_t>("Vssal", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vssra64VX) { + SetSemanticFunction(&Vssra); + VssrVXWrapper<int64_t>("Vssal", instruction_, this); +} + +// Templated helper functions for Vnclip/Vnclipu instructions. +template <typename T, typename WideT> +T VnclipHelper(RiscVCheriotVectorInstructionsTest *tester, WideT vs2, T vs1, + int rm, CheriotVectorState *rv_vector) { + auto vs1_16 = static_cast<WideT>(vs1); + auto shifted = VssrHelper<WideT>(tester, vs2, vs1_16, rm); + if (shifted < std::numeric_limits<T>::min()) { + rv_vector->set_vxsat(true); + return std::numeric_limits<T>::min(); + } + if (shifted > std::numeric_limits<T>::max()) { + rv_vector->set_vxsat(true); + return std::numeric_limits<T>::max(); + } + return static_cast<T>(shifted); +} + +template <typename T> +void VnclipVVWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + using WT = typename WideType<T>::type; + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVV<T, WT, T>( + absl::StrCat(base_name, "_", rm), sizeof(T) * 8, inst, + [rm, tester](WT vs2, T vs1) -> T { + return VnclipHelper<T, WT>(tester, vs2, vs1, rm, tester->rv_vector()); + }); + } +} +template <typename T> +void VnclipVXWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + using WT = typename WideType<T>::type; + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVV<T, WT, T>( + absl::StrCat(base_name, "_", rm), sizeof(T) * 8, inst, + [rm, tester](WT vs2, T vs1) -> T { + return VnclipHelper<T, WT>(tester, vs2, vs1, rm, tester->rv_vector()); + }); + } +} +// Vector shift right logical with rounding and saturation. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip8VV) { + SetSemanticFunction(&Vnclip); + VnclipVVWrapper<int8_t>("Vnclip", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip16VV) { + SetSemanticFunction(&Vnclip); + VnclipVVWrapper<int16_t>("Vnclip", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip32VV) { + SetSemanticFunction(&Vnclip); + VnclipVVWrapper<int32_t>("Vnclip", instruction_, this); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip8VX) { + SetSemanticFunction(&Vnclip); + VnclipVXWrapper<int8_t>("Vnclip", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip16VX) { + SetSemanticFunction(&Vnclip); + VnclipVXWrapper<int16_t>("Vnclip", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclip32VX) { + SetSemanticFunction(&Vnclip); + VnclipVXWrapper<int32_t>("Vnclip", instruction_, this); +} + +// Vector shift right arithmetic with rounding and saturation. +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu8VV) { + SetSemanticFunction(&Vnclipu); + VnclipVVWrapper<uint8_t>("Vnclipu", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu16VV) { + SetSemanticFunction(&Vnclipu); + VnclipVVWrapper<uint16_t>("Vnclipu", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu32VV) { + SetSemanticFunction(&Vnclipu); + SetSemanticFunction(&Vnclipu); + VnclipVVWrapper<uint32_t>("Vnclipu", instruction_, this); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu8VX) { + SetSemanticFunction(&Vnclipu); + VnclipVXWrapper<uint8_t>("Vnclipu", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu16VX) { + SetSemanticFunction(&Vnclipu); + VnclipVXWrapper<uint16_t>("Vnclipu", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vnclipu32VX) { + SetSemanticFunction(&Vnclipu); + VnclipVXWrapper<uint32_t>("Vnclipu", instruction_, this); +} + +// Vector fractional multiply with rounding and saturation. +template <typename T> +T VsmulHelper(RiscVCheriotVectorInstructionsTest *tester, T vs2, T vs1, int rm, + CheriotVectorState *rv_vector) { + using WT = typename WideType<T>::type; + WT vs2_w = static_cast<WT>(vs2); + WT vs1_w = static_cast<WT>(vs1); + WT prod = vs2_w * vs1_w; + WT res = VssrHelper<WT>(tester, prod, sizeof(T) * 8 - 1, rm); + if (res > std::numeric_limits<T>::max()) { + return std::numeric_limits<T>::max(); + } + if (res < std::numeric_limits<T>::min()) { + return std::numeric_limits<T>::min(); + } + return static_cast<T>(res); +} + +template <typename T> +void VsmulVVWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat(base_name, "_", rm), sizeof(T) * 8, inst, + [rm, tester](T vs2, T vs1) -> T { + return VsmulHelper<T>(tester, vs2, vs1, rm, tester->rv_vector()); + }); + } +} +template <typename T> +void VsmulVXWrapper(absl::string_view base_name, Instruction *inst, + RiscVCheriotVectorInstructionsTest *tester) { + for (int rm = 0; rm < 4; rm++) { + tester->rv_vector()->set_vxrm(rm); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat(base_name, "_", rm), sizeof(T) * 8, inst, + [rm, tester](T vs2, T vs1) -> T { + return VsmulHelper<T>(tester, vs2, vs1, rm, tester->rv_vector()); + }); + } +} +// Vector-Vector. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy8VV) { + SetSemanticFunction(&Vsmul); + VsmulVVWrapper<int8_t>("Vsmul", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy16VV) { + SetSemanticFunction(&Vsmul); + VsmulVVWrapper<int16_t>("Vsuly", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy32VV) { + SetSemanticFunction(&Vsmul); + VsmulVVWrapper<int32_t>("Vsuly", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy64VV) { + SetSemanticFunction(&Vsmul); + VsmulVVWrapper<int64_t>("Vsuly", instruction_, this); +} +// Vector-Scalar. +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy8VX) { + SetSemanticFunction(&Vsmul); + VsmulVXWrapper<int8_t>("Vsmul", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy16VX) { + SetSemanticFunction(&Vsmul); + VsmulVXWrapper<int16_t>("Vsuly", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy32VX) { + SetSemanticFunction(&Vsmul); + VsmulVXWrapper<int32_t>("Vsuly", instruction_, this); +} +TEST_F(RiscVCheriotVectorInstructionsTest, Vsmpy64VX) { + SetSemanticFunction(&Vsmul); + VsmulVXWrapper<int64_t>("Vsuly", instruction_, this); +} +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_opm_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_opm_instructions_test.cc new file mode 100644 index 0000000..9ca5087 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_opm_instructions_test.cc
@@ -0,0 +1,1650 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_opm_instructions.h" + +#include <cstdint> +#include <functional> +#include <ios> +#include <type_traits> + +#include "absl/base/casts.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" + +namespace { + +using Instruction = ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::WideType; + +using ::mpact::sim::cheriot::Vaadd; +using ::mpact::sim::cheriot::Vaaddu; +using ::mpact::sim::cheriot::Vasub; +using ::mpact::sim::cheriot::Vasubu; +using ::mpact::sim::cheriot::Vdiv; +using ::mpact::sim::cheriot::Vdivu; +using ::mpact::sim::cheriot::Vmacc; +using ::mpact::sim::cheriot::Vmadd; +using ::mpact::sim::cheriot::Vmand; +using ::mpact::sim::cheriot::Vmandnot; +using ::mpact::sim::cheriot::Vmnand; +using ::mpact::sim::cheriot::Vmnor; +using ::mpact::sim::cheriot::Vmor; +using ::mpact::sim::cheriot::Vmornot; +using ::mpact::sim::cheriot::Vmul; +using ::mpact::sim::cheriot::Vmulh; +using ::mpact::sim::cheriot::Vmulhsu; +using ::mpact::sim::cheriot::Vmulhu; +using ::mpact::sim::cheriot::Vmxnor; +using ::mpact::sim::cheriot::Vmxor; +using ::mpact::sim::cheriot::Vnmsac; +using ::mpact::sim::cheriot::Vnmsub; +using ::mpact::sim::cheriot::Vrem; +using ::mpact::sim::cheriot::Vremu; +using ::mpact::sim::cheriot::Vwadd; +using ::mpact::sim::cheriot::Vwaddu; +using ::mpact::sim::cheriot::Vwadduw; +using ::mpact::sim::cheriot::Vwaddw; +using ::mpact::sim::cheriot::Vwmacc; +using ::mpact::sim::cheriot::Vwmaccsu; +using ::mpact::sim::cheriot::Vwmaccu; +using ::mpact::sim::cheriot::Vwmaccus; +using ::mpact::sim::cheriot::Vwmul; +using ::mpact::sim::cheriot::Vwmulsu; +using ::mpact::sim::cheriot::Vwmulu; +using ::mpact::sim::cheriot::Vwsub; +using ::mpact::sim::cheriot::Vwsubu; +using ::mpact::sim::cheriot::Vwsubuw; +using ::mpact::sim::cheriot::Vwsubw; + +// Derived test class - adds a test helper function for testing the logical +// mask operation instructions. +class RiscVCheriotVectorOpmInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase { + protected: + void BinaryLogicalMaskOpTestHelper(absl::string_view name, + std::function<bool(bool, bool)> op) { + uint8_t vs2_value[kVectorLengthInBytes]; + uint8_t vs1_value[kVectorLengthInBytes]; + uint8_t vd_value[kVectorLengthInBytes]; + FillArrayWithRandomValues<uint8_t>(vs2_value); + FillArrayWithRandomValues<uint8_t>(vs1_value); + FillArrayWithRandomValues<uint8_t>(vd_value); + AppendVectorRegisterOperands({kVs2, kVs1}, {kVd}); + for (int vstart : {0, 7, 32, 100, 250, 384}) { + for (int vlen_pct : {10, 20, 50, 100}) { + int vlen = + (kVectorLengthInBytes * 8 - vstart) * vlen_pct / 100 + vstart; + CHECK_LE(vlen, kVectorLengthInBytes * 8); + // Configure vector unit for different lmul settings. + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettings[6]; + ConfigureVectorUnit(vtype, vlen); + vlen = rv_vector_->vector_length(); + rv_vector_->set_vstart(vstart); + SetVectorRegisterValues<uint8_t>({{kVs2Name, vs2_value}, + {kVs1Name, vs1_value}, + {kVdName, vd_value}}); + instruction_->Execute(); + auto dst_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes * 8; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool result = (dst_span[mask_index] >> mask_offset) & 0b1; + if ((i < vstart) || (i >= vlen)) { + bool vd = (vd_value[mask_index] >> mask_offset) & 0b1; + EXPECT_EQ(result, vd) << "[" << i << "] " << std::hex + << "vd: " << (int)vd_value[mask_index] + << " dst: " << (int)dst_span[mask_index]; + } else { + bool vs2 = (vs2_value[mask_index] >> mask_offset) & 0b1; + bool vs1 = (vs1_value[mask_index] >> mask_offset) & 0b1; + EXPECT_EQ(result, op(vs2, vs1)) + << "[" << i << "]: " << "op(" << vs2 << ", " << vs1 << ")"; + } + } + } + } + } +}; + +// Helper functions for averaging add and subtract. +template <typename T> +T VaaddHelper(RiscVCheriotVectorOpmInstructionsTest *tester, T vs2, T vs1) { + // Create two sums, lower nibble, and the upper part. Then combine after + // rounding. + T vs2_l = vs2 & 0xf; + T vs1_l = vs1 & 0xf; + T res_l = vs2_l + vs1_l; + T res = (vs2 >> 4) + (vs1 >> 4); + res_l += tester->RoundBits<T>(2, res_l) << 1; + // Add carry. + res += (res_l >> 4); + // Use unsigned type to avoid undefined behavior of left-shifting negative + // numbers. + using UT = typename std::make_unsigned<T>::type; + res = (absl::bit_cast<UT>(res) << 3) | ((res_l >> 1) & 0b111); + return res; +} + +template <typename T> +T VasubHelper(RiscVCheriotVectorOpmInstructionsTest *tester, T vs2, T vs1) { + // Create two diffs, lower nibble, and the upper part. Then combine after + // rounding. + T vs2_l = vs2 & 0xf; + T vs1_l = vs1 & 0xf; + T res_l = vs2_l - vs1_l; + T res_h = (vs2 >> 4) - (vs1 >> 4); + // Subtract borrow. + res_h -= ((res_l >> 4) & 0b1); + // Use unsigned type to avoid undefined behavior of left-shifting negative + // numbers. + using UT = typename std::make_unsigned<T>::type; + T res = (absl::bit_cast<UT>(res_h) << 3) | ((res_l >> 1) & 0b111); + res += tester->RoundBits<T>(2, res_l); + return res; +} + +// Vaaddu vector-vector test helper function. +template <typename T> +inline void VaadduVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vaaddu); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vaaddu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VaaddHelper(tester, val0, val1); + }); +} + +// Vaaddu vector-scalar test helper function. +template <typename T> +inline void VaadduVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vaaddu); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vaaddu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VaaddHelper(tester, val0, val1); + }); +} + +// Test Vaaddu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vaaddu) { + // vector-vector. + VaadduVVHelper<uint8_t>(this); + ResetInstruction(); + VaadduVVHelper<uint16_t>(this); + ResetInstruction(); + VaadduVVHelper<uint32_t>(this); + ResetInstruction(); + VaadduVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar. + VaadduVXHelper<uint8_t>(this); + ResetInstruction(); + VaadduVXHelper<uint16_t>(this); + ResetInstruction(); + VaadduVXHelper<uint32_t>(this); + ResetInstruction(); + VaadduVXHelper<uint64_t>(this); +} + +// Vaadd vector-vector test helper function. +template <typename T> +inline void VaaddVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vaadd); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vaaddu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VaaddHelper(tester, val0, val1); + }); +} + +// Vaadd vector-vector test helper function. +template <typename T> +inline void VaaddVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vaadd); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vaaddu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VaaddHelper(tester, val0, val1); + }); +} + +// Test Vaadd (signed) instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vaadd) { + // Vector-vector. + VaaddVVHelper<int8_t>(this); + ResetInstruction(); + VaaddVVHelper<int16_t>(this); + ResetInstruction(); + VaaddVVHelper<int32_t>(this); + ResetInstruction(); + VaaddVVHelper<int64_t>(this); + // Vector-scalar. + ResetInstruction(); + VaaddVXHelper<int8_t>(this); + ResetInstruction(); + VaaddVXHelper<int16_t>(this); + ResetInstruction(); + VaaddVXHelper<int32_t>(this); + ResetInstruction(); + VaaddVXHelper<int64_t>(this); +} + +// Vasubu vector-vector test helper function. +template <typename T> +inline void VasubuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vasubu); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vasubu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VasubHelper(tester, val0, val1); + }); +} +// Vasubu vector-scalar test helper function. +template <typename T> +inline void VasubuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vasubu); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vasubu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VasubHelper(tester, val0, val1); + }); +} + +// Test Vasubu (unsigned) instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vasubu) { + // Vector-vector. + VasubuVVHelper<uint8_t>(this); + ResetInstruction(); + VasubuVVHelper<uint16_t>(this); + ResetInstruction(); + VasubuVVHelper<uint32_t>(this); + ResetInstruction(); + VasubuVVHelper<uint64_t>(this); + ResetInstruction(); + // Vector-scalar. + VasubuVXHelper<uint8_t>(this); + ResetInstruction(); + VasubuVXHelper<uint16_t>(this); + ResetInstruction(); + VasubuVXHelper<uint32_t>(this); + ResetInstruction(); + VasubuVXHelper<uint64_t>(this); +} + +// Vasub vector-vector test helper function. +template <typename T> +inline void VasubVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vasub); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vasub", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VasubHelper(tester, val0, val1); + }); +} +// Vasub vector-scalar test helper function. +template <typename T> +inline void VasubVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vasub); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vasub", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [tester](T val0, T val1) -> T { + return VasubHelper(tester, val0, val1); + }); +} + +// Test Vasub (signed) instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vasub) { + // Vector-vector. + VasubVVHelper<int8_t>(this); + ResetInstruction(); + VasubVVHelper<int16_t>(this); + ResetInstruction(); + VasubVVHelper<int32_t>(this); + ResetInstruction(); + VasubVVHelper<int64_t>(this); + ResetInstruction(); + // Vector-scalar. + VasubVXHelper<int8_t>(this); + ResetInstruction(); + VasubVXHelper<int16_t>(this); + ResetInstruction(); + VasubVXHelper<int32_t>(this); + ResetInstruction(); + VasubVXHelper<int64_t>(this); +} + +// Testing instructions that perform logical operations on vector masks. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmandnot) { + SetSemanticFunction(&Vmandnot); + BinaryLogicalMaskOpTestHelper( + "Vmandnot", [](bool vs2, bool vs1) -> bool { return vs2 && !vs1; }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmand) { + SetSemanticFunction(&Vmand); + BinaryLogicalMaskOpTestHelper( + "Vmand", [](bool vs2, bool vs1) -> bool { return vs2 && vs1; }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmor) { + SetSemanticFunction(&Vmor); + BinaryLogicalMaskOpTestHelper( + "Vmor", [](bool vs2, bool vs1) -> bool { return vs2 || vs1; }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmxor) { + SetSemanticFunction(&Vmxor); + BinaryLogicalMaskOpTestHelper("Vmxor", [](bool vs2, bool vs1) -> bool { + return (vs1 && !vs2) || (!vs1 && vs2); + }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmornot) { + SetSemanticFunction(&Vmornot); + BinaryLogicalMaskOpTestHelper( + "Vmornot", [](bool vs2, bool vs1) -> bool { return vs2 || !vs1; }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmnand) { + SetSemanticFunction(&Vmnand); + BinaryLogicalMaskOpTestHelper( + "Vmnand", [](bool vs2, bool vs1) -> bool { return !(vs2 && vs1); }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmnor) { + SetSemanticFunction(&Vmnor); + BinaryLogicalMaskOpTestHelper( + "Vmnor", [](bool vs2, bool vs1) -> bool { return !(vs2 || vs1); }); +} + +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmxnor) { + SetSemanticFunction(&Vmxnor); + BinaryLogicalMaskOpTestHelper("Vmxnor", [](bool vs2, bool vs1) -> bool { + return !((vs1 && !vs2) || (!vs1 && vs2)); + }); +} + +// Vdivu vector-vector test helper function. +template <typename T> +inline void VdivuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vdivu); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vdivu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); +} +// Vdivu vector-scalar test helper function. +template <typename T> +inline void VdivuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vdivu); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vdivu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); +} + +// Test vdivu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vdivu) { + // Vector-vector. + VdivuVVHelper<uint8_t>(this); + ResetInstruction(); + VdivuVVHelper<uint16_t>(this); + ResetInstruction(); + VdivuVVHelper<uint32_t>(this); + ResetInstruction(); + VdivuVVHelper<uint64_t>(this); + ResetInstruction(); + // Vector-scalar. + VdivuVXHelper<uint8_t>(this); + ResetInstruction(); + VdivuVXHelper<uint16_t>(this); + ResetInstruction(); + VdivuVXHelper<uint32_t>(this); + ResetInstruction(); + VdivuVXHelper<uint64_t>(this); +} + +// Vdiv vector-vector test helper function. +template <typename T> +inline void VdivVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vdiv); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vdiv", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); +} +// Vdiv vector-scalar test helper function. +template <typename T> +inline void VdivVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vdiv); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vdiv", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return ~vs1; + return vs2 / vs1; + }); +} + +// Test vector-vector vdiv instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vdiv) { + // Vector-vector. + VdivVVHelper<int8_t>(this); + ResetInstruction(); + VdivVVHelper<int16_t>(this); + ResetInstruction(); + VdivVVHelper<int32_t>(this); + ResetInstruction(); + VdivVVHelper<int64_t>(this); + ResetInstruction(); + // Vector-scalar. + VdivVXHelper<int8_t>(this); + ResetInstruction(); + VdivVXHelper<int16_t>(this); + ResetInstruction(); + VdivVXHelper<int32_t>(this); + ResetInstruction(); + VdivVXHelper<int64_t>(this); +} + +// Vremu vector-vector test helper function. +template <typename T> +inline void VremuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vremu); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vremu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); +} +// Vremu vector-scalar test helper function. +template <typename T> +inline void VremuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vremu); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vremu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); +} + +// Test vremu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vremu) { + // Vector-vector. + VremuVVHelper<uint8_t>(this); + ResetInstruction(); + VremuVVHelper<uint16_t>(this); + ResetInstruction(); + VremuVVHelper<uint32_t>(this); + ResetInstruction(); + VremuVVHelper<uint64_t>(this); + ResetInstruction(); + // Vector-scalar. + VremuVXHelper<uint8_t>(this); + ResetInstruction(); + VremuVXHelper<uint16_t>(this); + ResetInstruction(); + VremuVXHelper<uint32_t>(this); + ResetInstruction(); + VremuVXHelper<uint64_t>(this); +} + +// Vrem vector-vector test helper function. +template <typename T> +inline void VremVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vrem); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vrem", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); +} +// Vrem vector-scalar test helper function. +template <typename T> +inline void VremVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vrem); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vrem", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + if (vs1 == 0) return vs2; + return vs2 % vs1; + }); +} + +// Test vector-vector vrem instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vrem) { + // vector-vector. + VremVVHelper<int8_t>(this); + ResetInstruction(); + VremVVHelper<int16_t>(this); + ResetInstruction(); + VremVVHelper<int32_t>(this); + ResetInstruction(); + VremVVHelper<int64_t>(this); + ResetInstruction(); + // vector-scalar. + VremVXHelper<int8_t>(this); + ResetInstruction(); + VremVXHelper<int16_t>(this); + ResetInstruction(); + VremVXHelper<int32_t>(this); + ResetInstruction(); + VremVXHelper<int64_t>(this); +} + +// Vmulhu vector-vector test helper function. +template <typename T> +inline void VmulhuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmulhu); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vmulhu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + absl::uint128 vs2_w = static_cast<absl::uint128>(vs2); + absl::uint128 vs1_w = static_cast<absl::uint128>(vs1); + absl::uint128 vd_w = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<T>(vd_w); + }); +} +// Vmulhu vector-scalar test helper function. +template <typename T> +inline void VmulhuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmulhu); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vmulhu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + absl::uint128 vs2_w = static_cast<absl::uint128>(vs2); + absl::uint128 vs1_w = static_cast<absl::uint128>(vs1); + absl::uint128 vd_w = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<T>(vd_w); + }); +} + +// Test vector-vector vmulhu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmulhu) { + // vector-vector. + VmulhuVVHelper<uint8_t>(this); + ResetInstruction(); + VmulhuVVHelper<uint16_t>(this); + ResetInstruction(); + VmulhuVVHelper<uint32_t>(this); + ResetInstruction(); + VmulhuVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar. + VmulhuVXHelper<uint8_t>(this); + ResetInstruction(); + VmulhuVXHelper<uint16_t>(this); + ResetInstruction(); + VmulhuVXHelper<uint32_t>(this); + ResetInstruction(); + VmulhuVXHelper<uint64_t>(this); +} + +// Vmulh vector-vector test helper function. +template <typename T> +inline void VmulhVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmulh); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vmulh", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + absl::int128 vs2_w = static_cast<absl::int128>(vs2); + absl::int128 vs1_w = static_cast<absl::int128>(vs1); + absl::int128 vd_w = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<T>(vd_w); + }); +} + +// Vmulh vector-scalar test helper function. +template <typename T> +inline void VmulhVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmulh); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vmulh", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + absl::int128 vs2_w = static_cast<absl::int128>(vs2); + absl::int128 vs1_w = static_cast<absl::int128>(vs1); + absl::int128 vd_w = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<T>(vd_w); + }); +} + +// Test vector-vector vmulh instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmulh) { + // vector-vector. + VmulhVVHelper<int8_t>(this); + ResetInstruction(); + VmulhVVHelper<int16_t>(this); + ResetInstruction(); + VmulhVVHelper<int32_t>(this); + ResetInstruction(); + VmulhVVHelper<int64_t>(this); + ResetInstruction(); + // vector-scalar. + VmulhVXHelper<int8_t>(this); + ResetInstruction(); + VmulhVXHelper<int16_t>(this); + ResetInstruction(); + VmulhVXHelper<int32_t>(this); + ResetInstruction(); + VmulhVXHelper<int64_t>(this); +} + +// Vmul vector-vector test helper function. +template <typename T> +inline void VmulVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vmul); + tester->BinaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vmul", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + return static_cast<T>(static_cast<WT>(vs2) * static_cast<WT>(vs1)); + }); +} + +// Vmul vector-scalar test helper function. +template <typename T> +inline void VmulVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vmul); + tester->BinaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vmul", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> T { + return static_cast<T>(static_cast<WT>(vs2) * static_cast<WT>(vs1)); + }); +} + +// Test vmulh instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmul) { + // vector-vector. + VmulVVHelper<int8_t>(this); + ResetInstruction(); + VmulVVHelper<int16_t>(this); + ResetInstruction(); + VmulVVHelper<int32_t>(this); + ResetInstruction(); + VmulVVHelper<int64_t>(this); + ResetInstruction(); + // vector-scalar. + VmulVXHelper<int8_t>(this); + ResetInstruction(); + VmulVXHelper<int16_t>(this); + ResetInstruction(); + VmulVXHelper<int32_t>(this); + ResetInstruction(); + VmulVXHelper<int64_t>(this); +} + +// Vmulhsu vector-vector test helper function. +template <typename T> +inline void VmulhsuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using ST = typename std::make_signed<T>::type; + tester->SetSemanticFunction(&Vmulhsu); + tester->BinaryOpTestHelperVV<T, ST, T>( + absl::StrCat("Vmulhsu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](ST vs2, T vs1) -> T { + absl::int128 vs2_w = static_cast<absl::int128>(vs2); + absl::int128 vs1_w = static_cast<absl::int128>(vs1); + absl::int128 res = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<ST>(res); + }); +} + +// Vmulhsu vector-scalar test helper function. +template <typename T> +inline void VmulhsuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using ST = typename std::make_signed<T>::type; + tester->SetSemanticFunction(&Vmulhsu); + tester->BinaryOpTestHelperVX<T, ST, T>( + absl::StrCat("Vmulhsu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](ST vs2, T vs1) -> T { + absl::int128 vs2_w = static_cast<absl::int128>(vs2); + absl::int128 vs1_w = static_cast<absl::int128>(vs1); + absl::int128 res = (vs2_w * vs1_w) >> (sizeof(T) * 8); + return static_cast<ST>(res); + }); +} + +// Test vmulhsu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmulhsu) { + // vector-vector + VmulhsuVVHelper<uint8_t>(this); + ResetInstruction(); + VmulhsuVVHelper<uint16_t>(this); + ResetInstruction(); + VmulhsuVVHelper<uint32_t>(this); + ResetInstruction(); + VmulhsuVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar + VmulhsuVXHelper<uint8_t>(this); + ResetInstruction(); + VmulhsuVXHelper<uint16_t>(this); + ResetInstruction(); + VmulhsuVXHelper<uint32_t>(this); + ResetInstruction(); + VmulhsuVXHelper<uint64_t>(this); +} + +// Vmadd vector-vector test helper function. +template <typename T> +inline void VmaddVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmadd); + tester->TernaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vmadd", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>((vs1_32 * vd_32) + vs2_32); + } + T res = vs1 * vd + vs2; + return res; + }); +} + +// Vmadd vector-scalar test helper function. +template <typename T> +inline void VmaddVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmadd); + tester->TernaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vmadd", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>((vs1_32 * vd_32) + vs2_32); + } + T res = vs1 * vd + vs2; + return res; + }); +} + +// Test vmadd instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmadd) { + // vector-vector + VmaddVVHelper<uint8_t>(this); + ResetInstruction(); + VmaddVVHelper<uint16_t>(this); + ResetInstruction(); + VmaddVVHelper<uint32_t>(this); + ResetInstruction(); + VmaddVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar + VmaddVXHelper<uint8_t>(this); + ResetInstruction(); + VmaddVXHelper<uint16_t>(this); + ResetInstruction(); + VmaddVXHelper<uint32_t>(this); + ResetInstruction(); + VmaddVXHelper<uint64_t>(this); +} + +// Vnmsub vector-vector test helper function. +template <typename T> +inline void VnmsubVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vnmsub); + tester->TernaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vnmsub", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>(-(vs1_32 * vd_32) + vs2_32); + } + T res = -(vs1 * vd) + vs2; + return res; + }); +} + +// Vnmsub vector-scalar test helper function. +template <typename T> +inline void VnmsubVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vnmsub); + tester->TernaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vnmsub", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>(-(vs1_32 * vd_32) + vs2_32); + } + T res = -(vs1 * vd) + vs2; + return res; + }); +} + +// Test vnmsub instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vnmsub) { + // vector-vector + VnmsubVVHelper<uint8_t>(this); + ResetInstruction(); + VnmsubVVHelper<uint16_t>(this); + ResetInstruction(); + VnmsubVVHelper<uint32_t>(this); + ResetInstruction(); + VnmsubVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar + VnmsubVXHelper<uint8_t>(this); + ResetInstruction(); + VnmsubVXHelper<uint16_t>(this); + ResetInstruction(); + VnmsubVXHelper<uint32_t>(this); + ResetInstruction(); + VnmsubVXHelper<uint64_t>(this); +} + +// Vmacc vector-vector test helper function. +template <typename T> +inline void VmaccVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmacc); + tester->TernaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vmacc", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>((vs1_32 * vs2_32) + vd_32); + } + T res = (vs1 * vs2) + vd; + return res; + }); +} + +// Vmacc vector-scalar test helper function. +template <typename T> +inline void VmaccVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vmacc); + tester->TernaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vmacc", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>((vs1_32 * vs2_32) + vd_32); + } + T res = (vs1 * vs2) + vd; + return res; + }); +} + +// Test vmacc instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vmacc) { + // vector-vector + VmaccVVHelper<uint8_t>(this); + ResetInstruction(); + VmaccVVHelper<uint16_t>(this); + ResetInstruction(); + VmaccVVHelper<uint32_t>(this); + ResetInstruction(); + VmaccVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar + VmaccVXHelper<uint8_t>(this); + ResetInstruction(); + VmaccVXHelper<uint16_t>(this); + ResetInstruction(); + VmaccVXHelper<uint32_t>(this); + ResetInstruction(); + VmaccVXHelper<uint64_t>(this); +} + +// Vnmsac vector-vector test helper function. +template <typename T> +inline void VnmsacVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vnmsac); + tester->TernaryOpTestHelperVV<T, T, T>( + absl::StrCat("Vnmsac", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>(-(vs1_32 * vs2_32) + vd_32); + } + T res = -(vs1 * vs2) + vd; + return res; + }); +} + +// Vnmsac vector-scalar test helper function. +template <typename T> +inline void VnmsacVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + tester->SetSemanticFunction(&Vnmsac); + tester->TernaryOpTestHelperVX<T, T, T>( + absl::StrCat("Vnmsac", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, T vd) { + if (sizeof(T) < 4) { + uint32_t vs1_32 = vs1; + uint32_t vs2_32 = vs2; + uint32_t vd_32 = vd; + return static_cast<T>(-(vs1_32 * vs2_32) + vd_32); + } + T res = -(vs1 * vs2) + vd; + return res; + }); +} + +// Test vnmsac instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vnmsac) { + // vector-vector + VnmsacVVHelper<uint8_t>(this); + ResetInstruction(); + VnmsacVVHelper<uint16_t>(this); + ResetInstruction(); + VnmsacVVHelper<uint32_t>(this); + ResetInstruction(); + VnmsacVVHelper<uint64_t>(this); + ResetInstruction(); + // vector-scalar + VnmsacVXHelper<uint8_t>(this); + ResetInstruction(); + VnmsacVXHelper<uint16_t>(this); + ResetInstruction(); + VnmsacVXHelper<uint32_t>(this); + ResetInstruction(); + VnmsacVXHelper<uint64_t>(this); +} + +// Vwaddu vector-vector test helper function. +template <typename T> +inline void VwadduVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwaddu); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwaddu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) + static_cast<WT>(vs1); + }); +} + +// Vwaddu vector-scalar test helper function. +template <typename T> +inline void VwadduVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwaddu); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwaddu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) + static_cast<WT>(vs1); + }); +} + +// Vector widening unsigned add. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwaddu) { + // vector-vector. + VwadduVVHelper<uint8_t>(this); + ResetInstruction(); + VwadduVVHelper<uint16_t>(this); + ResetInstruction(); + VwadduVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar. + VwadduVXHelper<uint8_t>(this); + ResetInstruction(); + VwadduVXHelper<uint16_t>(this); + ResetInstruction(); + VwadduVXHelper<uint32_t>(this); +} + +// Vwsubu vector-vector test helper function. +template <typename T> +inline void VwsubuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubu); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwsubu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) - static_cast<WT>(vs1); + }); +} + +// Vwsubu vector-scalar test helper function. +template <typename T> +inline void VwsubuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubu); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwsubu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) - static_cast<WT>(vs1); + }); +} + +// Vector widening unsigned subtract. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwsubu) { + // vector-vector. + VwsubuVVHelper<uint8_t>(this); + ResetInstruction(); + VwsubuVVHelper<uint16_t>(this); + ResetInstruction(); + VwsubuVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar. + VwsubuVXHelper<uint8_t>(this); + ResetInstruction(); + VwsubuVXHelper<uint16_t>(this); + ResetInstruction(); + VwsubuVXHelper<uint32_t>(this); +} + +// Vwadd vector-vector test helper function. +template <typename T> +inline void VwaddVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwadd); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwadd", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) + static_cast<WT>(vs1); + }); +} + +// Vwadd vector-scalar test helper function. +template <typename T> +inline void VwaddVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwadd); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwadd", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) + static_cast<WT>(vs1); + }); +} + +// Vector videning signed addition. (sew * 2) = sew + sew. +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwadd) { + // vector-vector. + VwaddVVHelper<int8_t>(this); + ResetInstruction(); + VwaddVVHelper<int16_t>(this); + ResetInstruction(); + VwaddVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwaddVXHelper<int8_t>(this); + ResetInstruction(); + VwaddVXHelper<int16_t>(this); + ResetInstruction(); + VwaddVXHelper<int32_t>(this); +} + +// Vwsub vector-vector test helper function. +template <typename T> +inline void VwsubVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsub); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwsub", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + WT vs2_w = vs2; + WT vs1_w = vs1; + WT res = vs2_w - vs1_w; + return res; + }); +} + +// Vwsub vector-scalar test helper function. +template <typename T> +inline void VwsubVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsub); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwsub", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + WT vs2_w = vs2; + WT vs1_w = vs1; + WT res = vs2_w - vs1_w; + return res; + }); +} + +// Vector widening unsigned subtract. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwsub) { + // vector-vector. + VwsubVVHelper<int8_t>(this); + ResetInstruction(); + VwsubVVHelper<int16_t>(this); + ResetInstruction(); + VwsubVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwsubVXHelper<int8_t>(this); + ResetInstruction(); + VwsubVXHelper<int16_t>(this); + ResetInstruction(); + VwsubVXHelper<int32_t>(this); +} + +// Vwadduw vector-vector test helper function. +template <typename T> +inline void VwadduwVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwadduw); + tester->BinaryOpTestHelperVV<WT, WT, T>( + absl::StrCat("Vwadduw", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 + static_cast<WT>(vs1); }); +} + +// Vwadduw vector-scalar test helper function. +template <typename T> +inline void VwadduwVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwadduw); + tester->BinaryOpTestHelperVX<WT, WT, T>( + absl::StrCat("Vwadduw", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 + static_cast<WT>(vs1); }); +} + +// Vector widening unsigned add. (sew * 2) = (sew * 2) + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwadduw) { + // vector-vector. + VwadduwVVHelper<uint8_t>(this); + ResetInstruction(); + VwadduwVVHelper<uint16_t>(this); + ResetInstruction(); + VwadduwVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar. + VwadduwVXHelper<uint8_t>(this); + ResetInstruction(); + VwadduwVXHelper<uint16_t>(this); + ResetInstruction(); + VwadduwVXHelper<uint32_t>(this); +} + +// Vwsubuw vector-vector test helper function. +template <typename T> +inline void VwsubuwVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubuw); + tester->BinaryOpTestHelperVV<WT, WT, T>( + absl::StrCat("Vwsubuw", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 - static_cast<WT>(vs1); }); +} + +// Vwsubuw vector-scalar test helper function. +template <typename T> +inline void VwsubuwVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubuw); + tester->BinaryOpTestHelperVX<WT, WT, T>( + absl::StrCat("Vwsubuw", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 - static_cast<WT>(vs1); }); +} + +// Vector widening unsigned subtract. (sew * 2) = (sew * 2) + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwsubuw) { + // vector-vector. + VwsubuwVVHelper<uint8_t>(this); + ResetInstruction(); + VwsubuwVVHelper<uint16_t>(this); + ResetInstruction(); + VwsubuwVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar. + VwsubuwVXHelper<uint8_t>(this); + ResetInstruction(); + VwsubuwVXHelper<uint16_t>(this); + ResetInstruction(); + VwsubuwVXHelper<uint32_t>(this); +} + +// Vwaddw vector-vector test helper function. +template <typename T> +inline void VwaddwVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwaddw); + tester->BinaryOpTestHelperVV<WT, WT, T>( + absl::StrCat("Vwaddw", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 + static_cast<WT>(vs1); }); +} + +// Vwaddw vector-scalar test helper function. +template <typename T> +inline void VwaddwVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwaddw); + tester->BinaryOpTestHelperVX<WT, WT, T>( + absl::StrCat("Vwaddw", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 + static_cast<WT>(vs1); }); +} + +// Vector widening signed add. (sew * 2) = (sew * 2) + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwaddw) { + // vector-vector. + VwaddwVVHelper<int8_t>(this); + ResetInstruction(); + VwaddwVVHelper<int16_t>(this); + ResetInstruction(); + VwaddwVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwaddwVXHelper<int8_t>(this); + ResetInstruction(); + VwaddwVXHelper<int16_t>(this); + ResetInstruction(); + VwaddwVXHelper<int32_t>(this); +} + +// Vwsubw vector-vector test helper function. +template <typename T> +inline void VwsubwVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubw); + tester->BinaryOpTestHelperVV<WT, WT, T>( + absl::StrCat("Vwsubw", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 - static_cast<WT>(vs1); }); +} + +// Vwsubw vector-scalar test helper function. +template <typename T> +inline void VwsubwVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwsubw); + tester->BinaryOpTestHelperVX<WT, WT, T>( + absl::StrCat("Vwsubw", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), + [](WT vs2, T vs1) -> WT { return vs2 - static_cast<WT>(vs1); }); +} + +// Vector widening signed subtract. (sew * 2) = (sew * 2) + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwsubw) { + // vector-vector. + VwsubwVVHelper<int8_t>(this); + ResetInstruction(); + VwsubwVVHelper<int16_t>(this); + ResetInstruction(); + VwsubwVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwsubwVXHelper<int8_t>(this); + ResetInstruction(); + VwsubwVXHelper<int16_t>(this); + ResetInstruction(); + VwsubwVXHelper<int32_t>(this); +} + +// Vwmul vector-vector test helper function. +template <typename T> +inline void VwmuluVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmulu); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwmulu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vwmulu vector-scalar test helper function. +template <typename T> +inline void VwmuluVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmulu); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwmulu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vector widening signed multiply. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmulu) { + // vector-vector. + VwmuluVVHelper<uint8_t>(this); + ResetInstruction(); + VwmuluVVHelper<uint16_t>(this); + ResetInstruction(); + VwmuluVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar. + VwmuluVXHelper<uint8_t>(this); + ResetInstruction(); + VwmuluVXHelper<uint16_t>(this); + ResetInstruction(); + VwmuluVXHelper<uint32_t>(this); +} + +// Vwmul vector-vector test helper function. +template <typename T> +inline void VwmulVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmul); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwmul", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vwmul vector-scalar test helper function. +template <typename T> +inline void VwmulVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmul); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwmul", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vector widening signed multiply. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmul) { + // vector-vector. + VwmulVVHelper<int8_t>(this); + ResetInstruction(); + VwmulVVHelper<int16_t>(this); + ResetInstruction(); + VwmulVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwmulVXHelper<int8_t>(this); + ResetInstruction(); + VwmulVXHelper<int16_t>(this); + ResetInstruction(); + VwmulVXHelper<int32_t>(this); +} + +// Vwmul vector-vector test helper function. +template <typename T> +inline void VwmulsuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmulsu); + tester->BinaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwmulsu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, UT vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vwmulsu vector-scalar test helper function. +template <typename T> +inline void VwmulsuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmulsu); + tester->BinaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwmulsu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, UT vs1) -> WT { + return static_cast<WT>(vs2) * static_cast<WT>(vs1); + }); +} + +// Vector widening signed multiply. (sew * 2) = sew + sew +// There is no test for sew == 64 bits, as this is a widening operation, +// and 64 bit values are the max sized vector elements. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmulsu) { + // vector-vector. + VwmulsuVVHelper<int8_t>(this); + ResetInstruction(); + VwmulsuVVHelper<int16_t>(this); + ResetInstruction(); + VwmulsuVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar. + VwmulsuVXHelper<int8_t>(this); + ResetInstruction(); + VwmulsuVXHelper<int16_t>(this); + ResetInstruction(); + VwmulsuVXHelper<int32_t>(this); +} + +// Vmaccu vector-vector test helper function. +template <typename T> +inline void VwmaccuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmaccu); + tester->TernaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwmaccu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, WT vd) { + return static_cast<WT>(vs2) * static_cast<WT>(vs1) + vd; + }); +} + +// Vwmaccu vector-scalar test helper function. +template <typename T> +inline void VwmaccuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmaccu); + tester->TernaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwmaccu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, WT vd) { + return static_cast<WT>(vs2) * static_cast<WT>(vs1) + vd; + }); +} + +// Test vwmaccu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmaccu) { + // vector-vector + VwmaccuVVHelper<uint8_t>(this); + ResetInstruction(); + VwmaccuVVHelper<uint16_t>(this); + ResetInstruction(); + VwmaccuVVHelper<uint32_t>(this); + ResetInstruction(); + // vector-scalar + VwmaccuVXHelper<uint8_t>(this); + ResetInstruction(); + VwmaccuVXHelper<uint16_t>(this); + ResetInstruction(); + VwmaccuVXHelper<uint32_t>(this); +} + +// Vmacc vector-vector test helper function. +template <typename T> +inline void VwmaccVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmacc); + tester->TernaryOpTestHelperVV<WT, T, T>( + absl::StrCat("Vwmacc", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, WT vd) -> WT { + WT vs1_w = vs1; + WT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + using UWT = typename std::make_unsigned<WT>::type; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Vwmacc vector-scalar test helper function. +template <typename T> +inline void VwmaccVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + tester->SetSemanticFunction(&Vwmacc); + tester->TernaryOpTestHelperVX<WT, T, T>( + absl::StrCat("Vwmacc", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, T vs1, WT vd) -> WT { + WT vs1_w = vs1; + WT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + using UWT = typename std::make_unsigned<WT>::type; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Test vwmacc instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmacc) { + // vector-vector + VwmaccVVHelper<int8_t>(this); + ResetInstruction(); + VwmaccVVHelper<int16_t>(this); + ResetInstruction(); + VwmaccVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar + VwmaccVXHelper<int8_t>(this); + ResetInstruction(); + VwmaccVXHelper<int16_t>(this); + ResetInstruction(); + VwmaccVXHelper<int32_t>(this); +} + +// Vmaccus vector-vector test helper function. +template <typename T> +inline void VwmaccusVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmaccus); + tester->TernaryOpTestHelperVV<WT, T, UT>( + absl::StrCat("Vwmaccus", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, UT vs1, WT vd) -> WT { + using UWT = typename std::make_unsigned<WT>::type; + UWT vs1_w = vs1; + WT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Vwmaccus vector-scalar test helper function. +template <typename T> +inline void VwmaccusVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmaccus); + tester->TernaryOpTestHelperVX<WT, T, UT>( + absl::StrCat("Vwmaccus", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](T vs2, UT vs1, WT vd) -> WT { + using UWT = typename std::make_unsigned<WT>::type; + UWT vs1_w = vs1; + WT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Test vwmaccus instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmaccus) { + // vector-vector + VwmaccusVVHelper<int8_t>(this); + ResetInstruction(); + VwmaccusVVHelper<int16_t>(this); + ResetInstruction(); + VwmaccusVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar + VwmaccusVXHelper<int8_t>(this); + ResetInstruction(); + VwmaccusVXHelper<int16_t>(this); + ResetInstruction(); + VwmaccusVXHelper<int32_t>(this); +} + +// Vmaccsu vector-vector test helper function. +template <typename T> +inline void VwmaccsuVVHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmaccsu); + tester->TernaryOpTestHelperVV<WT, UT, T>( + absl::StrCat("Vwmaccsu", sizeof(T) * 8, "vv"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](UT vs2, T vs1, WT vd) -> WT { + using UWT = typename std::make_unsigned<WT>::type; + WT vs1_w = vs1; + UWT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Vwmaccsu vector-scalar test helper function. +template <typename T> +inline void VwmaccsuVXHelper(RiscVCheriotVectorOpmInstructionsTest *tester) { + using WT = typename WideType<T>::type; + using UT = typename std::make_unsigned<T>::type; + tester->SetSemanticFunction(&Vwmaccsu); + tester->TernaryOpTestHelperVX<WT, UT, T>( + absl::StrCat("Vwmaccsu", sizeof(T) * 8, "vx"), /*sew*/ sizeof(T) * 8, + tester->instruction(), [](UT vs2, T vs1, WT vd) -> WT { + using UWT = typename std::make_unsigned<WT>::type; + WT vs1_w = vs1; + UWT vs2_w = vs2; + WT prod = vs1_w * vs2_w; + WT res = absl::bit_cast<UWT>(prod) + absl::bit_cast<UWT>(vd); + return res; + }); +} + +// Test vwmaccsu instructions. +TEST_F(RiscVCheriotVectorOpmInstructionsTest, Vwmaccsu) { + // vector-vector + VwmaccsuVVHelper<int8_t>(this); + ResetInstruction(); + VwmaccsuVVHelper<int16_t>(this); + ResetInstruction(); + VwmaccsuVVHelper<int32_t>(this); + ResetInstruction(); + // vector-scalar + VwmaccsuVXHelper<int8_t>(this); + ResetInstruction(); + VwmaccsuVXHelper<int16_t>(this); + ResetInstruction(); + VwmaccsuVXHelper<int32_t>(this); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_permute_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_permute_instructions_test.cc new file mode 100644 index 0000000..8fd7ad5 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_permute_instructions_test.cc
@@ -0,0 +1,616 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_permute_instructions.h" + +#include <cstdint> +#include <vector> + +#include "absl/random/random.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "riscv//riscv_register.h" + +// This file contains tests for the RiscV vector permute instructions. + +namespace { + +using ::mpact::sim::cheriot::CheriotRegister; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::riscv::RVVectorRegister; + +using ::mpact::sim::cheriot::Vcompress; +using ::mpact::sim::cheriot::Vrgather; +using ::mpact::sim::cheriot::Vrgatherei16; +using ::mpact::sim::cheriot::Vslide1down; +using ::mpact::sim::cheriot::Vslide1up; +using ::mpact::sim::cheriot::Vslidedown; +using ::mpact::sim::cheriot::Vslideup; + +class RiscVCheriotVectorPermuteInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase {}; + +// Helper function for vector-vector vrgather instructions. +template <typename T, typename I> +void VrgatherVVHelper(RiscVCheriotVectorPermuteInstructionsTest *tester, + Instruction *inst) { + auto *rv_vector = tester->rv_vector(); + // Configure vector unit for sew and maximum lmul. + uint32_t vtype = 0; + int max_regs = 8; + if (sizeof(I) > sizeof(T)) { + // This happens for vrgatherei16 when sew is 8. + vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[6]; + max_regs = 4; + } else { + vtype = (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[7]; + } + tester->ConfigureVectorUnit(vtype, 2048); + + int vlen = rv_vector->vector_length(); + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + int max_index = num_values_per_reg * max_regs; + int num_indices_per_reg = kVectorLengthInBytes / sizeof(I); + // Initialize vs2 and vs1 to random values. + for (int reg = kVs2; reg < kVs2 + max_regs; reg++) { + auto span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + span[i] = tester->RandomValue<T>(); + } + } + for (int reg = kVs1; reg < kVs1 + max_regs; reg++) { + auto span = tester->vreg()[reg]->data_buffer()->Get<I>(); + for (int i = 0; i < num_indices_per_reg; i++) { + span[i] = + absl::Uniform(absl::IntervalClosed, tester->bitgen(), 0, 2 * vlen); + } + } + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kA5Mask}}); + + inst->Execute(); + for (int i = 0; i < vlen; i++) { + int value_reg_offset = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + int index_reg_offset = i / num_indices_per_reg; + int index_elem_index = i % num_indices_per_reg; + int mask_index = i >> 8; + int mask_offset = i & 0b111; + bool mask_value = (kA5Mask[mask_index] >> mask_offset) & 0b1; + // Get the destination value. + T dst = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + if (mask_value) { + // If it's an active element, get the index value. + I indx = tester->vreg()[kVs1 + index_reg_offset]->data_buffer()->Get<I>( + index_elem_index); + if (indx >= max_index) { // If the index is out of range, compare to 0. + EXPECT_EQ(0, dst); + } else { // Else, get the src value at that index and compare. + int reg = kVs2 + indx / num_values_per_reg; + int element = indx % num_values_per_reg; + T src = tester->vreg()[reg]->data_buffer()->Get<T>(element); + EXPECT_EQ(src, dst); + } + } else { // Inactive values are unchanged. + T src = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst) << "index: " << i << " offset: " << value_reg_offset + << " elem: " << value_elem_index; + } + } +} + +// Helper function for vector-scalar vrgather instructions. +template <typename T> +void VrgatherVSHelper(RiscVCheriotVectorPermuteInstructionsTest *tester, + Instruction *inst) { + auto *rv_vector = tester->rv_vector(); + // Configure vector unit. + uint32_t vtype = + (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[7]; + tester->ConfigureVectorUnit(vtype, 2048); + int vlen = rv_vector->vector_length(); + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + int max_index = num_values_per_reg * 8; + // Initialize vs2 to random values. + for (int reg = kVs2; reg < kVs2 + 8; reg++) { + auto span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + span[i] = tester->RandomValue<T>(); + } + } + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kA5Mask}}); + // Try 20 different index values randomly. + for (int num = 0; num < 20; num++) { + // Set a random index value. + CheriotRegister::ValueType index_value = + absl::Uniform(absl::IntervalClosed, tester->bitgen(), 0, 2 * vlen); + tester->SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, index_value}}); + // Execute the instruction. + inst->Execute(); + for (int i = 0; i < vlen; i++) { + int value_reg_offset = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + int mask_index = i >> 8; + int mask_offset = i & 0b111; + bool mask_value = (kA5Mask[mask_index] >> mask_offset) & 0b1; + // Get the destination value. + T dst = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + if (mask_value) { // If it's an active value. + // If the index is out of range, it's 0. + if (index_value >= max_index) { + EXPECT_EQ(0, dst) << "max: " << max_index << " indx: " << index_value; + } else { // Otherwise, get the src value from vs2 and compare. + int reg = index_value / num_values_per_reg; + int element = index_value % num_values_per_reg; + T src = tester->vreg()[kVs2 + reg]->data_buffer()->Get<T>(element); + EXPECT_EQ(src, dst); + } + } else { // Inactive values are unchanged. + T src = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst); + } + } + } +} + +// Test vrgather instruction for 1, 2, 4, and 8 byte SEWs - vector index. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVV8) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint8_t, uint8_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVV16) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint16_t, uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVV32) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint32_t, uint32_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVV64) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint64_t, uint64_t>(this, instruction_); +} + +// Test vrgather instruction for 1, 2, 4, and 8 byte SEWs - scalar index. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVS8) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + VrgatherVSHelper<uint8_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVS16) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + VrgatherVSHelper<uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVS32) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + VrgatherVSHelper<uint32_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, VrgatherVS64) { + SetSemanticFunction(&Vrgather); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + VrgatherVSHelper<uint64_t>(this, instruction_); +} + +// Test vrgatherei16 instruction for SEW values of 1, 2, 4, and 8 bytes. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vrgatherei16VV8) { + SetSemanticFunction(&Vrgatherei16); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint8_t, uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vrgatherei16VV16) { + SetSemanticFunction(&Vrgatherei16); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint16_t, uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vrgatherei16VV32) { + SetSemanticFunction(&Vrgatherei16); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint32_t, uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vrgatherei16VV64) { + SetSemanticFunction(&Vrgatherei16); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + VrgatherVVHelper<uint64_t, uint16_t>(this, instruction_); +} + +// Helper function for slideup/down instructions. +template <typename T> +void SlideHelper(RiscVCheriotVectorPermuteInstructionsTest *tester, + Instruction *inst, bool is_slide_up) { + auto *rv_vector = tester->rv_vector(); + uint32_t vtype = + (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[7]; + tester->ConfigureVectorUnit(vtype, 2048); + int vlen = rv_vector->vector_length(); + int max_vlen = rv_vector->max_vector_length(); + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + // Initialize vs2 to random values. + for (int reg = 0; reg < 8; reg++) { + auto src_span = tester->vreg()[kVs2 + reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + src_span[i] = tester->RandomValue<T>(); + } + } + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kA5Mask}}); + // Try 20 different shift values randomly. + for (int num = 0; num < 20; num++) { + CheriotRegister::ValueType shift_value = + absl::Uniform(absl::IntervalClosed, tester->bitgen(), 0, 2 * vlen); + tester->SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, shift_value}}); + // Initialize the destination register set. + int value_indx = 0; + for (int reg = 0; reg < 8; reg++) { + auto dst_span = tester->vreg()[kVd + reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + dst_span[i] = value_indx++; + } + } + inst->Execute(); + for (int i = 0; i < vlen; i++) { + int value_reg_offset = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + int mask_index = i >> 8; + int mask_offset = i & 0b111; + bool mask_value = (kA5Mask[mask_index] >> mask_offset) & 0b1; + T dst = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + if (is_slide_up) { // For slide up instruction. + if ((i < shift_value) || (!mask_value)) { + // If it's an inactive element, or the index is less than the shift + // amount, the element is unchanged. + T value = static_cast<T>(i); + EXPECT_EQ(value, dst) << "indx: " << i; + } else { + // Active elements are shifted up by 'shift_value'. + int src_reg_offset = (i - shift_value) / num_values_per_reg; + int src_reg_index = (i - shift_value) % num_values_per_reg; + T src = tester->vreg()[kVs2 + src_reg_offset]->data_buffer()->Get<T>( + src_reg_index); + EXPECT_EQ(src, dst) << "indx: " << i; + } + } else { // This is slide down. + if (mask_value) { + if (i + shift_value >= max_vlen) { + // Active elements originating beyond max_vlen are 0. + EXPECT_EQ(0, dst) << "indx: " << i; + } else { + // Active elements are shifted down by 'shift_value'. + int src_reg_offset = (i + shift_value) / num_values_per_reg; + int src_reg_index = (i + shift_value) % num_values_per_reg; + T src = + tester->vreg()[kVs2 + src_reg_offset]->data_buffer()->Get<T>( + src_reg_index); + EXPECT_EQ(src, dst) << "indx: " << i; + } + } else { + // All inactive elements are unchanged. + T src = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst) << "indx: " << i; + } + } + } + } +} + +// Test vslideup instruction for SEW values of 1, 2, 4, and 8 bytes. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslideup8) { + SetSemanticFunction(&Vslideup); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint8_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslideup16) { + SetSemanticFunction(&Vslideup); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint16_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslideup32) { + SetSemanticFunction(&Vslideup); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint32_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslideup64) { + SetSemanticFunction(&Vslideup); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint64_t>(this, instruction_, /*is_slide_up*/ true); +} + +// Test vslidedown instruction for SEW values of 1, 2, 4, and 8 bytes. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslidedown8) { + SetSemanticFunction(&Vslidedown); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint8_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslidedown16) { + SetSemanticFunction(&Vslidedown); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint16_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslidedown32) { + SetSemanticFunction(&Vslidedown); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint32_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslidedown64) { + SetSemanticFunction(&Vslidedown); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + SlideHelper<uint64_t>(this, instruction_, /*is_slide_up*/ false); +} + +template <typename T> +void Slide1Helper(RiscVCheriotVectorPermuteInstructionsTest *tester, + Instruction *inst, bool is_slide_up) { + auto *rv_vector = tester->rv_vector(); + uint32_t vtype = + (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[7]; + tester->ConfigureVectorUnit(vtype, 2048); + int vlen = rv_vector->vector_length(); + int max_vlen = rv_vector->max_vector_length(); + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + // Initialize vs2 to random values. + for (int reg = kVs2; reg < kVs2 + 8; reg++) { + auto span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + span[i] = tester->RandomValue<T>(); + } + } + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kA5Mask}}); + // Try 20 different shift values randomly. + for (int num = 0; num < 20; num++) { + CheriotRegister::ValueType fill_in_value = + absl::Uniform(absl::IntervalClosed, tester->bitgen(), 0, 2 * vlen); + tester->SetRegisterValues<CheriotRegister::ValueType>( + {{kRs1Name, fill_in_value}}); + fill_in_value = static_cast<T>(fill_in_value); + inst->Execute(); + for (int i = 0; i < vlen; i++) { + int value_reg_offset = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + int mask_index = i >> 8; + int mask_offset = i & 0b111; + bool mask_value = (kA5Mask[mask_index] >> mask_offset) & 0b1; + T dst = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + if (is_slide_up) { + if (!mask_value) { // Inactive elements are unchanged. + T src = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst) << "i: " << i; + } else { + if (i == 0) { // The first value should match the fill-in. + EXPECT_EQ(fill_in_value, dst) << "i: " << i; + } else { // Other values are shifted by 1. + int src_reg_offset = (i - 1) / num_values_per_reg; + int src_reg_index = (i - 1) % num_values_per_reg; + T src = + tester->vreg()[kVs2 + src_reg_offset]->data_buffer()->Get<T>( + src_reg_index); + EXPECT_EQ(src, dst) << "i: " << i; + } + } + } else { // This is slide down. + if (mask_value) { + if (i + 1 >= max_vlen) { // The last value should match the fill-in. + EXPECT_EQ(fill_in_value, dst); + } else { // Other elements are shifted by 1. + int src_reg_offset = (i + 1) / num_values_per_reg; + int src_reg_index = (i + 1) % num_values_per_reg; + T src = + tester->vreg()[kVs2 + src_reg_offset]->data_buffer()->Get<T>( + src_reg_index); + EXPECT_EQ(src, dst) << "i: " << i; + } + } else { // Inactive elements are unchanged. + T src = tester->vreg()[kVd + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst) << "i: " << i; + } + } + } + } +} + +// Test vslide1up instruction for SEW values of 1, 2, 4, and 8 bytes. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1up8) { + SetSemanticFunction(&Vslide1up); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint8_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1up16) { + SetSemanticFunction(&Vslide1up); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint16_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1up32) { + SetSemanticFunction(&Vslide1up); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint32_t>(this, instruction_, /*is_slide_up*/ true); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1up64) { + SetSemanticFunction(&Vslide1up); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint64_t>(this, instruction_, /*is_slide_up*/ true); +} +// Test vslide1down instruction for SEW values of 1, 2, 4, and 8 bytes. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1down8) { + SetSemanticFunction(&Vslide1down); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint8_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1down16) { + SetSemanticFunction(&Vslide1down); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint16_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1down32) { + SetSemanticFunction(&Vslide1down); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint32_t>(this, instruction_, /*is_slide_up*/ false); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vslide1down64) { + SetSemanticFunction(&Vslide1down); + AppendVectorRegisterOperands({kVs2}, {}); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({kVmask}, {kVd}); + Slide1Helper<uint64_t>(this, instruction_, /*is_slide_up*/ false); +} + +template <typename T> +void CompressHelper(RiscVCheriotVectorPermuteInstructionsTest *tester, + Instruction *inst) { + auto *rv_vector = tester->rv_vector(); + uint32_t vtype = + (kSewSettingsByByteSize[sizeof(T)] << 3) | kLmulSettingByLogSize[7]; + tester->ConfigureVectorUnit(vtype, 2048); + int vlen = rv_vector->vector_length(); + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + auto vd_span = tester->vreg()[kVd]->data_buffer()->Get<T>(); + std::vector<T> origin_vd_values(vd_span.begin(), vd_span.end()); + // Initialize vs2 to random values. + for (int reg = kVs2; reg < kVs2 + 8; reg++) { + auto span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + span[i] = tester->RandomValue<T>(); + } + } + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kA5Mask}}); + inst->Execute(); + // First check all the elements that were compressed (mask bit true). + int offset = 0; + for (int i = 0; i < vlen; i++) { + int value_reg_offset = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + int mask_index = i >> 8; + int mask_offset = i & 0b111; + bool mask_value = (kA5Mask[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + T src = tester->vreg()[kVs2 + value_reg_offset]->data_buffer()->Get<T>( + value_elem_index); + int dst_reg_index = offset / num_values_per_reg; + int dst_element_index = offset % num_values_per_reg; + T dst = tester->vreg()[kVd + dst_reg_index]->data_buffer()->Get<T>( + dst_element_index); + EXPECT_EQ(src, dst) << "index: " << i; + offset++; + } + } + // The remaining elements are unchanged. + for (int i = offset; i < vlen; i++) { + int value_reg_index = i / num_values_per_reg; + int value_elem_index = i % num_values_per_reg; + T src = origin_vd_values[value_elem_index]; + T dst = tester->vreg()[kVd + value_reg_index]->data_buffer()->Get<T>( + value_elem_index); + EXPECT_EQ(src, dst) << "index: " << i; + } +} + +// Test compress instruction for SEW of 8, 16, 32, and 64. +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vcompress8) { + SetSemanticFunction(&Vcompress); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + CompressHelper<uint8_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vcompress16) { + SetSemanticFunction(&Vcompress); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + CompressHelper<uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vcompress32) { + SetSemanticFunction(&Vcompress); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + CompressHelper<uint32_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorPermuteInstructionsTest, Vcompress64) { + SetSemanticFunction(&Vcompress); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + CompressHelper<uint64_t>(this, instruction_); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_reduction_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_reduction_instructions_test.cc new file mode 100644 index 0000000..b7b6b28 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_reduction_instructions_test.cc
@@ -0,0 +1,412 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_reduction_instructions.h" + +#include <algorithm> +#include <cstdint> +#include <functional> +#include <vector> + +#include "absl/random/random.h" +#include "absl/strings/string_view.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_register.h" + +namespace { +using ::absl::Span; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::WideType; +using ::mpact::sim::riscv::RV32Register; +using ::mpact::sim::riscv::RVVectorRegister; + +using ::mpact::sim::cheriot::Vredand; +using ::mpact::sim::cheriot::Vredmax; +using ::mpact::sim::cheriot::Vredmaxu; +using ::mpact::sim::cheriot::Vredmin; +using ::mpact::sim::cheriot::Vredminu; +using ::mpact::sim::cheriot::Vredor; +using ::mpact::sim::cheriot::Vredsum; +using ::mpact::sim::cheriot::Vredxor; +using ::mpact::sim::cheriot::Vwredsum; +using ::mpact::sim::cheriot::Vwredsumu; + +class RiscVCheriotVectorReductionInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase { + public: + template <typename Vd, typename Vs2> + void ReductionOpTestHelper(absl::string_view name, int sew, Instruction *inst, + std::function<Vd(Vd, Vs2)> operation) { + int byte_sew = sew / 8; + if (byte_sew != sizeof(Vd) && byte_sew != sizeof(Vs2)) { + FAIL() << name << ": selected element width != any operand types" + << "sew: " << sew << " Vd: " << sizeof(Vd) + << " Vs2: " << sizeof(Vs2); + return; + } + // Number of elements per vector register. + constexpr int vs2_size = kVectorLengthInBytes / sizeof(Vs2); + // Input values for 8 registers. + Vs2 vs2_value[vs2_size * 8]; + auto vs2_span = Span<Vs2>(vs2_value); + Vs2 vs1_value[vs2_size]; + auto vs1_span = Span<Vs2>(vs1_value); + AppendVectorRegisterOperands({kVs2, kVs1, kVmask}, {kVd}); + // Initialize input values. + FillArrayWithRandomValues<Vs2>(vs2_span); + vs1_span[0] = RandomValue<Vs2>(); + auto mask_span = Span<const uint8_t>(kA5Mask); + SetVectorRegisterValues<uint8_t>({{kVmaskName, mask_span}}); + SetVectorRegisterValues<Vs2>({{kVs1Name, Span<const Vs2>(vs1_span)}}); + // Initialize the accumulator with the value from vs1[0]. + for (int i = 0; i < 8; i++) { + auto vs2_name = absl::StrCat("v", kVs2 + i); + SetVectorRegisterValues<Vs2>( + {{vs2_name, vs2_span.subspan(vs2_size * i, vs2_size)}}); + } + // Iterate across the different lmul values. + for (int lmul_index = 0; lmul_index < 7; lmul_index++) { + for (int vlen_count = 0; vlen_count < 4; vlen_count++) { + int lmul8 = kLmul8Values[lmul_index]; + int lmul8_vs2 = lmul8 * sizeof(Vs2) / byte_sew; + int lmul8_vd = lmul8 * sizeof(Vd) / byte_sew; + int num_values = lmul8 * kVectorLengthInBytes / (8 * byte_sew); + // Set vlen, but leave vlen high at least once. + int vlen = 1024; + if (vlen_count > 0) { + vlen = + absl::Uniform(absl::IntervalOpenClosed, bitgen_, 0, num_values); + } + num_values = std::min(num_values, vlen); + // Configure vector unit for different lmul settings. + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettings[lmul_index]; + ConfigureVectorUnit(vtype, vlen); + ClearVectorRegisterGroup(kVd, 8); + + inst->Execute(); + + if ((lmul8_vs2 < 1) || (lmul8_vs2 > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + if ((lmul8_vd < 1) || (lmul8_vd > 64)) { + EXPECT_TRUE(rv_vector_->vector_exception()); + rv_vector_->clear_vector_exception(); + continue; + } + + EXPECT_FALSE(rv_vector_->vector_exception()); + Vd accumulator = static_cast<Vd>(vs1_span[0]); + for (int i = 0; i < num_values; i++) { + int mask_index = i >> 3; + int mask_offset = i & 0b111; + bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + accumulator = operation(accumulator, vs2_span[i]); + } + } + EXPECT_EQ(accumulator, vreg_[kVd]->data_buffer()->Get<Vd>(0)); + } + } + } +}; + +// Test functions for vector reduction instruction semantic functions. The +// vector reduction instructions take two vector source operands and a mask +// operand, and write to the first element of a destination vector operand. + +// Vector sum reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredsum8) { + using T = uint8_t; + SetSemanticFunction(&Vredsum); + ReductionOpTestHelper<T, T>("Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 + val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredsum16) { + using T = uint16_t; + SetSemanticFunction(&Vredsum); + ReductionOpTestHelper<T, T>("Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 + val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredsum32) { + using T = uint32_t; + SetSemanticFunction(&Vredsum); + ReductionOpTestHelper<T, T>("Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 + val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredsum64) { + using T = uint64_t; + SetSemanticFunction(&Vredsum); + ReductionOpTestHelper<T, T>("Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 + val1; }); +} + +// Vector and reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredand8) { + using T = uint8_t; + SetSemanticFunction(&Vredand); + ReductionOpTestHelper<T, T>("Vredand", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 & val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredand16) { + using T = uint16_t; + SetSemanticFunction(&Vredand); + ReductionOpTestHelper<T, T>("Vredand", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 & val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredand32) { + using T = uint32_t; + SetSemanticFunction(&Vredand); + ReductionOpTestHelper<T, T>("Vredand", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 & val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredand64) { + using T = uint64_t; + SetSemanticFunction(&Vredand); + ReductionOpTestHelper<T, T>("Vredand", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 & val1; }); +} + +// Vector or reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredor8) { + using T = uint8_t; + SetSemanticFunction(&Vredor); + ReductionOpTestHelper<T, T>("Vredor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 | val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredor16) { + using T = uint16_t; + SetSemanticFunction(&Vredor); + ReductionOpTestHelper<T, T>("Vredor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 | val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredor32) { + using T = uint32_t; + SetSemanticFunction(&Vredor); + ReductionOpTestHelper<T, T>("Vredor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 | val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredor64) { + using T = uint64_t; + SetSemanticFunction(&Vredor); + ReductionOpTestHelper<T, T>("Vredor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 | val1; }); +} + +// Vector xor reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredxor8) { + using T = uint8_t; + SetSemanticFunction(&Vredxor); + ReductionOpTestHelper<T, T>("Vredxor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 ^ val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredxor16) { + using T = uint16_t; + SetSemanticFunction(&Vredxor); + ReductionOpTestHelper<T, T>("Vredxor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 ^ val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredxor32) { + using T = uint32_t; + SetSemanticFunction(&Vredxor); + ReductionOpTestHelper<T, T>("Vredxor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 ^ val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredxor64) { + using T = uint64_t; + SetSemanticFunction(&Vredxor); + ReductionOpTestHelper<T, T>("Vredxor", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 ^ val1; }); +} + +// Vector unsigned min reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredminu8) { + using T = uint8_t; + SetSemanticFunction(&Vredminu); + ReductionOpTestHelper<T, T>( + "Vredminu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredminu16) { + using T = uint16_t; + SetSemanticFunction(&Vredminu); + ReductionOpTestHelper<T, T>( + "Vredminu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredminu32) { + using T = uint32_t; + SetSemanticFunction(&Vredminu); + ReductionOpTestHelper<T, T>( + "Vredminu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredminu64) { + using T = uint64_t; + SetSemanticFunction(&Vredminu); + ReductionOpTestHelper<T, T>( + "Vredminu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} + +// Vector signed min reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmin8) { + using T = int8_t; + SetSemanticFunction(&Vredmin); + ReductionOpTestHelper<T, T>( + "Vredmin", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmin16) { + using T = int16_t; + SetSemanticFunction(&Vredmin); + ReductionOpTestHelper<T, T>( + "Vredmin", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmin32) { + using T = int32_t; + SetSemanticFunction(&Vredmin); + ReductionOpTestHelper<T, T>( + "Vredmin", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmin64) { + using T = int64_t; + SetSemanticFunction(&Vredmin); + ReductionOpTestHelper<T, T>( + "Vredmin", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 < val1 ? val0 : val1; }); +} + +// Vector unsigned max reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmaxu8) { + using T = uint8_t; + SetSemanticFunction(&Vredmaxu); + ReductionOpTestHelper<T, T>( + "Vredmaxu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmaxu16) { + using T = uint16_t; + SetSemanticFunction(&Vredmaxu); + ReductionOpTestHelper<T, T>( + "Vredmaxu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmaxu32) { + using T = uint32_t; + SetSemanticFunction(&Vredmaxu); + ReductionOpTestHelper<T, T>( + "Vredmaxu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmaxu64) { + using T = uint64_t; + SetSemanticFunction(&Vredmaxu); + ReductionOpTestHelper<T, T>( + "Vredmaxu", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} + +// Vector signed max reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmax8) { + using T = int8_t; + SetSemanticFunction(&Vredmax); + ReductionOpTestHelper<T, T>( + "Vredmax", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmax16) { + using T = int16_t; + SetSemanticFunction(&Vredmax); + ReductionOpTestHelper<T, T>( + "Vredmax", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmax32) { + using T = int32_t; + SetSemanticFunction(&Vredmax); + ReductionOpTestHelper<T, T>( + "Vredmax", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vredmax64) { + using T = int64_t; + SetSemanticFunction(&Vredmax); + ReductionOpTestHelper<T, T>( + "Vredmax", /*sew*/ sizeof(T) * 8, instruction_, + [](T val0, T val1) -> T { return val0 > val1 ? val0 : val1; }); +} + +// Vector widening unsigned sum reduction. +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsumu8) { + using T = uint8_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsumu); + ReductionOpTestHelper<WT, T>( + "Vredsumu", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsumu16) { + using T = uint16_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsumu); + ReductionOpTestHelper<WT, T>( + "Vredsumu", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsumu32) { + using T = uint32_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsumu); + ReductionOpTestHelper<WT, T>( + "Vredsumu", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} + +// Vector widening signed sum reduction. + +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsum8) { + using T = int8_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsum); + ReductionOpTestHelper<WT, T>( + "Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsum16) { + using T = int16_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsum); + ReductionOpTestHelper<WT, T>( + "Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} +TEST_F(RiscVCheriotVectorReductionInstructionsTest, Vwredsum32) { + using T = int32_t; + using WT = WideType<T>::type; + SetSemanticFunction(&Vwredsum); + ReductionOpTestHelper<WT, T>( + "Vredsum", /*sew*/ sizeof(T) * 8, instruction_, + [](WT val0, T val1) -> WT { return val0 + static_cast<WT>(val1); }); +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_true_test.cc b/cheriot/test/riscv_cheriot_vector_true_test.cc new file mode 100644 index 0000000..05b73f2 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_true_test.cc
@@ -0,0 +1,66 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <cstdint> + +#include "cheriot/cheriot_state.h" +#include "cheriot/cheriot_vector_state.h" +#include "cheriot/cheriot_vector_true_operand.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/util/memory/tagged_flat_demand_memory.h" + +namespace { + +using ::mpact::sim::cheriot::CheriotState; +using ::mpact::sim::cheriot::CheriotVectorState; +using ::mpact::sim::cheriot::CheriotVectorTrueOperand; +using ::mpact::sim::util::TaggedFlatDemandMemory; + +constexpr int kVLengthInBytes = 64; +// Test fixture. +class CheriotVectorTrueTest : public testing::Test { + protected: + CheriotVectorTrueTest() : memory_(8) { + state_ = new CheriotState("test", &memory_); + vstate_ = new CheriotVectorState(state_, kVLengthInBytes); + } + ~CheriotVectorTrueTest() override { + delete state_; + delete vstate_; + } + + TaggedFlatDemandMemory memory_; + CheriotState *state_; + CheriotVectorState *vstate_; +}; + +TEST_F(CheriotVectorTrueTest, Initial) { + auto *op = new CheriotVectorTrueOperand(state_); + for (int i = 0; i < op->shape()[0]; ++i) { + EXPECT_EQ(op->AsUint8(i), 0xff) << "element: " << i; + } + delete op; +} + +TEST_F(CheriotVectorTrueTest, Register) { + auto *op = new CheriotVectorTrueOperand(state_); + auto *reg = op->GetRegister(0); + auto span = reg->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < op->shape()[0]; ++i) { + EXPECT_EQ(span[i], 0xff) << "element: " << i; + } + delete op; +} + +} // namespace
diff --git a/cheriot/test/riscv_cheriot_vector_unary_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_unary_instructions_test.cc new file mode 100644 index 0000000..29a4a69 --- /dev/null +++ b/cheriot/test/riscv_cheriot_vector_unary_instructions_test.cc
@@ -0,0 +1,668 @@ +// Copyright 2023 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "cheriot/riscv_cheriot_vector_unary_instructions.h" + +#include <cstdint> +#include <cstring> +#include <ios> +#include <vector> + +#include "absl/random/random.h" +#include "absl/types/span.h" +#include "cheriot/cheriot_register.h" +#include "cheriot/test/riscv_cheriot_vector_instructions_test_base.h" +#include "googlemock/include/gmock/gmock.h" +#include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" +#include "riscv//riscv_register.h" + +namespace { + +using ::absl::Span; +using ::mpact::sim::cheriot::CheriotRegister; +using ::mpact::sim::generic::Instruction; +using ::mpact::sim::generic::SameSignedType; +using ::mpact::sim::riscv::RVVectorRegister; + +using ::mpact::sim::cheriot::Vcpop; +using ::mpact::sim::cheriot::Vfirst; +using ::mpact::sim::cheriot::Vid; +using ::mpact::sim::cheriot::Viota; +using ::mpact::sim::cheriot::Vmsbf; +using ::mpact::sim::cheriot::Vmsif; +using ::mpact::sim::cheriot::Vmsof; +using ::mpact::sim::cheriot::VmvFromScalar; +using ::mpact::sim::cheriot::VmvToScalar; +using ::mpact::sim::cheriot::Vsext2; +using ::mpact::sim::cheriot::Vsext4; +using ::mpact::sim::cheriot::Vsext8; +using ::mpact::sim::cheriot::Vzext2; +using ::mpact::sim::cheriot::Vzext4; +using ::mpact::sim::cheriot::Vzext8; + +using SignedXregType = + SameSignedType<CheriotRegister::ValueType, int64_t>::type; + +constexpr uint8_t k5AMask[kVectorLengthInBytes] = { + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, + 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, 0x5a, +}; + +constexpr uint8_t kE7Mask[kVectorLengthInBytes] = { + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, + 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, 0xe7, +}; + +constexpr uint8_t kAllOnesMask[kVectorLengthInBytes] = { + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, +}; + +class RiscVCheriotVectorUnaryInstructionsTest + : public RiscVCheriotVectorInstructionsTestBase {}; + +// Test move vector element 0 to scalar register. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, VmvToScalar) { + SetSemanticFunction(&VmvToScalar); + AppendRegisterOperands({}, {kRs1Name}); + AppendVectorRegisterOperands({kVs2}, {}); + for (int byte_sew : {1, 2, 4, 8}) { + int vlen = kVectorLengthInBytes / byte_sew; + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[4]; + ConfigureVectorUnit(vtype, vlen); + // Test 10 different values. + for (int i = 0; i < 10; i++) { + int64_t value; + switch (byte_sew) { + case 1: { + auto val8 = RandomValue<int8_t>(); + value = static_cast<int64_t>(val8); + SetVectorRegisterValues<int8_t>( + {{kVs2Name, absl::Span<const int8_t>(&val8, 1)}}); + break; + } + case 2: { + auto val16 = RandomValue<int16_t>(); + value = static_cast<int64_t>(val16); + SetVectorRegisterValues<int16_t>( + {{kVs2Name, absl::Span<const int16_t>(&val16, 1)}}); + break; + } + case 4: { + auto val32 = RandomValue<int32_t>(); + value = static_cast<int64_t>(val32); + SetVectorRegisterValues<int32_t>( + {{kVs2Name, absl::Span<const int32_t>(&val32, 1)}}); + break; + } + case 8: { + auto val64 = RandomValue<int64_t>(); + value = val64; + SetVectorRegisterValues<int64_t>( + {{kVs2Name, absl::Span<const int64_t>(&val64, 1)}}); + break; + } + } + instruction_->Execute(); + EXPECT_EQ(creg_[kRs1]->data_buffer()->Get<SignedXregType>(0), + static_cast<SignedXregType>(value)); + } + } +} + +// Test move scalar to vector element 0. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, VmvFromScalar) { + SetSemanticFunction(&VmvFromScalar); + AppendRegisterOperands({kRs1Name}, {}); + AppendVectorRegisterOperands({}, {kVs2}); + for (int byte_sew : {1, 2, 4, 8}) { + int vlen = kVectorLengthInBytes / byte_sew; + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[4]; + ConfigureVectorUnit(vtype, vlen); + // Test 10 different values. + for (int i = 0; i < 10; i++) { + auto value = RandomValue<SignedXregType>(); + SetRegisterValues<SignedXregType>({{kRs1Name, value}}); + instruction_->Execute(); + switch (byte_sew) { + case 1: + EXPECT_EQ(vreg_[kVs2]->data_buffer()->Get<int8_t>(0), + static_cast<int8_t>(value)); + break; + case 2: + EXPECT_EQ(vreg_[kVs2]->data_buffer()->Get<int16_t>(0), + static_cast<int16_t>(value)); + break; + case 4: + EXPECT_EQ(vreg_[kVs2]->data_buffer()->Get<int32_t>(0), + static_cast<int32_t>(value)); + break; + case 8: + EXPECT_EQ(vreg_[kVs2]->data_buffer()->Get<int64_t>(0), + static_cast<int64_t>(value)); + break; + } + } + } +} + +// Test vector mask population count. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vcpop) { + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettingByLogSize[7]; + SetSemanticFunction(&Vcpop); + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + AppendRegisterOperands({}, {kRdName}); + for (int vlen : {1, 8, 32, 48, 127, 200}) { + ConfigureVectorUnit(vtype, vlen); + // All 1s for mask and vector. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, kAllOnesMask}, {kVmaskName, kAllOnesMask}}); + instruction_->Execute(); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<CheriotRegister::ValueType>(0), + vlen); + + // Mask is inverse of vector. Will result in 0. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, kA5Mask}, {kVmaskName, k5AMask}}); + instruction_->Execute(); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<CheriotRegister::ValueType>(0), 0); + } +} + +// Test vector mask find first set. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vfirst) { + SetSemanticFunction(&Vfirst); + AppendVectorRegisterOperands({kVs2, kVmask}, {}); + AppendRegisterOperands({}, {kRdName}); + uint8_t reg_value[kVectorLengthInBytes]; + // Set vtype to byte vector, and vector lmul to 8. + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettingByLogSize[7]; + ConfigureVectorUnit(vtype, kVectorLengthInBytes * 8); + + // Clear the reg_value array. + std::memset(reg_value, 0, kVectorLengthInBytes); + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, reg_value}, {kVmaskName, kAllOnesMask}}); + // Execute the instruction. The result should be minus 1. + instruction_->Execute(); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<SignedXregType>(0), -1); + + // Pick a random location 20 times and set that bit to 1. Test first that + // the correct value is returned, then clear the mask bit that corresponds to + // that value, and ensure that now the result is -1. + for (int i = 0; i < 20; i++) { + // Clear the reg_value array. + std::memset(reg_value, 0, kVectorLengthInBytes); + // Get a random value for index to set. + uint32_t index = absl::Uniform(absl::IntervalClosed, bitgen_, 0, + kVectorLengthInBytes * 8 - 1); + // Compute the byte and bit index to be set. + auto byte_index = index >> 3; + auto bit_index = index & 0b111; + // Set the bit in the source register. + reg_value[byte_index] |= 1 << bit_index; + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, reg_value}, {kVmaskName, kAllOnesMask}}); + // Execute the instruction. The result should be the index value. + instruction_->Execute(); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<CheriotRegister::ValueType>(0), + index); + + // Clear the mask bit for the bit that was set. + auto mask_db = vreg_[kVmask]->data_buffer()->Get<uint8_t>(); + mask_db[byte_index] &= ~(1 << bit_index); + // Execute the instruction. The result should be minus 1. + instruction_->Execute(); + EXPECT_EQ(creg_[kRd]->data_buffer()->Get<SignedXregType>(0), -1); + } +} + +// Zero extension from sew/2 to sew. Test for sew of 16, 32 and 64 bits. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext2_16) { + SetSemanticFunction(&Vzext2); + UnaryOpTestHelperV<uint16_t, uint8_t>( + "Vzext2_16", /*sew*/ 16, instruction_, + [](uint8_t vs2) -> uint16_t { return vs2; }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext2_32) { + SetSemanticFunction(&Vzext2); + UnaryOpTestHelperV<uint32_t, uint16_t>( + "Vzext2_32", /*sew*/ 32, instruction_, + [](uint16_t vs2) -> uint32_t { return vs2; }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext2_64) { + SetSemanticFunction(&Vzext2); + UnaryOpTestHelperV<uint64_t, uint32_t>( + "Vzext2_64", /*sew*/ 64, instruction_, + [](uint32_t vs2) -> uint64_t { return vs2; }); +} + +// Sign extension from sew/2 to sew. Testing for sew of 16, 32 and 64. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext2_16) { + SetSemanticFunction(&Vsext2); + UnaryOpTestHelperV<int16_t, int8_t>("Vsext2_16", /*sew*/ 16, instruction_, + [](int8_t vs2) -> int16_t { + int16_t res = static_cast<int16_t>(vs2); + return res; + }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext2_32) { + SetSemanticFunction(&Vsext2); + UnaryOpTestHelperV<int32_t, int16_t>( + "Vsext2_32", /*sew*/ 32, instruction_, [](int16_t vs2) -> int32_t { + int32_t res = static_cast<int32_t>(vs2); + return res; + }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext2_64) { + SetSemanticFunction(&Vsext2); + UnaryOpTestHelperV<int64_t, int32_t>( + "Vsext2_64", /*sew*/ 64, instruction_, [](int32_t vs2) -> int64_t { + int64_t res = static_cast<int64_t>(vs2); + return res; + }); +} + +// Zero extension from sew/4 to sew. Testing for sew of 32 and 64. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext4_32) { + SetSemanticFunction(&Vzext4); + UnaryOpTestHelperV<uint32_t, uint8_t>( + "Vzext4_32", /*sew*/ 32, instruction_, + [](uint8_t vs2) -> uint32_t { return vs2; }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext4_64) { + SetSemanticFunction(&Vzext4); + UnaryOpTestHelperV<uint64_t, uint16_t>( + "Vzext4_32", /*sew*/ 64, instruction_, + [](uint16_t vs2) -> uint64_t { return vs2; }); +} + +// Sign extension from sew/4 to sew. Testing for sew of 32 and 64. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext4_32) { + SetSemanticFunction(&Vsext4); + UnaryOpTestHelperV<int32_t, int8_t>("Vzext4_32", /*sew*/ 32, instruction_, + [](int8_t vs2) -> int32_t { + int32_t res = static_cast<int32_t>(vs2); + return res; + }); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext4_64) { + SetSemanticFunction(&Vsext4); + UnaryOpTestHelperV<int64_t, int16_t>( + "Vzext4_64", /*sew*/ 64, instruction_, [](int16_t vs2) -> int64_t { + int64_t res = static_cast<int64_t>(vs2); + return res; + }); +} + +// Zero extension from sew/8 to sew. Testing for sew of 64. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vzext8_64) { + SetSemanticFunction(&Vzext8); + UnaryOpTestHelperV<uint64_t, uint8_t>("Vsext8_64", /*sew*/ 64, instruction_, + [](uint8_t vs2) -> uint64_t { + uint64_t res = vs2; + return (res << 56) >> 56; + }); +} + +// Sign extension from sew/8 to sew. Testing for sew of 64. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vsext8_64) { + SetSemanticFunction(&Vsext8); + UnaryOpTestHelperV<int64_t, int8_t>("Vzext8_64", /*sew*/ 64, instruction_, + [](int8_t vs2) -> int64_t { + int64_t res = static_cast<int64_t>(vs2); + return res; + }); +} + +// Test "set before first mask bit". +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vmsbf) { + SetSemanticFunction(&Vmsbf); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + uint8_t reg_value[kVectorLengthInBytes]; + // Set vtype to byte vector, and vector lmul to 8. + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettingByLogSize[7]; + ConfigureVectorUnit(vtype, kVectorLengthInBytes * 8); + + // Clear the reg_value array. + std::memset(reg_value, 0, kVectorLengthInBytes); + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, reg_value}, {kVmaskName, kAllOnesMask}}); + // Execute the instruction. The result should be all 1s. + instruction_->Execute(); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(dest_span[i], 0b1111'1111) << "Index: " << i; + } + + // Pick a random location 20 times and set that bit to 1. Test first that + // the vector mask is produced. + for (int i = 0; i < 20; i++) { + // Clear the reg_value array. + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, kA5Mask}, {kVmaskName, k5AMask}}); + // Get a random value for which bit index to set. + uint32_t index = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0, + kVectorLengthInBytes * 8); + // Compute the byte and bit index to be set. + auto byte_index = index >> 3; + auto bit_index = index & 0b111; + // Set the bit in the source register and mask registers. + auto mask_span = vreg_[kVmask]->data_buffer()->Get<uint8_t>(); + auto src_span = vreg_[kVs2]->data_buffer()->Get<uint8_t>(); + mask_span[byte_index] |= 1 << bit_index; + src_span[byte_index] |= 1 << bit_index; + // Execute the instruction. The result should be the index value. + instruction_->Execute(); + // Check results. + dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + // First check all the flag values before the byte where the index is. + for (int j = 0; j < byte_index; j++) { + EXPECT_EQ(dest_span[j], + (0b1111'1111 & mask_span[j]) | (dest_span[j] & ~mask_span[j])); + } + // Check the flag values of the byte where the index is. + EXPECT_EQ(dest_span[byte_index], + (((1 << bit_index) - 1) & mask_span[byte_index]) | + (dest_span[byte_index] & ~mask_span[byte_index])); + // Check the flag values after the byte where the index is. + for (int j = byte_index + 1; j < kVectorLengthInBytes; j++) { + EXPECT_EQ(dest_span[j], dest_span[j] & ~mask_span[j]); + } + } +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vmsof) { + SetSemanticFunction(&Vmsof); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + uint8_t reg_value[kVectorLengthInBytes]; + // Set vtype to byte vector, and vector lmul to 8. + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettingByLogSize[7]; + ConfigureVectorUnit(vtype, kVectorLengthInBytes * 8); + + // Clear the reg_value array. + std::memset(reg_value, 0, kVectorLengthInBytes); + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, reg_value}, {kVmaskName, kAllOnesMask}}); + // Execute the instruction. The result should be all 1s. + instruction_->Execute(); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(dest_span[i], 0) << "Index: " << i; + } + + // Pick a random location 20 times and set that bit to 1. Test first that + // the vector mask is produced. + for (int i = 0; i < 20; i++) { + // Clear the reg_value array. + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, kA5Mask}, {kVmaskName, k5AMask}}); + // Get a random value for which bit index to set. + uint32_t index = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0, + kVectorLengthInBytes * 8); + // Compute the byte and bit index to be set. + auto byte_index = index >> 3; + auto bit_index = index & 0b111; + // Set the bit in the source register and mask registers. + auto mask_span = vreg_[kVmask]->data_buffer()->Get<uint8_t>(); + auto src_span = vreg_[kVs2]->data_buffer()->Get<uint8_t>(); + mask_span[byte_index] |= 1 << bit_index; + src_span[byte_index] |= 1 << bit_index; + // Execute the instruction. The result should be the index value. + instruction_->Execute(); + // Check results. + dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + // First check all the flag values before the byte where the index is. + for (int j = 0; j < byte_index; j++) { + EXPECT_EQ(dest_span[j], dest_span[j] & ~mask_span[j]); + } + // Check the flag values of the byte where the index is. + EXPECT_EQ(dest_span[byte_index], + ((1 << bit_index) & mask_span[byte_index]) | + (dest_span[byte_index] & ~mask_span[byte_index])) + << " dest: " << std::hex << (int)dest_span[byte_index] + << " mask: " << (int)mask_span[byte_index]; + // Check the flag values after the byte where the index is. + for (int j = byte_index + 1; j < kVectorLengthInBytes; j++) { + EXPECT_EQ(dest_span[j], dest_span[j] & ~mask_span[j]); + } + } +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vmsif) { + SetSemanticFunction(&Vmsif); + AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + uint8_t reg_value[kVectorLengthInBytes]; + // Set vtype to byte vector, and vector lmul to 8. + uint32_t vtype = (kSewSettingsByByteSize[1] << 3) | kLmulSettingByLogSize[7]; + ConfigureVectorUnit(vtype, kVectorLengthInBytes * 8); + + // Clear the reg_value array. + std::memset(reg_value, 0, kVectorLengthInBytes); + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, reg_value}, {kVmaskName, kAllOnesMask}}); + // Execute the instruction. The result should be all 1s. + instruction_->Execute(); + auto dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(dest_span[i], 0b1111'1111) << "Index: " << i; + } + + // Pick a random location 20 times and set that bit to 1. Test first that + // the vector mask is produced. + for (int i = 0; i < 20; i++) { + // Clear the reg_value array. + // Set the register values. + SetVectorRegisterValues<uint8_t>( + {{kVs2Name, kA5Mask}, {kVmaskName, k5AMask}}); + // Get a random value for which bit index to set. + uint32_t index = absl::Uniform(absl::IntervalClosedOpen, bitgen_, 0, + kVectorLengthInBytes * 8); + // Compute the byte and bit index to be set. + auto byte_index = index >> 3; + auto bit_index = index & 0b111; + // Set the bit in the source register and mask registers. + auto mask_span = vreg_[kVmask]->data_buffer()->Get<uint8_t>(); + auto src_span = vreg_[kVs2]->data_buffer()->Get<uint8_t>(); + mask_span[byte_index] |= 1 << bit_index; + src_span[byte_index] |= 1 << bit_index; + // Execute the instruction. The result should be the index value. + instruction_->Execute(); + // Check results. + dest_span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + // First check all the flag values before the byte where the index is. + for (int j = 0; j < byte_index; j++) { + EXPECT_EQ(dest_span[j], + (0b1111'1111 & mask_span[j]) | (dest_span[j] & ~mask_span[j])); + } + // Check the flag values of the byte where the index is. + EXPECT_EQ(dest_span[byte_index], + (((1 << (bit_index + 1)) - 1) & mask_span[byte_index]) | + (dest_span[byte_index] & ~mask_span[byte_index])); + // Check the flag values after the byte where the index is. + for (int j = byte_index + 1; j < kVectorLengthInBytes; j++) { + EXPECT_EQ(dest_span[j], dest_span[j] & ~mask_span[j]); + } + } +} + +// Helper function for testing Viota instructions. +template <typename T> +void TestViota(RiscVCheriotVectorUnaryInstructionsTest *tester, + Instruction *inst) { + // Set up vector unit. + int byte_sew = sizeof(T); + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[7]; + tester->ConfigureVectorUnit(vtype, 1024); + int vlen = tester->rv_vector()->vector_length(); + int num_reg = + (vlen * byte_sew + kVectorLengthInBytes - 1) / kVectorLengthInBytes; + int num_per_reg = kVectorLengthInBytes / byte_sew; + int num_values_per_reg = kVectorLengthInBytes / sizeof(T); + + // Set up instruction. + tester->SetSemanticFunction(&Viota); + tester->AppendVectorRegisterOperands({kVs2, kVmask}, {kVd}); + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kE7Mask}}); + int count = vlen; + // Initialize vs2 to random values. + auto span = tester->vreg()[kVs2]->data_buffer()->Get<T>(); + for (int i = 0; i < num_values_per_reg; i++) { + span[i] = tester->RandomValue<T>(); + } + + for (int reg = kVd; reg < kVd + num_reg; reg++) { + auto reg_span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_per_reg; i++) { + reg_span[i] = static_cast<T>(count--); + } + } + + // Execute instruction. + inst->Execute(); + + // Check results. + const auto mask_span = tester->vreg()[kVmask]->data_buffer()->Get<uint8_t>(); + count = 0; + for (int i = 0; i < vlen; i++) { + int reg = kVd + i / num_per_reg; + int reg_index = i % num_per_reg; + const auto rs2_span = tester->vreg()[kVs2]->data_buffer()->Get<uint8_t>(); + auto value = tester->vreg()[reg]->data_buffer()->Get<T>(reg_index); + int mask_index = i >> 3; + int mask_offset = i & 0b111; + int mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + const bool rs2_bit = (rs2_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + EXPECT_EQ(value, static_cast<T>(count)) << "active index: " << i; + if (rs2_bit) { + count++; + } + } else { + EXPECT_EQ(value, static_cast<T>(vlen - i)) << "inactive index: " << i; + } + } +} + +// Test the viota instruction. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Viota8) { + TestViota<uint8_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Viota16) { + TestViota<uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Viota32) { + TestViota<uint32_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Viota64) { + TestViota<uint64_t>(this, instruction_); +} + +// Helper function for testing Vid instructions. +template <typename T> +void TestVid(RiscVCheriotVectorUnaryInstructionsTest *tester, + Instruction *inst) { + // Initialize the vector unit. + int byte_sew = sizeof(T); + uint32_t vtype = + (kSewSettingsByByteSize[byte_sew] << 3) | kLmulSettingByLogSize[7]; + int vlen = tester->rv_vector()->vector_length(); + int num_reg = + (vlen * byte_sew + kVectorLengthInBytes - 1) / kVectorLengthInBytes; + int num_per_reg = kVectorLengthInBytes / byte_sew; + tester->ConfigureVectorUnit(vtype, 1024); + + // Configure the instruction. + tester->SetSemanticFunction(&Vid); + tester->AppendVectorRegisterOperands({kVmask}, {kVd}); + tester->SetVectorRegisterValues<uint8_t>({{kVmaskName, kE7Mask}}); + int count = vlen; + for (int reg = kVd; reg < kVd + num_reg; reg++) { + auto reg_span = tester->vreg()[reg]->data_buffer()->Get<T>(); + for (int i = 0; i < num_per_reg; i++) { + reg_span[i] = static_cast<T>(count--); + } + } + + // Execute the instruction. + inst->Execute(); + + // Check the results. + auto mask_span = tester->vreg()[kVmask]->data_buffer()->Get<T>(); + count = 0; + for (int i = 0; i < vlen; i++) { + int reg = kVd + i / num_per_reg; + int reg_index = i % num_per_reg; + auto value = tester->vreg()[reg]->data_buffer()->Get<T>(reg_index); + int mask_index = i >> 3; + int mask_offset = i & 0b111; + int mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; + if (mask_value) { + EXPECT_EQ(value, static_cast<T>(i)) << "active index: " << i; + count++; + } else { + EXPECT_EQ(value, static_cast<T>(vlen - i)) << "inactive index: " << i; + } + } +} + +// Vid instructions. +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vid8) { + TestVid<uint8_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vid16) { + TestVid<uint16_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vid32) { + TestVid<uint32_t>(this, instruction_); +} + +TEST_F(RiscVCheriotVectorUnaryInstructionsTest, Vid64) { + TestVid<uint64_t>(this, instruction_); +} +} // namespace