| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_permute_instructions.h" |
| |
| #include <algorithm> |
| #include <cstdint> |
| |
| #include "absl/log/log.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/types/span.h" |
| #include "cheriot/cheriot_register.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/cheriot_vector_state.h" |
| #include "mpact/sim/generic/data_buffer.h" |
| #include "mpact/sim/generic/instruction.h" |
| #include "riscv//riscv_register.h" |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using ::mpact::sim::riscv::RV32VectorDestinationOperand; |
| using ::mpact::sim::riscv::RV32VectorSourceOperand; |
| |
| // This helper function handles the vector gather operations. |
| template <typename Vd, typename Vs2, typename Vs1> |
| void VrgatherHelper(CheriotVectorState *rv_vector, Instruction *inst) { |
| if (rv_vector->vector_exception()) return; |
| int num_elements = rv_vector->vector_length(); |
| int elements_per_vector = |
| rv_vector->vector_register_byte_length() / sizeof(Vd); |
| // Verify that the lmul is compatible with index size. |
| int index_emul = |
| rv_vector->vector_length_multiplier() * sizeof(Vs1) / sizeof(Vd); |
| if (index_emul > 64) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int max_regs = std::max( |
| 1, (num_elements + elements_per_vector - 1) / elements_per_vector); |
| auto *dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| // Verify that there are enough registers in the destination operand. |
| if (dest_op->size() < max_regs) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << absl::StrCat( |
| "Vector destination '", dest_op->AsString(), "' has fewer registers (", |
| dest_op->size(), ") than required by the operation (", max_regs, ")"); |
| return; |
| } |
| // Get the vector mask. |
| auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| // Get the vector start element index and compute the where to start |
| // the operation. |
| int vector_index = rv_vector->vstart(); |
| int start_reg = vector_index / elements_per_vector; |
| int item_index = vector_index % elements_per_vector; |
| // Determine if it's vector-vector or vector-scalar. |
| bool vector_scalar = inst->Source(1)->shape()[0] == 1; |
| auto src0_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| int max_index = src0_op->size() * elements_per_vector; |
| // Iterate over the number of registers to write. |
| for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); |
| reg++) { |
| // Allocate data buffer for the new register data. |
| auto *dest_db = dest_op->CopyDataBuffer(reg); |
| auto dest_span = dest_db->Get<Vd>(); |
| // Write data into register subject to masking. |
| int element_count = std::min(elements_per_vector, num_elements); |
| for (int i = item_index; |
| (i < element_count) && (vector_index < num_elements); i++) { |
| // Get the mask value. |
| int mask_index = i >> 3; |
| int mask_offset = i & 0b111; |
| bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; |
| if (mask_value) { |
| // Compute result. |
| CheriotRegister::ValueType vs1; |
| if (vector_scalar) { |
| vs1 = generic::GetInstructionSource<CheriotRegister::ValueType>(inst, |
| 1, 0); |
| } else { |
| vs1 = generic::GetInstructionSource<Vs1>(inst, 1, vector_index); |
| } |
| Vs2 vs2 = 0; |
| if (vs1 < max_index) { |
| vs2 = generic::GetInstructionSource<Vs2>(inst, 0, vs1); |
| } |
| dest_span[i] = vs2; |
| } |
| vector_index++; |
| } |
| // Submit the destination db . |
| dest_db->Submit(); |
| item_index = 0; |
| } |
| rv_vector->clear_vstart(); |
| } |
| |
| // Vector register gather. |
| void Vrgather(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return VrgatherHelper<uint8_t, uint8_t, uint8_t>(rv_vector, inst); |
| case 2: |
| return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst); |
| case 4: |
| return VrgatherHelper<uint32_t, uint32_t, uint32_t>(rv_vector, inst); |
| case 8: |
| return VrgatherHelper<uint64_t, uint64_t, uint64_t>(rv_vector, inst); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Vector register gather with 16 bit indices. |
| void Vrgatherei16(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return VrgatherHelper<uint8_t, uint8_t, uint16_t>(rv_vector, inst); |
| case 2: |
| return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst); |
| case 4: |
| return VrgatherHelper<uint32_t, uint32_t, uint16_t>(rv_vector, inst); |
| case 8: |
| return VrgatherHelper<uint64_t, uint64_t, uint16_t>(rv_vector, inst); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // This helper function handles the vector slide up/down instructions. |
| template <typename Vd> |
| void VSlideHelper(CheriotVectorState *rv_vector, Instruction *inst, |
| int offset) { |
| if (rv_vector->vector_exception()) return; |
| int num_elements = rv_vector->vector_length(); |
| int elements_per_vector = |
| rv_vector->vector_register_byte_length() / sizeof(Vd); |
| int max_regs = std::max( |
| 1, (num_elements + elements_per_vector - 1) / elements_per_vector); |
| auto *dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| // Verify that there are enough registers in the destination operand. |
| if (dest_op->size() < max_regs) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << absl::StrCat( |
| "Vector destination '", dest_op->AsString(), "' has fewer registers (", |
| dest_op->size(), ") than required by the operation (", max_regs, ")"); |
| return; |
| } |
| // Get the vector mask. |
| auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| // Get the vector start element index and compute the where to start |
| // the operation. |
| int vector_index = rv_vector->vstart(); |
| int start_reg = vector_index / elements_per_vector; |
| int item_index = vector_index % elements_per_vector; |
| // Iterate over the number of registers to write. |
| for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); |
| reg++) { |
| // Allocate data buffer for the new register data. |
| auto *dest_db = dest_op->CopyDataBuffer(reg); |
| auto dest_span = dest_db->Get<Vd>(); |
| // Write data into register subject to masking. |
| int element_count = std::min(elements_per_vector, num_elements); |
| for (int i = item_index; |
| (i < element_count) && (vector_index < num_elements); i++) { |
| // Get the mask value. |
| int mask_index = i >> 3; |
| int mask_offset = i & 0b111; |
| bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1); |
| int src_index = vector_index - offset; |
| if ((src_index >= 0) && (mask_value)) { |
| // Compute result. |
| Vd src_value = 0; |
| if (src_index < rv_vector->max_vector_length()) { |
| src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index); |
| } |
| dest_span[i] = src_value; |
| } |
| vector_index++; |
| } |
| // Submit the destination db . |
| dest_db->Submit(); |
| item_index = 0; |
| } |
| rv_vector->clear_vstart(); |
| } |
| |
| void Vslideup(Instruction *inst) { |
| using ValueType = CheriotRegister::ValueType; |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0); |
| int int_offset = static_cast<int>(offset); |
| if (offset > rv_vector->max_vector_length()) return; |
| // Slide up amount is positive. |
| switch (sew) { |
| case 1: |
| return VSlideHelper<uint8_t>(rv_vector, inst, int_offset); |
| case 2: |
| return VSlideHelper<uint16_t>(rv_vector, inst, int_offset); |
| case 4: |
| return VSlideHelper<uint32_t>(rv_vector, inst, int_offset); |
| case 8: |
| return VSlideHelper<uint64_t>(rv_vector, inst, int_offset); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| void Vslidedown(Instruction *inst) { |
| using ValueType = CheriotRegister::ValueType; |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0); |
| // Slide down amount is negative. |
| int int_offset = -static_cast<int>(offset); |
| switch (sew) { |
| case 1: |
| return VSlideHelper<uint8_t>(rv_vector, inst, int_offset); |
| case 2: |
| return VSlideHelper<uint16_t>(rv_vector, inst, int_offset); |
| case 4: |
| return VSlideHelper<uint32_t>(rv_vector, inst, int_offset); |
| case 8: |
| return VSlideHelper<uint64_t>(rv_vector, inst, int_offset); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // This helper function handles the vector slide up/down 1 instructions. |
| template <typename Vd> |
| void VSlide1Helper(CheriotVectorState *rv_vector, Instruction *inst, |
| int offset) { |
| if (rv_vector->vector_exception()) return; |
| int num_elements = rv_vector->vector_length(); |
| int elements_per_vector = |
| rv_vector->vector_register_byte_length() / sizeof(Vd); |
| int max_regs = std::max( |
| 1, (num_elements + elements_per_vector - 1) / elements_per_vector); |
| auto *dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| // Verify that there are enough registers in the destination operand. |
| if (dest_op->size() < max_regs) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << absl::StrCat( |
| "Vector destination '", dest_op->AsString(), "' has fewer registers (", |
| dest_op->size(), ") than required by the operation (", max_regs, ")"); |
| return; |
| } |
| // Get the vector mask. |
| auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| // Get the vector start element index and compute the where to start |
| // the operation. |
| int vector_index = rv_vector->vstart(); |
| int start_reg = vector_index / elements_per_vector; |
| int item_index = vector_index % elements_per_vector; |
| auto slide_value = generic::GetInstructionSource<Vd>(inst, 1, 0); |
| // Iterate over the number of registers to write. |
| for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements); |
| reg++) { |
| // Allocate data buffer for the new register data. |
| auto *dest_db = dest_op->CopyDataBuffer(reg); |
| auto dest_span = dest_db->Get<Vd>(); |
| // Write data into register subject to masking. |
| int element_count = std::min(elements_per_vector, num_elements); |
| for (int i = item_index; |
| (i < element_count) && (vector_index < num_elements); i++) { |
| // Get the mask value. |
| int mask_index = i >> 3; |
| int mask_offset = i & 0b111; |
| bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0; |
| if (mask_value) { |
| // Compute result. |
| Vd src_value = slide_value; |
| int src_index = vector_index - offset; |
| if ((src_index > 0) && (src_index < rv_vector->max_vector_length())) { |
| src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index); |
| } |
| dest_span[i] = src_value; |
| } |
| vector_index++; |
| } |
| // Submit the destination db . |
| dest_db->Submit(); |
| item_index = 0; |
| } |
| rv_vector->clear_vstart(); |
| } |
| |
| void Vslide1up(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return VSlide1Helper<uint8_t>(rv_vector, inst, 1); |
| case 2: |
| return VSlide1Helper<uint16_t>(rv_vector, inst, 1); |
| case 4: |
| return VSlide1Helper<uint32_t>(rv_vector, inst, 1); |
| case 8: |
| return VSlide1Helper<uint64_t>(rv_vector, inst, 1); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| void Vslide1down(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return VSlide1Helper<uint8_t>(rv_vector, inst, -1); |
| case 2: |
| return VSlide1Helper<uint16_t>(rv_vector, inst, -1); |
| case 4: |
| return VSlide1Helper<uint32_t>(rv_vector, inst, -1); |
| case 8: |
| return VSlide1Helper<uint64_t>(rv_vector, inst, -1); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| void Vfslide1up(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return VSlide1Helper<uint32_t>(rv_vector, inst, 1); |
| case 8: |
| return VSlide1Helper<uint64_t>(rv_vector, inst, 1); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| void Vfslide1down(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return VSlide1Helper<uint32_t>(rv_vector, inst, -1); |
| case 8: |
| return VSlide1Helper<uint64_t>(rv_vector, inst, -1); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| template <typename Vd> |
| void VCompressHelper(CheriotVectorState *rv_vector, Instruction *inst) { |
| if (rv_vector->vector_exception()) return; |
| int num_elements = rv_vector->vector_length(); |
| int elements_per_vector = |
| rv_vector->vector_register_byte_length() / sizeof(Vd); |
| int max_regs = std::max( |
| 1, (num_elements + elements_per_vector - 1) / elements_per_vector); |
| auto *dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| // Verify that there are enough registers in the destination operand. |
| if (dest_op->size() < max_regs) { |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << absl::StrCat( |
| "Vector destination '", dest_op->AsString(), "' has fewer registers (", |
| dest_op->size(), ") than required by the operation (", max_regs, ")"); |
| return; |
| } |
| // Get the vector mask. |
| auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| // Get the vector start element index and compute the where to start |
| // the operation. |
| int vector_index = rv_vector->vstart(); |
| int dest_index = 0; |
| int prev_reg = -1; |
| absl::Span<Vd> dest_span; |
| generic::DataBuffer *dest_db = nullptr; |
| // Iterate over the input elements. |
| for (int i = vector_index; i < num_elements; i++) { |
| // Get mask value. |
| int mask_index = i >> 3; |
| int mask_offset = i & 0b111; |
| bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1; |
| if (mask_value) { |
| // Compute destination register. |
| int reg = dest_index / elements_per_vector; |
| if (prev_reg != reg) { |
| // Submit previous data buffer if needed. |
| if (dest_db != nullptr) dest_db->Submit(); |
| // Allocate a data buffer. |
| dest_db = dest_op->CopyDataBuffer(reg); |
| dest_span = dest_db->Get<Vd>(); |
| prev_reg = reg; |
| } |
| // Copy the source value to the dest_index. |
| Vd src_value = generic::GetInstructionSource<Vd>(inst, 0, i); |
| dest_span[dest_index % elements_per_vector] = src_value; |
| ++dest_index; |
| } |
| } |
| if (dest_db != nullptr) dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| void Vcompress(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 1: |
| return VCompressHelper<uint8_t>(rv_vector, inst); |
| case 2: |
| return VCompressHelper<uint16_t>(rv_vector, inst); |
| case 4: |
| return VCompressHelper<uint32_t>(rv_vector, inst); |
| case 8: |
| return VCompressHelper<uint64_t>(rv_vector, inst); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |