| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_fp_reduction_instructions.h" |
| |
| #include <functional> |
| |
| #include "absl/log/log.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/riscv_cheriot_vector_instruction_helpers.h" |
| #include "mpact/sim/generic/type_helpers.h" |
| #include "riscv//riscv_fp_host.h" |
| #include "riscv//riscv_fp_state.h" |
| #include "riscv//riscv_vector_state.h" |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using ::mpact::sim::generic::FPTypeInfo; |
| using ::mpact::sim::riscv::ScopedFPStatus; |
| |
| // These reduction instructions take an accumulator and a value and returns |
| // the result of the reduction operation. Each partial sum is stored to a |
| // separate entry in the destination vector. |
| |
| // Sum reduction. |
| void Vfredosum(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVBinaryReductionVectorOp<float, float, float>( |
| rv_vector, inst, |
| [](float acc, float vs2) -> float { return acc + vs2; }); |
| return; |
| case 8: |
| return RiscVBinaryReductionVectorOp<double, double, double>( |
| rv_vector, inst, |
| [](double acc, double vs2) -> double { return acc + vs2; }); |
| return; |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| void Vfwredosum(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVBinaryReductionVectorOp<double, float, double>( |
| rv_vector, inst, [](double acc, float vs2) -> double { |
| return acc + static_cast<double>(vs2); |
| }); |
| return; |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Templated helper function for vfmin and vfmax instructions. |
| template <typename T> |
| inline T MaxMinHelper(T vs2, T vs1, std::function<T(T, T)> operation) { |
| // If either operand is a signaling NaN or if both operands are NaNs, then |
| // return a canonical (non-signaling) NaN. |
| if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2) || |
| (FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1))) { |
| typename FPTypeInfo<T>::UIntType c_nan = FPTypeInfo<T>::kCanonicalNaN; |
| return *reinterpret_cast<T *>(&c_nan); |
| } |
| // If either operand is a NaN return the other. |
| if (FPTypeInfo<T>::IsNaN(vs2)) return vs1; |
| if (FPTypeInfo<T>::IsNaN(vs1)) return vs2; |
| // Return the min/max of the two operands. |
| return operation(vs2, vs1); |
| } |
| |
| // FP min reduction. |
| void Vfredmin(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVBinaryReductionVectorOp<float, float, float>( |
| rv_vector, inst, [](float acc, float vs2) -> float { |
| return MaxMinHelper<float>(acc, vs2, |
| [](float acc, float vs2) -> float { |
| return (acc > vs2) ? vs2 : acc; |
| }); |
| }); |
| return; |
| case 8: |
| return RiscVBinaryReductionVectorOp<double, double, double>( |
| rv_vector, inst, [](double acc, double vs2) -> double { |
| return MaxMinHelper<double>(acc, vs2, |
| [](double acc, double vs2) -> double { |
| return (acc > vs2) ? vs2 : acc; |
| }); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // FP max reduction. |
| void Vfredmax(const Instruction *inst) { |
| auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp(); |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (!rv_fp->rounding_mode_valid()) { |
| LOG(ERROR) << "Invalid rounding mode"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface()); |
| switch (sew) { |
| case 4: |
| return RiscVBinaryReductionVectorOp<float, float, float>( |
| rv_vector, inst, [](float acc, float vs2) -> float { |
| return MaxMinHelper<float>(acc, vs2, |
| [](float acc, float vs2) -> float { |
| return (acc < vs2) ? vs2 : acc; |
| }); |
| }); |
| return; |
| case 8: |
| return RiscVBinaryReductionVectorOp<double, double, double>( |
| rv_vector, inst, [](double acc, double vs2) -> double { |
| return MaxMinHelper<double>(acc, vs2, |
| [](double acc, double vs2) -> double { |
| return (acc < vs2) ? vs2 : acc; |
| }); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |