| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_fp_unary_instructions.h" |
| |
| #include <cmath> |
| #include <cstdint> |
| #include <limits> |
| #include <tuple> |
| |
| #include "absl/log/log.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/riscv_cheriot_instruction_helpers.h" |
| #include "cheriot/riscv_cheriot_vector_instruction_helpers.h" |
| #include "mpact/sim/generic/instruction.h" |
| #include "mpact/sim/generic/type_helpers.h" |
| #include "riscv//riscv_fp_host.h" |
| #include "riscv//riscv_fp_info.h" |
| #include "riscv//riscv_fp_state.h" |
| #include "riscv//riscv_vector_state.h" |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using ::mpact::sim::generic::FPTypeInfo; |
| using ::mpact::sim::riscv::FPExceptions; |
| using ::mpact::sim::riscv::ScopedFPStatus; |
| |
| // These tables contain the 7 bits of mantissa used by the approximated |
| // reciprocal square root and reciprocal instructions. |
| static const int kRecipSqrtMantissaTable[128] = { |
| 52, 51, 50, 48, 47, 46, 44, 43, 42, 41, 40, 39, 38, 36, 35, |
| 34, 33, 32, 31, 30, 30, 29, 28, 27, 26, 25, 24, 23, 23, 22, |
| 21, 20, 19, 19, 18, 17, 16, 16, 15, 14, 14, 13, 12, 12, 11, |
| 10, 10, 9, 9, 8, 7, 7, 6, 6, 5, 4, 4, 3, 3, 2, |
| 2, 1, 1, 0, 127, 125, 123, 121, 119, 118, 116, 114, 113, 111, 109, |
| 108, 106, 105, 103, 102, 100, 99, 97, 96, 95, 93, 92, 91, 20, 88, |
| 87, 86, 85, 84, 83, 82, 80, 79, 78, 77, 76, 75, 74, 73, 72, |
| 71, 70, 70, 69, 68, 67, 66, 65, 64, 63, 63, 62, 61, 60, 59, |
| 59, 58, 57, 56, 56, 55, 54, 53, |
| }; |
| |
| static const int kRecipMantissaTable[128] = { |
| 127, 125, 123, 121, 119, 117, 116, 114, 112, 110, 109, 107, 105, 104, 102, |
| 100, 99, 97, 96, 94, 93, 91, 90, 88, 87, 85, 84, 83, 81, 80, |
| 79, 77, 76, 75, 74, 72, 71, 70, 69, 68, 66, 65, 64, 63, 62, |
| 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, |
| 46, 45, 44, 43, 42, 41, 40, 40, 39, 38, 37, 36, 35, 35, 34, |
| 33, 32, 31, 31, 30, 29, 28, 28, 27, 26, 25, 25, 24, 23, 23, |
| 22, 21, 21, 20, 19, 19, 18, 17, 17, 16, 15, 15, 14, 14, 13, |
| 12, 12, 11, 11, 10, 9, 9, 8, 8, 7, 7, 6, 5, 5, 4, |
| 4, 3, 3, 2, 2, 1, 1, 0}; |
| |
| // Move float from scalar fp register to vector register(all elements). |
| void Vfmvvf(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| const int vl = rv_vector->vector_length(); |
| if (rv_vector->vstart() > 0) return; |
| if (vl == 0) return; |
| |
| const int sew = rv_vector->selected_element_width(); |
| auto dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto dest_db = dest_op->CopyDataBuffer(); |
| switch (sew) { |
| case 4: |
| for (int i = 0; i < vl; ++i) { |
| dest_db->Set<uint32_t>( |
| i, generic::GetInstructionSource<uint32_t>(inst, 0, 0)); |
| } |
| break; |
| case 8: |
| for (int i = 0; i < vl; ++i) { |
| dest_db->Set<uint64_t>( |
| i, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); |
| } |
| break; |
| default: |
| dest_db->DecRef(); |
| LOG(ERROR) << "Vfmv.s.f: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Move float from vector to scalar fp register(first element). |
| void Vfmvsf(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart() > 0) return; |
| if (rv_vector->vector_length() == 0) return; |
| int sew = rv_vector->selected_element_width(); |
| auto dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto dest_db = dest_op->CopyDataBuffer(); |
| switch (sew) { |
| case 4: |
| dest_db->Set<uint32_t>( |
| 0, generic::GetInstructionSource<uint32_t>(inst, 0, 0)); |
| break; |
| case 8: |
| dest_db->Set<uint64_t>( |
| 0, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); |
| break; |
| default: |
| dest_db->DecRef(); |
| LOG(ERROR) << "Vfmv.s.f: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Move scalar floating point value to element 0 of vector register. |
| void Vfmvfs(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| auto dest_op = inst->Destination(0); |
| auto dest_db = dest_op->AllocateDataBuffer(); |
| int db_size = dest_db->size<uint8_t>(); |
| switch (sew) { |
| case 4: { |
| uint64_t value = generic::GetInstructionSource<uint32_t>(inst, 0, 0); |
| if (db_size == 4) { |
| dest_db->Set<uint32_t>(0, value); |
| } else if (db_size == 8) { |
| uint64_t val64 = 0xffff'ffff'0000'0000ULL | value; |
| dest_db->Set<uint64_t>(0, val64); |
| } else { |
| LOG(ERROR) << "Unexpected databuffer size in Vfmvfs"; |
| } |
| break; |
| } |
| case 8: |
| dest_db->Set<uint64_t>( |
| 0, generic::GetInstructionSource<uint64_t>(inst, 0, 0)); |
| break; |
| default: |
| dest_db->DecRef(); |
| LOG(ERROR) << "Vfmv.f.s: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Convert floating point to unsigned integer. |
| void Vfcvtxufv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint32_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint32_t, uint32_t> { |
| return CvtHelper<float, uint32_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<uint64_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<uint64_t, uint32_t> { |
| return CvtHelper<double, uint64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.xu.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Convert floating point to signed integer. |
| void Vfcvtxfv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int32_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int32_t, uint32_t> { |
| return CvtHelper<float, int32_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<int64_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<int64_t, uint32_t> { |
| return CvtHelper<double, int64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.x.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Convert unsigned integer to floating point. |
| void Vfcvtfxuv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<float, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2) -> float { return static_cast<float>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<double, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2) -> double { return static_cast<double>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfcvt.f.xuv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Convert signed integer to floating point. |
| void Vfcvtfxv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<float, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2) -> float { return static_cast<float>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<double, int64_t>( |
| rv_vector, inst, |
| [](int64_t vs2) -> double { return static_cast<double>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfcvt.f.xv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Convert floating point to unsigned integer with truncation. |
| void Vfcvtrtzxufv(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint32_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint32_t, uint32_t> { |
| return CvtHelper<float, uint32_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<uint64_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<uint64_t, uint32_t> { |
| return CvtHelper<double, uint64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.rtz.xu.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Convert floating point to signed integer with truncation. |
| void Vfcvtrtzxfv(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int32_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int32_t, uint32_t> { |
| return CvtHelper<float, int32_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<int64_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<int64_t, uint32_t> { |
| return CvtHelper<double, int64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.rtz.x.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of floating point to unsigned integer. |
| void Vfwcvtxufv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint64_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint64_t, uint32_t> { |
| return CvtHelper<float, uint64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfwcvt.xu.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of floating point to signed integer. |
| void Vfwcvtxfv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int64_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int64_t, uint32_t> { |
| return CvtHelper<float, int64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfwcvt.x.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Wideing conversion of floating point to floating point. |
| void Vfwcvtffv(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<double, float>( |
| rv_vector, inst, |
| [](float vs2) -> double { return static_cast<double>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfwcvt.f.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of unsigned integer to floating point. |
| void Vfwcvtfxuv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 2: |
| return RiscVUnaryVectorOp<float, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2) -> float { return static_cast<float>(vs2); }); |
| case 4: |
| return RiscVUnaryVectorOp<double, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2) -> double { return static_cast<double>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfwcvt.f.xuv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of signed integer to floating point. |
| void Vfwcvtfxv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 2: |
| return RiscVUnaryVectorOp<float, int16_t>( |
| rv_vector, inst, |
| [](int16_t vs2) -> float { return static_cast<float>(vs2); }); |
| case 4: |
| return RiscVUnaryVectorOp<double, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2) -> double { return static_cast<double>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfwcvt.f.xuv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of floating point to unsigned integer with truncation. |
| void Vfwcvtrtzxufv(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint64_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint64_t, uint32_t> { |
| return CvtHelper<float, uint64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vwfcvt.rtz.xu.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Widening conversion of floating point to signed integer with truncation. |
| void Vfwcvtrtzxfv(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int64_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int64_t, uint32_t> { |
| return CvtHelper<float, int64_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vwfcvt.rtz.x.fv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to unsigned integer. |
| void Vfncvtxufw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint16_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint16_t, uint32_t> { |
| return CvtHelper<float, uint16_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<uint32_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<uint32_t, uint32_t> { |
| return CvtHelper<double, uint32_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfncvt.xu.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to signed integer. |
| void Vfncvtxfw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int16_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int16_t, uint32_t> { |
| return CvtHelper<float, int16_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<int32_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<int32_t, uint32_t> { |
| return CvtHelper<double, int32_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfncvt.x.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to floating point. |
| void Vfncvtffw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<float, double>( |
| rv_vector, inst, |
| [](double vs2) -> float { return static_cast<float>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfwcvt.f.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to floating point rounding to odd. |
| void Vfncvtrodffw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| // The rounding mode is round to odd, which means that the lsb of the new |
| // mantissa is either 1 or it is the logical or of all the bits to the right |
| // in the original width mantissa. |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<float, double>( |
| rv_vector, inst, [](double vs2) -> float { |
| if (FPTypeInfo<double>::IsNaN(vs2) || |
| FPTypeInfo<double>::IsInf(vs2)) { |
| return static_cast<float>(vs2); |
| } |
| using UIntD = typename FPTypeInfo<double>::UIntType; |
| using UIntF = typename FPTypeInfo<float>::UIntType; |
| UIntD uval = *reinterpret_cast<UIntD *>(&vs2); |
| int sig_diff = |
| FPTypeInfo<double>::kSigSize - FPTypeInfo<float>::kSigSize; |
| UIntD mask = (1ULL << sig_diff) - 1; |
| UIntF bit = (mask & uval) != 0; |
| auto res = static_cast<float>(vs2); |
| if (FPTypeInfo<float>::IsInf(res)) return res; |
| UIntF ures = *reinterpret_cast<UIntF *>(&res); |
| ures |= bit; |
| return *reinterpret_cast<float *>(&ures); |
| }); |
| default: |
| LOG(ERROR) << "Vfwcvt.rod.f.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of unsigned integer to floating point. |
| void Vfncvtfxuw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<float, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2) -> float { return static_cast<float>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfncvt.f.xuw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of signed integeer to floating point. |
| void Vfncvtfxw(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<float, int64_t>( |
| rv_vector, inst, |
| [](int64_t vs2) -> float { return static_cast<float>(vs2); }); |
| default: |
| LOG(ERROR) << "Vfncvt.f.xw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to unsigned integer with truncation. |
| void Vfncvtrtzxufw(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<uint16_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<uint16_t, uint32_t> { |
| return CvtHelper<float, uint16_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<uint32_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<uint32_t, uint32_t> { |
| return CvtHelper<double, uint32_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.rtz.xu.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Narrowing conversion of floating point to signed integer with truncation. |
| void Vfncvtrtzxfw(const Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<int16_t, float>( |
| rv_vector, inst, [](float vs2) -> std::tuple<int16_t, uint32_t> { |
| return CvtHelper<float, int16_t>(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<int32_t, double>( |
| rv_vector, inst, [](double vs2) -> std::tuple<int32_t, uint32_t> { |
| return CvtHelper<double, int32_t>(vs2); |
| }); |
| default: |
| LOG(ERROR) << "Vfcvt.rtz.xu.fw: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Templated helper function to compute square root. |
| template <typename T> |
| inline std::tuple<T, uint32_t> SqrtHelper(T vs2) { |
| uint32_t flags = 0; |
| T res; |
| if (FPTypeInfo<T>::IsNaN(vs2) || vs2 < 0.0) { |
| auto value = FPTypeInfo<T>::kCanonicalNaN; |
| res = *reinterpret_cast<T *>(&value); |
| flags = *FPExceptions::kInvalidOp; |
| return std::tie(res, flags); |
| } |
| if (vs2 == 0.0) return std::tie(vs2, flags); |
| res = sqrt(vs2); |
| return std::tie(res, flags); |
| } |
| |
| // Square root. |
| void Vfsqrtv(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| uint32_t flags = 0; |
| { |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| RiscVUnaryVectorOp<float, float>(rv_vector, inst, |
| [&flags](float vs2) -> float { |
| auto [res, f] = SqrtHelper(vs2); |
| flags |= f; |
| return res; |
| }); |
| break; |
| case 8: |
| RiscVUnaryVectorOp<double, double>(rv_vector, inst, |
| [&flags](double vs2) -> double { |
| auto [res, f] = SqrtHelper(vs2); |
| flags |= f; |
| return res; |
| }); |
| break; |
| default: |
| LOG(ERROR) << "Vffcvt.f.xuv: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| auto *fflags = rv_fp->fflags(); |
| fflags->Write(flags | fflags->AsUint32()); |
| } |
| |
| // Templated helper function to compute the Reciprocal Square Root |
| // approximation for valid inputs. |
| template <typename T> |
| inline T RecipSqrt7(T value) { |
| using Uint = typename FPTypeInfo<T>::UIntType; |
| Uint uint_value = *reinterpret_cast<Uint *>(&value); |
| // The input value is positive. Negative values are already handled. |
| int norm_exponent = |
| (uint_value & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; |
| Uint norm_mantissa = uint_value & FPTypeInfo<T>::kSigMask; |
| if (norm_exponent == 0) { // The value is a denormal. |
| Uint mask = static_cast<Uint>(1) << (FPTypeInfo<T>::kSigSize - 1); |
| // Normalize the mantissa and exponent by shifting the mantissa left until |
| // the most significant bit is one. |
| while ((norm_mantissa & mask) == 0) { |
| norm_exponent--; |
| norm_mantissa <<= 1; |
| } |
| // Shift it left once more - so it becomes the "implied" bit, and not used |
| // in the lookup below. |
| norm_mantissa <<= 1; |
| } |
| int index = (norm_exponent & 0b1) << 6 | |
| ((norm_mantissa >> (FPTypeInfo<T>::kSigSize - 6)) & 0b11'1111); |
| Uint new_mantissa = static_cast<Uint>(kRecipSqrtMantissaTable[index]) |
| << (FPTypeInfo<T>::kSigSize - 7); |
| Uint new_exponent = (3 * FPTypeInfo<T>::kExpBias - 1 - norm_exponent) / 2; |
| Uint new_value = (new_exponent << FPTypeInfo<T>::kSigSize) | new_mantissa; |
| T new_fp_value = *reinterpret_cast<T *>(&new_value); |
| return new_fp_value; |
| } |
| |
| // Templated helper function to compute the Reciprocal Square Root |
| // approximation for all values. |
| template <typename T> |
| inline std::tuple<T, uint32_t> RecipSqrt7Helper(T value) { |
| auto fp_class = std::fpclassify(value); |
| T return_value = std::numeric_limits<T>::quiet_NaN(); |
| uint32_t fflags = 0; |
| switch (fp_class) { |
| case FP_INFINITE: |
| return_value = |
| std::signbit(value) ? std::numeric_limits<T>::quiet_NaN() : 0.0; |
| fflags = (uint32_t)FPExceptions::kInvalidOp; |
| break; |
| case FP_NAN: |
| // Just propagate the NaN. |
| return_value = std::numeric_limits<T>::quiet_NaN(); |
| fflags = (uint32_t)FPExceptions::kInvalidOp; |
| break; |
| case FP_ZERO: |
| return_value = std::signbit(value) ? -std::numeric_limits<T>::infinity() |
| : std::numeric_limits<T>::infinity(); |
| fflags = (uint32_t)FPExceptions::kDivByZero; |
| break; |
| case FP_SUBNORMAL: |
| case FP_NORMAL: |
| if (std::signbit(value)) { |
| return_value = std::numeric_limits<T>::quiet_NaN(); |
| fflags = (uint32_t)FPExceptions::kInvalidOp; |
| } else { |
| return_value = RecipSqrt7(value); |
| } |
| break; |
| default: |
| LOG(ERROR) << "RecipSqrt7Helper: Illegal fp_class (" << fp_class << ")"; |
| break; |
| } |
| return std::make_tuple(return_value, fflags); |
| } |
| |
| // Approximation of reciprocal square root to 7 bits mantissa. |
| void Vfrsqrt7v(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOpWithFflags<float, float>( |
| rv_vector, inst, [rv_fp](float vs2) -> std::tuple<float, uint32_t> { |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| return RecipSqrt7Helper(vs2); |
| }); |
| case 8: |
| return RiscVUnaryVectorOpWithFflags<double, double>( |
| rv_vector, inst, [rv_fp](double vs2) -> std::tuple<double, uint32_t> { |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| return RecipSqrt7Helper(vs2); |
| }); |
| default: |
| LOG(ERROR) << "vfrsqrt7.v: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Templated helper function to compute the Reciprocal approximation for valid |
| // normal floating point inputs. |
| template <typename T> |
| inline T Recip7(T value, FPRoundingMode rm) { |
| using Uint = typename FPTypeInfo<T>::UIntType; |
| using Int = typename FPTypeInfo<T>::IntType; |
| Uint uint_value = *reinterpret_cast<Uint *>(&value); |
| Int norm_exponent = |
| (uint_value & FPTypeInfo<T>::kExpMask) >> FPTypeInfo<T>::kSigSize; |
| Uint norm_mantissa = uint_value & FPTypeInfo<T>::kSigMask; |
| if (norm_exponent == 0) { // The value is a denormal. |
| Uint msb = static_cast<Uint>(1) << (FPTypeInfo<T>::kSigSize - 1); |
| // Normalize the mantissa and exponent by shifting the mantissa left until |
| // the most significant bit is one. |
| while (norm_mantissa && ((norm_mantissa & msb) == 0)) { |
| norm_exponent--; |
| norm_mantissa <<= 1; |
| } |
| // Shift it left once more - so it becomes the "implied" bit, and not used |
| // in the lookup below. |
| norm_mantissa <<= 1; |
| } |
| Int new_exponent = 2 * FPTypeInfo<T>::kExpBias - 1 - norm_exponent; |
| // If the exponent is too high, then return exceptional values. |
| if (new_exponent > 2 * FPTypeInfo<T>::kExpBias) { |
| switch (rm) { |
| case FPRoundingMode::kRoundDown: |
| return std::signbit(value) ? -std::numeric_limits<T>::infinity() |
| : std::numeric_limits<T>::max(); |
| case FPRoundingMode::kRoundTowardsZero: |
| return std::signbit(value) ? std::numeric_limits<T>::lowest() |
| : std::numeric_limits<T>::max(); |
| case FPRoundingMode::kRoundToNearestTiesToMax: |
| case FPRoundingMode::kRoundToNearest: |
| return std::signbit(value) ? -std::numeric_limits<T>::infinity() |
| : std::numeric_limits<T>::infinity(); |
| case FPRoundingMode::kRoundUp: |
| return std::signbit(value) ? std::numeric_limits<T>::lowest() |
| : std::numeric_limits<T>::infinity(); |
| default: |
| // kDynamic can't happen. |
| return std::numeric_limits<T>::quiet_NaN(); |
| } |
| } |
| // Perform table lookup and compute the new value using the new exponent. |
| int index = (norm_mantissa >> (FPTypeInfo<T>::kSigSize - 7)) & 0b111'1111; |
| Uint new_mantissa = static_cast<Uint>(kRecipMantissaTable[index]) |
| << (FPTypeInfo<T>::kSigSize - 7); |
| // If the new exponent is negative or 0, the result is denormal. First |
| // shift the mantissa right and or in the implied '1'. |
| if (new_exponent <= 0) { |
| new_mantissa = (new_mantissa >> 1) | 0b100'0000; |
| // If the exponent is less than 0, shift the mantissa right. |
| if (new_exponent < 0) { |
| new_mantissa >>= 1; |
| new_exponent = 0; |
| } |
| new_mantissa &= 0b111'1111; |
| } |
| Uint new_value = (new_exponent << FPTypeInfo<T>::kSigSize) | new_mantissa; |
| T new_fp_value = *reinterpret_cast<T *>(&new_value); |
| return value < 0.0 ? -new_fp_value : new_fp_value; |
| } |
| |
| // Templated helper function to compute the Reciprocal approximation for all |
| // values including non-normal floating point values. |
| template <typename T> |
| inline T Recip7Helper(T value, FPRoundingMode rm) { |
| auto fp_class = std::fpclassify(value); |
| |
| switch (fp_class) { |
| case FP_INFINITE: |
| // TODO: raise exception. |
| return std::signbit(value) ? -0.0 : 0; |
| case FP_NAN: |
| // Just propagate the NaN. |
| return std::numeric_limits<T>::quiet_NaN(); |
| case FP_ZERO: |
| return std::signbit(value) ? -std::numeric_limits<T>::infinity() |
| : std::numeric_limits<T>::infinity(); |
| case FP_SUBNORMAL: |
| case FP_NORMAL: |
| return Recip7(value, rm); |
| } |
| return std::numeric_limits<T>::quiet_NaN(); |
| } |
| |
| // Approximate reciprocal to 7 bits of mantissa. |
| void Vfrec7v(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| auto rm = rv_fp->GetRoundingMode(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<float, float>( |
| rv_vector, inst, |
| [rm](float vs2) -> float { return Recip7Helper(vs2, rm); }); |
| case 8: |
| return RiscVUnaryVectorOp<double, double>( |
| rv_vector, inst, |
| [rm](double vs2) -> double { return Recip7Helper(vs2, rm); }); |
| default: |
| LOG(ERROR) << "vfrec7.v: Illegal sew (" << sew << ")"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Classify floating point value. |
| void Vfclassv(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<uint32_t, float>( |
| rv_vector, inst, [](float vs2) -> uint32_t { |
| return static_cast<uint32_t>(ClassifyFP(vs2)); |
| }); |
| case 8: |
| return RiscVUnaryVectorOp<uint64_t, double>( |
| rv_vector, inst, [](double vs2) -> uint64_t { |
| return static_cast<uint64_t>(ClassifyFP(vs2)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "vfclass.v: Illegal sew (" << sew << ")"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |