Przeglądaj źródła

Parameterized impl declarations (#1189)

* parameterized impls, first step

* bug fixes, comments, etc.

* added another test case, fix a bug in impl lookup

* simplify tests, removing tuple stuff

* don't create impl bindings for -bound implicit parameters

* remove a redundant 'private'

* CamelCase

* add missing backtick

* add a comment
Jeremy G. Siek 4 lat temu
rodzic
commit
b784458aef

+ 1 - 0
executable_semantics/ast/ast_rtti.txt

@@ -56,3 +56,4 @@ abstract class Expression : AstNode;
   class IfExpression : Expression;
   class UnimplementedExpression : Expression;
   class ArrayTypeLiteral : Expression;
+  class InstantiateImpl : Expression;

+ 21 - 0
executable_semantics/ast/declaration.cpp

@@ -211,6 +211,27 @@ void FunctionDeclaration::PrintDepth(int depth, llvm::raw_ostream& out) const {
   }
 }
 
+auto ImplDeclaration::Create(Nonnull<Arena*> arena, SourceLocation source_loc,
+                             ImplKind kind, Nonnull<Expression*> impl_type,
+                             Nonnull<Expression*> interface,
+                             std::vector<Nonnull<AstNode*>> deduced_params,
+                             std::vector<Nonnull<Declaration*>> members)
+    -> ErrorOr<Nonnull<ImplDeclaration*>> {
+  std::vector<Nonnull<GenericBinding*>> resolved_params;
+  for (Nonnull<AstNode*> param : deduced_params) {
+    switch (param->kind()) {
+      case AstNodeKind::GenericBinding:
+        resolved_params.push_back(&cast<GenericBinding>(*param));
+        break;
+      default:
+        return CompilationError(source_loc)
+               << "illegal AST node in implicit parameter list of impl";
+    }
+  }
+  return arena->New<ImplDeclaration>(source_loc, kind, impl_type, interface,
+                                     resolved_params, members);
+}
+
 void AlternativeSignature::Print(llvm::raw_ostream& out) const {
   out << "alt " << name() << " " << signature();
 }

+ 25 - 0
executable_semantics/ast/declaration.h

