| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_opm_instructions.h" |
| |
| #include <cstdint> |
| #include <functional> |
| #include <limits> |
| #include <type_traits> |
| |
| #include "absl/log/log.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/cheriot_vector_state.h" |
| #include "cheriot/riscv_cheriot_vector_instruction_helpers.h" |
| #include "mpact/sim/generic/type_helpers.h" |
| #include "riscv//riscv_register.h" |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using ::mpact::sim::generic::WideType; |
| |
| // Helper function used to factor out some code from Vaadd* instructions. |
| template <typename T> |
| inline T VaaddHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { |
| // Perform the addition using a wider type, then shift and round. |
| using WT = typename WideType<T>::type; |
| WT vs2_w = static_cast<WT>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| auto res = RoundOff(rv_vector, vs2_w + vs1_w, 1); |
| return static_cast<T>(res); |
| } |
| |
| // Average unsigned add. The two sources are added, then shifted right by one |
| // and rounded. |
| void Vaaddu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Average signed add. The two sources are added, then shifted right by one and |
| // rounded. |
| void Vaadd(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VaaddHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Helper function for Vasub* instructions. Subract using a wider type, then |
| // round. |
| template <typename T> |
| inline T VasubHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { |
| using WT = typename WideType<T>::type; |
| WT vs2_w = static_cast<WT>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| auto res = RoundOff(rv_vector, vs2_w - vs1_w, 1); |
| return static_cast<T>(res); |
| } |
| |
| // Averaging unsigned subtract - subtract then shift right by 1 and round. |
| void Vasubu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Averaging signed subtract. Subtract then shift right by 1 and round. |
| void Vasub(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VasubHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Mask operands only operate on a single vector register. This helper function |
| // is used by the following bitwise mask manipulation instruction semantic |
| // functions. |
| static inline void BitwiseMaskBinaryOp( |
| CheriotVectorState *rv_vector, const Instruction *inst, |
| std::function<uint8_t(uint8_t, uint8_t)> op) { |
| if (rv_vector->vector_exception()) return; |
| int vstart = rv_vector->vstart(); |
| int vlen = rv_vector->vector_length(); |
| // Get spans for vector source and destination registers. |
| auto *vs2_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto vs2_span = vs2_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto *vs1_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto vs1_span = vs1_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto *vd_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto *vd_db = vd_op->CopyDataBuffer(); |
| auto vd_span = vd_db->Get<uint8_t>(); |
| // Compute start and end locations. |
| int start_byte = vstart / 8; |
| int start_offset = vstart % 8; |
| uint8_t start_mask = 0b1111'1111 << start_offset; |
| int end_byte = (vlen - 1) / 8; |
| int end_offset = (vlen - 1) % 8; |
| uint8_t end_mask = 0b1111'1111 >> (7 - end_offset); |
| // The start byte is computed first, applying a mask to mask out any preceding |
| // bits. |
| vd_span[start_byte] = |
| (op(vs2_span[start_byte], vs1_span[start_byte]) & start_mask) | |
| (vd_span[start_byte] & ~start_mask); |
| // Perform the bitwise operation on each byte between start and end. |
| for (int i = start_byte + 1; i < end_byte; i++) { |
| vd_span[i] = op(vs2_span[i], vs1_span[i]); |
| } |
| // Perform the bitwise operation with a mask on the end byte. |
| vd_span[end_byte] = (op(vs2_span[end_byte], vs1_span[end_byte]) & end_mask) | |
| (vd_span[end_byte] & ~end_mask); |
| vd_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Bitwise vector mask instructions. The operation is clear by their name. |
| void Vmandnot(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 & ~vs1; |
| }); |
| } |
| |
| void Vmand(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 & vs1; |
| }); |
| } |
| void Vmor(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 | vs1; |
| }); |
| } |
| void Vmxor(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 ^ vs1; |
| }); |
| } |
| void Vmornot(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 | ~vs1; |
| }); |
| } |
| void Vmnand(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return ~(vs2 & vs1); |
| }); |
| } |
| void Vmnor(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return ~(vs2 | vs1); |
| }); |
| } |
| void Vmxnor(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| BitwiseMaskBinaryOp(rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return ~(vs2 ^ vs1); |
| }); |
| } |
| |
| // Vector unsigned divide. Note, just like the scalar divide instruction, a |
| // divide by zero does not cause an exception, instead it returns all 1s. |
| void Vdivu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| if (vs1 == 0) return ~vs1; |
| return vs2 / vs1; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| if (vs1 == 0) return ~vs1; |
| return vs2 / vs1; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| if (vs1 == 0) return ~vs1; |
| return vs2 / vs1; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| if (vs1 == 0) return ~vs1; |
| return vs2 / vs1; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Signed divide. Divide by 0 returns all 1s. If -1 is divided by the largest |
| // magnitude negative number, it returns that negative number. |
| void Vdiv(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { |
| if (vs1 == 0) return static_cast<int8_t>(-1); |
| if ((vs1 == -1) && (vs2 == std::numeric_limits<int8_t>::min())) { |
| return std::numeric_limits<int8_t>::min(); |
| } |
| return vs2 / vs1; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| if (vs1 == 0) return static_cast<int16_t>(-1); |
| if ((vs1 == -1) && (vs2 == std::numeric_limits<int16_t>::min())) { |
| return std::numeric_limits<int16_t>::min(); |
| } |
| return vs2 / vs1; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| if (vs1 == 0) return static_cast<int32_t>(-1); |
| if ((vs1 == -1) && (vs2 == std::numeric_limits<int32_t>::min())) { |
| return std::numeric_limits<int32_t>::min(); |
| } |
| return vs2 / vs1; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| if (vs1 == 0) return static_cast<int64_t>(-1); |
| if ((vs1 == -1) && (vs2 == std::numeric_limits<int64_t>::min())) { |
| return std::numeric_limits<int64_t>::min(); |
| } |
| return vs2 / vs1; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Unsigned remainder. If the denominator is 0, it returns the enumerator. |
| void Vremu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Signed remainder. If the denominator is 0, it returns the enumerator. |
| void Vrem(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| if (vs1 == 0) return vs2; |
| return vs2 % vs1; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Helper function for multiply high. It promotes the to arguments to wider |
| // types, performs the multiplication, returns the high half of the result. |
| template <typename T> |
| inline T VmulHighHelper(T vs2, T vs1) { |
| using WT = typename WideType<T>::type; |
| WT vs2_w = static_cast<WT>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| WT prod = vs2_w * vs1_w; |
| prod >>= sizeof(T) * 8; |
| return static_cast<T>(prod); |
| } |
| |
| // Multiply high, unsigned. |
| void Vmulhu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Signed multiply. Note, that signed and unsigned multiply operations have the |
| // same result for the low half of the product. |
| void Vmul(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 * vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| uint32_t vs2_32 = vs2; |
| uint32_t vs1_32 = vs1; |
| return static_cast<uint16_t>(vs2_32 * vs1_32); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 * vs1; }); |
| case 8: |
| // The 64 bit version is treated a little differently. Because the vs1 |
| // operand may come from a register which may be 32 bits wide, it's first |
| // converted to int64_t. Then the product is done on unsigned numbers to |
| // avoid a signed multiply overflow, and returned as a signed number. |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| uint64_t vs2_u = vs2; |
| uint64_t vs1_u = vs1; |
| uint64_t prod = vs2_u * vs1_u; |
| return prod; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Helper for signed-unsigned multiplication return high half. |
| template <typename T> |
| inline typename std::make_signed<T>::type VmulHighSUHelper( |
| typename std::make_signed<T>::type vs2, |
| typename std::make_unsigned<T>::type vs1) { |
| using WT = typename WideType<T>::type; |
| using WST = typename WideType<typename std::make_signed<T>::type>::type; |
| WST vs2_w = static_cast<WST>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| WST prod = vs2_w * vs1_w; |
| prod >>= sizeof(T) * 8; |
| return static_cast<typename std::make_signed<T>::type>(prod); |
| } |
| |
| // Multiply signed unsigned and return the high half. |
| void Vmulhsu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, uint8_t>( |
| rv_vector, inst, [](int8_t vs2, uint8_t vs1) -> int8_t { |
| return VmulHighSUHelper<int8_t>(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, uint16_t>( |
| rv_vector, inst, [](int16_t vs2, uint16_t vs1) -> int16_t { |
| return VmulHighSUHelper<int16_t>(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, uint32_t>( |
| rv_vector, inst, [](int32_t vs2, uint32_t vs1) -> int32_t { |
| return VmulHighSUHelper<int32_t>(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, uint64_t>( |
| rv_vector, inst, [](int64_t vs2, uint64_t vs1) -> int64_t { |
| return VmulHighSUHelper<int64_t>(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Signed multiply, return high half. |
| void Vmulh(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| return VmulHighHelper(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Multiply-add. |
| void Vmadd(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { |
| uint8_t prod = vs1 * vd; |
| return prod + vs2; |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { |
| uint32_t vs2_32 = vs2; |
| uint32_t vs1_32 = vs1; |
| uint32_t vd_32 = vd; |
| return static_cast<uint16_t>(vs1_32 * vd_32 + vs2_32); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { |
| return vs1 * vd + vs2; |
| }); |
| case 8: |
| return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { |
| return vs1 * vd + vs2; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Negated multiply and add. |
| void Vnmsub(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { |
| return -(vs1 * vd) + vs2; |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { |
| uint32_t vs2_32 = vs2; |
| uint32_t vs1_32 = vs1; |
| uint32_t vd_32 = vd; |
| return static_cast<uint16_t>(-(vs1_32 * vd_32) + vs2_32); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { |
| return -(vs1 * vd) + vs2; |
| }); |
| case 8: |
| return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { |
| return -(vs1 * vd) + vs2; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Multiply add overwriting the sum. |
| void Vmacc(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { |
| return vs1 * vs2 + vd; |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { |
| uint32_t vs2_32 = vs2; |
| uint32_t vs1_32 = vs1; |
| uint32_t vd_32 = vd; |
| return static_cast<uint16_t>(vs1_32 * vs2_32 + vd_32); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { |
| return vs1 * vs2 + vd; |
| }); |
| case 8: |
| return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { |
| return vs1 * vs2 + vd; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Negated multiply add, overwriting sum. |
| void Vnmsac(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, uint8_t vd) -> uint8_t { |
| return -(vs1 * vs2) + vd; |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, uint16_t vd) -> uint16_t { |
| uint32_t vs2_32 = vs2; |
| uint32_t vs1_32 = vs1; |
| uint32_t vd_32 = vd; |
| return static_cast<uint16_t>(-(vs1_32 * vs2_32) + vd_32); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, uint32_t vd) -> uint32_t { |
| return -(vs1 * vs2) + vd; |
| }); |
| case 8: |
| return RiscVTernaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, uint64_t vd) -> uint64_t { |
| return -(vs1 * vs2) + vd; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening unsigned add. |
| void Vwaddu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t { |
| return static_cast<uint16_t>(vs2) + static_cast<uint16_t>(vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { |
| return static_cast<uint32_t>(vs2) + static_cast<uint32_t>(vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { |
| return static_cast<uint64_t>(vs2) + static_cast<uint64_t>(vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening unsigned subtract. |
| void Vwsubu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint16_t { |
| return static_cast<uint16_t>(vs2) - static_cast<uint16_t>(vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { |
| return static_cast<uint32_t>(vs2) - static_cast<uint32_t>(vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { |
| return static_cast<uint64_t>(vs2) - static_cast<uint64_t>(vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed add. |
| void Vwadd(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| // The values are first sign extended to the wide signed value, then |
| // an unsigned addition is performed, for which overflow is not undefined, |
| // as opposed to signed additions. |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t { |
| return static_cast<uint16_t>(static_cast<int16_t>(vs2)) + |
| static_cast<uint16_t>(static_cast<int16_t>(vs1)); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t { |
| return static_cast<uint32_t>(static_cast<int32_t>(vs2)) + |
| static_cast<uint32_t>(static_cast<int32_t>(vs1)); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t { |
| return static_cast<uint64_t>(static_cast<int64_t>(vs2)) + |
| static_cast<uint64_t>(static_cast<int64_t>(vs1)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed subtract. |
| void Vwsub(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } // The values are first sign extended to the wide signed value, then |
| // an unsigned subtraction is performed, for which overflow is not undefined, |
| // as opposed to signed subtraction. |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> uint16_t { |
| return static_cast<uint16_t>(static_cast<int16_t>(vs2)) - |
| static_cast<uint16_t>(static_cast<int16_t>(vs1)); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> uint32_t { |
| return static_cast<uint32_t>(static_cast<int32_t>(vs2)) - |
| static_cast<uint32_t>(static_cast<int32_t>(vs1)); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> uint64_t { |
| return static_cast<uint64_t>(static_cast<int64_t>(vs2)) - |
| static_cast<uint64_t>(static_cast<int64_t>(vs1)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening unsigned add with wide source. |
| void Vwadduw(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>( |
| rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t { |
| return vs2 + static_cast<uint16_t>(vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>( |
| rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t { |
| return vs2 + static_cast<uint32_t>(vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>( |
| rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t { |
| return vs2 + static_cast<uint64_t>(vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening unsigned subtract with wide source. |
| void Vwsubuw(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint8_t>( |
| rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint16_t { |
| return vs2 - static_cast<uint16_t>(vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint16_t>( |
| rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint32_t { |
| return vs2 - static_cast<uint32_t>(vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint32_t>( |
| rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint64_t { |
| return vs2 - static_cast<uint64_t>(vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed add with wide source. |
| void Vwaddw(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int16_t, uint16_t, int8_t>( |
| rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t { |
| return vs2 + static_cast<uint16_t>(static_cast<int16_t>(vs1)); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>( |
| rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t { |
| return vs2 + static_cast<uint32_t>(static_cast<int32_t>(vs1)); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>( |
| rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t { |
| return vs2 + static_cast<uint64_t>(static_cast<int64_t>(vs1)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed subtract with wide source. |
| void Vwsubw(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, int8_t>( |
| rv_vector, inst, [](uint16_t vs2, int8_t vs1) -> uint16_t { |
| return vs2 - static_cast<uint16_t>(static_cast<int16_t>(vs1)); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, int16_t>( |
| rv_vector, inst, [](uint32_t vs2, int16_t vs1) -> uint32_t { |
| return vs2 - static_cast<uint32_t>(static_cast<int32_t>(vs1)); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, int32_t>( |
| rv_vector, inst, [](uint64_t vs2, int32_t vs1) -> uint64_t { |
| return vs2 - static_cast<uint64_t>(static_cast<int64_t>(vs1)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening multiply helper function. Factors out some code. |
| template <typename T> |
| inline typename WideType<T>::type VwmulHelper(T vs2, T vs1) { |
| using WT = typename WideType<T>::type; |
| WT vs2_w = static_cast<WT>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| WT prod = vs2_w * vs1_w; |
| return prod; |
| } |
| |
| // Unsigned widening multiply. |
| void Vwmulu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint16_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, int8_t vs1) -> uint16_t { |
| return VwmulHelper<uint8_t>(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint32_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint32_t { |
| return VwmulHelper<uint16_t>(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint64_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint64_t { |
| return VwmulHelper<uint32_t>(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed-unsigned multiply helper function. |
| template <typename T> |
| inline typename WideType<typename std::make_signed<T>::type>::type |
| VwmulSuHelper(typename std::make_signed<T>::type vs2, |
| typename std::make_unsigned<T>::type vs1) { |
| using WST = typename WideType<typename std::make_signed<T>::type>::type; |
| using WT = typename WideType<typename std::make_unsigned<T>::type>::type; |
| WST vs2_w = static_cast<WST>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| WST prod = vs2_w * vs1_w; |
| return prod; |
| } |
| |
| // Widening multiply signed-unsigned. |
| void Vwmulsu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int16_t, int8_t, uint8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t { |
| return VwmulSuHelper<int8_t>(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int32_t, int16_t, uint16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t { |
| return VwmulSuHelper<int16_t>(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int64_t, int32_t, uint32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t { |
| return VwmulSuHelper<int32_t>(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed multiply. |
| void Vwmul(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int16_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int16_t { |
| return VwmulHelper<int8_t>(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int32_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int32_t { |
| return VwmulHelper<int16_t>(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int64_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int64_t { |
| return VwmulHelper<int32_t>(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening multiply accumulate helper function. |
| template <typename Vd, typename Vs2, typename Vs1> |
| Vd VwmaccHelper(Vs2 vs2, Vs1 vs1, Vd vd) { |
| Vd vs2_w = static_cast<Vd>(vs2); |
| Vd vs1_w = static_cast<Vd>(vs1); |
| Vd prod = vs2_w * vs1_w; |
| using UVd = typename std::make_unsigned<Vd>::type; |
| Vd res = absl::bit_cast<UVd>(prod) + absl::bit_cast<UVd>(vd); |
| return res; |
| } |
| |
| // Unsigned widening multiply and add. |
| void Vwmaccu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<uint16_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1, uint16_t vd) -> uint16_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<uint32_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, uint32_t vd) -> uint32_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<uint64_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, uint64_t vd) -> uint64_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed multiply and add. |
| void Vwmacc(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<int16_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1, int16_t vd) -> int16_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<int32_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1, int32_t vd) -> int32_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<int64_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1, int64_t vd) -> int64_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening unsigned-signed multiply and add. |
| void Vwmaccus(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<int16_t, int8_t, uint8_t>( |
| rv_vector, inst, [](int8_t vs2, uint8_t vs1, int16_t vd) -> int16_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<int32_t, int16_t, uint16_t>( |
| rv_vector, inst, |
| [](int16_t vs2, uint16_t vs1, int32_t vd) -> int32_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<int64_t, int32_t, uint32_t>( |
| rv_vector, inst, |
| [](int32_t vs2, uint32_t vs1, int64_t vd) -> int64_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Widening signed-unsigned multiply and add. |
| void Vwmaccsu(const Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) |
| << "Vector length multiplier out of range for widening operation"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVTernaryVectorOp<int16_t, uint8_t, int8_t>( |
| rv_vector, inst, [](uint8_t vs2, int8_t vs1, int16_t vd) -> int16_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 2: |
| return RiscVTernaryVectorOp<int32_t, uint16_t, int16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, int16_t vs1, int32_t vd) -> int32_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| case 4: |
| return RiscVTernaryVectorOp<int64_t, uint32_t, int32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, int32_t vs1, int64_t vd) -> int64_t { |
| return VwmaccHelper(vs2, vs1, vd); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |