blob: f311316274436245509996fefedc5bb6f98d9033 [file]
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mpact/sim/decoder/proto_encoding_group.h"
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <string>
#include <vector>
#include "absl/container/btree_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "mpact/sim/decoder/decoder_error_listener.h"
#include "mpact/sim/decoder/format_name.h"
#include "mpact/sim/decoder/proto_constraint_expression.h"
#include "mpact/sim/decoder/proto_constraint_value_set.h"
#include "mpact/sim/decoder/proto_encoding_info.h"
#include "mpact/sim/decoder/proto_format_contexts.h"
#include "mpact/sim/decoder/proto_instruction_encoding.h"
#include "mpact/sim/decoder/proto_instruction_group.h"
#include "src/google/protobuf/descriptor.h"
namespace mpact {
namespace sim {
namespace decoder {
namespace proto_fmt {
using ::mpact::sim::machine_description::instruction_set::ToPascalCase;
using ::mpact::sim::machine_description::instruction_set::ToSnakeCase;
struct FieldInfo {
const google::protobuf::FieldDescriptor* field;
const google::protobuf::OneofDescriptor* oneof;
QualifiedIdentCtx* ctx;
absl::btree_multimap<int64_t, const ProtoInstructionEncoding*> value_map;
int64_t min_value = 0;
int64_t max_value = 0;
size_t unique_values = 0;
double density = 0.0;
};
using ConstraintValueRange = ProtoConstraintValueSet::SubRange;
ProtoEncodingGroup::ProtoEncodingGroup(ProtoInstructionGroup* inst_group,
int level,
DecoderErrorListener* error_listener)
: ProtoEncodingGroup(nullptr, inst_group, level, error_listener) {}
ProtoEncodingGroup::ProtoEncodingGroup(ProtoEncodingGroup* parent,
ProtoInstructionGroup* inst_group,
int level,
DecoderErrorListener* error_listener)
: inst_group_(inst_group),
parent_(parent),
error_listener_(error_listener),
level_(level) {}
ProtoEncodingGroup::~ProtoEncodingGroup() {
for (auto const& [unused, field_info] : field_map_) {
delete field_info;
}
field_map_.clear();
for (auto const* enc : encoding_vec_) {
delete enc;
}
encoding_vec_.clear();
inst_group_ = nullptr;
for (auto const* enc_group : encoding_group_vec_) {
delete enc_group;
}
encoding_group_vec_.clear();
}
void ProtoEncodingGroup::AddEncoding(ProtoInstructionEncoding* enc) {
// All constraints in equal_constraints are kEq constraints on integer
// fields, or are kHas constraints which are kEq constraints on the
// '_value()' function of the one_of field (which is an int value).
// The first step is to determine which constraints differentiate the most
// instructions in the encoding group.
for (auto* eq_constraint : enc->equal_constraints()) {
auto const* field = eq_constraint->field_descriptor;
auto const* oneof = field->containing_oneof();
auto const* expr = eq_constraint->expr;
auto op = eq_constraint->op;
auto* qualifed_ident_ctx = eq_constraint->ctx->qualified_ident();
int64_t value; // Store the value in an int64_t.
if (op == ConstraintType::kEq) {
switch (expr->variant_type()) {
case *ProtoValueIndex::kInt32:
value = expr->GetValueAs<int32_t>();
break;
case *ProtoValueIndex::kInt64:
value = expr->GetValueAs<int64_t>();
break;
case *ProtoValueIndex::kUint32:
value = expr->GetValueAs<uint32_t>();
break;
case *ProtoValueIndex::kUint64: {
// If the value overflows int64_t, just flag an error. Keep it
// simple.
uint64_t tmp = expr->GetValueAs<uint64_t>();
if (tmp > std::numeric_limits<int64_t>::max()) {
error_listener()->semanticError(
eq_constraint->ctx->start,
absl::StrCat("Expression value for field '", field->name(),
"' overflows int64_t."));
return;
}
value = static_cast<int64_t>(tmp);
break;
}
default:
error_listener()->semanticError(
eq_constraint->ctx->start,
absl::StrCat(
"Illegal type in expression in constraint for field '",
field->name(), "'."));
return;
}
oneof = nullptr;
} else if (op == ConstraintType::kHas) {
value = field->index();
oneof = field->containing_oneof();
field = nullptr;
} else {
error_listener()->semanticError(
eq_constraint->ctx->start,
absl::StrCat("Illegal constraint op for field '", field->name(),
"' in equality constraints."));
return;
}
eq_constraint->value = value;
// If the field_info doesn't exist, add a new field_info.
auto name = field != nullptr ? field->name() : oneof->name();
auto iter = field_map_.find(name);
FieldInfo* field_info = nullptr;
if (iter == field_map_.end()) {
field_info = new FieldInfo{field, oneof};
field_info->min_value = std::numeric_limits<int64_t>::max();
field_info->max_value = std::numeric_limits<int64_t>::min();
field_info->ctx = qualifed_ident_ctx;
field_map_.emplace(name, field_info);
} else {
field_info = iter->second;
}
// Add a new entry for the value.
field_info->unique_values += !field_info->value_map.contains(value);
field_info->min_value = std::min(field_info->min_value, value);
field_info->max_value = std::max(field_info->max_value, value);
field_info->value_map.insert({value, enc});
}
encoding_vec_.push_back(enc);
// Populate the other_* sets. These are used later to ensure that sub groups
// aren't added with differentiators that are also used in other constraints.
for (auto* constraint : enc->other_constraints()) {
auto const* field = constraint->field_descriptor;
auto const* oneof = field->containing_oneof();
if (oneof != nullptr) {
other_oneof_set_.insert(oneof);
continue;
}
other_field_set_.insert(field);
}
}
void ProtoEncodingGroup::AddSubGroups() {
// If there is only one encoding, return.
if (encoding_vec_.size() == 1) return;
// First determine which field is the most productive to use to split up the
// group. This is determined by how many values it has, with some thought to
// the number of total values in its interval.
// To start with just pick the one with the largest number of unique values,
// as that should create a shallower decoding tree.
FieldInfo* best_field = nullptr;
absl::flat_hash_set<ProtoInstructionEncoding*> encodings;
for (auto enc : encoding_vec_) {
encodings.insert(enc);
}
for (auto& [unused, field_info] : field_map_) {
// First check if the field is used in any other constraints, e.g., '>' or
// '!='. If so, it is not a candidate for a direct lookup of the value.
if (other_field_set_.contains(field_info->field)) continue;
if (other_oneof_set_.contains(field_info->oneof)) continue;
if (best_field == nullptr) {
best_field = field_info;
} else {
if (field_info->unique_values > best_field->unique_values) {
best_field = field_info;
}
}
}
// If there is no best field, or it doesn't differentiate we're done, but
// first check the encodings to make sure there are no ambiguities or
// duplicate encodings.
if ((best_field == nullptr) || (best_field->unique_values == 1)) {
CheckEncodings();
return;
}
// Save the differentiating field info in this group.
differentiator_ = best_field;
// Next, create an encoding group for each value of the field, adding the
// encodings that match the value to the corresponding groups.
for (auto iter = best_field->value_map.begin();
iter != best_field->value_map.end();
/*empty*/) {
auto* enc_group =
new ProtoEncodingGroup(this, inst_group_, level_ + 1, error_listener_);
int64_t value = iter->first;
enc_group->set_value(value);
while ((iter != best_field->value_map.end()) && (value == iter->first)) {
// First create a copy of the encoding and remove the constraint
// that corresponds with the field info, so it will not be considered
// below.
ProtoInstructionEncoding* enc =
new ProtoInstructionEncoding(*(iter->second));
// Remove the best_field constraint from the new encoding object.
auto v_iter = enc->equal_constraints().begin();
ProtoConstraint* constraint = nullptr;
while (v_iter != enc->equal_constraints().end()) {
constraint = *v_iter;
if (constraint->op == ConstraintType::kEq) {
if ((best_field->field != nullptr) &&
(best_field->field == constraint->field_descriptor)) {
break;
}
} else { // This constraint is a kHas.
if ((best_field->oneof != nullptr) &&
(best_field->oneof ==
constraint->field_descriptor->containing_oneof())) {
break;
}
}
++v_iter;
}
if (v_iter != enc->equal_constraints().end()) {
delete constraint->expr;
delete constraint;
enc->equal_constraints().erase(v_iter);
}
enc_group->AddEncoding(enc);
// Remove the encoding from the map.
encodings.erase(iter->second);
++iter;
}
encoding_group_vec_.push_back(enc_group);
}
// Any encodings remaining in the map have to be added to each of the sub
// groups, as they weren't selected by value.
for (auto enc : encodings) {
for (auto* enc_group : encoding_group_vec_) {
auto enc_copy = new ProtoInstructionEncoding(*enc);
enc_group->AddEncoding(enc_copy);
}
}
encodings.clear();
// Recursively try to split the child encoding groups.
for (auto* enc_group : encoding_group_vec_) {
enc_group->AddSubGroups();
}
}
// Check the encodings to make sure there aren't ambiguities.
void ProtoEncodingGroup::CheckEncodings() {
// If there is only one encoding, there is no ambiguity.
if (encoding_vec_.size() <= 1) return;
// Encodings have to have additional constraints to differentiate between each
// other, so check to see if any of them have none, and if so, signal an
// error.
for (auto* enc : encoding_vec_) {
if (enc->equal_constraints().empty() && enc->other_constraints().empty()) {
std::string msg =
absl::StrCat("Decoding ambiguity between '", enc->name(), "' and :");
for (auto* other_enc : encoding_vec_) {
if (enc == other_enc) continue;
absl::StrAppend(&msg, " '", other_enc->name(), "'");
}
error_listener()->semanticError(nullptr, msg);
return;
}
}
// Check for identical or overlapping constraints.
// First sort the constraints in each vector.
std::vector<std::vector<ProtoConstraint*>> constraints;
constraints.reserve(encoding_vec_.size());
for (auto* enc : encoding_vec_) {
constraints.push_back({});
for (auto* constraint : enc->equal_constraints()) {
constraints.back().push_back(constraint);
}
for (auto* constraint : enc->other_constraints()) {
constraints.back().push_back(constraint);
}
}
for (int i = 0; i < constraints.size(); ++i) {
std::sort(
constraints[i].begin(), constraints[i].end(),
[](const ProtoConstraint* lhs, const ProtoConstraint* rhs) -> bool {
return lhs->field_descriptor->full_name() <
rhs->field_descriptor->full_name();
});
}
// Now create value value sets for each field descriptor, combining multiple
// constraints on the same field descriptor into a single set of values.
std::vector<std::vector<ProtoConstraintValueSet*>> value_sets;
value_sets.reserve(encoding_vec_.size());
for (auto const& constraint_vec : constraints) {
const google::protobuf::FieldDescriptor* previous = nullptr;
value_sets.push_back({});
for (auto const* constraint : constraint_vec) {
// If it's the first occurance of a field descriptor, create a new
// range on this constraint.
ProtoConstraintValueSet* value_set = nullptr;
if (previous != constraint->field_descriptor) {
previous = constraint->field_descriptor;
value_set = new ProtoConstraintValueSet(constraint);
value_sets.back().push_back(value_set);
continue;
}
// This is not the first occurance of a field descriptor. Intersect
// with the current range.
auto status =
value_set->IntersectWith(ProtoConstraintValueSet(constraint));
// Check for error.
if (!status.ok()) {
// Clean up.
for (auto& value_set_list : value_sets) {
for (auto* value_set : value_set_list) delete value_set;
}
value_sets.clear();
// Signal error.
error_listener()->semanticError(nullptr, status.message());
return;
}
}
}
for (int i = 0; i < value_sets.size(); ++i) {
for (int j = i + 1; j < value_sets.size(); ++j) {
if (DoConstraintsOverlap(value_sets[i], value_sets[j])) {
error_listener()->semanticError(
nullptr, absl::StrCat("Encoding group '", inst_group_->name(),
"': encoding ambiguity between '",
encoding_vec_[i]->name(), " and ",
encoding_vec_[j]->name(), "'"));
}
}
}
// Clean up.
for (auto& value_set_list : value_sets) {
for (auto* value_set : value_set_list) delete value_set;
}
}
// Determine if the constraints overlap for two encodings lhs and rhs based on
// the value sets.
bool ProtoEncodingGroup::DoConstraintsOverlap(
const std::vector<ProtoConstraintValueSet*>& lhs,
const std::vector<ProtoConstraintValueSet*>& rhs) {
auto iter_lhs = lhs.begin();
auto iter_rhs = rhs.begin();
while ((iter_lhs != lhs.end()) && (iter_rhs != rhs.end())) {
// The constraint value sets are sorted by field descriptor name, so if
// the field descriptors are different, then the constraints do not overlap.
if ((*iter_lhs)->field_descriptor()->full_name() !=
(*iter_rhs)->field_descriptor()->full_name()) {
return false;
}
ProtoConstraintValueSet lhs_copy(**(iter_lhs));
auto status = lhs_copy.IntersectWith(**(iter_rhs));
// If there is an error taking the intersection, return true to signify
// an overlap, even if there isn't one.
if (!status.ok()) return true;
// If the intersection is empty, then they don't overlap. No need to check
// further.
if (lhs_copy.IsEmpty()) return false;
++iter_lhs;
++iter_rhs;
}
// If there are additional constraint value sets for either instruction, then
// they don't overlap.
if ((iter_lhs != lhs.end()) || (iter_rhs != rhs.end())) {
return false;
}
return true;
}
constexpr char kDecodeMsgName[] = "inst_proto";
std::string ProtoEncodingGroup::EmitLeafDecoder(
absl::string_view fcn_name, absl::string_view opcode_enum,
absl::string_view message_type_name, int indent_width) const {
std::string output;
std::string if_sep;
std::string decoder_class =
ToPascalCase(inst_group_->encoding_info()->decoder()->name()) + "Decoder";
absl::StrAppend(&output, std::string(indent_width, ' '), opcode_enum, " ",
fcn_name, "(", message_type_name, " ", kDecodeMsgName, ", ",
decoder_class, " *decoder) {\n");
indent_width += 2;
std::string indent(indent_width, ' ');
// Check for the case when there is only a single encoding with no
// constraints.
if ((encoding_vec_.size() == 1) &&
encoding_vec_[0]->equal_constraints().empty() &&
encoding_vec_[0]->other_constraints().empty()) {
absl::StrAppend(
&output, encoding_vec_[0]->GetSetterCode(kDecodeMsgName, indent_width),
"return ", opcode_enum, "::k", ToPascalCase(encoding_vec_[0]->name()),
";\n");
indent_width -= 2;
absl::StrAppend(&output, std::string(indent_width, ' '), "}\n\n");
return output;
}
// Helper lambda.
auto generate_condition =
[](const ProtoConstraint* constraint) -> std::string {
std::string output;
if (constraint->op == ConstraintType::kHas) {
std::string ident = constraint->ctx->qualified_ident()->getText();
auto pos = ident.find_last_of('.');
std::string prefix;
if (pos != std::string::npos) {
prefix = absl::StrCat(".", ident.substr(0, pos + 1));
}
auto oneof_desc = constraint->field_descriptor->containing_oneof();
auto oneof_name = oneof_desc->name();
std::string parent_name;
for (auto parent = oneof_desc->containing_type(); parent != nullptr;
parent = parent->containing_type()) {
absl::StrAppend(&parent_name, ToPascalCase(parent->name()), "::");
}
auto package = absl::StrReplaceAll(
constraint->field_descriptor->file()->package(), {{".", "::"}});
return absl::StrCat(
"(", kDecodeMsgName, prefix, ".", oneof_name, "_case() == ", package,
"::", parent_name, ToPascalCase(oneof_name), "Case::k",
ToPascalCase(constraint->field_descriptor->name()), ")");
} else {
return absl::StrCat("(", kDecodeMsgName, ".",
constraint->ctx->field->getText(), " ",
GetOpText(constraint->op), " ",
constraint->ctx->constraint_expr()->getText(), ")");
}
};
// Generate a chained if-else if-else-statement for the encodings in the
// encoding vector.
std::string indent_body(indent_width + 2, ' ');
for (auto* enc : encoding_vec_) {
// Generate the if statement conditions.
absl::StrAppend(&output, indent, if_sep, "if (");
std::string cond_sep;
for (auto const* constraint : enc->equal_constraints()) {
absl::StrAppend(&output, cond_sep, generate_condition(constraint));
cond_sep = " && ";
}
for (auto const* constraint : enc->other_constraints()) {
absl::StrAppend(&output, cond_sep, generate_condition(constraint));
cond_sep = " && ";
}
absl::StrAppend(&output, ") {\n");
// Generate if statement body.
absl::StrAppend(&output, indent_body,
enc->GetSetterCode(kDecodeMsgName, indent_width + 2),
"return ", opcode_enum, "::k", ToPascalCase(enc->name()),
";\n");
if_sep = "} else ";
}
// Generate the fall through.
absl::StrAppend(&output, indent, "}\n", indent, "return ", opcode_enum,
"::kNone;\n");
indent_width -= 2;
absl::StrAppend(&output, std::string(indent_width, ' '), "}\n\n");
return output;
}
namespace {
bool LessThan(ProtoEncodingGroup* lhs, ProtoEncodingGroup* rhs) {
return lhs->value() < rhs->value();
}
} // namespace
std::string ProtoEncodingGroup::EmitComplexDecoder(
absl::string_view fcn_name, absl::string_view opcode_enum,
absl::string_view message_type_name) {
std::string output;
if (encoding_group_vec_.empty()) {
return EmitLeafDecoder(fcn_name, opcode_enum, message_type_name, 0);
}
std::string decoder_class =
ToPascalCase(inst_group_->encoding_info()->decoder()->name()) + "Decoder";
// First emit the function call tables.
// Sort the encoding_group_vec according to differentiator value.
std::sort(encoding_group_vec_.begin(), encoding_group_vec_.end(), LessThan);
// Now emit the decoder function.
double density =
(double)differentiator_->unique_values /
(double)(differentiator_->max_value - differentiator_->min_value);
if (density < 0.75) {
std::string map_name = absl::StrCat(ToSnakeCase(fcn_name), "_map");
absl::StrAppend(
&output,
"absl::NoDestructor<absl::flat_hash_map<int32_t, std::function<",
opcode_enum, "(", message_type_name, ", ", decoder_class, "*)>>> ",
map_name, "({\n");
for (auto* enc_group : encoding_group_vec_) {
auto enc_value = enc_group->value();
absl::StrAppend(&output, " {", enc_value, ", ", fcn_name, "_", enc_value,
"},\n");
}
absl::StrAppend(&output, "});\n\n");
// Emit the function body.
auto call =
absl::StrReplaceAll(differentiator_->ctx->getText(), {{".", "()."}});
absl::StrAppend(&output, opcode_enum, " ", fcn_name, "(", message_type_name,
" ", kDecodeMsgName, ", ", decoder_class, " *decoder) {\n",
" auto iter = ", map_name, "->find(", kDecodeMsgName, ".",
call, "());\n", " if (iter == ", map_name,
"->end()) return ", opcode_enum, "::kNone;\n",
" return iter->second(", kDecodeMsgName,
", decoder);\n}\n\n");
} else {
auto min = differentiator_->min_value;
auto max = differentiator_->max_value;
auto num_values = max - min + 1;
auto iter = encoding_group_vec_.begin();
absl::StrAppend(&output, "std::function<", opcode_enum, "(",
message_type_name, ", ", decoder_class, "*)> ",
ToSnakeCase(fcn_name), "_table[", num_values, "] = {\n");
// Fill in the entries in the function table.
for (int i = 0; i < num_values; ++i) {
auto enc_value = iter != encoding_group_vec_.end()
? (*iter)->value()
: std::numeric_limits<int64_t>::min();
auto index = enc_value - min;
if (index != i) {
absl::StrAppend(&output, " Decode", ToPascalCase(inst_group_->name()),
"_None,\n");
} else {
absl::StrAppend(&output, " ", fcn_name, "_", enc_value, ",\n");
++iter;
}
}
// Emit the function body.
auto call =
absl::StrReplaceAll(differentiator_->ctx->getText(), {{".", "()."}});
absl::StrAppend(
&output, "};\n\n", opcode_enum, " ", fcn_name, "(", message_type_name,
" ", kDecodeMsgName, ", ", decoder_class, " *decoder) {\n", " return ",
ToSnakeCase(fcn_name), "_table[", kDecodeMsgName, ".", call, "() - ",
differentiator_->min_value, "](", kDecodeMsgName, ", decoder);\n}\n\n");
}
return output;
}
std::string ProtoEncodingGroup::EmitDecoders(
absl::string_view fcn_name, absl::string_view opcode_enum,
absl::string_view message_type_name) {
std::string output;
// Emit decoders for subordinate groups (lower in the hierarchy).
for (auto* enc_group : encoding_group_vec_) {
absl::StrAppend(
&output,
enc_group->EmitDecoders(absl::StrCat(fcn_name, "_", enc_group->value()),
opcode_enum, message_type_name));
}
// Emit decoder for this group.
absl::StrAppend(&output,
EmitComplexDecoder(fcn_name, opcode_enum, message_type_name));
return output;
}
} // namespace proto_fmt
} // namespace decoder
} // namespace sim
} // namespace mpact