Fix vlm and vsm to respect vl and vstart

This change ensures that the vector load-mask (vlm.v) and vector store-mask (vsm.v) instructions correctly adhere to the vector length (vl) and start element (vstart) registers.

Previously, these instructions used the full byte length of a vector register, regardless of the vl setting. This could lead to incorrect data being loaded or stored.

Now matches Spike vlm.v vsm.v behavior.

PiperOrigin-RevId: 881974860
Change-Id: Ia0fee4ae809ec95b56f837ffc42d39c148a7e2a0
diff --git a/riscv/riscv_vector_memory_instructions.cc b/riscv/riscv_vector_memory_instructions.cc
index 5144796..6dacaef 100644
--- a/riscv/riscv_vector_memory_instructions.cc
+++ b/riscv/riscv_vector_memory_instructions.cc
@@ -279,16 +279,27 @@
   auto* rv_vector = static_cast<RiscVState*>(inst->state())->rv_vector();
   // Compute base address.
   int start = rv_vector->vstart();
-  uint64_t base = GetInstructionSource<uint64_t>(inst, 0) + start;
-  // Compute the number of bytes to be loaded.
-  int num_bytes = rv_vector->vector_register_byte_length() - start;
+  uint64_t base = GetInstructionSource<uint64_t>(inst, 0);
+  // According to Spec Section 7.4, vlm.v/vsm.v are unique:
+  // 1. They transfer ceil(vl/8) bytes.
+  // 2. vstart is interpreted in units of bytes (not elements).
+  int num_bytes = (rv_vector->vector_length() + 7) / 8;
+  int num_bytes_to_load = num_bytes - start;
+  if (start >= rv_vector->vector_register_byte_length()) {
+    rv_vector->clear_vstart();
+    return;
+  }
+  if (num_bytes_to_load <= 0) {
+    rv_vector->clear_vstart();
+    return;
+  }
   // Allocate address data buffer.
   auto* db_factory = inst->state()->db_factory();
-  auto* address_db = db_factory->Allocate<uint64_t>(num_bytes);
+  auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_to_load);
   // Allocate the value data buffer that the loaded data is returned in.
-  auto* value_db = db_factory->Allocate<uint8_t>(num_bytes);
+  auto* value_db = db_factory->Allocate<uint8_t>(num_bytes_to_load);
   // Allocate a byte mask data buffer.
-  auto* mask_db = db_factory->Allocate<bool>(num_bytes);
+  auto* mask_db = db_factory->Allocate<bool>(num_bytes_to_load);
   // Get the spans for addresses and masks.
   auto masks = mask_db->Get<bool>();
   auto addresses = address_db->Get<uint64_t>();
@@ -298,9 +309,8 @@
     masks[i - start] = true;
   }
   // Set up the context, and submit the load.
-  auto* context =
-      new VectorLoadContext(value_db, mask_db, sizeof(uint8_t), start,
-                            rv_vector->vector_register_byte_length());
+  auto* context = new VectorLoadContext(value_db, mask_db, sizeof(uint8_t),
+                                        start, num_bytes);
   auto* rv32_state = static_cast<RiscVState*>(inst->state());
   value_db->set_latency(0);
   rv32_state->LoadMemory(inst, address_db, mask_db, sizeof(uint8_t), value_db,
@@ -803,14 +813,24 @@
   // Compute base address.
   int start = rv_vector->vstart();
   uint64_t base = GetInstructionSource<uint64_t>(inst, 1);
-  // Compute the number of bytes and elements to be stored.
-  int num_bytes = rv_vector->vector_register_byte_length();
-  int num_bytes_stored = num_bytes - start;
+  // According to Spec Section 7.4, vlm.v/vsm.v are unique:
+  // 1. They transfer ceil(vl/8) bytes.
+  // 2. vstart is interpreted in units of bytes (not elements).
+  int num_bytes = (rv_vector->vector_length() + 7) / 8;
+  int num_bytes_to_store = num_bytes - start;
+  if (start >= rv_vector->vector_register_byte_length()) {
+    rv_vector->clear_vstart();
+    return;
+  }
+  if (num_bytes_to_store <= 0) {
+    rv_vector->clear_vstart();
+    return;
+  }
   // Allocate address data buffer.
   auto* db_factory = inst->state()->db_factory();
-  auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_stored);
-  auto* store_data_db = db_factory->Allocate(num_bytes_stored);
-  auto* mask_db = db_factory->Allocate<uint8_t>(num_bytes_stored);
+  auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_to_store);
+  auto* store_data_db = db_factory->Allocate(num_bytes_to_store);
+  auto* mask_db = db_factory->Allocate<bool>(num_bytes_to_store);
   // Get the spans for addresses, masks, and store data.
   auto addresses = address_db->Get<uint64_t>();
   auto masks = mask_db->Get<bool>();
diff --git a/riscv/test/riscv_vector_memory_instructions_test.cc b/riscv/test/riscv_vector_memory_instructions_test.cc
index 180d39c..c906d3c 100644
--- a/riscv/test/riscv_vector_memory_instructions_test.cc
+++ b/riscv/test/riscv_vector_memory_instructions_test.cc
@@ -289,10 +289,12 @@
     child_instruction_->set_semantic_function(fcn);
   }
 
-  // Configure the vector unit according to the vtype and vlen values.
-  void ConfigureVectorUnit(uint32_t vtype, uint32_t vlen) {
+  // Configure the vector unit according to the vtype and avl values.
+  // Note: 'avl' (Application Vector Length) is the requested 'vl'.
+  // This is distinct from the hardware register width (VLEN).
+  void ConfigureVectorUnit(uint32_t vtype, uint32_t avl) {
     Instruction* inst = new Instruction(state_);
-    AppendImmediateOperands<uint32_t>(inst, {vlen, vtype});
+    AppendImmediateOperands<uint32_t>(inst, {avl, vtype});
     SetSemanticFunction(inst, absl::bind_front(&Vsetvl, true, false));
     inst->Execute(nullptr);
     inst->DecRef();
@@ -1546,18 +1548,100 @@
 // Test of vector load mask.
 TEST_F(RV32VInstructionsTest, Vlm) {
   // Set up operands and register values.
+  uint32_t avl = kVectorLengthInBytes * 8;
+  ConfigureVectorUnit(0, avl);
+  int vl = rv_vector_->vector_length();
+  int num_bytes = (vl + 7) / 8;
+
   AppendRegisterOperands({kRs1Name}, {});
   SetSemanticFunction(&Vlm);
   SetChildInstruction();
   AppendVectorRegisterOperands(child_instruction_, {}, {kVd});
   SetChildSemanticFunction(&VlChild);
   SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}});
+
+  // Initialize destination register with 0xff.
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff);
+  }
+
   // Execute instruction.
   instruction_->Execute(nullptr);
   EXPECT_FALSE(rv_vector_->vector_exception());
   auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>();
   for (int i = 0; i < kVectorLengthInBytes; i++) {
-    EXPECT_EQ(i & 0xff, span[i]) << "element: " << i;
+    if (i < num_bytes) {
+      EXPECT_EQ(i & 0xff, span[i]) << "element: " << i;
+    } else {
+      EXPECT_EQ(0xff, span[i]) << "element: " << i;
+    }
+  }
+}
+
+// Test of vector load mask with small vl.
+TEST_F(RV32VInstructionsTest, Vlm_RespectsVl) {
+  // We use vl=11 specifically to test the ceil(vl/8) logic.
+  // ceil(11/8) = 2 bytes. If the implementation used floor or
+  // standard integer division, it would incorrectly result in 1 byte.
+  ConfigureVectorUnit(0, /*avl=*/11);
+  ASSERT_EQ(rv_vector_->vector_length(), 11);
+
+  AppendRegisterOperands({kRs1Name}, {});
+  SetSemanticFunction(&Vlm);
+  SetChildInstruction();
+  AppendVectorRegisterOperands(child_instruction_, {}, {kVd});
+  SetChildSemanticFunction(&VlChild);
+  SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}});
+
+  // Initialize destination register with 0xff.
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff);
+  }
+
+  // Execute instruction.
+  instruction_->Execute(nullptr);
+  EXPECT_FALSE(rv_vector_->vector_exception());
+
+  auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>();
+  // vl=11, ceil(11/8) = 2 bytes.
+  EXPECT_EQ(span[0], 0);
+  EXPECT_EQ(span[1], 1);
+  // Byte 2 should remain 0xff.
+  EXPECT_EQ(span[2], 0xff);
+}
+
+// Test of vector load mask with vstart >= vl/8.
+TEST_F(RV32VInstructionsTest, Vlm_RespectsVstart) {
+  // vl=11, which means ceil(11/8) = 2 bytes are processed.
+  // Set vstart=2, which is the first byte beyond the vl reach.
+  // For vlm.v/vsm.v, vstart is interpreted in units of bytes (Spec
+  // Section 7.4).
+  ConfigureVectorUnit(0, /*avl=*/11);
+  ASSERT_EQ(rv_vector_->vector_length(), 11);
+  rv_vector_->set_vstart(2);
+
+  AppendRegisterOperands({kRs1Name}, {});
+  SetSemanticFunction(&Vlm);
+  SetChildInstruction();
+  AppendVectorRegisterOperands(child_instruction_, {}, {kVd});
+  SetChildSemanticFunction(&VlChild);
+  SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}});
+
+  // Initialize destination register with 0xff.
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff);
+  }
+
+  // Execute instruction.
+  instruction_->Execute(nullptr);
+  EXPECT_FALSE(rv_vector_->vector_exception());
+  // vstart should be cleared.
+  EXPECT_EQ(rv_vector_->vstart(), 0);
+
+  auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>();
+  // No bytes should be loaded.
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    EXPECT_EQ(0xff, span[i]) << "element: " << i;
   }
 }
 
@@ -1755,7 +1839,9 @@
 TEST_F(RV32VInstructionsTest, Vsse64) { VectorStoreStridedHelper<uint64_t>(); }
 
 TEST_F(RV32VInstructionsTest, Vsm) {
-  ConfigureVectorUnit(0b0'0'000'000, /*vlen*/ 1024);
+  ConfigureVectorUnit(0b0'0'000'000, /*avl*/ 1024);
+  int vl = rv_vector_->vector_length();
+  int num_bytes = (vl + 7) / 8;
   // Set up operands and register values.
   AppendVectorRegisterOperands({kVs1}, {});
   AppendRegisterOperands({kRs1Name}, {});
@@ -1764,6 +1850,13 @@
   for (int i = 0; i < kVectorLengthInBytes; i++) {
     vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, i);
   }
+
+  // Zero out memory.
+  auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes);
+  std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes);
+  state_->StoreMemory(instruction_, kDataStoreAddress, zero_db);
+  zero_db->DecRef();
+
   // Execute instruction.
   instruction_->Execute(nullptr);
 
@@ -1774,7 +1867,100 @@
                      nullptr);
   auto span = data_db->Get<uint8_t>();
   for (int i = 0; i < kVectorLengthInBytes; i++) {
-    EXPECT_EQ(static_cast<int>(span[i]), i);
+    if (i < num_bytes) {
+      EXPECT_EQ(static_cast<int>(span[i]), i);
+    } else {
+      EXPECT_EQ(static_cast<int>(span[i]), 0);
+    }
+  }
+  data_db->DecRef();
+}
+
+// Test of vector store mask with small vl.
+TEST_F(RV32VInstructionsTest, Vsm_RespectsVl) {
+  // We use vl=11 specifically to test the ceil(vl/8) logic.
+  // ceil(11/8) = 2 bytes. If the implementation used floor or
+  // standard integer division, it would incorrectly result in 1 byte.
+  ConfigureVectorUnit(0, /*avl=*/11);
+  ASSERT_EQ(rv_vector_->vector_length(), 11);
+
+  // Set up operands and register values.
+  AppendVectorRegisterOperands({kVs1}, {});
+  AppendRegisterOperands({kRs1Name}, {});
+  SetSemanticFunction(&Vsm);
+  SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}});
+
+  // Initialize v1 with 0x55, 0xAA, 0x33, 0xCC, ...
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, (i % 2 == 0) ? 0x55 : 0xaa);
+  }
+
+  // Zero out memory.
+  auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes);
+  std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes);
+  state_->StoreMemory(instruction_, kDataStoreAddress, zero_db);
+  zero_db->DecRef();
+
+  // Execute instruction.
+  instruction_->Execute(nullptr);
+
+  // Verify result.
+  EXPECT_FALSE(rv_vector_->vector_exception());
+  auto* data_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes);
+  state_->LoadMemory(instruction_, kDataStoreAddress, data_db, nullptr,
+                     nullptr);
+  auto span = data_db->Get<uint8_t>();
+  // vl=11, ceil(11/8) = 2 bytes.
+  EXPECT_EQ(span[0], 0x55);
+  EXPECT_EQ(span[1], 0xaa);
+  // Byte 2 should remain 0.
+  EXPECT_EQ(span[2], 0);
+
+  data_db->DecRef();
+}
+
+// Test of vector store mask with vstart >= vl/8.
+TEST_F(RV32VInstructionsTest, Vsm_RespectsVstart) {
+  // vl=11, which means ceil(11/8) = 2 bytes are processed.
+  // Set vstart=2, which is the first byte beyond the vl reach.
+  // For vlm.v/vsm.v, vstart is interpreted in units of bytes (Spec
+  // Section 7.4).
+  ConfigureVectorUnit(0, /*avl=*/11);
+  ASSERT_EQ(rv_vector_->vector_length(), 11);
+  rv_vector_->set_vstart(2);
+
+  // Set up operands and register values.
+  AppendVectorRegisterOperands({kVs1}, {});
+  AppendRegisterOperands({kRs1Name}, {});
+  SetSemanticFunction(&Vsm);
+  SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}});
+
+  // Initialize v1 with 0x55, 0xAA, ...
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, (i % 2 == 0) ? 0x55 : 0xaa);
+  }
+
+  // Zero out memory.
+  auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes);
+  std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes);
+  state_->StoreMemory(instruction_, kDataStoreAddress, zero_db);
+  zero_db->DecRef();
+
+  // Execute instruction.
+  instruction_->Execute(nullptr);
+
+  // Verify result.
+  EXPECT_FALSE(rv_vector_->vector_exception());
+  // vstart should be cleared.
+  EXPECT_EQ(rv_vector_->vstart(), 0);
+
+  auto* data_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes);
+  state_->LoadMemory(instruction_, kDataStoreAddress, data_db, nullptr,
+                     nullptr);
+  auto span = data_db->Get<uint8_t>();
+  // Memory should remain zero.
+  for (int i = 0; i < kVectorLengthInBytes; i++) {
+    EXPECT_EQ(0, span[i]) << "element: " << i;
   }
   data_db->DecRef();
 }