Fixes scaling issues in vector loads/stores.

PiperOrigin-RevId: 663032671
Change-Id: Ie9a31cf5fdd338e23c1967e8d326fece7dbdee17
diff --git a/cheriot/cheriot_rvv_getters.h b/cheriot/cheriot_rvv_getters.h
index 303a2cb..190d897 100644
--- a/cheriot/cheriot_rvv_getters.h
+++ b/cheriot/cheriot_rvv_getters.h
@@ -52,6 +52,15 @@
   Insert(getter_map, *Enum::kConst1, [common]() -> SourceOperandInterface * {
     return new IntLiteralOperand<1>();
   });
+  Insert(getter_map, *Enum::kConst2, [common]() -> SourceOperandInterface * {
+    return new IntLiteralOperand<2>();
+  });
+  Insert(getter_map, *Enum::kConst4, [common]() -> SourceOperandInterface * {
+    return new IntLiteralOperand<4>();
+  });
+  Insert(getter_map, *Enum::kConst8, [common]() -> SourceOperandInterface * {
+    return new IntLiteralOperand<8>();
+  });
   Insert(getter_map, *Enum::kNf, [common]() -> SourceOperandInterface * {
     auto imm = Extractors::VMem::ExtractNf(common->inst_word());
     return new ImmediateOperand<uint32_t>(imm);
diff --git a/cheriot/riscv_cheriot_vector.isa b/cheriot/riscv_cheriot_vector.isa
index a4c00d8..d9078c5 100644
--- a/cheriot/riscv_cheriot_vector.isa
+++ b/cheriot/riscv_cheriot_vector.isa
@@ -59,32 +59,32 @@
     // VECTOR LOADS
 
     // Unit stride loads, masked (vm=0)
-    vle8{(: rs1, vmask :), (: : vd )},
+    vle8{(: rs1, const1, vmask :), (: : vd )},
       disasm: "vle8.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
-    vle16{(: rs1, vmask :), (: : vd )},
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+    vle16{(: rs1, const2, vmask :), (: : vd )},
       disasm: "vle16.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
-    vle32{(: rs1, vmask :), ( : : vd) },
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+    vle32{(: rs1, const4, vmask :), ( : : vd) },
       disasm: "vle32.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
-    vle64{(: rs1, vmask :), ( : : vd) },
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+    vle64{(: rs1, const8, vmask :), ( : : vd) },
       disasm: "vle64.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
 
     // Unit stride loads, unmasked (vm=1)
-    vle8_vm1{(: rs1, vmask_true :), (: : vd )},
+    vle8_vm1{(: rs1, const1, vmask_true :), (: : vd )},
       disasm: "vle8.v", "%vd, (%rs1)",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
-    vle16_vm1{(: rs1, vmask_true :), (: : vd )},
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+    vle16_vm1{(: rs1, const2, vmask_true :), (: : vd )},
       disasm: "vle16.v", "%vd, (%rs1)",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
-    vle32_vm1{(: rs1, vmask_true :), ( : : vd) },
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+    vle32_vm1{(: rs1, const4, vmask_true :), ( : : vd) },
       disasm: "vle32.v", "%vd, (%rs1)",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
-    vle64_vm1{(: rs1, vmask_true :), ( : : vd) },
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+    vle64_vm1{(: rs1, const8, vmask_true :), ( : : vd) },
       disasm: "vle64.v", "%vd, (%rs1)",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
 
     // Vector strided loads
     vlse8{(: rs1, rs2, vmask :), (: : vd)},
@@ -106,18 +106,18 @@
       semfunc: "&Vlm", "&VlChild";
 
     // Unit stride vector load, fault first
-    vle8ff{(: rs1, vmask:), (: : vd)},
+    vle8ff{(: rs1, const1, vmask:), (: : vd)},
       disasm: "vle8ff.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 1)", "&VlChild";
-    vle16ff{(: rs1, vmask:), (: : vd)},
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 1)", "&VlChild";
+    vle16ff{(: rs1, const2, vmask:), (: : vd)},
       disasm: "vle16ff.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 2)", "&VlChild";
-    vle32ff{(: rs1, vmask:), (: : vd)},
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 2)", "&VlChild";
+    vle32ff{(: rs1, const4, vmask:), (: : vd)},
       disasm: "vle32ff.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 4)", "&VlChild";
-    vle64ff{(: rs1, vmask:), (: : vd)},
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 4)", "&VlChild";
+    vle64ff{(: rs1, const8, vmask:), (: : vd)},
       disasm: "vle64ff.v", "%vd, (%rs1), %vmask",
-      semfunc: "absl::bind_front(&VlUnitStrided, /*element_width*/ 8)", "&VlChild";
+      semfunc: "absl::bind_front(&VlStrided, /*element_width*/ 8)", "&VlChild";
 
     // Vector register load
     vl1re8{(: rs1 :), (: : vd)},
@@ -276,13 +276,13 @@
     vse8{: vs3, rs1, const1, vmask : },
       disasm: "vse8.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)";
-    vse16{: vs3, rs1, const1, vmask : },
+    vse16{: vs3, rs1, const2, vmask : },
       disasm: "vse16.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)";
-    vse32{: vs3, rs1, const1, vmask : },
+    vse32{: vs3, rs1, const4, vmask : },
       disasm: "vse32.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)";
-    vse64{: vs3, rs1, const1, vmask : },
+    vse64{: vs3, rs1, const8, vmask : },
       disasm: "vse64.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)";
 
@@ -295,13 +295,13 @@
     vse8ff{: vs3, rs1, const1, vmask:},
       disasm: "vse8ff.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 1)";
-    vse16ff{: vs3, rs1, const1, vmask:},
+    vse16ff{: vs3, rs1, const2, vmask:},
       disasm: "vse16ff.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 2)";
-    vse32ff{: vs3, rs1, const1, vmask:},
+    vse32ff{: vs3, rs1, const4, vmask:},
       disasm: "vse32ff.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 4)";
-    vse64ff{: vs3, rs1, const1, vmask:},
+    vse64ff{: vs3, rs1, const8, vmask:},
       disasm: "vse64ff.v", "%vs3, (%rs1), %vmask",
       semfunc: "absl::bind_front(&VsStrided, /*element_width*/ 8)";
 
diff --git a/cheriot/riscv_cheriot_vector_memory_instructions.cc b/cheriot/riscv_cheriot_vector_memory_instructions.cc
index 66f3b41..f56bab5 100644
--- a/cheriot/riscv_cheriot_vector_memory_instructions.cc
+++ b/cheriot/riscv_cheriot_vector_memory_instructions.cc
@@ -256,77 +256,6 @@
 // zero, or negative.
 
 // Source(0): base address.
