// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cheriot/riscv_cheriot_vector_opm_instructions.h"

#include <cstdint>
#include <functional>
#include <limits>
#include <type_traits>

#include "absl/log/log.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/cheriot_vector_state.h"
#include "cheriot/riscv_cheriot_vector_instruction_helpers.h"
#include "mpact/sim/generic/type_helpers.h"
#include "riscv//riscv_register.h"

namespace mpact {
namespace sim {
namespace cheriot {

using ::mpact::sim::generic::WideType;

// Helper function used to factor out some code from Vaadd* instructions.
template <typename T>
inline T VaaddHelper(CheriotVectorState* rv_vector, T vs2, T vs1) {
  // Perform the addition using a wider type, then shift and round.
  using WT = typename WideType<T>::type;
  WT vs2_w = static_cast<WT>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  auto res = RoundOff(rv_vector, vs2_w + vs1_w, 1);
  return static_cast<T>(res);
}

// Average unsigned add. The two sources are added, then shifted right by one
// and rounded.
void Vaaddu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Average signed add. The two sources are added, then shifted right by one and
// rounded.
void Vaadd(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>(
          rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>(
          rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>(
          rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t {
            return VaaddHelper(rv_vector, vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Helper function for Vasub* instructions. Subract using a wider type, then
// round.
template <typename T>
inline T VasubHelper(CheriotVectorState* rv_vector, T vs2, T vs1) {
  using WT = typename WideType<T>::type;
  WT vs2_w = static_cast<WT>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  auto res = RoundOff(rv_vector, vs2_w - vs1_w, 1);
  return static_cast<T>(res);
}

// Averaging unsigned subtract - subtract then shift right by 1 and round.
void Vasubu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Averaging signed subtract. Subtract then shift right by 1 and round.
void Vasub(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>(
          rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>(
          rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>(
          rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t {
            return VasubHelper(rv_vector, vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Mask operands only operate on a single vector register. This helper function
// is used by the following bitwise mask manipulation instruction semantic
// functions.
static inline void BitwiseMaskBinaryOp(
    CheriotVectorState* rv_vector, const Instruction* inst,
    std::function<uint8_t(uint8_t, uint8_t)> op) {
  if (rv_vector->vector_exception()) return;
  int vstart = rv_vector->vstart();
  int vlen = rv_vector->vector_length();
  // Get spans for vector source and destination registers.
  auto* vs2_op = static_cast<RV32VectorSourceOperand*>(inst->Source(0));
  auto vs2_span = vs2_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
  auto* vs1_op = static_cast<RV32VectorSourceOperand*>(inst->Source(1));
  auto vs1_span = vs1_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
  auto* vd_op =
      static_cast<RV32VectorDestinationOperand*>(inst->Destination(0));
  auto* vd_db = vd_op->CopyDataBuffer();
  auto vd_span = vd_db->Get<uint8_t>();
  // Compute start and end locations.
  int start_byte = vstart / 8;
  int start_offset = vstart % 8;
  uint8_t start_mask = 0b1111'1111 << start_offset;
  int end_byte = (vlen - 1) / 8;
  int end_offset = (vlen - 1) % 8;
  uint8_t end_mask = 0b1111'1111 >> (7 - end_offset);
  // The start byte is computed first, applying a mask to mask out any preceding
  // bits.
  vd_span[start_byte] =
      (op(vs2_span[start_byte], vs1_span[start_byte]) & start_mask) |
      (vd_span[start_byte] & ~start_mask);
  // Perform the bitwise operation on each byte between start and end.
  for (int i = start_byte + 1; i < end_byte; i++) {
    vd_span[i] = op(vs2_span[i], vs1_span[i]);
  }
  // Perform the bitwise operation with a mask on the end byte.
  vd_span[end_byte] = (op(vs2_span[end_byte], vs1_span[end_byte]) & end_mask) |
                      (vd_span[end_byte] & ~end_mask);
  vd_db->Submit();
  rv_vector->clear_vstart();
}

// Bitwise vector mask instructions. The operation is clear by their name.
void Vmandnot(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return vs2 & ~vs1;
  });
}

void Vmand(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return vs2 & vs1;
  });
}
void Vmor(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return vs2 | vs1;
  });
}
void Vmxor(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return vs2 ^ vs1;
  });
}
void Vmornot(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return vs2 | ~vs1;
  });
}
void Vmnand(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return ~(vs2 & vs1);
  });
}
void Vmnor(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return ~(vs2 | vs1);
  });
}
void Vmxnor(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
    return ~(vs2 ^ vs1);
  });
}

