| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_opi_instructions.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| #include <cstring> |
| #include <limits> |
| #include <type_traits> |
| |
| #include "absl/log/log.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/cheriot_vector_state.h" |
| #include "cheriot/riscv_cheriot_vector_instruction_helpers.h" |
| #include "mpact/sim/generic/type_helpers.h" |
| #include "riscv//riscv_register.h" |
| |
| // This file contains the instruction semantic functions for most of the |
| // vector instructions in the OPIVV, OPIVX, and OPIVI encoding spaces. The |
| // exception is vector element permute instructions and a couple of reduction |
| // instructions. |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using ::mpact::sim::generic::MakeUnsigned; |
| using ::mpact::sim::generic::WideType; |
| using riscv::RV32VectorSourceOperand; |
| using std::numeric_limits; |
| |
| // Vector arithmetic operations. |
| |
| // Vector add. |
| void Vadd(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 + vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 + vs1; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 + vs1; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 + vs1; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector subtract. |
| void Vsub(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 - vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 - vs1; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 - vs1; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 - vs1; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector reverse subtract. |
| void Vrsub(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs1 - vs2; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs1 - vs2; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs1 - vs2; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs1 - vs2; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector logical operations. |
| |
| // Vector and. |
| void Vand(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 & vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 & vs1; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 & vs1; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 & vs1; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector or. |
| void Vor(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 | vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 | vs1; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 | vs1; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 | vs1; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector xor. |
| void Vxor(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> uint8_t { return vs2 ^ vs1; }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> uint16_t { return vs2 ^ vs1; }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> uint32_t { return vs2 ^ vs1; }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> uint64_t { return vs2 ^ vs1; }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector shift operations. |
| |
| // Vector shift left logical. |
| void Vsll(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 << (vs1 & 0b111); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return vs2 << (vs1 & 0b1111); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return vs2 << (vs1 & 0b1'1111); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return vs2 << (vs1 & 0b11'1111); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector shift right logical. |
| void Vsrl(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return vs2 >> (vs1 & 0b111); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return vs2 >> (vs1 & 0b1111); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return vs2 >> (vs1 & 0b1'1111); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return vs2 >> (vs1 & 0b11'1111); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector shift right arithmetic. |
| void Vsra(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [](int8_t vs2, int8_t vs1) -> int8_t { |
| return vs2 >> (vs1 & 0b111); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| return vs2 >> (vs1 & 0b1111); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| return vs2 >> (vs1 & 0b1'1111); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| return vs2 >> (vs1 & 0b11'1111); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector narrowing shift operations. Narrow from sew * 2 to sew. |
| |
| // Vector narrowing shift right logical. Source op 0 is shifted right |
| // by source op 1 and the result is 1/2 the size of source op 0. |
| void Vnsrl(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // LMUL8 cannot be 64. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Vector length multiplier out of range for narrowing shift"; |
| return; |
| } |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint16_t, uint8_t>( |
| rv_vector, inst, [](uint16_t vs2, uint8_t vs1) -> uint8_t { |
| return static_cast<uint8_t>(vs2 >> (vs1 & 0b1111)); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint32_t, uint16_t>( |
| rv_vector, inst, [](uint32_t vs2, uint16_t vs1) -> uint16_t { |
| return static_cast<uint16_t>(vs2 >> (vs1 & 0b1'1111)); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint64_t, uint32_t>( |
| rv_vector, inst, [](uint64_t vs2, uint32_t vs1) -> uint32_t { |
| return static_cast<uint32_t>(vs2 >> (vs1 & 0b11'1111)); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value for narrowing shift right: " << sew; |
| return; |
| } |
| } |
| |
| // Vector narrowing shift right arithmetic. Source op 0 is shifted right |
| // by source op 1 and the result is 1/2 the size of source op 0. |
| void Vnsra(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| // If the vector length multiplier (x8) is greater than 32, that means that |
| // the source values (sew * 2) would exceed the available register group. |
| if (rv_vector->vector_length_multiplier() > 32) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Vector length multiplier out of range for narrowing shift"; |
| return; |
| } |
| // Note, sew cannot be 64 bits, as there is no support for operations on |
| // 128 bit quantities. |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int16_t, int8_t>( |
| rv_vector, inst, [](int16_t vs2, int8_t vs1) -> int8_t { |
| return vs2 >> (vs1 & 0b1111); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int32_t, int16_t>( |
| rv_vector, inst, [](int32_t vs2, int16_t vs1) -> int16_t { |
| return vs2 >> (vs1 & 0b1'1111); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int64_t, int32_t>( |
| rv_vector, inst, [](int64_t vs2, int32_t vs1) -> int32_t { |
| return vs2 >> (vs1 & 0b11'1111); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value for narrowing shift right: " << sew; |
| return; |
| } |
| } |
| |
| // Vector unsigned min. |
| void Vminu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return std::min(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return std::min(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return std::min(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return std::min(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector signed min. |
| void Vmin(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2, int8_t vs1) -> int8_t { return std::min(vs2, vs1); }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| return std::min(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| return std::min(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| return std::min(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector unsigned max. |
| void Vmaxu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return std::max(vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return std::max(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return std::max(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return std::max(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector signed max. |
| void Vmax(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2, int8_t vs1) -> int8_t { return std::max(vs2, vs1); }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [](int16_t vs2, int16_t vs1) -> int16_t { |
| return std::max(vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [](int32_t vs2, int32_t vs1) -> int32_t { |
| return std::max(vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [](int64_t vs2, int64_t vs1) -> int64_t { |
| return std::max(vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set equal. |
| void Vmseq(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 == vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 == vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 == vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 == vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector compare instructions. |
| |
| // Set not equal. |
| void Vmsne(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 != vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 != vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 != vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 != vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set less than unsigned. |
| void Vmsltu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 < vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 < vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 < vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 < vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set less than. |
| void Vmslt(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<int8_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2, int8_t vs1) -> bool { return (vs2 < vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<int16_t, int16_t>( |
| rv_vector, inst, |
| [](int16_t vs2, int16_t vs1) -> bool { return (vs2 < vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<int32_t, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2, int32_t vs1) -> bool { return (vs2 < vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<int64_t, int64_t>( |
| rv_vector, inst, |
| [](int64_t vs2, int64_t vs1) -> bool { return (vs2 < vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set less than or equal unsigned. |
| void Vmsleu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 <= vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set less than or equal. |
| void Vmsle(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<int8_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2, int8_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<int16_t, int16_t>( |
| rv_vector, inst, |
| [](int16_t vs2, int16_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<int32_t, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2, int32_t vs1) -> bool { return (vs2 <= vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<int64_t, int64_t>( |
| rv_vector, inst, |
| [](int64_t vs2, int64_t vs1) -> bool { return (vs2 <= vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set greater than unsigned. |
| void Vmsgtu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2, uint8_t vs1) -> bool { return (vs2 > vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1) -> bool { return (vs2 > vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1) -> bool { return (vs2 > vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1) -> bool { return (vs2 > vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Set greater than. |
| void Vmsgt(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorMaskOp<int8_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2, int8_t vs1) -> bool { return (vs2 > vs1); }); |
| case 2: |
| return RiscVBinaryVectorMaskOp<int16_t, int16_t>( |
| rv_vector, inst, |
| [](int16_t vs2, int16_t vs1) -> bool { return (vs2 > vs1); }); |
| case 4: |
| return RiscVBinaryVectorMaskOp<int32_t, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2, int32_t vs1) -> bool { return (vs2 > vs1); }); |
| case 8: |
| return RiscVBinaryVectorMaskOp<int64_t, int64_t>( |
| rv_vector, inst, |
| [](int64_t vs2, int64_t vs1) -> bool { return (vs2 > vs1); }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Saturated unsigned addition. |
| void Vsaddu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| uint8_t sum = vs2 + vs1; |
| if (sum < vs1) { |
| sum = numeric_limits<uint8_t>::max(); |
| rv_vector->set_vxsat(true); |
| } |
| return sum; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| uint16_t sum = vs2 + vs1; |
| if (sum < vs1) { |
| sum = numeric_limits<uint16_t>::max(); |
| rv_vector->set_vxsat(true); |
| } |
| return sum; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| uint32_t sum = vs2 + vs1; |
| if (sum < vs1) { |
| sum = numeric_limits<uint32_t>::max(); |
| rv_vector->set_vxsat(true); |
| } |
| return sum; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| uint64_t sum = vs2 + vs1; |
| if (sum < vs1) { |
| sum = numeric_limits<uint64_t>::max(); |
| rv_vector->set_vxsat(true); |
| } |
| return sum; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Helper function for Vsadd. |
| // Uses unsigned arithmetic for the addition to avoid signed overflow, which, |
| // when compiled with --config=asan, will trigger an exception. |
| template <typename T> |
| inline T VsaddHelper(T vs2, T vs1, CheriotVectorState *rv_vector) { |
| using UT = typename std::make_unsigned<T>::type; |
| UT uvs2 = static_cast<UT>(vs2); |
| UT uvs1 = static_cast<UT>(vs1); |
| UT usum = uvs2 + uvs1; |
| T sum = static_cast<T>(usum); |
| if (((vs2 ^ vs1) >= 0) && ((sum ^ vs2) < 0)) { |
| rv_vector->set_vxsat(true); |
| return vs2 > 0 ? numeric_limits<T>::max() : numeric_limits<T>::min(); |
| } |
| return sum; |
| } |
| |
| // Saturated signed addition. |
| void Vsadd(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VsaddHelper(vs2, vs1, rv_vector); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VsaddHelper(vs2, vs1, rv_vector); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VsaddHelper(vs2, vs1, rv_vector); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VsaddHelper(vs2, vs1, rv_vector); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Saturated unsigned subtract. |
| void Vssubu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| uint8_t diff = vs2 - vs1; |
| if (vs2 < vs1) { |
| diff = 0; |
| rv_vector->set_vxsat(true); |
| } |
| return diff; |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| uint16_t diff = vs2 - vs1; |
| if (vs2 < vs1) { |
| diff = 0; |
| rv_vector->set_vxsat(true); |
| } |
| return diff; |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| uint32_t diff = vs2 - vs1; |
| if (vs2 < vs1) { |
| diff = 0; |
| rv_vector->set_vxsat(true); |
| } |
| return diff; |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| uint64_t diff = vs2 - vs1; |
| if (vs2 < vs1) { |
| diff = 0; |
| rv_vector->set_vxsat(true); |
| } |
| return diff; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| template <typename T> |
| T VssubHelper(T vs2, T vs1, CheriotVectorState *rv_vector) { |
| using UT = typename std::make_unsigned<T>::type; |
| UT uvs2 = static_cast<UT>(vs2); |
| UT uvs1 = static_cast<UT>(vs1); |
| UT udiff = uvs2 - uvs1; |
| T diff = static_cast<T>(udiff); |
| if (((vs2 ^ vs1) < 0) && ((diff ^ vs1) >= 0)) { |
| rv_vector->set_vxsat(true); |
| return vs1 < 0 ? numeric_limits<T>::max() : numeric_limits<T>::min(); |
| } |
| return diff; |
| } |
| |
| // Saturated signed subtract. |
| void Vssub(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VssubHelper(vs2, vs1, rv_vector); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VssubHelper(vs2, vs1, rv_vector); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VssubHelper(vs2, vs1, rv_vector); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VssubHelper(vs2, vs1, rv_vector); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Add/Subtract with carry, carry generation. |
| void Vadc(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { |
| return vs2 + vs1 + static_cast<uint8_t>(mask); |
| }); |
| case 2: |
| return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { |
| return vs2 + vs1 + static_cast<uint16_t>(mask); |
| }); |
| case 4: |
| return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { |
| return vs2 + vs1 + static_cast<uint32_t>(mask); |
| }); |
| case 8: |
| return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { |
| return vs2 + vs1 + static_cast<uint64_t>(mask); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Add with carry - carry generation. |
| void Vmadc(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVSetMaskBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> bool { |
| uint16_t sum = static_cast<uint16_t>(vs2) + |
| static_cast<uint16_t>(vs1) + |
| static_cast<uint16_t>(mask); |
| sum >>= 8; |
| return sum; |
| }); |
| case 2: |
| return RiscVSetMaskBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1, bool mask) -> bool { |
| uint32_t sum = static_cast<uint32_t>(vs2) + |
| static_cast<uint32_t>(vs1) + |
| static_cast<uint32_t>(mask); |
| sum >>= 16; |
| return sum != 0; |
| }); |
| case 4: |
| return RiscVSetMaskBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1, bool mask) -> bool { |
| uint64_t sum = static_cast<uint64_t>(vs2) + |
| static_cast<uint64_t>(vs1) + |
| static_cast<uint64_t>(mask); |
| sum >>= 32; |
| return sum != 0; |
| }); |
| case 8: |
| return RiscVSetMaskBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1, bool mask) -> bool { |
| // Compute carry by doing two additions. First get the carry out |
| // from adding the low byte. |
| uint64_t carry = |
| (vs1 & 0xff + vs2 & 0xff + static_cast<uint64_t>(mask)) >> 8; |
| // Now add the high 7 bytes together with the carry from the low |
| // byte addition. |
| uint64_t sum = (vs1 >> 8) + (vs2 >> 8) + carry; |
| // The carry out is in the high byte. |
| sum >>= 56; |
| return sum != 0; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Subtract with borrow. |
| void Vsbc(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { |
| return vs2 - vs1 - static_cast<uint8_t>(mask); |
| }); |
| case 2: |
| return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { |
| return vs2 - vs1 - static_cast<uint16_t>(mask); |
| }); |
| case 4: |
| return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { |
| return vs2 - vs1 - static_cast<uint32_t>(mask); |
| }); |
| case 8: |
| return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { |
| return vs2 - vs1 - static_cast<uint64_t>(mask); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Subtract with borrow - borrow generation. |
| void Vmsbc(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVSetMaskBinaryVectorMaskOp<uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> bool { |
| return static_cast<uint16_t>(vs2) < |
| static_cast<uint16_t>(mask) + static_cast<uint16_t>(vs1); |
| }); |
| case 2: |
| return RiscVSetMaskBinaryVectorMaskOp<uint16_t, uint16_t>( |
| rv_vector, inst, [](uint16_t vs2, uint16_t vs1, bool mask) -> bool { |
| return static_cast<uint32_t>(vs2) < |
| static_cast<uint32_t>(mask) + static_cast<uint32_t>(vs1); |
| }); |
| case 4: |
| return RiscVSetMaskBinaryVectorMaskOp<uint32_t, uint32_t>( |
| rv_vector, inst, [](uint32_t vs2, uint32_t vs1, bool mask) -> bool { |
| return static_cast<uint64_t>(vs2) < |
| static_cast<uint64_t>(mask) + static_cast<uint64_t>(vs1); |
| }); |
| case 8: |
| return RiscVSetMaskBinaryVectorMaskOp<uint64_t, uint64_t>( |
| rv_vector, inst, [](uint64_t vs2, uint64_t vs1, bool mask) -> bool { |
| if (vs2 < vs1) return true; |
| if (vs2 == vs1) return mask; |
| return false; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector merge. |
| void Vmerge(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVMaskBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [](uint8_t vs2, uint8_t vs1, bool mask) -> uint8_t { |
| return mask ? vs1 : vs2; |
| }); |
| case 2: |
| return RiscVMaskBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2, uint16_t vs1, bool mask) -> uint16_t { |
| return mask ? vs1 : vs2; |
| }); |
| case 4: |
| return RiscVMaskBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2, uint32_t vs1, bool mask) -> uint32_t { |
| return mask ? vs1 : vs2; |
| }); |
| case 8: |
| return RiscVMaskBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, |
| [](uint64_t vs2, uint64_t vs1, bool mask) -> uint64_t { |
| return mask ? vs1 : vs2; |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector move register(s). |
| void Vmvr(int num_regs, Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vector_exception()) return; |
| |
| auto *src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto *dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| if (src_op->size() < num_regs) { |
| LOG(ERROR) << "Vmvr: source operand has fewer registers than requested"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| if (dest_op->size() < num_regs) { |
| LOG(ERROR) |
| << "Vmvr: destination operand has fewer registers than requested"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| int num_elements_per_vector = rv_vector->vector_register_byte_length() / sew; |
| int vstart = rv_vector->vstart(); |
| int start_reg = vstart / num_elements_per_vector; |
| for (int i = start_reg; i < num_regs; i++) { |
| auto *src_db = src_op->GetRegister(i)->data_buffer(); |
| auto *dest_db = dest_op->AllocateDataBuffer(i); |
| std::memcpy(dest_db->raw_ptr(), src_db->raw_ptr(), |
| dest_db->size<uint8_t>()); |
| dest_db->Submit(); |
| } |
| rv_vector->clear_vstart(); |
| } |
| |
| // Templated helper function for shift right with rounding. |
| template <typename T> |
| T VssrHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { |
| using UT = typename MakeUnsigned<T>::type; |
| int rm = rv_vector->vxrm(); |
| int max_shift = (sizeof(T) << 3) - 1; |
| int shift_amount = static_cast<int>(vs1 & max_shift); |
| // Create mask for the bits that will be shifted out + 1. |
| UT round_bits = vs2; |
| if (shift_amount < max_shift) { |
| UT mask = numeric_limits<UT>::max(); |
| mask = ~(numeric_limits<UT>::max() << shift_amount + 1); |
| round_bits = vs2 & mask; |
| } |
| vs2 >>= shift_amount; |
| vs2 += static_cast<T>(GetRoundingBit(rm, round_bits, shift_amount + 1)); |
| return vs2; |
| } |
| |
| // Logical shift right with rounding. |
| void Vssrl(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint8_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint8_t vs2, uint8_t vs1) -> uint8_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint16_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint16_t vs1) -> uint16_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint32_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint32_t vs1) -> uint32_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<uint64_t, uint64_t, uint64_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint64_t vs1) -> uint64_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Arithmetic shift right with rounding. |
| void Vssra(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VssrHelper(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Templated helper function for shift right with rounding and saturation. |
| template <typename DT, typename WT, typename T> |
| T VnclipHelper(CheriotVectorState *rv_vector, WT vs2, T vs1) { |
| using WUT = typename std::make_unsigned<WT>::type; |
| int rm = rv_vector->vxrm(); |
| int max_shift = (sizeof(WT) << 3) - 1; |
| int shift_amount = vs1 & ((sizeof(WT) << 3) - 1); |
| // Create mask for the bits that will be shifted out + 1. |
| WUT mask = vs2; |
| if (shift_amount < max_shift) { |
| mask = ~(numeric_limits<WUT>::max() << (shift_amount + 1)); |
| } |
| WUT round_bits = vs2 & mask; |
| // Perform the rounded shift. |
| vs2 = |
| (vs2 >> shift_amount) + GetRoundingBit(rm, round_bits, shift_amount + 1); |
| // Saturate if needed. |
| if (vs2 > numeric_limits<DT>::max()) { |
| rv_vector->set_vxsat(true); |
| return numeric_limits<DT>::max(); |
| } |
| if (vs2 < numeric_limits<DT>::min()) { |
| rv_vector->set_vxsat(true); |
| return numeric_limits<DT>::min(); |
| } |
| return static_cast<DT>(vs2); |
| } |
| |
| // Arithmetic shift right and narrowing from 2*sew to sew with rounding and |
| // signed saturation. |
| void Vnclip(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int lmul8 = rv_vector->vector_length_multiplier(); |
| // This is a narrowing operation and sew is that of the narrow data type. |
| // Thus if lmul > 32, then emul for the wider data type is illegal. |
| if (lmul8 > 32) { |
| LOG(ERROR) << "Illegal lmul value"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int16_t, int8_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int8_t vs1) -> int8_t { |
| return VnclipHelper<int8_t, int16_t, int8_t>(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int32_t, int16_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int16_t vs1) -> int16_t { |
| return VnclipHelper<int16_t, int32_t, int16_t>(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int64_t, int32_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int32_t vs1) -> int32_t { |
| return VnclipHelper<int32_t, int64_t, int32_t>(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| // There is no valid sew * 2 = 16. |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Logical shift right and narrowing from 2*sew to sew with rounding and |
| // unsigned saturation. |
| void Vnclipu(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int lmul8 = rv_vector->vector_length_multiplier(); |
| // This is a narrowing operation and sew is that of the narrow data type. |
| // Thus if lmul > 32, then emul for the wider data type is illegal. |
| if (lmul8 > 32) { |
| LOG(ERROR) << "Illegal lmul value"; |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<uint8_t, uint16_t, uint8_t>( |
| rv_vector, inst, [rv_vector](uint16_t vs2, uint8_t vs1) -> uint8_t { |
| return VnclipHelper<uint8_t, uint16_t, uint8_t>(rv_vector, vs2, |
| vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<uint16_t, uint32_t, uint16_t>( |
| rv_vector, inst, [rv_vector](uint32_t vs2, uint16_t vs1) -> uint16_t { |
| return VnclipHelper<uint16_t, uint32_t, uint16_t>(rv_vector, vs2, |
| vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<uint32_t, uint64_t, uint32_t>( |
| rv_vector, inst, [rv_vector](uint64_t vs2, uint32_t vs1) -> uint32_t { |
| return VnclipHelper<uint32_t, uint64_t, uint32_t>(rv_vector, vs2, |
| vs1); |
| }); |
| case 8: |
| // There is no valid sew * 2 = 16. |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Perform a signed multiply from T to wider int type. Shift that result |
| // right by sizeof(T) * 8 - 1 and round. Saturate if needed to fit into T. |
| template <typename T> |
| T VsmulHelper(CheriotVectorState *rv_vector, T vs2, T vs1) { |
| using WT = typename WideType<T>::type; |
| WT vd_w; |
| WT vs2_w = static_cast<WT>(vs2); |
| WT vs1_w = static_cast<WT>(vs1); |
| vd_w = vs2_w * vs1_w; |
| vd_w = VssrHelper<WT>(rv_vector, vd_w, sizeof(T) * 8 - 1); |
| if (vd_w < numeric_limits<T>::min()) { |
| rv_vector->set_vxsat(true); |
| return numeric_limits<T>::min(); |
| } |
| if (vd_w > numeric_limits<T>::max()) { |
| rv_vector->set_vxsat(true); |
| return numeric_limits<T>::max(); |
| } |
| return static_cast<T>(vd_w); |
| } |
| |
| // Vector fractional multiply with rounding and saturation. |
| void Vsmul(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return RiscVBinaryVectorOp<int8_t, int8_t, int8_t>( |
| rv_vector, inst, [rv_vector](int8_t vs2, int8_t vs1) -> int8_t { |
| return VsmulHelper<int8_t>(rv_vector, vs2, vs1); |
| }); |
| case 2: |
| return RiscVBinaryVectorOp<int16_t, int16_t, int16_t>( |
| rv_vector, inst, [rv_vector](int16_t vs2, int16_t vs1) -> int16_t { |
| return VsmulHelper<int16_t>(rv_vector, vs2, vs1); |
| }); |
| case 4: |
| return RiscVBinaryVectorOp<int32_t, int32_t, int32_t>( |
| rv_vector, inst, [rv_vector](int32_t vs2, int32_t vs1) -> int32_t { |
| return VsmulHelper<int32_t>(rv_vector, vs2, vs1); |
| }); |
| case 8: |
| return RiscVBinaryVectorOp<int64_t, int64_t, int64_t>( |
| rv_vector, inst, [rv_vector](int64_t vs2, int64_t vs1) -> int64_t { |
| return VsmulHelper<int64_t>(rv_vector, vs2, vs1); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |