Fixes scaling issues in vector loads/stores. PiperOrigin-RevId: 663032671 Change-Id: Ie9a31cf5fdd338e23c1967e8d326fece7dbdee17
diff --git a/cheriot/cheriot_rvv_getters.h b/cheriot/cheriot_rvv_getters.h index 303a2cb..190d897 100644 --- a/cheriot/cheriot_rvv_getters.h +++ b/cheriot/cheriot_rvv_getters.h
@@ -52,6 +52,15 @@ Insert(getter_map, *Enum::kConst1, [common]() -> SourceOperandInterface * { return new IntLiteralOperand<1>(); }); + Insert(getter_map, *Enum::kConst2, [common]() -> SourceOperandInterface * { + return new IntLiteralOperand<2>(); + }); + Insert(getter_map, *Enum::kConst4, [common]() -> SourceOperandInterface * { + return new IntLiteralOperand<4>(); + }); + Insert(getter_map, *Enum::kConst8, [common]() -> SourceOperandInterface * { + return new IntLiteralOperand<8>(); + }); Insert(getter_map, *Enum::kNf, [common]() -> SourceOperandInterface * { auto imm = Extractors::VMem::ExtractNf(common->inst_word()); return new ImmediateOperand<uint32_t>(imm);
diff --git a/cheriot/riscv_cheriot_vector.isa b/cheriot/riscv_cheriot_vector.isa index a4c00d8..d9078c5 100644 --- a/cheriot/riscv_cheriot_vector.isa +++ b/cheriot/riscv_cheriot_vector.isa
@@ -59,32 +59,32 @@ // VECTOR LOADS // Unit stride loads, masked (vm=0) - vle8{(: rs1, vmask :), (: : vd )}, + vle8{(: rs1, const1, vmask :), (: : vd )}, disasm: "vle8.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; - vle16{(: rs1, vmask :), (: : vd )}, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild"; + vle16{(: rs1, const2, vmask :), (: : vd )}, disasm: "vle16.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; - vle32{(: rs1, vmask :), ( : : vd) }, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild"; + vle32{(: rs1, const4, vmask :), ( : : vd) }, disasm: "vle32.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; - vle64{(: rs1, vmask :), ( : : vd) }, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild"; + vle64{(: rs1, const8, vmask :), ( : : vd) }, disasm: "vle64.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild"; // Unit stride loads, unmasked (vm=1) - vle8_vm1{(: rs1, vmask_true :), (: : vd )}, + vle8_vm1{(: rs1, const1, vmask_true :), (: : vd )}, disasm: "vle8.v", "%vd, (%rs1)", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; - vle16_vm1{(: rs1, vmask_true :), (: : vd )}, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild"; + vle16_vm1{(: rs1, const2, vmask_true :), (: : vd )}, disasm: "vle16.v", "%vd, (%rs1)", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; - vle32_vm1{(: rs1, vmask_true :), ( : : vd) }, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild"; + vle32_vm1{(: rs1, const4, vmask_true :), ( : : vd) }, disasm: "vle32.v", "%vd, (%rs1)", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; - vle64_vm1{(: rs1, vmask_true :), ( : : vd) }, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild"; + vle64_vm1{(: rs1, const8, vmask_true :), ( : : vd) }, disasm: "vle64.v", "%vd, (%rs1)", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild"; // Vector strided loads vlse8{(: rs1, rs2, vmask :), (: : vd)}, @@ -106,18 +106,18 @@ semfunc: "&Vlm", "&VlChild"; // Unit stride vector load, fault first - vle8ff{(: rs1, vmask:), (: : vd)}, + vle8ff{(: rs1, const1, vmask:), (: : vd)}, disasm: "vle8ff.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild"; - vle16ff{(: rs1, vmask:), (: : vd)}, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild"; + vle16ff{(: rs1, const2, vmask:), (: : vd)}, disasm: "vle16ff.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild"; - vle32ff{(: rs1, vmask:), (: : vd)}, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild"; + vle32ff{(: rs1, const4, vmask:), (: : vd)}, disasm: "vle32ff.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild"; - vle64ff{(: rs1, vmask:), (: : vd)}, + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild"; + vle64ff{(: rs1, const8, vmask:), (: : vd)}, disasm: "vle64ff.v", "%vd, (%rs1), %vmask", - semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild"; + semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild"; // Vector register load vl1re8{(: rs1 :), (: : vd)}, @@ -276,13 +276,13 @@ vse8{: vs3, rs1, const1, vmask : }, disasm: "vse8.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)"; - vse16{: vs3, rs1, const1, vmask : }, + vse16{: vs3, rs1, const2, vmask : }, disasm: "vse16.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)"; - vse32{: vs3, rs1, const1, vmask : }, + vse32{: vs3, rs1, const4, vmask : }, disasm: "vse32.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)"; - vse64{: vs3, rs1, const1, vmask : }, + vse64{: vs3, rs1, const8, vmask : }, disasm: "vse64.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)"; @@ -295,13 +295,13 @@ vse8ff{: vs3, rs1, const1, vmask:}, disasm: "vse8ff.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)"; - vse16ff{: vs3, rs1, const1, vmask:}, + vse16ff{: vs3, rs1, const2, vmask:}, disasm: "vse16ff.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)"; - vse32ff{: vs3, rs1, const1, vmask:}, + vse32ff{: vs3, rs1, const4, vmask:}, disasm: "vse32ff.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)"; - vse64ff{: vs3, rs1, const1, vmask:}, + vse64ff{: vs3, rs1, const8, vmask:}, disasm: "vse64ff.v", "%vs3, (%rs1), %vmask", semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)";
diff --git a/cheriot/riscv_cheriot_vector_memory_instructions.cc b/cheriot/riscv_cheriot_vector_memory_instructions.cc index 66f3b41..f56bab5 100644 --- a/cheriot/riscv_cheriot_vector_memory_instructions.cc +++ b/cheriot/riscv_cheriot_vector_memory_instructions.cc
@@ -256,77 +256,6 @@ // zero, or negative. // Source(0): base address. -// Source(1): vector mask register, vector constant {1..} if not masked. -// Destination(0): vector destination register. -void VlUnitStrided(int element_width, const Instruction *inst) { - auto *state = static_cast<CheriotState *>(inst->state()); - auto *rv_vector = state->rv_vector(); - int start = rv_vector->vstart(); - auto cap_reg = GetCapSource(inst, 0); - if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return; - uint64_t base = cap_reg->address(); - int emul = element_width * rv_vector->vector_length_multiplier() / - rv_vector->selected_element_width(); - if ((emul > 64) || (emul == 0)) { - // TODO: signal vector error. - LOG(WARNING) << "EMUL (" << emul << ") out of range"; - return; - } - - // Compute total number of elements to be loaded. - int num_elements = rv_vector->vector_length(); - int num_elements_loaded = num_elements - start; - - // Allocate address data buffer. - auto *db_factory = inst->state()->db_factory(); - auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded); - - // Allocate the value data buffer that the loaded data is returned in. - auto *value_db = db_factory->Allocate(num_elements_loaded * element_width); - - // Get the source mask (stored in a single vector register). - auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1)); - auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>(); - - // Allocate a byte mask data buffer for the load. - auto *mask_db = db_factory->Allocate<bool>(num_elements_loaded); - - // Get the spans for addresses and masks. - auto addresses = address_db->Get<uint64_t>(); - auto masks = mask_db->Get<bool>(); - - // The vector mask in the vector register is a bit mask. The mask used in - // the LoadMemory call is a bool mask so convert the bit masks to bool masks - // and compute the element addresses. - for (int i = start; i < num_elements; i++) { - int index = i >> 3; - int offset = i & 0b111; - addresses[i - start] = base + i * element_width; - masks[i - start] = ((src_masks[index] >> offset) & 0b1) != 0; - if (masks[i - start]) { - if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg, - state)) { - address_db->DecRef(); - mask_db->DecRef(); - value_db->DecRef(); - return; - } - } - } - - // Set up the context, and submit the load. - auto *context = new VectorLoadContext(value_db, mask_db, element_width, start, - rv_vector->vector_length()); - value_db->set_latency(0); - state->LoadMemory(inst, address_db, mask_db, element_width, value_db, - inst->child(), context); - // Release the context and address_db. The others will be released elsewhere. - context->DecRef(); - address_db->DecRef(); - rv_vector->clear_vstart(); -} - -// Source(0): base address. // Source(1): stride size bytes. // Source(2): vector mask register, vector constant {1..} if not masked. // Destination(0): vector destination register.
diff --git a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc index 17770bb..78b7f03 100644 --- a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc +++ b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
@@ -68,7 +68,6 @@ using ::mpact::sim::cheriot::VlSegmentIndexed; using ::mpact::sim::cheriot::VlSegmentStrided; using ::mpact::sim::cheriot::VlStrided; -using ::mpact::sim::cheriot::VlUnitStrided; using ::mpact::sim::cheriot::Vsetvl; using ::mpact::sim::cheriot::VsIndexed; using ::mpact::sim::cheriot::Vsm; @@ -314,16 +313,17 @@ template <typename T> void VectorLoadUnitStridedHelper() { // Set up instructions. - AppendRegisterOperands({kRs1Name}, {}); + AppendRegisterOperands({kRs1Name, kRs2Name}, {}); AppendVectorRegisterOperands({kVmask}, {}); - SetSemanticFunction(absl::bind_front(&VlUnitStrided, + SetSemanticFunction(absl::bind_front(&VlStrided, /*element_width*/ sizeof(T))); // Add the child instruction that performs the register write-back. SetChildInstruction(); SetChildSemanticFunction(&VlChild); AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); // Set up register values. - SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + SetRegisterValues<uint32_t>( + {{kRs1Name, kDataLoadAddress}, {kRs2Name, sizeof(T)}}); SetVectorRegisterValues<uint8_t>( {{kVmaskName, Span<const uint8_t>(kA5Mask)}}); // Iterate over different lmul values.