Adds capability to specify constraints between two fields/overlays.

In the binary decoder specification, you can now use constraints
between the value of two fields/overlays, not just a field or overlay,
and a constant.

PiperOrigin-RevId: 693798358
Change-Id: I64f38ddf7b4d3f6adb9813607940b09a5a7f6c3e
diff --git a/mpact/sim/decoder/BinFormat.g4 b/mpact/sim/decoder/BinFormat.g4
index 57872d4..912fe0a 100644
--- a/mpact/sim/decoder/BinFormat.g4
+++ b/mpact/sim/decoder/BinFormat.g4
@@ -169,7 +169,7 @@
 // be equal, not equal, greater/less, etc. to a number given by a numeric
 // literal.
 field_constraint
-  : field_name=IDENT constraint_op value=number
+  : field_name=IDENT constraint_op (value=number | rhs_field_name=IDENT)
   ;
 
 constraint_op
diff --git a/mpact/sim/decoder/bin_format_visitor.cc b/mpact/sim/decoder/bin_format_visitor.cc
index 04d944e..d278ca2 100644
--- a/mpact/sim/decoder/bin_format_visitor.cc
+++ b/mpact/sim/decoder/bin_format_visitor.cc
@@ -98,6 +98,15 @@
 
 )foo";
 
+BinFormatVisitor::BinFormatVisitor() {
+  constraint_string_to_type_.emplace("==", ConstraintType::kEq);
+  constraint_string_to_type_.emplace("!=", ConstraintType::kNe);
+  constraint_string_to_type_.emplace("<", ConstraintType::kLt);
+  constraint_string_to_type_.emplace("<=", ConstraintType::kLe);
+  constraint_string_to_type_.emplace(">", ConstraintType::kGt);
+  constraint_string_to_type_.emplace(">=", ConstraintType::kGe);
+}
+
 BinFormatVisitor::~BinFormatVisitor() {
   for (auto *wrapper : antlr_parser_wrappers_) {
     delete wrapper;
@@ -975,48 +984,43 @@
   // Constraints are based on field names ==/!=/>/>=/</<= to a value.
   std::string field_name = ctx->field_name->getText();
   std::string op = ctx->constraint_op()->getText();
-  // If the number is binary, let's get its length too and check against the
-  // field width.
-  if (ctx->number()->BIN_NUMBER() != nullptr) {
-    int length = ParseBinaryNum(ctx->number()->BIN_NUMBER()).width;
-    auto *field = format->GetField(field_name);
-    auto *overlay = format->GetOverlay(field_name);
-    if (field != nullptr) {
-      if (field->width != length) {
-        error_listener_->semanticWarning(
-            file_names_[context_file_map_.at(ctx)], ctx->start,
-            absl::StrCat("Field '", field_name, "' has width ", field->width,
-                         " but constraint value is ", length, " bits"));
-      }
-    } else if (overlay != nullptr) {
-      if (overlay->computed_width() != length) {
-        error_listener_->semanticWarning(
-            file_names_[context_file_map_.at(ctx)], ctx->start,
-            absl::StrCat("Overlay '", field_name, "' has width ",
-                         overlay->computed_width(), " but constraint value is ",
-                         length, " bits"));
+  absl::Status status;
+  ConstraintType constraint_type = constraint_string_to_type_.at(op);
+  if (ctx->rhs_field_name != nullptr) {
+    std::string rhs_name = ctx->rhs_field_name->getText();
+    status = inst_encoding->AddOtherConstraint(constraint_type, field_name,
+                                               rhs_name);
+  } else {
+    // If the number is binary, let's get its length too and check against the
+    // field width.
+    if (ctx->number()->BIN_NUMBER() != nullptr) {
+      int length = ParseBinaryNum(ctx->number()->BIN_NUMBER()).width;
+      auto *field = format->GetField(field_name);
+      auto *overlay = format->GetOverlay(field_name);
+      if (field != nullptr) {
+        if (field->width != length) {
+          error_listener_->semanticWarning(
+              file_names_[context_file_map_.at(ctx)], ctx->start,
+              absl::StrCat("Field '", field_name, "' has width ", field->width,
+                           " but constraint value is ", length, " bits"));
+        }
+      } else if (overlay != nullptr) {
+        if (overlay->computed_width() != length) {
+          error_listener_->semanticWarning(
+              file_names_[context_file_map_.at(ctx)], ctx->start,
+              absl::StrCat("Overlay '", field_name, "' has width ",
+                           overlay->computed_width(),
+                           " but constraint value is ", length, " bits"));
+        }
       }
     }
-  }
-  int value = ConvertToInt(ctx->number());
-  absl::Status status;
-  if (op == "==") {
-    status = inst_encoding->AddEqualConstraint(field_name, value);
-  } else if (op == "!=") {
-    status = inst_encoding->AddOtherConstraint(ConstraintType::kNe, field_name,
-                                               value);
-  } else if (op == ">") {
-    status = inst_encoding->AddOtherConstraint(ConstraintType::kGt, field_name,
-                                               value);
-  } else if (op == ">=") {
-    status = inst_encoding->AddOtherConstraint(ConstraintType::kGe, field_name,
-                                               value);
-  } else if (op == "<") {
-    status = inst_encoding->AddOtherConstraint(ConstraintType::kLt, field_name,
-                                               value);
-  } else if (op == "<=") {
-    status = inst_encoding->AddOtherConstraint(ConstraintType::kLe, field_name,
-                                               value);
+    int value = ConvertToInt(ctx->number());
+    if (constraint_type == ConstraintType::kEq) {
+      status = inst_encoding->AddEqualConstraint(field_name, value);
+    } else {
+      status =
+          inst_encoding->AddOtherConstraint(constraint_type, field_name, value);
+    }
   }
   if (!status.ok()) {
     error_listener_->semanticError(file_names_[context_file_map_.at(ctx)],
diff --git a/mpact/sim/decoder/bin_format_visitor.h b/mpact/sim/decoder/bin_format_visitor.h
index cd5ae1a..f6e9e63 100644
--- a/mpact/sim/decoder/bin_format_visitor.h
+++ b/mpact/sim/decoder/bin_format_visitor.h
@@ -40,6 +40,8 @@
 namespace decoder {
 namespace bin_format {
 
+enum class ConstraintType : int { kEq = 0, kNe, kLt, kLe, kGt, kGe };
+
 // This struct holds information about a range assignment in an instruction
 // generator.
 struct RangeAssignmentInfo {
@@ -75,7 +77,7 @@
     std::string cc_output;
   };
 
-  BinFormatVisitor() = default;
+  BinFormatVisitor();
   ~BinFormatVisitor();
 
   // Entry point for processing a source_stream input, generating any output
@@ -159,6 +161,8 @@
   absl::flat_hash_map<std::string, DecoderDefCtx *> decoder_decl_map_;
   // AntlrParserWrapper vector.
   std::vector<BinFmtAntlrParserWrapper *> antlr_parser_wrappers_;
+  // Map from comparator string to constraint type.
+  absl::flat_hash_map<std::string, ConstraintType> constraint_string_to_type_;
 };
 
 }  // namespace bin_format
diff --git a/mpact/sim/decoder/encoding_group.cc b/mpact/sim/decoder/encoding_group.cc
index 21a08e1..39da46b 100644
--- a/mpact/sim/decoder/encoding_group.cc
+++ b/mpact/sim/decoder/encoding_group.cc
@@ -513,8 +513,8 @@
                                     "==", connector, &condition);
   count += EmitConstraintConditions(encoding->equal_extracted_constraints(),
                                     "==", connector, &condition);
-  count += EmitConstraintConditions(encoding->other_constraints(),
-                                    "!=", connector, &condition);
+  count += EmitOtherConstraintConditions(encoding->other_constraints(),
+                                         connector, &condition);
 
   // Ensure the number of parentheses are appropriate to the number of
   // conjunctions in the if statement.
@@ -532,6 +532,66 @@
   return count != 0 ? 1 : 0;
 }
 
+void EncodingGroup::EmitFieldExtraction(
+    const Field *field, const std::string &indent_str,
+    absl::flat_hash_set<std::string> &extracted,
+    std::string *definitions_ptr) const {
+  std::string name = absl::StrCat(field->name, "_value");
+  if (!extracted.contains(name)) {
+    std::string data_type;
+    if (field->width > inst_group_->width()) {
+      auto shift = absl::bit_width(static_cast<unsigned>(field->width)) - 1;
+      if (absl::popcount(static_cast<unsigned>(field->width)) > 1) shift++;
+      shift = std::max(shift, 3);
+      if (shift > 6) {
+        LOG(ERROR) << "Field '" << field->name << "' width: " << field->width
+                   << " > 64 bits";
+        data_type =
+            absl::StrCat("#error field width ", field->width, " > 64 bits");
+      } else {
+        data_type = absl::StrCat("uint", 1 << shift, "_t");
+      }
+    } else {
+      data_type = inst_word_type_;
+    }
+    uint64_t mask = ((1ULL << field->width) - 1);
+    absl::StrAppend(definitions_ptr, indent_str, data_type, " ", name,
+                    " = (inst_word >> ", field->low, ") & 0x", absl::Hex(mask),
+                    ";\n");
+    extracted.insert(name);
+  }
+}
+
+void EncodingGroup::EmitOverlayExtraction(
+    const Overlay *overlay, const std::string &indent_str,
+    absl::flat_hash_set<std::string> &extracted,
+    std::string *definitions_ptr) const {
+  std::string name = absl::StrCat(overlay->name(), "_value");
+  if (!extracted.contains(name)) {
+    auto ovl_width = overlay->declared_width();
+    std::string data_type;
+    if (ovl_width > inst_group_->width()) {
+      auto shift = absl::bit_width(static_cast<unsigned>(ovl_width)) - 1;
+      if (absl::popcount(static_cast<unsigned>(ovl_width)) > 1) shift++;
+      shift = std::max(shift, 3);
+      if (shift > 6) {
+        LOG(ERROR) << "Field '" << overlay->name() << "' width: " << ovl_width
+                   << " > 64 bits";
+        data_type =
+            absl::StrCat("#error overlay width ", ovl_width, " > 64 bits");
+      } else {
+        data_type = absl::StrCat("uint", 1 << shift, "_t");
+      }
+    } else {
+      data_type = inst_word_type_;
+    }
+    absl::StrAppend(definitions_ptr, indent_str, data_type, " ", name, ";\n");
+    absl::StrAppend(definitions_ptr, indent_str,
+                    overlay->WriteSimpleValueExtractor("inst_word", name));
+    extracted.insert(name);
+  }
+}
+
 void EncodingGroup::EmitExtractions(
     int indent, const std::vector<Constraint *> &constraints,
     absl::flat_hash_set<std::string> &extracted,
@@ -544,62 +604,53 @@
   for (auto const *constraint : constraints) {
     if (constraint->can_ignore) continue;
     if (constraint->field != nullptr) {
-      Field *field = constraint->field;
-      std::string name = absl::StrCat(field->name, "_value");
-      if (!extracted.contains(name)) {
-        std::string data_type;
-        if (field->width > inst_group_->width()) {
-          auto shift = absl::bit_width(static_cast<unsigned>(field->width)) - 1;
-          if (absl::popcount(static_cast<unsigned>(field->width)) > 1) shift++;
-          shift = std::max(shift, 3);
-          if (shift > 6) {
-            LOG(ERROR) << "Field '" << field->name
-                       << "' width: " << field->width << " > 64 bits";
-            data_type =
-                absl::StrCat("#error field width ", field->width, " > 64 bits");
-          } else {
-            data_type = absl::StrCat("uint", 1 << shift, "_t");
-          }
-        } else {
-          data_type = inst_word_type_;
-        }
-        uint64_t mask = ((1ULL << field->width) - 1);
-        absl::StrAppend(definitions_ptr, indent_str, data_type, " ", name,
-                        " = (inst_word >> ", field->low, ") & 0x",
-                        absl::Hex(mask), ";\n");
-        extracted.insert(name);
-      }
+      EmitFieldExtraction(constraint->field, indent_str, extracted,
+                          definitions_ptr);
     } else {
-      Overlay *overlay = constraint->overlay;
-      std::string name = absl::StrCat(overlay->name(), "_value");
-      if (!extracted.contains(name)) {
-        auto ovl_width = overlay->declared_width();
-        std::string data_type;
-        if (ovl_width > inst_group_->width()) {
-          auto shift = absl::bit_width(static_cast<unsigned>(ovl_width)) - 1;
-          if (absl::popcount(static_cast<unsigned>(ovl_width)) > 1) shift++;
-          shift = std::max(shift, 3);
-          if (shift > 6) {
-            LOG(ERROR) << "Field '" << overlay->name()
-                       << "' width: " << ovl_width << " > 64 bits";
-            data_type =
-                absl::StrCat("#error overlay width ", ovl_width, " > 64 bits");
-          } else {
-            data_type = absl::StrCat("uint", 1 << shift, "_t");
-          }
-        } else {
-          data_type = inst_word_type_;
-        }
-        absl::StrAppend(definitions_ptr, indent_str, data_type, " ", name,
-                        ";\n");
-        absl::StrAppend(definitions_ptr, indent_str,
-                        overlay->WriteSimpleValueExtractor("inst_word", name));
-        extracted.insert(name);
-      }
+      EmitOverlayExtraction(constraint->overlay, indent_str, extracted,
+                            definitions_ptr);
+    }
+    if (constraint->rhs_field != nullptr) {
+      EmitFieldExtraction(constraint->rhs_field, indent_str, extracted,
+                          definitions_ptr);
+    } else if (constraint->rhs_overlay != nullptr) {
+      EmitOverlayExtraction(constraint->rhs_overlay, indent_str, extracted,
+                            definitions_ptr);
     }
   }
 }
 
+int EncodingGroup::EmitOtherConstraintConditions(
+    const std::vector<Constraint *> &constraints, std::string &connector,
+    std::string *condition) const {
+  int count = 0;
+  for (auto const *constraint : constraints) {
+    if (constraint->can_ignore) continue;
+
+    std::string comparison(kComparison[static_cast<int>(constraint->type)]);
+    std::string lhs_name = absl::StrCat((constraint->field != nullptr)
+                                            ? constraint->field->name
+                                            : constraint->overlay->name(),
+                                        "_value");
+    std::string rhs;
+    if ((constraint->rhs_field != nullptr) ||
+        (constraint->rhs_overlay != nullptr)) {
+      rhs = absl::StrCat((constraint->rhs_field != nullptr)
+                             ? constraint->rhs_field->name
+                             : constraint->rhs_overlay->name(),
+                         "_value");
+    } else {
+      rhs = absl::StrCat("0x", absl::Hex(constraint->value));
+    }
+
+    absl::StrAppend(condition, connector, "(", lhs_name, " ", comparison, " ",
+                    rhs, ")");
+    connector = " &&\n          ";
+    count++;
+  }
+  return count;
+}
+
 int EncodingGroup::EmitConstraintConditions(
     const std::vector<Constraint *> &constraints, absl::string_view comparison,
     std::string &connector, std::string *condition) const {
@@ -613,7 +664,7 @@
                                     "_value");
     absl::StrAppend(condition, connector, "(", name, " ", comparison, " 0x",
                     absl::Hex(constraint->value), ")");
-    connector = " &&\n      ";
+    connector = " &&\n          ";
     count++;
   }
   return count;
diff --git a/mpact/sim/decoder/encoding_group.h b/mpact/sim/decoder/encoding_group.h
index fc59a75..9204b91 100644
--- a/mpact/sim/decoder/encoding_group.h
+++ b/mpact/sim/decoder/encoding_group.h
@@ -23,6 +23,7 @@
 #include "absl/container/flat_hash_set.h"
 #include "absl/strings/string_view.h"
 #include "mpact/sim/decoder/extract.h"
+#include "mpact/sim/decoder/format.h"
 
 namespace mpact {
 namespace sim {
@@ -32,6 +33,8 @@
 class InstructionGroup;
 class InstructionEncoding;
 struct Constraint;
+class Field;
+class Overlay;
 
 // The encoding group is a class that allows instruction encodings to be grouped
 // together to facilitate breaking the instruction encodings into a tree like
@@ -117,6 +120,13 @@
   void ProcessConstraint(const absl::flat_hash_set<std::string> &extracted,
                          Constraint *constraint,
                          std::string *definitions_ptr) const;
+  void EmitFieldExtraction(const Field *field, const std::string &indent_str,
+                           absl::flat_hash_set<std::string> &extracted,
+                           std::string *definitions_ptr) const;
+  void EmitOverlayExtraction(const Overlay *overlay,
+                             const std::string &indent_str,
+                             absl::flat_hash_set<std::string> &extracted,
+                             std::string *definitions_ptr) const;
   void EmitExtractions(int indent, const std::vector<Constraint *> &constraints,
                        absl::flat_hash_set<std::string> &extracted,
                        std::string *definitions_ptr) const;
@@ -124,6 +134,10 @@
                                absl::string_view comparison,
                                std::string &connector,
                                std::string *condition) const;
+  int EmitOtherConstraintConditions(
+      const std::vector<Constraint *> &constraints, std::string &connector,
+      std::string *condition) const;
+
   InstructionGroup *inst_group_ = nullptr;
   EncodingGroup *parent_ = nullptr;
   uint64_t varying_ = 0;
diff --git a/mpact/sim/decoder/instruction_encoding.cc b/mpact/sim/decoder/instruction_encoding.cc
index de9c316..9b91ff2 100644
--- a/mpact/sim/decoder/instruction_encoding.cc
+++ b/mpact/sim/decoder/instruction_encoding.cc
@@ -20,6 +20,7 @@
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
 #include "absl/strings/str_cat.h"
+#include "mpact/sim/decoder/bin_format_visitor.h"
 #include "mpact/sim/decoder/format.h"
 
 namespace mpact {
@@ -55,6 +56,54 @@
 }
 
 absl::StatusOr<Constraint *> InstructionEncoding::CreateConstraint(
+    ConstraintType type, std::string lhs_name, std::string rhs_name) {
+  Constraint constraint;
+  constraint.type = type;
+  // Check if the field name is indeed a field.
+  auto *lhs_field = format_->GetField(lhs_name);
+  if (lhs_field != nullptr) {
+    if (lhs_field->width >= 64) {
+      return absl::OutOfRangeError(absl::StrCat(
+          "Field '", lhs_field->name,
+          "' is too wide to create constraint - ust be <= 64 bits"));
+    }
+    constraint.field = lhs_field;
+  } else {
+    // If not a field, is it an overlay?
+    auto *lhs_overlay = format_->GetOverlay(lhs_name);
+    if (lhs_overlay == nullptr) {
+      // If neither, it's an error.
+      return absl::NotFoundError(absl::StrCat(
+          "Format '", format_->name(),
+          "' does not contain a field or overlay named ", lhs_name));
+    }
+    constraint.overlay = lhs_overlay;
+  }
+  // Check if the field name is indeed a field.
+  auto *rhs_field = format_->GetField(rhs_name);
+  if (rhs_field != nullptr) {
+    if (rhs_field->width >= 64) {
+      return absl::OutOfRangeError(absl::StrCat(
+          "Field '", rhs_field->name,
+          "' is too wide to create constraint - ust be <= 64 bits"));
+    }
+    constraint.rhs_field = rhs_field;
+  } else {
+    // If not a field, is it an overlay?
+    auto *rhs_overlay = format_->GetOverlay(rhs_name);
+    if (rhs_overlay == nullptr) {
+      // If neither, it's an error.
+      return absl::NotFoundError(absl::StrCat(
+          "Format '", format_->name(),
+          "' does not contain a field or overlay named ", rhs_name));
+    }
+    constraint.rhs_overlay = rhs_overlay;
+  }
+  Constraint *result = new Constraint(constraint);
+  return result;
+}
+
+absl::StatusOr<Constraint *> InstructionEncoding::CreateConstraint(
     ConstraintType type, std::string field_name, int64_t value) {
   // Check if the field name is indeed a field.
   auto *field = format_->GetField(field_name);
@@ -192,6 +241,16 @@
   return absl::OkStatus();
 }
 
+absl::Status InstructionEncoding::AddOtherConstraint(
+    ConstraintType type, const std::string &lhs_name,
+    const std::string &rhs_name) {
+  auto res = CreateConstraint(type, lhs_name, rhs_name);
+  if (!res.ok()) return res.status();
+  auto *constraint = res.value();
+  other_constraints_.push_back(constraint);
+  return absl::OkStatus();
+}
+
 absl::Status InstructionEncoding::ComputeMaskAndValue() {
   // First consider equal constraints.
   mask_ = 0;
@@ -240,6 +299,17 @@
       mask = constraint->overlay->mask();
     }
     other_mask_ |= mask;
+    // If the rhs is a field or overlay, add to the mask.
+    if (constraint->rhs_field != nullptr) {
+      int width = constraint->rhs_field->width;
+      mask &= (1LLU << width) - 1;
+      int shift = constraint->rhs_field->low;
+      mask <<= shift;
+      other_mask_ |= mask;
+    } else if (constraint->rhs_overlay != nullptr) {
+      mask &= constraint->rhs_overlay->mask();
+      other_mask_ |= mask;
+    }
   }
   mask_set_ = true;
   return absl::OkStatus();
diff --git a/mpact/sim/decoder/instruction_encoding.h b/mpact/sim/decoder/instruction_encoding.h
index c8f58b6..730e68b 100644
--- a/mpact/sim/decoder/instruction_encoding.h
+++ b/mpact/sim/decoder/instruction_encoding.h
@@ -21,6 +21,7 @@
 
 #include "absl/status/status.h"
 #include "absl/status/statusor.h"
+#include "mpact/sim/decoder/bin_format_visitor.h"
 #include "mpact/sim/decoder/format.h"
 #include "mpact/sim/decoder/overlay.h"
 
@@ -29,13 +30,13 @@
 namespace decoder {
 namespace bin_format {
 
-enum class ConstraintType : int { kEq = 0, kNe, kLt, kLe, kGt, kGe };
-
 // Helper struct to group the information of a constraint (either == or !=).
 struct Constraint {
   ConstraintType type;
   Field *field = nullptr;
   Overlay *overlay = nullptr;
+  Field *rhs_field = nullptr;
+  Overlay *rhs_overlay = nullptr;
   bool can_ignore = false;
   uint64_t value;
 };
@@ -60,6 +61,11 @@
   // instruction) needing a different comparison (ne, lt, le, etc.).
   absl::Status AddOtherConstraint(ConstraintType type, std::string field_name,
                                   int64_t value);
+  // Add a constraint on a field/overlay (in the format associated with the
+  // instruction) that compares against another field/overlay.
+  absl::Status AddOtherConstraint(ConstraintType type,
+                                  const std::string &lhs_name,
+                                  const std::string &rhs_name);
 
   // Get the value of the constant bits in the instruction (as defined by the
   // equal constraints).
@@ -96,6 +102,10 @@
  private:
   // Internal helper to create and check a constraint.
   absl::StatusOr<Constraint *> CreateConstraint(ConstraintType type,
+                                                std::string lhs_name,
+                                                std::string rhs_name);
+
+  absl::StatusOr<Constraint *> CreateConstraint(ConstraintType type,
                                                 std::string field_name,
                                                 int64_t value);
   // Recomputes the masks and values.
diff --git a/mpact/sim/decoder/overlay.h b/mpact/sim/decoder/overlay.h
index 73ee2cc..7917988 100644
--- a/mpact/sim/decoder/overlay.h
+++ b/mpact/sim/decoder/overlay.h
@@ -117,7 +117,7 @@
   bool operator!=(const Overlay &rhs) const;
 
   // Accessors.
-  const std::string &name() { return name_; }
+  const std::string &name() const { return name_; }
   bool is_signed() const { return is_signed_; }
   int declared_width() const { return declared_width_; }
   int computed_width() const { return computed_width_; }
diff --git a/mpact/sim/decoder/test/testfiles/constraints.bin_fmt b/mpact/sim/decoder/test/testfiles/constraints.bin_fmt
index 2a15614..9a07f08 100644
--- a/mpact/sim/decoder/test/testfiles/constraints.bin_fmt
+++ b/mpact/sim/decoder/test/testfiles/constraints.bin_fmt
@@ -36,4 +36,5 @@
   greater_equal : Inst32Format : opcode == 0b011, field3 >= 3;
   less          : Inst32Format : opcode == 0b100, field4 < 4;
   less_equal    : Inst32Format : opcode == 0b101, field5 <= 5;
+  field_field   : Inst32Format : opcode == 0b111, field0 != field1;
 };
\ No newline at end of file