blob: b5bad6e8c38adf7506273a10ff26ca506f8f33ef [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cheriot/riscv_cheriot_vector_permute_instructions.h"
#include <algorithm>
#include <cstdint>
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "cheriot/cheriot_register.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/cheriot_vector_state.h"
#include "mpact/sim/generic/data_buffer.h"
#include "mpact/sim/generic/instruction.h"
#include "riscv//riscv_register.h"
namespace mpact {
namespace sim {
namespace cheriot {
using ::mpact::sim::riscv::RV32VectorDestinationOperand;
using ::mpact::sim::riscv::RV32VectorSourceOperand;
// This helper function handles the vector gather operations.
template <typename Vd, typename Vs2, typename Vs1>
void VrgatherHelper(CheriotVectorState *rv_vector, Instruction *inst) {
if (rv_vector->vector_exception()) return;
int num_elements = rv_vector->vector_length();
int elements_per_vector =
rv_vector->vector_register_byte_length() / sizeof(Vd);
// Verify that the lmul is compatible with index size.
int index_emul =
rv_vector->vector_length_multiplier() * sizeof(Vs1) / sizeof(Vd);
if (index_emul > 64) {
rv_vector->set_vector_exception();
return;
}
int max_regs = std::max(
1, (num_elements + elements_per_vector - 1) / elements_per_vector);
auto *dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
// Verify that there are enough registers in the destination operand.
if (dest_op->size() < max_regs) {
rv_vector->set_vector_exception();
LOG(ERROR) << absl::StrCat(
"Vector destination '", dest_op->AsString(), "' has fewer registers (",
dest_op->size(), ") than required by the operation (", max_regs, ")");
return;
}
// Get the vector mask.
auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
// Get the vector start element index and compute the where to start
// the operation.
int vector_index = rv_vector->vstart();
int start_reg = vector_index / elements_per_vector;
int item_index = vector_index % elements_per_vector;
// Determine if it's vector-vector or vector-scalar.
bool vector_scalar = inst->Source(1)->shape()[0] == 1;
auto src0_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
int max_index = src0_op->size() * elements_per_vector;
// Iterate over the number of registers to write.
for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements);
reg++) {
// Allocate data buffer for the new register data.
auto *dest_db = dest_op->CopyDataBuffer(reg);
auto dest_span = dest_db->Get<Vd>();
// Write data into register subject to masking.
int element_count = std::min(elements_per_vector, num_elements);
for (int i = item_index;
(i < element_count) && (vector_index < num_elements); i++) {
// Get the mask value.
int mask_index = i >> 3;
int mask_offset = i & 0b111;
bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0;
if (mask_value) {
// Compute result.
CheriotRegister::ValueType vs1;
if (vector_scalar) {
vs1 = generic::GetInstructionSource<CheriotRegister::ValueType>(inst,
1, 0);
} else {
vs1 = generic::GetInstructionSource<Vs1>(inst, 1, vector_index);
}
Vs2 vs2 = 0;
if (vs1 < max_index) {
vs2 = generic::GetInstructionSource<Vs2>(inst, 0, vs1);
}
dest_span[i] = vs2;
}
vector_index++;
}
// Submit the destination db .
dest_db->Submit();
item_index = 0;
}
rv_vector->clear_vstart();
}
// Vector register gather.
void Vrgather(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 1:
return VrgatherHelper<uint8_t, uint8_t, uint8_t>(rv_vector, inst);
case 2:
return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst);
case 4:
return VrgatherHelper<uint32_t, uint32_t, uint32_t>(rv_vector, inst);
case 8:
return VrgatherHelper<uint64_t, uint64_t, uint64_t>(rv_vector, inst);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// Vector register gather with 16 bit indices.
void Vrgatherei16(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 1:
return VrgatherHelper<uint8_t, uint8_t, uint16_t>(rv_vector, inst);
case 2:
return VrgatherHelper<uint16_t, uint16_t, uint16_t>(rv_vector, inst);
case 4:
return VrgatherHelper<uint32_t, uint32_t, uint16_t>(rv_vector, inst);
case 8:
return VrgatherHelper<uint64_t, uint64_t, uint16_t>(rv_vector, inst);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// This helper function handles the vector slide up/down instructions.
template <typename Vd>
void VSlideHelper(CheriotVectorState *rv_vector, Instruction *inst,
int offset) {
if (rv_vector->vector_exception()) return;
int num_elements = rv_vector->vector_length();
int elements_per_vector =
rv_vector->vector_register_byte_length() / sizeof(Vd);
int max_regs = std::max(
1, (num_elements + elements_per_vector - 1) / elements_per_vector);
auto *dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
// Verify that there are enough registers in the destination operand.
if (dest_op->size() < max_regs) {
rv_vector->set_vector_exception();
LOG(ERROR) << absl::StrCat(
"Vector destination '", dest_op->AsString(), "' has fewer registers (",
dest_op->size(), ") than required by the operation (", max_regs, ")");
return;
}
// Get the vector mask.
auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
// Get the vector start element index and compute the where to start
// the operation.
int vector_index = rv_vector->vstart();
int start_reg = vector_index / elements_per_vector;
int item_index = vector_index % elements_per_vector;
// Iterate over the number of registers to write.
for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements);
reg++) {
// Allocate data buffer for the new register data.
auto *dest_db = dest_op->CopyDataBuffer(reg);
auto dest_span = dest_db->Get<Vd>();
// Write data into register subject to masking.
int element_count = std::min(elements_per_vector, num_elements);
for (int i = item_index;
(i < element_count) && (vector_index < num_elements); i++) {
// Get the mask value.
int mask_index = i >> 3;
int mask_offset = i & 0b111;
bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1);
int src_index = vector_index - offset;
if ((src_index >= 0) && (mask_value)) {
// Compute result.
Vd src_value = 0;
if (src_index < rv_vector->max_vector_length()) {
src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index);
}
dest_span[i] = src_value;
}
vector_index++;
}
// Submit the destination db .
dest_db->Submit();
item_index = 0;
}
rv_vector->clear_vstart();
}
void Vslideup(Instruction *inst) {
using ValueType = CheriotRegister::ValueType;
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0);
int int_offset = static_cast<int>(offset);
if (offset > rv_vector->max_vector_length()) return;
// Slide up amount is positive.
switch (sew) {
case 1:
return VSlideHelper<uint8_t>(rv_vector, inst, int_offset);
case 2:
return VSlideHelper<uint16_t>(rv_vector, inst, int_offset);
case 4:
return VSlideHelper<uint32_t>(rv_vector, inst, int_offset);
case 8:
return VSlideHelper<uint64_t>(rv_vector, inst, int_offset);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
void Vslidedown(Instruction *inst) {
using ValueType = CheriotRegister::ValueType;
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
auto offset = generic::GetInstructionSource<ValueType>(inst, 1, 0);
// Slide down amount is negative.
int int_offset = -static_cast<int>(offset);
switch (sew) {
case 1:
return VSlideHelper<uint8_t>(rv_vector, inst, int_offset);
case 2:
return VSlideHelper<uint16_t>(rv_vector, inst, int_offset);
case 4:
return VSlideHelper<uint32_t>(rv_vector, inst, int_offset);
case 8:
return VSlideHelper<uint64_t>(rv_vector, inst, int_offset);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// This helper function handles the vector slide up/down 1 instructions.
template <typename Vd>
void VSlide1Helper(CheriotVectorState *rv_vector, Instruction *inst,
int offset) {
if (rv_vector->vector_exception()) return;
int num_elements = rv_vector->vector_length();
int elements_per_vector =
rv_vector->vector_register_byte_length() / sizeof(Vd);
int max_regs = std::max(
1, (num_elements + elements_per_vector - 1) / elements_per_vector);
auto *dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
// Verify that there are enough registers in the destination operand.
if (dest_op->size() < max_regs) {
rv_vector->set_vector_exception();
LOG(ERROR) << absl::StrCat(
"Vector destination '", dest_op->AsString(), "' has fewer registers (",
dest_op->size(), ") than required by the operation (", max_regs, ")");
return;
}
// Get the vector mask.
auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(2));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
// Get the vector start element index and compute the where to start
// the operation.
int vector_index = rv_vector->vstart();
int start_reg = vector_index / elements_per_vector;
int item_index = vector_index % elements_per_vector;
auto slide_value = generic::GetInstructionSource<Vd>(inst, 1, 0);
// Iterate over the number of registers to write.
for (int reg = start_reg; (reg < max_regs) && (vector_index < num_elements);
reg++) {
// Allocate data buffer for the new register data.
auto *dest_db = dest_op->CopyDataBuffer(reg);
auto dest_span = dest_db->Get<Vd>();
// Write data into register subject to masking.
int element_count = std::min(elements_per_vector, num_elements);
for (int i = item_index;
(i < element_count) && (vector_index < num_elements); i++) {
// Get the mask value.
int mask_index = i >> 3;
int mask_offset = i & 0b111;
bool mask_value = ((mask_span[mask_index] >> mask_offset) & 0b1) != 0;
if (mask_value) {
// Compute result.
Vd src_value = slide_value;
int src_index = vector_index - offset;
if ((src_index > 0) && (src_index < rv_vector->max_vector_length())) {
src_value = generic::GetInstructionSource<Vd>(inst, 0, src_index);
}
dest_span[i] = src_value;
}
vector_index++;
}
// Submit the destination db .
dest_db->Submit();
item_index = 0;
}
rv_vector->clear_vstart();
}
void Vslide1up(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 1:
return VSlide1Helper<uint8_t>(rv_vector, inst, 1);
case 2:
return VSlide1Helper<uint16_t>(rv_vector, inst, 1);
case 4:
return VSlide1Helper<uint32_t>(rv_vector, inst, 1);
case 8:
return VSlide1Helper<uint64_t>(rv_vector, inst, 1);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
void Vslide1down(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 1:
return VSlide1Helper<uint8_t>(rv_vector, inst, -1);
case 2:
return VSlide1Helper<uint16_t>(rv_vector, inst, -1);
case 4:
return VSlide1Helper<uint32_t>(rv_vector, inst, -1);
case 8:
return VSlide1Helper<uint64_t>(rv_vector, inst, -1);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
void Vfslide1up(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return VSlide1Helper<uint32_t>(rv_vector, inst, 1);
case 8:
return VSlide1Helper<uint64_t>(rv_vector, inst, 1);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
void Vfslide1down(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return VSlide1Helper<uint32_t>(rv_vector, inst, -1);
case 8:
return VSlide1Helper<uint64_t>(rv_vector, inst, -1);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
template <typename Vd>
void VCompressHelper(CheriotVectorState *rv_vector, Instruction *inst) {
if (rv_vector->vector_exception()) return;
int num_elements = rv_vector->vector_length();
int elements_per_vector =
rv_vector->vector_register_byte_length() / sizeof(Vd);
int max_regs = std::max(
1, (num_elements + elements_per_vector - 1) / elements_per_vector);
auto *dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
// Verify that there are enough registers in the destination operand.
if (dest_op->size() < max_regs) {
rv_vector->set_vector_exception();
LOG(ERROR) << absl::StrCat(
"Vector destination '", dest_op->AsString(), "' has fewer registers (",
dest_op->size(), ") than required by the operation (", max_regs, ")");
return;
}
// Get the vector mask.
auto *mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
// Get the vector start element index and compute the where to start
// the operation.
int vector_index = rv_vector->vstart();
int dest_index = 0;
int prev_reg = -1;
absl::Span<Vd> dest_span;
generic::DataBuffer *dest_db = nullptr;
// Iterate over the input elements.
for (int i = vector_index; i < num_elements; i++) {
// Get mask value.
int mask_index = i >> 3;
int mask_offset = i & 0b111;
bool mask_value = (mask_span[mask_index] >> mask_offset) & 0b1;
if (mask_value) {
// Compute destination register.
int reg = dest_index / elements_per_vector;
if (prev_reg != reg) {
// Submit previous data buffer if needed.
if (dest_db != nullptr) dest_db->Submit();
// Allocate a data buffer.
dest_db = dest_op->CopyDataBuffer(reg);
dest_span = dest_db->Get<Vd>();
prev_reg = reg;
}
// Copy the source value to the dest_index.
Vd src_value = generic::GetInstructionSource<Vd>(inst, 0, i);
dest_span[dest_index % elements_per_vector] = src_value;
++dest_index;
}
}
if (dest_db != nullptr) dest_db->Submit();
rv_vector->clear_vstart();
}
void Vcompress(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 1:
return VCompressHelper<uint8_t>(rv_vector, inst);
case 2:
return VCompressHelper<uint16_t>(rv_vector, inst);
case 4:
return VCompressHelper<uint32_t>(rv_vector, inst);
case 8:
return VCompressHelper<uint64_t>(rv_vector, inst);
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
} // namespace cheriot
} // namespace sim
} // namespace mpact