Internal changes to support a half precision datatype.

PiperOrigin-RevId: 752410882
Change-Id: Ia4a56b14b02db8c6d05ddc3e5831601f7b3b3749
diff --git a/riscv/riscv_instruction_helpers.h b/riscv/riscv_instruction_helpers.h
index 09a4e03..f408d13 100644
--- a/riscv/riscv_instruction_helpers.h
+++ b/riscv/riscv_instruction_helpers.h
@@ -483,7 +483,8 @@
     ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value);
     dest_value = operation(lhs);
   }
-  if (std::isnan(dest_value) && std::signbit(dest_value)) {
+  if (FPTypeInfo<Result>::IsNaN(dest_value) &&
+      FPTypeInfo<Result>::SignBit(dest_value)) {
     ResUint res_value = *reinterpret_cast<ResUint *>(&dest_value);
     res_value &= FPTypeInfo<Result>::kInfMask;
     dest_value = *reinterpret_cast<Result *>(&res_value);
@@ -601,7 +602,7 @@
     ScopedFPStatus fp_status(rv_fp->host_fp_interface(), rm_value);
     dest_value = operation(lhs, rhs);
   }
-  if (std::isnan(dest_value)) {
+  if (FPTypeInfo<Result>::IsNaN(dest_value)) {
     *reinterpret_cast<typename FPTypeInfo<Result>::UIntType *>(&dest_value) =
         FPTypeInfo<Result>::kCanonicalNaN;
   }
diff --git a/riscv/test/BUILD b/riscv/test/BUILD
index 6aaf46a..4332d87 100644
--- a/riscv/test/BUILD
+++ b/riscv/test/BUILD
@@ -62,6 +62,7 @@
         "@com_google_absl//absl/strings",
         "@com_google_absl//absl/types:span",
         "@com_google_mpact-sim//mpact/sim/generic:instruction",
+        "@com_google_mpact-sim//mpact/sim/generic:type_helpers",
         "@com_google_mpact-sim//mpact/sim/util/memory",
     ],
 )
diff --git a/riscv/test/riscv_fp_test_base.h b/riscv/test/riscv_fp_test_base.h
index fb115c5..81c8b07 100644
--- a/riscv/test/riscv_fp_test_base.h
+++ b/riscv/test/riscv_fp_test_base.h
@@ -32,6 +32,7 @@
 #include "absl/types/span.h"
 #include "googlemock/include/gmock/gmock.h"
 #include "mpact/sim/generic/instruction.h"
+#include "mpact/sim/generic/type_helpers.h"
 #include "mpact/sim/util/memory/flat_demand_memory.h"
 #include "riscv/riscv_fp_host.h"
 #include "riscv/riscv_fp_info.h"
@@ -44,6 +45,11 @@
 namespace riscv {
 namespace test {
 
+using ::mpact::sim::generic::ConvertHalfToSingle;
+using ::mpact::sim::generic::FloatingPointToString;
+using ::mpact::sim::generic::HalfFP;
+using ::mpact::sim::generic::IsMpactFp;
+
 using ::mpact::sim::generic::Instruction;
 using ::mpact::sim::riscv::FPRoundingMode;
 using ::mpact::sim::util::FlatDemandMemory;
@@ -97,6 +103,7 @@
     IntType uint_val = *reinterpret_cast<IntType *>(&value);
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0);
   }
+  static bool IsInf(T value) { return std::isinf(value); }
 };
 
 template <>
@@ -124,6 +131,46 @@
     IntType uint_val = *reinterpret_cast<IntType *>(&value);
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0);
   }
+  static bool IsInf(T value) { return std::isinf(value); }
+};
+
+template <>
+struct FPTypeInfo<HalfFP> {
+  using T = HalfFP;
+  using IntType = uint16_t;
+  static const int kExpBias = 15;
+  static const int kBitSize = sizeof(HalfFP) << 3;
+  static const int kExpSize = 5;
+  static const int kSigSize = kBitSize - kExpSize - 1;  // 10 from the spec.
+  static const IntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize;
+  static const IntType kSigMask = (1ULL << kSigSize) - 1;
+  static const IntType kQNaN = kExpMask | (1ULL << (kSigSize - 1)) | 1;
+  static const IntType kSNaN = kExpMask | 1;
+  static const IntType kPosInf = kExpMask;
+  static const IntType kNegInf = kExpMask | (1ULL << (kBitSize - 1));
+  static const IntType kPosZero = 0;
+  static const IntType kNegZero = 1ULL << (kBitSize - 1);
+  static const IntType kPosDenorm = 1ULL << (kSigSize - 2);
+  static const IntType kNegDenorm =
+      (1ULL << (kBitSize - 1)) | (1ULL << (kSigSize - 2));
+  static const IntType kCanonicalNaN = 0x7e00;
+  // std::isnan won't work for half precision.
+  static bool IsNaN(T wrapper) {
+    IntType exp = (wrapper.value & kExpMask) >> kSigSize;
+    IntType sig = wrapper.value & kSigMask;
+    return (exp == (1 << kExpSize) - 1) && (sig != 0);
+  }
+  static bool IsQNaN(T value) {
+    IntType uint_val = *reinterpret_cast<IntType *>(&value);
+    IntType significand_msb = (uint_val >> (kSigSize - 1)) & 1;
+    return IsNaN(value) && (significand_msb != 0);
+  }
+  // std::isinf won't work for half precision.
+  static bool IsInf(T wrapper) {
+    IntType exp = (wrapper.value & kExpMask) >> kSigSize;
+    IntType sig = wrapper.value & kSigMask;
+    return (exp == (1 << kExpSize) - 1) && (sig == 0);
+  }
 };
 
 // Templated helper function for classifying fp numbers.
@@ -220,6 +267,14 @@
   }
 }
 
+template <>
+inline void FPCompare<HalfFP>(HalfFP op, HalfFP reg, int delta_position,
+                              absl::string_view str) {
+  float op_float = ConvertHalfToSingle(op);
+  float reg_float = ConvertHalfToSingle(reg);
+  FPCompare<float>(op_float, reg_float, delta_position, str);
+}
+
 template <typename FP>
 FP OptimizationBarrier(FP op) {
   asm volatile("" : "+X"(op));
@@ -232,22 +287,19 @@
 // part of the enable_if construct.
 template <typename S, typename D>
 struct EqualSize {
-  static const bool value = sizeof(S) == sizeof(D) &&
-                            std::is_floating_point<S>::value &&
+  static const bool value = sizeof(S) == sizeof(D) && IsMpactFp<S>::value &&
                             std::is_integral<D>::value;
 };
 
 template <typename S, typename D>
 struct GreaterSize {
   static const bool value =
-      sizeof(S) > sizeof(D) &&
-      std::is_floating_point<S>::value &&std::is_integral<D>::value;
+      sizeof(S) > sizeof(D) && IsMpactFp<S>::value &&std::is_integral<D>::value;
 };
 
 template <typename S, typename D>
 struct LessSize {
-  static const bool value = sizeof(S) < sizeof(D) &&
-                            std::is_floating_point<S>::value &&
+  static const bool value = sizeof(S) < sizeof(D) && IsMpactFp<S>::value &&
                             std::is_integral<D>::value;
 };
 
@@ -461,7 +513,12 @@
     *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm;
     *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm;
     for (int i = 0; i < kTestValueLength; i++) {
-      SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}});
+      if constexpr (std::is_integral<LHS>::value) {
+        SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}});
+      } else {
+        SetNaNBoxedRegisterValues<LHS, LhsRegisterType>(
+            {{kR1Name, lhs_span[i]}});
+      }
 
       for (int rm : {0, 1, 2, 3, 4}) {
         rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm));
@@ -479,7 +536,8 @@
                            .first->data_buffer()
                            ->template Get<R>(0);
         FPCompare<R>(op_val, reg_val, delta_position,
-                     absl::StrCat(name, "  ", i, ": ", lhs_span[i]));
+                     absl::StrCat(name, "  ", i, ": ",
+                                  FloatingPointToString<LHS>(lhs_span[i])));
       }
     }
   }
@@ -523,7 +581,12 @@
     *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm;
     *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm;
     for (int i = 0; i < kTestValueLength; i++) {
-      SetNaNBoxedRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}});
+      if constexpr (std::is_integral<LHS>::value) {
+        SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}});
+      } else {
+        SetNaNBoxedRegisterValues<LHS, LhsRegisterType>(
+            {{kR1Name, lhs_span[i]}});
+      }
 
       for (int rm : {0, 1, 2, 3, 4}) {
         rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm));
@@ -532,13 +595,13 @@
         SetRegisterValues<R, DestRegisterType>({{kRdName, 0}});
 
         inst->Execute(nullptr);
-        auto fflags = rv_fp_->fflags()->GetUint32();
+        auto instruction_fflags = rv_fp_->fflags()->GetUint32();
 
         R op_val;
-        uint32_t flag;
+        uint32_t test_operation_fflags;
         {
           ScopedFPRoundingMode scoped_rm(rv_fp_->host_fp_interface(), rm);
-          std::tie(op_val, flag) = operation(lhs_span[i], rm);
+          std::tie(op_val, test_operation_fflags) = operation(lhs_span[i], rm);
         }
 
         auto reg_val = state_->GetRegister<DestRegisterType>(kRdName)
@@ -546,12 +609,14 @@
                            ->template Get<R>(0);
         FPCompare<R>(
             op_val, reg_val, delta_position,
-            absl::StrCat(name, "  ", i, ": ", lhs_span[i], " rm: ", rm));
+            absl::StrCat(name, "  ", i, ": ",
+                         FloatingPointToString<LHS>(lhs_span[i]), " rm: ", rm));
         auto lhs_uint = *reinterpret_cast<LhsInt *>(&lhs_span[i]);
         auto op_val_uint = *reinterpret_cast<RInt *>(&op_val);
-        EXPECT_EQ(flag, fflags)
-            << name << "(" << lhs_span[i] << ")  " << std::hex << name << "(0x"
-            << lhs_uint << ") == " << op_val << std::hex << "  0x"
+        EXPECT_EQ(test_operation_fflags, instruction_fflags)
+            << name << "(" << FloatingPointToString<LHS>(lhs_span[i]) << ")  "
+            << std::hex << name << "(0x" << lhs_uint
+            << ") == " << FloatingPointToString<R>(op_val) << std::hex << "  0x"
             << op_val_uint << " rm: " << rm;
       }
     }
@@ -623,9 +688,11 @@
         auto reg_val = state_->GetRegister<DestRegisterType>(kRdName)
                            .first->data_buffer()
                            ->template Get<R>(0);
-        FPCompare<R>(op_val, reg_val, delta_position,
-                     absl::StrCat(name, "  ", i, ": ", lhs_span[i], "  ",
-                                  rhs_span[i], " rm: ", rm));
+        FPCompare<R>(
+            op_val, reg_val, delta_position,
+            absl::StrCat(name, "  ", i, ": ",
+                         FloatingPointToString<LHS>(lhs_span[i]), "  ",
+                         FloatingPointToString<RHS>(rhs_span[i]), " rm: ", rm));
       }
       if (HasFailure()) return;
     }
@@ -693,23 +760,25 @@
 
         inst->Execute(nullptr);
 
-        auto fflags = rv_fp_->fflags()->GetUint32();
+        auto instruction_fflags = rv_fp_->fflags()->GetUint32();
 
         R op_val;
-        uint32_t flag;
+        uint32_t test_operation_fflags;
         {
           ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface());
-          std::tie(op_val, flag) = operation(lhs_span[i], rhs_span[i]);
+          std::tie(op_val, test_operation_fflags) =
+              operation(lhs_span[i], rhs_span[i]);
         }
         auto reg_val = state_->GetRegister<DestRegisterType>(kRdName)
                            .first->data_buffer()
                            ->template Get<R>(0);
-        FPCompare<R>(
-            op_val, reg_val, delta_position,
-            absl::StrCat(name, "  ", i, ": ", lhs_span[i], "  ", rhs_span[i]));
+        FPCompare<R>(op_val, reg_val, delta_position,
+                     absl::StrCat(name, "  ", i, ": ",
+                                  FloatingPointToString<LHS>(lhs_span[i]), "  ",
+                                  FloatingPointToString<RHS>(rhs_span[i])));
         auto lhs_uint = *reinterpret_cast<LhsUInt *>(&lhs_span[i]);
         auto rhs_uint = *reinterpret_cast<RhsUInt *>(&rhs_span[i]);
-        EXPECT_EQ(flag, fflags)
+        EXPECT_EQ(test_operation_fflags, instruction_fflags)
             << std::hex << name << "(" << lhs_uint << ", " << rhs_uint << ")";
       }
     }
@@ -790,8 +859,10 @@
                            .first->data_buffer()
                            ->template Get<R>(0);
         FPCompare<R>(op_val, reg_val, delta_position,
-                     absl::StrCat(name, "  ", i, ": ", lhs_span[i], "  ",
-                                  mhs_span[i], " ", rhs_span[i]));
+                     absl::StrCat(name, "  ", i, ": ",
+                                  FloatingPointToString<LHS>(lhs_span[i]), "  ",
+                                  FloatingPointToString<MHS>(mhs_span[i]), "  ",
+                                  FloatingPointToString<RHS>(rhs_span[i])));
       }
     }
   }
@@ -865,7 +936,7 @@
 
         inst->Execute(nullptr);
         // Get the fflags for the instruction execution.
-        auto fflags = rv_fp_->fflags()->GetUint32();
+        auto instruction_fflags = rv_fp_->fflags()->GetUint32();
         rv_fp_->fflags()->Write(static_cast<uint32_t>(0));
         R op_val;
         {
@@ -876,13 +947,16 @@
                            .first->data_buffer()
                            ->template Get<R>(0);
         FPCompare<R>(op_val, reg_val, delta_position,
-                     absl::StrCat(name, "  ", i, ": ", lhs_span[i], "  ",
-                                  mhs_span[i], " ", rhs_span[i]));
+                     absl::StrCat(name, "  ", i, ": ",
+                                  FloatingPointToString<LHS>(lhs_span[i]), "  ",
+                                  FloatingPointToString<MHS>(mhs_span[i]), "  ",
+                                  FloatingPointToString<RHS>(rhs_span[i])));
 
-        auto flag = rv_fp_->fflags()->GetUint32();
-        EXPECT_EQ(flag, fflags)
-            << absl::StrCat(name, "  ", i, ": ", lhs_span[i], "  ", mhs_span[i],
-                            " ", rhs_span[i]);
+        auto test_operation_fflags = rv_fp_->fflags()->GetUint32();
+        EXPECT_EQ(test_operation_fflags, instruction_fflags) << absl::StrCat(
+            name, "  ", i, ": ", FloatingPointToString<LHS>(lhs_span[i]), " ",
+            FloatingPointToString<MHS>(mhs_span[i]), " ",
+            FloatingPointToString<RHS>(rhs_span[i]));
       }
     }
   }