zfh simplification

* Remove platform dependant conversions to half precision
* Resolve double to half conversion bug when cast followed by round resulted in the incorrect value
* Fix inf / 0 incorrect fflags expectations
* fmin / fmax dont use the rounding mode operand

PiperOrigin-RevId: 756900912
Change-Id: Ie21c09b372d587de7ae2d18abe9beebb6470af68
diff --git a/riscv/BUILD b/riscv/BUILD
index cfc03dc..962213f 100644
--- a/riscv/BUILD
+++ b/riscv/BUILD
@@ -329,35 +329,15 @@
 
 cc_library(
     name = "riscv_zfh_instructions",
-    srcs = select({
-        "//third_party/bazel_platforms/cpu:aarch64": [
-            "riscv_zfh_instructions.cc",
-            "riscv_zfh_instructions_arm.cc",
-        ],
-        "//conditions:default": [
-            "riscv_zfh_instructions.cc",
-            "riscv_zfh_instructions_x86.cc",
-        ],
-    }),
+    srcs = ["riscv_zfh_instructions.cc"],
     hdrs = [
         "riscv_instruction_helpers.h",
         "riscv_zfh_instructions.h",
     ],
-    copts = select({
-        "//third_party/bazel_platforms/cpu:aarch64": [
-            "-O3",
-            "-ffp-model=strict",
-        ],
-        "//buildenv/platforms/settings:macos_aarch64": [
-            "-O3",
-            "-ffp-model=strict",
-        ],
-        "//conditions:default": [
-            "-ffp-model=strict",
-            "-O3",
-            "-mf16c",
-        ],
-    }),
+    copts = [
+        "-ffp-model=strict",
+        "-O3",
+    ],
     deps = [
         ":riscv_fp_state",
         ":riscv_state",
diff --git a/riscv/riscv_zfh.isa b/riscv/riscv_zfh.isa
index abd0fa8..c4514e7 100644
--- a/riscv/riscv_zfh.isa
+++ b/riscv/riscv_zfh.isa
@@ -136,11 +136,11 @@
       resources: {next_pc, frs1, frs2 : frd[0..]},
       semfunc: "&RiscVZfhFdiv",
       disasm: "fdiv.h", "%frd, %frs1, %frs2";
-    fmin_h{: frs1, frs2, rm : frd, fflags},
+    fmin_h{: frs1, frs2 : frd, fflags},
       resources: {next_pc, frs1, frs2 : frd[0..]},
       semfunc: "&RiscVZfhFmin",
       disasm: "fmin.h", "%frd, %frs1, %frs2";
-    fmax_h{: frs1, frs2, rm : frd, fflags},
+    fmax_h{: frs1, frs2 : frd, fflags},
       resources: {next_pc, frs1, frs2 : frd[0..]},
       semfunc: "&RiscVZfhFmax",
       disasm: "fmax.h", "%frd, %frs1, %frs2";
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc
index 928bf24..f4b8ad2 100644
--- a/riscv/riscv_zfh_instructions.cc
+++ b/riscv/riscv_zfh_instructions.cc
@@ -85,6 +85,166 @@
   using type = RV64Register::ValueType;
 };
 
+// This is a soft conversion from a float or double to a half precision value.
+// It is not a direct conversion from the floating point format to the half
+// format. Instead, it uses the floating point hardware to do the conversion.
+// This is done to get the correct rounding behavior for free from the FPU.
+template <typename T>
+HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm, uint32_t &fflags) {
+  using UIntType = typename FPTypeInfo<T>::UIntType;
+  using IntType = typename FPTypeInfo<T>::IntType;
+  UIntType in_int = absl::bit_cast<UIntType>(input_value);
+  HalfFP half_fp = {.value = 0x0000};
+
+  // Extract the mantissa, exponent and sign.
+  UIntType mantissa = in_int & FPTypeInfo<T>::kSigMask;
+  UIntType exponent =
+      (in_int & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize;
+  UIntType sign = in_int >> (FPTypeInfo<T>::kBitSize - 1);
+
+  if (std::isnan(input_value)) {
+    half_fp.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
+    if (FPTypeInfo<T>::IsSNaN(input_value)) {
+      fflags |= static_cast<UIntType>(FPExceptions::kInvalidOp);
+    }
+    return half_fp;
+  }
+
+  if (std::isinf(input_value)) {
+    half_fp.value = FPTypeInfo<HalfFP>::kPosInf;
+    half_fp.value |= (sign & 1) << (FPTypeInfo<HalfFP>::kBitSize - 1);
+    return half_fp;
+  }
+
+  if (in_int == 0 || in_int == 1ULL << (FPTypeInfo<T>::kBitSize - 1)) {
+    half_fp.value =
+        in_int >> (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize);
+    return half_fp;
+  }
+
+  IntType bias_diff = FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias;
+  IntType unbounded_half_exponent = static_cast<IntType>(exponent) - bias_diff;
+  IntType sig_size_diff =
+      FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize;
+  UIntType half_inf_exponent = ((1 << FPTypeInfo<HalfFP>::kExpSize) - 1);
+  UIntType source_type_inf_exponent = ((1 << FPTypeInfo<T>::kExpSize) - 1);
+
+  // Create a temp float with the smallest normal exponent and input mantissa.
+  T ftmp = absl::bit_cast<T>(
+      (sign << (FPTypeInfo<T>::kBitSize - 1)) |
+      (static_cast<UIntType>(1ULL) << FPTypeInfo<T>::kSigSize) | mantissa);
+
+  // Create a divisor float that will be used for shifting the mantissa in a
+  // rounding aware way. The amount of shifting depends on if the result is
+  // subnormal or normal.
+  T fdiv = 0;
+  UIntType default_fdiv_exp = FPTypeInfo<T>::kExpBias + sig_size_diff;
+  UIntType fdiv_exp = default_fdiv_exp;
+  if (unbounded_half_exponent > 0) {
+    fdiv_exp = default_fdiv_exp;
+  } else if (unbounded_half_exponent < 0) {
+    // shift_count: emin - unbiased exponent
+    IntType shift_count = 1 - static_cast<int>(exponent) + bias_diff;
+    fdiv_exp = default_fdiv_exp + shift_count;
+    fdiv_exp = std::min(fdiv_exp, source_type_inf_exponent - 1);
+  } else {
+    fdiv_exp = default_fdiv_exp + 1;
+  }
+  fdiv = absl::bit_cast<T>(fdiv_exp << FPTypeInfo<T>::kSigSize);
+
+  // Shift right by doing division.
+  T fres = ftmp / fdiv;
+  UIntType res = absl::bit_cast<UIntType>(fres);
+
+  // Shift left by doing multiplication.
+  T fmultiply = absl::bit_cast<T>(default_fdiv_exp << FPTypeInfo<T>::kSigSize);
+  T fres2 = fres * fmultiply;
+  UIntType res2 = absl::bit_cast<UIntType>(fres2);
+
+  // Update the exponent if rounding caused an increase.
+  IntType exp_diff = static_cast<IntType>((res2 >> FPTypeInfo<T>::kSigSize) &
+                                          source_type_inf_exponent) -
+                     1;
+  UIntType new_exponent = (exponent + exp_diff) & source_type_inf_exponent;
+
+  UIntType half_exponent = 0;
+  if (unbounded_half_exponent > 0) {
+    half_exponent = new_exponent - bias_diff;
+  } else if (unbounded_half_exponent < 0) {
+    // Guaranteed subnormal. Nothing to do.
+  } else {
+    // This case could be normal or subnormal depending on the rounding result.
+    half_exponent = (res2 >> FPTypeInfo<T>::kSigSize) & half_inf_exponent;
+  }
+
+  UIntType half_mantissa =
+      (res2 >> sig_size_diff) & FPTypeInfo<HalfFP>::kSigMask;
+  if (unbounded_half_exponent < 0) {  // Guaranteed Subnormal
+    half_mantissa = (res & (1 << FPTypeInfo<HalfFP>::kSigSize))
+                        ? ((res >> 1) & FPTypeInfo<HalfFP>::kSigMask)
+                        : res & FPTypeInfo<HalfFP>::kSigMask;
+  }
+
+  // Handle the rules for overflowing to infinity depending on the rounding
+  // mode.
+  if (half_exponent >= half_inf_exponent) {
+    fflags |= static_cast<uint32_t>(FPExceptions::kOverflow);
+    fflags |= static_cast<uint32_t>(FPExceptions::kInexact);
+    switch (rm) {
+      case FPRoundingMode::kRoundToNearest:
+        half_exponent = half_inf_exponent;
+        half_mantissa = 0;
+        break;
+      case FPRoundingMode::kRoundTowardsZero:
+        half_exponent = half_inf_exponent - 1;
+        half_mantissa = FPTypeInfo<HalfFP>::kSigMask;
+        break;
+      case FPRoundingMode::kRoundDown:
+        half_exponent = sign ? half_inf_exponent : half_inf_exponent - 1;
+        half_mantissa = sign ? 0 : FPTypeInfo<HalfFP>::kSigMask;
+        break;
+      case FPRoundingMode::kRoundUp:
+        half_exponent = sign ? half_inf_exponent - 1 : half_inf_exponent;
+        half_mantissa = sign ? FPTypeInfo<HalfFP>::kSigMask : 0;
+        break;
+      default:
+        half_exponent = half_inf_exponent;
+        half_mantissa = 0;
+        break;
+    }
+  }
+
+  // Construct the half float.
+  half_fp.value = half_mantissa |
+                  (half_exponent << FPTypeInfo<HalfFP>::kSigSize) |
+                  (sign << (FPTypeInfo<HalfFP>::kBitSize - 1));
+
+  // Do an arithmetic reconstruction of the float to check for exactness.
+  T trailing_significand_float = static_cast<T>(half_mantissa);
+  T precision_factor = std::pow(2.0, -1.0 * FPTypeInfo<HalfFP>::kSigSize);
+  IntType unbiased_exponent =
+      (half_exponent == 0 ? 1 : half_exponent) - FPTypeInfo<HalfFP>::kExpBias;
+  T exponent_factor = std::pow(2.0, unbiased_exponent);
+  T sign_factor = sign == 1 ? -1.0 : 1.0;
+  T implicit_bit_adjustment = half_exponent == 0 ? 0.0 : 1.0;
+  T reconstructed_value = ((trailing_significand_float * precision_factor) +
+                           implicit_bit_adjustment) *
+                          exponent_factor * sign_factor;
+  bool exact_conversion = reconstructed_value == input_value;
+
+  // Handle flags for the specific underflow case.
+  if (!exact_conversion && (unbounded_half_exponent < 0 ||
+                            (unbounded_half_exponent == 0 && fres2 != ftmp))) {
+    fflags |= static_cast<uint32_t>(FPExceptions::kUnderflow);
+  }
+
+  // Handle flags for the specific inexact case.
+  if (!exact_conversion && (fres2 != ftmp)) {
+    fflags |= static_cast<uint32_t>(FPExceptions::kInexact);
+  }
+  return half_fp;
+}
+
 // Convert from half precision to single or double precision.
 template <typename T>
 inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) {
@@ -152,7 +312,7 @@
   if constexpr (IsMpactFp<Argument>::value) {
     lhs = GetNaNBoxedSource<RVFpRegister::ValueType, Argument>(instruction, 0);
     if (FPTypeInfo<Argument>::IsSNaN(lhs)) {
-      fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp);
+      fflags |= *FPExceptions::kInvalidOp;
     }
   } else {
     lhs = generic::GetInstructionSource<Argument>(instruction, 0);
@@ -171,16 +331,11 @@
   }
 
   Result dest_value;
