ソースを参照

Unify StructElement and VarValues (#909)

Geoff Romer 4 年 前
コミット
0da907fd2a

+ 7 - 7
executable_semantics/interpreter/interpreter.cpp

@@ -128,8 +128,8 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
 
     case Declaration::Kind::ClassDeclaration: {
       const ClassDefinition& class_def = cast<ClassDeclaration>(d).definition();
-      VarValues fields;
-      VarValues methods;
+      std::vector<NamedValue> fields;
+      std::vector<NamedValue> methods;
       for (Nonnull<const Member*> m : class_def.members()) {
         switch (m->kind()) {
           case Member::Kind::FieldMember: {
@@ -137,7 +137,7 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
             const Expression& type_expression =
                 cast<ExpressionPattern>(binding.type()).expression();
             auto type = InterpExp(Env(arena_), &type_expression);
-            fields.push_back(make_pair(*binding.name(), type));
+            fields.push_back({.name = *binding.name(), .value = type});
             break;
           }
         }
@@ -151,10 +151,10 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
 
     case Declaration::Kind::ChoiceDeclaration: {
       const auto& choice = cast<ChoiceDeclaration>(d);
-      VarValues alts;
+      std::vector<NamedValue> alts;
       for (const auto& alternative : choice.alternatives()) {
         auto t = InterpExp(Env(arena_), &alternative.signature());
-        alts.push_back(make_pair(alternative.name(), t));
+        alts.push_back({.name = alternative.name(), .value = t});
       }
       auto ct = arena_->New<ChoiceType>(choice.name(), std::move(alts));
       auto a = heap_.AllocateValue(ct);
@@ -205,7 +205,7 @@ auto Interpreter::CreateStruct(const std::vector<FieldInitializer>& fields,
                                const std::vector<Nonnull<const Value*>>& values)
     -> Nonnull<const Value*> {
   CHECK(fields.size() == values.size());
-  std::vector<StructElement> elements;
+  std::vector<NamedValue> elements;
   for (size_t i = 0; i < fields.size(); ++i) {
     elements.push_back({.name = fields[i].name(), .value = values[i]});
   }
@@ -580,7 +580,7 @@ auto Interpreter::StepExp() -> Transition {
         return Spawn{arena_->New<ExpressionAction>(
             &struct_type.fields()[act->pos()].expression())};
       } else {
-        VarValues fields;
+        std::vector<NamedValue> fields;
         for (size_t i = 0; i < struct_type.fields().size(); ++i) {
           fields.push_back({struct_type.fields()[i].name(), act->results()[i]});
         }

+ 37 - 34
executable_semantics/interpreter/type_checker.cpp

@@ -162,15 +162,18 @@ static auto IsImplicitlyConvertible(Nonnull<const Value*> source,
 // the corresponding value in destination_fields. All values in both arguments
 // must be types.
 static auto FieldTypesImplicitlyConvertible(
-    const VarValues& source_fields, const VarValues& destination_fields) {
+    llvm::ArrayRef<NamedValue> source_fields,
+    llvm::ArrayRef<NamedValue> destination_fields) {
   if (source_fields.size() != destination_fields.size()) {
     return false;
   }
-  for (const auto& [field_name, source_field_type] : source_fields) {
-    std::optional<Nonnull<const Value*>> destination_field_type =
-        FindInVarValues(field_name, destination_fields);
-    if (!destination_field_type.has_value() ||
-        !IsImplicitlyConvertible(source_field_type, *destination_field_type)) {
+  for (const auto& source_field : source_fields) {
+    auto it = std::find_if(destination_fields.begin(), destination_fields.end(),
+                           [&](const NamedValue& field) {
+                             return field.name == source_field.name;
+                           });
+    if (it == destination_fields.end() ||
+        !IsImplicitlyConvertible(source_field.value, it->value)) {
       return false;
     }
   }
@@ -294,14 +297,14 @@ static auto ArgumentDeduction(SourceLocation source_loc, TypeEnv deduced,
             << arg_struct.fields().size();
       }
       for (size_t i = 0; i < param_struct.fields().size(); ++i) {
-        if (param_struct.fields()[i].first != arg_struct.fields()[i].first) {
+        if (param_struct.fields()[i].name != arg_struct.fields()[i].name) {
           FATAL_COMPILATION_ERROR(source_loc)
-              << "mismatch in field names, " << param_struct.fields()[i].first
-              << " != " << arg_struct.fields()[i].first;
+              << "mismatch in field names, " << param_struct.fields()[i].name
+              << " != " << arg_struct.fields()[i].name;
         }
         deduced = ArgumentDeduction(source_loc, deduced,
-                                    param_struct.fields()[i].second,
-                                    arg_struct.fields()[i].second);
+                                    param_struct.fields()[i].value,
+                                    arg_struct.fields()[i].value);
       }
       return deduced;
     }
@@ -382,7 +385,7 @@ auto TypeChecker::Substitute(TypeEnv dict, Nonnull<const Value*> type)
       return arena_->New<TupleValue>(elts);
     }
     case Value::Kind::StructType: {
-      VarValues fields;
+      std::vector<NamedValue> fields;
       for (const auto& [name, value] : cast<StructType>(*type).fields()) {
         auto new_type = Substitute(dict, value);
         fields.push_back({name, new_type});
@@ -469,7 +472,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
     }
     case Expression::Kind::StructLiteral: {
       std::vector<FieldInitializer> new_args;
-      VarValues arg_types;
+      std::vector<NamedValue> arg_types;
       auto new_types = types;
       for (auto& arg : cast<StructLiteral>(*e).fields()) {
         auto arg_res = TypeCheckExp(&arg.expression(), new_types, values);
@@ -523,15 +526,15 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
           const auto& t_class = cast<NominalClassType>(aggregate_type);
           // Search for a field
           for (auto& field : t_class.fields()) {
-            if (access.field() == field.first) {
-              SetStaticType(&access, field.second);
+            if (access.field() == field.name) {
+              SetStaticType(&access, field.value);
               return TCResult(res.types);
             }
           }
           // Search for a method
           for (auto& method : t_class.methods()) {
-            if (access.field() == method.first) {
-              SetStaticType(&access, method.second);
+            if (access.field() == method.name) {
+              SetStaticType(&access, method.value);
               return TCResult(res.types);
             }
           }
@@ -541,17 +544,17 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
         }
         case Value::Kind::ChoiceType: {
           const auto& choice = cast<ChoiceType>(aggregate_type);
-          for (const auto& vt : choice.alternatives()) {
-            if (access.field() == vt.first) {
-              SetStaticType(&access, arena_->New<FunctionType>(
-                                         std::vector<GenericBinding>(),
-                                         vt.second, &aggregate_type));
-              return TCResult(res.types);
-            }
+          std::optional<Nonnull<const Value*>> parameter_types =
+              choice.FindAlternative(access.field());
+          if (!parameter_types.has_value()) {
+            FATAL_COMPILATION_ERROR(e->source_loc())
+                << "choice " << choice.name() << " does not have a field named "
+                << access.field();
           }
-          FATAL_COMPILATION_ERROR(e->source_loc())
-              << "choice " << choice.name() << " does not have a field named "
-              << access.field();
+          SetStaticType(&access, arena_->New<FunctionType>(
+                                     std::vector<GenericBinding>(),
+                                     *parameter_types, &aggregate_type));
+          return TCResult(res.types);
         }
         default:
           FATAL_COMPILATION_ERROR(e->source_loc())
@@ -802,8 +805,8 @@ auto TypeChecker::TypeCheckPattern(
                         *expected, choice_type);
       }
       std::optional<Nonnull<const Value*>> parameter_types =
-          FindInVarValues(alternative.alternative_name(),
-                          cast<ChoiceType>(*choice_type).alternatives());
+          cast<ChoiceType>(*choice_type)
+              .FindAlternative(alternative.alternative_name());
       if (parameter_types == std::nullopt) {
         FATAL_COMPILATION_ERROR(alternative.source_loc())
             << "'" << alternative.alternative_name()
@@ -1108,8 +1111,8 @@ auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
 
 auto TypeChecker::TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
                                  Env ct_top) -> Nonnull<const Value*> {
-  VarValues fields;
-  VarValues methods;
+  std::vector<NamedValue> fields;
+  std::vector<NamedValue> methods;
   for (Nonnull<const Member*> m : sd->members()) {
     switch (m->kind()) {
       case Member::Kind::FieldMember: {
@@ -1124,7 +1127,7 @@ auto TypeChecker::TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
               << "Struct members must have explicit types";
         }
         auto type = interpreter_.InterpExp(ct_top, &binding_type->expression());
-        fields.push_back(std::make_pair(*binding.name(), type));
+        fields.push_back({.name = *binding.name(), .value = type});
         break;
       }
     }
@@ -1210,10 +1213,10 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
 
     case Declaration::Kind::ChoiceDeclaration: {
       const auto& choice = cast<ChoiceDeclaration>(*d);
-      VarValues alts;
+      std::vector<NamedValue> alts;
       for (const auto& alternative : choice.alternatives()) {
         auto t = interpreter_.InterpExp(tops->values, &alternative.signature());
-        alts.push_back(std::make_pair(alternative.name(), t));
+        alts.push_back({.name = alternative.name(), .value = t});
       }
       auto ct = arena_->New<ChoiceType>(choice.name(), std::move(alts));
       Address a = interpreter_.AllocateValue(ct);

+ 30 - 61
executable_semantics/interpreter/value.cpp

@@ -17,36 +17,9 @@ namespace Carbon {
 
 using llvm::cast;
 
-auto FindInVarValues(const std::string& field, const VarValues& inits)
-    -> std::optional<Nonnull<const Value*>> {
-  for (auto& i : inits) {
-    if (i.first == field) {
-      return i.second;
-    }
-  }
-  return std::nullopt;
-}
-
-auto FieldsEqual(const VarValues& ts1, const VarValues& ts2) -> bool {
-  if (ts1.size() == ts2.size()) {
-    for (auto& iter1 : ts1) {
-      auto t2 = FindInVarValues(iter1.first, ts2);
-      if (!t2) {
-        return false;
-      }
-      if (!TypeEqual(iter1.second, *t2)) {
-        return false;
-      }
-    }
-    return true;
-  } else {
-    return false;
-  }
-}
-
 auto StructValue::FindField(const std::string& name) const
     -> std::optional<Nonnull<const Value*>> {
-  for (const StructElement& element : elements_) {
+  for (const NamedValue& element : elements_) {
     if (element.name == name) {
       return element.value;
     }
@@ -78,7 +51,7 @@ auto GetMember(Nonnull<Arena*> arena, Nonnull<const Value*> v,
     }
     case Value::Kind::ChoiceType: {
       const auto& choice = cast<ChoiceType>(*v);
-      if (!FindInVarValues(f, choice.alternatives())) {
+      if (!choice.FindAlternative(f)) {
         FATAL_RUNTIME_ERROR(source_loc)
             << "alternative " << f << " not in " << *v;
       }
@@ -112,10 +85,9 @@ auto SetFieldImpl(Nonnull<Arena*> arena, Nonnull<const Value*> value,
   }
   switch (value->kind()) {
     case Value::Kind::StructValue: {
-      std::vector<StructElement> elements =
-          cast<StructValue>(*value).elements();
+      std::vector<NamedValue> elements = cast<StructValue>(*value).elements();
       auto it = std::find_if(elements.begin(), elements.end(),
-                             [path_begin](const StructElement& element) {
+                             [path_begin](const NamedValue& element) {
                                return element.name == *path_begin;
                              });
       if (it == elements.end()) {
@@ -185,7 +157,7 @@ void Value::Print(llvm::raw_ostream& out) const {
       const auto& struct_val = cast<StructValue>(*this);
       out << "{";
       llvm::ListSeparator sep;
-      for (const StructElement& element : struct_val.elements()) {
+      for (const NamedValue& element : struct_val.elements()) {
         out << sep << "." << element.name << " = " << *element.value;
       }
       out << "}";
@@ -313,9 +285,8 @@ auto TypeEqual(Nonnull<const Value*> t1, Nonnull<const Value*> t2) -> bool {
         return false;
       }
       for (size_t i = 0; i < struct1.fields().size(); ++i) {
-        if (struct1.fields()[i].first != struct2.fields()[i].first ||
-            !TypeEqual(struct1.fields()[i].second,
-                       struct2.fields()[i].second)) {
+        if (struct1.fields()[i].name != struct2.fields()[i].name ||
+            !TypeEqual(struct1.fields()[i].value, struct2.fields()[i].value)) {
           return false;
         }
       }
@@ -354,28 +325,6 @@ auto TypeEqual(Nonnull<const Value*> t1, Nonnull<const Value*> t2) -> bool {
   }
 }
 
-// Returns true if all the fields of the two tuples contain equal values
-// and returns false otherwise.
-static auto FieldsValueEqual(const std::vector<StructElement>& ts1,
-                             const std::vector<StructElement>& ts2,
-                             SourceLocation source_loc) -> bool {
-  if (ts1.size() != ts2.size()) {
-    return false;
-  }
-  for (const StructElement& element : ts1) {
-    auto iter = std::find_if(
-        ts2.begin(), ts2.end(),
-        [&](const StructElement& e2) { return e2.name == element.name; });
-    if (iter == ts2.end()) {
-      return false;
-    }
-    if (!ValueEqual(element.value, iter->value, source_loc)) {
-      return false;
-    }
-  }
-  return true;
-}
-
 // Returns true if the two values are equal and returns false otherwise.
 //
 // This function implements the `==` operator of Carbon.
@@ -414,9 +363,19 @@ auto ValueEqual(Nonnull<const Value*> v1, Nonnull<const Value*> v2,
       }
       return true;
     }
-    case Value::Kind::StructValue:
-      return FieldsValueEqual(cast<StructValue>(*v1).elements(),
-                              cast<StructValue>(*v2).elements(), source_loc);
+    case Value::Kind::StructValue: {
+      const auto& struct_v1 = cast<StructValue>(*v1);
+      const auto& struct_v2 = cast<StructValue>(*v2);
+      CHECK(struct_v1.elements().size() == struct_v2.elements().size());
+      for (size_t i = 0; i < struct_v1.elements().size(); ++i) {
+        CHECK(struct_v1.elements()[i].name == struct_v2.elements()[i].name);
+        if (!ValueEqual(struct_v1.elements()[i].value,
+                        struct_v2.elements()[i].value, source_loc)) {
+          return false;
+        }
+      }
+      return true;
+    }
     case Value::Kind::StringValue:
       return cast<StringValue>(*v1).value() == cast<StringValue>(*v2).value();
     case Value::Kind::IntType:
@@ -441,4 +400,14 @@ auto ValueEqual(Nonnull<const Value*> v1, Nonnull<const Value*> v2,
   }
 }
 
+auto ChoiceType::FindAlternative(std::string_view name) const
+    -> std::optional<Nonnull<const Value*>> {
+  for (const NamedValue& alternative : alternatives_) {
+    if (alternative.name == name) {
+      return alternative.value;
+    }
+  }
+  return std::nullopt;
+}
+
 }  // namespace Carbon

+ 22 - 28
executable_semantics/interpreter/value.h

@@ -90,17 +90,8 @@ class Value {
   const Kind kind_;
 };
 
-using VarValues = std::vector<std::pair<std::string, Nonnull<const Value*>>>;
-
-auto FindInVarValues(const std::string& field, const VarValues& inits)
-    -> std::optional<Nonnull<const Value*>>;
-auto FieldsEqual(const VarValues& ts1, const VarValues& ts2) -> bool;
-
-// A StructElement represents the value of a single struct field.
-//
-// TODO(geoffromer): Look for ways to eliminate duplication among StructElement,
-// VarValues::value_type, FieldInitializer, and any similar types.
-struct StructElement {
+// A NamedValue represents a value with a name, such as a single struct field.
+struct NamedValue {
   // The field name.
   std::string name;
 
@@ -181,7 +172,7 @@ class BoolValue : public Value {
 // StructType instances.
 class StructValue : public Value {
  public:
-  explicit StructValue(std::vector<StructElement> elements)
+  explicit StructValue(std::vector<NamedValue> elements)
       : Value(Kind::StructValue), elements_(std::move(elements)) {
     CHECK(!elements_.empty())
         << "`{}` is represented as a StructType, not a StructValue.";
@@ -191,9 +182,7 @@ class StructValue : public Value {
     return value->kind() == Kind::StructValue;
   }
 
-  auto elements() const -> const std::vector<StructElement>& {
-    return elements_;
-  }
+  auto elements() const -> llvm::ArrayRef<NamedValue> { return elements_; }
 
   // Returns the value of the field named `name` in this struct, or
   // nullopt if there is no such field.
@@ -201,7 +190,7 @@ class StructValue : public Value {
       -> std::optional<Nonnull<const Value*>>;
 
  private:
-  std::vector<StructElement> elements_;
+  std::vector<NamedValue> elements_;
 };
 
 // A value of a nominal class type.
@@ -400,25 +389,26 @@ class AutoType : public Value {
 // for `{}`, which is a struct value in addition to being a struct type.
 class StructType : public Value {
  public:
-  StructType() : StructType(VarValues{}) {}
+  StructType() : StructType(std::vector<NamedValue>{}) {}
 
-  explicit StructType(VarValues fields)
+  explicit StructType(std::vector<NamedValue> fields)
       : Value(Kind::StructType), fields_(std::move(fields)) {}
 
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::StructType;
   }
 
-  auto fields() const -> const VarValues& { return fields_; }
+  auto fields() const -> llvm::ArrayRef<NamedValue> { return fields_; }
 
  private:
-  VarValues fields_;
+  std::vector<NamedValue> fields_;
 };
 
 // A class type.
 class NominalClassType : public Value {
  public:
-  NominalClassType(std::string name, VarValues fields, VarValues methods)
+  NominalClassType(std::string name, std::vector<NamedValue> fields,
+                   std::vector<NamedValue> methods)
       : Value(Kind::NominalClassType),
         name_(std::move(name)),
         fields_(std::move(fields)),
@@ -429,19 +419,19 @@ class NominalClassType : public Value {
   }
 
   auto name() const -> const std::string& { return name_; }
-  auto fields() const -> const VarValues& { return fields_; }
-  auto methods() const -> const VarValues& { return methods_; }
+  auto fields() const -> llvm::ArrayRef<NamedValue> { return fields_; }
+  auto methods() const -> llvm::ArrayRef<NamedValue> { return methods_; }
 
  private:
   std::string name_;
-  VarValues fields_;
-  VarValues methods_;
+  std::vector<NamedValue> fields_;
+  std::vector<NamedValue> methods_;
 };
 
 // A choice type.
 class ChoiceType : public Value {
  public:
-  ChoiceType(std::string name, VarValues alternatives)
+  ChoiceType(std::string name, std::vector<NamedValue> alternatives)
       : Value(Kind::ChoiceType),
         name_(std::move(name)),
         alternatives_(std::move(alternatives)) {}
@@ -451,11 +441,15 @@ class ChoiceType : public Value {
   }
 
   auto name() const -> const std::string& { return name_; }
-  auto alternatives() const -> const VarValues& { return alternatives_; }
+
+  // Returns the parameter types of the alternative with the given name,
+  // or nullopt if no such alternative is present.
+  auto FindAlternative(std::string_view name) const
+      -> std::optional<Nonnull<const Value*>>;
 
  private:
   std::string name_;
-  VarValues alternatives_;
+  std::vector<NamedValue> alternatives_;
 };
 
 // A continuation type.