Internal changes to support a half precision datatype. PiperOrigin-RevId: 752410882 Change-Id: Ic2e446fe63771c06fbc772a77838c6e4165f6849
diff --git a/mpact/sim/generic/BUILD b/mpact/sim/generic/BUILD index 5e6e266..b9ac8e6 100644 --- a/mpact/sim/generic/BUILD +++ b/mpact/sim/generic/BUILD
@@ -147,6 +147,7 @@ deps = [ ":arch_state", ":core", + ":type_helpers", "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/types:span", ],
diff --git a/mpact/sim/generic/instruction.h b/mpact/sim/generic/instruction.h index af4e7eb..8a553e7 100644 --- a/mpact/sim/generic/instruction.h +++ b/mpact/sim/generic/instruction.h
@@ -26,6 +26,7 @@ #include "mpact/sim/generic/arch_state.h" #include "mpact/sim/generic/operand_interface.h" #include "mpact/sim/generic/ref_count.h" +#include "mpact/sim/generic/type_helpers.h" namespace mpact { namespace sim { @@ -267,6 +268,11 @@ return inst->Source(index)->AsInt16(0); } template <> +inline HalfFP GetInstructionSource<HalfFP>(const Instruction *inst, int index) { + auto value = inst->Source(index)->AsUint16(0); + return HalfFP{.value = value}; +} +template <> inline uint32_t GetInstructionSource<uint32_t>(const Instruction *inst, int index) { return inst->Source(index)->AsUint32(0);
diff --git a/mpact/sim/generic/type_helpers.h b/mpact/sim/generic/type_helpers.h index c7a74b4..4816ed1 100644 --- a/mpact/sim/generic/type_helpers.h +++ b/mpact/sim/generic/type_helpers.h
@@ -16,6 +16,7 @@ #define MPACT_SIM_GENERIC_TYPE_HELPERS_H_ #include <cstdint> +#include <string> #include <type_traits> #include "absl/numeric/int128.h" @@ -121,6 +122,11 @@ using type = absl::uint128; }; +// Make a half floating point type since it isn't native to C++. +struct HalfFP { + uint16_t value; +}; + // Helper template for floating point type information. This allows the specific // information for each fp type to be easily extracted. template <typename T> @@ -130,6 +136,43 @@ }; template <> +struct FPTypeInfo<HalfFP> { + using T = HalfFP; + using UIntType = uint16_t; + using IntType = std::make_signed<UIntType>::type; + static constexpr int kBitSize = sizeof(HalfFP) << 3; + static constexpr int kExpSize = 5; + static constexpr int kExpBias = 15; + static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1; + static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1; + static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1; + static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1); + static constexpr UIntType kPosInf = kExpMask; + static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static inline bool IsInf(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return (uint_val & kInfMask) == kPosInf; + } + static inline bool IsNaN(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return ((uint_val & kExpMask) == kExpMask) && ((uint_val & kSigMask) != 0); + } + static inline bool IsSNaN(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) == 0); + } + static inline bool IsQNaN(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) != 0); + } + static inline bool SignBit(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return 1 == (uint_val >> (kBitSize - 1)); + } +}; + +template <> struct FPTypeInfo<float> { using T = float; using UIntType = uint32_t; @@ -161,6 +204,10 @@ UIntType uint_val = *reinterpret_cast<UIntType *>(&value); return IsNaN(value) && (((1 << (kSigSize - 1)) & uint_val) != 0); } + static inline bool SignBit(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return 1 == (uint_val >> (kBitSize - 1)); + } }; template <> @@ -195,6 +242,66 @@ UIntType uint_val = *reinterpret_cast<UIntType *>(&value); return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0); } + static inline bool SignBit(T value) { + UIntType uint_val = *reinterpret_cast<UIntType *>(&value); + return 1 == (uint_val >> (kBitSize - 1)); + } +}; + +template <> +struct FPTypeInfo<uint16_t> { + using T = uint16_t; + using UIntType = uint16_t; + using IntType = std::make_signed<UIntType>::type; + static constexpr int kBitSize = sizeof(HalfFP) << 3; + static constexpr int kExpSize = 5; + static constexpr int kExpBias = 15; + static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1; + static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1; + static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1; + static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1); + static constexpr UIntType kPosInf = kExpMask; + static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static inline bool IsInf(T value) { return (value & kInfMask) == kPosInf; } + static inline bool IsNaN(T value) { + return ((value & kExpMask) == kExpMask) && ((value & kSigMask) != 0); + } + static inline bool IsSNaN(T value) { + return IsNaN(value) && (((1 << (kSigSize - 1)) & value) == 0); + } + static inline bool IsQNaN(T value) { + return IsNaN(value) && (((1 << (kSigSize - 1)) & value) != 0); + } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } +}; + +template <> +struct FPTypeInfo<int16_t> { + using T = int16_t; + using UIntType = uint16_t; + using IntType = std::make_signed<UIntType>::type; + static constexpr int kBitSize = sizeof(HalfFP) << 3; + static constexpr int kExpSize = 5; + static constexpr int kExpBias = 15; + static constexpr int kSigSize = kBitSize - kExpSize - /*sign*/ 1; + static constexpr UIntType kInfMask = (1ULL << (kBitSize - 1)) - 1; + static constexpr UIntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static constexpr UIntType kSigMask = (1ULL << kSigSize) - 1; + static constexpr UIntType kCanonicalNaN = 0b0'11111'1ULL << (kSigSize - 1); + static constexpr UIntType kPosInf = kExpMask; + static constexpr UIntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static inline bool IsInf(T value) { return (value & kInfMask) == kPosInf; } + static inline bool IsNaN(T value) { + return ((value & kExpMask) == kExpMask) && ((value & kSigMask) != 0); + } + static inline bool IsSNaN(T value) { + return IsNaN(value) && (((1 << (kSigSize - 1)) & value) == 0); + } + static inline bool IsQNaN(T value) { + return IsNaN(value) && (((1 << (kSigSize - 1)) & value) != 0); + } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } }; template <> @@ -223,6 +330,7 @@ static inline bool IsQNaN(T value) { return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0); } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } }; template <> @@ -251,6 +359,7 @@ static inline bool IsQNaN(T value) { return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0); } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } }; template <> @@ -279,6 +388,7 @@ static inline bool IsQNaN(T value) { return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0); } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } }; template <> @@ -307,8 +417,83 @@ static inline bool IsQNaN(T value) { return IsNaN(value) && (((1ULL << (kSigSize - 1)) & value) != 0); } + static inline bool SignBit(T value) { return 1 == (value >> (kBitSize - 1)); } }; +inline float ConvertHalfToSingle(HalfFP half) { + uint32_t float_uint; + bool sign = half.value >> (FPTypeInfo<HalfFP>::kBitSize - 1); + + if (half.value == FPTypeInfo<HalfFP>::kPosInf) { + float_uint = FPTypeInfo<float>::kPosInf; + return *reinterpret_cast<float *>(&float_uint); + } + + if (half.value == FPTypeInfo<HalfFP>::kNegInf) { + float_uint = FPTypeInfo<float>::kNegInf; + return *reinterpret_cast<float *>(&float_uint); + } + + if (FPTypeInfo<HalfFP>::IsNaN(half)) { + float_uint = FPTypeInfo<float>::kCanonicalNaN; + float_uint |= sign << (FPTypeInfo<float>::kBitSize - 1); + return *reinterpret_cast<float *>(&float_uint); + } + + if (half.value == 0) { + float_uint = 0; + return *reinterpret_cast<float *>(&float_uint); + } + + if (half.value == 1 << (FPTypeInfo<HalfFP>::kBitSize - 1)) { + float_uint = 1 << (FPTypeInfo<float>::kBitSize - 1); + return *reinterpret_cast<float *>(&float_uint); + } + + uint32_t exp = (half.value & FPTypeInfo<HalfFP>::kExpMask) >> + (FPTypeInfo<HalfFP>::kSigSize); + uint32_t sig = half.value & FPTypeInfo<HalfFP>::kSigMask; + if (exp == 0 && sig != 0) { + // Subnormal value. + int32_t shift_count = 0; + while ((sig & (1 << FPTypeInfo<HalfFP>::kSigSize)) == 0) { + sig <<= 1; + shift_count++; + } + sig &= FPTypeInfo<HalfFP>::kSigMask; + exp = 1 - shift_count; + } + exp += FPTypeInfo<float>::kExpBias - FPTypeInfo<HalfFP>::kExpBias; + sig <<= FPTypeInfo<float>::kSigSize - FPTypeInfo<HalfFP>::kSigSize; + float_uint = (exp << FPTypeInfo<float>::kSigSize) | sig; + float_uint |= sign << (FPTypeInfo<float>::kBitSize - 1); + return *reinterpret_cast<float *>(&float_uint); +} + +// A replacement for std::is_floating_point that works for half precision. +template <typename T> +struct IsMpactFp { + static constexpr bool value = std::is_floating_point<T>::value; +}; + +template <> +struct IsMpactFp<HalfFP> { + static constexpr bool value = true; +}; + +// A helper to print the contents of a floating point that also works for half +// precision. +template <typename T> +inline std::string FloatingPointToString(T floating_point) { + return std::to_string(floating_point); +} + +template <> +inline std::string FloatingPointToString(HalfFP floating_point) { + // Convert to float and then convert to string. + return FloatingPointToString<float>(ConvertHalfToSingle(floating_point)); +} + // This templated helper function defines the dereference '*' operand for // enumeration class types and uses it to cast the enum class member to the // underlying type. That means for an enum class E and member e, this