-  if (zfh_internal::UseHostFlagsForConversion()) {
-    ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value);
-    dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags);
-  } else {
+  {
     ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value);
     dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags);
   }
-  if (!zfh_internal::UseHostFlagsForConversion()) {
-    fflags_dest->GetRiscVCsr()->SetBits(fflags);
-  }
+  fflags_dest->GetRiscVCsr()->SetBits(fflags);
   auto *reg = static_cast<generic::RegisterDestinationOperand<DstRegValue> *>(
                   instruction->Destination(0))
                   ->GetRegister();
@@ -200,67 +355,48 @@
   reg->data_buffer()->template Set<Result>(0, dest_value);
 }
 
-template <typename T>
-inline HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm,
-                              uint32_t &fflags);
-
-template <>
-inline HalfFP ConvertToHalfFP(float input_value, FPRoundingMode rm,
-                              uint32_t &fflags) {
-  return ConvertSingleToHalfFP(input_value, rm, fflags);
-}
-
-template <>
-inline HalfFP ConvertToHalfFP(double input_value, FPRoundingMode rm,
-                              uint32_t &fflags) {
-  return ConvertDoubleToHalfFP(input_value, rm, fflags);
-}
-
 // Generic helper function enabling HalfFP operations in native datatypes.
-template <typename Result, typename Argument>
+template <typename Argument, typename IntermediateType>
 void RiscVZfhUnaryHelper(
     const Instruction *instruction,
-    std::function<Result(Argument, FPRoundingMode, uint32_t &)> operation) {
-  uint32_t fflags = 0;
-  RiscVFPState *rv_fp =
-      static_cast<RiscVState *>(instruction->state())->rv_fp();
-  int rm_value = generic::GetInstructionSource<int>(instruction, 1);
-
-  // If the rounding mode is dynamic, read it from the current state.
-  if (rm_value == *FPRoundingMode::kDynamic) {
-    if (!rv_fp->rounding_mode_valid()) {
-      LOG(ERROR) << "Invalid rounding mode";
-      return;
-    }
-    rm_value = *(rv_fp->GetRoundingMode());
-  }
-  FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
+    std::function<IntermediateType(IntermediateType)> operation) {
   RiscVCsrDestinationOperand *fflags_dest =
       static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
-  bool arguments_contain_snan = false;
+  uint32_t fflags = 0;
   RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType,
-                          Result, Argument>(
-      instruction,
-      [rv_fp, rm, &fflags, &operation,
-       &arguments_contain_snan](Argument a) -> Result {
-        Result result;
-        if (FPTypeInfo<Argument>::IsSNaN(a)) {
-          arguments_contain_snan = true;
+                          HalfFP, Argument>(
+      instruction, [instruction, &operation, &fflags](Argument a) -> HalfFP {
+        RiscVFPState *rv_fp =
+            static_cast<RiscVState *>(instruction->state())->rv_fp();
+        int rm_value = generic::GetInstructionSource<int>(instruction, 1);
+
+        // If the rounding mode is dynamic, read it from the current state.
+        if (rm_value == *FPRoundingMode::kDynamic) {
+          if (!rv_fp->rounding_mode_valid()) {
+            LOG(ERROR) << "Invalid rounding mode";
+          }
+          rm_value = *(rv_fp->GetRoundingMode());
         }
-        if (zfh_internal::UseHostFlagsForConversion()) {
-          result = operation(a, rm, fflags);
-        } else {
+        FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
+        IntermediateType argument1 =
+            ConvertFromHalfFP<IntermediateType>(a, fflags);
+        IntermediateType result;
+        {
           ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm);
-          result = operation(a, rm, fflags);
+          result = operation(argument1);
         }
-        return result;
+        // To get the correct fflags we need a combination of host flags from
+        // the operation and the conversion flags. Copy the host flags and merge
+        // them with the conversion flags.
+        fflags |= rv_fp->fflags()->GetUint32();
+        {
+          // ConvertToHalfFP pollutes the host flags so we need to create a
+          // ScopedFPRoundingMode to restore the host flags.
+          ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value);
+          return ConvertToHalfFP(result, rm, fflags);
+        }
       });
-  if (!zfh_internal::UseHostFlagsForConversion()) {
-    fflags_dest->GetRiscVCsr()->SetBits(fflags);
-  }
-  if (arguments_contain_snan) {
-    fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp);
-  }
+  fflags_dest->GetRiscVCsr()->SetBits(fflags);
 }
 
 // Generic helper function enabling HalfFP operations in native datatypes.
@@ -269,69 +405,44 @@
     const Instruction *instruction,
     std::function<IntermediateType(IntermediateType, IntermediateType)>
         operation) {
-  RiscVFPState *rv_fp =
-      static_cast<RiscVState *>(instruction->state())->rv_fp();
-  int rm_value = generic::GetInstructionSource<int>(instruction, 2);
-
-  // If the rounding mode is dynamic, read it from the current state.
-  if (rm_value == *FPRoundingMode::kDynamic) {
-    if (!rv_fp->rounding_mode_valid()) {
-      LOG(ERROR) << "Invalid rounding mode";
-      return;
-    }
-    rm_value = *(rv_fp->GetRoundingMode());
-  }
-  FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
   RiscVCsrDestinationOperand *fflags_dest =
       static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
-  uint32_t fflags = fflags_dest->GetRiscVCsr()->GetUint32();
-  bool arguments_contain_snan = false;
-  IntermediateType b_emin =
-      std::pow(2.0, 1 - FPTypeInfo<IntermediateType>::kExpBias);
-  IntermediateType result;
+  uint32_t fflags = 0;
   RiscVBinaryFloatNaNBoxOp<RVFpRegister::ValueType, HalfFP, Argument>(
       instruction,
-      [&operation, &arguments_contain_snan, &fflags, rv_fp, &rm, &result](
-          Argument a, Argument b) -> HalfFP {
-        IntermediateType a_f;
-        IntermediateType b_f;
-        if (FPTypeInfo<Argument>::IsSNaN(a)) {
-          a_f = absl::bit_cast<IntermediateType>(
-              FPTypeInfo<IntermediateType>::kPosInf | 1);
-          arguments_contain_snan = true;
-        } else {
-          a_f = ConvertFromHalfFP<IntermediateType>(a, fflags);
+      [instruction, &operation, &fflags](Argument a, Argument b) -> HalfFP {
+        RiscVFPState *rv_fp =
+            static_cast<RiscVState *>(instruction->state())->rv_fp();
+        int rm_value = generic::GetInstructionSource<int>(instruction, 2);
+        // If the rounding mode is dynamic, read it from the current state.
+        if (rm_value == *FPRoundingMode::kDynamic) {
+          if (!rv_fp->rounding_mode_valid()) {
+            LOG(ERROR) << "Invalid rounding mode";
+          }
+          rm_value = *(rv_fp->GetRoundingMode());
         }
-        if (FPTypeInfo<Argument>::IsSNaN(b)) {
-          b_f = absl::bit_cast<IntermediateType>(
-              FPTypeInfo<IntermediateType>::kPosInf | 1);
-          arguments_contain_snan = true;
-        } else {
-          b_f = ConvertFromHalfFP<IntermediateType>(b, fflags);
-        }
-        if (zfh_internal::UseHostFlagsForConversion()) {
-          result = operation(a_f, b_f);
-        } else {
+        FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
+        IntermediateType argument1 =
+            ConvertFromHalfFP<IntermediateType>(a, fflags);
+        IntermediateType argument2 =
+            ConvertFromHalfFP<IntermediateType>(b, fflags);
+        IntermediateType result;
+        {
           ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm);
-          result = operation(a_f, b_f);
+          result = operation(argument1, argument2);
         }
-        if (!zfh_internal::UseHostFlagsForConversion()) {
-          fflags |= rv_fp->fflags()->GetUint32();
+        // To get the correct fflags we need a combination of host flags from
+        // the operation and the conversion flags. Copy the host flags and merge
+        // them with the conversion flags.
+        fflags |= rv_fp->fflags()->GetUint32();
+        {
+          // ConvertToHalfFP pollutes the host flags so we need to create a
+          // ScopedFPRoundingMode to restore the host flags.
+          ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value);
+          return ConvertToHalfFP(result, rm, fflags);
         }
-        return ConvertToHalfFP(result, rm, fflags);
       });
-  if (arguments_contain_snan) {
-    fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp);
-  }
-  if (!zfh_internal::UseHostFlagsForConversion()) {
-    fflags_dest->GetRiscVCsr()->Write(fflags);
-  }
-  // When the result is less than b_emin before rounding we need to set the
-  // underflow flag.
-  if ((fflags_dest->GetRiscVCsr()->GetUint32() & *FPExceptions::kInexact) &&
-      result != 0 && std::abs(result) < b_emin) {
-    fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kUnderflow);
-  }
+  fflags_dest->GetRiscVCsr()->SetBits(fflags);
 }
 
 }  // namespace
