RiscVZfh: Simplify semantic functions for fp conversions

PiperOrigin-RevId: 755095617
Change-Id: Icf8c0a6c9278df7749bc00c642c38e6680feea18
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc
index 8f20a82..cc4d4aa 100644
--- a/riscv/riscv_zfh_instructions.cc
+++ b/riscv/riscv_zfh_instructions.cc
@@ -17,6 +17,7 @@
 #include <cstdint>
 #include <functional>
 #include <limits>
+#include <type_traits>
 
 #include "absl/base/casts.h"
 #include "absl/log/log.h"
@@ -35,9 +36,48 @@
 namespace riscv {
 
 using HalfFP = ::mpact::sim::generic::HalfFP;
+using ::mpact::sim::generic::IsMpactFp;
 
 namespace {
 
+template <typename T>
+struct DataTypeRegValue {};
+
+template <>
+struct DataTypeRegValue<float> {
+  using type = RVFpRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<double> {
+  using type = RVFpRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<HalfFP> {
+  using type = RVFpRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<int32_t> {
+  using type = RVXRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<uint32_t> {
+  using type = RVXRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<int64_t> {
+  using type = RVXRegister::ValueType;
+};
+
+template <>
+struct DataTypeRegValue<uint64_t> {
+  using type = RVXRegister::ValueType;
+};
+
 // Convert from half precision to single or double precision.
 template <typename T>
 inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) {
@@ -96,37 +136,77 @@
 void RiscVZfhCvtHelper(
     const Instruction *instruction,
     std::function<Result(Argument, FPRoundingMode, uint32_t &)> operation) {
+  RiscVCsrDestinationOperand *fflags_dest =
+      static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
+  using DstRegValue = typename DataTypeRegValue<Result>::type;
   uint32_t fflags = 0;
-  RiscVFPState *fp_state =
-      static_cast<RiscVState *>(instruction->state())->rv_fp();
+
+  Argument lhs;
+  if constexpr (IsMpactFp<Argument>::value) {
+    lhs = GetNaNBoxedSource<RVFpRegister::ValueType, Argument>(instruction, 0);
+    if (FPTypeInfo<Argument>::IsSNaN(lhs)) {
+      fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp);
+    }
+  } else {
+    lhs = generic::GetInstructionSource<Argument>(instruction, 0);
+  }
+  // Get the rounding mode.
   int rm_value = generic::GetInstructionSource<int>(instruction, 1);
 
+  auto *rv_fp = static_cast<RiscVState *>(instruction->state())->rv_fp();
   // If the rounding mode is dynamic, read it from the current state.
   if (rm_value == *FPRoundingMode::kDynamic) {
-    if (!fp_state->rounding_mode_valid()) {
+    if (!rv_fp->rounding_mode_valid()) {
       LOG(ERROR) << "Invalid rounding mode";
       return;
     }
-    rm_value = *(fp_state->GetRoundingMode());
+    rm_value = *rv_fp->GetRoundingMode();
   }
-  FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value);
-  RiscVCsrDestinationOperand *fflags_dest =
-      static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
-  RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType,
-                          Result, Argument>(
-      instruction, [fp_state, rm, &fflags, &operation](Argument a) -> Result {
-        Result result;
-        if (zfh_internal::UseHostFlagsForConversion()) {
-          result = operation(a, rm, fflags);
-        } else {
-          ScopedFPStatus set_fpstatus(fp_state->host_fp_interface(), rm);
-          result = operation(a, rm, fflags);
-        }
-        return result;
-      });
+
+  Result dest_value;
+  if (zfh_internal::UseHostFlagsForConversion()) {
+    ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value);
+    dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags);
+  } else {
+    ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value);
+    dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags);
+  }
   if (!zfh_internal::UseHostFlagsForConversion()) {
     fflags_dest->GetRiscVCsr()->SetBits(fflags);
   }
+  auto *reg = static_cast<generic::RegisterDestinationOperand<DstRegValue> *>(
+                  instruction->Destination(0))
+                  ->GetRegister();
+
+  if (sizeof(DstRegValue) > sizeof(Result) && IsMpactFp<Result>::value) {
+    // If the floating point value is narrower than the register, the upper
+    // bits have to be set to all ones.
+    using UReg = typename std::make_unsigned<DstRegValue>::type;
+    using UInt = typename FPTypeInfo<Result>::UIntType;
+    auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value);
+    UReg reg_value = std::numeric_limits<UReg>::max();
+    int shift = 8 * sizeof(Result);
+    reg_value = (reg_value << shift) | dest_u_value;
+    reg->data_buffer()->template Set<DstRegValue>(0, reg_value);
+    return;
+  }
+  reg->data_buffer()->template Set<Result>(0, dest_value);
+}
+
+template <typename T>
+inline HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm,
+                              uint32_t &fflags);
+
+template <>
+inline HalfFP ConvertToHalfFP(float input_value, FPRoundingMode rm,
+                              uint32_t &fflags) {
+  return ConvertSingleToHalfFP(input_value, rm, fflags);
+}
+
+template <>
+inline HalfFP ConvertToHalfFP(double input_value, FPRoundingMode rm,
+                              uint32_t &fflags) {
+  return ConvertDoubleToHalfFP(input_value, rm, fflags);
 }
 
 }  // namespace
@@ -191,43 +271,33 @@
 
 // Convert from half precision to single precision.
 void RiscVZfhCvtSh(const Instruction *instruction) {
-  uint32_t fflags = 0;
-  RiscVCsrDestinationOperand *fflags_dest =
-      static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
-  RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType,
-                          float, HalfFP>(
-      instruction, [&fflags](HalfFP a) -> float {
+  RiscVZfhCvtHelper<float, HalfFP>(
+      instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> float {
         return ConvertFromHalfFP<float>(a, fflags);
       });
-  fflags_dest->GetRiscVCsr()->SetBits(fflags);
 }
 
 // Convert from single precision to half precision.
 void RiscVZfhCvtHs(const Instruction *instruction) {
   RiscVZfhCvtHelper<HalfFP, float>(
       instruction, [](float a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP {
-        return ConvertSingleToHalfFP(a, rm, fflags);
+        return ConvertToHalfFP(a, rm, fflags);
       });
 }
 
 // Convert from half precision to double precision.
 void RiscVZfhCvtDh(const Instruction *instruction) {
-  uint32_t fflags = 0;
-  RiscVCsrDestinationOperand *fflags_dest =
-      static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1));
-  RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType,
-                          double, HalfFP>(
-      instruction, [&fflags](HalfFP a) -> double {
+  RiscVZfhCvtHelper<double, HalfFP>(
+      instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> double {
         return ConvertFromHalfFP<double>(a, fflags);
       });
-  fflags_dest->GetRiscVCsr()->SetBits(fflags);
 }
 
 // Convert from double precision to half precision.
 void RiscVZfhCvtHd(const Instruction *instruction) {
   RiscVZfhCvtHelper<HalfFP, double>(
       instruction, [](double a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP {
-        return ConvertDoubleToHalfFP(a, rm, fflags);
+        return ConvertToHalfFP(a, rm, fflags);
       });
 }