No public description

PiperOrigin-RevId: 785945988
Change-Id: I4fdc18c1e7fe502c091834cc0232e5b8f82eb846
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc
index c0dfcd8..fc8baac 100644
--- a/riscv/riscv_zfh_instructions.cc
+++ b/riscv/riscv_zfh_instructions.cc
@@ -84,12 +84,67 @@
   using type = RV64Register::ValueType;
 };
 
+// Convert from half precision to single or double precision.
+template <typename T>
+inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) {
+  using UIntType = typename FPTypeInfo<T>::UIntType;
+  using HalfFPUIntType = typename FPTypeInfo<HalfFP>::UIntType;
+  HalfFPUIntType in_int = half_fp.value;
+
+  if (FPTypeInfo<HalfFP>::IsNaN(half_fp)) {
+    if (FPTypeInfo<HalfFP>::IsSNaN(half_fp)) {
+      fflags |= static_cast<uint32_t>(FPExceptions::kInvalidOp);
+    }
+    UIntType uint_value = FPTypeInfo<T>::kCanonicalNaN;
+    return absl::bit_cast<T>(uint_value);
+  }
+
+  if (FPTypeInfo<HalfFP>::IsInf(half_fp)) {
+    UIntType uint_value = FPTypeInfo<T>::kPosInf;
+    UIntType sign = in_int >> (FPTypeInfo<HalfFP>::kBitSize - 1);
+    uint_value |= sign << (FPTypeInfo<T>::kBitSize - 1);
+    return absl::bit_cast<T>(uint_value);
+  }
+
+  if ((in_int == 0) || (in_int == (1 << (FPTypeInfo<HalfFP>::kBitSize - 1)))) {
+    UIntType uint_value =
+        static_cast<UIntType>(in_int)
+        << (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize);
+    return absl::bit_cast<T>(uint_value);
+  }
+
+  UIntType in_sign = FPTypeInfo<HalfFP>::SignBit(half_fp);
+  UIntType in_exp =
+      (in_int & FPTypeInfo<HalfFP>::kExpMask) >> FPTypeInfo<HalfFP>::kSigSize;
+  UIntType in_sig = in_int & FPTypeInfo<HalfFP>::kSigMask;
+  UIntType out_int = 0;
+  UIntType out_sig = in_sig;
+  if ((in_exp == 0) && (in_sig != 0)) {
+    // Handle subnormal half precision inputs. They always result in a normal
+    // float or double. Calculate how much shifting is needed move the MSB to
+    // the location of the implicit bit. Then it can be handled as a normal
+    // value from here on.
+    int32_t shift_count =
+        (1 + FPTypeInfo<HalfFP>::kSigSize) -
+        (std::numeric_limits<UIntType>::digits - absl::countl_zero(out_sig));
+    out_sig = (out_sig << shift_count) & FPTypeInfo<HalfFP>::kSigMask;
+    in_exp = 1 - shift_count;
+  }
+  out_int |= in_sign << (FPTypeInfo<T>::kBitSize - 1);
+  out_int |= (in_exp + FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias)
+             << FPTypeInfo<T>::kSigSize;
+  out_int |=
+      out_sig << (FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize);
+  return absl::bit_cast<T>(out_int);
+}
+
 // This is a soft conversion from a float or double to a half precision value.
 // It is not a direct conversion from the floating point format to the half
 // format. Instead, it uses the floating point hardware to do the conversion.
 // This is done to get the correct rounding behavior for free from the FPU.
 template <typename T>
-HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm, uint32_t &fflags) {
+inline HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm,
+                              uint32_t &fflags) {
   using UIntType = typename FPTypeInfo<T>::UIntType;
   using IntType = typename FPTypeInfo<T>::IntType;
   UIntType in_int = absl::bit_cast<UIntType>(input_value);
@@ -115,7 +170,7 @@
     return half_fp;
   }
 
-  if (in_int == 0 || in_int == 1ULL << (FPTypeInfo<T>::kBitSize - 1)) {
+  if ((in_int == 0) || (in_int == (1ULL << (FPTypeInfo<T>::kBitSize - 1)))) {
     half_fp.value =
         in_int >> (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize);
     return half_fp;
@@ -218,22 +273,14 @@
                   (half_exponent << FPTypeInfo<HalfFP>::kSigSize) |
                   (sign << (FPTypeInfo<HalfFP>::kBitSize - 1));
 
-  // Do an arithmetic reconstruction of the float to check for exactness.
-  T trailing_significand_float = static_cast<T>(half_mantissa);
-  T precision_factor = std::pow(2.0, -1.0 * FPTypeInfo<HalfFP>::kSigSize);
-  IntType unbiased_exponent =
-      (half_exponent == 0 ? 1 : half_exponent) - FPTypeInfo<HalfFP>::kExpBias;
-  T exponent_factor = std::pow(2.0, unbiased_exponent);
-  T sign_factor = sign == 1 ? -1.0 : 1.0;
-  T implicit_bit_adjustment = half_exponent == 0 ? 0.0 : 1.0;
-  T reconstructed_value = ((trailing_significand_float * precision_factor) +
-                           implicit_bit_adjustment) *
-                          exponent_factor * sign_factor;
+  uint32_t temp_fflags = 0;
+  T reconstructed_value = ConvertFromHalfFP<T>(half_fp, temp_fflags);
   bool exact_conversion = reconstructed_value == input_value;
 
   // Handle flags for the specific underflow case.
-  if (!exact_conversion && (unbounded_half_exponent < 0 ||
-                            (unbounded_half_exponent == 0 && fres2 != ftmp))) {
+  if (!exact_conversion &&
+      ((unbounded_half_exponent < 0) ||
+       ((unbounded_half_exponent == 0) && (fres2 != ftmp)))) {
     fflags |= static_cast<uint32_t>(FPExceptions::kUnderflow);
   }
 
@@ -244,60 +291,6 @@
   return half_fp;
 }
 
-// Convert from half precision to single or double precision.
-template <typename T>
-inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) {
-  using UIntType = typename FPTypeInfo<T>::UIntType;
-  using HalfFPUIntType = typename FPTypeInfo<HalfFP>::UIntType;
-  HalfFPUIntType in_int = half_fp.value;
-
-  if (FPTypeInfo<HalfFP>::IsNaN(half_fp)) {
-    if (FPTypeInfo<HalfFP>::IsSNaN(half_fp)) {
-      fflags |= static_cast<uint32_t>(FPExceptions::kInvalidOp);
-    }
-    UIntType uint_value = FPTypeInfo<T>::kCanonicalNaN;
-    return absl::bit_cast<T>(uint_value);
-  }
-
-  if (FPTypeInfo<HalfFP>::IsInf(half_fp)) {
-    UIntType uint_value = FPTypeInfo<T>::kPosInf;
-    UIntType sign = in_int >> (FPTypeInfo<HalfFP>::kBitSize - 1);
-    uint_value |= sign << (FPTypeInfo<T>::kBitSize - 1);
-    return absl::bit_cast<T>(uint_value);
-  }
-
-  if (in_int == 0 || in_int == 1 << (FPTypeInfo<HalfFP>::kBitSize - 1)) {
-    UIntType uint_value =
-        static_cast<UIntType>(in_int)
-        << (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize);
-    return absl::bit_cast<T>(uint_value);
-  }
-
-  UIntType in_sign = FPTypeInfo<HalfFP>::SignBit(half_fp);
-  UIntType in_exp =
-      (in_int & FPTypeInfo<HalfFP>::kExpMask) >> FPTypeInfo<HalfFP>::kSigSize;
-  UIntType in_sig = in_int & FPTypeInfo<HalfFP>::kSigMask;
-  UIntType out_int = 0;
-  UIntType out_sig = in_sig;
-  if (in_exp == 0 && in_sig != 0) {
-    // Handle subnormal half precision inputs. They always result in a normal
-    // float or double. Calculate how much shifting is needed move the MSB to
-    // the location of the implicit bit. Then it can be handled as a normal
-    // value from here on.
-    int32_t shift_count =
-        (1 + FPTypeInfo<HalfFP>::kSigSize) -
-        (std::numeric_limits<UIntType>::digits - absl::countl_zero(out_sig));
-    out_sig = (out_sig << shift_count) & FPTypeInfo<HalfFP>::kSigMask;
-    in_exp = 1 - shift_count;
-  }
-  out_int |= in_sign << (FPTypeInfo<T>::kBitSize - 1);
-  out_int |= (in_exp + FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias)
-             << FPTypeInfo<T>::kSigSize;
-  out_int |=
-      out_sig << (FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize);
-  return absl::bit_cast<T>(out_int);
-}
-
 template <typename Result, typename Argument>
 void RiscVZfhCvtHelper(
     const Instruction *instruction,
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc
index 154e3a4..1630e05 100644
--- a/riscv/test/riscv_zfh_instructions_test.cc
+++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -521,7 +521,10 @@
   UnaryOpWithFflagsMixedTestHelper<RVFpRegister, XRegister, HalfFP,
                                    SourceIntegerType>(
       name, this->instruction_, {"x", "f"}, 32,
-      [](SourceIntegerType input_int, int rm) -> std::tuple<HalfFP, uint32_t> {
+      [this](SourceIntegerType input_int,
+             int rm) -> std::tuple<HalfFP, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         HalfFP result = FpConversionsTestHelper(static_cast<double>(input_int))
                             .ConvertWithFlags<HalfFP>(
@@ -540,6 +543,8 @@
       name, this->instruction_, {"f", "x"}, 32,
       [this](HalfFP input,
              int rm) -> std::tuple<DestinationIntegerType, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double input_double =
             FpConversionsTestHelper(input).ConvertWithFlags<double>(fflags);
@@ -832,7 +837,9 @@
   SetSemanticFunction(&RiscVZfhCvtHs);
   UnaryOpWithFflagsFPTestHelper<HalfFP, float>(
       "fcvt.h.s", instruction_, {"f", "f"}, 32,
-      [](float input_float, int rm) -> std::tuple<HalfFP, uint32_t> {
+      [this](float input_float, int rm) -> std::tuple<HalfFP, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         HalfFP half_result = FpConversionsTestHelper(input_float)
                                  .ConvertWithFlags<HalfFP>(
@@ -1086,6 +1093,8 @@
       "fadd.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1112,6 +1121,8 @@
       "fsub.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1138,6 +1149,8 @@
       "fmul.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1164,6 +1177,8 @@
       "fdiv.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1199,6 +1214,8 @@
       "fmin.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1239,6 +1256,8 @@
       "fmax.h", instruction_, {"f", "f", "f"}, 32,
       [this](HalfFP a, HalfFP b) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1321,7 +1340,9 @@
   SetSemanticFunction(&RiscVZfhFsqrt);
   UnaryOpWithFflagsFPTestHelper<HalfFP, HalfFP>(
       "fsqrt.h", instruction_, {"f", "f"}, 32,
-      [](HalfFP input_half, int rm) -> std::tuple<HalfFP, uint32_t> {
+      [this](HalfFP input_half, int rm) -> std::tuple<HalfFP, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double input_double_f = FpConversionsTestHelper(input_half)
                                     .ConvertWithFlags<double>(fflags);
@@ -1402,6 +1423,8 @@
       "fmadd.h", instruction_, {"f", "f", "f", "f"}, 10,
       [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1437,6 +1460,8 @@
       "fmsub.h", instruction_, {"f", "f", "f", "f"}, 10,
       [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1472,6 +1497,8 @@
       "fnmadd.h", instruction_, {"f", "f", "f", "f"}, 10,
       [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1507,6 +1534,8 @@
       "fnmsub.h", instruction_, {"f", "f", "f", "f"}, 10,
       [this](HalfFP a, HalfFP b, HalfFP c) -> std::tuple<HalfFP, uint32_t> {
         FPRoundingMode rm = rv_fp_->GetRoundingMode();
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double a_f =
             FpConversionsTestHelper(a).ConvertWithFlags<double>(fflags);
@@ -1663,6 +1692,8 @@
   UnaryOpWithFflagsMixedTestHelper<RV64Register, RVFpRegister, int64_t, HalfFP>(
       "fcvt.l.h", instruction_, {"f", "x"}, 32,
       [this](HalfFP input, int rm) -> std::tuple<int64_t, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double input_double =
             FpConversionsTestHelper(input).ConvertWithFlags<double>(fflags);
@@ -1679,6 +1710,8 @@
                                    HalfFP>(
       "fcvt.lu.h", instruction_, {"f", "x"}, 32,
       [this](HalfFP input, int rm) -> std::tuple<uint64_t, uint32_t> {
+        ScopedFPRoundingMode scoped_rm(this->rv_fp_->host_fp_interface(),
+                                       FPRoundingMode::kRoundToNearest);
         uint32_t fflags = 0;
         double input_double =
             FpConversionsTestHelper(input).ConvertWithFlags<double>(fflags);