zfh: fused multiply add semantic functions PiperOrigin-RevId: 757923526 Change-Id: I870093d42c020626bdde16e06a029842e655ffd0
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc index f4b8ad2..484173e 100644 --- a/riscv/riscv_zfh_instructions.cc +++ b/riscv/riscv_zfh_instructions.cc
@@ -445,6 +445,59 @@ fflags_dest->GetRiscVCsr()->SetBits(fflags); } +// Generic helper function enabling HalfFP operations in native datatypes. +template <typename Argument, typename IntermediateType> +void RiscVZfhTernaryHelper( + const Instruction *instruction, + std::function<IntermediateType(IntermediateType, IntermediateType, + IntermediateType)> + operation) { + RiscVCsrDestinationOperand *fflags_dest = + static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); + uint32_t fflags = 0; + // RiscVTernaryFloatNaNBoxOp will handle the register NaN boxed reads and + // write. The operation is in a native datatype so we will handle conversions + // from/to half precision float values before and after the operation. + RiscVTernaryFloatNaNBoxOp<RVFpRegister::ValueType, HalfFP, Argument>( + instruction, + [instruction, &operation, &fflags](Argument a, Argument b, + Argument c) -> HalfFP { + RiscVFPState *rv_fp = + static_cast<RiscVState *>(instruction->state())->rv_fp(); + int rm_value = generic::GetInstructionSource<int>(instruction, 3); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + } + rm_value = *(rv_fp->GetRoundingMode()); + } + FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); + IntermediateType argument1 = + ConvertFromHalfFP<IntermediateType>(a, fflags); + IntermediateType argument2 = + ConvertFromHalfFP<IntermediateType>(b, fflags); + IntermediateType argument3 = + ConvertFromHalfFP<IntermediateType>(c, fflags); + IntermediateType result; + { + ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm_value); + result = operation(argument1, argument2, argument3); + } + // To get the correct fflags we need a combination of host flags from + // the operation and the conversion flags. Copy the host flags and merge + // them with the conversion flags. + fflags |= rv_fp->fflags()->GetUint32(); + { + // ConvertToHalfFP pollutes the host flags so we need to create a + // ScopedFPRoundingMode to restore the host flags. + ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value); + return ConvertToHalfFP(result, rm, fflags); + } + }); + fflags_dest->GetRiscVCsr()->SetBits(fflags); +} + } // namespace namespace RV32 { @@ -743,6 +796,38 @@ }); } +// Fused multiply add in half precision. Do the operation in single precision. +// (rs1 * rs2) + rs3 +void RiscVZfhFmadd(const Instruction *instruction) { + RiscVZfhTernaryHelper<HalfFP, float>( + instruction, + [](float a, float b, float c) -> float { return fma(a, b, c); }); +} + +// Fused multiply add in half precision. Do the operation in single precision. +// (rs1 * rs2) - rs3 +void RiscVZfhFmsub(const Instruction *instruction) { + RiscVZfhTernaryHelper<HalfFP, float>( + instruction, + [](float a, float b, float c) -> float { return fma(a, b, -c); }); +} + +// Fused multiply add in half precision. Do the operation in single precision. +// -(rs1 * rs2) - rs3 +void RiscVZfhFnmadd(const Instruction *instruction) { + RiscVZfhTernaryHelper<HalfFP, float>( + instruction, + [](float a, float b, float c) -> float { return fma(-a, b, -c); }); +} + +// Fused multiply add in half precision. Do the operation in single precision. +// -(rs1 * rs2) + rs3 +void RiscVZfhFnmsub(const Instruction *instruction) { + RiscVZfhTernaryHelper<HalfFP, float>( + instruction, + [](float a, float b, float c) -> float { return fma(-a, b, c); }); +} + // TODO(b/409778536): Factor out generic unimplemented instruction semantic // function. void RV32VUnimplementedInstruction(const Instruction *instruction) {
diff --git a/riscv/riscv_zfh_instructions.h b/riscv/riscv_zfh_instructions.h index 6337a31..20d5261 100644 --- a/riscv/riscv_zfh_instructions.h +++ b/riscv/riscv_zfh_instructions.h
@@ -234,6 +234,46 @@ // fflags: Accrued Exception Flags field in FCSR void RiscVZfhFsqrt(const Instruction *instruction); +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// frs3: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFmadd(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// frs3: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFmsub(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// frs3: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFnmadd(const Instruction *instruction); + +// Source Operands: +// frs1: Float Register +// frs2: Float Register +// frs3: Float Register +// rm: Literal Operand (rounding mode) +// Destination Operands: +// frd: Float Register +// fflags: Accrued Exception Flags field in FCSR +void RiscVZfhFnmsub(const Instruction *instruction); + } // namespace riscv } // namespace sim } // namespace mpact
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc index 65e01ce..dee6408 100644 --- a/riscv/test/riscv_zfh_instructions_test.cc +++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -56,6 +56,7 @@ using ::mpact::sim::generic::HalfFP; using ::mpact::sim::generic::ImmediateOperand; using ::mpact::sim::generic::Instruction; +using ::mpact::sim::riscv::FPExceptions; using ::mpact::sim::riscv::FPRoundingMode; using ::mpact::sim::riscv::RiscVZfhCvtDh; using ::mpact::sim::riscv::RiscVZfhCvtHd; @@ -64,10 +65,14 @@ using ::mpact::sim::riscv::RiscVZfhFadd; using ::mpact::sim::riscv::RiscVZfhFdiv; using ::mpact::sim::riscv::RiscVZfhFlhChild; +using ::mpact::sim::riscv::RiscVZfhFmadd; using ::mpact::sim::riscv::RiscVZfhFmax; using ::mpact::sim::riscv::RiscVZfhFmin; +using ::mpact::sim::riscv::RiscVZfhFmsub; using ::mpact::sim::riscv::RiscVZfhFmul; using ::mpact::sim::riscv::RiscVZfhFMvhx; +using ::mpact::sim::riscv::RiscVZfhFnmadd; +using ::mpact::sim::riscv::RiscVZfhFnmsub; using ::mpact::sim::riscv::RiscVZfhFsgnj; using ::mpact::sim::riscv::RiscVZfhFsgnjn; using ::mpact::sim::riscv::RiscVZfhFsgnjx; @@ -77,6 +82,7 @@ using ::mpact::sim::riscv::RV64Register; using ::mpact::sim::riscv::RVFpRegister; using ::mpact::sim::riscv::ScopedFPRoundingMode; +using ::mpact::sim::riscv::ScopedFPStatus; using ::mpact::sim::riscv::RV32::RiscVILhu; using ::mpact::sim::riscv::RV32::RiscVZfhCvtHw; using ::mpact::sim::riscv::RV32::RiscVZfhCvtHwu; @@ -102,6 +108,43 @@ const int kRoundingModeRoundDown = static_cast<int>(FPRoundingMode::kRoundDown); const int kRoundingModeRoundUp = static_cast<int>(FPRoundingMode::kRoundUp); +// A source operand that is used to set the rounding mode. This is less +// confusing than using a register source operand since the rounding mode is +// part of the instruction encoding. +class TestRoundingModeSourceOperand + : public mpact::sim::generic::SourceOperandInterface { + public: + explicit TestRoundingModeSourceOperand() + : rounding_mode_(FPRoundingMode::kRoundToNearest) {} + + void SetRoundingMode(FPRoundingMode rounding_mode) { + rounding_mode_ = rounding_mode; + } + + bool AsBool(int) override { return static_cast<bool>(rounding_mode_); } + int8_t AsInt8(int) override { return static_cast<int8_t>(rounding_mode_); } + uint8_t AsUint8(int) override { return static_cast<uint8_t>(rounding_mode_); } + int16_t AsInt16(int) override { return static_cast<int16_t>(rounding_mode_); } + uint16_t AsUint16(int) override { + return static_cast<uint16_t>(rounding_mode_); + } + int32_t AsInt32(int) override { return static_cast<int32_t>(rounding_mode_); } + uint32_t AsUint32(int) override { + return static_cast<uint32_t>(rounding_mode_); + } + int64_t AsInt64(int) override { return static_cast<int64_t>(rounding_mode_); } + uint64_t AsUint64(int) override { + return static_cast<uint64_t>(rounding_mode_); + } + + std::vector<int> shape() const override { return {1}; } + std::string AsString() const override { return std::string(""); } + std::any GetObject() const override { return std::any(); } + + protected: + FPRoundingMode rounding_mode_; +}; + class RVZfhInstructionTestBase : public RiscVFPInstructionTestBase { protected: template <typename AddressType, typename ValueType> @@ -116,6 +159,14 @@ absl::string_view name, Instruction *inst, absl::Span<const absl::string_view> reg_prefixes, int delta_position, std::function<std::tuple<R, uint32_t>(LHS, uint32_t)> operation); + + template <typename R, typename LHS, typename MHS, typename RHS> + void TernaryOpWithFflagsFPTestHelper( + absl::string_view name, Instruction *inst, + absl::Span<const absl::string_view> reg_prefixes, int delta_position, + std::function<std::tuple<R, uint32_t>(LHS, MHS, RHS)> operation); + + uint32_t GetOperationFlags(std::function<void(void)> operation); }; template <typename AddressType, typename ValueType> @@ -153,6 +204,25 @@ return observed_val; } +// Helper for statements that set the host FPU flags. This is used to +// determine the flags set by an operation. Enables targeted capture of flags +// that are set by an operation. The Simulation flags are restored to the +// initial state after the operation is executed. +uint32_t RVZfhInstructionTestBase::GetOperationFlags( + std::function<void(void)> operation) { + uint32_t initial_fflags = rv_fp_->fflags()->GetUint32(); + uint32_t delta_fflags = 0; + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + { + ScopedFPStatus sfpstatus(rv_fp_->host_fp_interface(), + rv_fp_->GetRoundingMode()); + operation(); + } + delta_fflags = rv_fp_->fflags()->GetUint32(); + rv_fp_->fflags()->Write(initial_fflags); + return delta_fflags; +} + // Helper for unary instructions that go between floats and integers. template <typename DestRegisterType, typename LhsRegisterType, typename R, typename LHS> @@ -248,42 +318,127 @@ } } -// A source operand that is used to set the rounding mode. This is less -// confusing than using a register source operand since the rounding mode is -// part of the instruction encoding. -class TestRoundingModeSourceOperand - : public mpact::sim::generic::SourceOperandInterface { - public: - explicit TestRoundingModeSourceOperand() - : rounding_mode_(FPRoundingMode::kRoundToNearest) {} - - void SetRoundingMode(FPRoundingMode rounding_mode) { - rounding_mode_ = rounding_mode; +template <typename R, typename LHS, typename MHS, typename RHS> +void RVZfhInstructionTestBase::TernaryOpWithFflagsFPTestHelper( + absl::string_view name, Instruction *inst, + absl::Span<const absl::string_view> reg_prefixes, int delta_position, + std::function<std::tuple<R, uint32_t>(LHS, MHS, RHS)> operation) { + using LhsRegisterType = RVFpRegister; + using MhsRegisterType = RVFpRegister; + using RhsRegisterType = RVFpRegister; + using DestRegisterType = RVFpRegister; + LHS lhs_values[kTestValueLength]; + MHS mhs_values[kTestValueLength]; + RHS rhs_values[kTestValueLength]; + auto lhs_span = absl::Span<LHS>(lhs_values); + auto mhs_span = absl::Span<MHS>(mhs_values); + auto rhs_span = absl::Span<RHS>(rhs_values); + const std::string kR1Name = absl::StrCat(reg_prefixes[0], 1); + const std::string kR2Name = absl::StrCat(reg_prefixes[1], 2); + const std::string kR3Name = absl::StrCat(reg_prefixes[2], 3); + const std::string kRdName = absl::StrCat(reg_prefixes[3], 5); + if (kR1Name[0] == 'x') { + AppendRegisterOperands<RV32Register>({kR1Name}, {}); + } else { + AppendRegisterOperands<RVFpRegister>({kR1Name}, {}); } - - bool AsBool(int) override { return static_cast<bool>(rounding_mode_); } - int8_t AsInt8(int) override { return static_cast<int8_t>(rounding_mode_); } - uint8_t AsUint8(int) override { return static_cast<uint8_t>(rounding_mode_); } - int16_t AsInt16(int) override { return static_cast<int16_t>(rounding_mode_); } - uint16_t AsUint16(int) override { - return static_cast<uint16_t>(rounding_mode_); + if (kR2Name[0] == 'x') { + AppendRegisterOperands<RV32Register>({kR2Name}, {}); + } else { + AppendRegisterOperands<RVFpRegister>({kR2Name}, {}); } - int32_t AsInt32(int) override { return static_cast<int32_t>(rounding_mode_); } - uint32_t AsUint32(int) override { - return static_cast<uint32_t>(rounding_mode_); + if (kR3Name[0] == 'x') { + AppendRegisterOperands<RV32Register>({kR3Name}, {}); + } else { + AppendRegisterOperands<RVFpRegister>({kR3Name}, {}); } - int64_t AsInt64(int) override { return static_cast<int64_t>(rounding_mode_); } - uint64_t AsUint64(int) override { - return static_cast<uint64_t>(rounding_mode_); + if (kRdName[0] == 'x') { + AppendRegisterOperands<RV32Register>({}, {kRdName}); + } else { + AppendRegisterOperands<RVFpRegister>({}, {kRdName}); } + TestRoundingModeSourceOperand *rm_source_operand = + new TestRoundingModeSourceOperand(); + instruction_->AppendSource(rm_source_operand); + auto *flag_op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags"); + instruction_->AppendDestination(flag_op); + FillArrayWithRandomFPValues<LHS>(lhs_span); + FillArrayWithRandomFPValues<MHS>(mhs_span); + FillArrayWithRandomFPValues<RHS>(rhs_span); + using LhsInt = typename FPTypeInfo<LHS>::IntType; + *reinterpret_cast<LhsInt *>(&lhs_span[0]) = FPTypeInfo<LHS>::kQNaN; + *reinterpret_cast<LhsInt *>(&lhs_span[1]) = FPTypeInfo<LHS>::kSNaN; + *reinterpret_cast<LhsInt *>(&lhs_span[2]) = FPTypeInfo<LHS>::kPosInf; + *reinterpret_cast<LhsInt *>(&lhs_span[3]) = FPTypeInfo<LHS>::kNegInf; + *reinterpret_cast<LhsInt *>(&lhs_span[4]) = FPTypeInfo<LHS>::kPosZero; + *reinterpret_cast<LhsInt *>(&lhs_span[5]) = FPTypeInfo<LHS>::kNegZero; + *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm; + *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm; + using MhsInt = typename FPTypeInfo<MHS>::IntType; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 0]) = FPTypeInfo<MHS>::kQNaN; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 1]) = FPTypeInfo<MHS>::kSNaN; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 2]) = FPTypeInfo<MHS>::kPosInf; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 3]) = FPTypeInfo<MHS>::kNegInf; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 4]) = FPTypeInfo<MHS>::kPosZero; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 5]) = FPTypeInfo<MHS>::kNegZero; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 6]) = FPTypeInfo<MHS>::kPosDenorm; + *reinterpret_cast<MhsInt *>(&mhs_span[8 + 7]) = FPTypeInfo<MHS>::kNegDenorm; + using RhsInt = typename FPTypeInfo<RHS>::IntType; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 0]) = FPTypeInfo<RHS>::kQNaN; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 1]) = FPTypeInfo<RHS>::kSNaN; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 2]) = FPTypeInfo<RHS>::kPosInf; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 3]) = FPTypeInfo<RHS>::kNegInf; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 4]) = FPTypeInfo<RHS>::kPosZero; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 5]) = FPTypeInfo<RHS>::kNegZero; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 6]) = FPTypeInfo<RHS>::kPosDenorm; + *reinterpret_cast<RhsInt *>(&rhs_span[16 + 7]) = FPTypeInfo<RHS>::kNegDenorm; + for (int i = 0; i < kTestValueLength; i++) { + SetNaNBoxedRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + SetNaNBoxedRegisterValues<MHS, MhsRegisterType>({{kR2Name, mhs_span[i]}}); + SetNaNBoxedRegisterValues<RHS, RhsRegisterType>({{kR3Name, rhs_span[i]}}); - std::vector<int> shape() const override { return {1}; } - std::string AsString() const override { return std::string(""); } - std::any GetObject() const override { return std::any(); } + for (int rm : {0, 1, 2, 3, 4}) { + rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rm_source_operand->SetRoundingMode(static_cast<FPRoundingMode>(rm)); + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + SetRegisterValues<DestRegisterType::ValueType, DestRegisterType>( + {{kRdName, 0}}); - protected: - FPRoundingMode rounding_mode_; -}; + inst->Execute(nullptr); + // Get the fflags for the instruction execution. + auto instruction_fflags = rv_fp_->fflags()->GetUint32(); + rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); + R op_val; + uint32_t test_operation_fflags; + { + ScopedFPRoundingMode scoped_rm(rv_fp_->host_fp_interface(), rm); + std::tie(op_val, test_operation_fflags) = + operation(lhs_span[i], mhs_span[i], rhs_span[i]); + } + auto reg_val = state_->GetRegister<DestRegisterType>(kRdName) + .first->data_buffer() + ->template Get<R>(0); + FPCompare<R>( + op_val, reg_val, delta_position, + absl::StrCat(name, " ", i, " (rm=", rm, + ") : ", FloatingPointToString<LHS>(lhs_span[i]), "(", + absl::StrFormat("%#x", lhs_span[i].value), ") ", + FloatingPointToString<MHS>(mhs_span[i]), "(", + absl::StrFormat("%#x", mhs_span[i].value), ") ", + FloatingPointToString<RHS>(rhs_span[i]), "(", + absl::StrFormat("%#x", rhs_span[i].value), ") ")); + + EXPECT_EQ(test_operation_fflags, instruction_fflags) + << absl::StrCat(name, " ", i, " (rm=", rm, + ") : ", FloatingPointToString<LHS>(lhs_span[i]), "(", + absl::StrFormat("%#x", lhs_span[i].value), ") ", + FloatingPointToString<MHS>(mhs_span[i]), "(", + absl::StrFormat("%#x", mhs_span[i].value), ") ", + FloatingPointToString<RHS>(rhs_span[i]), "(", + absl::StrFormat("%#x", rhs_span[i].value), ") "); + } + } +} class RV32ZfhInstructionTest : public RVZfhInstructionTestBase { protected: @@ -1147,6 +1302,146 @@ }); } +// Test fused multiply add for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFmadd) { + SetSemanticFunction(&RiscVZfhFmadd); + TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>( + "fmadd.h", instruction_, {"f", "f", "f", "f"}, 10, + [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> { + FPRoundingMode rm = rv_fp_->GetRoundingMode(); + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + double c_f = + FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags); + // Don't collect any host flags from the product operation. + double product_f = a_f * b_f; + HalfFP result; + if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else if (std::isinf(product_f) && std::isinf(c_f) && + (std::signbit(product_f) != std::signbit(c_f))) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else { + double result_f; + fflags |= GetOperationFlags( + [&result_f, &product_f, &c_f] { result_f = product_f + c_f; }); + result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>( + fflags, rm); + } + return std::make_tuple(result, fflags); + }); +} + +// Test fused multiply subtract for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFmsub) { + SetSemanticFunction(&RiscVZfhFmsub); + TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>( + "fmsub.h", instruction_, {"f", "f", "f", "f"}, 10, + [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> { + FPRoundingMode rm = rv_fp_->GetRoundingMode(); + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + double c_f = + FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags); + // Don't collect any host flags from the product operation. + double product_f = a_f * b_f; + HalfFP result; + if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else if (std::isinf(product_f) && std::isinf(c_f) && + (std::signbit(product_f) == std::signbit(c_f))) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else { + double result_f; + fflags |= GetOperationFlags( + [&result_f, &product_f, &c_f] { result_f = product_f - c_f; }); + result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>( + fflags, rm); + } + return std::make_tuple(result, fflags); + }); +} + +// Test negative fused multiply add for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFnmadd) { + SetSemanticFunction(&RiscVZfhFnmadd); + TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>( + "fnmadd.h", instruction_, {"f", "f", "f", "f"}, 10, + [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> { + FPRoundingMode rm = rv_fp_->GetRoundingMode(); + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + double c_f = + FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags); + // Don't collect any host flags from the product operation. + double product_f = -(a_f * b_f); + HalfFP result; + if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else if (std::isinf(product_f) && std::isinf(c_f) && + (std::signbit(product_f) == std::signbit(c_f))) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else { + double result_f; + fflags |= GetOperationFlags( + [&result_f, &product_f, &c_f] { result_f = product_f - c_f; }); + result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>( + fflags, rm); + } + return std::make_tuple(result, fflags); + }); +} + +// Test negative fused multiply subtract for half precision values. +TEST_F(RV32ZfhInstructionTest, RiscVZfhFnmsub) { + SetSemanticFunction(&RiscVZfhFnmsub); + TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>( + "fnmsub.h", instruction_, {"f", "f", "f", "f"}, 10, + [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> { + FPRoundingMode rm = rv_fp_->GetRoundingMode(); + uint32_t fflags = 0; + double a_f = + FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags); + double b_f = + FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags); + double c_f = + FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags); + // Don't collect any host flags from the product operation. + double product_f = -(a_f * b_f); + HalfFP result; + if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else if (std::isinf(product_f) && std::isinf(c_f) && + (std::signbit(product_f) != std::signbit(c_f))) { + fflags |= *FPExceptions::kInvalidOp; + result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + } else { + double result_f; + fflags |= GetOperationFlags( + [&result_f, &product_f, &c_f] { result_f = product_f + c_f; }); + result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>( + fflags, rm); + } + return std::make_tuple(result, fflags); + }); +} + class RV64ZfhInstructionTest : public RVZfhInstructionTestBase {}; // Move half precision from a float register to an integer register. The IEEE754