Account for branches when setting epc when an interrupt occurs

PiperOrigin-RevId: 819222976
Change-Id: I49ee849cc1f520e48cf11243777ac5512701a683
diff --git a/riscv/riscv_top.cc b/riscv/riscv_top.cc
index 186748a..8472281 100644
--- a/riscv/riscv_top.cc
+++ b/riscv/riscv_top.cc
@@ -250,6 +250,14 @@
   // Re-enable the breakpoint.
   (void)rv_action_point_manager_->ap_memory_interface()
       ->WriteBreakpointInstruction(bpt_pc);
+  // Check for interrupt.
+  if (state_->is_interrupt_available()) {
+    uint64_t epc = pc;
+    if (executed) {
+      epc = state_->branch() ? state_->pc_operand()->AsUint64(0) : next_pc;
+    }
+    state_->TakeAvailableInterrupt(epc);  // Will set state_->branch().
+  }
   if (state_->branch()) {
     state_->set_branch(false);
     auto new_pc = state_->pc_operand()->AsUint64(0);
@@ -308,8 +316,11 @@
       state_->AdvanceDelayLines();
       // Check for interrupt.
       if (state_->is_interrupt_available()) {
-        uint64_t epc = (executed ? next_pc : state_->pc_operand()->AsUint64(0));
-        state_->TakeAvailableInterrupt(epc);
+        uint64_t epc = pc;
+        if (executed) {
+          epc = state_->branch() ? state_->pc_operand()->AsUint64(0) : next_pc;
+        }
+        state_->TakeAvailableInterrupt(epc);  // Will set state_->branch().
       }
     } while (!executed);
     count++;
@@ -404,9 +415,12 @@
         state_->AdvanceDelayLines();
         // Check for interrupt.
         if (state_->is_interrupt_available()) {
-          uint64_t epc =
-              (executed ? next_pc : state_->pc_operand()->AsUint64(0));
-          state_->TakeAvailableInterrupt(epc);
+          uint64_t epc = pc;
+          if (executed) {
+            epc =
+                state_->branch() ? state_->pc_operand()->AsUint64(0) : next_pc;
+          }
+          state_->TakeAvailableInterrupt(epc);  // Will set state_->branch().
         }
       } while (!executed);
       // Update counters.