@@ -592,14 +703,8 @@
 // Calculate the square root of a half precision value. Do the operation in
 // single precision and then convert back to half precision.
 void RiscVZfhFsqrt(const Instruction *instruction) {
-  RiscVZfhUnaryHelper<HalfFP, HalfFP>(
-      instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP {
-        float input_f = ConvertFromHalfFP<float>(a, fflags);
-        if (!std::isnan(input_f) && input_f < 0) {
-          fflags |= static_cast<uint32_t>(FPExceptions::kInvalidOp);
-        }
-        return ConvertToHalfFP(std::sqrt(input_f), rm, fflags);
-      });
+  RiscVZfhUnaryHelper<HalfFP, float>(
+      instruction, [](float a) -> float { return std::sqrt(a); });
 }
 
 // The result is the exponent and significand of the first source with the
diff --git a/riscv/riscv_zfh_instructions.h b/riscv/riscv_zfh_instructions.h
index 08e33cf..6337a31 100644
--- a/riscv/riscv_zfh_instructions.h
+++ b/riscv/riscv_zfh_instructions.h
@@ -15,11 +15,8 @@
 #ifndef THIRD_PARTY_MPACT_RISCV_RISCV_ZFH_INSTRUCTIONS_H_
 #define THIRD_PARTY_MPACT_RISCV_RISCV_ZFH_INSTRUCTIONS_H_
 