// Vector unsigned divide. Note, just like the scalar divide instruction, a
// divide by zero does not cause an exception, instead it returns all 1s.
void Vdivu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
            if (vs1 == 0) return ~vs1;
            return vs2 / vs1;
          });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t {
            if (vs1 == 0) return ~vs1;
            return vs2 / vs1;
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            if (vs1 == 0) return ~vs1;
            return vs2 / vs1;
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            if (vs1 == 0) return ~vs1;
            return vs2 / vs1;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Signed divide. Divide by 0 returns all 1s. If -1 is divided by the largest
// magnitude negative number, it returns that negative number.
void Vdiv(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t {
            if (vs1 == 0) return static_cast<int8_t>(-1);
            if ((vs1 == -1) && (vs2 == std::numeric_limits<int8_t>::min())) {
              return std::numeric_limits<int8_t>::min();
            }
            return vs2 / vs1;
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t {
            if (vs1 == 0) return static_cast<int16_t>(-1);
            if ((vs1 == -1) && (vs2 == std::numeric_limits<int16_t>::min())) {
              return std::numeric_limits<int16_t>::min();
            }
            return vs2 / vs1;
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t {
            if (vs1 == 0) return static_cast<int32_t>(-1);
            if ((vs1 == -1) && (vs2 == std::numeric_limits<int32_t>::min())) {
              return std::numeric_limits<int32_t>::min();
            }
            return vs2 / vs1;
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t {
            if (vs1 == 0) return static_cast<int64_t>(-1);
            if ((vs1 == -1) && (vs2 == std::numeric_limits<int64_t>::min())) {
              return std::numeric_limits<int64_t>::min();
            }
            return vs2 / vs1;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Unsigned remainder. If the denominator is 0, it returns the enumerator.
void Vremu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Signed remainder. If the denominator is 0, it returns the enumerator.
void Vrem(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t {
            if (vs1 == 0) return vs2;
            return vs2 % vs1;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Helper function for multiply high. It promotes the to arguments to wider
// types, performs the multiplication, returns the high half of the result.
template <typename T>
inline T VmulHighHelper(T vs2, T vs1) {
  using WT = typename WideType<T>::type;
  WT vs2_w = static_cast<WT>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  WT prod = vs2_w * vs1_w;
  prod >>= sizeof(T) * 8;
  return static_cast<T>(prod);
}

// Multiply high, unsigned.
void Vmulhu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t {
            return VmulHighHelper(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Signed multiply. Note, that signed and unsigned multiply operations have the
// same result for the low half of the product.
void Vmul(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst,
          [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 * vs1; });
    case 2:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t {
            uint32_t vs2_32 = vs2;
            uint32_t vs1_32 = vs1;
            return static_cast<uint16_t>(vs2_32 * vs1_32);
          });
    case 4:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 * vs1; });
    case 8:
      // The 64 bit version is treated a little differently. Because the vs1
      // operand may come from a register which may be 32 bits wide, it's first
      // converted to int64_t. Then the product is done on unsigned numbers to
      // avoid a signed multiply overflow, and returned as a signed number.
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t {
            uint64_t vs2_u = vs2;
            uint64_t vs1_u = vs1;
            uint64_t prod = vs2_u * vs1_u;
            return prod;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Helper for signed-unsigned multiplication return high half.
template <typename T>
inline typename std::make_signed<T>::type VmulHighSUHelper(
    typename std::make_signed<T>::type vs2,
    typename std::make_unsigned<T>::type vs1) {
  using WT = typename WideType<T>::type;
  using WST = typename WideType<typename std::make_signed<T>::type>::type;
  WST vs2_w = static_cast<WST>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  WST prod = vs2_w * vs1_w;
  prod >>= sizeof(T) * 8;
  return static_cast<typename std::make_signed<T>::type>(prod);
}

// Multiply signed unsigned and return the high half.
void Vmulhsu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, uint8_t>(
          rv_vector, inst, [](int8_t vs2, uint8_t vs1) -> int8_t {
            return VmulHighSUHelper<int8_t>(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, uint16_t>(
          rv_vector, inst, [](int16_t vs2, uint16_t vs1) -> int16_t {
            return VmulHighSUHelper<int16_t>(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, uint32_t>(
          rv_vector, inst, [](int32_t vs2, uint32_t vs1) -> int32_t {
            return VmulHighSUHelper<int32_t>(vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, uint64_t>(
          rv_vector, inst, [](int64_t vs2, uint64_t vs1) -> int64_t {
            return VmulHighSUHelper<int64_t>(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Signed multiply, return high half.
void Vmulh(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t {
            return VmulHighHelper(vs2, vs1);
          });
    case 8:
      return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>(
          rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t {
            return VmulHighHelper(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Multiply-add.
void Vmadd(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t {
            uint8_t prod = vs1 * vd;
            return prod + vs2;
          });
    case 2:
      return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst,
          [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t {
            uint32_t vs2_32 = vs2;
            uint32_t vs1_32 = vs1;
            uint32_t vd_32 = vd;
            return static_cast<uint16_t>(vs1_32 * vd_32 + vs2_32);
          });
    case 4:
      return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t {
            return vs1 * vd + vs2;
          });
    case 8:
      return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst,
          [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t {
            return vs1 * vd + vs2;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Negated multiply and add.
void Vnmsub(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t {
            return -(vs1 * vd) + vs2;
          });
    case 2:
      return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst,
          [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t {
            uint32_t vs2_32 = vs2;
            uint32_t vs1_32 = vs1;
            uint32_t vd_32 = vd;
            return static_cast<uint16_t>(-(vs1_32 * vd_32) + vs2_32);
          });
    case 4:
      return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t {
            return -(vs1 * vd) + vs2;
          });
    case 8:
      return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst,
          [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t {
            return -(vs1 * vd) + vs2;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Multiply add overwriting the sum.
void Vmacc(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t {
            return vs1 * vs2 + vd;
          });
    case 2:
      return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst,
          [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t {
            uint32_t vs2_32 = vs2;
            uint32_t vs1_32 = vs1;
            uint32_t vd_32 = vd;
            return static_cast<uint16_t>(vs1_32 * vs2_32 + vd_32);
          });
    case 4:
      return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t {
            return vs1 * vs2 + vd;
          });
    case 8:
      return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst,
          [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t {
            return vs1 * vs2 + vd;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Negated multiply add, overwriting sum.
void Vnmsac(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t {
            return -(vs1 * vs2) + vd;
          });
    case 2:
      return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>(
          rv_vector, inst,
          [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t {
            uint32_t vs2_32 = vs2;
            uint32_t vs1_32 = vs1;
            uint32_t vd_32 = vd;
            return static_cast<uint16_t>(-(vs1_32 * vs2_32) + vd_32);
          });
    case 4:
      return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t {
            return -(vs1 * vs2) + vd;
          });
    case 8:
      return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>(
          rv_vector, inst,
          [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t {
            return -(vs1 * vs2) + vd;
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening unsigned add.
void Vwaddu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t {
            return static_cast<uint16_t>(vs2) + static_cast<uint16_t>(vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t {
            return static_cast<uint32_t>(vs2) + static_cast<uint32_t>(vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t {
            return static_cast<uint64_t>(vs2) + static_cast<uint64_t>(vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening unsigned subtract.
void Vwsubu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t {
            return static_cast<uint16_t>(vs2) - static_cast<uint16_t>(vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t {
            return static_cast<uint32_t>(vs2) - static_cast<uint32_t>(vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t {
            return static_cast<uint64_t>(vs2) - static_cast<uint64_t>(vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed add.
void Vwadd(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  // The values are first sign extended to the wide signed value, then
  // an unsigned addition is performed, for which overflow is not undefined,
  // as opposed to signed additions.
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t {
            return static_cast<uint16_t>(static_cast<int16_t>(vs2)) +
                   static_cast<uint16_t>(static_cast<int16_t>(vs1));
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t {
            return static_cast<uint32_t>(static_cast<int32_t>(vs2)) +
                   static_cast<uint32_t>(static_cast<int32_t>(vs1));
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t {
            return static_cast<uint64_t>(static_cast<int64_t>(vs2)) +
                   static_cast<uint64_t>(static_cast<int64_t>(vs1));
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed subtract.
void Vwsub(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }  // The values are first sign extended to the wide signed value, then
  // an unsigned subtraction is performed, for which overflow is not undefined,
  // as opposed to signed subtraction.
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t {
            return static_cast<uint16_t>(static_cast<int16_t>(vs2)) -
                   static_cast<uint16_t>(static_cast<int16_t>(vs1));
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t {
            return static_cast<uint32_t>(static_cast<int32_t>(vs2)) -
                   static_cast<uint32_t>(static_cast<int32_t>(vs1));
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t {
            return static_cast<uint64_t>(static_cast<int64_t>(vs2)) -
                   static_cast<uint64_t>(static_cast<int64_t>(vs1));
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening unsigned add with wide source.
void Vwadduw(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>(
          rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t {
            return vs2 + static_cast<uint16_t>(vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>(
          rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t {
            return vs2 + static_cast<uint32_t>(vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>(
          rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t {
            return vs2 + static_cast<uint64_t>(vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening unsigned subtract with wide source.
void Vwsubuw(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>(
          rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t {
            return vs2 - static_cast<uint16_t>(vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>(
          rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t {
            return vs2 - static_cast<uint32_t>(vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>(
          rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t {
            return vs2 - static_cast<uint64_t>(vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed add with wide source.
void Vwaddw(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int16_t, uint16_t, int8_t>(
          rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t {
            return vs2 + static_cast<uint16_t>(static_cast<int16_t>(vs1));
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>(
          rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t {
            return vs2 + static_cast<uint32_t>(static_cast<int32_t>(vs1));
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>(
          rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t {
            return vs2 + static_cast<uint64_t>(static_cast<int64_t>(vs1));
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed subtract with wide source.
void Vwsubw(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint16_t, int8_t>(
          rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t {
            return vs2 - static_cast<uint16_t>(static_cast<int16_t>(vs1));
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>(
          rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t {
            return vs2 - static_cast<uint32_t>(static_cast<int32_t>(vs1));
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>(
          rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t {
            return vs2 - static_cast<uint64_t>(static_cast<int64_t>(vs1));
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening multiply helper function. Factors out some code.
template <typename T>
inline typename WideType<T>::type VwmulHelper(T vs2, T vs1) {
  using WT = typename WideType<T>::type;
  WT vs2_w = static_cast<WT>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  WT prod = vs2_w * vs1_w;
  return prod;
}

// Unsigned widening multiply.
void Vwmulu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>(
          rv_vector, inst, [](uint8_t vs2, int8_t vs1) -> uint16_t {
            return VwmulHelper<uint8_t>(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>(
          rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t {
            return VwmulHelper<uint16_t>(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>(
          rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t {
            return VwmulHelper<uint32_t>(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed-unsigned multiply helper function.
template <typename T>
inline typename WideType<typename std::make_signed<T>::type>::type
VwmulSuHelper(typename std::make_signed<T>::type vs2,
              typename std::make_unsigned<T>::type vs1) {
  using WST = typename WideType<typename std::make_signed<T>::type>::type;
  using WT = typename WideType<typename std::make_unsigned<T>::type>::type;
  WST vs2_w = static_cast<WST>(vs2);
  WT vs1_w = static_cast<WT>(vs1);
  WST prod = vs2_w * vs1_w;
  return prod;
}

// Widening multiply signed-unsigned.
void Vwmulsu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int16_t, int8_t, uint8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t {
            return VwmulSuHelper<int8_t>(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int32_t, int16_t, uint16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t {
            return VwmulSuHelper<int16_t>(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int64_t, int32_t, uint32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t {
            return VwmulSuHelper<int32_t>(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed multiply.
void Vwmul(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVBinaryVectorOp<int16_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t {
            return VwmulHelper<int8_t>(vs2, vs1);
          });
    case 2:
      return RiscVBinaryVectorOp<int32_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t {
            return VwmulHelper<int16_t>(vs2, vs1);
          });
    case 4:
      return RiscVBinaryVectorOp<int64_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t {
            return VwmulHelper<int32_t>(vs2, vs1);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening multiply accumulate helper function.
template <typename Vd, typename Vs2, typename Vs1>
Vd VwmaccHelper(Vs2 vs2, Vs1 vs1, Vd vd) {
  Vd vs2_w = static_cast<Vd>(vs2);
  Vd vs1_w = static_cast<Vd>(vs1);
  Vd prod = vs2_w * vs1_w;
  using UVd = typename std::make_unsigned<Vd>::type;
  Vd res = absl::bit_cast<UVd>(prod) + absl::bit_cast<UVd>(vd);
  return res;
}

// Unsigned widening multiply and add.
void Vwmaccu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<uint16_t, uint8_t, uint8_t>(
          rv_vector, inst,
          [](uint8_t vs2, uint8_t vs1, uint16_t vd) -> uint16_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 2:
      return RiscVTernaryVectorOp<uint32_t, uint16_t, uint16_t>(
          rv_vector, inst,
          [](uint16_t vs2, uint16_t vs1, uint32_t vd) -> uint32_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 4:
      return RiscVTernaryVectorOp<uint64_t, uint32_t, uint32_t>(
          rv_vector, inst,
          [](uint32_t vs2, uint32_t vs1, uint64_t vd) -> uint64_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed multiply and add.
void Vwmacc(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<int16_t, int8_t, int8_t>(
          rv_vector, inst, [](int8_t vs2, int8_t vs1, int16_t vd) -> int16_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 2:
      return RiscVTernaryVectorOp<int32_t, int16_t, int16_t>(
          rv_vector, inst, [](int16_t vs2, int16_t vs1, int32_t vd) -> int32_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 4:
      return RiscVTernaryVectorOp<int64_t, int32_t, int32_t>(
          rv_vector, inst, [](int32_t vs2, int32_t vs1, int64_t vd) -> int64_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening unsigned-signed multiply and add.
void Vwmaccus(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<int16_t, int8_t, uint8_t>(
          rv_vector, inst, [](int8_t vs2, uint8_t vs1, int16_t vd) -> int16_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 2:
      return RiscVTernaryVectorOp<int32_t, int16_t, uint16_t>(
          rv_vector, inst,
          [](int16_t vs2, uint16_t vs1, int32_t vd) -> int32_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 4:
      return RiscVTernaryVectorOp<int64_t, int32_t, uint32_t>(
          rv_vector, inst,
          [](int32_t vs2, uint32_t vs1, int64_t vd) -> int64_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

// Widening signed-unsigned multiply and add.
void Vwmaccsu(const Instruction* inst) {
  auto* rv_vector = static_cast<CheriotState*>(inst->state())->rv_vector();
  int sew = rv_vector->selected_element_width();
  // LMUL8 cannot be 64.
  if (rv_vector->vector_length_multiplier() > 32) {
    rv_vector->set_vector_exception();
    LOG(ERROR)
        << "Vector length multiplier out of range for widening operation";
    return;
  }
  switch (sew) {
    case 1:
      return RiscVTernaryVectorOp<int16_t, uint8_t, int8_t>(
          rv_vector, inst, [](uint8_t vs2, int8_t vs1, int16_t vd) -> int16_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 2:
      return RiscVTernaryVectorOp<int32_t, uint16_t, int16_t>(
          rv_vector, inst,
          [](uint16_t vs2, int16_t vs1, int32_t vd) -> int32_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    case 4:
      return RiscVTernaryVectorOp<int64_t, uint32_t, int32_t>(
          rv_vector, inst,
          [](uint32_t vs2, int32_t vs1, int64_t vd) -> int64_t {
            return VwmaccHelper(vs2, vs1, vd);
          });
    default:
      rv_vector->set_vector_exception();
      LOG(ERROR) << "Illegal SEW value";
      return;
  }
}

}  // namespace cheriot
}  // namespace sim
}  // namespace mpact
