Internal changes to support a half precision datatype.

PiperOrigin-RevId: 752410882
Change-Id: Ic2e446fe63771c06fbc772a77838c6e4165f6849
diff --git a/mpact/sim/generic/BUILD b/mpact/sim/generic/BUILD
index 5e6e266..b9ac8e6 100644
--- a/mpact/sim/generic/BUILD
+++ b/mpact/sim/generic/BUILD
@@ -147,6 +147,7 @@
     deps = [
         ":arch_state",
         ":core",
+        ":type_helpers",
         "@com_google_absl//absl/numeric:int128",
         "@com_google_absl//absl/types:span",
     ],
diff --git a/mpact/sim/generic/instruction.h b/mpact/sim/generic/instruction.h
index af4e7eb..8a553e7 100644
--- a/mpact/sim/generic/instruction.h
+++ b/mpact/sim/generic/instruction.h
@@ -26,6 +26,7 @@
 #include "mpact/sim/generic/arch_state.h"
 #include "mpact/sim/generic/operand_interface.h"
 #include "mpact/sim/generic/ref_count.h"
+#include "mpact/sim/generic/type_helpers.h"
 
 namespace mpact {
 namespace sim {
@@ -267,6 +268,11 @@
   return inst->Source(index)->AsInt16(0);
 }
 template <>
+inline HalfFP GetInstructionSource<HalfFP>(const Instruction *inst, int index) {
+  auto value = inst->Source(index)->AsUint16(0);
+  return HalfFP{.value = value};
+}
+template <>
 inline uint32_t GetInstructionSource<uint32_t>(const Instruction *inst,
                                                int index) {
   return inst->Source(index)->AsUint32(0);
diff --git a/mpact/sim/generic/type_helpers.h b/mpact/sim/generic/type_helpers.h
index c7a74b4..4816ed1 100644
--- a/mpact/sim/generic/type_helpers.h
+++ b/mpact/sim/generic/type_helpers.h
@@ -16,6 +16,7 @@
 #define MPACT_SIM_GENERIC_TYPE_HELPERS_H_
 
 #include <cstdint>
+#include <string>
 #include <type_traits>
 
 #include "absl/numeric/int128.h"
@@ -121,6 +122,11 @@
   using type = absl::uint128;
 };
 
+// Make a half floating point type since it isn't native to C++.
+struct HalfFP {
+  uint16_t value;
+};
+
 // Helper template for floating point type information. This allows the specific
 // information for each fp type to be easily extracted.
 template <typename T>
@@ -130,6 +136,43 @@
 };
 
 template <>
+struct FPTypeInfo<HalfFP> {
+  using T = HalfFP;
+  using UIntType = uint16_t;
+  using IntType = std::make_signed<UIntType>::type;
+  static constexpr int kBitSize = sizeof(HalfFP) << 3;
+  static constexpr int kExpSize = 5;
+  static constexpr int kExpBias = 15;
+  static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1;
+  static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1;
+  static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize;
+  static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1;
+  static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1);
+  static constexpr UIntType kPosInf = kExpMask;
+  static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1));
+  static inline bool IsInf(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return (uint_val & kInfMask) == kPosInf;
+  }
+  static inline bool IsNaN(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return ((uint_val & kExpMask) == kExpMask) && ((uint_val & kSigMask) != 0);
+  }
+  static inline bool IsSNaN(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) == 0);
+  }
+  static inline bool IsQNaN(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) != 0);
+  }
+  static inline bool SignBit(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return 1 == (uint_val >> (kBitSize - 1));
+  }
+};
+
+template <>
 struct FPTypeInfo<float> {
   using T = float;
   using UIntType = uint32_t;
@@ -161,6 +204,10 @@
     UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
     return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) != 0);
   }
+  static inline bool SignBit(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return 1 == (uint_val >> (kBitSize - 1));
+  }
 };
 
 template <>
@@ -195,6 +242,66 @@
     UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0);
   }
+  static inline bool SignBit(T value) {
+    UIntType uint_val = *reinterpret_cast<UIntType *>(&value);
+    return 1 == (uint_val >> (kBitSize - 1));
+  }
+};
+
+template <>
+struct FPTypeInfo<uint16_t> {
+  using T = uint16_t;
+  using UIntType = uint16_t;
+  using IntType = std::make_signed<UIntType>::type;
+  static constexpr int kBitSize = sizeof(HalfFP) << 3;
+  static constexpr int kExpSize = 5;
+  static constexpr int kExpBias = 15;
+  static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1;
+  static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1;
+  static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize;
+  static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1;
+  static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1);
+  static constexpr UIntType kPosInf = kExpMask;
+  static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1));
+  static inline bool IsInf(T value) { return (value & kInfMask) == kPosInf; }
+  static inline bool IsNaN(T value) {
+    return ((value & kExpMask) == kExpMask) && ((value & kSigMask) != 0);
+  }
+  static inline bool IsSNaN(T value) {
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & value) == 0);
+  }
+  static inline bool IsQNaN(T value) {
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & value) != 0);
+  }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
+};
+
+template <>
+struct FPTypeInfo<int16_t> {
+  using T = int16_t;
+  using UIntType = uint16_t;
+  using IntType = std::make_signed<UIntType>::type;
+  static constexpr int kBitSize = sizeof(HalfFP) << 3;
+  static constexpr int kExpSize = 5;
+  static constexpr int kExpBias = 15;
+  static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1;
+  static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1;
+  static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize;
+  static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1;
+  static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1);
+  static constexpr UIntType kPosInf = kExpMask;
+  static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1));
+  static inline bool IsInf(T value) { return (value & kInfMask) == kPosInf; }
+  static inline bool IsNaN(T value) {
+    return ((value & kExpMask) == kExpMask) && ((value & kSigMask) != 0);
+  }
+  static inline bool IsSNaN(T value) {
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & value) == 0);
+  }
+  static inline bool IsQNaN(T value) {
+    return IsNaN(value) && (((1 << (kSigSize - 1)) & value) != 0);
+  }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
 };
 
 template <>
@@ -223,6 +330,7 @@
   static inline bool IsQNaN(T value) {
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0);
   }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
 };
 
 template <>
@@ -251,6 +359,7 @@
   static inline bool IsQNaN(T value) {
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0);
   }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
 };
 
 template <>
@@ -279,6 +388,7 @@
   static inline bool IsQNaN(T value) {
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0);
   }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
 };
 
 template <>
@@ -307,8 +417,83 @@
   static inline bool IsQNaN(T value) {
     return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0);
   }
+  static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); }
 };
 
+inline float ConvertHalfToSingle(HalfFP half) {
+  uint32_t float_uint;
+  bool sign = half.value >> (FPTypeInfo<HalfFP>::kBitSize - 1);
+
+  if (half.value == FPTypeInfo<HalfFP>::kPosInf) {
+    float_uint = FPTypeInfo<float>::kPosInf;
+    return *reinterpret_cast<float *>(&float_uint);
+  }
+
+  if (half.value == FPTypeInfo<HalfFP>::kNegInf) {
+    float_uint = FPTypeInfo<float>::kNegInf;
+    return *reinterpret_cast<float *>(&float_uint);
+  }
+
+  if (FPTypeInfo<HalfFP>::IsNaN(half)) {
+    float_uint = FPTypeInfo<float>::kCanonicalNaN;
+    float_uint |= sign << (FPTypeInfo<float>::kBitSize - 1);
+    return *reinterpret_cast<float *>(&float_uint);
+  }
+
+  if (half.value == 0) {
+    float_uint = 0;
+    return *reinterpret_cast<float *>(&float_uint);
+  }
+
+  if (half.value == 1 << (FPTypeInfo<HalfFP>::kBitSize - 1)) {
+    float_uint = 1 << (FPTypeInfo<float>::kBitSize - 1);
+    return *reinterpret_cast<float *>(&float_uint);
+  }
+
+  uint32_t exp = (half.value & FPTypeInfo<HalfFP>::kExpMask) >>
+                 (FPTypeInfo<HalfFP>::kSigSize);
+  uint32_t sig = half.value & FPTypeInfo<HalfFP>::kSigMask;
+  if (exp == 0 && sig != 0) {
+    // Subnormal value.
+    int32_t shift_count = 0;
+    while ((sig & (1 << FPTypeInfo<HalfFP>::kSigSize)) == 0) {
+      sig <<= 1;
+      shift_count++;
+    }
+    sig &= FPTypeInfo<HalfFP>::kSigMask;
+    exp = 1 - shift_count;
+  }
+  exp += FPTypeInfo<float>::kExpBias - FPTypeInfo<HalfFP>::kExpBias;
+  sig <<= FPTypeInfo<float>::kSigSize - FPTypeInfo<HalfFP>::kSigSize;
+  float_uint = (exp << FPTypeInfo<float>::kSigSize) | sig;
+  float_uint |= sign << (FPTypeInfo<float>::kBitSize - 1);
+  return *reinterpret_cast<float *>(&float_uint);
+}
+
+// A replacement for std::is_floating_point that works for half precision.
+template <typename T>
+struct IsMpactFp {
+  static constexpr bool value = std::is_floating_point<T>::value;
+};
+
+template <>
+struct IsMpactFp<HalfFP> {
+  static constexpr bool value = true;
+};
+
+// A helper to print the contents of a floating point that also works for half
+// precision.
+template <typename T>
+inline std::string FloatingPointToString(T floating_point) {
+  return std::to_string(floating_point);
+}
+
+template <>
+inline std::string FloatingPointToString(HalfFP floating_point) {
+  // Convert to float and then convert to string.
+  return FloatingPointToString<float>(ConvertHalfToSingle(floating_point));
+}
+
 // This templated helper function defines the dereference '*' operand for
 // enumeration class types and uses it to cast the enum class member to the
 // underlying type. That means for an enum class E and member e, this