-#include <cstdint>
-
 #include "mpact/sim/generic/instruction.h"
 #include "mpact/sim/generic/type_helpers.h"
-#include "riscv/riscv_fp_info.h"
 
 namespace mpact {
 namespace sim {
@@ -154,13 +151,6 @@
 //                    function.
 void RV32VUnimplementedInstruction(const Instruction *instruction);
 
-HalfFP ConvertSingleToHalfFP(float, FPRoundingMode, uint32_t &);
-HalfFP ConvertDoubleToHalfFP(double, FPRoundingMode, uint32_t &);
-
-namespace zfh_internal {
-bool UseHostFlagsForConversion();
-}  // namespace zfh_internal
-
 // Source Operands:
 //   frs1: Float Register
 //   frs2: Float Register
diff --git a/riscv/riscv_zfh_instructions_arm.cc b/riscv/riscv_zfh_instructions_arm.cc
deleted file mode 100644
index 74104d2..0000000
--- a/riscv/riscv_zfh_instructions_arm.cc
+++ /dev/null
@@ -1,215 +0,0 @@
-// Copyright 2025 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <sys/types.h>
-
-#include <cassert>
-#include <cmath>
-#include <cstdint>
-
-#include "absl/base/casts.h"
-#include "mpact/sim/generic/type_helpers.h"
-#include "riscv/riscv_fp_info.h"
-#include "riscv/riscv_instruction_helpers.h"
-#include "riscv/riscv_zfh_instructions.h"
-
-namespace mpact {
-namespace sim {
-namespace riscv {
-
-using HalfFP = ::mpact::sim::generic::HalfFP;
-
-// TODO(b/401856759): Use arm intrinsics for fp32 -> fp16 and fp64 -> fp16
-//                    conversions.
-
-namespace {
-
-// This is a soft conversion from a float or double to a half precision value.
-// It is not a direct conversion from the floating point format to the half
-// format. Instead, it uses the floating point hardware to do the conversion.
-// This is done to get the correct rounding behavior for free from the FPU.
-template <typename T>
-HalfFP SoftConvertToHalfFP(T input_value, FPRoundingMode rm, uint32_t &fflags) {
-  using UIntType = typename FPTypeInfo<T>::UIntType;
-  using IntType = typename FPTypeInfo<T>::IntType;
-  UIntType in_int = absl::bit_cast<UIntType>(input_value);
-  HalfFP half_fp = {.value = 0x0000};
-
-  // Extract the mantissa, exponent and sign.
-  UIntType mantissa = in_int & FPTypeInfo<T>::kSigMask;
-  UIntType exponent =
-      (in_int & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize;
-  UIntType sign = in_int >> (FPTypeInfo<T>::kBitSize - 1);
-
-  if (std::isnan(input_value)) {
-    half_fp.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
-    if (FPTypeInfo<T>::IsSNaN(input_value)) {
-      fflags |= static_cast<UIntType>(FPExceptions::kInvalidOp);
-    }
-    return half_fp;
-  }
-
-  if (std::isinf(input_value)) {
-    half_fp.value = FPTypeInfo<HalfFP>::kPosInf;
-    half_fp.value |= (sign & 1) << (FPTypeInfo<HalfFP>::kBitSize - 1);
-    return half_fp;
-  }
-
-  if (in_int == 0 || in_int == 1ULL << (FPTypeInfo<T>::kBitSize - 1)) {
-    half_fp.value =
-        in_int >> (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize);
-    return half_fp;
-  }
-
-  IntType bias_diff = FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias;
-  IntType unbounded_half_exponent = static_cast<IntType>(exponent) - bias_diff;
-  IntType sig_size_diff =
-      FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize;
-  UIntType half_inf_exponent = ((1 << FPTypeInfo<HalfFP>::kExpSize) - 1);
-  UIntType source_type_inf_exponent = ((1 << FPTypeInfo<T>::kExpSize) - 1);
-
-  // Create a temp float with the smallest normal exponent and input mantissa.
-  T ftmp = absl::bit_cast<T>(
-      (sign << (FPTypeInfo<T>::kBitSize - 1)) |
-      (static_cast<UIntType>(1ULL) << FPTypeInfo<T>::kSigSize) | mantissa);
-
-  // Create a divisor float that will be used for shifting the mantissa in a
-  // rounding aware way. The amount of shifting depends on if the result is
-  // subnormal or normal.
-  T fdiv = 0;
-  UIntType default_fdiv_exp = FPTypeInfo<T>::kExpBias + sig_size_diff;
-  UIntType fdiv_exp = default_fdiv_exp;
-  if (unbounded_half_exponent > 0) {
-    fdiv_exp = default_fdiv_exp;
-  } else if (unbounded_half_exponent < 0) {
-    // shift_count: emin - unbiased exponent
-    IntType shift_count = 1 - static_cast<int>(exponent) + bias_diff;
-    fdiv_exp = default_fdiv_exp + shift_count;
-    fdiv_exp = std::min(fdiv_exp, source_type_inf_exponent - 1);
-  } else {
-    fdiv_exp = default_fdiv_exp + 1;
-  }
-  fdiv = absl::bit_cast<T>(fdiv_exp << FPTypeInfo<T>::kSigSize);
-
-  // Shift right by doing division.
-  T fres = ftmp / fdiv;
-  UIntType res = absl::bit_cast<UIntType>(fres);
-
-  // Shift left by doing multiplication.
-  T fmultiply = absl::bit_cast<T>(default_fdiv_exp << FPTypeInfo<T>::kSigSize);
-  T fres2 = fres * fmultiply;
-  UIntType res2 = absl::bit_cast<UIntType>(fres2);
-
-  // Update the exponent if rounding caused an increase.
-  IntType exp_diff = static_cast<IntType>((res2 >> FPTypeInfo<T>::kSigSize) &
-                                          source_type_inf_exponent) -
-                     1;
-  UIntType new_exponent = (exponent + exp_diff) & source_type_inf_exponent;
-
-  UIntType half_exponent = 0;
-  if (unbounded_half_exponent > 0) {
-    half_exponent = new_exponent - bias_diff;
-  } else if (unbounded_half_exponent < 0) {
-    // Guaranteed subnormal. Nothing to do.
-  } else {
-    // This case could be normal or subnormal depending on the rounding result.
-    half_exponent = (res2 >> FPTypeInfo<T>::kSigSize) & half_inf_exponent;
-  }
-
-  UIntType half_mantissa =
-      (res2 >> sig_size_diff) & FPTypeInfo<HalfFP>::kSigMask;
-  if (unbounded_half_exponent < 0) {  // Guaranteed Subnormal
-    half_mantissa = (res & (1 << FPTypeInfo<HalfFP>::kSigSize))
-                        ? ((res >> 1) & FPTypeInfo<HalfFP>::kSigMask)
-                        : res & FPTypeInfo<HalfFP>::kSigMask;
-  }
-
-  // Handle the rules for overflowing to infinity depending on the rounding
-  // mode.
-  if (half_exponent >= half_inf_exponent) {
-    fflags |= static_cast<uint32_t>(FPExceptions::kOverflow);
-    fflags |= static_cast<uint32_t>(FPExceptions::kInexact);
-    switch (rm) {
-      case FPRoundingMode::kRoundToNearest:
-        half_exponent = half_inf_exponent;
-        half_mantissa = 0;
-        break;
-      case FPRoundingMode::kRoundTowardsZero:
-        half_exponent = half_inf_exponent - 1;
-        half_mantissa = FPTypeInfo<HalfFP>::kSigMask;
-        break;
-      case FPRoundingMode::kRoundDown:
-        half_exponent = sign ? half_inf_exponent : half_inf_exponent - 1;
-        half_mantissa = sign ? 0 : FPTypeInfo<HalfFP>::kSigMask;
-        break;
-      case FPRoundingMode::kRoundUp:
-        half_exponent = sign ? half_inf_exponent - 1 : half_inf_exponent;
-        half_mantissa = sign ? FPTypeInfo<HalfFP>::kSigMask : 0;
-        break;
-      default:
-        half_exponent = half_inf_exponent;
-        half_mantissa = 0;
-        break;
-    }
-  }
-
-  // Construct the half float.
-  half_fp.value = half_mantissa |
-                  (half_exponent << FPTypeInfo<HalfFP>::kSigSize) |
-                  (sign << (FPTypeInfo<HalfFP>::kBitSize - 1));
-
-  // Do an arithmetic reconstruction of the float to check for exactness.
-  T trailing_significand_float = static_cast<T>(half_mantissa);
-  T precision_factor = std::pow(2.0, -1.0 * FPTypeInfo<HalfFP>::kSigSize);
-  IntType unbiased_exponent =
-      (half_exponent == 0 ? 1 : half_exponent) - FPTypeInfo<HalfFP>::kExpBias;
-  T exponent_factor = std::pow(2.0, unbiased_exponent);
-  T sign_factor = sign == 1 ? -1.0 : 1.0;
-  T implicit_bit_adjustment = half_exponent == 0 ? 0.0 : 1.0;
-  T reconstructed_value = ((trailing_significand_float * precision_factor) +
-                           implicit_bit_adjustment) *
-                          exponent_factor * sign_factor;
-  bool exact_conversion = reconstructed_value == input_value;
-
-  // Handle flags for the specific underflow case.
-  if (!exact_conversion && (unbounded_half_exponent < 0 ||
-                            (unbounded_half_exponent == 0 && fres2 != ftmp))) {
-    fflags |= static_cast<uint32_t>(FPExceptions::kUnderflow);
-  }
-
-  // Handle flags for the specific inexact case.
-  if (!exact_conversion && (fres2 != ftmp)) {
-    fflags |= static_cast<uint32_t>(FPExceptions::kInexact);
-  }
-  return half_fp;
-}
-}  // namespace
-
-HalfFP ConvertSingleToHalfFP(float input_value, FPRoundingMode rm,
-                             uint32_t &fflags) {
-  return SoftConvertToHalfFP(input_value, rm, fflags);
-}
-
-HalfFP ConvertDoubleToHalfFP(double input_value, FPRoundingMode rm,
-                             uint32_t &fflags) {
-  return SoftConvertToHalfFP(input_value, rm, fflags);
-}
-
-namespace zfh_internal {
-bool UseHostFlagsForConversion() { return false; }
-}  // namespace zfh_internal
-
-}  // namespace riscv
-}  // namespace sim
-}  // namespace mpact
diff --git a/riscv/riscv_zfh_instructions_x86.cc b/riscv/riscv_zfh_instructions_x86.cc
deleted file mode 100644
index f99a0fa..0000000
--- a/riscv/riscv_zfh_instructions_x86.cc
+++ /dev/null
@@ -1,74 +0,0 @@
-// Copyright 2025 Google LLC
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-//     https://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#include <immintrin.h>
-#include <sys/types.h>
-
-#include <cstdint>
-
-#include "mpact/sim/generic/type_helpers.h"
-#include "riscv/riscv_fp_info.h"
-#include "riscv/riscv_zfh_instructions.h"
-
-namespace mpact {
-namespace sim {
-namespace riscv {
-
-using HalfFP = ::mpact::sim::generic::HalfFP;
-
-HalfFP ConvertSingleToHalfFP(float input_value, FPRoundingMode rm,
-                             uint32_t &fflags) {
-  HalfFP half_fp;
-
-  // Get current MXCSR value. The simulator should have already configured the
-  // rounding mode so we simply pass it along to the intrinsic.
-  unsigned int mxcsr = _mm_getcsr();
-
-  // Extract rounding control bits (bits 13 and 14)
-  int rounding_control_bits = (mxcsr >> 13) & 0x3;
-
-  switch (rounding_control_bits) {
-    case 0x0:  // Round to nearest
-      half_fp.value = _cvtss_sh(input_value, 0);
-      break;
-    case 0x1:  // Round down
-      half_fp.value = _cvtss_sh(input_value, 1);
-      break;
-    case 0x2:  // Round up
-      half_fp.value = _cvtss_sh(input_value, 2);
-      break;
-    case 0x3:  // Round towards zero
-      half_fp.value = _cvtss_sh(input_value, 3);
-      break;
-    default:  // Default to nearest even if mode is not recognized
-      half_fp.value = _cvtss_sh(input_value, 0);
-      break;
-  }
-
-  return half_fp;
-}
-
-HalfFP ConvertDoubleToHalfFP(double input_value, FPRoundingMode rm,
-                             uint32_t &fflags) {
-  float input_float = static_cast<float>(input_value);
-  return ConvertSingleToHalfFP(input_float, rm, fflags);
-}
-
-namespace zfh_internal {
-bool UseHostFlagsForConversion() { return true; }
-}  // namespace zfh_internal
-
-}  // namespace riscv
-}  // namespace sim
-}  // namespace mpact
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc
index 372ddbd..65e01ce 100644
--- a/riscv/test/riscv_zfh_instructions_test.cc
+++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -797,6 +797,7 @@
               mpact::sim::riscv::FPExceptions::kInvalidOp);
           result.value = FPTypeInfo<HalfFP>::kCanonicalNaN;
         } else if (!FPTypeInfo<HalfFP>::IsNaN(a) &&
+                   !FPTypeInfo<HalfFP>::IsInf(a) &&
                    (b.value == FPTypeInfo<HalfFP>::kPosZero ||
                     b.value == FPTypeInfo<HalfFP>::kNegZero)) {
           // Dividing by zero requires an exception for non-NaN dividend values.