RiscVZfh: Simplify semantic functions for fp conversions PiperOrigin-RevId: 755095617 Change-Id: Icf8c0a6c9278df7749bc00c642c38e6680feea18
diff --git a/riscv/riscv_zfh_instructions.cc b/riscv/riscv_zfh_instructions.cc index 8f20a82..cc4d4aa 100644 --- a/riscv/riscv_zfh_instructions.cc +++ b/riscv/riscv_zfh_instructions.cc
@@ -17,6 +17,7 @@ #include <cstdint> #include <functional> #include <limits> +#include <type_traits> #include "absl/base/casts.h" #include "absl/log/log.h" @@ -35,9 +36,48 @@ namespace riscv { using HalfFP = ::mpact::sim::generic::HalfFP; +using ::mpact::sim::generic::IsMpactFp; namespace { +template <typename T> +struct DataTypeRegValue {}; + +template <> +struct DataTypeRegValue<float> { + using type = RVFpRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<double> { + using type = RVFpRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<HalfFP> { + using type = RVFpRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<int32_t> { + using type = RVXRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<uint32_t> { + using type = RVXRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<int64_t> { + using type = RVXRegister::ValueType; +}; + +template <> +struct DataTypeRegValue<uint64_t> { + using type = RVXRegister::ValueType; +}; + // Convert from half precision to single or double precision. template <typename T> inline T ConvertFromHalfFP(HalfFP half_fp, uint32_t &fflags) { @@ -96,37 +136,77 @@ void RiscVZfhCvtHelper( const Instruction *instruction, std::function<Result(Argument, FPRoundingMode, uint32_t &)> operation) { + RiscVCsrDestinationOperand *fflags_dest = + static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); + using DstRegValue = typename DataTypeRegValue<Result>::type; uint32_t fflags = 0; - RiscVFPState *fp_state = - static_cast<RiscVState *>(instruction->state())->rv_fp(); + + Argument lhs; + if constexpr (IsMpactFp<Argument>::value) { + lhs = GetNaNBoxedSource<RVFpRegister::ValueType, Argument>(instruction, 0); + if (FPTypeInfo<Argument>::IsSNaN(lhs)) { + fflags_dest->GetRiscVCsr()->SetBits(*FPExceptions::kInvalidOp); + } + } else { + lhs = generic::GetInstructionSource<Argument>(instruction, 0); + } + // Get the rounding mode. int rm_value = generic::GetInstructionSource<int>(instruction, 1); + auto *rv_fp = static_cast<RiscVState *>(instruction->state())->rv_fp(); // If the rounding mode is dynamic, read it from the current state. if (rm_value == *FPRoundingMode::kDynamic) { - if (!fp_state->rounding_mode_valid()) { + if (!rv_fp->rounding_mode_valid()) { LOG(ERROR) << "Invalid rounding mode"; return; } - rm_value = *(fp_state->GetRoundingMode()); + rm_value = *rv_fp->GetRoundingMode(); } - FPRoundingMode rm = static_cast<FPRoundingMode>(rm_value); - RiscVCsrDestinationOperand *fflags_dest = - static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); - RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType, - Result, Argument>( - instruction, [fp_state, rm, &fflags, &operation](Argument a) -> Result { - Result result; - if (zfh_internal::UseHostFlagsForConversion()) { - result = operation(a, rm, fflags); - } else { - ScopedFPStatus set_fpstatus(fp_state->host_fp_interface(), rm); - result = operation(a, rm, fflags); - } - return result; - }); + + Result dest_value; + if (zfh_internal::UseHostFlagsForConversion()) { + ScopedFPStatus set_fp_status(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags); + } else { + ScopedFPRoundingMode scoped_rm(rv_fp->host_fp_interface(), rm_value); + dest_value = operation(lhs, static_cast<FPRoundingMode>(rm_value), fflags); + } if (!zfh_internal::UseHostFlagsForConversion()) { fflags_dest->GetRiscVCsr()->SetBits(fflags); } + auto *reg = static_cast<generic::RegisterDestinationOperand<DstRegValue> *>( + instruction->Destination(0)) + ->GetRegister(); + + if (sizeof(DstRegValue) > sizeof(Result) && IsMpactFp<Result>::value) { + // If the floating point value is narrower than the register, the upper + // bits have to be set to all ones. + using UReg = typename std::make_unsigned<DstRegValue>::type; + using UInt = typename FPTypeInfo<Result>::UIntType; + auto dest_u_value = *reinterpret_cast<UInt *>(&dest_value); + UReg reg_value = std::numeric_limits<UReg>::max(); + int shift = 8 * sizeof(Result); + reg_value = (reg_value << shift) | dest_u_value; + reg->data_buffer()->template Set<DstRegValue>(0, reg_value); + return; + } + reg->data_buffer()->template Set<Result>(0, dest_value); +} + +template <typename T> +inline HalfFP ConvertToHalfFP(T input_value, FPRoundingMode rm, + uint32_t &fflags); + +template <> +inline HalfFP ConvertToHalfFP(float input_value, FPRoundingMode rm, + uint32_t &fflags) { + return ConvertSingleToHalfFP(input_value, rm, fflags); +} + +template <> +inline HalfFP ConvertToHalfFP(double input_value, FPRoundingMode rm, + uint32_t &fflags) { + return ConvertDoubleToHalfFP(input_value, rm, fflags); } } // namespace @@ -191,43 +271,33 @@ // Convert from half precision to single precision. void RiscVZfhCvtSh(const Instruction *instruction) { - uint32_t fflags = 0; - RiscVCsrDestinationOperand *fflags_dest = - static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); - RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType, - float, HalfFP>( - instruction, [&fflags](HalfFP a) -> float { + RiscVZfhCvtHelper<float, HalfFP>( + instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> float { return ConvertFromHalfFP<float>(a, fflags); }); - fflags_dest->GetRiscVCsr()->SetBits(fflags); } // Convert from single precision to half precision. void RiscVZfhCvtHs(const Instruction *instruction) { RiscVZfhCvtHelper<HalfFP, float>( instruction, [](float a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { - return ConvertSingleToHalfFP(a, rm, fflags); + return ConvertToHalfFP(a, rm, fflags); }); } // Convert from half precision to double precision. void RiscVZfhCvtDh(const Instruction *instruction) { - uint32_t fflags = 0; - RiscVCsrDestinationOperand *fflags_dest = - static_cast<RiscVCsrDestinationOperand *>(instruction->Destination(1)); - RiscVUnaryFloatNaNBoxOp<RVFpRegister::ValueType, RVFpRegister::ValueType, - double, HalfFP>( - instruction, [&fflags](HalfFP a) -> double { + RiscVZfhCvtHelper<double, HalfFP>( + instruction, [](HalfFP a, FPRoundingMode rm, uint32_t &fflags) -> double { return ConvertFromHalfFP<double>(a, fflags); }); - fflags_dest->GetRiscVCsr()->SetBits(fflags); } // Convert from double precision to half precision. void RiscVZfhCvtHd(const Instruction *instruction) { RiscVZfhCvtHelper<HalfFP, double>( instruction, [](double a, FPRoundingMode rm, uint32_t &fflags) -> HalfFP { - return ConvertDoubleToHalfFP(a, rm, fflags); + return ConvertToHalfFP(a, rm, fflags); }); }