@@ -321,14 +321,24 @@ class ImplDeclaration : public Declaration {
  public:
   using ImplementsCarbonValueNode = void;
 
+  static auto Create(Nonnull<Arena*> arena, SourceLocation source_loc,
+                     ImplKind kind, Nonnull<Expression*> impl_type,
+                     Nonnull<Expression*> interface,
+                     std::vector<Nonnull<AstNode*>> deduced_params,
+                     std::vector<Nonnull<Declaration*>> members)
+      -> ErrorOr<Nonnull<ImplDeclaration*>>;
+
+  // Use `Create` instead.
   ImplDeclaration(SourceLocation source_loc, ImplKind kind,
                   Nonnull<Expression*> impl_type,
                   Nonnull<Expression*> interface,
+                  std::vector<Nonnull<GenericBinding*>> deduced_params,
                   std::vector<Nonnull<Declaration*>> members)
       : Declaration(AstNodeKind::ImplDeclaration, source_loc),
         kind_(kind),
         impl_type_(impl_type),
         interface_(interface),
+        deduced_parameters_(std::move(deduced_params)),
         members_(std::move(members)) {}
 
   static auto classof(const AstNode* node) -> bool {
@@ -347,17 +357,32 @@ class ImplDeclaration : public Declaration {
   auto interface_type() const -> Nonnull<const Value*> {
     return *interface_type_;
   }
+  auto deduced_parameters() const
+      -> llvm::ArrayRef<Nonnull<const GenericBinding*>> {
+    return deduced_parameters_;
+  }
+  auto deduced_parameters() -> llvm::ArrayRef<Nonnull<GenericBinding*>> {
+    return deduced_parameters_;
+  }
   auto members() const -> llvm::ArrayRef<Nonnull<Declaration*>> {
     return members_;
   }
   auto value_category() const -> ValueCategory { return ValueCategory::Let; }
+  void set_impl_bindings(llvm::ArrayRef<Nonnull<const ImplBinding*>> imps) {
+    impl_bindings_ = imps;
+  }
+  auto impl_bindings() const -> llvm::ArrayRef<Nonnull<const ImplBinding*>> {
+    return impl_bindings_;
+  }
 
  private:
   ImplKind kind_;
   Nonnull<Expression*> impl_type_;  // TODO: make this optional
   Nonnull<Expression*> interface_;
   std::optional<Nonnull<const Value*>> interface_type_;
+  std::vector<Nonnull<GenericBinding*>> deduced_parameters_;
   std::vector<Nonnull<Declaration*>> members_;
+  std::vector<Nonnull<const ImplBinding*>> impl_bindings_;
 };
 
 // Return the name of a declaration, if it has one.

+ 6 - 0
executable_semantics/ast/expression.cpp

@@ -164,6 +164,11 @@ void Expression::Print(llvm::raw_ostream& out) const {
           << if_expr.then_expression() << " else " << if_expr.else_expression();
       break;
     }
+    case ExpressionKind::InstantiateImpl: {
+      const auto& inst_impl = cast<InstantiateImpl>(*this);
+      out << "instantiate " << *inst_impl.generic_impl();
+      break;
+    }
     case ExpressionKind::UnimplementedExpression: {
       const auto& unimplemented = cast<UnimplementedExpression>(*this);
       out << "UnimplementedExpression<" << unimplemented.label() << ">(";
@@ -237,6 +242,7 @@ void Expression::PrintID(llvm::raw_ostream& out) const {
     case ExpressionKind::UnimplementedExpression:
     case ExpressionKind::FunctionTypeLiteral:
     case ExpressionKind::ArrayTypeLiteral:
+    case ExpressionKind::InstantiateImpl:
       out << "...";
       break;
   }

+ 36 - 9
executable_semantics/ast/expression.h

@@ -364,6 +364,8 @@ class GenericBinding;
 using BindingMap =
     std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>;
 
+using ImplExpMap = std::map<Nonnull<const ImplBinding*>, Nonnull<Expression*>>;
+
 class CallExpression : public Expression {
  public:
   explicit CallExpression(SourceLocation source_loc,
@@ -382,18 +384,14 @@ class CallExpression : public Expression {
   auto argument() const -> const Expression& { return *argument_; }
   auto argument() -> Expression& { return *argument_; }
 
-  // Maps each of `function`'s generic parameters to the AST node
-  // that identifies the witness table for the corresponding argument.
+  // Maps each of `function`'s impl bindings to an expression
+  // that constructs a witness table.
   // Should not be called before typechecking, or if `function` is not
   // a generic function.
-  auto impls() const
-      -> const std::map<Nonnull<const ImplBinding*>, ValueNodeView>& {
-    return impls_;
-  }
+  auto impls() const -> const ImplExpMap& { return impls_; }
 
   // Can only be called once, during typechecking.
-  void set_impls(
-      const std::map<Nonnull<const ImplBinding*>, ValueNodeView>& impls) {
+  void set_impls(const ImplExpMap& impls) {
     CHECK(impls_.empty());
     impls_ = impls;
   }
@@ -407,7 +405,7 @@ class CallExpression : public Expression {
  private:
   Nonnull<Expression*> function_;
   Nonnull<Expression*> argument_;
-  std::map<Nonnull<const ImplBinding*>, ValueNodeView> impls_;
+  ImplExpMap impls_;
   BindingMap deduced_args_;
 };
 
@@ -538,6 +536,35 @@ class IfExpression : public Expression {
   Nonnull<Expression*> else_expression_;
 };
 
+// Instantiate a generic impl.
+class InstantiateImpl : public Expression {
+ public:
+  using ImplementsCarbonValueNode = void;
+
+  explicit InstantiateImpl(SourceLocation source_loc,
+                           Nonnull<Expression*> generic_impl,
+                           const BindingMap& type_args, const ImplExpMap& impls)
+      : Expression(AstNodeKind::InstantiateImpl, source_loc),
+        generic_impl_(generic_impl),
+        type_args_(type_args),
+        impls_(impls) {}
+
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromInstantiateImpl(node->kind());
+  }
+  auto generic_impl() const -> Nonnull<Expression*> { return generic_impl_; }
+  auto type_args() const -> const BindingMap& { return type_args_; }
+
+  // Maps each of the impl bindings to an expression that constructs
+  // the witness table for that impl.
+  auto impls() const -> const ImplExpMap& { return impls_; }
+
+ private:
+  Nonnull<Expression*> generic_impl_;
+  BindingMap type_args_;
+  ImplExpMap impls_;
+};
+
 // An expression whose semantics have not been implemented. This can be used
 // as a placeholder during development, in order to implement and test parsing
 // of a new expression syntax without having to implement its semantics.

+ 7 - 8
executable_semantics/ast/impl_binding.h

@@ -19,14 +19,13 @@ class Value;
 class Expression;
 class ImplBinding;
 
-// The run-time counterpart of a `GenericBinding`.
-//
-// Once a generic binding has been declared, it can be used
-// in two different ways: as a compile-time constant with a
-// symbolic value (such as a `VariableType`), or as a run-time
-// variable with a concrete value that is stored on the stack.
-// An `ImplBinding` is used in contexts where the second
-// interpretation is intended.
+// `ImplBinding` plays the role of the parameter for passing witness
+// tables to a generic. However, unlike regular parameters
+// (`BindingPattern`) there is no explicit syntax that corresponds to
+// an `ImplBinding`, so they are not created during parsing. Instances
+// of `ImplBinding` are created during type checking, when processing
+// a type parameter (a `GenericBinding`), or an `is` requirement in
+// a `where` clause.
 class ImplBinding : public AstNode {
  public:
   using ImplementsCarbonValueNode = void;

+ 4 - 0
executable_semantics/fuzzing/ast_to_proto.cpp

@@ -80,6 +80,10 @@ static auto ExpressionToProto(const Expression& expression)
     -> Fuzzing::Expression {
   Fuzzing::Expression expression_proto;
   switch (expression.kind()) {
+    case ExpressionKind::InstantiateImpl: {
+      // UNDER CONSTRUCTION
+      break;
+    }
     case ExpressionKind::CallExpression: {
       const auto& call = cast<CallExpression>(expression);
       auto* call_proto = expression_proto.mutable_call();

+ 58 - 36
executable_semantics/interpreter/impl_scope.cpp

@@ -4,7 +4,7 @@
 
 #include "executable_semantics/interpreter/impl_scope.h"
 
-#include "executable_semantics/common/error_builders.h"
+#include "executable_semantics/interpreter/type_checker.h"
 #include "executable_semantics/interpreter/value.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Casting.h"
@@ -14,8 +14,20 @@ using llvm::cast;
 namespace Carbon {
 
 void ImplScope::Add(Nonnull<const Value*> iface, Nonnull<const Value*> type,
-                    ValueNodeView impl) {
-  impls_.push_back({.interface = iface, .type = type, .impl = impl});
+                    Nonnull<Expression*> impl) {
+  Add(iface, {}, type, {}, impl);
+}
+
+void ImplScope::Add(Nonnull<const Value*> iface,
+                    llvm::ArrayRef<Nonnull<const GenericBinding*>> deduced,
+                    Nonnull<const Value*> type,
+                    llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings,
+                    Nonnull<Expression*> impl) {
+  impls_.push_back({.interface = iface,
+                    .deduced = deduced,
+                    .type = type,
+                    .impl_bindings = impl_bindings,
+                    .impl = impl});
 }
 
 void ImplScope::AddParent(Nonnull<const ImplScope*> parent) {
@@ -23,11 +35,12 @@ void ImplScope::AddParent(Nonnull<const ImplScope*> parent) {
 }
 
 auto ImplScope::Resolve(Nonnull<const Value*> iface_type,
-                        Nonnull<const Value*> type,
-                        SourceLocation source_loc) const
-    -> ErrorOr<ValueNodeView> {
-  ASSIGN_OR_RETURN(std::optional<ValueNodeView> result,
-                   TryResolve(iface_type, type, source_loc));
+                        Nonnull<const Value*> type, SourceLocation source_loc,
+                        const TypeChecker& type_checker) const
+    -> ErrorOr<Nonnull<Expression*>> {
+  ASSIGN_OR_RETURN(
+      std::optional<Nonnull<Expression*>> result,
+      TryResolve(iface_type, type, source_loc, *this, type_checker));
   if (!result.has_value()) {
     return CompilationError(source_loc) << "could not find implementation of "
                                         << *iface_type << " for " << *type;
@@ -37,45 +50,54 @@ auto ImplScope::Resolve(Nonnull<const Value*> iface_type,
 
 auto ImplScope::TryResolve(Nonnull<const Value*> iface_type,
                            Nonnull<const Value*> type,
-                           SourceLocation source_loc) const
-    -> ErrorOr<std::optional<ValueNodeView>> {
-  std::optional<ValueNodeView> result =
-      ResolveHere(iface_type, type, source_loc);
-  if (result.has_value()) {
-    return result;
-  }
+                           SourceLocation source_loc,
+                           const ImplScope& original_scope,
+                           const TypeChecker& type_checker) const
+    -> ErrorOr<std::optional<Nonnull<Expression*>>> {
+  ASSIGN_OR_RETURN(
+      std::optional<Nonnull<Expression*>> result,
+      ResolveHere(iface_type, type, source_loc, original_scope, type_checker));
   for (Nonnull<const ImplScope*> parent : parent_scopes_) {
-    ASSIGN_OR_RETURN(auto parent_result,
-                     parent->TryResolve(iface_type, type, source_loc));
-    if (parent_result.has_value() && result.has_value() &&
-        *parent_result != *result) {
-      return CompilationError(source_loc) << "ambiguous implementations of "
-                                          << *iface_type << " for " << *type;
+    ASSIGN_OR_RETURN(std::optional<Nonnull<Expression*>> parent_result,
+                     parent->TryResolve(iface_type, type, source_loc,
+                                        original_scope, type_checker));
+    if (parent_result.has_value()) {
+      if (result.has_value()) {
+        return CompilationError(source_loc) << "ambiguous implementations of "
+                                            << *iface_type << " for " << *type;
+      } else {
+        result = *parent_result;
+      }
     }
-    result = parent_result;
   }
   return result;
 }
 
 auto ImplScope::ResolveHere(Nonnull<const Value*> iface_type,
                             Nonnull<const Value*> impl_type,
-                            SourceLocation /*source_loc*/) const
-    -> std::optional<ValueNodeView> {
-  switch (iface_type->kind()) {
-    case Value::Kind::InterfaceType: {
-      const auto& iface = cast<InterfaceType>(*iface_type);
-      for (const Impl& impl : impls_) {
-        if (TypeEqual(&iface, impl.interface) &&
-            TypeEqual(impl_type, impl.type)) {
-          return impl.impl;
-        }
+                            SourceLocation source_loc,
+                            const ImplScope& original_scope,
+                            const TypeChecker& type_checker) const
+    -> ErrorOr<std::optional<Nonnull<Expression*>>> {
+  if (iface_type->kind() != Value::Kind::InterfaceType) {
+    FATAL() << "expected an interface, not " << *iface_type;
+  }
+  const auto& iface = cast<InterfaceType>(*iface_type);
+  std::optional<Nonnull<Expression*>> result = std::nullopt;
+  for (const Impl& impl : impls_) {
+    std::optional<Nonnull<Expression*>> m = type_checker.MatchImpl(
+        iface, impl_type, impl, original_scope, source_loc);
+    if (m.has_value()) {
+      if (result.has_value()) {
+        return CompilationError(source_loc)
+               << "ambiguous implementations of " << *iface_type << " for "
+               << *impl_type;
+      } else {
+        result = *m;
       }
-      return std::nullopt;
     }
-    default:
-      FATAL() << "expected an interface, not " << *iface_type;
-      break;
   }
+  return result;
 }
 
 // TODO: Add indentation when printing the parents.

+ 44 - 16
executable_semantics/interpreter/impl_scope.h

@@ -10,6 +10,8 @@
 namespace Carbon {
 
 class Value;
+class TypeChecker;
+class InterfaceType;
 
 // The `ImplScope` class is responsible for mapping a type and
 // interface to the location of the witness table for the `impl` for
@@ -40,7 +42,14 @@ class ImplScope {
  public:
   // Associates `iface` and `type` with the `impl` in this scope.
   void Add(Nonnull<const Value*> iface, Nonnull<const Value*> type,
-           ValueNodeView impl);
+           Nonnull<Expression*> impl);
+  // For a parameterized impl, associates `iface` and `type`
+  // with the `impl` in this scope.
+  void Add(Nonnull<const Value*> iface,
+           llvm::ArrayRef<Nonnull<const GenericBinding*>> deduced,
+           Nonnull<const Value*> type,
+           llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings,
+           Nonnull<Expression*> impl);
 
   // Make `parent` a parent of this scope.
   // REQUIRES: `parent` is not already a parent of this scope.
@@ -50,32 +59,51 @@ class ImplScope {
   // the ancestor graph of this scope, or reports a compilation error
   // at `source_loc` there isn't exactly one matching impl.
   auto Resolve(Nonnull<const Value*> iface, Nonnull<const Value*> type,
-               SourceLocation source_loc) const -> ErrorOr<ValueNodeView>;
+               SourceLocation source_loc, const TypeChecker& type_checker) const
+      -> ErrorOr<Nonnull<Expression*>>;
 
   void Print(llvm::raw_ostream& out) const;
 
- private:
-  auto TryResolve(Nonnull<const Value*> iface_type, Nonnull<const Value*> type,
-                  SourceLocation source_loc) const
-      -> ErrorOr<std::optional<ValueNodeView>>;
-  auto ResolveHere(Nonnull<const Value*> iface_type,
-                   Nonnull<const Value*> impl_type,
-                   SourceLocation source_loc) const
-      -> std::optional<ValueNodeView>;
-
   // The `Impl` struct is a key-value pair where the key is the
   // combination of a type and an interface, e.g., `List` and `Container`,
   // and the value is the result of statically resolving to the `impl`
-  // for `List` as `Container`, which is an `ValueNodeView`. The generality
-  // of `ValueNodeView` is needed (not just `ImplDeclaration`) because
-  // inside a generic, we need to map, e.g., from `T` and `Container` to the
-  // witness table that is passed into the generic.
+  // for `List` as `Container`, which is an `Expression` that produces
+  // the witness for that `impl`.
+  // When the `impl` is parameterized, `deduced` and `impl_bindings`
+  // are non-empty. The former contains the type parameters and the
+  // later are impl bindings, that is, parameters for witnesses.
   struct Impl {
     Nonnull<const Value*> interface;
+    std::vector<Nonnull<const GenericBinding*>> deduced;
     Nonnull<const Value*> type;
-    ValueNodeView impl;
+    std::vector<Nonnull<const ImplBinding*>> impl_bindings;
+    Nonnull<Expression*> impl;
   };
 
+ private:
+  // Returns the associated impl for the given `iface` and `type` in
+  // the ancestor graph of this scope, returns std::nullopt if there
+  // is none, or reports a compilation error is there is not a most
+  // specific impl for the given `iface` and `type`.
+  // Use `original_scope` to satisfy requirements of any generic impl
+  // that matches `iface` and `type`.
+  auto TryResolve(Nonnull<const Value*> iface, Nonnull<const Value*> type,
+                  SourceLocation source_loc, const ImplScope& original_scope,
+                  const TypeChecker& type_checker) const
+      -> ErrorOr<std::optional<Nonnull<Expression*>>>;
+
+  // Returns the associated impl for the given `iface` and `type` in
+  // this scope, returns std::nullopt if there is none, or reports
+  // a compilation error is there is not a most specific impl for the
+  // given `iface` and `type`.
+  // Use `original_scope` to satisfy requirements of any generic impl
+  // that matches `iface` and `type`.
+  auto ResolveHere(Nonnull<const Value*> iface_type,
+                   Nonnull<const Value*> impl_type, SourceLocation source_loc,
+                   const ImplScope& original_scope,
+                   const TypeChecker& type_checker) const
+      -> ErrorOr<std::optional<Nonnull<Expression*>>>;
+
   std::vector<Impl> impls_;
   std::vector<Nonnull<const ImplScope*>> parent_scopes_;
 };

+ 196 - 178
executable_semantics/interpreter/interpreter.cpp

@@ -93,6 +93,15 @@ class Interpreter {
                SourceLocation source_loc) const
       -> ErrorOr<Nonnull<const Value*>>;
 
+  // Evaluate an impl expression to produce a witness, or signal an
+  // error.
+  //
+  // An impl expression is either
+  // 1) an IdentifierExpression whose value_node is an impl declaration, or
+  // 2) an InstantiateImpl expression.
+  auto EvalImplExp(Nonnull<const Expression*> exp) const
+      -> ErrorOr<Nonnull<const Witness*>>;
+
   // Instantiate a type by replacing all type variables that occur inside the
   // type by the current values of those variables.
   //
@@ -104,6 +113,12 @@ class Interpreter {
                        SourceLocation source_loc) const
       -> ErrorOr<Nonnull<const Value*>>;
 
+  // Call the function `fun` with the given `arg` and the `witnesses`
+  // for the function's impl bindings.
+  auto CallFunction(const CallExpression& call, Nonnull<const Value*> fun,
+                    Nonnull<const Value*> arg, const ImplWitnessMap& witnesses)
+      -> ErrorOr<Success>;
+
   void PrintState(llvm::raw_ostream& out);
 
   Phase phase() const { return phase_; }
@@ -365,23 +380,49 @@ auto Interpreter::StepLvalue() -> ErrorOr<Success> {
     case ExpressionKind::IntrinsicExpression:
     case ExpressionKind::IfExpression:
     case ExpressionKind::ArrayTypeLiteral:
+    case ExpressionKind::InstantiateImpl:
       FATAL() << "Can't treat expression as lvalue: " << exp;
     case ExpressionKind::UnimplementedExpression:
       FATAL() << "Unimplemented: " << exp;
   }
 }
 
+auto Interpreter::EvalImplExp(Nonnull<const Expression*> exp) const
+    -> ErrorOr<Nonnull<const Witness*>> {
+  switch (exp->kind()) {
+    case ExpressionKind::InstantiateImpl: {
+      const InstantiateImpl& inst_impl = cast<InstantiateImpl>(*exp);
+      ASSIGN_OR_RETURN(Nonnull<const Witness*> gen_impl,
+                       EvalImplExp(inst_impl.generic_impl()));
+      ImplWitnessMap witnesses;
+      for (auto& [bind, impl_exp] : inst_impl.impls()) {
+        ASSIGN_OR_RETURN(witnesses[bind], EvalImplExp(impl_exp));
+      }
+      return arena_->New<Witness>(&gen_impl->declaration(),
+                                  inst_impl.type_args(), witnesses);
+    }
+    case ExpressionKind::IdentifierExpression: {
+      const auto& ident = cast<IdentifierExpression>(*exp);
+      ASSIGN_OR_RETURN(
+          Nonnull<const Value*> value,
+          todo_.ValueOfNode(ident.value_node(), ident.source_loc()));
+      if (const auto* lvalue = dyn_cast<LValue>(value)) {
+        ASSIGN_OR_RETURN(value,
+                         heap_.Read(lvalue->address(), exp->source_loc()));
+      }
+      return cast<Witness>(value);
+    }
+    default: {
+      FATAL() << "EvalImplExp, unexpected expression: " << *exp;
+    }
+  }
+}
+
 auto Interpreter::InstantiateType(Nonnull<const Value*> type,
                                   SourceLocation source_loc) const
     -> ErrorOr<Nonnull<const Value*>> {
-  if (trace_stream_) {
-    **trace_stream_ << "instantiating: " << *type << "\n";
-  }
   switch (type->kind()) {
     case Value::Kind::VariableType: {
-      if (trace_stream_) {
-        **trace_stream_ << "case VariableType\n";
-      }
       ASSIGN_OR_RETURN(
           Nonnull<const Value*> value,
           todo_.ValueOfNode(&cast<VariableType>(*type).binding(), source_loc));
@@ -391,46 +432,15 @@ auto Interpreter::InstantiateType(Nonnull<const Value*> type,
       return value;
     }
     case Value::Kind::NominalClassType: {
-      if (trace_stream_) {
-        **trace_stream_ << "case NominalClassType\n";
-      }
       const auto& class_type = cast<NominalClassType>(*type);
       BindingMap inst_type_args;
       for (const auto& [ty_var, ty_arg] : class_type.type_args()) {
         ASSIGN_OR_RETURN(inst_type_args[ty_var],
                          InstantiateType(ty_arg, source_loc));
       }
-      if (trace_stream_) {
-        **trace_stream_ << "finished instantiating ty_arg\n";
-      }
       std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>> witnesses;
-      for (const auto& [bind, impl] : class_type.impls()) {
-        ASSIGN_OR_RETURN(Nonnull<const Value*> witness_addr,
-                         todo_.ValueOfNode(impl, source_loc));
-        if (trace_stream_) {
-          **trace_stream_ << "witness_addr: " << *witness_addr << "\n";
-        }
-        // If the witness came directly from an `impl` declaration (via
-        // `constant_value`), then it is a `Witness`. If the witness
-        // came from the runtime scope, then the `Witness` got wrapped
-        // in an `LValue` because that's what
-        // `RuntimeScope::Initialize` does.
-        Nonnull<const Witness*> witness;
-        if (llvm::isa<Witness>(witness_addr)) {
-          witness = cast<Witness>(witness_addr);
-        } else if (llvm::isa<LValue>(witness_addr)) {
-          ASSIGN_OR_RETURN(
-              Nonnull<const Value*> witness_value,
-              heap_.Read(llvm::cast<LValue>(witness_addr)->address(),
-                         source_loc));
-          witness = cast<Witness>(witness_value);
-        } else {
-          FATAL() << "expected a witness or LValue of a witness";
-        }
-        witnesses[bind] = witness;
-      }
-      if (trace_stream_) {
-        **trace_stream_ << "finished finding witnesses\n";
+      for (const auto& [bind, impl_exp] : class_type.impls()) {
+        ASSIGN_OR_RETURN(witnesses[bind], EvalImplExp(impl_exp));
       }
       return arena_->New<NominalClassType>(&class_type.declaration(),
                                            inst_type_args, witnesses);
@@ -538,6 +548,104 @@ auto Interpreter::Convert(Nonnull<const Value*> value,
   }
 }
 
+auto Interpreter::CallFunction(const CallExpression& call,
+                               Nonnull<const Value*> fun,
+                               Nonnull<const Value*> arg,
+                               const ImplWitnessMap& witnesses)
+    -> ErrorOr<Success> {
+  if (trace_stream_) {
+    **trace_stream_ << "calling function: " << *fun << "\n";
+  }
+  switch (fun->kind()) {
+    case Value::Kind::AlternativeConstructorValue: {
+      const auto& alt = cast<AlternativeConstructorValue>(*fun);
+      return todo_.FinishAction(arena_->New<AlternativeValue>(
+          alt.alt_name(), alt.choice_name(), arg));
+    }
+    case Value::Kind::FunctionValue: {
+      const FunctionValue& fun_val = cast<FunctionValue>(*fun);
+      const FunctionDeclaration& function = fun_val.declaration();
+      ASSIGN_OR_RETURN(Nonnull<const Value*> converted_args,
+                       Convert(arg, &function.param_pattern().static_type(),
+                               call.source_loc()));
+      RuntimeScope function_scope(&heap_);
+      // Bring the class type arguments into scope.
+      for (const auto& [bind, val] : fun_val.type_args()) {
+        function_scope.Initialize(bind, val);
+      }
+      // Bring the deduced type arguments into scope.
+      for (const auto& [bind, val] : call.deduced_args()) {
+        function_scope.Initialize(bind, val);
+      }
+      // Bring the impl witness tables into scope.
+      for (const auto& [impl_bind, witness] : witnesses) {
+        function_scope.Initialize(impl_bind, witness);
+      }
+      for (const auto& [impl_bind, witness] : fun_val.witnesses()) {
+        function_scope.Initialize(impl_bind, witness);
+      }
+      BindingMap generic_args;
+      CHECK(PatternMatch(&function.param_pattern().value(), converted_args,
+                         call.source_loc(), &function_scope, generic_args));
+      CHECK(function.body().has_value())
+          << "Calling a function that's missing a body";
+      return todo_.Spawn(std::make_unique<StatementAction>(*function.body()),
+                         std::move(function_scope));
+    }
+    case Value::Kind::BoundMethodValue: {
+      const auto& m = cast<BoundMethodValue>(*fun);
+      const FunctionDeclaration& method = m.declaration();
+      CHECK(method.is_method());
+      ASSIGN_OR_RETURN(Nonnull<const Value*> converted_args,
+                       Convert(arg, &method.param_pattern().static_type(),
+                               call.source_loc()));
+      RuntimeScope method_scope(&heap_);
+      BindingMap generic_args;
+      CHECK(PatternMatch(&method.me_pattern().value(), m.receiver(),
+                         call.source_loc(), &method_scope, generic_args));
+      CHECK(PatternMatch(&method.param_pattern().value(), converted_args,
+                         call.source_loc(), &method_scope, generic_args));
+      // Bring the class type arguments into scope.
+      for (const auto& [bind, val] : m.type_args()) {
+        method_scope.Initialize(bind, val);
+      }
+
+      // Bring the impl witness tables into scope.
+      for (const auto& [impl_bind, witness] : m.witnesses()) {
+        method_scope.Initialize(impl_bind, witness);
+      }
+      CHECK(method.body().has_value())
+          << "Calling a method that's missing a body";
+      return todo_.Spawn(std::make_unique<StatementAction>(*method.body()),
+                         std::move(method_scope));
+    }
+    case Value::Kind::NominalClassType: {
+      const NominalClassType& class_type = cast<NominalClassType>(*fun);
+      const ClassDeclaration& class_decl = class_type.declaration();
+      RuntimeScope type_params_scope(&heap_);
+      BindingMap generic_args;
+      if (class_decl.type_params().has_value()) {
+        CHECK(PatternMatch(&(*class_decl.type_params())->value(), arg,
+                           call.source_loc(), &type_params_scope,
+                           generic_args));
+        switch (phase()) {
+          case Phase::RunTime:
+            return todo_.FinishAction(arena_->New<NominalClassType>(
+                &class_type.declaration(), generic_args, witnesses));
+          case Phase::CompileTime:
+            return todo_.FinishAction(arena_->New<NominalClassType>(
+                &class_type.declaration(), generic_args, call.impls()));
+        }
+      } else {
+        FATAL() << "instantiation of non-generic class " << class_type;
+      }
+    }
+    default:
+      return RuntimeError(call.source_loc())
+             << "in call, expected a function, not " << *fun;
+  }
+}
+
 auto Interpreter::StepExp() -> ErrorOr<Success> {
   Action& act = todo_.CurrentAction();
   const Expression& exp = cast<ExpressionAction>(act).expression();
@@ -546,6 +654,28 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
                     << ") --->\n";
   }
   switch (exp.kind()) {
+    case ExpressionKind::InstantiateImpl: {
+      const InstantiateImpl& inst_impl = cast<InstantiateImpl>(exp);
+      if (act.pos() == 0) {
+        return todo_.Spawn(
+            std::make_unique<ExpressionAction>(inst_impl.generic_impl()));
+      } else if (act.pos() - 1 < int(inst_impl.impls().size())) {
+        auto iter = inst_impl.impls().begin();
+        std::advance(iter, act.pos() - 1);
+        return todo_.Spawn(std::make_unique<ExpressionAction>(iter->second));
+      } else {
+        Nonnull<const Witness*> generic_witness =
+            cast<Witness>(act.results()[0]);
+        ImplWitnessMap witnesses;
+        int i = 0;
+        for (const auto& [impl_bind, impl_exp] : inst_impl.impls()) {
+          witnesses[impl_bind] = cast<Witness>(act.results()[i + 1]);
+          ++i;
+        }
+        return todo_.FinishAction(arena_->New<Witness>(
+            &generic_witness->declaration(), inst_impl.type_args(), witnesses));
+      }
+    }
     case ExpressionKind::IndexExpression: {
       if (act.pos() == 0) {
         //    { { e[i] :: C, E, F} :: S, H}
@@ -673,161 +803,49 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
         return todo_.FinishAction(value);
       }
     }
-    case ExpressionKind::CallExpression:
+    case ExpressionKind::CallExpression: {
+      const CallExpression& call = cast<CallExpression>(exp);
+      // Don't evaluate the impls at compile time?
+      unsigned int num_impls =
+          phase() == Phase::CompileTime ? 0 : call.impls().size();
       if (act.pos() == 0) {
         //    { {e1(e2) :: C, E, F} :: S, H}
         // -> { {e1 :: [](e2) :: C, E, F} :: S, H}
-        return todo_.Spawn(std::make_unique<ExpressionAction>(
-            &cast<CallExpression>(exp).function()));
+        return todo_.Spawn(
+            std::make_unique<ExpressionAction>(&call.function()));
       } else if (act.pos() == 1) {
         //    { { v :: [](e) :: C, E, F} :: S, H}
         // -> { { e :: v([]) :: C, E, F} :: S, H}
-        return todo_.Spawn(std::make_unique<ExpressionAction>(
-            &cast<CallExpression>(exp).argument()));
-      } else if (act.pos() == 2) {
+        return todo_.Spawn(
+            std::make_unique<ExpressionAction>(&call.argument()));
+      } else if (num_impls > 0 && act.pos() < 2 + int(num_impls)) {
+        auto iter = call.impls().begin();
+        std::advance(iter, act.pos() - 2);
+        return todo_.Spawn(std::make_unique<ExpressionAction>(iter->second));
+      } else if (act.pos() == 2 + int(num_impls)) {
         //    { { v2 :: v1([]) :: C, E, F} :: S, H}
         // -> { {C',E',F'} :: {C, E, F} :: S, H}
-        switch (act.results()[0]->kind()) {
-          case Value::Kind::AlternativeConstructorValue: {
-            const auto& alt =
-                cast<AlternativeConstructorValue>(*act.results()[0]);
-            return todo_.FinishAction(arena_->New<AlternativeValue>(
-                alt.alt_name(), alt.choice_name(), act.results()[1]));
+        ImplWitnessMap witnesses;
+        if (num_impls > 0) {
+          int i = 2;
+          for (const auto& [impl_bind, impl_exp] : call.impls()) {
+            witnesses[impl_bind] = cast<Witness>(act.results()[i]);
+            ++i;
           }
-          case Value::Kind::FunctionValue: {
-            const FunctionValue& fun_val =
-                cast<FunctionValue>(*act.results()[0]);
-            const FunctionDeclaration& function = fun_val.declaration();
-            if (trace_stream_) {
-              **trace_stream_ << "*** call function " << function.name()
-                              << "\n";
-            }
-            ASSIGN_OR_RETURN(Nonnull<const Value*> converted_args,
-                             Convert(act.results()[1],
-                                     &function.param_pattern().static_type(),
-                                     exp.source_loc()));
-            RuntimeScope function_scope(&heap_);
-            // Bring the class type arguments into scope.
-            for (const auto& [bind, val] : fun_val.type_args()) {
-              function_scope.Initialize(bind, val);
-            }
-            // Bring the deduced type arguments into scope.
-            for (const auto& [bind, val] :
-                 cast<CallExpression>(exp).deduced_args()) {
-              function_scope.Initialize(bind, val);
-            }
-
-            // Bring the impl witness tables into scope.
-            for (const auto& [impl_bind, impl_node] :
-                 cast<CallExpression>(exp).impls()) {
-              ASSIGN_OR_RETURN(Nonnull<const Value*> witness,
-                               todo_.ValueOfNode(impl_node, exp.source_loc()));
-              if (witness->kind() == Value::Kind::LValue) {
-                const auto& lval = cast<LValue>(*witness);
-                ASSIGN_OR_RETURN(witness,
-                                 heap_.Read(lval.address(), exp.source_loc()));
-              }
-              function_scope.Initialize(impl_bind, witness);
-            }
-            for (const auto& [impl_bind, witness] : fun_val.witnesses()) {
-              function_scope.Initialize(impl_bind, witness);
-            }
-            BindingMap generic_args;
-            CHECK(PatternMatch(&function.param_pattern().value(),
-                               converted_args, exp.source_loc(),
-                               &function_scope, generic_args));
-            CHECK(function.body().has_value())
-                << "Calling a function that's missing a body";
-            return todo_.Spawn(
-                std::make_unique<StatementAction>(*function.body()),
-                std::move(function_scope));
-          }
-          case Value::Kind::BoundMethodValue: {
-            const auto& m = cast<BoundMethodValue>(*act.results()[0]);
-            const FunctionDeclaration& method = m.declaration();
-            CHECK(method.is_method());
-            ASSIGN_OR_RETURN(
-                Nonnull<const Value*> converted_args,
-                Convert(act.results()[1], &method.param_pattern().static_type(),
-                        exp.source_loc()));
-            RuntimeScope method_scope(&heap_);
-            BindingMap generic_args;
-            CHECK(PatternMatch(&method.me_pattern().value(), m.receiver(),
-                               exp.source_loc(), &method_scope, generic_args));
-            CHECK(PatternMatch(&method.param_pattern().value(), converted_args,
-                               exp.source_loc(), &method_scope, generic_args));
-            // Bring the class type arguments into scope.
-            for (const auto& [bind, val] : m.type_args()) {
-              method_scope.Initialize(bind, val);
-            }
-
-            // Bring the impl witness tables into scope.
-            for (const auto& [impl_bind, witness] : m.witnesses()) {
-              method_scope.Initialize(impl_bind, witness);
-            }
-            CHECK(method.body().has_value())
-                << "Calling a method that's missing a body";
-            return todo_.Spawn(
-                std::make_unique<StatementAction>(*method.body()),
-                std::move(method_scope));
-          }
-          case Value::Kind::NominalClassType: {
-            const NominalClassType& class_type =
-                cast<NominalClassType>(*act.results()[0]);
-            const ClassDeclaration& class_decl = class_type.declaration();
-            RuntimeScope type_params_scope(&heap_);
-            BindingMap generic_args;
-            if (class_decl.type_params().has_value()) {
-              CHECK(PatternMatch(&(*class_decl.type_params())->value(),
-                                 act.results()[1], exp.source_loc(),
-                                 &type_params_scope, generic_args));
-              switch (phase()) {
-                case Phase::RunTime: {
-                  std::map<Nonnull<const ImplBinding*>, const Witness*>
-                      witnesses;
-                  for (const auto& [impl_bind, impl_node] :
-                       cast<CallExpression>(exp).impls()) {
-                    ASSIGN_OR_RETURN(
-                        Nonnull<const Value*> witness,
-                        todo_.ValueOfNode(impl_node, exp.source_loc()));
-                    if (witness->kind() == Value::Kind::LValue) {
-                      const LValue& lval = cast<LValue>(*witness);
-                      ASSIGN_OR_RETURN(witness, heap_.Read(lval.address(),
-                                                           exp.source_loc()));
-                    }
-                    witnesses[impl_bind] = &cast<Witness>(*witness);
-                  }
-                  Nonnull<NominalClassType*> inst_class =
-                      arena_->New<NominalClassType>(&class_type.declaration(),
-                                                    generic_args, witnesses);
-                  return todo_.FinishAction(inst_class);
-                }
-                case Phase::CompileTime: {
-                  Nonnull<NominalClassType*> inst_class =
-                      arena_->New<NominalClassType>(
-                          &class_type.declaration(), generic_args,
-                          cast<CallExpression>(exp).impls());
-                  return todo_.FinishAction(inst_class);
-                }
-              }
-            } else {
-              FATAL() << "instantiation of non-generic class " << class_type;
-            }
-          }
-          default:
-            return RuntimeError(exp.source_loc())
-                   << "in call, expected a function, not " << *act.results()[0];
         }
-      } else if (act.pos() == 3) {
-        if (act.results().size() < 3) {
+        return CallFunction(call, act.results()[0], act.results()[1],
+                            witnesses);
+      } else if (act.pos() == 3 + int(num_impls)) {
+        if (act.results().size() < 3 + num_impls) {
           // Control fell through without explicit return.
           return todo_.FinishAction(TupleValue::Empty());
         } else {
-          return todo_.FinishAction(act.results()[2]);
+          return todo_.FinishAction(act.results()[2 + int(num_impls)]);
         }
       } else {
-        FATAL() << "in handle_value with Call pos " << act.pos();
+        FATAL() << "in StepExp with Call pos " << act.pos();
       }
+    }
     case ExpressionKind::IntrinsicExpression: {
       const auto& intrinsic = cast<IntrinsicExpression>(exp);
       if (act.pos() == 0) {

+ 10 - 3
executable_semantics/interpreter/resolve_names.cpp

@@ -166,6 +166,7 @@ static auto ResolveNames(Expression& expression,
     case ExpressionKind::StringTypeLiteral:
     case ExpressionKind::TypeTypeLiteral:
       break;
+    case ExpressionKind::InstantiateImpl:  // created after name resolution
     case ExpressionKind::UnimplementedExpression:
       return CompilationError(expression.source_loc()) << "Unimplemented";
   }
@@ -313,13 +314,19 @@ static auto ResolveNames(Declaration& declaration, StaticScope& enclosing_scope)
     }
     case DeclarationKind::ImplDeclaration: {
       auto& impl = cast<ImplDeclaration>(declaration);
+      StaticScope impl_scope;
+      impl_scope.AddParent(&enclosing_scope);
+      for (Nonnull<GenericBinding*> binding : impl.deduced_parameters()) {
+        RETURN_IF_ERROR(ResolveNames(binding->type(), impl_scope));
+        RETURN_IF_ERROR(impl_scope.Add(binding->name(), binding));
+      }
+      RETURN_IF_ERROR(ResolveNames(*impl.impl_type(), impl_scope));
       RETURN_IF_ERROR(ResolveNames(impl.interface(), enclosing_scope));
-      RETURN_IF_ERROR(ResolveNames(*impl.impl_type(), enclosing_scope));
       for (Nonnull<Declaration*> member : impl.members()) {
-        RETURN_IF_ERROR(AddExposedNames(*member, enclosing_scope));
+        RETURN_IF_ERROR(AddExposedNames(*member, impl_scope));
       }
       for (Nonnull<Declaration*> member : impl.members()) {
-        RETURN_IF_ERROR(ResolveNames(*member, enclosing_scope));
+        RETURN_IF_ERROR(ResolveNames(*member, impl_scope));
       }
       break;
     }

+ 214 - 86
executable_semantics/interpreter/type_checker.cpp

@@ -123,7 +123,7 @@ auto TypeChecker::ExpectIsConcreteType(SourceLocation source_loc,
 
 auto TypeChecker::FieldTypesImplicitlyConvertible(
     llvm::ArrayRef<NamedValue> source_fields,
-    llvm::ArrayRef<NamedValue> destination_fields) {
+    llvm::ArrayRef<NamedValue> destination_fields) const {
   if (source_fields.size() != destination_fields.size()) {
     return false;
   }
@@ -140,7 +140,7 @@ auto TypeChecker::FieldTypesImplicitlyConvertible(
   return true;
 }
 
-auto TypeChecker::FieldTypes(const NominalClassType& class_type)
+auto TypeChecker::FieldTypes(const NominalClassType& class_type) const
     -> std::vector<NamedValue> {
   std::vector<NamedValue> field_types;
   for (Nonnull<Declaration*> m : class_type.declaration().members()) {
@@ -160,8 +160,8 @@ auto TypeChecker::FieldTypes(const NominalClassType& class_type)
   return field_types;
 }
 
-auto TypeChecker::IsImplicitlyConvertible(Nonnull<const Value*> source,
-                                          Nonnull<const Value*> destination)
+auto TypeChecker::IsImplicitlyConvertible(
+    Nonnull<const Value*> source, Nonnull<const Value*> destination) const
     -> bool {
   CHECK(IsConcreteType(source));
   CHECK(IsConcreteType(destination));
@@ -226,7 +226,8 @@ auto TypeChecker::IsImplicitlyConvertible(Nonnull<const Value*> source,
 auto TypeChecker::ExpectType(SourceLocation source_loc,
                              const std::string& context,
                              Nonnull<const Value*> expected,
-                             Nonnull<const Value*> actual) -> ErrorOr<Success> {
+                             Nonnull<const Value*> actual) const
+    -> ErrorOr<Success> {
   if (!IsImplicitlyConvertible(actual, expected)) {
     return CompilationError(source_loc)
            << "type error in " << context << ": "
@@ -237,19 +238,26 @@ auto TypeChecker::ExpectType(SourceLocation source_loc,
   }
 }
 
-auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
-                                    BindingMap& deduced,
-                                    Nonnull<const Value*> param_type,
-                                    Nonnull<const Value*> arg_type)
-    -> ErrorOr<Success> {
+auto TypeChecker::ArgumentDeduction(
+    SourceLocation source_loc,
+    llvm::ArrayRef<Nonnull<const GenericBinding*>> type_params,
+    BindingMap& deduced, Nonnull<const Value*> param_type,
+    Nonnull<const Value*> arg_type) const -> ErrorOr<Success> {
   switch (param_type->kind()) {
     case Value::Kind::VariableType: {
       const auto& var_type = cast<VariableType>(*param_type);
-      auto [it, success] = deduced.insert({&var_type.binding(), arg_type});
-      if (!success) {
-        // TODO: can we allow implicit conversions here?
+      if (std::find(type_params.begin(), type_params.end(),
+                    &var_type.binding()) != type_params.end()) {
+        auto [it, success] = deduced.insert({&var_type.binding(), arg_type});
+        if (!success) {
+          // Variable already has a match.
+          // TODO: can we allow implicit conversions here?
+          RETURN_IF_ERROR(ExpectExactType(source_loc, "argument deduction",
+                                          it->second, arg_type));
+        }
+      } else {
         RETURN_IF_ERROR(ExpectExactType(source_loc, "argument deduction",
-                                        it->second, arg_type));
+                                        param_type, arg_type));
       }
       return Success();
     }
@@ -269,7 +277,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
                << arg_tup.elements().size();
       }
       for (size_t i = 0; i < param_tup.elements().size(); ++i) {
-        RETURN_IF_ERROR(ArgumentDeduction(source_loc, deduced,
+        RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
                                           param_tup.elements()[i],
                                           arg_tup.elements()[i]));
       }
@@ -296,7 +304,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
                  << "mismatch in field names, " << param_struct.fields()[i].name
                  << " != " << arg_struct.fields()[i].name;
         }
-        RETURN_IF_ERROR(ArgumentDeduction(source_loc, deduced,
+        RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
                                           param_struct.fields()[i].value,
                                           arg_struct.fields()[i].value));
       }
@@ -312,10 +320,12 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
       const auto& param_fn = cast<FunctionType>(*param_type);
       const auto& arg_fn = cast<FunctionType>(*arg_type);
       // TODO: handle situation when arg has deduced parameters.
-      RETURN_IF_ERROR(ArgumentDeduction(
-          source_loc, deduced, &param_fn.parameters(), &arg_fn.parameters()));
-      RETURN_IF_ERROR(ArgumentDeduction(
-          source_loc, deduced, &param_fn.return_type(), &arg_fn.return_type()));
+      RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
+                                        &param_fn.parameters(),
+                                        &arg_fn.parameters()));
+      RETURN_IF_ERROR(ArgumentDeduction(source_loc, type_params, deduced,
+                                        &param_fn.return_type(),
+                                        &arg_fn.return_type()));
       return Success();
     }
     case Value::Kind::PointerType: {
@@ -325,7 +335,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
                << "expected: " << *param_type << "\n"
                << "actual: " << *arg_type;
       }
-      return ArgumentDeduction(source_loc, deduced,
+      return ArgumentDeduction(source_loc, type_params, deduced,
                                &cast<PointerType>(*param_type).type(),
                                &cast<PointerType>(*arg_type).type());
     }
@@ -341,7 +351,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
             arg_class_type.declaration().name()) {
           for (const auto& [ty, param_ty] : param_class_type.type_args()) {
             RETURN_IF_ERROR(
-                ArgumentDeduction(source_loc, deduced, param_ty,
+                ArgumentDeduction(source_loc, type_params, deduced, param_ty,
                                   arg_class_type.type_args().at(ty)));
           }
           return Success();
@@ -387,7 +397,7 @@ auto TypeChecker::ArgumentDeduction(SourceLocation source_loc,
 
 auto TypeChecker::Substitute(
     const std::map<Nonnull<const GenericBinding*>, Nonnull<const Value*>>& dict,
-    Nonnull<const Value*> type) -> Nonnull<const Value*> {
+    Nonnull<const Value*> type) const -> Nonnull<const Value*> {
   switch (type->kind()) {
     case Value::Kind::VariableType: {
       auto it = dict.find(&cast<VariableType>(*type).binding());
@@ -470,6 +480,80 @@ auto TypeChecker::Substitute(
   }
 }
 
+auto TypeChecker::MatchImpl(const InterfaceType& iface,
+                            Nonnull<const Value*> impl_type,
+                            const ImplScope::Impl& impl,
+                            const ImplScope& impl_scope,
+                            SourceLocation source_loc) const
+    -> std::optional<Nonnull<Expression*>> {
+  if (trace_stream_) {
+    **trace_stream_ << "MatchImpl: looking for " << *impl_type << " as "
+                    << iface << "\n";
+    **trace_stream_ << "checking [";
+    llvm::ListSeparator sep;
+    for (Nonnull<const GenericBinding*> deduced_param : impl.deduced) {
+      **trace_stream_ << sep << *deduced_param;
+    }
+    **trace_stream_ << "] " << *impl.type << " as " << *impl.interface << "\n";
+  }
+  if (!TypeEqual(&iface, impl.interface)) {
+    return std::nullopt;
+  }
+  if (impl.deduced.empty() && impl.impl_bindings.empty()) {
+    // case: impl is a non-generic impl.
+    if (!TypeEqual(impl_type, impl.type)) {
+      return std::nullopt;
+    }
+    return impl.impl;
+  } else {
+    // case: impl is a generic impl.
+    BindingMap deduced_type_args;
+    ErrorOr<Success> e = ArgumentDeduction(
+        source_loc, impl.deduced, deduced_type_args, impl.type, impl_type);
+    if (trace_stream_) {
+      **trace_stream_ << "match results: {";
+      llvm::ListSeparator sep;
+      for (const auto& [binding, val] : deduced_type_args) {
+        **trace_stream_ << sep << *binding << " = " << *val;
+      }
+      **trace_stream_ << "}\n";
+    }
+    if (!e.ok()) {
+      return std::nullopt;
+    }
+    // Check that all the type parameters were deduced.
+    // Find impls for all the impls bindings.
+    ImplExpMap impls;
+    ErrorOr<Success> m = SatisfyImpls(impl.impl_bindings, impl_scope,
+                                      source_loc, deduced_type_args, impls);
+    if (!m.ok()) {
+      return std::nullopt;
+    }
+    if (trace_stream_) {
+      **trace_stream_ << "matched with " << *impl.type << " as "
+                      << *impl.interface << "\n\n";
+    }
+    return arena_->New<InstantiateImpl>(source_loc, impl.impl,
+                                        deduced_type_args, impls);
+  }
+}
+
+auto TypeChecker::SatisfyImpls(
+    llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings,
+    const ImplScope& impl_scope, SourceLocation source_loc,
+    BindingMap& deduced_type_args, ImplExpMap& impls) const
+    -> ErrorOr<Success> {
+  for (Nonnull<const ImplBinding*> impl_binding : impl_bindings) {
+    ASSIGN_OR_RETURN(
+        Nonnull<Expression*> impl,
+        impl_scope.Resolve(impl_binding->interface(),
+                           deduced_type_args[impl_binding->type_var()],
+                           source_loc, *this));
+    impls.emplace(impl_binding, impl);
+  }
+  return Success();
+}
+
 auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                                const ImplScope& impl_scope)
     -> ErrorOr<Success> {
@@ -480,6 +564,9 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
     **trace_stream_ << "\n";
   }
   switch (e->kind()) {
+    case ExpressionKind::InstantiateImpl: {
+      FATAL() << "instantiate impl nodes are generated during type checking";
+    }
     case ExpressionKind::IndexExpression: {
       auto& index = cast<IndexExpression>(*e);
       RETURN_IF_ERROR(TypeCheckExp(&index.aggregate(), impl_scope));
@@ -842,7 +929,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
           Nonnull<const Value*> return_type = &fun_t.return_type();
           if (!fun_t.deduced().empty()) {
             BindingMap deduced_type_args;
-            RETURN_IF_ERROR(ArgumentDeduction(e->source_loc(),
+            RETURN_IF_ERROR(ArgumentDeduction(e->source_loc(), fun_t.deduced(),
                                               deduced_type_args, parameters,
                                               &call.argument().static_type()));
             call.set_deduced_args(deduced_type_args);
@@ -860,30 +947,11 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
             }
             parameters = Substitute(deduced_type_args, parameters);
             return_type = Substitute(deduced_type_args, return_type);
-
-            // Find impls for all the impl bindings of the function
-            std::map<Nonnull<const ImplBinding*>, ValueNodeView> impls;
-            for (Nonnull<const ImplBinding*> impl_binding :
-                 fun_t.impl_bindings()) {
-              switch (impl_binding->interface()->kind()) {
-                case Value::Kind::InterfaceType: {
-                  ASSIGN_OR_RETURN(
-                      ValueNodeView impl,
-                      impl_scope.Resolve(
-                          impl_binding->interface(),
-                          deduced_type_args[impl_binding->type_var()],
-                          e->source_loc()));
-                  impls.emplace(impl_binding, impl);
-                  break;
-                }
-                case Value::Kind::TypeType:
-                  break;
-                default:
-                  return CompilationError(e->source_loc())
-                         << "unexpected type of deduced parameter "
-                         << *impl_binding->interface();
-              }
-            }
+            // Find impls for all the impl bindings of the function.
+            ImplExpMap impls;
+            RETURN_IF_ERROR(SatisfyImpls(fun_t.impl_bindings(), impl_scope,
+                                         e->source_loc(), deduced_type_args,
+                                         impls));
             call.set_impls(impls);
           } else {
             // No deduced parameters. Check that the argument types
@@ -917,17 +985,18 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                    << "attempt to instantiate a non-generic class: " << *e;
           }
           // Find impls for all the impl bindings of the class.
-          std::map<Nonnull<const ImplBinding*>, ValueNodeView> impls;
+          ImplExpMap impls;
           for (const auto& [binding, val] : generic_args) {
             if (binding->impl_binding().has_value()) {
               Nonnull<const ImplBinding*> impl_binding =
                   *binding->impl_binding();
               switch (impl_binding->interface()->kind()) {
                 case Value::Kind::InterfaceType: {
-                  ASSIGN_OR_RETURN(ValueNodeView impl,
-                                   impl_scope.Resolve(impl_binding->interface(),
-                                                      generic_args[binding],
-                                                      call.source_loc()));
+                  ASSIGN_OR_RETURN(
+                      Nonnull<Expression*> impl,
+                      impl_scope.Resolve(impl_binding->interface(),
+                                         generic_args[binding],
+                                         call.source_loc(), *this));
                   impls.emplace(impl_binding, impl);
                   break;
                 }
@@ -1056,9 +1125,10 @@ void TypeChecker::AddPatternImpls(Nonnull<Pattern*> p, ImplScope& impl_scope) {
       auto& binding = cast<GenericBinding>(*p);
       CHECK(binding.impl_binding().has_value());
       Nonnull<const ImplBinding*> impl_binding = *binding.impl_binding();
+      auto impl_id = arena_->New<IdentifierExpression>(p->source_loc(), "impl");
+      impl_id->set_value_node(impl_binding);
       impl_scope.Add(impl_binding->interface(),
-                     *impl_binding->type_var()->symbolic_identity(),
-                     impl_binding);
+                     *impl_binding->type_var()->symbolic_identity(), impl_id);
       return;
     }
     case PatternKind::TuplePattern: {
@@ -1350,10 +1420,10 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s,
       return Success();
     }
     case StatementKind::Await: {
-      // nothing to do here
+      // Nothing to do here.
       return Success();
     }
-  }  // switch
+  }
 }
 
 // Returns true if we can statically verify that `match` is exhaustive, meaning
@@ -1435,6 +1505,45 @@ auto TypeChecker::ExpectReturnOnAllPaths(
   }
 }
 
+auto TypeChecker::CreateImplBindings(
+    llvm::ArrayRef<Nonnull<GenericBinding*>> deduced_parameters,
+    SourceLocation source_loc,
+    std::vector<Nonnull<const ImplBinding*>>& impl_bindings)
+    -> ErrorOr<Success> {
+  for (Nonnull<GenericBinding*> deduced : deduced_parameters) {
+    switch (deduced->static_type().kind()) {
+      case Value::Kind::InterfaceType: {
+        Nonnull<ImplBinding*> impl_binding = arena_->New<ImplBinding>(
+            deduced->source_loc(), deduced, &deduced->static_type());
+        deduced->set_impl_binding(impl_binding);
+        impl_binding->set_static_type(&deduced->static_type());
+        impl_bindings.push_back(impl_binding);
+        break;
+      }
+      case Value::Kind::TypeType:
+        // No `impl` binding needed for type parameter with bound `Type`.
+        break;
+      default:
+        return CompilationError(source_loc)
+               << "unexpected type of deduced parameter "
+               << deduced->static_type();
+    }
+  }
+  return Success();
+}
+
+void TypeChecker::BringImplsIntoScope(
+    llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings, ImplScope& scope,
+    SourceLocation source_loc) {
+  for (Nonnull<const ImplBinding*> impl_binding : impl_bindings) {
+    CHECK(impl_binding->type_var()->symbolic_identity().has_value());
+    auto impl_id = arena_->New<IdentifierExpression>(source_loc, "impl");
+    impl_id->set_value_node(impl_binding);
+    scope.Add(impl_binding->interface(),
+              *impl_binding->type_var()->symbolic_identity(), impl_id);
+  }
+}
+
 // TODO: Add checking to function definitions to ensure that
 //   all deduced type parameters will be deduced.
 auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
@@ -1443,7 +1552,7 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
   if (trace_stream_) {
     **trace_stream_ << "** declaring function " << f->name() << "\n";
   }
-  // Bring the deduced parameters into scope
+  // Bring the deduced parameters into scope.
   for (Nonnull<GenericBinding*> deduced : f->deduced_parameters()) {
     RETURN_IF_ERROR(TypeCheckExp(&deduced->type(), enclosing_scope));
     deduced->set_symbolic_identity(arena_->New<VariableType>(deduced));
@@ -1451,24 +1560,14 @@ auto TypeChecker::DeclareFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
                      InterpExp(&deduced->type(), arena_, trace_stream_));
     deduced->set_static_type(type_of_type);
   }
-  // Create the impl_bindings
+  // Create the impl_bindings.
   std::vector<Nonnull<const ImplBinding*>> impl_bindings;
-  for (Nonnull<GenericBinding*> deduced : f->deduced_parameters()) {
-    Nonnull<ImplBinding*> impl_binding = arena_->New<ImplBinding>(
-        deduced->source_loc(), deduced, &deduced->static_type());
-    deduced->set_impl_binding(impl_binding);
-    impl_binding->set_static_type(&deduced->static_type());
-    impl_bindings.push_back(impl_binding);
-  }
+  RETURN_IF_ERROR(CreateImplBindings(f->deduced_parameters(), f->source_loc(),
+                                     impl_bindings));
   // Bring the impl bindings into scope.
   ImplScope function_scope;
   function_scope.AddParent(&enclosing_scope);
-  for (Nonnull<const ImplBinding*> impl_binding : impl_bindings) {
-    CHECK(impl_binding->type_var()->symbolic_identity().has_value());
-    function_scope.Add(impl_binding->interface(),
-                       *impl_binding->type_var()->symbolic_identity(),
-                       impl_binding);
-  }
+  BringImplsIntoScope(impl_bindings, function_scope, f->source_loc());
   // Type check the receiver pattern.
   if (f->is_method()) {
     RETURN_IF_ERROR(TypeCheckPattern(&f->me_pattern(), std::nullopt,
@@ -1536,18 +1635,13 @@ auto TypeChecker::TypeCheckFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
     **trace_stream_ << "** checking function " << f->name() << "\n";
   }
   // if f->return_term().is_auto(), the function body was already
-  // type checked in DeclareFunctionDeclaration
+  // type checked in DeclareFunctionDeclaration.
   if (f->body().has_value() && !f->return_term().is_auto()) {
-    // Bring the impl's into scope
+    // Bring the impl's into scope.
     ImplScope function_scope;
     function_scope.AddParent(&impl_scope);
-    for (Nonnull<const ImplBinding*> impl_binding :
-         cast<FunctionType>(f->static_type()).impl_bindings()) {
-      CHECK(impl_binding->type_var()->symbolic_identity().has_value());
-      function_scope.Add(impl_binding->interface(),
-                         *impl_binding->type_var()->symbolic_identity(),
-                         impl_binding);
-    }
+    BringImplsIntoScope(cast<FunctionType>(f->static_type()).impl_bindings(),
+                        function_scope, f->source_loc());
     if (trace_stream_)
       **trace_stream_ << function_scope;
     RETURN_IF_ERROR(TypeCheckStmt(*f->body(), function_scope));
@@ -1671,15 +1765,40 @@ auto TypeChecker::DeclareImplDeclaration(Nonnull<ImplDeclaration*> impl_decl,
   const auto& iface_decl = cast<InterfaceType>(*iface_type).declaration();
   impl_decl->set_interface_type(iface_type);
 
-  RETURN_IF_ERROR(TypeCheckExp(impl_decl->impl_type(), enclosing_scope));
+  // Bring the deduced parameters into scope.
+  for (Nonnull<GenericBinding*> deduced : impl_decl->deduced_parameters()) {
+    RETURN_IF_ERROR(TypeCheckExp(&deduced->type(), enclosing_scope));
+    deduced->set_symbolic_identity(arena_->New<VariableType>(deduced));
+    ASSIGN_OR_RETURN(Nonnull<const Value*> type_of_type,
+                     InterpExp(&deduced->type(), arena_, trace_stream_));
+    deduced->set_static_type(type_of_type);
+  }
+  // Create the impl_bindings.
+  std::vector<Nonnull<const ImplBinding*>> impl_bindings;
+  RETURN_IF_ERROR(CreateImplBindings(impl_decl->deduced_parameters(),
+                                     impl_decl->source_loc(), impl_bindings));
+  impl_decl->set_impl_bindings(impl_bindings);
+
+  // Bring the impl bindings into scope for the impl body.
+  ImplScope impl_scope;
+  impl_scope.AddParent(&enclosing_scope);
+  BringImplsIntoScope(impl_bindings, impl_scope, impl_decl->source_loc());
+  // Check and interpret the impl_type
+  RETURN_IF_ERROR(TypeCheckExp(impl_decl->impl_type(), impl_scope));
   ASSIGN_OR_RETURN(Nonnull<const Value*> impl_type_value,
                    InterpExp(impl_decl->impl_type(), arena_, trace_stream_));
-  enclosing_scope.Add(iface_type, impl_type_value, impl_decl);
+  // Bring this impl into the enclosing scope.
+  auto impl_id =
+      arena_->New<IdentifierExpression>(impl_decl->source_loc(), "impl");
+  impl_id->set_value_node(impl_decl);
+  enclosing_scope.Add(iface_type, impl_decl->deduced_parameters(),
+                      impl_type_value, impl_bindings, impl_id);
 
+  // Declare the impl members.
   for (Nonnull<Declaration*> m : impl_decl->members()) {
-    RETURN_IF_ERROR(DeclareDeclaration(m, enclosing_scope));
+    RETURN_IF_ERROR(DeclareDeclaration(m, impl_scope));
   }
-  // Check that the interface is satisfied by the impl members
+  // Check that the interface is satisfied by the impl members.
   for (Nonnull<Declaration*> m : iface_decl.members()) {
     if (std::optional<std::string> mem_name = GetName(*m);
         mem_name.has_value()) {
@@ -1701,15 +1820,24 @@ auto TypeChecker::DeclareImplDeclaration(Nonnull<ImplDeclaration*> impl_decl,
     }
   }
   impl_decl->set_constant_value(arena_->New<Witness>(impl_decl));
+  if (trace_stream_) {
+    **trace_stream_ << "** finished declaring impl " << *impl_decl->impl_type()
+                    << " as " << impl_decl->interface() << "\n";
+  }
   return Success();
 }
 
 auto TypeChecker::TypeCheckImplDeclaration(Nonnull<ImplDeclaration*> impl_decl,
-                                           const ImplScope& impl_scope)
+                                           const ImplScope& enclosing_scope)
     -> ErrorOr<Success> {
   if (trace_stream_) {
     **trace_stream_ << "checking " << *impl_decl << "\n";
   }
+  // Bring the impl's from the parameters into scope.
+  ImplScope impl_scope;
+  impl_scope.AddParent(&enclosing_scope);
+  BringImplsIntoScope(impl_decl->impl_bindings(), impl_scope,
+                      impl_decl->source_loc());
   for (Nonnull<Declaration*> m : impl_decl->members()) {
     RETURN_IF_ERROR(TypeCheckDeclaration(m, impl_scope));
   }

+ 47 - 10
executable_semantics/interpreter/type_checker.h

@@ -30,17 +30,27 @@ class TypeChecker {
   // processed.
   auto TypeCheck(AST& ast) -> ErrorOr<Success>;
 
- private:
   // Perform type argument deduction, matching the parameter type `param`
   // against the argument type `arg`. Whenever there is an VariableType
   // in the parameter type, it is deduced to be the corresponding type
   // inside the argument type.
   // The `deduced` parameter is an accumulator, that is, it holds the
   // results so-far.
-  auto ArgumentDeduction(SourceLocation source_loc, BindingMap& deduced,
-                         Nonnull<const Value*> param_type,
-                         Nonnull<const Value*> arg_type) -> ErrorOr<Success>;
+  auto ArgumentDeduction(
+      SourceLocation source_loc,
+      llvm::ArrayRef<Nonnull<const GenericBinding*>> type_params,
+      BindingMap& deduced, Nonnull<const Value*> param_type,
+      Nonnull<const Value*> arg_type) const -> ErrorOr<Success>;
+
+  // If `impl` can be an implementation of interface `iface` for the
+  // given `type`, then return an expression that will produce the witness
+  // for this `impl` (at runtime). Otherwise return std::nullopt.
+  auto MatchImpl(const InterfaceType& iface, Nonnull<const Value*> type,
+                 const ImplScope::Impl& impl, const ImplScope& impl_scope,
+                 SourceLocation source_loc) const
+      -> std::optional<Nonnull<Expression*>>;
 
+ private:
   // Traverses the AST rooted at `e`, populating the static_type() of all nodes
   // and ensuring they follow Carbon's typing rules.
   //
@@ -140,7 +150,7 @@ class TypeChecker {
                             Nonnull<const Value*> value) -> ErrorOr<Success>;
 
   // Returns the field names of the class together with their types.
-  auto FieldTypes(const NominalClassType& class_type)
+  auto FieldTypes(const NominalClassType& class_type) const
       -> std::vector<NamedValue>;
 
   // Returns true if source_fields and destination_fields contain the same set
@@ -149,22 +159,49 @@ class TypeChecker {
   // must be types.
   auto FieldTypesImplicitlyConvertible(
       llvm::ArrayRef<NamedValue> source_fields,
-      llvm::ArrayRef<NamedValue> destination_fields);
+      llvm::ArrayRef<NamedValue> destination_fields) const;
 
   // Returns true if *source is implicitly convertible to *destination. *source
   // and *destination must be concrete types.
   auto IsImplicitlyConvertible(Nonnull<const Value*> source,
-                               Nonnull<const Value*> destination) -> bool;
+                               Nonnull<const Value*> destination) const -> bool;
 
   // Check whether `actual` is implicitly convertible to `expected`
   // and halt with a fatal compilation error if it is not.
   auto ExpectType(SourceLocation source_loc, const std::string& context,
-                  Nonnull<const Value*> expected, Nonnull<const Value*> actual)
-      -> ErrorOr<Success>;
+                  Nonnull<const Value*> expected,
+                  Nonnull<const Value*> actual) const -> ErrorOr<Success>;
 
+  // Construct a type that is the same as `type` except that occurrences
+  // of type variables (aka. `GenericBinding`) are replaced by their
+  // corresponding type in `dict`.
   auto Substitute(const std::map<Nonnull<const GenericBinding*>,
                                  Nonnull<const Value*>>& dict,
-                  Nonnull<const Value*> type) -> Nonnull<const Value*>;
+                  Nonnull<const Value*> type) const -> Nonnull<const Value*>;
+
+  // For each deduced type parameter of a generic that has a
+  // non-trivial type (such as an interface), create an impl binding
+  // to serve as the parameter for passing a witness at runtime for
+  // the required impl.
+  auto CreateImplBindings(
+      llvm::ArrayRef<Nonnull<GenericBinding*>> deduced_parameters,
+      SourceLocation source_loc,
+      std::vector<Nonnull<const ImplBinding*>>& impl_bindings)
+      -> ErrorOr<Success>;
+
+  // Add all of the `impl_bindings` into the `scope`.
+  void BringImplsIntoScope(
+      llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings,
+      ImplScope& scope, SourceLocation source_loc);
+
+  // Find impls that satisfy all of the `impl_bindings`, but with the
+  // type variables in the `impl_bindings` replaced by the argument
+  // type in `deduced_type_args`.  The results are placed in the
+  // `impls` map.
+  auto SatisfyImpls(llvm::ArrayRef<Nonnull<const ImplBinding*>> impl_bindings,
+                    const ImplScope& impl_scope, SourceLocation source_loc,
+                    BindingMap& deduced_type_args, ImplExpMap& impls) const
+      -> ErrorOr<Success>;
 
   // Sets value_node.constant_value() to `value`. Can be called multiple
   // times on the same value_node, so long as it is always called with

+ 10 - 7
executable_semantics/interpreter/value.cpp

@@ -43,10 +43,14 @@ static auto GetMember(Nonnull<Arena*> arena, Nonnull<const Value*> v,
             mem_decl.has_value()) {
           const auto& fun_decl = cast<FunctionDeclaration>(**mem_decl);
           if (fun_decl.is_method()) {
-            return arena->New<BoundMethodValue>(&fun_decl, v);
+            return arena->New<BoundMethodValue>(
+                &fun_decl, v, witness->type_args(), witness->witnesses());
           } else {
             // Class function.
-            return *fun_decl.constant_value();
+            auto fun = cast<FunctionValue>(*fun_decl.constant_value());
+            return arena->New<FunctionValue>(&fun->declaration(),
+                                             witness->type_args(),
+                                             witness->witnesses());
           }
         } else {
           return CompilationError(source_loc)
@@ -89,10 +93,9 @@ static auto GetMember(Nonnull<Arena*> arena, Nonnull<const Value*> v,
                                               class_type.witnesses());
         } else {
           // Found a class function
-          Nonnull<const FunctionValue*> fun = arena->New<FunctionValue>(
-              &(*func)->declaration(), class_type.type_args(),
-              class_type.witnesses());
-          return fun;
+          return arena->New<FunctionValue>(&(*func)->declaration(),
+                                           class_type.type_args(),
+                                           class_type.witnesses());
         }
       }
     }
@@ -316,7 +319,7 @@ void Value::Print(llvm::raw_ostream& out) const {
         out << " impls ";
         llvm::ListSeparator sep;
         for (const auto& [impl_bind, impl] : class_type.impls()) {
-          out << sep << impl;
+          out << sep << *impl;
         }
       }
       if (!class_type.witnesses().empty()) {

+ 39 - 29
executable_semantics/interpreter/value.h

@@ -124,6 +124,9 @@ class IntValue : public Value {
   int value_;
 };
 
+using ImplWitnessMap =
+    std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>>;
+
 // A function value.
 class FunctionValue : public Value {
  public:
@@ -132,8 +135,7 @@ class FunctionValue : public Value {
 
   explicit FunctionValue(Nonnull<const FunctionDeclaration*> declaration,
                          const BindingMap& type_args,
-                         const std::map<Nonnull<const ImplBinding*>,
-                                        Nonnull<const Witness*>>& wits)
+                         const ImplWitnessMap& wits)
       : Value(Kind::FunctionValue),
         declaration_(declaration),
         type_args_(type_args),
@@ -157,7 +159,7 @@ class FunctionValue : public Value {
  private:
   Nonnull<const FunctionDeclaration*> declaration_;
   BindingMap type_args_;
-  std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>> witnesses_;
+  ImplWitnessMap witnesses_;
 };
 
 // A bound method value. It includes the receiver object.
@@ -192,16 +194,13 @@ class BoundMethodValue : public Value {
 
   auto type_args() const -> const BindingMap& { return type_args_; }
 
-  auto witnesses() const
-      -> const std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>>& {
-    return witnesses_;
-  }
+  auto witnesses() const -> const ImplWitnessMap& { return witnesses_; }
 
  private:
   Nonnull<const FunctionDeclaration*> declaration_;
   Nonnull<const Value*> receiver_;
   BindingMap type_args_;
-  std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>> witnesses_;
+  ImplWitnessMap witnesses_;
 };
 
 // The value of a location in memory.
@@ -343,7 +342,7 @@ class AlternativeValue : public Value {
   Nonnull<const Value*> argument_;
 };
 
-// A function value.
+// A tuple value.
 class TupleValue : public Value {
  public:
   // An empty tuple, also known as the unit type.
@@ -523,9 +522,9 @@ class NominalClassType : public Value {
   // Construct a class type that represents the result of applying the
   // given generic class to the `type_args` and that records the result of the
   // compile-time search for any required impls.
-  explicit NominalClassType(
-      Nonnull<const ClassDeclaration*> declaration, const BindingMap& type_args,
-      const std::map<Nonnull<const ImplBinding*>, ValueNodeView>& impls)
+  explicit NominalClassType(Nonnull<const ClassDeclaration*> declaration,
+                            const BindingMap& type_args,
+                            const ImplExpMap& impls)
       : Value(Kind::NominalClassType),
         declaration_(declaration),
         type_args_(type_args),
@@ -549,23 +548,17 @@ class NominalClassType : public Value {
   auto declaration() const -> const ClassDeclaration& { return *declaration_; }
   auto type_args() const -> const BindingMap& { return type_args_; }
 
-  // Maps each of the class's generic parameters to the AST node that
-  // identifies the witness table for the corresponding argument.
-  // Should not be called on 1) a non-generic class, 2) a generic-class
-  // that is not instantiated, or 3) a fully instantiated runtime type
-  // of a generic class.
-  auto impls() const
-      -> const std::map<Nonnull<const ImplBinding*>, ValueNodeView>& {
-    return impls_;
-  }
+  // Maps each of an instantiated generic class's impl bindings to an
+  // expression that constructs the witness table for the corresponding
+  // argument. Should not be called on 1) a non-generic class, 2) a
+  // generic-class that is not instantiated, or 3) a fully
+  // instantiated runtime type of a generic class.
+  auto impls() const -> const ImplExpMap& { return impls_; }
 
-  // Maps each of the class's generic parameters to the witness table
+  // Maps each of the class's impl bindings to the witness table
   // for the corresponding argument. Should only be called on a fully
   // instantiated runtime type of a generic class.
-  auto witnesses() const
-      -> const std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>>& {
-    return witnesses_;
-  }
+  auto witnesses() const -> const ImplWitnessMap& { return witnesses_; }
 
   // Returns the value of the function named `name` in this class, or
   // nullopt if there is no such function.
@@ -575,8 +568,8 @@ class NominalClassType : public Value {
  private:
   Nonnull<const ClassDeclaration*> declaration_;
   BindingMap type_args_;
-  std::map<Nonnull<const ImplBinding*>, ValueNodeView> impls_;
-  std::map<Nonnull<const ImplBinding*>, Nonnull<const Witness*>> witnesses_;
+  ImplExpMap impls_;
+  ImplWitnessMap witnesses_;
 };
 
 // Return the declaration of the member with the given name.
@@ -605,17 +598,34 @@ class InterfaceType : public Value {
 // The witness table for an impl.
 class Witness : public Value {
  public:
+  // Construct a witness for
+  // 1) a non-generic impl, or
+  // 2) a generic impl that has not yet been applied to type arguments.
   explicit Witness(Nonnull<const ImplDeclaration*> declaration)
       : Value(Kind::Witness), declaration_(declaration) {}
 
+  // Construct an instantiated generic impl.
+  explicit Witness(Nonnull<const ImplDeclaration*> declaration,
+                   const BindingMap& type_args, const ImplWitnessMap& wits)
+      : Value(Kind::Witness),
+        declaration_(declaration),
+        type_args_(type_args),
+        witnesses_(wits) {}
+
   static auto classof(const Value* value) -> bool {
     return value->kind() == Kind::Witness;
   }
-
   auto declaration() const -> const ImplDeclaration& { return *declaration_; }
+  auto type_args() const -> const BindingMap& { return type_args_; }
+  // Maps each of the impl's impl bindings to the witness table
+  // for the corresponding argument. Should only be called on a fully
+  // instantiated runtime type of a generic class.
+  auto witnesses() const -> const ImplWitnessMap& { return witnesses_; }
 
  private:
   Nonnull<const ImplDeclaration*> declaration_;
+  BindingMap type_args_;
+  ImplWitnessMap witnesses_;
 };
 
 // A choice type.

+ 2 - 0
executable_semantics/syntax/lexer.lpp

@@ -60,6 +60,7 @@ EXTERNAL             "external"
 FALSE                "false"
 FN                   "fn"
 FN_TYPE              "__Fn"
+FORALL               "forall"
 IF                   "if"
 IMPL                 "impl"
 IMPORT               "import"
@@ -162,6 +163,7 @@ string_literal        \"([^\\\"\n\v\f\r]|\\.)*\"
 {FALSE}               { return SIMPLE_TOKEN(FALSE);               }
 {FN_TYPE}             { return SIMPLE_TOKEN(FN_TYPE);             }
 {FN}                  { return SIMPLE_TOKEN(FN);                  }
+{FORALL}              { return SIMPLE_TOKEN(FORALL);              }
 {IF}                  { return SIMPLE_TOKEN(IF);                  }
 {IMPL}                { return SIMPLE_TOKEN(IMPL);                }
 {IMPORT}              { return SIMPLE_TOKEN(IMPORT);              }

+ 19 - 2
executable_semantics/syntax/parser.ypp

@@ -141,6 +141,7 @@
 %type <Nonnull<Expression*>> expression
 %type <Nonnull<GenericBinding*>> generic_binding
 %type <std::vector<Nonnull<AstNode*>>> deduced_params
+%type <std::vector<Nonnull<AstNode*>>> impl_deduced_params
 %type <std::vector<Nonnull<AstNode*>>> deduced_param_list
 %type <Nonnull<Pattern*>> pattern
 %type <Nonnull<Pattern*>> non_expression_pattern
@@ -198,6 +199,7 @@
   FALSE
   FN
   FN_TYPE
+  FORALL
   IF
   IMPL
   IMPORT
@@ -810,6 +812,12 @@ deduced_params:
 | LEFT_SQUARE_BRACKET deduced_param_list RIGHT_SQUARE_BRACKET
     { $$ = $2; }
 ;
+impl_deduced_params:
+  // Empty
+    { $$ = std::vector<Nonnull<AstNode*>>(); }
+| FORALL LEFT_SQUARE_BRACKET deduced_param_list RIGHT_SQUARE_BRACKET
+    { $$ = $3; }
+;
 receiver:
   // Empty
     { $$ = std::nullopt; }
@@ -908,8 +916,17 @@ declaration:
           arena -> New<GenericBinding>(context.source_loc(), "Self", ty_ty);
       $$ = arena->New<InterfaceDeclaration>(context.source_loc(), $2, self, $4);
     }
-| impl_kind IMPL expression AS expression LEFT_CURLY_BRACE declaration_list RIGHT_CURLY_BRACE
-    { $$ = arena->New<ImplDeclaration>(context.source_loc(), $1, $3, $5, $7); }
+| impl_kind IMPL impl_deduced_params expression AS expression LEFT_CURLY_BRACE declaration_list RIGHT_CURLY_BRACE
+    {
+      ErrorOr<ImplDeclaration*> impl = ImplDeclaration::Create(
+          arena, context.source_loc(), $1, $4, $6, $3, $8);
+      if (impl.ok()) {
+        $$ = *impl;
+      } else {
+        context.RecordSyntaxError(impl.error().message());
+        YYERROR;
+      }
+    }
 ;
 impl_kind:
   // Internal

+ 52 - 0
executable_semantics/testdata/impl/fail_ambiguous_impl.carbon

@@ -0,0 +1,52 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{not} %{executable_semantics} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{not} %{executable_semantics} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{executable_semantics} %s
+// CHECK: COMPILATION ERROR: {{.*}}/executable_semantics/testdata/impl/fail_ambiguous_impl.carbon:50: ambiguous implementations of interface Vector for class Point
+
+
+package ExecutableSemanticsTest api;
+
+interface Vector {
+  fn Add[me: Self](b: Self) -> Self;
+  fn Scale[me: Self](v: i32) -> Self;
+}
+
+class Point {
+  var x: i32;
+  var y: i32;
+}
+
+external impl Point as Vector {
+  fn Add[me: Point](b: Point) -> Point {
+      return {.x = me.x + b.x, .y = me.y + b.y};
+  }
+  fn Scale[me: Point](v: i32) -> Point {
+      return {.x = me.x * v, .y = me.y * v};
+  }
+}
+
+external impl Point as Vector {
+  fn Add[me: Point](b: Point) -> Point {
+      return {.x = me.x + b.x, .y = me.y + b.y};
+  }
+  fn Scale[me: Point](v: i32) -> Point {
+      return {.x = me.x * v, .y = me.y * v};
+  }
+}
+
+fn AddAndScaleGeneric[T:! Vector](a: T, b: T, s: i32) -> T {
+  return a.Add(b).Scale(s);
+}
+
+fn Main() -> i32 {
+  var a: Point = {.x = 1, .y = 1};
+  var b: Point = {.x = 2, .y = 3};
+  var p: Point = AddAndScaleGeneric(a, b, 5);
+  return p.x - 15;
+}

+ 48 - 0
executable_semantics/testdata/impl/param_impl.carbon

@@ -0,0 +1,48 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{executable_semantics} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{executable_semantics} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{executable_semantics} %s
+// CHECK: result: 0
+
+package ExecutableSemanticsTest api;
+
+interface Number {
+  fn Zero() -> Self;
+  fn Add[me: Self](other: Self) -> Self;
+}
+
+class Point(T:! Number) {
+  var x: T;
+  var y: T;
+}
+
+external impl i32 as Number {
+  fn Zero() -> i32 { return 0; }
+  fn Add[me: i32](other: i32) -> i32 { return me + other; }
+}
+
+external impl forall [U:! Number] Point(U) as Number {
+  fn Zero() -> Point(U) { return {.x = U.Zero(), .y = U.Zero() }; }
+  fn Add[me: Point(U)](other: Point(U)) -> Point(U) {
+    return {.x = me.x.Add(other.x), .y = me.y.Add(other.y)};
+  }
+}
+
+fn Sum[E:! Number](x: E, y: E) -> E {
+  var total: E = E.Zero();
+  total = total.Add(x);
+  total = total.Add(y);
+  return total;
+}
+
+fn Main() -> i32 {
+  var p: Point(i32) = {.x = 1, .y = 2};
+  var q: Point(i32) = {.x = 4, .y = 3};
+  var r: Point(i32) = Sum(p, q);
+  return r.x - r.y;
+}

+ 52 - 0
executable_semantics/testdata/impl/param_impl2.carbon

@@ -0,0 +1,52 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// RUN: %{executable_semantics} %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes=false %s
+// RUN: %{executable_semantics} --parser_debug --trace_file=- %s 2>&1 | \
+// RUN:   %{FileCheck} --match-full-lines --allow-unused-prefixes %s
+// AUTOUPDATE: %{executable_semantics} %s
+// CHECK: result: 0
+
+package ExecutableSemanticsTest api;
+
+interface Number {
+  fn Zero() -> Self;
+  fn Add[me: Self](other: Self) -> Self;
+}
+
+class Point(T:! Number) {
+  var x: T;
+  var y: T;
+}
+
+external impl i32 as Number {
+  fn Zero() -> i32 { return 0; }
+  fn Add[me: i32](other: i32) -> i32 { return me + other; }
+}
+
+external impl forall [U:! Number] Point(U) as Number {
+  fn Zero() -> Point(U) { return {.x = U.Zero(), .y = U.Zero() }; }
+  fn Add[me: Point(U)](other: Point(U)) -> Point(U) {
+    return {.x = me.x.Add(other.x), .y = me.y.Add(other.y)};
+  }
+}
+
+fn Sum[E:! Number](x: E, y: E) -> E {
+  var total: E = E.Zero();
+  total = total.Add(x);
+  total = total.Add(y);
+  return total;
+}
+
+fn SumPoints[E:! Number](p: Point(E), q: Point(E)) -> Point(E) {
+  return Sum(p, q);
+}
+
+fn Main() -> i32 {
+  var p: Point(i32) = {.x = 1, .y = 2};
+  var q: Point(i32) = {.x = 4, .y = 3};
+  var r: Point(i32) = SumPoints(p, q);
+  return r.x - r.y;
+}