-// Source(1): vector mask register, vector constant {1..} if not masked.
-// Destination(0): vector destination register.
-void VlUnitStrided(int element_width, const Instruction *inst) {
-  auto *state = static_cast<CheriotState *>(inst->state());
-  auto *rv_vector = state->rv_vector();
-  int start = rv_vector->vstart();
-  auto cap_reg = GetCapSource(inst, 0);
-  if (!CheckCapForMemoryAccess(inst, cap_reg, state)) return;
-  uint64_t base = cap_reg->address();
-  int emul = element_width * rv_vector->vector_length_multiplier() /
-             rv_vector->selected_element_width();
-  if ((emul > 64) || (emul == 0)) {
-    // TODO: signal vector error.
-    LOG(WARNING) << "EMUL (" << emul << ") out of range";
-    return;
-  }
-
-  // Compute total number of elements to be loaded.
-  int num_elements = rv_vector->vector_length();
-  int num_elements_loaded = num_elements - start;
-
-  // Allocate address data buffer.
-  auto *db_factory = inst->state()->db_factory();
-  auto *address_db = db_factory->Allocate<uint64_t>(num_elements_loaded);
-
-  // Allocate the value data buffer that the loaded data is returned in.
-  auto *value_db = db_factory->Allocate(num_elements_loaded * element_width);
-
-  // Get the source mask (stored in a single vector register).
-  auto *src_mask_op = static_cast<RV32VectorSourceOperand *>(inst->Source(1));
-  auto src_masks = src_mask_op->GetRegister(0)->data_buffer()->Get<uint8_t>();
-
-  // Allocate a byte mask data buffer for the load.
-  auto *mask_db = db_factory->Allocate<bool>(num_elements_loaded);
-
-  // Get the spans for addresses and masks.
-  auto addresses = address_db->Get<uint64_t>();
-  auto masks = mask_db->Get<bool>();
-
-  // The vector mask in the vector register is a bit mask. The mask used in
-  // the LoadMemory call is a bool mask so convert the bit masks to bool masks
-  // and compute the element addresses.
-  for (int i = start; i < num_elements; i++) {
-    int index = i >> 3;
-    int offset = i & 0b111;
-    addresses[i - start] = base + i * element_width;
-    masks[i - start] = ((src_masks[index] >> offset) & 0b1) != 0;
-    if (masks[i - start]) {
-      if (!CheckCapBounds(inst, addresses[i - start], element_width, cap_reg,
-                          state)) {
-        address_db->DecRef();
-        mask_db->DecRef();
-        value_db->DecRef();
-        return;
-      }
-    }
-  }
-
-  // Set up the context, and submit the load.
-  auto *context = new VectorLoadContext(value_db, mask_db, element_width, start,
-                                        rv_vector->vector_length());
-  value_db->set_latency(0);
-  state->LoadMemory(inst, address_db, mask_db, element_width, value_db,
-                    inst->child(), context);
-  // Release the context and address_db. The others will be released elsewhere.
-  context->DecRef();
-  address_db->DecRef();
-  rv_vector->clear_vstart();
-}
-
-// Source(0): base address.
 // Source(1): stride size bytes.
 // Source(2): vector mask register, vector constant {1..} if not masked.
 // Destination(0): vector destination register.
diff --git a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
index 17770bb..78b7f03 100644
--- a/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
+++ b/cheriot/test/riscv_cheriot_vector_memory_instructions_test.cc
@@ -68,7 +68,6 @@
 using ::mpact::sim::cheriot::VlSegmentIndexed;
 using ::mpact::sim::cheriot::VlSegmentStrided;
 using ::mpact::sim::cheriot::VlStrided;
-using ::mpact::sim::cheriot::VlUnitStrided;
 using ::mpact::sim::cheriot::Vsetvl;
 using ::mpact::sim::cheriot::VsIndexed;
 using ::mpact::sim::cheriot::Vsm;
@@ -314,16 +313,17 @@
   template <typename T>
   void VectorLoadUnitStridedHelper() {
     // Set up instructions.
-    AppendRegisterOperands({kRs1Name}, {});
+    AppendRegisterOperands({kRs1Name, kRs2Name}, {});
     AppendVectorRegisterOperands({kVmask}, {});
-    SetSemanticFunction(absl::bind_front(&VlUnitStrided,
+    SetSemanticFunction(absl::bind_front(&VlStrided,
                                          /*element_width*/ sizeof(T)));
     // Add the child instruction that performs the register write-back.
     SetChildInstruction();
     SetChildSemanticFunction(&VlChild);
     AppendVectorRegisterOperands(child_instruction_, {}, {kVd});
     // Set up register values.
-    SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}});
+    SetRegisterValues<uint32_t>(
+        {{kRs1Name, kDataLoadAddress}, {kRs2Name, sizeof(T)}});
     SetVectorRegisterValues<uint8_t>(
         {{kVmaskName, Span<const uint8_t>(kA5Mask)}});
     // Iterate over different lmul values.