blob: becf09a734d34c753c3c1d36fde5da5b99ce0e45 [file] [log] [blame]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "cheriot/riscv_cheriot_vector_unary_instructions.h"
#include <cstdint>
#include <cstring>
#include <functional>
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
#include "cheriot/cheriot_register.h"
#include "cheriot/cheriot_state.h"
#include "cheriot/cheriot_vector_state.h"
#include "cheriot/riscv_cheriot_instruction_helpers.h"
#include "cheriot/riscv_cheriot_vector_instruction_helpers.h"
#include "mpact/sim/generic/instruction.h"
#include "mpact/sim/generic/type_helpers.h"
namespace mpact {
namespace sim {
namespace cheriot {
using SignedXregType =
::mpact::sim::generic::SameSignedType<CheriotRegister::ValueType,
int64_t>::type;
// Move scalar to vector register.
void VmvToScalar(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) return;
if (rv_vector->vector_length() == 0) return;
int sew = rv_vector->selected_element_width();
SignedXregType value;
switch (sew) {
case 1:
value = static_cast<SignedXregType>(
generic::GetInstructionSource<int8_t>(inst, 0));
break;
case 2:
value = static_cast<SignedXregType>(
generic::GetInstructionSource<int16_t>(inst, 0));
break;
case 4:
value = static_cast<SignedXregType>(
generic::GetInstructionSource<int32_t>(inst, 0));
break;
case 8:
value = static_cast<SignedXregType>(
generic::GetInstructionSource<int64_t>(inst, 0));
break;
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs");
rv_vector->set_vector_exception();
return;
}
WriteCapIntResult<SignedXregType>(inst, 0, value);
}
void VmvFromScalar(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) return;
if (rv_vector->vector_length() == 0) return;
int sew = rv_vector->selected_element_width();
auto *dest_db = inst->Destination(0)->AllocateDataBuffer();
std::memset(dest_db->raw_ptr(), 0, dest_db->size<uint8_t>());
switch (sew) {
case 1:
dest_db->Set<int8_t>(0, generic::GetInstructionSource<int8_t>(inst, 0));
break;
case 2:
dest_db->Set<int16_t>(0, generic::GetInstructionSource<int16_t>(inst, 0));
break;
case 4:
dest_db->Set<int32_t>(0, generic::GetInstructionSource<int32_t>(inst, 0));
break;
case 8:
dest_db->Set<int64_t>(0, generic::GetInstructionSource<int64_t>(inst, 0));
break;
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vmvxs");
rv_vector->set_vector_exception();
return;
}
dest_db->Submit();
}
// Population count of vector mask register. The value is written to a scalar
// register.
void Vcpop(Instruction *inst) {
auto *rv_state = static_cast<CheriotState *>(inst->state());
auto *rv_vector = rv_state->rv_vector();
if (rv_vector->vstart()) {
rv_vector->set_vector_exception();
return;
}
int vlen = rv_vector->vector_length();
auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
uint64_t count = 0;
for (int i = 0; i < vlen; i++) {
int index = i >> 3;
int offset = i & 0b111;
int mask_value = (mask_span[index] >> offset);
int src_value = (src_span[index] >> offset);
count += mask_value & src_value & 0b1;
}
WriteCapIntResult<uint32_t>(inst, 0, count);
}
// Find first set of vector mask register. The value is written to a scalar
// register.
void Vfirst(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) {
rv_vector->set_vector_exception();
return;
}
auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
// Initialize the element index to -1.
uint64_t element_index = -1LL;
int vlen = rv_vector->vector_length();
for (int i = 0; i < vlen; i++) {
int index = i >> 3;
int offset = i & 0b111;
int mask_value = (mask_span[index] >> offset);
int src_value = (src_span[index] >> offset);
if (mask_value & src_value & 0b1) {
element_index = i;
break;
}
}
WriteCapIntResult<uint32_t>(inst, 0, element_index);
}
// Vector integer sign and zero extension instructions.
void Vzext2(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 2:
return RiscVUnaryVectorOp<uint16_t, uint8_t>(
rv_vector, inst,
[](uint8_t vs2) -> uint16_t { return static_cast<uint16_t>(vs2); });
case 4:
return RiscVUnaryVectorOp<uint32_t, uint16_t>(
rv_vector, inst,
[](uint16_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); });
case 8:
return RiscVUnaryVectorOp<uint64_t, uint32_t>(
rv_vector, inst,
[](uint32_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext2");
rv_vector->set_vector_exception();
return;
}
}
void Vsext2(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 2:
return RiscVUnaryVectorOp<int16_t, int8_t>(
rv_vector, inst,
[](int8_t vs2) -> int16_t { return static_cast<int16_t>(vs2); });
case 4:
return RiscVUnaryVectorOp<uint32_t, uint16_t>(
rv_vector, inst,
[](int16_t vs2) -> int32_t { return static_cast<int32_t>(vs2); });
case 8:
return RiscVUnaryVectorOp<int64_t, int32_t>(
rv_vector, inst,
[](int32_t vs2) -> int64_t { return static_cast<int64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vsext2");
rv_vector->set_vector_exception();
return;
}
}
void Vzext4(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVUnaryVectorOp<uint32_t, uint8_t>(
rv_vector, inst,
[](uint8_t vs2) -> uint32_t { return static_cast<uint32_t>(vs2); });
case 8:
return RiscVUnaryVectorOp<uint64_t, uint16_t>(
rv_vector, inst,
[](uint16_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4");
rv_vector->set_vector_exception();
return;
}
}
void Vsext4(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 4:
return RiscVUnaryVectorOp<uint32_t, uint8_t>(
rv_vector, inst,
[](int8_t vs2) -> int32_t { return static_cast<int32_t>(vs2); });
case 8:
return RiscVUnaryVectorOp<int64_t, int16_t>(
rv_vector, inst,
[](int16_t vs2) -> int64_t { return static_cast<int64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext4");
rv_vector->set_vector_exception();
return;
}
}
void Vzext8(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 8:
return RiscVUnaryVectorOp<uint64_t, uint8_t>(
rv_vector, inst,
[](uint8_t vs2) -> uint64_t { return static_cast<uint64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8");
rv_vector->set_vector_exception();
return;
}
}
void Vsext8(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
switch (sew) {
case 8:
return RiscVUnaryVectorOp<int64_t, int8_t>(
rv_vector, inst,
[](int8_t vs2) -> int64_t { return static_cast<int64_t>(vs2); });
default:
LOG(ERROR) << absl::StrCat("Illegal SEW value (", sew, ") for Vzext8");
rv_vector->set_vector_exception();
return;
}
}
// Vector mask set-before-first mask bit.
void Vmsbf(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) {
rv_vector->set_vector_exception();
return;
}
int vlen = rv_vector->vector_length();
auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
auto *dest_db = dest_op->CopyDataBuffer(0);
auto dest_span = dest_db->Get<uint8_t>();
bool before_first = true;
int last = 0;
// Set the bits before the first active 1.
for (int i = 0; i < vlen; i++) {
last = i;
int index = i >> 3;
int offset = i & 0b111;
int mask_value = (mask_span[index] >> offset) & 0b1;
int src_value = (src_span[index] >> offset) & 0b1;
if (mask_value) {
before_first = before_first && (src_value == 0);
if (!before_first) break;
dest_span[index] |= 1 << offset;
}
}
// Clear the remaining bits.
for (int i = last; !before_first && (i < vlen); i++) {
int index = i >> 3;
int offset = i & 0b111;
dest_span[index] &= ~(1 << offset);
}
dest_db->Submit();
rv_vector->clear_vstart();
}
// Vector mask set-including-first mask bit.
void Vmsif(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) {
rv_vector->set_vector_exception();
return;
}
int vlen = rv_vector->vector_length();
auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
auto *dest_db = dest_op->CopyDataBuffer(0);
auto dest_span = dest_db->Get<uint8_t>();
uint8_t value = 1;
for (int i = 0; i < vlen; i++) {
int index = i >> 3;
int offset = i & 0b111;
int mask_value = (mask_span[index] >> offset) & 0b1;
int src_value = (src_span[index] >> offset) & 0b1;
if (mask_value) {
if (value) {
dest_span[index] |= 1 << offset;
} else {
dest_span[index] &= ~(1 << offset);
}
if (src_value) {
value = 0;
}
}
}
dest_db->Submit();
rv_vector->clear_vstart();
}
// Vector maks set-only-first mask bit.
void Vmsof(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
if (rv_vector->vstart()) {
rv_vector->set_vector_exception();
return;
}
int vlen = rv_vector->vector_length();
auto src_op = static_cast<RV32VectorSourceOperand *>(inst->Source(0));
auto src_span = src_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
auto mask_span = mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
auto dest_op =
static_cast<RV32VectorDestinationOperand *>(inst->Destination(0));
auto *dest_db = dest_op->CopyDataBuffer(0);
auto dest_span = dest_db->Get<uint8_t>();
bool first = true;
for (int i = 0; i < vlen; i++) {
int index = i >> 3;
int offset = i & 0b111;
int mask_value = (mask_span[index] >> offset) & 0b1;
int src_value = (src_span[index] >> offset) & 0b1;
if (mask_value) {
if (first & src_value) {
dest_span[index] |= (1 << offset);
first = false;
} else {
dest_span[index] &= ~(1 << offset);
}
}
}
dest_db->Submit();
rv_vector->clear_vstart();
}
// Vector iota. This instruction reads a source vector mask register and
// writes to each element of the destination vector register group the sum
// of all bits of elements in the mask register whose index is less than the
// element. This is subject to masking.
void Viota(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
int count = 0;
switch (sew) {
case 1:
return RiscVMaskNullaryVectorOp<uint8_t>(
rv_vector, inst, [&count](bool mask) -> uint8_t {
return mask ? static_cast<uint8_t>(count++)
: static_cast<uint8_t>(count);
});
case 2:
return RiscVMaskNullaryVectorOp<uint16_t>(
rv_vector, inst, [&count](bool mask) -> uint16_t {
return mask ? static_cast<uint16_t>(count++)
: static_cast<uint16_t>(count);
});
case 4:
return RiscVMaskNullaryVectorOp<uint32_t>(
rv_vector, inst, [&count](bool mask) -> uint32_t {
return mask ? static_cast<uint32_t>(count++)
: static_cast<uint32_t>(count);
});
case 8:
return RiscVMaskNullaryVectorOp<uint64_t>(
rv_vector, inst, [&count](bool mask) -> uint64_t {
return mask ? static_cast<uint64_t>(count++)
: static_cast<uint64_t>(count);
});
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
// Writes the index of each active (mask true) element to the destination
// vector elements.
void Vid(Instruction *inst) {
auto *rv_vector = static_cast<CheriotState *>(inst->state())->rv_vector();
int sew = rv_vector->selected_element_width();
int index = 0;
switch (sew) {
case 1:
return RiscVMaskNullaryVectorOp<uint8_t>(
rv_vector, inst, [&index](bool mask) -> uint8_t {
uint64_t ret = index++;
return static_cast<uint8_t>(ret);
});
case 2:
return RiscVMaskNullaryVectorOp<uint16_t>(
rv_vector, inst, [&index](bool mask) -> uint16_t {
uint64_t ret = index++;
return static_cast<uint16_t>(ret);
});
case 4:
return RiscVMaskNullaryVectorOp<uint32_t>(
rv_vector, inst, [&index](bool mask) -> uint32_t {
uint64_t ret = index++;
return static_cast<uint32_t>(ret);
});
case 8:
return RiscVMaskNullaryVectorOp<uint64_t>(
rv_vector, inst, [&index](bool mask) -> uint64_t {
uint64_t ret = index++;
return static_cast<uint64_t>(ret);
});
default:
rv_vector->set_vector_exception();
LOG(ERROR) << "Illegal SEW value";
return;
}
}
} // namespace cheriot
} // namespace sim
} // namespace mpact