Fix vlm and vsm to respect vl and vstart This change ensures that the vector load-mask (vlm.v) and vector store-mask (vsm.v) instructions correctly adhere to the vector length (vl) and start element (vstart) registers. Previously, these instructions used the full byte length of a vector register, regardless of the vl setting. This could lead to incorrect data being loaded or stored. Now matches Spike vlm.v vsm.v behavior. PiperOrigin-RevId: 881974860 Change-Id: Ia0fee4ae809ec95b56f837ffc42d39c148a7e2a0
diff --git a/riscv/riscv_vector_memory_instructions.cc b/riscv/riscv_vector_memory_instructions.cc index 5144796..6dacaef 100644 --- a/riscv/riscv_vector_memory_instructions.cc +++ b/riscv/riscv_vector_memory_instructions.cc
@@ -279,16 +279,27 @@ auto* rv_vector = static_cast<RiscVState*>(inst->state())->rv_vector(); // Compute base address. int start = rv_vector->vstart(); - uint64_t base = GetInstructionSource<uint64_t>(inst, 0) + start; - // Compute the number of bytes to be loaded. - int num_bytes = rv_vector->vector_register_byte_length() - start; + uint64_t base = GetInstructionSource<uint64_t>(inst, 0); + // According to Spec Section 7.4, vlm.v/vsm.v are unique: + // 1. They transfer ceil(vl/8) bytes. + // 2. vstart is interpreted in units of bytes (not elements). + int num_bytes = (rv_vector->vector_length() + 7) / 8; + int num_bytes_to_load = num_bytes - start; + if (start >= rv_vector->vector_register_byte_length()) { + rv_vector->clear_vstart(); + return; + } + if (num_bytes_to_load <= 0) { + rv_vector->clear_vstart(); + return; + } // Allocate address data buffer. auto* db_factory = inst->state()->db_factory(); - auto* address_db = db_factory->Allocate<uint64_t>(num_bytes); + auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_to_load); // Allocate the value data buffer that the loaded data is returned in. - auto* value_db = db_factory->Allocate<uint8_t>(num_bytes); + auto* value_db = db_factory->Allocate<uint8_t>(num_bytes_to_load); // Allocate a byte mask data buffer. - auto* mask_db = db_factory->Allocate<bool>(num_bytes); + auto* mask_db = db_factory->Allocate<bool>(num_bytes_to_load); // Get the spans for addresses and masks. auto masks = mask_db->Get<bool>(); auto addresses = address_db->Get<uint64_t>(); @@ -298,9 +309,8 @@ masks[i - start] = true; } // Set up the context, and submit the load. - auto* context = - new VectorLoadContext(value_db, mask_db, sizeof(uint8_t), start, - rv_vector->vector_register_byte_length()); + auto* context = new VectorLoadContext(value_db, mask_db, sizeof(uint8_t), + start, num_bytes); auto* rv32_state = static_cast<RiscVState*>(inst->state()); value_db->set_latency(0); rv32_state->LoadMemory(inst, address_db, mask_db, sizeof(uint8_t), value_db, @@ -803,14 +813,24 @@ // Compute base address. int start = rv_vector->vstart(); uint64_t base = GetInstructionSource<uint64_t>(inst, 1); - // Compute the number of bytes and elements to be stored. - int num_bytes = rv_vector->vector_register_byte_length(); - int num_bytes_stored = num_bytes - start; + // According to Spec Section 7.4, vlm.v/vsm.v are unique: + // 1. They transfer ceil(vl/8) bytes. + // 2. vstart is interpreted in units of bytes (not elements). + int num_bytes = (rv_vector->vector_length() + 7) / 8; + int num_bytes_to_store = num_bytes - start; + if (start >= rv_vector->vector_register_byte_length()) { + rv_vector->clear_vstart(); + return; + } + if (num_bytes_to_store <= 0) { + rv_vector->clear_vstart(); + return; + } // Allocate address data buffer. auto* db_factory = inst->state()->db_factory(); - auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_stored); - auto* store_data_db = db_factory->Allocate(num_bytes_stored); - auto* mask_db = db_factory->Allocate<uint8_t>(num_bytes_stored); + auto* address_db = db_factory->Allocate<uint64_t>(num_bytes_to_store); + auto* store_data_db = db_factory->Allocate(num_bytes_to_store); + auto* mask_db = db_factory->Allocate<bool>(num_bytes_to_store); // Get the spans for addresses, masks, and store data. auto addresses = address_db->Get<uint64_t>(); auto masks = mask_db->Get<bool>();
diff --git a/riscv/test/riscv_vector_memory_instructions_test.cc b/riscv/test/riscv_vector_memory_instructions_test.cc index 180d39c..c906d3c 100644 --- a/riscv/test/riscv_vector_memory_instructions_test.cc +++ b/riscv/test/riscv_vector_memory_instructions_test.cc
@@ -289,10 +289,12 @@ child_instruction_->set_semantic_function(fcn); } - // Configure the vector unit according to the vtype and vlen values. - void ConfigureVectorUnit(uint32_t vtype, uint32_t vlen) { + // Configure the vector unit according to the vtype and avl values. + // Note: 'avl' (Application Vector Length) is the requested 'vl'. + // This is distinct from the hardware register width (VLEN). + void ConfigureVectorUnit(uint32_t vtype, uint32_t avl) { Instruction* inst = new Instruction(state_); - AppendImmediateOperands<uint32_t>(inst, {vlen, vtype}); + AppendImmediateOperands<uint32_t>(inst, {avl, vtype}); SetSemanticFunction(inst, absl::bind_front(&Vsetvl, true, false)); inst->Execute(nullptr); inst->DecRef(); @@ -1546,18 +1548,100 @@ // Test of vector load mask. TEST_F(RV32VInstructionsTest, Vlm) { // Set up operands and register values. + uint32_t avl = kVectorLengthInBytes * 8; + ConfigureVectorUnit(0, avl); + int vl = rv_vector_->vector_length(); + int num_bytes = (vl + 7) / 8; + AppendRegisterOperands({kRs1Name}, {}); SetSemanticFunction(&Vlm); SetChildInstruction(); AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); SetChildSemanticFunction(&VlChild); SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + + // Initialize destination register with 0xff. + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff); + } + // Execute instruction. instruction_->Execute(nullptr); EXPECT_FALSE(rv_vector_->vector_exception()); auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); for (int i = 0; i < kVectorLengthInBytes; i++) { - EXPECT_EQ(i & 0xff, span[i]) << "element: " << i; + if (i < num_bytes) { + EXPECT_EQ(i & 0xff, span[i]) << "element: " << i; + } else { + EXPECT_EQ(0xff, span[i]) << "element: " << i; + } + } +} + +// Test of vector load mask with small vl. +TEST_F(RV32VInstructionsTest, Vlm_RespectsVl) { + // We use vl=11 specifically to test the ceil(vl/8) logic. + // ceil(11/8) = 2 bytes. If the implementation used floor or + // standard integer division, it would incorrectly result in 1 byte. + ConfigureVectorUnit(0, /*avl=*/11); + ASSERT_EQ(rv_vector_->vector_length(), 11); + + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vlm); + SetChildInstruction(); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + SetChildSemanticFunction(&VlChild); + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + + // Initialize destination register with 0xff. + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff); + } + + // Execute instruction. + instruction_->Execute(nullptr); + EXPECT_FALSE(rv_vector_->vector_exception()); + + auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + // vl=11, ceil(11/8) = 2 bytes. + EXPECT_EQ(span[0], 0); + EXPECT_EQ(span[1], 1); + // Byte 2 should remain 0xff. + EXPECT_EQ(span[2], 0xff); +} + +// Test of vector load mask with vstart >= vl/8. +TEST_F(RV32VInstructionsTest, Vlm_RespectsVstart) { + // vl=11, which means ceil(11/8) = 2 bytes are processed. + // Set vstart=2, which is the first byte beyond the vl reach. + // For vlm.v/vsm.v, vstart is interpreted in units of bytes (Spec + // Section 7.4). + ConfigureVectorUnit(0, /*avl=*/11); + ASSERT_EQ(rv_vector_->vector_length(), 11); + rv_vector_->set_vstart(2); + + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vlm); + SetChildInstruction(); + AppendVectorRegisterOperands(child_instruction_, {}, {kVd}); + SetChildSemanticFunction(&VlChild); + SetRegisterValues<uint32_t>({{kRs1Name, kDataLoadAddress}}); + + // Initialize destination register with 0xff. + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVd]->data_buffer()->Set<uint8_t>(i, 0xff); + } + + // Execute instruction. + instruction_->Execute(nullptr); + EXPECT_FALSE(rv_vector_->vector_exception()); + // vstart should be cleared. + EXPECT_EQ(rv_vector_->vstart(), 0); + + auto span = vreg_[kVd]->data_buffer()->Get<uint8_t>(); + // No bytes should be loaded. + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(0xff, span[i]) << "element: " << i; } } @@ -1755,7 +1839,9 @@ TEST_F(RV32VInstructionsTest, Vsse64) { VectorStoreStridedHelper<uint64_t>(); } TEST_F(RV32VInstructionsTest, Vsm) { - ConfigureVectorUnit(0b0'0'000'000, /*vlen*/ 1024); + ConfigureVectorUnit(0b0'0'000'000, /*avl*/ 1024); + int vl = rv_vector_->vector_length(); + int num_bytes = (vl + 7) / 8; // Set up operands and register values. AppendVectorRegisterOperands({kVs1}, {}); AppendRegisterOperands({kRs1Name}, {}); @@ -1764,6 +1850,13 @@ for (int i = 0; i < kVectorLengthInBytes; i++) { vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, i); } + + // Zero out memory. + auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes); + state_->StoreMemory(instruction_, kDataStoreAddress, zero_db); + zero_db->DecRef(); + // Execute instruction. instruction_->Execute(nullptr); @@ -1774,7 +1867,100 @@ nullptr); auto span = data_db->Get<uint8_t>(); for (int i = 0; i < kVectorLengthInBytes; i++) { - EXPECT_EQ(static_cast<int>(span[i]), i); + if (i < num_bytes) { + EXPECT_EQ(static_cast<int>(span[i]), i); + } else { + EXPECT_EQ(static_cast<int>(span[i]), 0); + } + } + data_db->DecRef(); +} + +// Test of vector store mask with small vl. +TEST_F(RV32VInstructionsTest, Vsm_RespectsVl) { + // We use vl=11 specifically to test the ceil(vl/8) logic. + // ceil(11/8) = 2 bytes. If the implementation used floor or + // standard integer division, it would incorrectly result in 1 byte. + ConfigureVectorUnit(0, /*avl=*/11); + ASSERT_EQ(rv_vector_->vector_length(), 11); + + // Set up operands and register values. + AppendVectorRegisterOperands({kVs1}, {}); + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vsm); + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + + // Initialize v1 with 0x55, 0xAA, 0x33, 0xCC, ... + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, (i % 2 == 0) ? 0x55 : 0xaa); + } + + // Zero out memory. + auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes); + state_->StoreMemory(instruction_, kDataStoreAddress, zero_db); + zero_db->DecRef(); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Verify result. + EXPECT_FALSE(rv_vector_->vector_exception()); + auto* data_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + state_->LoadMemory(instruction_, kDataStoreAddress, data_db, nullptr, + nullptr); + auto span = data_db->Get<uint8_t>(); + // vl=11, ceil(11/8) = 2 bytes. + EXPECT_EQ(span[0], 0x55); + EXPECT_EQ(span[1], 0xaa); + // Byte 2 should remain 0. + EXPECT_EQ(span[2], 0); + + data_db->DecRef(); +} + +// Test of vector store mask with vstart >= vl/8. +TEST_F(RV32VInstructionsTest, Vsm_RespectsVstart) { + // vl=11, which means ceil(11/8) = 2 bytes are processed. + // Set vstart=2, which is the first byte beyond the vl reach. + // For vlm.v/vsm.v, vstart is interpreted in units of bytes (Spec + // Section 7.4). + ConfigureVectorUnit(0, /*avl=*/11); + ASSERT_EQ(rv_vector_->vector_length(), 11); + rv_vector_->set_vstart(2); + + // Set up operands and register values. + AppendVectorRegisterOperands({kVs1}, {}); + AppendRegisterOperands({kRs1Name}, {}); + SetSemanticFunction(&Vsm); + SetRegisterValues<uint32_t>({{kRs1Name, kDataStoreAddress}}); + + // Initialize v1 with 0x55, 0xAA, ... + for (int i = 0; i < kVectorLengthInBytes; i++) { + vreg_[kVs1]->data_buffer()->Set<uint8_t>(i, (i % 2 == 0) ? 0x55 : 0xaa); + } + + // Zero out memory. + auto* zero_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + std::memset(zero_db->raw_ptr(), 0, kVectorLengthInBytes); + state_->StoreMemory(instruction_, kDataStoreAddress, zero_db); + zero_db->DecRef(); + + // Execute instruction. + instruction_->Execute(nullptr); + + // Verify result. + EXPECT_FALSE(rv_vector_->vector_exception()); + // vstart should be cleared. + EXPECT_EQ(rv_vector_->vstart(), 0); + + auto* data_db = state_->db_factory()->Allocate<uint8_t>(kVectorLengthInBytes); + state_->LoadMemory(instruction_, kDataStoreAddress, data_db, nullptr, + nullptr); + auto span = data_db->Get<uint8_t>(); + // Memory should remain zero. + for (int i = 0; i < kVectorLengthInBytes; i++) { + EXPECT_EQ(0, span[i]) << "element: " << i; } data_db->DecRef(); }