| // Copyright 2023 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // https://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "cheriot/riscv_cheriot_vector_unary_instructions.h" |
| |
| #include <cstdint> |
| #include <cstring> |
| #include <functional> |
| |
| #include "absl/log/log.h" |
| #include "absl/strings/str_cat.h" |
| #include "cheriot/cheriot_register.h" |
| #include "cheriot/cheriot_state.h" |
| #include "cheriot/cheriot_vector_state.h" |
| #include "cheriot/riscv_cheriot_instruction_helpers.h" |
| #include "cheriot/riscv_cheriot_vector_instruction_helpers.h" |
| #include "mpact/sim/generic/instruction.h" |
| #include "mpact/sim/generic/type_helpers.h" |
| |
| namespace mpact { |
| namespace sim { |
| namespace cheriot { |
| |
| using SignedXregType = |
| ::mpact::sim::generic::SameSignedType<CheriotRegister::ValueType, |
| int64_t>::type; |
| |
| // Move scalar to vector register. |
| void VmvToScalar(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) return; |
| if (rv_vector->vector_length() == 0) return; |
| int sew = rv_vector->selected_element_width(); |
| SignedXregType value; |
| switch (sew) { |
| case 1: |
| value = static_cast<SignedXregType>( |
| generic::GetInstructionSource<int8_t>(inst, 0)); |
| break; |
| case 2: |
| value = static_cast<SignedXregType>( |
| generic::GetInstructionSource<int16_t>(inst, 0)); |
| break; |
| case 4: |
| value = static_cast<SignedXregType>( |
| generic::GetInstructionSource<int32_t>(inst, 0)); |
| break; |
| case 8: |
| value = static_cast<SignedXregType>( |
| generic::GetInstructionSource<int64_t>(inst, 0)); |
| break; |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| WriteCapIntResult<SignedXregType>(inst, 0, value); |
| } |
| |
| void VmvFromScalar(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) return; |
| if (rv_vector->vector_length() == 0) return; |
| int sew = rv_vector->selected_element_width(); |
| auto *dest_db = inst->Destination(0)->AllocateDataBuffer(); |
| std::memset(dest_db->raw_ptr(), 0, dest_db->size<uint8_t>()); |
| switch (sew) { |
| case 1: |
| dest_db->Set<int8_t>(0, generic::GetInstructionSource<int8_t>(inst, 0)); |
| break; |
| case 2: |
| dest_db->Set<int16_t>(0, generic::GetInstructionSource<int16_t>(inst, 0)); |
| break; |
| case 4: |
| dest_db->Set<int32_t>(0, generic::GetInstructionSource<int32_t>(inst, 0)); |
| break; |
| case 8: |
| dest_db->Set<int64_t>(0, generic::GetInstructionSource<int64_t>(inst, 0)); |
| break; |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| dest_db->Submit(); |
| } |
| |
| // Population count of vector mask register. The value is written to a scalar |
| // register. |
| void Vcpop(Instruction *inst) { |
| auto *rv_state = static_cast<CheriotState *>(inst->state()); |
| auto *rv_vector = rv_state->rv_vector(); |
| if (rv_vector->vstart()) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int vlen = rv_vector->vector_length(); |
| auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| uint64_t count = 0; |
| for (int i = 0; i < vlen; i++) { |
| int index = i >> 3; |
| int offset = i & 0b111; |
| int mask_value = (mask_span[index] >> offset); |
| int src_value = (src_span[index] >> offset); |
| count += mask_value & src_value & 0b1; |
| } |
| WriteCapIntResult<uint32_t>(inst, 0, count); |
| } |
| |
| // Find first set of vector mask register. The value is written to a scalar |
| // register. |
| void Vfirst(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| // Initialize the element index to -1. |
| uint64_t element_index = -1LL; |
| int vlen = rv_vector->vector_length(); |
| for (int i = 0; i < vlen; i++) { |
| int index = i >> 3; |
| int offset = i & 0b111; |
| int mask_value = (mask_span[index] >> offset); |
| int src_value = (src_span[index] >> offset); |
| if (mask_value & src_value & 0b1) { |
| element_index = i; |
| break; |
| } |
| } |
| WriteCapIntResult<uint32_t>(inst, 0, element_index); |
| } |
| |
| // Vector integer sign and zero extension instructions. |
| void Vzext2(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 2: |
| return RiscVUnaryVectorOp<uint16_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2) -> uint16_t { return static_cast<uint16_t>(vs2); }); |
| case 4: |
| return RiscVUnaryVectorOp<uint32_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<uint64_t, uint32_t>( |
| rv_vector, inst, |
| [](uint32_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext2"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| void Vsext2(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 2: |
| return RiscVUnaryVectorOp<int16_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2) -> int16_t { return static_cast<int16_t>(vs2); }); |
| case 4: |
| return RiscVUnaryVectorOp<uint32_t, uint16_t>( |
| rv_vector, inst, |
| [](int16_t vs2) -> int32_t { return static_cast<int32_t>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<int64_t, int32_t>( |
| rv_vector, inst, |
| [](int32_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vsext2"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| void Vzext4(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<uint32_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<uint64_t, uint16_t>( |
| rv_vector, inst, |
| [](uint16_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| void Vsext4(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 4: |
| return RiscVUnaryVectorOp<uint32_t, uint8_t>( |
| rv_vector, inst, |
| [](int8_t vs2) -> int32_t { return static_cast<int32_t>(vs2); }); |
| case 8: |
| return RiscVUnaryVectorOp<int64_t, int16_t>( |
| rv_vector, inst, |
| [](int16_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| void Vzext8(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<uint64_t, uint8_t>( |
| rv_vector, inst, |
| [](uint8_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| void Vsext8(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| switch (sew) { |
| case 8: |
| return RiscVUnaryVectorOp<int64_t, int8_t>( |
| rv_vector, inst, |
| [](int8_t vs2) -> int64_t { return static_cast<int64_t>(vs2); }); |
| default: |
| LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8"); |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| } |
| |
| // Vector mask set-before-first mask bit. |
| void Vmsbf(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int vlen = rv_vector->vector_length(); |
| auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto *dest_db = dest_op->CopyDataBuffer(0); |
| auto dest_span = dest_db->Get<uint8_t>(); |
| bool before_first = true; |
| int last = 0; |
| // Set the bits before the first active 1. |
| for (int i = 0; i < vlen; i++) { |
| last = i; |
| int index = i >> 3; |
| int offset = i & 0b111; |
| int mask_value = (mask_span[index] >> offset) & 0b1; |
| int src_value = (src_span[index] >> offset) & 0b1; |
| if (mask_value) { |
| before_first = before_first && (src_value == 0); |
| if (!before_first) break; |
| |
| dest_span[index] |= 1 << offset; |
| } |
| } |
| // Clear the remaining bits. |
| for (int i = last; !before_first && (i < vlen); i++) { |
| int index = i >> 3; |
| int offset = i & 0b111; |
| dest_span[index] &= ~(1 << offset); |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Vector mask set-including-first mask bit. |
| void Vmsif(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int vlen = rv_vector->vector_length(); |
| auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto *dest_db = dest_op->CopyDataBuffer(0); |
| auto dest_span = dest_db->Get<uint8_t>(); |
| uint8_t value = 1; |
| for (int i = 0; i < vlen; i++) { |
| int index = i >> 3; |
| int offset = i & 0b111; |
| int mask_value = (mask_span[index] >> offset) & 0b1; |
| int src_value = (src_span[index] >> offset) & 0b1; |
| if (mask_value) { |
| if (value) { |
| dest_span[index] |= 1 << offset; |
| } else { |
| dest_span[index] &= ~(1 << offset); |
| } |
| if (src_value) { |
| value = 0; |
| } |
| } |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Vector maks set-only-first mask bit. |
| void Vmsof(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| if (rv_vector->vstart()) { |
| rv_vector->set_vector_exception(); |
| return; |
| } |
| int vlen = rv_vector->vector_length(); |
| auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0)); |
| auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); |
| auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); |
| auto dest_op = |
| static_cast<RV32VectorDestinationOperand *>(inst->Destination(0)); |
| auto *dest_db = dest_op->CopyDataBuffer(0); |
| auto dest_span = dest_db->Get<uint8_t>(); |
| bool first = true; |
| for (int i = 0; i < vlen; i++) { |
| int index = i >> 3; |
| int offset = i & 0b111; |
| int mask_value = (mask_span[index] >> offset) & 0b1; |
| int src_value = (src_span[index] >> offset) & 0b1; |
| if (mask_value) { |
| if (first & src_value) { |
| dest_span[index] |= (1 << offset); |
| first = false; |
| } else { |
| dest_span[index] &= ~(1 << offset); |
| } |
| } |
| } |
| dest_db->Submit(); |
| rv_vector->clear_vstart(); |
| } |
| |
| // Vector iota. This instruction reads a source vector mask register and |
| // writes to each element of the destination vector register group the sum |
| // of all bits of elements in the mask register whose index is less than the |
| // element. This is subject to masking. |
| void Viota(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| int count = 0; |
| switch (sew) { |
| case 1: |
| return RiscVMaskNullaryVectorOp<uint8_t>( |
| rv_vector, inst, [&count](bool mask) -> uint8_t { |
| return mask ? static_cast<uint8_t>(count++) |
| : static_cast<uint8_t>(count); |
| }); |
| case 2: |
| return RiscVMaskNullaryVectorOp<uint16_t>( |
| rv_vector, inst, [&count](bool mask) -> uint16_t { |
| return mask ? static_cast<uint16_t>(count++) |
| : static_cast<uint16_t>(count); |
| }); |
| case 4: |
| return RiscVMaskNullaryVectorOp<uint32_t>( |
| rv_vector, inst, [&count](bool mask) -> uint32_t { |
| return mask ? static_cast<uint32_t>(count++) |
| : static_cast<uint32_t>(count); |
| }); |
| case 8: |
| return RiscVMaskNullaryVectorOp<uint64_t>( |
| rv_vector, inst, [&count](bool mask) -> uint64_t { |
| return mask ? static_cast<uint64_t>(count++) |
| : static_cast<uint64_t>(count); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| // Writes the index of each active (mask true) element to the destination |
| // vector elements. |
| void Vid(Instruction *inst) { |
| auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector(); |
| int sew = rv_vector->selected_element_width(); |
| int index = 0; |
| switch (sew) { |
| case 1: |
| return RiscVMaskNullaryVectorOp<uint8_t>( |
| rv_vector, inst, [&index](bool mask) -> uint8_t { |
| uint64_t ret = index++; |
| return static_cast<uint8_t>(ret); |
| }); |
| case 2: |
| return RiscVMaskNullaryVectorOp<uint16_t>( |
| rv_vector, inst, [&index](bool mask) -> uint16_t { |
| uint64_t ret = index++; |
| return static_cast<uint16_t>(ret); |
| }); |
| case 4: |
| return RiscVMaskNullaryVectorOp<uint32_t>( |
| rv_vector, inst, [&index](bool mask) -> uint32_t { |
| uint64_t ret = index++; |
| return static_cast<uint32_t>(ret); |
| }); |
| case 8: |
| return RiscVMaskNullaryVectorOp<uint64_t>( |
| rv_vector, inst, [&index](bool mask) -> uint64_t { |
| uint64_t ret = index++; |
| return static_cast<uint64_t>(ret); |
| }); |
| default: |
| rv_vector->set_vector_exception(); |
| LOG(ERROR) << "Illegal SEW value"; |
| return; |
| } |
| } |
| |
| } // namespace cheriot |
| } // namespace sim |
| } // namespace mpact |