blob: e4d92fa268f663cc88e0b647502107f9cc04bef3 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cheriot/riscv_cheriot_vector_fp_instructions.h"
#include <cmath>
#include <cstdint>
#include <functional>
#include <tuple>
#include "absl/log/log.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/riscv_cheriot_vector_instruction_helpers.h"
#include "mpact/sim/generic/type_helpers.h"
#include "riscv//riscv_fp_host.h"
#include "riscv//riscv_fp_info.h"
#include "riscv//riscv_fp_state.h"
#include "riscv//riscv_vector_state.h"
namespace mpact {
namespace sim {
namespace cheriot {
using ::mpact::sim::generic::FPTypeInfo;
using ::mpact::sim::riscv::FPExceptions;
using ::mpact::sim::riscv::ScopedFPStatus;
// Floating point add.
void Vfadd(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs2 + vs1; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs2 + vs1; });
default:
LOG(ERROR) << "Vfadd: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point subtract.
void Vfsub(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs2 - vs1; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs2 - vs1; });
default:
LOG(ERROR) << "Vfsub: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Reverse floating point subtract (rs1 - vs2).
void Vfrsub(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs1 - vs2; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs1 - vs2; });
default:
LOG(ERROR) << "Vfrsub: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point add.
void Vfwadd(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1) -> double {
double vs2_d = static_cast<double>(vs2);
double vs1_d = static_cast<double>(vs1);
return (vs2_d + vs1_d);
});
default:
LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point subtract.
void Vfwsub(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1) -> double {
double vs2_d = static_cast<double>(vs2);
double vs1_d = static_cast<double>(vs1);
return (vs2_d - vs1_d);
});
default:
LOG(ERROR) << "Vfwsub: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point add with wide operand (vs2).
void Vfwaddw(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<double, double, float>(
rv_vector, inst, [](double vs2_d, float vs1) -> double {
double vs1_d = static_cast<double>(vs1);
return (vs2_d + vs1_d);
});
default:
LOG(ERROR) << "Vfwaddw: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point subtract with wide operand (vs2).
void Vfwsubw(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<double, double, float>(
rv_vector, inst, [](double vs2_d, float vs1) -> double {
double vs1_d = static_cast<double>(vs1);
return (vs2_d - vs1_d);
});
default:
LOG(ERROR) << "Vfwsubw: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point multiply.
void Vfmul(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs2 * vs1; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs2 * vs1; });
default:
LOG(ERROR) << "Vfmul: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point division vs2/vs1;
void Vfdiv(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs2 / vs1; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs2 / vs1; });
default:
LOG(ERROR) << "Vfdiv: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point reverse division vs1/vs2.
void Vfrdiv(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> float { return vs1 / vs2; });
case 8:
return RiscVBinaryVectorOp<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> double { return vs1 / vs2; });
default:
LOG(ERROR) << "Vfrdiv: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point multiply.
void Vfwmul(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVBinaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1) -> double {
double vs2_d = static_cast<double>(vs2);
double vs1_d = static_cast<double>(vs1);
return (vs2_d * vs1_d);
});
default:
LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point multiply and add vs2.
void Vfmadd(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(vs1, vd, vs2);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(vs1, vd, vs2);
});
default:
LOG(ERROR) << "Vfmadd: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Negated floating point multiply and add vs2.
void Vfnmadd(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(-vs1, vd, -vs2);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(-vs1, vd, -vs2);
});
default:
LOG(ERROR) << "Vfnmadd: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point multiply and subtract vs2.
void Vfmsub(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(vs1, vd, -vs2);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(vs1, vd, -vs2);
});
default:
LOG(ERROR) << "Vfmsub: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Negated floating point multiply and subtract vs2.
void Vfnmsub(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(-vs1, vd, vs2);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(-vs1, vd, vs2);
});
default:
LOG(ERROR) << "Vfnmsub: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point multiply and accumulate vd.
void Vfmacc(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(vs1, vs2, vd);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(vs1, vs2, vd);
});
default:
LOG(ERROR) << "Vfmacc: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Negated floating point multiply and accumulate vd.
void Vfnmacc(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(-vs1, vs2, -vd);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(-vs1, vs2, -vd);
});
default:
LOG(ERROR) << "Vfnmacc: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Floating point multiply and subtract vd.
void Vfmsac(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(vs1, vs2, -vd);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(vs1, vs2, -vd);
});
default:
LOG(ERROR) << "Vfmsac: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Negated floating point multiply and subtract vd.
void Vfnmsac(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<float, float, float>(
rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
return std::fma(-vs1, vs2, vd);
});
case 8:
return RiscVTernaryVectorOp<double, double, double>(
rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
return std::fma(-vs1, vs2, vd);
});
default:
LOG(ERROR) << "Vfnmsac: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point multiply and accumulate vd.
void Vfwmacc(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
double vs1_d = vs1;
double vs2_d = vs2;
return ((vs1_d * vs2_d) + vd);
});
default:
LOG(ERROR) << "Vfwmacc: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening negated floating point multiply and accumulate vd.
void Vfwnmacc(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
double vs1_d = vs1;
double vs2_d = vs2;
return (-(vs1_d * vs2_d)) - vd;
});
default:
LOG(ERROR) << "Vfwnmacc: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening floating point multiply and subtract vd.
void Vfwmsac(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
double vs1_d = vs1;
double vs2_d = vs2;
return ((vs1_d * vs2_d) - vd);
});
default:
LOG(ERROR) << "Vfwmsac: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Widening negated floating point multiply and subtract vd.
void Vfwnmsac(const Instruction *inst) {
auto *rv_fp = static_cast<CheriotState *>(inst->state())->rv_fp();
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (!rv_fp->rounding_mode_valid()) {
LOG(ERROR) << "Invalid rounding mode";
rv_vector->set_vector_exception();
return;
}
int sew = rv_vector->selected_element_width();
ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
switch (sew) {
case 4:
return RiscVTernaryVectorOp<double, float, float>(
rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
double vs1_d = vs1;
double vs2_d = vs2;
return (-(vs1_d * vs2_d)) + vd;
});
default:
LOG(ERROR) << "Vfwnmsac: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Change the sign of vs2 to the sign of vs1.
void Vfsgnj(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
return (vs2 & 0x7fff'ffff) | (vs1 & 0x8000'0000);
});
case 8:
return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
return (vs2 & 0x7fff'ffff'ffff'ffff) |
(vs1 & 0x8000'0000'0000'0000);
});
default:
LOG(ERROR) << "Vfsgnj: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Change the sign of vs2 to the negation of the sign of vs1.
void Vfsgnjn(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
return (vs2 & 0x7fff'ffff) | (~vs1 & 0x8000'0000);
});
case 8:
return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
return (vs2 & 0x7fff'ffff'ffff'ffff) |
(~vs1 & 0x8000'0000'0000'0000);
});
default:
LOG(ERROR) << "Vfsgnjn: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Change the sign of vs2 to the xor of the sign of the two operands.
void Vfsgnjx(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
return (vs2 & 0x7fff'ffff) | ((vs1 ^ vs2) & 0x8000'0000);
});
case 8:
return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
return (vs2 & 0x7fff'ffff'ffff'ffff) ^
((vs1 ^ vs2) & 0x8000'0000'0000'0000);
});
default:
LOG(ERROR) << "Vfsgnjx: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Templated helper function for vfmin and vfmax instructions.
template <typename T>
inline std::tuple<T, uint32_t> MaxMinHelper(T vs2, T vs1,
std::function<T(T, T)> operation) {
// If either operand is a signaling NaN or if both operands are NaNs, then
// return a canonical (non-signaling) NaN.
uint32_t flag = 0;
if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2)) {
flag = static_cast<uint32_t>(FPExceptions::kInvalidOp);
}
if (FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1)) {
auto c_nan = FPTypeInfo<T>::kCanonicalNaN;
return std::make_tuple(*reinterpret_cast<T *>(&c_nan), flag);
}
// If either operand is a NaN return the other.
if (FPTypeInfo<T>::IsNaN(vs2)) return std::tie(vs1, flag);
if (FPTypeInfo<T>::IsNaN(vs1)) return std::tie(vs2, flag);
// Return the min/max of the two operands.
if ((vs2 == 0.0) && (vs1 == 0.0)) {
T tmp2 = std::signbit(vs2) ? -1.0 : 1;
T tmp1 = std::signbit(vs1) ? -1.0 : 1;
return std::make_tuple(operation(tmp2, tmp1) == tmp2 ? vs2 : vs1, 0);
}
return std::make_tuple(operation(vs2, vs1), flag);
}
// Vector floating point min.
void Vfmin(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVBinaryVectorOpWithFflags<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> std::tuple<float, uint32_t> {
using T = float;
return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
return (vs1 < vs2) ? vs1 : vs2;
});
});
case 8:
return RiscVBinaryVectorOpWithFflags<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> std::tuple<double, uint32_t> {
using T = double;
return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
return (vs1 < vs2) ? vs1 : vs2;
});
});
default:
LOG(ERROR) << "Vfmin: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Vector floating point max.
void Vfmax(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVBinaryVectorOpWithFflags<float, float, float>(
rv_vector, inst,
[](float vs2, float vs1) -> std::tuple<float, uint32_t> {
using T = float;
return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
return (vs1 > vs2) ? vs1 : vs2;
});
});
case 8:
return RiscVBinaryVectorOpWithFflags<double, double, double>(
rv_vector, inst,
[](double vs2, double vs1) -> std::tuple<double, uint32_t> {
using T = double;
return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
return (vs1 > vs2) ? vs1 : vs2;
});
});
default:
LOG(ERROR) << "Vfmax: Illegal sew (" << sew << ")";
rv_vector->set_vector_exception();
return;
}
}
// Vector fp merge.
void Vfmerge(const Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
rv_vector, inst,
[](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t {
return mask ? vs1 : vs2;
});
case 8:
return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
rv_vector, inst,
[](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t {
return mask ? vs1 : vs2;
});
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Vfmerge: Illegal sew (" << sew << ")";
return;
}
}
} // namespace cheriot
} // namespace sim
} // namespace mpact