// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cheriot/riscv_cheriot_vector_fp_instructions.h"

#include <cmath>
#include <cstdint>
#include <functional>
#include <tuple>

#include "absl/log/log.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/riscv_cheriot_vector_instruction_helpers.h"
#include "mpact/sim/generic/type_helpers.h"
#include "riscv//riscv_fp_host.h"
#include "riscv//riscv_fp_info.h"
#include "riscv//riscv_fp_state.h"

namespace mpact {
namespace sim {
namespace cheriot {

using ::mpact::sim::generic::FPTypeInfo;
using ::mpact::sim::riscv::FPExceptions;
using ::mpact::sim::riscv::ScopedFPStatus;

// Floating point add.
void Vfadd(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs2 + vs1; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs2 + vs1; });
    default:
      LOG(ERROR) << "Vfadd: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point subtract.
void Vfsub(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs2 - vs1; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs2 - vs1; });
    default:
      LOG(ERROR) << "Vfsub: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Reverse floating point subtract (rs1 - vs2).
void Vfrsub(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs1 - vs2; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs1 - vs2; });
    default:
      LOG(ERROR) << "Vfrsub: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point add.
void Vfwadd(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1) -> double {
            double vs2_d = static_cast<double>(vs2);
            double vs1_d = static_cast<double>(vs1);
            return (vs2_d + vs1_d);
          });
    default:
      LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point subtract.
void Vfwsub(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1) -> double {
            double vs2_d = static_cast<double>(vs2);
            double vs1_d = static_cast<double>(vs1);
            return (vs2_d - vs1_d);
          });
    default:
      LOG(ERROR) << "Vfwsub: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point add with wide operand (vs2).
void Vfwaddw(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<double, double, float>(
          rv_vector, inst, [](double vs2_d, float vs1) -> double {
            double vs1_d = static_cast<double>(vs1);
            return (vs2_d + vs1_d);
          });
    default:
      LOG(ERROR) << "Vfwaddw: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point subtract with wide operand (vs2).
void Vfwsubw(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<double, double, float>(
          rv_vector, inst, [](double vs2_d, float vs1) -> double {
            double vs1_d = static_cast<double>(vs1);
            return (vs2_d - vs1_d);
          });
    default:
      LOG(ERROR) << "Vfwsubw: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point multiply.
void Vfmul(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs2 * vs1; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs2 * vs1; });
    default:
      LOG(ERROR) << "Vfmul: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point division vs2/vs1;
void Vfdiv(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs2 / vs1; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs2 / vs1; });
    default:
      LOG(ERROR) << "Vfdiv: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point reverse division vs1/vs2.
void Vfrdiv(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> float { return vs1 / vs2; });
    case 8:
      return RiscVBinaryVectorOp<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> double { return vs1 / vs2; });
    default:
      LOG(ERROR) << "Vfrdiv: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point multiply.
void Vfwmul(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1) -> double {
            double vs2_d = static_cast<double>(vs2);
            double vs1_d = static_cast<double>(vs1);
            return (vs2_d * vs1_d);
          });
    default:
      LOG(ERROR) << "Vfwadd: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point multiply and add vs2.
void Vfmadd(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(vs1, vd, vs2);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(vs1, vd, vs2);
          });
    default:
      LOG(ERROR) << "Vfmadd: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Negated floating point multiply and add vs2.
void Vfnmadd(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(-vs1, vd, -vs2);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(-vs1, vd, -vs2);
          });
    default:
      LOG(ERROR) << "Vfnmadd: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point multiply and subtract vs2.
void Vfmsub(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(vs1, vd, -vs2);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(vs1, vd, -vs2);
          });
    default:
      LOG(ERROR) << "Vfmsub: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Negated floating point multiply and subtract vs2.
void Vfnmsub(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(-vs1, vd, vs2);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(-vs1, vd, vs2);
          });
    default:
      LOG(ERROR) << "Vfnmsub: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point multiply and accumulate vd.
void Vfmacc(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(vs1, vs2, vd);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(vs1, vs2, vd);
          });
    default:
      LOG(ERROR) << "Vfmacc: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Negated floating point multiply and accumulate vd.
void Vfnmacc(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(-vs1, vs2, -vd);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(-vs1, vs2, -vd);
          });
    default:
      LOG(ERROR) << "Vfnmacc: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Floating point multiply and subtract vd.
void Vfmsac(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(vs1, vs2, -vd);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(vs1, vs2, -vd);
          });
    default:
      LOG(ERROR) << "Vfmsac: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Negated floating point multiply and subtract vd.
void Vfnmsac(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<float, float, float>(
          rv_vector, inst, [](float vs2, float vs1, float vd) -> float {
            return std::fma(-vs1, vs2, vd);
          });
    case 8:
      return RiscVTernaryVectorOp<double, double, double>(
          rv_vector, inst, [](double vs2, double vs1, double vd) -> double {
            return std::fma(-vs1, vs2, vd);
          });
    default:
      LOG(ERROR) << "Vfnmsac: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point multiply and accumulate vd.
void Vfwmacc(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
            double vs1_d = vs1;
            double vs2_d = vs2;
            return ((vs1_d * vs2_d) + vd);
          });
    default:
      LOG(ERROR) << "Vfwmacc: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening negated floating point multiply and accumulate vd.
void Vfwnmacc(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
            double vs1_d = vs1;
            double vs2_d = vs2;
            return (-(vs1_d * vs2_d)) - vd;
          });
    default:
      LOG(ERROR) << "Vfwnmacc: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening floating point multiply and subtract vd.
