zfh: fused multiply add semantic functions

PiperOrigin-RevId: 757923526
Change-Id: I870093d42c020626bdde16e06a029842e655ffd0
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc
index f4b8ad2..484173e 100644
--- a/riscv/riscv_zfh_instructions.cc
+++ b/riscv/riscv_zfh_instructions.cc
@@ -445,6 +445,59 @@
   fflags_dest->GetRiscVCsr()->SetBits(fflags);
 }
 
+// Generic helper function enabling HalfFP operations in native datatypes.
+template <typename Argument, typename IntermediateType>
+void RiscVZfhTernaryHelper(
+    const Instruction *instruction,
+    std::function<IntermediateType(IntermediateType, IntermediateType,
+                                   IntermediateType)>
+        operation) {
+  RiscVCsrDestinationOperand *fflags_dest =
+      static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
+  uint32_t fflags = 0;
+  // RiscVTernaryFloatNaNBoxOp will handle the register NaN boxed reads and
+  // write. The operation is in a native datatype so we will handle conversions
+  // from/to half precision float values before and after the operation.
+  RiscVTernaryFloatNaNBoxOp<RVFpRegister::ValueType, HalfFP, Argument>(
+      instruction,
+      [instruction, &operation, &fflags](Argument a, Argument b,
+                                         Argument c) -> HalfFP {
+        RiscVFPState *rv_fp =
+            static_cast<RiscVState *>(instruction->state())->rv_fp();
+        int rm_value = generic::GetInstructionSource<int>(instruction, 3);
+        // If the rounding mode is dynamic, read it from the current state.
+        if (rm_value == *FPRoundingMode::kDynamic) {
+          if (!rv_fp->rounding_mode_valid()) {
+            LOG(ERROR) << "Invalid rounding mode";
+          }
+          rm_value = *(rv_fp->GetRoundingMode());
+        }
+        FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
+        IntermediateType argument1 =
+            ConvertFromHalfFP<IntermediateType>(a, fflags);
+        IntermediateType argument2 =
+            ConvertFromHalfFP<IntermediateType>(b, fflags);
+        IntermediateType argument3 =
+            ConvertFromHalfFP<IntermediateType>(c, fflags);
+        IntermediateType result;
+        {
+          ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm_value);
+          result = operation(argument1, argument2, argument3);
+        }
+        // To get the correct fflags we need a combination of host flags from
+        // the operation and the conversion flags. Copy the host flags and merge
+        // them with the conversion flags.
+        fflags |= rv_fp->fflags()->GetUint32();
+        {
+          // ConvertToHalfFP pollutes the host flags so we need to create a
+          // ScopedFPRoundingMode to restore the host flags.
+          ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value);
+          return ConvertToHalfFP(result, rm, fflags);
+        }
+      });
+  fflags_dest->GetRiscVCsr()->SetBits(fflags);
+}
+
 }  // namespace
 
 namespace RV32 {
@@ -743,6 +796,38 @@
       });
 }
 
