Geoff Romer пре 4 година
родитељ
комит
652bee6d99

+ 16 - 0
executable_semantics/ast/expression.h

@@ -19,6 +19,8 @@
 
 namespace Carbon {
 
+class Value;
+
 class Expression {
  public:
   enum class Kind {
@@ -51,6 +53,18 @@ class Expression {
 
   auto source_loc() const -> SourceLocation { return source_loc_; }
 
+  // The static type of this expression. Cannot be called before typechecking.
+  auto static_type() const -> Nonnull<const Value*> { return *static_type_; }
+
+  // Sets the static type of this expression. Can only be called once, during
+  // typechecking.
+  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
+
+  // Returns whether the static type has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_static_type() const -> bool { return static_type_.has_value(); }
+
  protected:
   // Constructs an Expression representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
@@ -61,6 +75,8 @@ class Expression {
  private:
   const Kind kind_;
   SourceLocation source_loc_;
+
+  std::optional<Nonnull<const Value*>> static_type_;
 };
 
 // Converts paren_contents to an Expression, interpreting the parentheses as

+ 16 - 0
executable_semantics/ast/function_definition.h

@@ -15,6 +15,8 @@
 
 namespace Carbon {
 
+class Value;
+
 // TODO: expand the kinds of things that can be deduced parameters.
 //   For now, only generic parameters are supported.
 struct GenericBinding {
@@ -58,6 +60,18 @@ class FunctionDefinition {
   }
   auto body() -> std::optional<Nonnull<Statement*>> { return body_; }
 
+  // The static type of this function. Cannot be called before typechecking.
+  auto static_type() const -> Nonnull<const Value*> { return *static_type_; }
+
+  // Sets the static type of this expression. Can only be called once, during
+  // typechecking.
+  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
+
+  // Returns whether the static type has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_static_type() const -> bool { return static_type_.has_value(); }
+
  private:
   SourceLocation source_loc_;
   std::string name_;
@@ -66,6 +80,8 @@ class FunctionDefinition {
   Nonnull<Pattern*> return_type_;
   bool is_omitted_return_type_;
   std::optional<Nonnull<Statement*>> body_;
+
+  std::optional<Nonnull<const Value*>> static_type_;
 };
 
 }  // namespace Carbon

+ 16 - 0
executable_semantics/ast/pattern.h

@@ -16,6 +16,8 @@
 
 namespace Carbon {
 
+class Value;
+
 // Abstract base class of all AST nodes representing patterns.
 //
 // Pattern and its derived classes support LLVM-style RTTI, including
@@ -46,6 +48,18 @@ class Pattern {
 
   auto source_loc() const -> SourceLocation { return source_loc_; }
 
+  // The static type of this pattern. Cannot be called before typechecking.
+  auto static_type() const -> Nonnull<const Value*> { return *static_type_; }
+
+  // Sets the static type of this expression. Can only be called once, during
+  // typechecking.
+  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
+
+  // Returns whether the static type has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_static_type() const -> bool { return static_type_.has_value(); }
+
  protected:
   // Constructs a Pattern representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
@@ -56,6 +70,8 @@ class Pattern {
  private:
   const Kind kind_;
   SourceLocation source_loc_;
+
+  std::optional<Nonnull<const Value*>> static_type_;
 };
 
 // A pattern consisting of the `auto` keyword.

+ 174 - 108
executable_semantics/interpreter/type_checker.cpp

@@ -26,6 +26,39 @@ using llvm::isa;
 
 namespace Carbon {
 
+// Sets the static type of `expression`. Can be called multiple times on
+// the same node, so long as the types are the same on each call.
+static void SetStaticType(Nonnull<Expression*> expression,
+                          Nonnull<const Value*> type) {
+  if (expression->has_static_type()) {
+    CHECK(TypeEqual(expression->static_type(), type));
+  } else {
+    expression->set_static_type(type);
+  }
+}
+
+// Sets the static type of `pattern`. Can be called multiple times on
+// the same node, so long as the types are the same on each call.
+static void SetStaticType(Nonnull<Pattern*> pattern,
+                          Nonnull<const Value*> type) {
+  if (pattern->has_static_type()) {
+    CHECK(TypeEqual(pattern->static_type(), type));
+  } else {
+    pattern->set_static_type(type);
+  }
+}
+
+// Sets the static type of `definition`. Can be called multiple times on
+// the same node, so long as the types are the same on each call.
+static void SetStaticType(Nonnull<FunctionDefinition*> definition,
+                          Nonnull<const Value*> type) {
+  if (definition->has_static_type()) {
+    CHECK(TypeEqual(definition->static_type(), type));
+  } else {
+    definition->set_static_type(type);
+  }
+}
+
 TypeChecker::ReturnTypeContext::ReturnTypeContext(
     Nonnull<const Value*> orig_return_type, bool is_omitted)
     : is_auto_(isa<AutoType>(orig_return_type)),
@@ -403,20 +436,21 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
     case Expression::Kind::IndexExpression: {
       auto& index = cast<IndexExpression>(*e);
       auto res = TypeCheckExp(index.Aggregate(), types, values);
-      auto t = res.type;
-      switch (t->kind()) {
+      Nonnull<const Value*> aggregate_type = index.Aggregate()->static_type();
+      switch (aggregate_type->kind()) {
         case Value::Kind::TupleValue: {
           auto i =
               cast<IntValue>(*interpreter.InterpExp(values, index.Offset()))
                   .Val();
           std::string f = std::to_string(i);
           std::optional<Nonnull<const Value*>> field_t =
-              cast<TupleValue>(*t).FindField(f);
+              cast<TupleValue>(*aggregate_type).FindField(f);
           if (!field_t) {
             FATAL_COMPILATION_ERROR(e->source_loc())
-                << "field " << f << " is not in the tuple " << *t;
+                << "field " << f << " is not in the tuple " << *aggregate_type;
           }
-          return TCResult(*field_t, res.types);
+          SetStaticType(&index, *field_t);
+          return TCResult(res.types);
         }
         default:
           FATAL_COMPILATION_ERROR(e->source_loc()) << "expected a tuple";
@@ -430,10 +464,11 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
         auto arg_res = TypeCheckExp(arg.expression(), new_types, values);
         new_types = arg_res.types;
         new_args.push_back(FieldInitializer(arg.name(), arg.expression()));
-        arg_types.push_back({.name = arg.name(), .value = arg_res.type});
+        arg_types.push_back(
+            {.name = arg.name(), .value = arg.expression()->static_type()});
       }
-      auto tuple_t = arena->New<TupleValue>(std::move(arg_types));
-      return TCResult(tuple_t, new_types);
+      SetStaticType(e, arena->New<TupleValue>(std::move(arg_types)));
+      return TCResult(new_types);
     }
     case Expression::Kind::StructLiteral: {
       std::vector<FieldInitializer> new_args;
@@ -443,10 +478,10 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
         auto arg_res = TypeCheckExp(arg.expression(), new_types, values);
         new_types = arg_res.types;
         new_args.push_back(FieldInitializer(arg.name(), arg.expression()));
-        arg_types.push_back({arg.name(), arg_res.type});
+        arg_types.push_back({arg.name(), arg.expression()->static_type()});
       }
-      auto type = arena->New<StructType>(std::move(arg_types));
-      return TCResult(type, new_types);
+      SetStaticType(e, arena->New<StructType>(std::move(arg_types)));
+      return TCResult(new_types);
     }
     case Expression::Kind::StructTypeLiteral: {
       auto& struct_type = cast<StructTypeLiteral>(*e);
@@ -459,28 +494,28 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
                              interpreter.InterpExp(values, arg.expression()));
         new_args.push_back(FieldInitializer(arg.name(), arg.expression()));
       }
-      Nonnull<const Value*> type;
       if (struct_type.fields().empty()) {
         // `{}` is the type of `{}`, just as `()` is the type of `()`.
         // This applies only if there are no fields, because (unlike with
         // tuples) non-empty struct types are syntactically disjoint
         // from non-empty struct values.
-        type = arena->New<StructType>();
+        SetStaticType(&struct_type, arena->New<StructType>());
       } else {
-        type = arena->New<TypeType>();
+        SetStaticType(&struct_type, arena->New<TypeType>());
       }
-      return TCResult(type, new_types);
+      return TCResult(new_types);
     }
     case Expression::Kind::FieldAccessExpression: {
       auto& access = cast<FieldAccessExpression>(*e);
       auto res = TypeCheckExp(access.Aggregate(), types, values);
-      auto t = res.type;
-      switch (t->kind()) {
+      Nonnull<const Value*> aggregate_type = access.Aggregate()->static_type();
+      switch (aggregate_type->kind()) {
         case Value::Kind::StructType: {
-          const auto& struct_type = cast<StructType>(*t);
+          const auto& struct_type = cast<StructType>(*aggregate_type);
           for (const auto& [field_name, field_type] : struct_type.fields()) {
             if (access.Field() == field_name) {
-              return TCResult(field_type, res.types);
+              SetStaticType(&access, field_type);
+              return TCResult(res.types);
             }
           }
           FATAL_COMPILATION_ERROR(access.source_loc())
@@ -488,17 +523,19 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
               << access.Field();
         }
         case Value::Kind::NominalClassType: {
-          const auto& t_class = cast<NominalClassType>(*t);
+          const auto& t_class = cast<NominalClassType>(*aggregate_type);
           // Search for a field
           for (auto& field : t_class.Fields()) {
             if (access.Field() == field.first) {
-              return TCResult(field.second, res.types);
+              SetStaticType(&access, field.second);
+              return TCResult(res.types);
             }
           }
           // Search for a method
           for (auto& method : t_class.Methods()) {
             if (access.Field() == method.first) {
-              return TCResult(method.second, res.types);
+              SetStaticType(&access, method.second);
+              return TCResult(res.types);
             }
           }
           FATAL_COMPILATION_ERROR(e->source_loc())
@@ -506,10 +543,11 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
               << access.Field();
         }
         case Value::Kind::TupleValue: {
-          const auto& tup = cast<TupleValue>(*t);
+          const auto& tup = cast<TupleValue>(*aggregate_type);
           for (const TupleElement& field : tup.Elements()) {
             if (access.Field() == field.name) {
-              return TCResult(field.value, res.types);
+              SetStaticType(&access, field.value);
+              return TCResult(res.types);
             }
           }
           FATAL_COMPILATION_ERROR(e->source_loc())
@@ -517,12 +555,13 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
               << access.Field();
         }
         case Value::Kind::ChoiceType: {
-          const auto& choice = cast<ChoiceType>(*t);
+          const auto& choice = cast<ChoiceType>(*aggregate_type);
           for (const auto& vt : choice.Alternatives()) {
             if (access.Field() == vt.first) {
-              auto fun_ty = arena->New<FunctionType>(
-                  std::vector<GenericBinding>(), vt.second, t);
-              return TCResult(fun_ty, res.types);
+              SetStaticType(&access, arena->New<FunctionType>(
+                                         std::vector<GenericBinding>(),
+                                         vt.second, aggregate_type));
+              return TCResult(res.types);
             }
           }
           FATAL_COMPILATION_ERROR(e->source_loc())
@@ -536,21 +575,24 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       }
     }
     case Expression::Kind::IdentifierExpression: {
-      const auto& ident = cast<IdentifierExpression>(*e);
+      auto& ident = cast<IdentifierExpression>(*e);
       std::optional<Nonnull<const Value*>> type = types.Get(ident.Name());
       if (type) {
-        return TCResult(*type, types);
+        SetStaticType(&ident, *type);
+        return TCResult(types);
       } else {
         FATAL_COMPILATION_ERROR(e->source_loc())
             << "could not find `" << ident.Name() << "`";
       }
     }
     case Expression::Kind::IntLiteral:
-      return TCResult(arena->New<IntType>(), types);
+      SetStaticType(e, arena->New<IntType>());
+      return TCResult(types);
     case Expression::Kind::BoolLiteral:
-      return TCResult(arena->New<BoolType>(), types);
+      SetStaticType(e, arena->New<BoolType>());
+      return TCResult(types);
     case Expression::Kind::PrimitiveOperatorExpression: {
-      const auto& op = cast<PrimitiveOperatorExpression>(*e);
+      auto& op = cast<PrimitiveOperatorExpression>(*e);
       std::vector<Nonnull<Expression*>> es;
       std::vector<Nonnull<const Value*>> ts;
       auto new_types = types;
@@ -558,70 +600,82 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
         auto res = TypeCheckExp(argument, types, values);
         new_types = res.types;
         es.push_back(argument);
-        ts.push_back(res.type);
+        ts.push_back(argument->static_type());
       }
       switch (op.Op()) {
         case Operator::Neg:
           ExpectExactType(e->source_loc(), "negation", arena->New<IntType>(),
                           ts[0]);
-          return TCResult(arena->New<IntType>(), new_types);
+          SetStaticType(&op, arena->New<IntType>());
+          return TCResult(new_types);
         case Operator::Add:
           ExpectExactType(e->source_loc(), "addition(1)", arena->New<IntType>(),
                           ts[0]);
           ExpectExactType(e->source_loc(), "addition(2)", arena->New<IntType>(),
                           ts[1]);
-          return TCResult(arena->New<IntType>(), new_types);
+          SetStaticType(&op, arena->New<IntType>());
+          return TCResult(new_types);
         case Operator::Sub:
           ExpectExactType(e->source_loc(), "subtraction(1)",
                           arena->New<IntType>(), ts[0]);
           ExpectExactType(e->source_loc(), "subtraction(2)",
                           arena->New<IntType>(), ts[1]);
-          return TCResult(arena->New<IntType>(), new_types);
+          SetStaticType(&op, arena->New<IntType>());
+          return TCResult(new_types);
         case Operator::Mul:
           ExpectExactType(e->source_loc(), "multiplication(1)",
                           arena->New<IntType>(), ts[0]);
           ExpectExactType(e->source_loc(), "multiplication(2)",
                           arena->New<IntType>(), ts[1]);
-          return TCResult(arena->New<IntType>(), new_types);
+          SetStaticType(&op, arena->New<IntType>());
+          return TCResult(new_types);
         case Operator::And:
           ExpectExactType(e->source_loc(), "&&(1)", arena->New<BoolType>(),
                           ts[0]);
           ExpectExactType(e->source_loc(), "&&(2)", arena->New<BoolType>(),
                           ts[1]);
-          return TCResult(arena->New<BoolType>(), new_types);
+          SetStaticType(&op, arena->New<BoolType>());
+          return TCResult(new_types);
         case Operator::Or:
           ExpectExactType(e->source_loc(), "||(1)", arena->New<BoolType>(),
                           ts[0]);
           ExpectExactType(e->source_loc(), "||(2)", arena->New<BoolType>(),
                           ts[1]);
-          return TCResult(arena->New<BoolType>(), new_types);
+          SetStaticType(&op, arena->New<BoolType>());
+          return TCResult(new_types);
         case Operator::Not:
           ExpectExactType(e->source_loc(), "!", arena->New<BoolType>(), ts[0]);
-          return TCResult(arena->New<BoolType>(), new_types);
+          SetStaticType(&op, arena->New<BoolType>());
+          return TCResult(new_types);
         case Operator::Eq:
           ExpectExactType(e->source_loc(), "==", ts[0], ts[1]);
-          return TCResult(arena->New<BoolType>(), new_types);
+          SetStaticType(&op, arena->New<BoolType>());
+          return TCResult(new_types);
         case Operator::Deref:
           ExpectPointerType(e->source_loc(), "*", ts[0]);
-          return TCResult(cast<PointerType>(*ts[0]).Type(), new_types);
+          SetStaticType(&op, cast<PointerType>(*ts[0]).Type());
+          return TCResult(new_types);
         case Operator::Ptr:
           ExpectExactType(e->source_loc(), "*", arena->New<TypeType>(), ts[0]);
-          return TCResult(arena->New<TypeType>(), new_types);
+          SetStaticType(&op, arena->New<TypeType>());
+          return TCResult(new_types);
       }
       break;
     }
     case Expression::Kind::CallExpression: {
       auto& call = cast<CallExpression>(*e);
       auto fun_res = TypeCheckExp(call.Function(), types, values);
-      switch (fun_res.type->kind()) {
+      switch (call.Function()->static_type()->kind()) {
         case Value::Kind::FunctionType: {
-          const auto& fun_t = cast<FunctionType>(*fun_res.type);
+          const auto& fun_t =
+              cast<FunctionType>(*call.Function()->static_type());
           auto arg_res = TypeCheckExp(call.Argument(), fun_res.types, values);
           auto parameter_type = fun_t.Param();
           auto return_type = fun_t.Ret();
           if (!fun_t.Deduced().empty()) {
             auto deduced_args = ArgumentDeduction(
-                e->source_loc(), TypeEnv(arena), parameter_type, arg_res.type);
+                e->source_loc(), TypeEnv(arena), parameter_type,
+                call.Argument()->static_type());
             for (auto& deduced_param : fun_t.Deduced()) {
               // TODO: change the following to a CHECK once the real checking
               // has been added to the type checking of function signatures.
@@ -634,9 +688,11 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
             parameter_type = Substitute(deduced_args, parameter_type);
             return_type = Substitute(deduced_args, return_type);
           } else {
-            ExpectType(e->source_loc(), "call", parameter_type, arg_res.type);
+            ExpectType(e->source_loc(), "call", parameter_type,
+                       call.Argument()->static_type());
           }
-          return TCResult(return_type, arg_res.types);
+          SetStaticType(&call, return_type);
+          return TCResult(arg_res.types);
         }
         default: {
           FATAL_COMPILATION_ERROR(e->source_loc())
@@ -652,21 +708,25 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
                            interpreter.InterpExp(values, fn.Parameter()));
       ExpectIsConcreteType(fn.ReturnType()->source_loc(),
                            interpreter.InterpExp(values, fn.ReturnType()));
-      return TCResult(arena->New<TypeType>(), types);
+      SetStaticType(&fn, arena->New<TypeType>());
+      return TCResult(types);
     }
     case Expression::Kind::StringLiteral:
-      return TCResult(arena->New<StringType>(), types);
+      SetStaticType(e, arena->New<StringType>());
+      return TCResult(types);
     case Expression::Kind::IntrinsicExpression:
       switch (cast<IntrinsicExpression>(*e).Intrinsic()) {
         case IntrinsicExpression::IntrinsicKind::Print:
-          return TCResult(TupleValue::Empty(), types);
+          SetStaticType(e, TupleValue::Empty());
+          return TCResult(types);
       }
     case Expression::Kind::IntTypeLiteral:
     case Expression::Kind::BoolTypeLiteral:
     case Expression::Kind::StringTypeLiteral:
     case Expression::Kind::TypeTypeLiteral:
     case Expression::Kind::ContinuationTypeLiteral:
-      return TCResult(arena->New<TypeType>(), types);
+      SetStaticType(e, arena->New<TypeType>());
+      return TCResult(types);
   }
 }
 
@@ -686,7 +746,8 @@ auto TypeChecker::TypeCheckPattern(
   }
   switch (p->kind()) {
     case Pattern::Kind::AutoPattern: {
-      return TCResult(arena->New<TypeType>(), types);
+      SetStaticType(p, arena->New<TypeType>());
+      return TCResult(types);
     }
     case Pattern::Kind::BindingPattern: {
       auto& binding = cast<BindingPattern>(*p);
@@ -713,7 +774,8 @@ auto TypeChecker::TypeCheckPattern(
       if (binding.Name().has_value()) {
         types.Set(*binding.Name(), type);
       }
-      return TCResult(type, types);
+      SetStaticType(&binding, type);
+      return TCResult(types);
     }
     case Pattern::Kind::TuplePattern: {
       auto& tuple = cast<TuplePattern>(*p);
@@ -745,10 +807,11 @@ auto TypeChecker::TypeCheckPattern(
                                              expected_field_type);
         new_types = field_result.types;
         new_fields.push_back(TuplePattern::Field(field.name, field.pattern));
-        field_types.push_back({.name = field.name, .value = field_result.type});
+        field_types.push_back(
+            {.name = field.name, .value = field.pattern->static_type()});
       }
-      auto tuple_t = arena->New<TupleValue>(std::move(field_types));
-      return TCResult(tuple_t, new_types);
+      SetStaticType(&tuple, arena->New<TupleValue>(std::move(field_types)));
+      return TCResult(new_types);
     }
     case Pattern::Kind::AlternativePattern: {
       auto& alternative = cast<AlternativePattern>(*p);
@@ -772,12 +835,14 @@ auto TypeChecker::TypeCheckPattern(
       }
       TCResult arg_results = TypeCheckPattern(alternative.Arguments(), types,
                                               values, *parameter_types);
-      return TCResult(choice_type, arg_results.types);
+      SetStaticType(&alternative, choice_type);
+      return TCResult(arg_results.types);
     }
     case Pattern::Kind::ExpressionPattern: {
-      TCResult result =
-          TypeCheckExp(cast<ExpressionPattern>(*p).Expression(), types, values);
-      return TCResult(result.type, result.types);
+      const auto& expression = cast<ExpressionPattern>(*p).Expression();
+      TCResult result = TypeCheckExp(expression, types, values);
+      SetStaticType(p, expression->static_type());
+      return TCResult(result.types);
     }
   }
 }
@@ -799,42 +864,41 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
   switch (s->kind()) {
     case Statement::Kind::Match: {
       auto& match = cast<Match>(*s);
-      auto res = TypeCheckExp(&match.expression(), types, values);
-      auto res_type = res.type;
+      TypeCheckExp(&match.expression(), types, values);
       std::vector<Match::Clause> new_clauses;
       for (auto& clause : match.clauses()) {
-        new_clauses.push_back(TypeCheckCase(res_type, &clause.pattern(),
-                                            &clause.statement(), types, values,
-                                            return_type_context));
+        new_clauses.push_back(TypeCheckCase(
+            match.expression().static_type(), &clause.pattern(),
+            &clause.statement(), types, values, return_type_context));
       }
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::While: {
       auto& while_stmt = cast<While>(*s);
-      auto cnd_res = TypeCheckExp(while_stmt.Cond(), types, values);
+      TypeCheckExp(while_stmt.Cond(), types, values);
       ExpectType(s->source_loc(), "condition of `while`",
-                 arena->New<BoolType>(), cnd_res.type);
+                 arena->New<BoolType>(), while_stmt.Cond()->static_type());
       TypeCheckStmt(while_stmt.Body(), types, values, return_type_context);
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::Break:
     case Statement::Kind::Continue:
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     case Statement::Kind::Block: {
       auto& block = cast<Block>(*s);
       if (block.Stmt()) {
         TypeCheckStmt(*block.Stmt(), types, values, return_type_context);
-        return TCResult(TupleValue::Empty(), types);
+        return TCResult(types);
       } else {
-        return TCResult(TupleValue::Empty(), types);
+        return TCResult(types);
       }
     }
     case Statement::Kind::VariableDefinition: {
       auto& var = cast<VariableDefinition>(*s);
-      auto res = TypeCheckExp(var.Init(), types, values);
-      Nonnull<const Value*> rhs_ty = res.type;
+      TypeCheckExp(var.Init(), types, values);
+      Nonnull<const Value*> rhs_ty = var.Init()->static_type();
       auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
-      return TCResult(TupleValue::Empty(), lhs_res.types);
+      return TCResult(lhs_res.types);
     }
     case Statement::Kind::Sequence: {
       auto& seq = cast<Sequence>(*s);
@@ -846,35 +910,34 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
                                       return_type_context);
         checked_types = next_res.types;
       }
-      return TCResult(TupleValue::Empty(), checked_types);
+      return TCResult(checked_types);
     }
     case Statement::Kind::Assign: {
       auto& assign = cast<Assign>(*s);
-      auto rhs_res = TypeCheckExp(assign.Rhs(), types, values);
-      auto rhs_t = rhs_res.type;
+      TypeCheckExp(assign.Rhs(), types, values);
       auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
-      auto lhs_t = lhs_res.type;
-      ExpectType(s->source_loc(), "assign", lhs_t, rhs_t);
-      return TCResult(TupleValue::Empty(), lhs_res.types);
+      ExpectType(s->source_loc(), "assign", assign.Lhs()->static_type(),
+                 assign.Rhs()->static_type());
+      return TCResult(lhs_res.types);
     }
     case Statement::Kind::ExpressionStatement: {
       TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::If: {
       auto& if_stmt = cast<If>(*s);
-      auto cnd_res = TypeCheckExp(if_stmt.Cond(), types, values);
+      TypeCheckExp(if_stmt.Cond(), types, values);
       ExpectType(s->source_loc(), "condition of `if`", arena->New<BoolType>(),
-                 cnd_res.type);
+                 if_stmt.Cond()->static_type());
       TypeCheckStmt(if_stmt.ThenStmt(), types, values, return_type_context);
       if (if_stmt.ElseStmt()) {
         TypeCheckStmt(*if_stmt.ElseStmt(), types, values, return_type_context);
       }
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::Return: {
       auto& ret = cast<Return>(*s);
-      auto res = TypeCheckExp(ret.Exp(), types, values);
+      TypeCheckExp(ret.Exp(), types, values);
       if (return_type_context->is_auto()) {
         if (return_type_context->deduced_return_type()) {
           // Only one return is allowed when the return type is `auto`.
@@ -883,11 +946,13 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
                  "return type.";
         } else {
           // Infer the auto return from the first `return` statement.
-          return_type_context->set_deduced_return_type(res.type);
+          return_type_context->set_deduced_return_type(
+              ret.Exp()->static_type());
         }
       } else {
         ExpectType(s->source_loc(), "return",
-                   *return_type_context->deduced_return_type(), res.type);
+                   *return_type_context->deduced_return_type(),
+                   ret.Exp()->static_type());
       }
       if (ret.IsOmittedExp() != return_type_context->is_omitted()) {
         FATAL_COMPILATION_ERROR(s->source_loc())
@@ -895,24 +960,24 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
             << (return_type_context->is_omitted() ? " not" : "")
             << " provide a return value, to match the function's signature.";
       }
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::Continuation: {
       auto& cont = cast<Continuation>(*s);
       TypeCheckStmt(cont.Body(), types, values, return_type_context);
       types.Set(cont.ContinuationVariable(), arena->New<ContinuationType>());
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
     case Statement::Kind::Run: {
-      TCResult argument_result =
-          TypeCheckExp(cast<Run>(*s).Argument(), types, values);
+      auto& run = cast<Run>(*s);
+      TypeCheckExp(run.Argument(), types, values);
       ExpectType(s->source_loc(), "argument of `run`",
-                 arena->New<ContinuationType>(), argument_result.type);
-      return TCResult(TupleValue::Empty(), types);
+                 arena->New<ContinuationType>(), run.Argument()->static_type());
+      return TCResult(types);
     }
     case Statement::Kind::Await: {
       // nothing to do here
-      return TCResult(TupleValue::Empty(), types);
+      return TCResult(types);
     }
   }  // switch
 }
@@ -1008,9 +1073,10 @@ auto TypeChecker::TypeCheckFunDef(FunctionDefinition* f, TypeEnv types,
     ExpectReturnOnAllPaths(body_stmt, f->source_loc());
   }
   ExpectIsConcreteType(f->return_type().source_loc(), return_type);
-  return TCResult(arena->New<FunctionType>(f->deduced_parameters(),
-                                           param_res.type, return_type),
-                  types);
+  SetStaticType(f, arena->New<FunctionType>(f->deduced_parameters(),
+                                            f->param_pattern().static_type(),
+                                            return_type));
+  return TCResult(types);
 }
 
 auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
@@ -1024,15 +1090,16 @@ auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
     values.Set(deduced.name, a);
   }
   // Type check the parameter pattern
-  auto param_res =
-      TypeCheckPattern(&fun_def->param_pattern(), types, values, std::nullopt);
+  TypeCheckPattern(&fun_def->param_pattern(), types, values, std::nullopt);
   // Evaluate the return type expression
   auto ret = interpreter.InterpPattern(values, &fun_def->return_type());
   if (ret->kind() == Value::Kind::AutoType) {
-    return TypeCheckFunDef(fun_def, types, values).type;
+    // FIXME do this unconditionally?
+    TypeCheckFunDef(fun_def, types, values);
+    return fun_def->static_type();
   }
-  return arena->New<FunctionType>(fun_def->deduced_parameters(), param_res.type,
-                                  ret);
+  return arena->New<FunctionType>(fun_def->deduced_parameters(),
+                                  fun_def->param_pattern().static_type(), ret);
 }
 
 auto TypeChecker::TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
@@ -1102,8 +1169,7 @@ void TypeChecker::TypeCheck(Nonnull<Declaration*> d, const TypeEnv& types,
       // Signals a type error if the initializing expression does not have
       // the declared type of the variable, otherwise returns this
       // declaration with annotated types.
-      TCResult type_checked_initializer =
-          TypeCheckExp(&var.initializer(), types, values);
+      TypeCheckExp(&var.initializer(), types, values);
       const auto* binding_type =
           dyn_cast<ExpressionPattern>(var.binding().Type());
       if (binding_type == nullptr) {
@@ -1114,7 +1180,7 @@ void TypeChecker::TypeCheck(Nonnull<Declaration*> d, const TypeEnv& types,
       Nonnull<const Value*> declared_type =
           interpreter.InterpExp(values, binding_type->Expression());
       ExpectType(var.source_loc(), "initializer of variable", declared_type,
-                 type_checked_initializer.type);
+                 var.initializer().static_type());
       return;
     }
   }

+ 1 - 2
executable_semantics/interpreter/type_checker.h

@@ -70,9 +70,8 @@ class TypeChecker {
   };
 
   struct TCResult {
-    TCResult(Nonnull<const Value*> t, TypeEnv types) : type(t), types(types) {}
+    TCResult(TypeEnv types) : types(types) {}
 
-    Nonnull<const Value*> type;
     TypeEnv types;
   };