zfh simplification * Remove platform dependant conversions to half precision * Resolve double to half conversion bug when cast followed by round resulted in the incorrect value * Fix inf / 0 incorrect fflags expectations * fmin / fmax dont use the rounding mode operand PiperOrigin-RevId: 756900912 Change-Id: Ie21c09b372d587de7ae2d18abe9beebb6470af68
diff --git a/riscv/BUILD b/riscv/BUILD index cfc03dc..962213f 100644 --- a/riscv/BUILD +++ b/riscv/BUILD
@@ -329,35 +329,15 @@ cc_library( name = "riscv_zfh_instructions", - srcs = select({ - "//third_party/bazel_platforms/cpu:aarch64": [ - "riscv_zfh_instructions.cc", - "riscv_zfh_instructions_arm.cc", - ], - "//conditions:default": [ - "riscv_zfh_instructions.cc", - "riscv_zfh_instructions_x86.cc", - ], - }), + srcs = ["riscv_zfh_instructions.cc"], hdrs = [ "riscv_instruction_helpers.h", "riscv_zfh_instructions.h", ], - copts = select({ - "//third_party/bazel_platforms/cpu:aarch64": [ - "-O3", - "-ffp-model=strict", - ], - "//buildenv/platforms/settings:macos_aarch64": [ - "-O3", - "-ffp-model=strict", - ], - "//conditions:default": [ - "-ffp-model=strict", - "-O3", - "-mf16c", - ], - }), + copts = [ + "-ffp-model=strict", + "-O3", + ], deps = [ ":riscv_fp_state", ":riscv_state",
diff --git a/riscv/riscv_zfh.isa b/riscv/riscv_zfh.isa index abd0fa8..c4514e7 100644 --- a/riscv/riscv_zfh.isa +++ b/riscv/riscv_zfh.isa
@@ -136,11 +136,11 @@ resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFdiv", disasm: "fdiv.h", "%frd, %frs1, %frs2"; - fmin_h{: frs1, frs2, rm : frd, fflags}, + fmin_h{: frs1, frs2 : frd, fflags}, resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFmin", disasm: "fmin.h", "%frd, %frs1, %frs2"; - fmax_h{: frs1, frs2, rm : frd, fflags}, + fmax_h{: frs1, frs2 : frd, fflags}, resources: {next_pc, frs1, frs2 : frd[0..]}, semfunc: "&RiscVZfhFmax", disasm: "fmax.h", "%frd, %frs1, %frs2";
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc index 928bf24..f4b8ad2 100644 --- a/riscv/riscv_zfh_instructions.cc +++ b/riscv/riscv_zfh_instructions.cc
@@ -85,6 +85,166 @@ using type = RV64Register::ValueType; }; +// This is a soft conversion from a float or double to a half precision value. +// It is not a direct conversion from the floating point format to the half +// format. Instead, it uses the floating point hardware to do the conversion. +// This is done to get the correct rounding behavior for free from the FPU. +template <typename T> +HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm, uint32_t &fflags) { + using UIntType = typename FPTypeInfo<T>::UIntType; + using IntType = typename FPTypeInfo<T>::IntType; + UIntType in_int = absl::bit_cast<UIntType>(input_value); + HalfFP half_fp = {.value = 0x0000}; + + // Extract the mantissa, exponent and sign. + UIntType mantissa = in_int & FPTypeInfo<T>::kSigMask; + UIntType exponent = + (in_int & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; + UIntType sign = in_int >> (FPTypeInfo<T>::kBitSize - 1); + + if (std::isnan(input_value)) { + half_fp.value = FPTypeInfo<HalfFP>::kCanonicalNaN; + if (FPTypeInfo<T>::IsSNaN(input_value)) { + fflags |= static_cast<UIntType>(FPExceptions::kInvalidOp); + } + return half_fp; + } + + if (std::isinf(input_value)) { + half_fp.value = FPTypeInfo<HalfFP>::kPosInf; + half_fp.value |= (sign & 1) << (FPTypeInfo<HalfFP>::kBitSize - 1); + return half_fp; + } + + if (in_int == 0 || in_int == 1ULL << (FPTypeInfo<T>::kBitSize - 1)) { + half_fp.value = + in_int >> (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize); + return half_fp; + } + + IntType bias_diff = FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias; + IntType unbounded_half_exponent = static_cast<IntType>(exponent) - bias_diff; + IntType sig_size_diff = + FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize; + UIntType half_inf_exponent = ((1 << FPTypeInfo<HalfFP>::kExpSize) - 1); + UIntType source_type_inf_exponent = ((1 << FPTypeInfo<T>::kExpSize) - 1); + + // Create a temp float with the smallest normal exponent and input mantissa. + T ftmp = absl::bit_cast<T>( + (sign << (FPTypeInfo<T>::kBitSize - 1)) | + (static_cast<UIntType>(1ULL) << FPTypeInfo<T>::kSigSize) | mantissa); + + // Create a divisor float that will be used for shifting the mantissa in a + // rounding aware way. The amount of shifting depends on if the result is + // subnormal or normal. + T fdiv = 0; + UIntType default_fdiv_exp = FPTypeInfo<T>::kExpBias + sig_size_diff; + UIntType fdiv_exp = default_fdiv_exp; + if (unbounded_half_exponent > 0) { + fdiv_exp = default_fdiv_exp; + } else if (unbounded_half_exponent < 0) { + // shift_count: emin - unbiased exponent + IntType shift_count = 1 - static_cast<int>(exponent) + bias_diff; + fdiv_exp = default_fdiv_exp + shift_count; + fdiv_exp = std::min(fdiv_exp, source_type_inf_exponent - 1); + } else { + fdiv_exp = default_fdiv_exp + 1; + } + fdiv = absl::bit_cast<T>(fdiv_exp << FPTypeInfo<T>::kSigSize); + + // Shift right by doing division. + T fres = ftmp / fdiv; + UIntType res = absl::bit_cast<UIntType>(fres); + + // Shift left by doing multiplication. + T fmultiply = absl::bit_cast<T>(default_fdiv_exp << FPTypeInfo<T>::kSigSize); + T fres2 = fres * fmultiply; + UIntType res2 = absl::bit_cast<UIntType>(fres2); + + // Update the exponent if rounding caused an increase. + IntType exp_diff = static_cast<IntType>((res2 >> FPTypeInfo<T>::kSigSize) & + source_type_inf_exponent) - + 1; + UIntType new_exponent = (exponent + exp_diff) & source_type_inf_exponent; + + UIntType half_exponent = 0; + if (unbounded_half_exponent > 0) { + half_exponent = new_exponent - bias_diff; + } else if (unbounded_half_exponent < 0) { + // Guaranteed subnormal. Nothing to do. + } else { + // This case could be normal or subnormal depending on the rounding result. + half_exponent = (res2 >> FPTypeInfo<T>::kSigSize) & half_inf_exponent; + } + + UIntType half_mantissa = + (res2 >> sig_size_diff) & FPTypeInfo<HalfFP>::kSigMask; + if (unbounded_half_exponent < 0) { // Guaranteed Subnormal + half_mantissa = (res & (1 << FPTypeInfo<HalfFP>::kSigSize)) + ? ((res >> 1) & FPTypeInfo<HalfFP>::kSigMask) + : res & FPTypeInfo<HalfFP>::kSigMask; + } + + // Handle the rules for overflowing to infinity depending on the rounding + // mode. + if (half_exponent >= half_inf_exponent) { + fflags |= static_cast<uint32_t>(FPExceptions::kOverflow); + fflags |= static_cast<uint32_t>(FPExceptions::kInexact); + switch (rm) { + case FPRoundingMode::kRoundToNearest: + half_exponent = half_inf_exponent; + half_mantissa = 0; + break; + case FPRoundingMode::kRoundTowardsZero: + half_exponent = half_inf_exponent - 1; + half_mantissa = FPTypeInfo<HalfFP>::kSigMask; + break; + case FPRoundingMode::kRoundDown: + half_exponent = sign ? half_inf_exponent : half_inf_exponent - 1; + half_mantissa = sign ? 0 : FPTypeInfo<HalfFP>::kSigMask; + break; + case FPRoundingMode::kRoundUp: + half_exponent = sign ? half_inf_exponent - 1 : half_inf_exponent; + half_mantissa = sign ? FPTypeInfo<HalfFP>::kSigMask : 0; + break; + default: + half_exponent = half_inf_exponent; + half_mantissa = 0; + break; + } + } + + // Construct the half float. + half_fp.value = half_mantissa | + (half_exponent << FPTypeInfo<HalfFP>::kSigSize) | + (sign << (FPTypeInfo<HalfFP>::kBitSize - 1)); + + // Do an arithmetic reconstruction of the float to check for exactness. + T trailing_significand_float = static_cast<T>(half_mantissa); + T precision_factor = std::pow(2.0, -1.0 * FPTypeInfo<HalfFP>::kSigSize); + IntType unbiased_exponent = + (half_exponent == 0 ? 1 : half_exponent) - FPTypeInfo<HalfFP>::kExpBias; + T exponent_factor = std::pow(2.0, unbiased_exponent); + T sign_factor = sign == 1 ? -1.0 : 1.0; + T implicit_bit_adjustment = half_exponent == 0 ? 0.0 : 1.0; + T reconstructed_value = ((trailing_significand_float * precision_factor) + + implicit_bit_adjustment) * + exponent_factor * sign_factor; + bool exact_conversion = reconstructed_value == input_value; + + // Handle flags for the specific underflow case. + if (!exact_conversion && (unbounded_half_exponent < 0 || + (unbounded_half_exponent == 0 && fres2 != ftmp))) { + fflags |= static_cast<uint32_t>(FPExceptions::kUnderflow); + } + + // Handle flags for the specific inexact case. + if (!exact_conversion && (fres2 != ftmp)) { + fflags |= static_cast<uint32_t>(FPExceptions::kInexact); + } + return half_fp; +} + // Convert from half precision to single or double precision. template <typename T> inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) { @@ -152,7 +312,7 @@ if constexpr (IsMpactFp<Argument>::value) { lhs = GetNaNBoxedSource<RVFpRegister::ValueType, Argument>(instruction, 0); if (FPTypeInfo<Argument>::IsSNaN(lhs)) { - fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp); + fflags |= *FPExceptions::kInvalidOp; } } else { lhs = generic::GetInstructionSource<Argument>(instruction, 0); @@ -171,16 +331,11 @@ } Result dest_value; - if (zfh_internal::UseHostFlagsForConversion()) { - ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); - dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags); - } else { + { ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value); dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags); } - if (!zfh_internal::UseHostFlagsForConversion()) { - fflags_dest->GetRiscVCsr()->SetBits(fflags); - } + fflags_dest->GetRiscVCsr()->SetBits(fflags); auto *reg = static_cast<generic::RegisterDestinationOperand<DstRegValue> *>( instruction->Destination(0)) ->GetRegister(); @@ -200,67 +355,48 @@ reg->data_buffer()->template Set<Result>(0, dest_value); } -template <typename T> -inline HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm, - uint32_t &fflags); - -template <> -inline HalfFP ConvertToHalfFP(float input_value, FPRoundingMode rm, - uint32_t &fflags) { - return ConvertSingleToHalfFP(input_value, rm, fflags); -} - -template <> -inline HalfFP ConvertToHalfFP(double input_value, FPRoundingMode rm, - uint32_t &fflags) { - return ConvertDoubleToHalfFP(input_value, rm, fflags); -} - // Generic helper function enabling HalfFP operations in native datatypes. -template <typename Result, typename Argument> +template <typename Argument, typename IntermediateType> void RiscVZfhUnaryHelper( const Instruction *instruction, - std::function<Result(Argument, FPRoundingMode, uint32_t &)> operation) { - uint32_t fflags = 0; - RiscVFPState *rv_fp = - static_cast<RiscVState *>(instruction->state())->rv_fp(); - int rm_value = generic::GetInstructionSource<int>(instruction, 1); - - // If the rounding mode is dynamic, read it from the current state. - if (rm_value == *FPRoundingMode::kDynamic) { - if (!rv_fp->rounding_mode_valid()) { - LOG(ERROR) << "Invalid rounding mode"; - return; - } - rm_value = *(rv_fp->GetRoundingMode()); - } - FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); + std::function<IntermediateType(IntermediateType)> operation) { RiscVCsrDestinationOperand *fflags_dest = static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); - bool arguments_contain_snan = false; + uint32_t fflags = 0; RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType, - Result, Argument>( - instruction, - [rv_fp, rm, &fflags, &operation, - &arguments_contain_snan](Argument a) -> Result { - Result result; - if (FPTypeInfo<Argument>::IsSNaN(a)) { - arguments_contain_snan = true; + HalfFP, Argument>( + instruction, [instruction, &operation, &fflags](Argument a) -> HalfFP { + RiscVFPState *rv_fp = + static_cast<RiscVState *>(instruction->state())->rv_fp(); + int rm_value = generic::GetInstructionSource<int>(instruction, 1); + + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + } + rm_value = *(rv_fp->GetRoundingMode()); } - if (zfh_internal::UseHostFlagsForConversion()) { - result = operation(a, rm, fflags); - } else { + FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); + IntermediateType argument1 = + ConvertFromHalfFP<IntermediateType>(a, fflags); + IntermediateType result; + { ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm); - result = operation(a, rm, fflags); + result = operation(argument1); } - return result; + // To get the correct fflags we need a combination of host flags from + // the operation and the conversion flags. Copy the host flags and merge + // them with the conversion flags. + fflags |= rv_fp->fflags()->GetUint32(); + { + // ConvertToHalfFP pollutes the host flags so we need to create a + // ScopedFPRoundingMode to restore the host flags. + ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value); + return ConvertToHalfFP(result, rm, fflags); + } }); - if (!zfh_internal::UseHostFlagsForConversion()) { - fflags_dest->GetRiscVCsr()->SetBits(fflags); - } - if (arguments_contain_snan) { - fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp); - } + fflags_dest->GetRiscVCsr()->SetBits(fflags); } // Generic helper function enabling HalfFP operations in native datatypes. @@ -269,69 +405,44 @@ const Instruction *instruction, std::function<IntermediateType(IntermediateType, IntermediateType)> operation) { - RiscVFPState *rv_fp = - static_cast<RiscVState *>(instruction->state())->rv_fp(); - int rm_value = generic::GetInstructionSource<int>(instruction, 2); - - // If the rounding mode is dynamic, read it from the current state. - if (rm_value == *FPRoundingMode::kDynamic) { - if (!rv_fp->rounding_mode_valid()) { - LOG(ERROR) << "Invalid rounding mode"; - return; - } - rm_value = *(rv_fp->GetRoundingMode()); - } - FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); RiscVCsrDestinationOperand *fflags_dest = static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); - uint32_t fflags = fflags_dest->GetRiscVCsr()->GetUint32(); - bool arguments_contain_snan = false; - IntermediateType b_emin = - std::pow(2.0, 1 - FPTypeInfo<IntermediateType>::kExpBias); - IntermediateType result; + uint32_t fflags = 0; RiscVBinaryFloatNaNBoxOp<RVFpRegister::ValueType, HalfFP, Argument>( instruction, - [&operation, &arguments_contain_snan, &fflags, rv_fp, &rm, &result]( - Argument a, Argument b) -> HalfFP { - IntermediateType a_f; - IntermediateType b_f; - if (FPTypeInfo<Argument>::IsSNaN(a)) { - a_f = absl::bit_cast<IntermediateType>( - FPTypeInfo<IntermediateType>::kPosInf | 1); - arguments_contain_snan = true; - } else { - a_f = ConvertFromHalfFP<IntermediateType>(a, fflags); + [instruction, &operation, &fflags](Argument a, Argument b) -> HalfFP { + RiscVFPState *rv_fp = + static_cast<RiscVState *>(instruction->state())->rv_fp(); + int rm_value = generic::GetInstructionSource<int>(instruction, 2); + // If the rounding mode is dynamic, read it from the current state. + if (rm_value == *FPRoundingMode::kDynamic) { + if (!rv_fp->rounding_mode_valid()) { + LOG(ERROR) << "Invalid rounding mode"; + } + rm_value = *(rv_fp->GetRoundingMode()); } - if (FPTypeInfo<Argument>::IsSNaN(b)) { - b_f = absl::bit_cast<IntermediateType>( - FPTypeInfo<IntermediateType>::kPosInf | 1); - arguments_contain_snan = true; - } else { - b_f = ConvertFromHalfFP<IntermediateType>(b, fflags); - } - if (zfh_internal::UseHostFlagsForConversion()) { - result = operation(a_f, b_f); - } else { + FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); + IntermediateType argument1 = + ConvertFromHalfFP<IntermediateType>(a, fflags); + IntermediateType argument2 = + ConvertFromHalfFP<IntermediateType>(b, fflags); + IntermediateType result; + { ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface(), rm); - result = operation(a_f, b_f); + result = operation(argument1, argument2); } - if (!zfh_internal::UseHostFlagsForConversion()) { - fflags |= rv_fp->fflags()->GetUint32(); + // To get the correct fflags we need a combination of host flags from + // the operation and the conversion flags. Copy the host flags and merge + // them with the conversion flags. + fflags |= rv_fp->fflags()->GetUint32(); + { + // ConvertToHalfFP pollutes the host flags so we need to create a + // ScopedFPRoundingMode to restore the host flags. + ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value); + return ConvertToHalfFP(result, rm, fflags); } - return ConvertToHalfFP(result, rm, fflags); }); - if (arguments_contain_snan) { - fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp); - } - if (!zfh_internal::UseHostFlagsForConversion()) { - fflags_dest->GetRiscVCsr()->Write(fflags); - } - // When the result is less than b_emin before rounding we need to set the - // underflow flag. - if ((fflags_dest->GetRiscVCsr()->GetUint32() & *FPExceptions::kInexact) && - result != 0 && std::abs(result) < b_emin) { - fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kUnderflow); - } + fflags_dest->GetRiscVCsr()->SetBits(fflags); } } // namespace @@ -592,14 +703,8 @@ // Calculate the square root of a half precision value. Do the operation in // single precision and then convert back to half precision. void RiscVZfhFsqrt(const Instruction *instruction) { - RiscVZfhUnaryHelper<HalfFP, HalfFP>( - instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { - float input_f = ConvertFromHalfFP<float>(a, fflags); - if (!std::isnan(input_f) && input_f < 0) { - fflags |= static_cast<uint32_t>(FPExceptions::kInvalidOp); - } - return ConvertToHalfFP(std::sqrt(input_f), rm, fflags); - }); + RiscVZfhUnaryHelper<HalfFP, float>( + instruction, [](float a) -> float { return std::sqrt(a); }); } // The result is the exponent and significand of the first source with the
diff --git a/riscv/riscv_zfh_instructions.h b/riscv/riscv_zfh_instructions.h index 08e33cf..6337a31 100644 --- a/riscv/riscv_zfh_instructions.h +++ b/riscv/riscv_zfh_instructions.h
@@ -15,11 +15,8 @@ #ifndef THIRD_PARTY_MPACT_RISCV_RISCV_ZFH_INSTRUCTIONS_H_ #define THIRD_PARTY_MPACT_RISCV_RISCV_ZFH_INSTRUCTIONS_H_ -#include <cstdint> - #include "mpact/sim/generic/instruction.h" #include "mpact/sim/generic/type_helpers.h" -#include "riscv/riscv_fp_info.h" namespace mpact { namespace sim { @@ -154,13 +151,6 @@ // function. void RV32VUnimplementedInstruction(const Instruction *instruction); -HalfFP ConvertSingleToHalfFP(float, FPRoundingMode, uint32_t &); -HalfFP ConvertDoubleToHalfFP(double, FPRoundingMode, uint32_t &); - -namespace zfh_internal { -bool UseHostFlagsForConversion(); -} // namespace zfh_internal - // Source Operands: // frs1: Float Register // frs2: Float Register
diff --git a/riscv/riscv_zfh_instructions_arm.cc b/riscv/riscv_zfh_instructions_arm.cc deleted file mode 100644 index 74104d2..0000000 --- a/riscv/riscv_zfh_instructions_arm.cc +++ /dev/null
@@ -1,215 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <sys/types.h> - -#include <cassert> -#include <cmath> -#include <cstdint> - -#include "absl/base/casts.h" -#include "mpact/sim/generic/type_helpers.h" -#include "riscv/riscv_fp_info.h" -#include "riscv/riscv_instruction_helpers.h" -#include "riscv/riscv_zfh_instructions.h" - -namespace mpact { -namespace sim { -namespace riscv { - -using HalfFP = ::mpact::sim::generic::HalfFP; - -// TODO(b/401856759): Use arm intrinsics for fp32 -> fp16 and fp64 -> fp16 -// conversions. - -namespace { - -// This is a soft conversion from a float or double to a half precision value. -// It is not a direct conversion from the floating point format to the half -// format. Instead, it uses the floating point hardware to do the conversion. -// This is done to get the correct rounding behavior for free from the FPU. -template <typename T> -HalfFP SoftConvertToHalfFP(T input_value, FPRoundingMode rm, uint32_t &fflags) { - using UIntType = typename FPTypeInfo<T>::UIntType; - using IntType = typename FPTypeInfo<T>::IntType; - UIntType in_int = absl::bit_cast<UIntType>(input_value); - HalfFP half_fp = {.value = 0x0000}; - - // Extract the mantissa, exponent and sign. - UIntType mantissa = in_int & FPTypeInfo<T>::kSigMask; - UIntType exponent = - (in_int & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; - UIntType sign = in_int >> (FPTypeInfo<T>::kBitSize - 1); - - if (std::isnan(input_value)) { - half_fp.value = FPTypeInfo<HalfFP>::kCanonicalNaN; - if (FPTypeInfo<T>::IsSNaN(input_value)) { - fflags |= static_cast<UIntType>(FPExceptions::kInvalidOp); - } - return half_fp; - } - - if (std::isinf(input_value)) { - half_fp.value = FPTypeInfo<HalfFP>::kPosInf; - half_fp.value |= (sign & 1) << (FPTypeInfo<HalfFP>::kBitSize - 1); - return half_fp; - } - - if (in_int == 0 || in_int == 1ULL << (FPTypeInfo<T>::kBitSize - 1)) { - half_fp.value = - in_int >> (FPTypeInfo<T>::kBitSize - FPTypeInfo<HalfFP>::kBitSize); - return half_fp; - } - - IntType bias_diff = FPTypeInfo<T>::kExpBias - FPTypeInfo<HalfFP>::kExpBias; - IntType unbounded_half_exponent = static_cast<IntType>(exponent) - bias_diff; - IntType sig_size_diff = - FPTypeInfo<T>::kSigSize - FPTypeInfo<HalfFP>::kSigSize; - UIntType half_inf_exponent = ((1 << FPTypeInfo<HalfFP>::kExpSize) - 1); - UIntType source_type_inf_exponent = ((1 << FPTypeInfo<T>::kExpSize) - 1); - - // Create a temp float with the smallest normal exponent and input mantissa. - T ftmp = absl::bit_cast<T>( - (sign << (FPTypeInfo<T>::kBitSize - 1)) | - (static_cast<UIntType>(1ULL) << FPTypeInfo<T>::kSigSize) | mantissa); - - // Create a divisor float that will be used for shifting the mantissa in a - // rounding aware way. The amount of shifting depends on if the result is - // subnormal or normal. - T fdiv = 0; - UIntType default_fdiv_exp = FPTypeInfo<T>::kExpBias + sig_size_diff; - UIntType fdiv_exp = default_fdiv_exp; - if (unbounded_half_exponent > 0) { - fdiv_exp = default_fdiv_exp; - } else if (unbounded_half_exponent < 0) { - // shift_count: emin - unbiased exponent - IntType shift_count = 1 - static_cast<int>(exponent) + bias_diff; - fdiv_exp = default_fdiv_exp + shift_count; - fdiv_exp = std::min(fdiv_exp, source_type_inf_exponent - 1); - } else { - fdiv_exp = default_fdiv_exp + 1; - } - fdiv = absl::bit_cast<T>(fdiv_exp << FPTypeInfo<T>::kSigSize); - - // Shift right by doing division. - T fres = ftmp / fdiv; - UIntType res = absl::bit_cast<UIntType>(fres); - - // Shift left by doing multiplication. - T fmultiply = absl::bit_cast<T>(default_fdiv_exp << FPTypeInfo<T>::kSigSize); - T fres2 = fres * fmultiply; - UIntType res2 = absl::bit_cast<UIntType>(fres2); - - // Update the exponent if rounding caused an increase. - IntType exp_diff = static_cast<IntType>((res2 >> FPTypeInfo<T>::kSigSize) & - source_type_inf_exponent) - - 1; - UIntType new_exponent = (exponent + exp_diff) & source_type_inf_exponent; - - UIntType half_exponent = 0; - if (unbounded_half_exponent > 0) { - half_exponent = new_exponent - bias_diff; - } else if (unbounded_half_exponent < 0) { - // Guaranteed subnormal. Nothing to do. - } else { - // This case could be normal or subnormal depending on the rounding result. - half_exponent = (res2 >> FPTypeInfo<T>::kSigSize) & half_inf_exponent; - } - - UIntType half_mantissa = - (res2 >> sig_size_diff) & FPTypeInfo<HalfFP>::kSigMask; - if (unbounded_half_exponent < 0) { // Guaranteed Subnormal - half_mantissa = (res & (1 << FPTypeInfo<HalfFP>::kSigSize)) - ? ((res >> 1) & FPTypeInfo<HalfFP>::kSigMask) - : res & FPTypeInfo<HalfFP>::kSigMask; - } - - // Handle the rules for overflowing to infinity depending on the rounding - // mode. - if (half_exponent >= half_inf_exponent) { - fflags |= static_cast<uint32_t>(FPExceptions::kOverflow); - fflags |= static_cast<uint32_t>(FPExceptions::kInexact); - switch (rm) { - case FPRoundingMode::kRoundToNearest: - half_exponent = half_inf_exponent; - half_mantissa = 0; - break; - case FPRoundingMode::kRoundTowardsZero: - half_exponent = half_inf_exponent - 1; - half_mantissa = FPTypeInfo<HalfFP>::kSigMask; - break; - case FPRoundingMode::kRoundDown: - half_exponent = sign ? half_inf_exponent : half_inf_exponent - 1; - half_mantissa = sign ? 0 : FPTypeInfo<HalfFP>::kSigMask; - break; - case FPRoundingMode::kRoundUp: - half_exponent = sign ? half_inf_exponent - 1 : half_inf_exponent; - half_mantissa = sign ? FPTypeInfo<HalfFP>::kSigMask : 0; - break; - default: - half_exponent = half_inf_exponent; - half_mantissa = 0; - break; - } - } - - // Construct the half float. - half_fp.value = half_mantissa | - (half_exponent << FPTypeInfo<HalfFP>::kSigSize) | - (sign << (FPTypeInfo<HalfFP>::kBitSize - 1)); - - // Do an arithmetic reconstruction of the float to check for exactness. - T trailing_significand_float = static_cast<T>(half_mantissa); - T precision_factor = std::pow(2.0, -1.0 * FPTypeInfo<HalfFP>::kSigSize); - IntType unbiased_exponent = - (half_exponent == 0 ? 1 : half_exponent) - FPTypeInfo<HalfFP>::kExpBias; - T exponent_factor = std::pow(2.0, unbiased_exponent); - T sign_factor = sign == 1 ? -1.0 : 1.0; - T implicit_bit_adjustment = half_exponent == 0 ? 0.0 : 1.0; - T reconstructed_value = ((trailing_significand_float * precision_factor) + - implicit_bit_adjustment) * - exponent_factor * sign_factor; - bool exact_conversion = reconstructed_value == input_value; - - // Handle flags for the specific underflow case. - if (!exact_conversion && (unbounded_half_exponent < 0 || - (unbounded_half_exponent == 0 && fres2 != ftmp))) { - fflags |= static_cast<uint32_t>(FPExceptions::kUnderflow); - } - - // Handle flags for the specific inexact case. - if (!exact_conversion && (fres2 != ftmp)) { - fflags |= static_cast<uint32_t>(FPExceptions::kInexact); - } - return half_fp; -} -} // namespace - -HalfFP ConvertSingleToHalfFP(float input_value, FPRoundingMode rm, - uint32_t &fflags) { - return SoftConvertToHalfFP(input_value, rm, fflags); -} - -HalfFP ConvertDoubleToHalfFP(double input_value, FPRoundingMode rm, - uint32_t &fflags) { - return SoftConvertToHalfFP(input_value, rm, fflags); -} - -namespace zfh_internal { -bool UseHostFlagsForConversion() { return false; } -} // namespace zfh_internal - -} // namespace riscv -} // namespace sim -} // namespace mpact
diff --git a/riscv/riscv_zfh_instructions_x86.cc b/riscv/riscv_zfh_instructions_x86.cc deleted file mode 100644 index f99a0fa..0000000 --- a/riscv/riscv_zfh_instructions_x86.cc +++ /dev/null
@@ -1,74 +0,0 @@ -// Copyright 2025 Google LLC -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include <immintrin.h> -#include <sys/types.h> - -#include <cstdint> - -#include "mpact/sim/generic/type_helpers.h" -#include "riscv/riscv_fp_info.h" -#include "riscv/riscv_zfh_instructions.h" - -namespace mpact { -namespace sim { -namespace riscv { - -using HalfFP = ::mpact::sim::generic::HalfFP; - -HalfFP ConvertSingleToHalfFP(float input_value, FPRoundingMode rm, - uint32_t &fflags) { - HalfFP half_fp; - - // Get current MXCSR value. The simulator should have already configured the - // rounding mode so we simply pass it along to the intrinsic. - unsigned int mxcsr = _mm_getcsr(); - - // Extract rounding control bits (bits 13 and 14) - int rounding_control_bits = (mxcsr >> 13) & 0x3; - - switch (rounding_control_bits) { - case 0x0: // Round to nearest - half_fp.value = _cvtss_sh(input_value, 0); - break; - case 0x1: // Round down - half_fp.value = _cvtss_sh(input_value, 1); - break; - case 0x2: // Round up - half_fp.value = _cvtss_sh(input_value, 2); - break; - case 0x3: // Round towards zero - half_fp.value = _cvtss_sh(input_value, 3); - break; - default: // Default to nearest even if mode is not recognized - half_fp.value = _cvtss_sh(input_value, 0); - break; - } - - return half_fp; -} - -HalfFP ConvertDoubleToHalfFP(double input_value, FPRoundingMode rm, - uint32_t &fflags) { - float input_float = static_cast<float>(input_value); - return ConvertSingleToHalfFP(input_float, rm, fflags); -} - -namespace zfh_internal { -bool UseHostFlagsForConversion() { return true; } -} // namespace zfh_internal - -} // namespace riscv -} // namespace sim -} // namespace mpact
diff --git a/riscv/test/riscv_zfh_instructions_test.cc b/riscv/test/riscv_zfh_instructions_test.cc index 372ddbd..65e01ce 100644 --- a/riscv/test/riscv_zfh_instructions_test.cc +++ b/riscv/test/riscv_zfh_instructions_test.cc
@@ -797,6 +797,7 @@ mpact::sim::riscv::FPExceptions::kInvalidOp); result.value = FPTypeInfo<HalfFP>::kCanonicalNaN; } else if (!FPTypeInfo<HalfFP>::IsNaN(a) && + !FPTypeInfo<HalfFP>::IsInf(a) && (b.value == FPTypeInfo<HalfFP>::kPosZero || b.value == FPTypeInfo<HalfFP>::kNegZero)) { // Dividing by zero requires an exception for non-NaN dividend values.