+// Fused multiply add in half precision. Do the operation in single precision.
+// (rs1 * rs2) + rs3
+void RiscVZfhFmadd(const Instruction *instruction) {
+  RiscVZfhTernaryHelper<HalfFP, float>(
+      instruction,
+      [](float a, float b, float c) -> float { return fma(a, b, c); });
+}
+
+// Fused multiply add in half precision. Do the operation in single precision.
+// (rs1 * rs2) - rs3
+void RiscVZfhFmsub(const Instruction *instruction) {
+  RiscVZfhTernaryHelper<HalfFP, float>(
+      instruction,
+      [](float a, float b, float c) -> float { return fma(a, b, -c); });
+}
+
+// Fused multiply add in half precision. Do the operation in single precision.
+// -(rs1 * rs2) - rs3
+void RiscVZfhFnmadd(const Instruction *instruction) {
+  RiscVZfhTernaryHelper<HalfFP, float>(
+      instruction,
+      [](float a, float b, float c) -> float { return fma(-a, b, -c); });
+}
+
+// Fused multiply add in half precision. Do the operation in single precision.
+// -(rs1 * rs2) + rs3
+void RiscVZfhFnmsub(const Instruction *instruction) {
+  RiscVZfhTernaryHelper<HalfFP, float>(
+      instruction,
+      [](float a, float b, float c) -> float { return fma(-a, b, c); });
+}
+
 // TODO(b/409778536): Factor out generic unimplemented instruction semantic
 //                    function.
 void RV32VUnimplementedInstruction(const Instruction *instruction) {
diff --git a/riscv/riscv_zfh_instructions.h b/riscv/riscv_zfh_instructions.h
index 6337a31..20d5261 100644
--- a/riscv/riscv_zfh_instructions.h
+++ b/riscv/riscv_zfh_instructions.h
@@ -234,6 +234,46 @@
 //   fflags: Accrued Exception Flags field in FCSR
 void RiscVZfhFsqrt(const Instruction *instruction);
 
+// Source Operands:
+//   frs1: Float Register
+//   frs2: Float Register
+//   frs3: Float Register
+//   rm: Literal Operand (rounding mode)
+// Destination Operands:
+//   frd: Float Register
+//   fflags: Accrued Exception Flags field in FCSR
+void RiscVZfhFmadd(const Instruction *instruction);
+
+// Source Operands:
+//   frs1: Float Register
+//   frs2: Float Register
+//   frs3: Float Register
+//   rm: Literal Operand (rounding mode)
+// Destination Operands:
+//   frd: Float Register
+//   fflags: Accrued Exception Flags field in FCSR
+void RiscVZfhFmsub(const Instruction *instruction);
+
+// Source Operands:
+//   frs1: Float Register
+//   frs2: Float Register
+//   frs3: Float Register
+//   rm: Literal Operand (rounding mode)
+// Destination Operands:
+//   frd: Float Register
+//   fflags: Accrued Exception Flags field in FCSR
+void RiscVZfhFnmadd(const Instruction *instruction);
+
+// Source Operands:
+//   frs1: Float Register
+//   frs2: Float Register
+//   frs3: Float Register
+//   rm: Literal Operand (rounding mode)
+// Destination Operands:
+//   frd: Float Register
+//   fflags: Accrued Exception Flags field in FCSR
+void RiscVZfhFnmsub(const Instruction *instruction);
+
 }  // namespace riscv
 }  // namespace sim
 }  // namespace mpact
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc
index 65e01ce..dee6408 100644
--- a/riscv/test/riscv_zfh_instructions_test.cc
+++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -56,6 +56,7 @@
 using ::mpact::sim::generic::HalfFP;
 using ::mpact::sim::generic::ImmediateOperand;
 using ::mpact::sim::generic::Instruction;
+using ::mpact::sim::riscv::FPExceptions;
 using ::mpact::sim::riscv::FPRoundingMode;
 using ::mpact::sim::riscv::RiscVZfhCvtDh;
 using ::mpact::sim::riscv::RiscVZfhCvtHd;
@@ -64,10 +65,14 @@
 using ::mpact::sim::riscv::RiscVZfhFadd;
 using ::mpact::sim::riscv::RiscVZfhFdiv;
 using ::mpact::sim::riscv::RiscVZfhFlhChild;
+using ::mpact::sim::riscv::RiscVZfhFmadd;
 using ::mpact::sim::riscv::RiscVZfhFmax;
 using ::mpact::sim::riscv::RiscVZfhFmin;
+using ::mpact::sim::riscv::RiscVZfhFmsub;
 using ::mpact::sim::riscv::RiscVZfhFmul;
 using ::mpact::sim::riscv::RiscVZfhFMvhx;
+using ::mpact::sim::riscv::RiscVZfhFnmadd;
+using ::mpact::sim::riscv::RiscVZfhFnmsub;
 using ::mpact::sim::riscv::RiscVZfhFsgnj;
 using ::mpact::sim::riscv::RiscVZfhFsgnjn;
 using ::mpact::sim::riscv::RiscVZfhFsgnjx;
@@ -77,6 +82,7 @@
 using ::mpact::sim::riscv::RV64Register;
 using ::mpact::sim::riscv::RVFpRegister;
 using ::mpact::sim::riscv::ScopedFPRoundingMode;
+using ::mpact::sim::riscv::ScopedFPStatus;
 using ::mpact::sim::riscv::RV32::RiscVILhu;
 using ::mpact::sim::riscv::RV32::RiscVZfhCvtHw;
 using ::mpact::sim::riscv::RV32::RiscVZfhCvtHwu;
@@ -102,6 +108,43 @@
 const int kRoundingModeRoundDown = static_cast<int>(FPRoundingMode::kRoundDown);
 const int kRoundingModeRoundUp = static_cast<int>(FPRoundingMode::kRoundUp);
 
+// A source operand that is used to set the rounding mode. This is less
+// confusing than using a register source operand since the rounding mode is
+// part of the instruction encoding.
+class TestRoundingModeSourceOperand
+    : public mpact::sim::generic::SourceOperandInterface {
+ public:
+  explicit TestRoundingModeSourceOperand()
+      : rounding_mode_(FPRoundingMode::kRoundToNearest) {}
+
+  void SetRoundingMode(FPRoundingMode rounding_mode) {
+    rounding_mode_ = rounding_mode;
+  }
+
+  bool AsBool(int) override { return static_cast<bool>(rounding_mode_); }
+  int8_t AsInt8(int) override { return static_cast<int8_t>(rounding_mode_); }
+  uint8_t AsUint8(int) override { return static_cast<uint8_t>(rounding_mode_); }
+  int16_t AsInt16(int) override { return static_cast<int16_t>(rounding_mode_); }
+  uint16_t AsUint16(int) override {
+    return static_cast<uint16_t>(rounding_mode_);
+  }
+  int32_t AsInt32(int) override { return static_cast<int32_t>(rounding_mode_); }
+  uint32_t AsUint32(int) override {
+    return static_cast<uint32_t>(rounding_mode_);
+  }
+  int64_t AsInt64(int) override { return static_cast<int64_t>(rounding_mode_); }
+  uint64_t AsUint64(int) override {
+    return static_cast<uint64_t>(rounding_mode_);
+  }
+
+  std::vector<int> shape() const override { return {1}; }
+  std::string AsString() const override { return std::string(""); }
+  std::any GetObject() const override { return std::any(); }
+
+ protected:
+  FPRoundingMode rounding_mode_;
+};
+
 class RVZfhInstructionTestBase : public RiscVFPInstructionTestBase {
  protected:
   template <typename AddressType, typename ValueType>
@@ -116,6 +159,14 @@
       absl::string_view name, Instruction *inst,
       absl::Span<const absl::string_view> reg_prefixes, int delta_position,
       std::function<std::tuple<R, uint32_t>(LHS, uint32_t)> operation);
+
+  template <typename R, typename LHS, typename MHS, typename RHS>
+  void TernaryOpWithFflagsFPTestHelper(
+      absl::string_view name, Instruction *inst,
+      absl::Span<const absl::string_view> reg_prefixes, int delta_position,
+      std::function<std::tuple<R, uint32_t>(LHS, MHS, RHS)> operation);
+
+  uint32_t GetOperationFlags(std::function<void(void)> operation);
 };
 
 template <typename AddressType, typename ValueType>
@@ -153,6 +204,25 @@
   return observed_val;
 }
 
+// Helper for statements that set the host FPU flags. This is used to
+// determine the flags set by an operation. Enables targeted capture of flags
+// that are set by an operation. The Simulation flags are restored to the
+// initial state after the operation is executed.
+uint32_t RVZfhInstructionTestBase::GetOperationFlags(
+    std::function<void(void)> operation) {
+  uint32_t initial_fflags = rv_fp_->fflags()->GetUint32();
+  uint32_t delta_fflags = 0;
+  rv_fp_->fflags()->Write(static_cast<uint32_t>(0));
+  {
+    ScopedFPStatus sfpstatus(rv_fp_->host_fp_interface(),
+                             rv_fp_->GetRoundingMode());
+    operation();
+  }
+  delta_fflags = rv_fp_->fflags()->GetUint32();
+  rv_fp_->fflags()->Write(initial_fflags);
+  return delta_fflags;
+}
+
 // Helper for unary instructions that go between floats and integers.
 template <typename DestRegisterType, typename LhsRegisterType, typename R,
           typename LHS>
@@ -248,42 +318,127 @@
   }
 }
 
