Internal changes to support a half precision datatype. PiperOrigin-RevId: 752410882 Change-Id: Ia4a56b14b02db8c6d05ddc3e5831601f7b3b3749
diff --git a/riscv/riscv_instruction_helpers.h b/riscv/riscv_instruction_helpers.h index 09a4e03..f408d13 100644 --- a/riscv/riscv_instruction_helpers.h +++ b/riscv/riscv_instruction_helpers.h
@@ -483,7 +483,8 @@ ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); dest_value = operation(lhs); } - if (std::isnan(dest_value) && std::signbit(dest_value)) { + if (FPTypeInfo<Result>::IsNaN(dest_value) && + FPTypeInfo<Result>::SignBit(dest_value)) { ResUint res_value = *reinterpret_cast<ResUint *>(&dest_value); res_value &= FPTypeInfo<Result>::kInfMask; dest_value = *reinterpret_cast<Result *>(&res_value); @@ -601,7 +602,7 @@ ScopedFPStatus fp_status(rv_fp->host_fp_interface(), rm_value); dest_value = operation(lhs, rhs); } - if (std::isnan(dest_value)) { + if (FPTypeInfo<Result>::IsNaN(dest_value)) { *reinterpret_cast<typename FPTypeInfo<Result>::UIntType *>(&dest_value) = FPTypeInfo<Result>::kCanonicalNaN; }
diff --git a/riscv/test/BUILD b/riscv/test/BUILD index 6aaf46a..4332d87 100644 --- a/riscv/test/BUILD +++ b/riscv/test/BUILD
@@ -62,6 +62,7 @@ "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", "@com_google_mpact-sim//mpact/sim/generic:instruction", + "@com_google_mpact-sim//mpact/sim/generic:type_helpers", "@com_google_mpact-sim//mpact/sim/util/memory", ], )
diff --git a/riscv/test/riscv_fp_test_base.h b/riscv/test/riscv_fp_test_base.h index fb115c5..81c8b07 100644 --- a/riscv/test/riscv_fp_test_base.h +++ b/riscv/test/riscv_fp_test_base.h
@@ -32,6 +32,7 @@ #include "absl/types/span.h" #include "googlemock/include/gmock/gmock.h" #include "mpact/sim/generic/instruction.h" +#include "mpact/sim/generic/type_helpers.h" #include "mpact/sim/util/memory/flat_demand_memory.h" #include "riscv/riscv_fp_host.h" #include "riscv/riscv_fp_info.h" @@ -44,6 +45,11 @@ namespace riscv { namespace test { +using ::mpact::sim::generic::ConvertHalfToSingle; +using ::mpact::sim::generic::FloatingPointToString; +using ::mpact::sim::generic::HalfFP; +using ::mpact::sim::generic::IsMpactFp; + using ::mpact::sim::generic::Instruction; using ::mpact::sim::riscv::FPRoundingMode; using ::mpact::sim::util::FlatDemandMemory; @@ -97,6 +103,7 @@ IntType uint_val = *reinterpret_cast<IntType *>(&value); return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0); } + static bool IsInf(T value) { return std::isinf(value); } }; template <> @@ -124,6 +131,46 @@ IntType uint_val = *reinterpret_cast<IntType *>(&value); return IsNaN(value) && (((1ULL << (kSigSize - 1)) & uint_val) != 0); } + static bool IsInf(T value) { return std::isinf(value); } +}; + +template <> +struct FPTypeInfo<HalfFP> { + using T = HalfFP; + using IntType = uint16_t; + static const int kExpBias = 15; + static const int kBitSize = sizeof(HalfFP) << 3; + static const int kExpSize = 5; + static const int kSigSize = kBitSize - kExpSize - 1; // 10 from the spec. + static const IntType kExpMask = ((1ULL << kExpSize) - 1) << kSigSize; + static const IntType kSigMask = (1ULL << kSigSize) - 1; + static const IntType kQNaN = kExpMask | (1ULL << (kSigSize - 1)) | 1; + static const IntType kSNaN = kExpMask | 1; + static const IntType kPosInf = kExpMask; + static const IntType kNegInf = kExpMask | (1ULL << (kBitSize - 1)); + static const IntType kPosZero = 0; + static const IntType kNegZero = 1ULL << (kBitSize - 1); + static const IntType kPosDenorm = 1ULL << (kSigSize - 2); + static const IntType kNegDenorm = + (1ULL << (kBitSize - 1)) | (1ULL << (kSigSize - 2)); + static const IntType kCanonicalNaN = 0x7e00; + // std::isnan won't work for half precision. + static bool IsNaN(T wrapper) { + IntType exp = (wrapper.value & kExpMask) >> kSigSize; + IntType sig = wrapper.value & kSigMask; + return (exp == (1 << kExpSize) - 1) && (sig != 0); + } + static bool IsQNaN(T value) { + IntType uint_val = *reinterpret_cast<IntType *>(&value); + IntType significand_msb = (uint_val >> (kSigSize - 1)) & 1; + return IsNaN(value) && (significand_msb != 0); + } + // std::isinf won't work for half precision. + static bool IsInf(T wrapper) { + IntType exp = (wrapper.value & kExpMask) >> kSigSize; + IntType sig = wrapper.value & kSigMask; + return (exp == (1 << kExpSize) - 1) && (sig == 0); + } }; // Templated helper function for classifying fp numbers. @@ -220,6 +267,14 @@ } } +template <> +inline void FPCompare<HalfFP>(HalfFP op, HalfFP reg, int delta_position, + absl::string_view str) { + float op_float = ConvertHalfToSingle(op); + float reg_float = ConvertHalfToSingle(reg); + FPCompare<float>(op_float, reg_float, delta_position, str); +} + template <typename FP> FP OptimizationBarrier(FP op) { asm volatile("" : "+X"(op)); @@ -232,22 +287,19 @@ // part of the enable_if construct. template <typename S, typename D> struct EqualSize { - static const bool value = sizeof(S) == sizeof(D) && - std::is_floating_point<S>::value && + static const bool value = sizeof(S) == sizeof(D) && IsMpactFp<S>::value && std::is_integral<D>::value; }; template <typename S, typename D> struct GreaterSize { static const bool value = - sizeof(S) > sizeof(D) && - std::is_floating_point<S>::value &&std::is_integral<D>::value; + sizeof(S) > sizeof(D) && IsMpactFp<S>::value &&std::is_integral<D>::value; }; template <typename S, typename D> struct LessSize { - static const bool value = sizeof(S) < sizeof(D) && - std::is_floating_point<S>::value && + static const bool value = sizeof(S) < sizeof(D) && IsMpactFp<S>::value && std::is_integral<D>::value; }; @@ -461,7 +513,12 @@ *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm; *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm; for (int i = 0; i < kTestValueLength; i++) { - SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + if constexpr (std::is_integral<LHS>::value) { + SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + } else { + SetNaNBoxedRegisterValues<LHS, LhsRegisterType>( + {{kR1Name, lhs_span[i]}}); + } for (int rm : {0, 1, 2, 3, 4}) { rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); @@ -479,7 +536,8 @@ .first->data_buffer() ->template Get<R>(0); FPCompare<R>(op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i])); + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]))); } } } @@ -523,7 +581,12 @@ *reinterpret_cast<LhsInt *>(&lhs_span[6]) = FPTypeInfo<LHS>::kPosDenorm; *reinterpret_cast<LhsInt *>(&lhs_span[7]) = FPTypeInfo<LHS>::kNegDenorm; for (int i = 0; i < kTestValueLength; i++) { - SetNaNBoxedRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + if constexpr (std::is_integral<LHS>::value) { + SetRegisterValues<LHS, LhsRegisterType>({{kR1Name, lhs_span[i]}}); + } else { + SetNaNBoxedRegisterValues<LHS, LhsRegisterType>( + {{kR1Name, lhs_span[i]}}); + } for (int rm : {0, 1, 2, 3, 4}) { rv_fp_->SetRoundingMode(static_cast<FPRoundingMode>(rm)); @@ -532,13 +595,13 @@ SetRegisterValues<R, DestRegisterType>({{kRdName, 0}}); inst->Execute(nullptr); - auto fflags = rv_fp_->fflags()->GetUint32(); + auto instruction_fflags = rv_fp_->fflags()->GetUint32(); R op_val; - uint32_t flag; + uint32_t test_operation_fflags; { ScopedFPRoundingMode scoped_rm(rv_fp_->host_fp_interface(), rm); - std::tie(op_val, flag) = operation(lhs_span[i], rm); + std::tie(op_val, test_operation_fflags) = operation(lhs_span[i], rm); } auto reg_val = state_->GetRegister<DestRegisterType>(kRdName) @@ -546,12 +609,14 @@ ->template Get<R>(0); FPCompare<R>( op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i], " rm: ", rm)); + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " rm: ", rm)); auto lhs_uint = *reinterpret_cast<LhsInt *>(&lhs_span[i]); auto op_val_uint = *reinterpret_cast<RInt *>(&op_val); - EXPECT_EQ(flag, fflags) - << name << "(" << lhs_span[i] << ") " << std::hex << name << "(0x" - << lhs_uint << ") == " << op_val << std::hex << " 0x" + EXPECT_EQ(test_operation_fflags, instruction_fflags) + << name << "(" << FloatingPointToString<LHS>(lhs_span[i]) << ") " + << std::hex << name << "(0x" << lhs_uint + << ") == " << FloatingPointToString<R>(op_val) << std::hex << " 0x" << op_val_uint << " rm: " << rm; } } @@ -623,9 +688,11 @@ auto reg_val = state_->GetRegister<DestRegisterType>(kRdName) .first->data_buffer() ->template Get<R>(0); - FPCompare<R>(op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i], " ", - rhs_span[i], " rm: ", rm)); + FPCompare<R>( + op_val, reg_val, delta_position, + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " ", + FloatingPointToString<RHS>(rhs_span[i]), " rm: ", rm)); } if (HasFailure()) return; } @@ -693,23 +760,25 @@ inst->Execute(nullptr); - auto fflags = rv_fp_->fflags()->GetUint32(); + auto instruction_fflags = rv_fp_->fflags()->GetUint32(); R op_val; - uint32_t flag; + uint32_t test_operation_fflags; { ScopedFPStatus set_fpstatus(rv_fp_->host_fp_interface()); - std::tie(op_val, flag) = operation(lhs_span[i], rhs_span[i]); + std::tie(op_val, test_operation_fflags) = + operation(lhs_span[i], rhs_span[i]); } auto reg_val = state_->GetRegister<DestRegisterType>(kRdName) .first->data_buffer() ->template Get<R>(0); - FPCompare<R>( - op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i], " ", rhs_span[i])); + FPCompare<R>(op_val, reg_val, delta_position, + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " ", + FloatingPointToString<RHS>(rhs_span[i]))); auto lhs_uint = *reinterpret_cast<LhsUInt *>(&lhs_span[i]); auto rhs_uint = *reinterpret_cast<RhsUInt *>(&rhs_span[i]); - EXPECT_EQ(flag, fflags) + EXPECT_EQ(test_operation_fflags, instruction_fflags) << std::hex << name << "(" << lhs_uint << ", " << rhs_uint << ")"; } } @@ -790,8 +859,10 @@ .first->data_buffer() ->template Get<R>(0); FPCompare<R>(op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i], " ", - mhs_span[i], " ", rhs_span[i])); + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " ", + FloatingPointToString<MHS>(mhs_span[i]), " ", + FloatingPointToString<RHS>(rhs_span[i]))); } } } @@ -865,7 +936,7 @@ inst->Execute(nullptr); // Get the fflags for the instruction execution. - auto fflags = rv_fp_->fflags()->GetUint32(); + auto instruction_fflags = rv_fp_->fflags()->GetUint32(); rv_fp_->fflags()->Write(static_cast<uint32_t>(0)); R op_val; { @@ -876,13 +947,16 @@ .first->data_buffer() ->template Get<R>(0); FPCompare<R>(op_val, reg_val, delta_position, - absl::StrCat(name, " ", i, ": ", lhs_span[i], " ", - mhs_span[i], " ", rhs_span[i])); + absl::StrCat(name, " ", i, ": ", + FloatingPointToString<LHS>(lhs_span[i]), " ", + FloatingPointToString<MHS>(mhs_span[i]), " ", + FloatingPointToString<RHS>(rhs_span[i]))); - auto flag = rv_fp_->fflags()->GetUint32(); - EXPECT_EQ(flag, fflags) - << absl::StrCat(name, " ", i, ": ", lhs_span[i], " ", mhs_span[i], - " ", rhs_span[i]); + auto test_operation_fflags = rv_fp_->fflags()->GetUint32(); + EXPECT_EQ(test_operation_fflags, instruction_fflags) << absl::StrCat( + name, " ", i, ": ", FloatingPointToString<LHS>(lhs_span[i]), " ", + FloatingPointToString<MHS>(mhs_span[i]), " ", + FloatingPointToString<RHS>(rhs_span[i])); } } }