void Vfwmsac(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
            double vs1_d = vs1;
            double vs2_d = vs2;
            return ((vs1_d * vs2_d) - vd);
          });
    default:
      LOG(ERROR) << "Vfwmsac: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Widening negated floating point multiply and subtract vd.
void Vfwnmsac(const Instruction* inst) {
  auto* rv_fp = static_cast<CheriotState*>(inst->state())->rv_fp();
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  if (!rv_fp->rounding_mode_valid()) {
    LOG(ERROR) << "Invalid rounding mode";
    rv_vector->set_vector_exception();
    return;
  }
  int sew = rv_vector->selected_element_width();
  ScopedFPStatus set_fpstatus(rv_fp->host_fp_interface());
  switch (sew) {
    case 4:
      return RiscVTernaryVectorOp<double, float, float>(
          rv_vector, inst, [](float vs2, float vs1, double vd) -> double {
            double vs1_d = vs1;
            double vs2_d = vs2;
            return (-(vs1_d * vs2_d)) + vd;
          });
    default:
      LOG(ERROR) << "Vfwnmsac: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Change the sign of vs2 to the sign of vs1.
void Vfsgnj(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return (vs2 & 0x7fff'ffff) | (vs1 & 0x8000'0000);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return (vs2 & 0x7fff'ffff'ffff'ffff) |
                   (vs1 & 0x8000'0000'0000'0000);
          });
    default:
      LOG(ERROR) << "Vfsgnj: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Change the sign of vs2 to the negation of the sign of vs1.
void Vfsgnjn(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return (vs2 & 0x7fff'ffff) | (~vs1 & 0x8000'0000);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return (vs2 & 0x7fff'ffff'ffff'ffff) |
                   (~vs1 & 0x8000'0000'0000'0000);
          });
    default:
      LOG(ERROR) << "Vfsgnjn: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Change the sign of vs2 to the xor of the sign of the two operands.
void Vfsgnjx(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return (vs2 & 0x7fff'ffff) | ((vs1 ^ vs2) & 0x8000'0000);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return (vs2 & 0x7fff'ffff'ffff'ffff) ^
                   ((vs1 ^ vs2) & 0x8000'0000'0000'0000);
          });
    default:
      LOG(ERROR) << "Vfsgnjx: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Templated helper function for vfmin and vfmax instructions.
template <typename T>
inline std::tuple<T, uint32_t> MaxMinHelper(T vs2, T vs1,
                                            std::function<T(T, T)> operation) {
  // If either operand is a signaling NaN or if both operands are NaNs, then
  // return a canonical (non-signaling) NaN.
  uint32_t flag = 0;
  if (FPTypeInfo<T>::IsSNaN(vs1) || FPTypeInfo<T>::IsSNaN(vs2)) {
    flag = static_cast<uint32_t>(FPExceptions::kInvalidOp);
  }
  if (FPTypeInfo<T>::IsNaN(vs2) && FPTypeInfo<T>::IsNaN(vs1)) {
    auto c_nan = FPTypeInfo<T>::kCanonicalNaN;
    return std::make_tuple(*reinterpret_cast<T*>(&c_nan), flag);
  }
  // If either operand is a NaN return the other.
  if (FPTypeInfo<T>::IsNaN(vs2)) return std::tie(vs1, flag);
  if (FPTypeInfo<T>::IsNaN(vs1)) return std::tie(vs2, flag);
  // Return the min/max of the two operands.
  if ((vs2 == 0.0) && (vs1 == 0.0)) {
    T tmp2 = std::signbit(vs2) ? -1.0 : 1;
    T tmp1 = std::signbit(vs1) ? -1.0 : 1;
    return std::make_tuple(operation(tmp2, tmp1) == tmp2 ? vs2 : vs1, 0);
  }
  return std::make_tuple(operation(vs2, vs1), flag);
}

// Vector floating point min.
void Vfmin(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOpWithFflags<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> std::tuple<float, uint32_t> {
            using T = float;
            return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
              return (vs1 < vs2) ? vs1 : vs2;
            });
          });
    case 8:
      return RiscVBinaryVectorOpWithFflags<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> std::tuple<double, uint32_t> {
            using T = double;
            return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
              return (vs1 < vs2) ? vs1 : vs2;
            });
          });
    default:
      LOG(ERROR) << "Vfmin: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Vector floating point max.
void Vfmax(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVBinaryVectorOpWithFflags<float, float, float>(
          rv_vector, inst,
          [](float vs2, float vs1) -> std::tuple<float, uint32_t> {
            using T = float;
            return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
              return (vs1 > vs2) ? vs1 : vs2;
            });
          });
    case 8:
      return RiscVBinaryVectorOpWithFflags<double, double, double>(
          rv_vector, inst,
          [](double vs2, double vs1) -> std::tuple<double, uint32_t> {
            using T = double;
            return MaxMinHelper<T>(vs2, vs1, [](T vs2, T vs1) -> T {
              return (vs1 > vs2) ? vs1 : vs2;
            });
          });
    default:
      LOG(ERROR) << "Vfmax: Illegal sew (" << sew << ")";
      rv_vector->set_vector_exception();
      return;
  }
}

// Vector fp merge.
void Vfmerge(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 4:
      return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t {
            return mask ? vs1 : vs2;
          });
    case 8:
      return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst,
          [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t {
            return mask ? vs1 : vs2;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Vfmerge: Illegal sew (" << sew << ")";
      return;
  }
}

}  // namespace cheriot
}  // namespace sim
}  // namespace mpact