-// A source operand that is used to set the rounding mode. This is less
-// confusing than using a register source operand since the rounding mode is
-// part of the instruction encoding.
-class TestRoundingModeSourceOperand
-    : public mpact::sim::generic::SourceOperandInterface {
- public:
-  explicit TestRoundingModeSourceOperand()
-      : rounding_mode_(FPRoundingMode::kRoundToNearest) {}
-
-  void SetRoundingMode(FPRoundingMode rounding_mode) {
-    rounding_mode_ = rounding_mode;
+template <typename R, typename LHS, typename MHS, typename RHS>
+void RVZfhInstructionTestBase::TernaryOpWithFflagsFPTestHelper(
+    absl::string_view name, Instruction *inst,
+    absl::Span<const absl::string_view> reg_prefixes, int delta_position,
+    std::function<std::tuple<R, uint32_t>(LHS, MHS, RHS)> operation) {
+  using LhsRegisterType = RVFpRegister;
+  using MhsRegisterType = RVFpRegister;
+  using RhsRegisterType = RVFpRegister;
+  using DestRegisterType = RVFpRegister;
+  LHS lhs_values[kTestValueLength];
+  MHS mhs_values[kTestValueLength];
+  RHS rhs_values[kTestValueLength];
+  auto lhs_span = absl::Span<LHS>(lhs_values);
+  auto mhs_span = absl::Span<MHS>(mhs_values);
+  auto rhs_span = absl::Span<RHS>(rhs_values);
+  const std::string kR1Name = absl::StrCat(reg_prefixes[0], 1);
+  const std::string kR2Name = absl::StrCat(reg_prefixes[1], 2);
+  const std::string kR3Name = absl::StrCat(reg_prefixes[2], 3);
+  const std::string kRdName = absl::StrCat(reg_prefixes[3], 5);
+  if (kR1Name[0] == 'x') {
+    AppendRegisterOperands<RV32Register>({kR1Name}, {});
+  } else {
+    AppendRegisterOperands<RVFpRegister>({kR1Name}, {});
   }
-
-  bool AsBool(int) override { return static_cast<bool>(rounding_mode_); }
-  int8_t AsInt8(int) override { return static_cast<int8_t>(rounding_mode_); }
-  uint8_t AsUint8(int) override { return static_cast<uint8_t>(rounding_mode_); }
-  int16_t AsInt16(int) override { return static_cast<int16_t>(rounding_mode_); }
-  uint16_t AsUint16(int) override {
-    return static_cast<uint16_t>(rounding_mode_);
+  if (kR2Name[0] == 'x') {
+    AppendRegisterOperands<RV32Register>({kR2Name}, {});
+  } else {
+    AppendRegisterOperands<RVFpRegister>({kR2Name}, {});
   }
-  int32_t AsInt32(int) override { return static_cast<int32_t>(rounding_mode_); }
-  uint32_t AsUint32(int) override {
-    return static_cast<uint32_t>(rounding_mode_);
+  if (kR3Name[0] == 'x') {
+    AppendRegisterOperands<RV32Register>({kR3Name}, {});
+  } else {
+    AppendRegisterOperands<RVFpRegister>({kR3Name}, {});
   }
-  int64_t AsInt64(int) override { return static_cast<int64_t>(rounding_mode_); }
-  uint64_t AsUint64(int) override {
-    return static_cast<uint64_t>(rounding_mode_);
+  if (kRdName[0] == 'x') {
+    AppendRegisterOperands<RV32Register>({}, {kRdName});
+  } else {
+    AppendRegisterOperands<RVFpRegister>({}, {kRdName});
   }
+  TestRoundingModeSourceOperand *rm_source_operand =
+      new TestRoundingModeSourceOperand();
+  instruction_->AppendSource(rm_source_operand);
+  auto *flag_op = rv_fp_->fflags()->CreateSetDestinationOperand(0, "fflags");
+  instruction_->AppendDestination(flag_op);
+  FillArrayWithRandomFPValues<LHS>(lhs_span);
+  FillArrayWithRandomFPValues<MHS>(mhs_span);
+  FillArrayWithRandomFPValues<RHS>(rhs_span);
+  using LhsInt = typename FPTypeInfo<LHS>::IntType;
+  *reinterpret_cast<LhsInt *>(&lhs_span[0]) = FPTypeInfo<LHS>::kQNaN;
+  *reinterpret_cast<LhsInt *>(&lhs_span[1]) = FPTypeInfo<LHS>::kSNaN;
+  *reinterpret_cast<LhsInt *>(&lhs_span[2]) = FPTypeInfo<LHS>::kPosInf;
+  *reinterpret_cast<LhsInt *>(&lhs_span[3]) = FPTypeInfo<LHS>::kNegInf;
+  *reinterpret_cast<LhsInt *>(&lhs_span[4]) = FPTypeInfo<LHS>::kPosZero;
+  *reinterpret_cast<LhsInt *>(&lhs_span[5]) = FPTypeInfo<LHS>::kNegZero;
+  *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm;
+  *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm;
+  using MhsInt = typename FPTypeInfo<MHS>::IntType;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 0]) = FPTypeInfo<MHS>::kQNaN;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 1]) = FPTypeInfo<MHS>::kSNaN;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 2]) = FPTypeInfo<MHS>::kPosInf;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 3]) = FPTypeInfo<MHS>::kNegInf;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 4]) = FPTypeInfo<MHS>::kPosZero;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 5]) = FPTypeInfo<MHS>::kNegZero;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 6]) = FPTypeInfo<MHS>::kPosDenorm;
+  *reinterpret_cast<MhsInt *>(&mhs_span[8 + 7]) = FPTypeInfo<MHS>::kNegDenorm;
+  using RhsInt = typename FPTypeInfo<RHS>::IntType;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 0]) = FPTypeInfo<RHS>::kQNaN;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 1]) = FPTypeInfo<RHS>::kSNaN;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 2]) = FPTypeInfo<RHS>::kPosInf;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 3]) = FPTypeInfo<RHS>::kNegInf;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 4]) = FPTypeInfo<RHS>::kPosZero;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 5]) = FPTypeInfo<RHS>::kNegZero;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 6]) = FPTypeInfo<RHS>::kPosDenorm;
+  *reinterpret_cast<RhsInt *>(&rhs_span[16 + 7]) = FPTypeInfo<RHS>::kNegDenorm;
+  for (int i = 0; i < kTestValueLength; i++) {
+    SetNaNBoxedRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}});
+    SetNaNBoxedRegisterValues<MHS, MhsRegisterType>({{kR2Name, mhs_span[i]}});
+    SetNaNBoxedRegisterValues<RHS, RhsRegisterType>({{kR3Name, rhs_span[i]}});
 
-  std::vector<int> shape() const override { return {1}; }
-  std::string AsString() const override { return std::string(""); }
-  std::any GetObject() const override { return std::any(); }
+    for (int rm : {0, 1, 2, 3, 4}) {
+      rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm));
+      rm_source_operand->SetRoundingMode(static_cast<FPRoundingMode>(rm));
+      rv_fp_->fflags()->Write(static_cast<uint32_t>(0));
+      SetRegisterValues<DestRegisterType::ValueType, DestRegisterType>(
+          {{kRdName, 0}});
 
- protected:
-  FPRoundingMode rounding_mode_;
-};
+      inst->Execute(nullptr);
+      // Get the fflags for the instruction execution.
+      auto instruction_fflags = rv_fp_->fflags()->GetUint32();
+      rv_fp_->fflags()->Write(static_cast<uint32_t>(0));
+      R op_val;
+      uint32_t test_operation_fflags;
+      {
+        ScopedFPRoundingMode scoped_rm(rv_fp_->host_fp_interface(), rm);
+        std::tie(op_val, test_operation_fflags) =
+            operation(lhs_span[i], mhs_span[i], rhs_span[i]);
+      }
+      auto reg_val = state_->GetRegister<DestRegisterType>(kRdName)
+                         .first->data_buffer()
+                         ->template Get<R>(0);
+      FPCompare<R>(
+          op_val, reg_val, delta_position,
+          absl::StrCat(name, "  ", i, " (rm=", rm,
+                       ") : ", FloatingPointToString<LHS>(lhs_span[i]), "(",
+                       absl::StrFormat("%#x", lhs_span[i].value), ") ",
+                       FloatingPointToString<MHS>(mhs_span[i]), "(",
+                       absl::StrFormat("%#x", mhs_span[i].value), ") ",
+                       FloatingPointToString<RHS>(rhs_span[i]), "(",
+                       absl::StrFormat("%#x", rhs_span[i].value), ") "));
+
+      EXPECT_EQ(test_operation_fflags, instruction_fflags)
+          << absl::StrCat(name, "  ", i, " (rm=", rm,
+                          ") : ", FloatingPointToString<LHS>(lhs_span[i]), "(",
+                          absl::StrFormat("%#x", lhs_span[i].value), ") ",
+                          FloatingPointToString<MHS>(mhs_span[i]), "(",
+                          absl::StrFormat("%#x", mhs_span[i].value), ") ",
+                          FloatingPointToString<RHS>(rhs_span[i]), "(",
+                          absl::StrFormat("%#x", rhs_span[i].value), ") ");
+    }
+  }
+}
 
 class RV32ZfhInstructionTest : public RVZfhInstructionTestBase {
  protected:
@@ -1147,6 +1302,146 @@
       });
 }
 
+// Test fused multiply add for half precision values.
+TEST_F(RV32ZfhInstructionTest, RiscVZfhFmadd) {
+  SetSemanticFunction(&RiscVZfhFmadd);
+  TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>(
+      "fmadd.h", instruction_, {"f", "f", "f", "f"}, 10,
+      [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
+        FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        uint32_t fflags = 0;
+        double a_f =
+            FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
+        double b_f =
+            FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags);
+        double c_f =
+            FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags);
+        // Don't collect any host flags from the product operation.
+        double product_f = a_f * b_f;
+        HalfFP result;
+        if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else if (std::isinf(product_f) && std::isinf(c_f) &&
+                   (std::signbit(product_f) != std::signbit(c_f))) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else {
+          double result_f;
+          fflags |= GetOperationFlags(
+              [&result_f, &product_f, &c_f] { result_f = product_f + c_f; });
+          result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>(
+              fflags, rm);
+        }
+        return std::make_tuple(result, fflags);
+      });
+}
+
+// Test fused multiply subtract for half precision values.
+TEST_F(RV32ZfhInstructionTest, RiscVZfhFmsub) {
+  SetSemanticFunction(&RiscVZfhFmsub);
+  TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>(
+      "fmsub.h", instruction_, {"f", "f", "f", "f"}, 10,
+      [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
+        FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        uint32_t fflags = 0;
+        double a_f =
+            FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
+        double b_f =
+            FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags);
+        double c_f =
+            FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags);
+        // Don't collect any host flags from the product operation.
+        double product_f = a_f * b_f;
+        HalfFP result;
+        if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else if (std::isinf(product_f) && std::isinf(c_f) &&
+                   (std::signbit(product_f) == std::signbit(c_f))) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else {
+          double result_f;
+          fflags |= GetOperationFlags(
+              [&result_f, &product_f, &c_f] { result_f = product_f - c_f; });
+          result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>(
+              fflags, rm);
+        }
+        return std::make_tuple(result, fflags);
+      });
+}
+
+// Test negative fused multiply add for half precision values.
+TEST_F(RV32ZfhInstructionTest, RiscVZfhFnmadd) {
+  SetSemanticFunction(&RiscVZfhFnmadd);
+  TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>(
+      "fnmadd.h", instruction_, {"f", "f", "f", "f"}, 10,
+      [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
+        FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        uint32_t fflags = 0;
+        double a_f =
+            FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
+        double b_f =
+            FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags);
+        double c_f =
+            FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags);
+        // Don't collect any host flags from the product operation.
+        double product_f = -(a_f * b_f);
+        HalfFP result;
+        if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else if (std::isinf(product_f) && std::isinf(c_f) &&
+                   (std::signbit(product_f) == std::signbit(c_f))) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else {
+          double result_f;
+          fflags |= GetOperationFlags(
+              [&result_f, &product_f, &c_f] { result_f = product_f - c_f; });
+          result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>(
+              fflags, rm);
+        }
+        return std::make_tuple(result, fflags);
+      });
+}
+
+// Test negative fused multiply subtract for half precision values.
+TEST_F(RV32ZfhInstructionTest, RiscVZfhFnmsub) {
+  SetSemanticFunction(&RiscVZfhFnmsub);
+  TernaryOpWithFflagsFPTestHelper<HalfFP, HalfFP, HalfFP, HalfFP>(
+      "fnmsub.h", instruction_, {"f", "f", "f", "f"}, 10,
+      [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
+        FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        uint32_t fflags = 0;
+        double a_f =
+            FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
+        double b_f =
+            FpConversionsTestHelper(b).ConvertWithFlags<double>(fflags);
+        double c_f =
+            FpConversionsTestHelper(c).ConvertWithFlags<double>(fflags);
+        // Don't collect any host flags from the product operation.
+        double product_f = -(a_f * b_f);
+        HalfFP result;
+        if ((std::isinf(a_f) && b_f == 0) || (std::isinf(b_f) && a_f == 0)) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else if (std::isinf(product_f) && std::isinf(c_f) &&
+                   (std::signbit(product_f) != std::signbit(c_f))) {
+          fflags |= *FPExceptions::kInvalidOp;
+          result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+        } else {
+          double result_f;
+          fflags |= GetOperationFlags(
+              [&result_f, &product_f, &c_f] { result_f = product_f + c_f; });
+          result = FpConversionsTestHelper(result_f).ConvertWithFlags<HalfFP>(
+              fflags, rm);
+        }
+        return std::make_tuple(result, fflags);
+      });
+}
+
 class RV64ZfhInstructionTest : public RVZfhInstructionTestBase {};
 
 // Move half precision from a float register to an integer register. The IEEE754