blob: cf7261685e0dcb0a9b6ae9315c37e375de78aa43 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cheriot/riscv_cheriot_vector_fp_reduction_instructions.h"
#include <functional>
#include "absl/log/log.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/riscv_cheriot_vector_instruction_helpers.h"
#include "mpact/sim/generic/type_helpers.h"
#include "riscv//riscv_fp_host.h"
#include "riscv//riscv_fp_state.h"
#include "riscv//riscv_vector_state.h"
namespace mpact {
namespace sim {
namespace cheriot {
using ::mpact::sim::generic::FPTypeInfo;
using ::mpact::sim::riscv::ScopedFPStatus;
// These reduction instructions take an accumulator and a value and returns
// the result of the reduction operation. Each partial sum is stored to a
// separate entry in the destination vector.
// Sum reduction.
void Vfredosum(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryReductionVectorOp<float, float, float>(
rv_vector, inst,
[](float acc, float vs2) -> float { return acc + vs2; });
return;
case 8:
return RiscVBinaryReductionVectorOp<double, double, double>(
rv_vector, inst,
[](double acc, double vs2) -> double { return acc + vs2; });
return;
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
void Vfwredosum(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryReductionVectorOp<double, float, double>(
rv_vector, inst, [](double acc, float vs2) -> double {
return acc + static_cast<double>(vs2);
});
return;
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// Templated helper function for vfmin and vfmax instructions.
template <typename T>
inline T MaxMinHelper(T vs2, T vs1, std::function<T(T, T)> operation) {
// If either operand is a signaling NaN or if both operands are NaNs, then
// return a canonical (non-signaling) NaN.
if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2) ||
(FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1))) {
typename FPTypeInfo<T>::UIntType c_nan = FPTypeInfo<T>::kCanonicalNaN;
return *reinterpret_cast<T *>(&c_nan);
}
// If either operand is a NaN return the other.
if (FPTypeInfo<T>::IsNaN(vs2)) return vs1;
if (FPTypeInfo<T>::IsNaN(vs1)) return vs2;
// Return the min/max of the two operands.
return operation(vs2, vs1);
}
// FP min reduction.
void Vfredmin(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryReductionVectorOp<float, float, float>(
rv_vector, inst, [](float acc, float vs2) -> float {
return MaxMinHelper<float>(acc, vs2,
[](float acc, float vs2) -> float {
return (acc > vs2) ? vs2 : acc;
});
});
return;
case 8:
return RiscVBinaryReductionVectorOp<double, double, double>(
rv_vector, inst, [](double acc, double vs2) -> double {
return MaxMinHelper<double>(acc, vs2,
[](double acc, double vs2) -> double {
return (acc > vs2) ? vs2 : acc;
});
});
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// FP max reduction.
void Vfredmax(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryReductionVectorOp<float, float, float>(
rv_vector, inst, [](float acc, float vs2) -> float {
return MaxMinHelper<float>(acc, vs2,
[](float acc, float vs2) -> float {
return (acc < vs2) ? vs2 : acc;
});
});
return;
case 8:
return RiscVBinaryReductionVectorOp<double, double, double>(
rv_vector, inst, [](double acc, double vs2) -> double {
return MaxMinHelper<double>(acc, vs2,
[](double acc, double vs2) -> double {
return (acc < vs2) ? vs2 : acc;
});
});
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
} // namespace cheriot
} // namespace sim
} // namespace mpact