zfh semantic functions: sqrt, convert, compare, classify PiperOrigin-RevId: 755967699 Change-Id: I4ea9bca29de12f2ad93137bd885e0bae14b614a6
diff --git a/riscv/riscv_instruction_helpers.h b/riscv/riscv_instruction_helpers.h index f408d13..8999576 100644 --- a/riscv/riscv_instruction_helpers.h +++ b/riscv/riscv_instruction_helpers.h
@@ -78,7 +78,10 @@ constexpr To kMin = std::numeric_limits<To>::min(); From lhs = generic::GetInstructionSource<From>(instruction, 0); - + using FromUint = typename FPTypeInfo<From>::UIntType; + FromUint lhs_u = *reinterpret_cast<FromUint *>(&lhs); + auto constexpr kExpMask = FPTypeInfo<From>::kExpMask; + auto constexpr kSigMask = FPTypeInfo<From>::kSigMask; uint32_t flags = 0; uint32_t rm = generic::GetInstructionSource<uint32_t>(instruction, 1); // Dynamic rounding mode will get rounding mode from the global state. @@ -91,10 +94,13 @@ rm = *rv_fp->GetRoundingMode(); } To value = 0; - if (FPTypeInfo<From>::IsNaN(lhs)) { + if (FPTypeInfo<From>::IsNaN(lhs) || lhs_u == FPTypeInfo<From>::kPosInf) { value = std::numeric_limits<To>::max(); flags = *FPExceptions::kInvalidOp; - } else if (lhs == 0.0) { + } else if (lhs_u == FPTypeInfo<From>::kNegInf) { + value = std::numeric_limits<To>::min(); + flags = *FPExceptions::kInvalidOp; + } else if ((lhs_u & (kExpMask | kSigMask)) == 0) { // lhs == 0.0 value = 0; } else { // static_cast<>() doesn't necessarily round, so will have to force
diff --git a/riscv/riscv_zfh.isa b/riscv/riscv_zfh.isa index f906733..abd0fa8 100644 --- a/riscv/riscv_zfh.isa +++ b/riscv/riscv_zfh.isa
@@ -144,6 +144,26 @@ resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFmax", disasm: "fmax.h", "%frd, %frs1, %frs2"; + fsqrt_h{: frs1, rm : frd, fflags}, + resources: {next_pc, frs1 : frd[0..]}, + semfunc: "&RiscVZfhFsqrt", + disasm: "fsqrt.h", "%frd, %frs1"; + fcvt_hw{: rs1, rm : frd, fflags}, + resources: {next_pc, frs1 : frd[0..]}, + semfunc: "&RV32::RiscVZfhCvtHw", + disasm: "fcvt.h.w", "%frd, %rs1"; + fcvt_wh{: frs1, rm : rd, fflags}, + resources: {next_pc, frs1 : rd[0..]}, + semfunc: "&RV32::RiscVZfhCvtWh", + disasm: "fcvt.w.h", "%rd, %frs1"; + fcvt_hwu{: rs1, rm : frd, fflags}, + resources: {next_pc, frs1 : frd[0..]}, + semfunc: "&RV32::RiscVZfhCvtHwu", + disasm: "fcvt.h.wu", "%frd, %rs1"; + fcvt_wuh{: frs1, rm : rd, fflags}, + resources: {next_pc, frs1 : rd[0..]}, + semfunc: "&RV32::RiscVZfhCvtWuh", + disasm: "fcvt.wu.h", "%rd, %frs1"; fsgnj_h{: frs1, frs2 : frd }, resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFsgnj", @@ -156,5 +176,21 @@ resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFsgnjx", disasm: "fsgnjnx.h", "%frd, %frs1, %frs2"; + fcmpeq_h{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVZfhFcmpeq", + disasm: "feq.h", "%rd, %frs1, %frs2"; + fcmplt_h{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVZfhFcmplt", + disasm: "flt.h", "%rd, %frs1, %frs2"; + fcmple_h{: frs1, frs2 : rd, fflags}, + resources: { next_pc, frs1, frs2 : rd[0..]}, + semfunc: "&RV32::RiscVZfhFcmple", + disasm: "fle.h", "%rd, %frs1, %frs2"; + fclass_h{: frs1 : rd}, + resources: { next_pc, frs1 : rd[0..]}, + semfunc: "&RV32::RiscVZfhFclass", + disasm: "fclass.h", "%rd, %frs1"; } }
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc index 4391530..928bf24 100644 --- a/riscv/riscv_zfh_instructions.cc +++ b/riscv/riscv_zfh_instructions.cc
@@ -67,22 +67,22 @@ template <> struct DataTypeRegValue<int32_t> { - using type = RVXRegister::ValueType; + using type = RV32Register::ValueType; }; template <> struct DataTypeRegValue<uint32_t> { - using type = RVXRegister::ValueType; + using type = RV32Register::ValueType; }; template <> struct DataTypeRegValue<int64_t> { - using type = RVXRegister::ValueType; + using type = RV64Register::ValueType; }; template <> struct DataTypeRegValue<uint64_t> { - using type = RVXRegister::ValueType; + using type = RV64Register::ValueType; }; // Convert from half precision to single or double precision. @@ -216,6 +216,54 @@ return ConvertDoubleToHalfFP(input_value, rm, fflags); } +// Generic helper function enabling HalfFP operations in native datatypes. +template <typename Result, typename Argument> +void RiscVZfhUnaryHelper( + const Instruction *instruction, + std::function<Result(Argument, FPRoundingMode, uint32_t &)> operation) { + uint32_t fflags = 0; + RiscVFPState *rv_fp = + static_cast<RiscVState *>(instruction->state())->rv_fp(); + int rm_value = generic::GetInstructionSource<int>(instruction, 1); + + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + return; + } + rm_value = *(rv_fp->GetRoundingMode()); + } + FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); + RiscVCsrDestinationOperand *fflags_dest = + static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); + bool arguments_contain_snan = false; + RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType, + Result, Argument>( + instruction, + [rv_fp, rm, &fflags, &operation, + &arguments_contain_snan](Argument a) -> Result { + Result result; + if (FPTypeInfo<Argument>::IsSNaN(a)) { + arguments_contain_snan = true; + } + if (zfh_internal::UseHostFlagsForConversion()) { + result = operation(a, rm, fflags); + } else { + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm); + result = operation(a, rm, fflags); + } + return result; + }); + if (!zfh_internal::UseHostFlagsForConversion()) { + fflags_dest->GetRiscVCsr()->SetBits(fflags); + } + if (arguments_contain_snan) { + fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp); + } +} + +// Generic helper function enabling HalfFP operations in native datatypes. template <typename Argument, typename IntermediateType> void RiscVZfhBinaryHelper( const Instruction *instruction, @@ -302,6 +350,109 @@ }); } +// Convert from half precision to integer. +void RiscVZfhCvtWh(const Instruction *instruction) { + RiscVConvertFloatWithFflagsOp<typename RV32Register::ValueType, HalfFP, + int32_t>(instruction); +} + +// Convert from integer to half precision. +void RiscVZfhCvtHw(const Instruction *instruction) { + RiscVZfhCvtHelper<HalfFP, int32_t>( + instruction, + [](int32_t a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { + float input_float = static_cast<float>(a); + return ConvertToHalfFP(input_float, rm, fflags); + }); +} + +// Convert from unsigned integer to half precision. +void RiscVZfhCvtHwu(const Instruction *instruction) { + RiscVZfhCvtHelper<HalfFP, uint32_t>( + instruction, + [](uint32_t a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { + float input_float = static_cast<float>(a); + return ConvertToHalfFP(input_float, rm, fflags); + }); +} + +// Convert from half precision to unsigned integer. +void RiscVZfhCvtWuh(const Instruction *instruction) { + RiscVConvertFloatWithFflagsOp<typename RV32Register::ValueType, HalfFP, + uint32_t>(instruction); +} + +// Compare two half precision values for equality. +void RiscVZfhFcmpeq(const Instruction *instruction) { + RiscVBinaryFloatNaNBoxOp<RVFpRegister::ValueType, uint64_t, HalfFP>( + instruction, [](HalfFP a, HalfFP b) -> uint64_t { + float a_f; + float b_f; + uint32_t unused_fflags = 0; + if (FPTypeInfo<HalfFP>::IsSNaN(a)) { + a_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + a_f = ConvertFromHalfFP<float>(a, unused_fflags); + } + if (FPTypeInfo<HalfFP>::IsSNaN(b)) { + b_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + b_f = ConvertFromHalfFP<float>(b, unused_fflags); + } + return a_f == b_f ? 1 : 0; + }); +} + +// Compare two half precision values for less than. +void RiscVZfhFcmplt(const Instruction *instruction) { + RiscVBinaryFloatNaNBoxOp<RVFpRegister::ValueType, uint64_t, HalfFP>( + instruction, [](HalfFP a, HalfFP b) -> uint64_t { + float a_f; + float b_f; + uint32_t unused_fflags = 0; + if (FPTypeInfo<HalfFP>::IsNaN(a)) { + a_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + a_f = ConvertFromHalfFP<float>(a, unused_fflags); + } + if (FPTypeInfo<HalfFP>::IsNaN(b)) { + b_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + b_f = ConvertFromHalfFP<float>(b, unused_fflags); + } + return a_f < b_f ? 1 : 0; + }); +} + +// Compare two half precision values for less than or equal to. +void RiscVZfhFcmple(const Instruction *instruction) { + RiscVBinaryFloatNaNBoxOp<RVFpRegister::ValueType, uint64_t, HalfFP>( + instruction, [](HalfFP a, HalfFP b) -> uint64_t { + float a_f; + float b_f; + uint32_t unused_fflags = 0; + if (FPTypeInfo<HalfFP>::IsNaN(a)) { + a_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + a_f = ConvertFromHalfFP<float>(a, unused_fflags); + } + if (FPTypeInfo<HalfFP>::IsNaN(b)) { + b_f = absl::bit_cast<float>(FPTypeInfo<float>::kPosInf | 1); + } else { + b_f = ConvertFromHalfFP<float>(b, unused_fflags); + } + return a_f <= b_f ? 1 : 0; + }); +} + +// Classify a half precision value. +void RiscVZfhFclass(const Instruction *instruction) { + RiscVUnaryOp<RV32Register, uint32_t, HalfFP>( + instruction, [](HalfFP a) -> uint32_t { + return static_cast<uint32_t>(ClassifyFP(a)); + }); +} + } // namespace RV32 namespace RV64 { @@ -438,6 +589,19 @@ }); } +// Calculate the square root of a half precision value. Do the operation in +// single precision and then convert back to half precision. +void RiscVZfhFsqrt(const Instruction *instruction) { + RiscVZfhUnaryHelper<HalfFP, HalfFP>( + instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { + float input_f = ConvertFromHalfFP<float>(a, fflags); + if (!std::isnan(input_f) && input_f < 0) { + fflags |= static_cast<uint32_t>(FPExceptions::kInvalidOp); + } + return ConvertToHalfFP(std::sqrt(input_f), rm, fflags); + }); +} + // The result is the exponent and significand of the first source with the // sign bit of the second source. void RiscVZfhFsgnj(const Instruction *instruction) {
diff --git a/riscv/riscv_zfh_instructions.h b/riscv/riscv_zfh_instructions.h index 27a717d..08e33cf 100644 --- a/riscv/riscv_zfh_instructions.h +++ b/riscv/riscv_zfh_instructions.h
@@ -34,6 +34,69 @@ // Destination Operands: // rd: Integer Register void RiscVZfhFMvxh(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// rd: Integer Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhCvtWh(const Instruction *instruction); + +// Source Operands: +// rs1: Integer Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhCvtHw(const Instruction *instruction); + +// Source Operands: +// rs1: Integer Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhCvtHwu(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// rd: Integer Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhCvtWuh(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// Destination Operands: +// rd: Integer Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFcmpeq(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// Destination Operands: +// rd: Integer Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFcmplt(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// Destination Operands: +// rd: Integer Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFcmple(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// Destination Operands: +// rd: Integer Register +void RiscVZfhFclass(const Instruction *instruction); + } // namespace RV32 namespace RV64 { @@ -173,6 +236,14 @@ // frd: Float Register void RiscVZfhFsgnjx(const Instruction *instruction); +// Source Operands: +// frs1: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFsqrt(const Instruction *instruction); + } // namespace riscv } // namespace sim } // namespace mpact
diff --git a/riscv/test/BUILD b/riscv/test/BUILD index fd65642..436a112 100644 --- a/riscv/test/BUILD +++ b/riscv/test/BUILD
@@ -226,6 +226,10 @@ "//riscv:riscv_state", "//riscv:riscv_zfh_instructions", "@com_google_absl//absl/base", + "@com_google_absl//absl/random:distributions", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@com_google_googletest//:gtest_main", "@com_google_mpact-sim//mpact/sim/generic:core", "@com_google_mpact-sim//mpact/sim/generic:instruction",
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc index 7f8aa01..372ddbd 100644 --- a/riscv/test/riscv_zfh_instructions_test.cc +++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -21,12 +21,19 @@ #include <cassert> #include <cmath> #include <cstdint> +#include <functional> #include <ios> +#include <limits> #include <string> #include <tuple> +#include <type_traits> #include <vector> #include "absl/base/casts.h" +#include "absl/random/distributions.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" #include "googlemock/include/gmock/gmock.h" #include "mpact/sim/generic/data_buffer.h" #include "mpact/sim/generic/immediate_operand.h" @@ -64,15 +71,28 @@ using ::mpact::sim::riscv::RiscVZfhFsgnj; using ::mpact::sim::riscv::RiscVZfhFsgnjn; using ::mpact::sim::riscv::RiscVZfhFsgnjx; +using ::mpact::sim::riscv::RiscVZfhFsqrt; using ::mpact::sim::riscv::RiscVZfhFsub; using ::mpact::sim::riscv::RV32Register; using ::mpact::sim::riscv::RV64Register; using ::mpact::sim::riscv::RVFpRegister; -using ::mpact::sim::riscv::RVXRegister; +using ::mpact::sim::riscv::ScopedFPRoundingMode; using ::mpact::sim::riscv::RV32::RiscVILhu; +using ::mpact::sim::riscv::RV32::RiscVZfhCvtHw; +using ::mpact::sim::riscv::RV32::RiscVZfhCvtHwu; +using ::mpact::sim::riscv::RV32::RiscVZfhCvtWh; +using ::mpact::sim::riscv::RV32::RiscVZfhCvtWuh; +using ::mpact::sim::riscv::RV32::RiscVZfhFclass; +using ::mpact::sim::riscv::RV32::RiscVZfhFcmpeq; +using ::mpact::sim::riscv::RV32::RiscVZfhFcmple; +using ::mpact::sim::riscv::RV32::RiscVZfhFcmplt; using ::mpact::sim::riscv::RV32::RiscVZfhFMvxh; + +using ::mpact::sim::riscv::test::FloatingPointToString; +using ::mpact::sim::riscv::test::FPCompare; using ::mpact::sim::riscv::test::FpConversionsTestHelper; using ::mpact::sim::riscv::test::FPTypeInfo; +using ::mpact::sim::riscv::test::kTestValueLength; using ::mpact::sim::riscv::test::RiscVFPInstructionTestBase; const int kRoundingModeRoundToNearest = @@ -89,6 +109,13 @@ template <typename ReturnType, typename IntegerRegister> ReturnType LoadHalfHelper(typename IntegerRegister::ValueType, int16_t); + + template <typename DestRegisterType, typename LhsRegisterType, typename R, + typename LHS> + void UnaryOpWithFflagsMixedTestHelper( + absl::string_view name, Instruction *inst, + absl::Span<const absl::string_view> reg_prefixes, int delta_position, + std::function<std::tuple<R, uint32_t>(LHS, uint32_t)> operation); }; template <typename AddressType, typename ValueType> @@ -126,6 +153,101 @@ return observed_val; } +// Helper for unary instructions that go between floats and integers. +template <typename DestRegisterType, typename LhsRegisterType, typename R, + typename LHS> +void RVZfhInstructionTestBase::UnaryOpWithFflagsMixedTestHelper( + absl::string_view name, Instruction *inst, + absl::Span<const absl::string_view> reg_prefixes, int delta_position, + std::function<std::tuple<R, uint32_t>(LHS, uint32_t)> operation) { + using LhsInt = typename FPTypeInfo<LHS>::IntType; + using RInt = typename FPTypeInfo<R>::IntType; + LHS lhs_values[kTestValueLength]; + auto lhs_span = absl::Span<LHS>(lhs_values); + const std::string kR1Name = absl::StrCat(reg_prefixes[0], 1); + const std::string kRdName = absl::StrCat(reg_prefixes[1], 5); + // This is used for the rounding mode operand. + const std::string kRmName = absl::StrCat("x", 10); + if (kR1Name[0] == 'x') { + AppendRegisterOperands<RV32Register>({kR1Name}, {}); + } else { + AppendRegisterOperands<RVFpRegister>({kR1Name}, {}); + } + if (kRdName[0] == 'x') { + AppendRegisterOperands<RV32Register>({}, {kRdName}); + } else { + AppendRegisterOperands<RVFpRegister>({}, {kRdName}); + } + AppendRegisterOperands<RV32Register>({kRmName}, {}); + auto *flag_op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags"); + instruction_->AppendDestination(flag_op); + if constexpr (std::is_integral<LHS>::value) { + for (auto &lhs : lhs_span) { + lhs = absl::Uniform(absl::IntervalClosed, bitgen_, + std::numeric_limits<LHS>::min(), + std::numeric_limits<LHS>::max()); + } + *reinterpret_cast<LHS *>(&lhs_span[0]) = 0; + *reinterpret_cast<LHS *>(&lhs_span[1]) = 1; + *reinterpret_cast<LHS *>(&lhs_span[2]) = 2; + *reinterpret_cast<LHS *>(&lhs_span[3]) = 4; + *reinterpret_cast<LHS *>(&lhs_span[4]) = 8; + *reinterpret_cast<LHS *>(&lhs_span[5]) = 16; + *reinterpret_cast<LHS *>(&lhs_span[6]) = 1024; + *reinterpret_cast<LHS *>(&lhs_span[7]) = 65000; + } else { + FillArrayWithRandomFPValues<LHS>(lhs_span); + *reinterpret_cast<LhsInt *>(&lhs_span[0]) = FPTypeInfo<LHS>::kQNaN; + *reinterpret_cast<LhsInt *>(&lhs_span[1]) = FPTypeInfo<LHS>::kSNaN; + *reinterpret_cast<LhsInt *>(&lhs_span[2]) = FPTypeInfo<LHS>::kPosInf; + *reinterpret_cast<LhsInt *>(&lhs_span[3]) = FPTypeInfo<LHS>::kNegInf; + *reinterpret_cast<LhsInt *>(&lhs_span[4]) = FPTypeInfo<LHS>::kPosZero; + *reinterpret_cast<LhsInt *>(&lhs_span[5]) = FPTypeInfo<LHS>::kNegZero; + *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm; + *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm; + } + for (int i = 0; i < kTestValueLength; i++) { + if constexpr (std::is_integral<LHS>::value) { + SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + } else { + SetNaNBoxedRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + } + + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + SetRegisterValues<int, RV32Register>({{kRmName, rm}, {}}); + SetRegisterValues<typename DestRegisterType::ValueType, DestRegisterType>( + {{kRdName, 0}}); + + inst->Execute(nullptr); + auto instruction_fflags = rv_fp_->fflags()->GetUint32(); + + R op_val; + uint32_t test_operation_fflags; + { + ScopedFPRoundingMode scoped_rm(rv_fp_->host_fp_interface(), rm); + std::tie(op_val, test_operation_fflags) = operation(lhs_span[i], rm); + } + + auto reg_val = state_->GetRegister<DestRegisterType>(kRdName) + .first->data_buffer() + ->template Get<R>(0); + FPCompare<R>( + op_val, reg_val, delta_position, + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " rm: ", rm)); + LhsInt lhs_uint = absl::bit_cast<LhsInt>(lhs_span[i]); + RInt op_val_uint = absl::bit_cast<RInt>(op_val); + EXPECT_EQ(test_operation_fflags, instruction_fflags) + << name << "(" << FloatingPointToString<LHS>(lhs_span[i]) << ") " + << std::hex << name << "(0x" << lhs_uint + << ") == " << FloatingPointToString<R>(op_val) << std::hex << " 0x" + << op_val_uint << " rm: " << rm; + } + } +} + // A source operand that is used to set the rounding mode. This is less // confusing than using a register source operand since the rounding mode is // part of the instruction encoding. @@ -767,6 +889,7 @@ }); } +// Test sign injection for half precision values. TEST_F(RV32ZfhInstructionTest, RiscVZfhFsgnj) { SetSemanticFunction(&RiscVZfhFsgnj); BinaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP>( @@ -781,6 +904,7 @@ }); } +// Test sign injection for half precision values with the opposite sign bit. TEST_F(RV32ZfhInstructionTest, RiscVZfhFsgnjn) { SetSemanticFunction(&RiscVZfhFsgnjn); BinaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP>( @@ -795,6 +919,7 @@ }); } +// Test sign injection for half precision values with the xor sign bit.. TEST_F(RV32ZfhInstructionTest, RiscVZfhFsgnjx) { SetSemanticFunction(&RiscVZfhFsgnjx); BinaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP>( @@ -809,6 +934,218 @@ }); } +// Test square root for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFsqrt) { + SetSemanticFunction(&RiscVZfhFsqrt); + UnaryOpWithFflagsFPTestHelper<HalfFP, HalfFP>( + "fsqrt.h", instruction_, {"f", "f"}, 32, + [](HalfFP input_half, int rm) -> std::tuple<HalfFP, uint32_t> { + uint32_t fflags = 0; + double input_double_f = FpConversionsTestHelper(input_half) + .ConvertWithFlags<double>(fflags); + HalfFP result; + if (FPTypeInfo<HalfFP>::IsNaN(input_half)) { + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + if (!FPTypeInfo<HalfFP>::IsQNaN(input_half)) { + fflags |= static_cast<uint32_t>( + mpact::sim::riscv::FPExceptions::kInvalidOp); + } + } else if (std::isinf(input_double_f) && input_double_f > 0) { + result.value = FPTypeInfo<HalfFP>::kPosInf; + } else if (input_double_f < 0) { + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + fflags |= static_cast<uint32_t>( + mpact::sim::riscv::FPExceptions::kInvalidOp); + } else { + result = FpConversionsTestHelper(std::sqrt(input_double_f)) + .ConvertWithFlags<HalfFP>( + fflags, static_cast<FPRoundingMode>(rm)); + } + return std::make_tuple(result, fflags); + }); +} + +// Test conversion from signed 32 bit integer to half precision. +TEST_F(RV32ZfhInstructionTest, RiscVZfhCvtHw) { + SetSemanticFunction(&RiscVZfhCvtHw); + UnaryOpWithFflagsMixedTestHelper<RVFpRegister, RV32Register, HalfFP, int32_t>( + "fcvt.h.w", instruction_, {"x", "f"}, 32, + [](int32_t input_int, int rm) -> std::tuple<HalfFP, uint32_t> { + uint32_t fflags = 0; + HalfFP result = FpConversionsTestHelper(static_cast<double>(input_int)) + .ConvertWithFlags<HalfFP>( + fflags, static_cast<FPRoundingMode>(rm)); + return std::make_tuple(result, fflags); + }); +} + +// Test conversion from half precision to signed 32 bit integer. +TEST_F(RV32ZfhInstructionTest, RiscVZfhCvtWh) { + SetSemanticFunction(&RiscVZfhCvtWh); + UnaryOpWithFflagsMixedTestHelper<RV32Register, RVFpRegister, int32_t, HalfFP>( + "fcvt.w.h", instruction_, {"f", "x"}, 32, + [this](HalfFP input, int rm) -> std::tuple<int32_t, uint32_t> { + uint32_t fflags = 0; + double input_double = + FpConversionsTestHelper(input).ConvertWithFlags<double>(fflags); + const int32_t val = + RoundToInteger<double, int32_t>(input_double, rm, fflags); + return std::make_tuple(val, fflags); + }); +} + +// Test conversion from unsigned 32 bit integer to half precision. +TEST_F(RV32ZfhInstructionTest, RiscVZfhCvtHwu) { + SetSemanticFunction(&RiscVZfhCvtHwu); + UnaryOpWithFflagsMixedTestHelper<RVFpRegister, RV32Register, HalfFP, + uint32_t>( + "fcvt.h.wu", instruction_, {"x", "f"}, 32, + [](uint32_t input_int, int rm) -> std::tuple<HalfFP, uint32_t> { + uint32_t fflags = 0; + HalfFP result = FpConversionsTestHelper(static_cast<double>(input_int)) + .ConvertWithFlags<HalfFP>( + fflags, static_cast<FPRoundingMode>(rm)); + return std::make_tuple(result, fflags); + }); +} + +// Test conversion from half precision to unsigned 32 bit integer. +TEST_F(RV32ZfhInstructionTest, RiscVZfhCvtWuh) { + SetSemanticFunction(&RiscVZfhCvtWuh); + UnaryOpWithFflagsMixedTestHelper<RV32Register, RVFpRegister, uint32_t, + HalfFP>( + "fcvt.wu.h", instruction_, {"f", "x"}, 32, + [this](HalfFP input, int rm) -> std::tuple<uint32_t, uint32_t> { + uint32_t fflags = 0; + double input_double = + FpConversionsTestHelper(input).ConvertWithFlags<double>(fflags); + const uint32_t val = + RoundToInteger<double, uint32_t>(input_double, rm, fflags); + return std::make_tuple(val, fflags); + }); +} + +// Test equality comparison for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFcmpeq) { + SetSemanticFunction(&RiscVZfhFcmpeq); + BinaryOpWithFflagsFPTestHelper<uint64_t, HalfFP, HalfFP>( + "feq.h", instruction_, {"f", "f", "x"}, 32, + [](HalfFP a, HalfFP b) -> std::tuple<uint64_t, uint32_t> { + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + uint64_t result = a_f == b_f ? 1 : 0; + if (std::isnan(a_f) || std::isnan(b_f)) { + result = 0; + } + return std::make_tuple(result, fflags); + }); +} + +// Test less than comparison for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFcmplt) { + SetSemanticFunction(&RiscVZfhFcmplt); + BinaryOpWithFflagsFPTestHelper<uint64_t, HalfFP, HalfFP>( + "flt.h", instruction_, {"f", "f", "x"}, 32, + [](HalfFP a, HalfFP b) -> std::tuple<uint64_t, uint32_t> { + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + uint64_t result = a_f < b_f ? 1 : 0; + if (std::isnan(a_f) || std::isnan(b_f)) { + result = 0; + // LT is a signaling comparison, so the invalid operation flag is + // set. + fflags |= static_cast<uint32_t>( + mpact::sim::riscv::FPExceptions::kInvalidOp); + } + return std::make_tuple(result, fflags); + }); +} + +// Test less than or equal to comparison for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFcmple) { + SetSemanticFunction(&RiscVZfhFcmple); + BinaryOpWithFflagsFPTestHelper<uint64_t, HalfFP, HalfFP>( + "fle.h", instruction_, {"f", "f", "x"}, 32, + [](HalfFP a, HalfFP b) -> std::tuple<uint64_t, uint32_t> { + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + uint64_t result = a_f <= b_f ? 1 : 0; + if (std::isnan(a_f) || std::isnan(b_f)) { + result = 0; + // LE is a signaling comparison, so the invalid operation flag is + // set. + fflags |= static_cast<uint32_t>( + mpact::sim::riscv::FPExceptions::kInvalidOp); + } + return std::make_tuple(result, fflags); + }); +} + +// Test classification of half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFclass) { + SetSemanticFunction(&RiscVZfhFclass); + UnaryOpWithFflagsMixedTestHelper<RV32Register, RVFpRegister, uint32_t, + HalfFP>( + "fclass.h", instruction_, {"f", "x"}, 32, + [](HalfFP input, int rm) -> std::tuple<uint32_t, uint32_t> { + uint16_t sign_mask = + ~(FPTypeInfo<HalfFP>::kExpMask | FPTypeInfo<HalfFP>::kSigMask); + uint16_t sign = input.value & sign_mask; + int shift = -1; + + switch (input.value & FPTypeInfo<HalfFP>::kExpMask) { + case 0: + if (input.value & FPTypeInfo<HalfFP>::kSigMask) { + if (sign) { + shift = 2; // Negative subnormal + } else { + shift = 5; // Positive subnormal + } + } else { + if (sign) { + shift = 3; // Negative zero + } else { + shift = 4; // Positive zero + } + } + break; + case FPTypeInfo<HalfFP>::kExpMask: + if (input.value & FPTypeInfo<HalfFP>::kSigMask) { + if (FPTypeInfo<HalfFP>::IsQNaN(input)) { + shift = 9; // Quiet NaN + } else { + shift = 8; // Signaling NaN + } + } else { // Inf + if (sign) { + shift = 0; // Negative infinity + } else { + shift = 7; // Positive infinity + } + } + break; + default: + if (sign) { + shift = 1; // Negative normal + } else { + shift = 6; // Positive normal + } + break; + } + EXPECT_GE(shift, 0) << "The test didn't set the expected result."; + return std::make_tuple(1 << shift, 0); + }); +} + class RV64ZfhInstructionTest : public RVZfhInstructionTestBase {}; // Move half precision from a float register to an integer register. The IEEE754