Fixes scaling issues in vector loads/stores.
PiperOrigin-RevId: 663032671
Change-Id: Ie9a31cf5fdd338e23c1967e8d326fece7dbdee17
diff --git a/cheriot/cheriot_rvv_getters.h b/cheriot/cheriot_rvv_getters.h
index 303a2cb..190d897 100644
--- a/cheriot/cheriot_rvv_getters.h
+++ b/cheriot/cheriot_rvv_getters.h
@@ -52,6 +52,15 @@
Insert(getter_map, *Enum::kConst1, [common]() -> SourceOperandInterface * {
return new IntLiteralOperand<1>();
});
+ Insert(getter_map, *Enum::kConst2, [common]() -> SourceOperandInterface * {
+ return new IntLiteralOperand<2>();
+ });
+ Insert(getter_map, *Enum::kConst4, [common]() -> SourceOperandInterface * {
+ return new IntLiteralOperand<4>();
+ });
+ Insert(getter_map, *Enum::kConst8, [common]() -> SourceOperandInterface * {
+ return new IntLiteralOperand<8>();
+ });
Insert(getter_map, *Enum::kNf, [common]() -> SourceOperandInterface * {
auto imm = Extractors::VMem::ExtractNf(common->inst_word());
return new ImmediateOperand<uint32_t>(imm);
diff --git a/cheriot/riscv_cheriot_vector.isa b/cheriot/riscv_cheriot_vector.isa
index a4c00d8..d9078c5 100644
--- a/cheriot/riscv_cheriot_vector.isa
+++ b/cheriot/riscv_cheriot_vector.isa
@@ -59,32 +59,32 @@
// VECTOR LOADS
// Unit stride loads, masked (vm=0)
- vle8{(: rs1, vmask :), (: : vd )},
+ vle8{(: rs1, const1, vmask :), (: : vd )},
disasm: "vle8.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
- vle16{(: rs1, vmask :), (: : vd )},
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+ vle16{(: rs1, const2, vmask :), (: : vd )},
disasm: "vle16.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
- vle32{(: rs1, vmask :), ( : : vd) },
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+ vle32{(: rs1, const4, vmask :), ( : : vd) },
disasm: "vle32.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
- vle64{(: rs1, vmask :), ( : : vd) },
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+ vle64{(: rs1, const8, vmask :), ( : : vd) },
disasm: "vle64.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
// Unit stride loads, unmasked (vm=1)
- vle8_vm1{(: rs1, vmask_true :), (: : vd )},
+ vle8_vm1{(: rs1, const1, vmask_true :), (: : vd )},
disasm: "vle8.v", "%vd, (%rs1)",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
- vle16_vm1{(: rs1, vmask_true :), (: : vd )},
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+ vle16_vm1{(: rs1, const2, vmask_true :), (: : vd )},
disasm: "vle16.v", "%vd, (%rs1)",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
- vle32_vm1{(: rs1, vmask_true :), ( : : vd) },
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+ vle32_vm1{(: rs1, const4, vmask_true :), ( : : vd) },
disasm: "vle32.v", "%vd, (%rs1)",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
- vle64_vm1{(: rs1, vmask_true :), ( : : vd) },
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+ vle64_vm1{(: rs1, const8, vmask_true :), ( : : vd) },
disasm: "vle64.v", "%vd, (%rs1)",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
// Vector strided loads
vlse8{(: rs1, rs2, vmask :), (: : vd)},
@@ -106,18 +106,18 @@
semfunc: "&Vlm", "&VlChild";
// Unit stride vector load, fault first
- vle8ff{(: rs1, vmask:), (: : vd)},
+ vle8ff{(: rs1, const1, vmask:), (: : vd)},
disasm: "vle8ff.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
- vle16ff{(: rs1, vmask:), (: : vd)},
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+ vle16ff{(: rs1, const2, vmask:), (: : vd)},
disasm: "vle16ff.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
- vle32ff{(: rs1, vmask:), (: : vd)},
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+ vle32ff{(: rs1, const4, vmask:), (: : vd)},
disasm: "vle32ff.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
- vle64ff{(: rs1, vmask:), (: : vd)},
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+ vle64ff{(: rs1, const8, vmask:), (: : vd)},
disasm: "vle64ff.v", "%vd, (%rs1), %vmask",
- semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+ semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
// Vector register load
vl1re8{(: rs1 :), (: : vd)},
@@ -276,13 +276,13 @@
vse8{: vs3, rs1, const1, vmask : },
disasm: "vse8.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)";
- vse16{: vs3, rs1, const1, vmask : },
+ vse16{: vs3, rs1, const2, vmask : },
disasm: "vse16.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)";
- vse32{: vs3, rs1, const1, vmask : },
+ vse32{: vs3, rs1, const4, vmask : },
disasm: "vse32.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)";
- vse64{: vs3, rs1, const1, vmask : },
+ vse64{: vs3, rs1, const8, vmask : },
disasm: "vse64.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)";
@@ -295,13 +295,13 @@
vse8ff{: vs3, rs1, const1, vmask:},
disasm: "vse8ff.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)";
- vse16ff{: vs3, rs1, const1, vmask:},
+ vse16ff{: vs3, rs1, const2, vmask:},
disasm: "vse16ff.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)";
- vse32ff{: vs3, rs1, const1, vmask:},
+ vse32ff{: vs3, rs1, const4, vmask:},
disasm: "vse32ff.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)";
- vse64ff{: vs3, rs1, const1, vmask:},
+ vse64ff{: vs3, rs1, const8, vmask:},
disasm: "vse64ff.v", "%vs3, (%rs1), %vmask",
semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)";
diff --git a/cheriot/riscv_cheriot_vector_memory_instructions.cc b/cheriot/riscv_cheriot_vector_memory_instructions.cc
index 66f3b41..f56bab5 100644
--- a/cheriot/riscv_cheriot_vector_memory_instructions.cc
+++ b/cheriot/riscv_cheriot_vector_memory_instructions.cc
@@ -256,77 +256,6 @@
// zero, or negative.
// Source(0): base address.
-// Source(1): vector mask register, vector constant {1..} if not masked.
-// Destination(0): vector destination register.
-void VlUnitStrided(int element_width, const Instruction *inst) {
- auto *state = static_cast<CheriotState *>(inst->state());
- auto *rv_vector = state->rv_vector();
- int start = rv_vector->vstart();
- auto cap_reg = GetCapSource(inst, 0);
- if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return;
- uint64_t base = cap_reg->address();
- int emul = element_width * rv_vector->vector_length_multiplier() /
- rv_vector->selected_element_width();
- if ((emul > 64) || (emul == 0)) {
- // TODO: signal vector error.
- LOG(WARNING) << "EMUL (" << emul << ") out of range";
- return;
- }
-
- // Compute total number of elements to be loaded.
- int num_elements = rv_vector->vector_length();
- int num_elements_loaded = num_elements - start;
-
- // Allocate address data buffer.
- auto *db_factory = inst->state()->db_factory();
- auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded);
-
- // Allocate the value data buffer that the loaded data is returned in.
- auto *value_db = db_factory->Allocate(num_elements_loaded * element_width);
-
- // Get the source mask (stored in a single vector register).
- auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
- auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
-
- // Allocate a byte mask data buffer for the load.
- auto *mask_db = db_factory->Allocate<bool>(num_elements_loaded);
-
- // Get the spans for addresses and masks.
- auto addresses = address_db->Get<uint64_t>();
- auto masks = mask_db->Get<bool>();
-
- // The vector mask in the vector register is a bit mask. The mask used in
- // the LoadMemory call is a bool mask so convert the bit masks to bool masks
- // and compute the element addresses.
- for (int i = start; i < num_elements; i++) {
- int index = i >> 3;
- int offset = i & 0b111;
- addresses[i - start] = base + i * element_width;
- masks[i - start] = ((src_masks[index] >> offset) & 0b1) != 0;
- if (masks[i - start]) {
- if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg,
- state)) {
- address_db->DecRef();
- mask_db->DecRef();
- value_db->DecRef();
- return;
- }
- }
- }
-
- // Set up the context, and submit the load.
- auto *context = new VectorLoadContext(value_db, mask_db, element_width, start,
- rv_vector->vector_length());
- value_db->set_latency(0);
- state->LoadMemory(inst, address_db, mask_db, element_width, value_db,
- inst->child(), context);
- // Release the context and address_db. The others will be released elsewhere.
- context->DecRef();
- address_db->DecRef();
- rv_vector->clear_vstart();
-}
-
-// Source(0): base address.
// Source(1): stride size bytes.
// Source(2): vector mask register, vector constant {1..} if not masked.
// Destination(0): vector destination register.
diff --git a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
index 17770bb..78b7f03 100644
--- a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
+++ b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
@@ -68,7 +68,6 @@
using ::mpact::sim::cheriot::VlSegmentIndexed;
using ::mpact::sim::cheriot::VlSegmentStrided;
using ::mpact::sim::cheriot::VlStrided;
-using ::mpact::sim::cheriot::VlUnitStrided;
using ::mpact::sim::cheriot::Vsetvl;
using ::mpact::sim::cheriot::VsIndexed;
using ::mpact::sim::cheriot::Vsm;
@@ -314,16 +313,17 @@
template <typename T>
void VectorLoadUnitStridedHelper() {
// Set up instructions.
- AppendRegisterOperands({kRs1Name}, {});
+ AppendRegisterOperands({kRs1Name, kRs2Name}, {});
AppendVectorRegisterOperands({kVmask}, {});
- SetSemanticFunction(absl::bind_front(&VlUnitStrided,
+ SetSemanticFunction(absl::bind_front(&VlStrided,
/*element_width*/ sizeof(T)));
// Add the child instruction that performs the register write-back.
SetChildInstruction();
SetChildSemanticFunction(&VlChild);
AppendVectorRegisterOperands(child_instruction_, {}, {kVd});
// Set up register values.
- SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}});
+ SetRegisterValues<uint32_t>(
+ {{kRs1Name, kDataLoadAddress}, {kRs2Name, sizeof(T)}});
SetVectorRegisterValues<uint8_t>(
{{kVmaskName, Span<const uint8_t>(kA5Mask)}});
// Iterate over different lmul values.