Quellcode durchsuchen

Factor out AST node for function return types. (#912)

This enables us to stop treating the return type as a Pattern (which is really isn't), treat return types more consistently with other static types in the typechecker, and drop ReturnTypeContext.

Additional changes:
- Merge TypeCheckFunDef with TypeOfFunDef.
- Handle implicit conversions in `return` statements.
- Require function type literals to have an explicit `->`.
- Move consistency check for omitted returns from TypeChecker to ResolveControlFlow.
Geoff Romer vor 4 Jahren
Ursprung
Commit
d854fb93cb

+ 14 - 4
executable_semantics/ast/declaration.cpp

@@ -45,6 +45,19 @@ void Declaration::Print(llvm::raw_ostream& out) const {
   }
 }
 
+void ReturnTerm::Print(llvm::raw_ostream& out) const {
+  switch (kind_) {
+    case ReturnKind::Omitted:
+      return;
+    case ReturnKind::Auto:
+      out << "-> auto";
+      return;
+    case ReturnKind::Expression:
+      out << "-> " << **type_expression_;
+      return;
+  }
+}
+
 void FunctionDeclaration::PrintDepth(int depth, llvm::raw_ostream& out) const {
   out << "fn " << name_ << " ";
   if (!deduced_parameters_.empty()) {
@@ -60,10 +73,7 @@ void FunctionDeclaration::PrintDepth(int depth, llvm::raw_ostream& out) const {
     }
     out << "]";
   }
-  out << *param_pattern_;
-  if (!is_omitted_return_type_) {
-    out << " -> " << *return_type_;
-  }
+  out << *param_pattern_ << return_term_;
   if (body_) {
     out << " {\n";
     (*body_)->PrintDepth(depth, out);

+ 84 - 11
executable_semantics/ast/declaration.h

@@ -104,20 +104,97 @@ struct GenericBinding : public NamedEntityInterface {
   Nonnull<Expression*> type_;
 };
 
+// The syntactic representation of a function declaration's return type.
+// This syntax can take one of three forms:
+// - An _explicit_ term consists of `->` followed by a type expression.
+// - An _auto_ term consists of `-> auto`.
+// - An _omitted_ term consists of no tokens at all.
+// Each of these forms has a corresponding factory function.
+class ReturnTerm {
+ public:
+  ReturnTerm(const ReturnTerm&) = default;
+  ReturnTerm& operator=(const ReturnTerm&) = default;
+
+  // Represents an omitted return term at `source_loc`.
+  static auto Omitted(SourceLocation source_loc) -> ReturnTerm {
+    return ReturnTerm(ReturnKind::Omitted, source_loc);
+  }
+
+  // Represents an auto return term at `source_loc`.
+  static auto Auto(SourceLocation source_loc) -> ReturnTerm {
+    return ReturnTerm(ReturnKind::Auto, source_loc);
+  }
+
+  // Represents an explicit return term with the given type expression.
+  static auto Explicit(Nonnull<Expression*> type_expression) -> ReturnTerm {
+    return ReturnTerm(type_expression);
+  }
+
+  // Returns true if this represents an omitted return term.
+  auto is_omitted() const -> bool { return kind_ == ReturnKind::Omitted; }
+
+  // Returns true if this represents an auto return term.
+  auto is_auto() const -> bool { return kind_ == ReturnKind::Auto; }
+
+  // If this represents an explicit return term, returns the type expression.
+  // Otherwise, returns nullopt.
+  auto type_expression() const -> std::optional<Nonnull<const Expression*>> {
+    return type_expression_;
+  }
+  auto type_expression() -> std::optional<Nonnull<Expression*>> {
+    return type_expression_;
+  }
+
+  // The static return type this term resolves to. Cannot be called before
+  // typechecking.
+  auto static_type() const -> const Value& { return **static_type_; }
+
+  // Sets the value of static_type(). Can only be called once, during
+  // typechecking.
+  void set_static_type(Nonnull<const Value*> type) { static_type_ = type; }
+
+  // Returns whether static_type() has been set. Should only be called
+  // during typechecking: before typechecking it's guaranteed to be false,
+  // and after typechecking it's guaranteed to be true.
+  auto has_static_type() const -> bool { return static_type_.has_value(); }
+
+  auto source_loc() const -> SourceLocation { return source_loc_; }
+
+  void Print(llvm::raw_ostream& out) const;
+  LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
+
+ private:
+  enum class ReturnKind { Omitted, Auto, Expression };
+
+  explicit ReturnTerm(ReturnKind kind, SourceLocation source_loc)
+      : kind_(kind), source_loc_(source_loc) {
+    CHECK(kind != ReturnKind::Expression);
+  }
+
+  explicit ReturnTerm(Nonnull<Expression*> type_expression)
+      : kind_(ReturnKind::Expression),
+        type_expression_(type_expression),
+        source_loc_(type_expression->source_loc()) {}
+
+  ReturnKind kind_;
+  std::optional<Nonnull<Expression*>> type_expression_;
+  std::optional<Nonnull<const Value*>> static_type_;
+
+  SourceLocation source_loc_;
+};
+
 class FunctionDeclaration : public Declaration {
  public:
   FunctionDeclaration(SourceLocation source_loc, std::string name,
                       std::vector<Nonnull<GenericBinding*>> deduced_params,
                       Nonnull<TuplePattern*> param_pattern,
-                      Nonnull<Pattern*> return_type,
-                      bool is_omitted_return_type,
+                      ReturnTerm return_term,
                       std::optional<Nonnull<Block*>> body)
       : Declaration(Kind::FunctionDeclaration, source_loc),
         name_(std::move(name)),
         deduced_parameters_(std::move(deduced_params)),
         param_pattern_(param_pattern),
-        return_type_(return_type),
-        is_omitted_return_type_(is_omitted_return_type),
+        return_term_(return_term),
         body_(body) {}
 
   static auto classof(const Declaration* decl) -> bool {
@@ -133,11 +210,8 @@ class FunctionDeclaration : public Declaration {
   }
   auto param_pattern() const -> const TuplePattern& { return *param_pattern_; }
   auto param_pattern() -> TuplePattern& { return *param_pattern_; }
-  auto return_type() const -> const Pattern& { return *return_type_; }
-  auto return_type() -> Pattern& { return *return_type_; }
-  auto is_omitted_return_type() const -> bool {
-    return is_omitted_return_type_;
-  }
+  auto return_term() const -> const ReturnTerm& { return return_term_; }
+  auto return_term() -> ReturnTerm& { return return_term_; }
   auto body() const -> std::optional<Nonnull<const Block*>> { return body_; }
   auto body() -> std::optional<Nonnull<Block*>> { return body_; }
 
@@ -149,8 +223,7 @@ class FunctionDeclaration : public Declaration {
   std::string name_;
   std::vector<Nonnull<GenericBinding*>> deduced_parameters_;
   Nonnull<TuplePattern*> param_pattern_;
-  Nonnull<Pattern*> return_type_;
-  bool is_omitted_return_type_;
+  ReturnTerm return_term_;
   std::optional<Nonnull<Block*>> body_;
   StaticScope static_scope_;
 };

+ 2 - 8
executable_semantics/ast/expression.h

@@ -352,12 +352,10 @@ class FunctionTypeLiteral : public Expression {
  public:
   explicit FunctionTypeLiteral(SourceLocation source_loc,
                                Nonnull<Expression*> parameter,
-                               Nonnull<Expression*> return_type,
-                               bool is_omitted_return_type)
+                               Nonnull<Expression*> return_type)
       : Expression(Kind::FunctionTypeLiteral, source_loc),
         parameter_(parameter),
-        return_type_(return_type),
-        is_omitted_return_type_(is_omitted_return_type) {}
+        return_type_(return_type) {}
 
   static auto classof(const Expression* exp) -> bool {
     return exp->kind() == Kind::FunctionTypeLiteral;
@@ -367,14 +365,10 @@ class FunctionTypeLiteral : public Expression {
   auto parameter() -> Expression& { return *parameter_; }
   auto return_type() const -> const Expression& { return *return_type_; }
   auto return_type() -> Expression& { return *return_type_; }
-  auto is_omitted_return_type() const -> bool {
-    return is_omitted_return_type_;
-  }
 
  private:
   Nonnull<Expression*> parameter_;
   Nonnull<Expression*> return_type_;
-  bool is_omitted_return_type_;
 };
 
 class BoolTypeLiteral : public Expression {

+ 3 - 2
executable_semantics/ast/statement.h

@@ -198,9 +198,10 @@ class Return : public Statement {
   // structure of the AST: the return value is not a child of this node,
   // but an ancestor.
   auto function() const -> const FunctionDeclaration& { return **function_; }
+  auto function() -> FunctionDeclaration& { return **function_; }
 
   // Can only be called once, by ResolveControlFlow.
-  void set_function(Nonnull<const FunctionDeclaration*> function) {
+  void set_function(Nonnull<FunctionDeclaration*> function) {
     CHECK(!function_.has_value());
     function_ = function;
   }
@@ -208,7 +209,7 @@ class Return : public Statement {
  private:
   Nonnull<Expression*> expression_;
   bool is_omitted_expression_;
-  std::optional<Nonnull<const FunctionDeclaration*>> function_;
+  std::optional<Nonnull<FunctionDeclaration*>> function_;
 };
 
 class While : public Statement {

+ 1 - 2
executable_semantics/interpreter/exec_program.cpp

@@ -32,8 +32,7 @@ static void AddIntrinsics(Nonnull<Arena*> arena,
   auto print = arena->New<FunctionDeclaration>(
       source_loc, "Print", std::vector<Nonnull<GenericBinding*>>(),
       arena->New<TuplePattern>(source_loc, print_params),
-      arena->New<ExpressionPattern>(arena->New<TupleLiteral>(source_loc)),
-      /*is_omitted_return_type=*/false, print_return);
+      ReturnTerm::Explicit(arena->New<TupleLiteral>(source_loc)), print_return);
   declarations->insert(declarations->begin(), print);
 }
 

+ 4 - 4
executable_semantics/interpreter/interpreter.cpp

@@ -944,11 +944,11 @@ auto Interpreter::StepStmt() -> Transition {
       } else {
         //    { {v :: return [] :: C, E, F} :: {C', E', F'} :: S, H}
         // -> { {v :: C', E', F'} :: S, H}
-        // TODO(geoffromer): convert the result to the function's return type,
-        // once #880 gives us a way to find that type.
         const FunctionDeclaration& function = cast<Return>(stmt).function();
-        return UnwindPast{.ast_node = *function.body(),
-                          .result = act.results()[0]};
+        return UnwindPast{
+            .ast_node = *function.body(),
+            .result = Convert(act.results()[0],
+                              &function.return_term().static_type())};
       }
     case Statement::Kind::Continuation: {
       CHECK(act.pos() == 0);

+ 46 - 16
executable_semantics/interpreter/resolve_control_flow.cpp

@@ -13,22 +13,51 @@ using llvm::cast;
 
 namespace Carbon {
 
-// Resolves control-flow edges in the AST rooted at `statement`. `return`
-// statements will resolve to `*function`, and `break` and `continue`
-// statements will resolve to `*loop`. If either parameter is nullopt, that
-// indicates a context where the corresponding statements are not permitted.
-static void ResolveControlFlow(
-    Nonnull<Statement*> statement,
-    std::optional<Nonnull<const FunctionDeclaration*>> function,
-    std::optional<Nonnull<const Statement*>> loop) {
+// Aggregate information about a function being analyzed.
+struct FunctionData {
+  // The function declaration.
+  Nonnull<FunctionDeclaration*> declaration;
+
+  // True if the function has a deduced return type, and we've already seen
+  // a `return` statement in its body.
+  bool saw_return_in_auto = false;
+};
+
+// Resolves control-flow edges such as `Return::function()` and `Break::loop()`
+// in the AST rooted at `statement`. `loop` is the innermost loop that
+// statically encloses `statement`, or nullopt if there is no such loop.
+// `function` carries information about the function body that `statement`
+// belongs to, and that information may be updated by this call. `function`
+// can be nullopt if `statement` does not belong to a function body, for
+// example if it is part of a continuation body instead.
+static void ResolveControlFlow(Nonnull<Statement*> statement,
+                               std::optional<Nonnull<const Statement*>> loop,
+                               std::optional<Nonnull<FunctionData*>> function) {
   switch (statement->kind()) {
-    case Statement::Kind::Return:
+    case Statement::Kind::Return: {
       if (!function.has_value()) {
         FATAL_COMPILATION_ERROR(statement->source_loc())
             << "return is not within a function body";
       }
-      cast<Return>(*statement).set_function(*function);
+      const ReturnTerm& function_return =
+          (*function)->declaration->return_term();
+      if (function_return.is_auto()) {
+        if ((*function)->saw_return_in_auto) {
+          FATAL_COMPILATION_ERROR(statement->source_loc())
+              << "Only one return is allowed in a function with an `auto` "
+                 "return type.";
+        }
+        (*function)->saw_return_in_auto = true;
+      }
+      auto& ret = cast<Return>(*statement);
+      ret.set_function((*function)->declaration);
+      if (ret.is_omitted_expression() != function_return.is_omitted()) {
+        FATAL_COMPILATION_ERROR(ret.source_loc())
+            << ret << " should" << (function_return.is_omitted() ? " not" : "")
+            << " provide a return value, to match the function's signature.";
+      }
       return;
+    }
     case Statement::Kind::Break:
       if (!loop.has_value()) {
         FATAL_COMPILATION_ERROR(statement->source_loc())
@@ -45,26 +74,26 @@ static void ResolveControlFlow(
       return;
     case Statement::Kind::If: {
       auto& if_stmt = cast<If>(*statement);
-      ResolveControlFlow(&if_stmt.then_block(), function, loop);
+      ResolveControlFlow(&if_stmt.then_block(), loop, function);
       if (if_stmt.else_block().has_value()) {
-        ResolveControlFlow(*if_stmt.else_block(), function, loop);
+        ResolveControlFlow(*if_stmt.else_block(), loop, function);
       }
       return;
     }
     case Statement::Kind::Block: {
       auto& block = cast<Block>(*statement);
       for (auto* block_statement : block.statements()) {
-        ResolveControlFlow(block_statement, function, loop);
+        ResolveControlFlow(block_statement, loop, function);
       }
       return;
     }
     case Statement::Kind::While:
-      ResolveControlFlow(&cast<While>(*statement).body(), function, statement);
+      ResolveControlFlow(&cast<While>(*statement).body(), statement, function);
       return;
     case Statement::Kind::Match: {
       auto& match = cast<Match>(*statement);
       for (Match::Clause& clause : match.clauses()) {
-        ResolveControlFlow(&clause.statement(), function, loop);
+        ResolveControlFlow(&clause.statement(), loop, function);
       }
       return;
     }
@@ -88,7 +117,8 @@ void ResolveControlFlow(AST& ast) {
     }
     auto& function = cast<FunctionDeclaration>(*declaration);
     if (function.body().has_value()) {
-      ResolveControlFlow(*function.body(), &function, std::nullopt);
+      FunctionData data = {.declaration = &function};
+      ResolveControlFlow(*function.body(), std::nullopt, &data);
     }
   }
 }

+ 68 - 94
executable_semantics/interpreter/type_checker.cpp

@@ -58,6 +58,15 @@ static void SetStaticType(Nonnull<Declaration*> definition,
   }
 }
 
+static void SetStaticType(Nonnull<ReturnTerm*> return_term,
+                          Nonnull<const Value*> type) {
+  if (return_term->has_static_type()) {
+    CHECK(TypeEqual(&return_term->static_type(), type));
+  } else {
+    return_term->set_static_type(type);
+  }
+}
+
 static void SetValue(Nonnull<Pattern*> pattern, Nonnull<const Value*> value) {
   // TODO: find some way to CHECK that `value` is identical to pattern->value(),
   // if it's already set. Unclear if `ValueEqual` is suitable, because it
@@ -68,13 +77,6 @@ static void SetValue(Nonnull<Pattern*> pattern, Nonnull<const Value*> value) {
   }
 }
 
-TypeChecker::ReturnTypeContext::ReturnTypeContext(
-    Nonnull<const Value*> orig_return_type, bool is_omitted)
-    : is_auto_(isa<AutoType>(orig_return_type)),
-      deduced_return_type_(is_auto_ ? std::nullopt
-                                    : std::optional(orig_return_type)),
-      is_omitted_(is_omitted) {}
-
 void TypeChecker::PrintTypeEnv(TypeEnv types, llvm::raw_ostream& out) {
   llvm::ListSeparator sep;
   for (const auto& [name, type] : types) {
@@ -826,27 +828,23 @@ auto TypeChecker::TypeCheckPattern(
 
 auto TypeChecker::TypeCheckCase(Nonnull<const Value*> expected,
                                 Nonnull<Pattern*> pat, Nonnull<Statement*> body,
-                                TypeEnv types, Env values,
-                                Nonnull<ReturnTypeContext*> return_type_context)
-    -> Match::Clause {
+                                TypeEnv types, Env values) -> Match::Clause {
   auto pat_res = TypeCheckPattern(pat, types, values, expected);
-  TypeCheckStmt(body, pat_res.types, values, return_type_context);
+  TypeCheckStmt(body, pat_res.types, values);
   return Match::Clause(pat, body);
 }
 
 auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
-                                Env values,
-                                Nonnull<ReturnTypeContext*> return_type_context)
-    -> TCResult {
+                                Env values) -> TCResult {
   switch (s->kind()) {
     case Statement::Kind::Match: {
       auto& match = cast<Match>(*s);
       TypeCheckExp(&match.expression(), types, values);
       std::vector<Match::Clause> new_clauses;
       for (auto& clause : match.clauses()) {
-        new_clauses.push_back(TypeCheckCase(
-            &match.expression().static_type(), &clause.pattern(),
-            &clause.statement(), types, values, return_type_context));
+        new_clauses.push_back(
+            TypeCheckCase(&match.expression().static_type(), &clause.pattern(),
+                          &clause.statement(), types, values));
       }
       return TCResult(types);
     }
@@ -856,7 +854,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       ExpectType(s->source_loc(), "condition of `while`",
                  arena_->New<BoolType>(),
                  &while_stmt.condition().static_type());
-      TypeCheckStmt(&while_stmt.body(), types, values, return_type_context);
+      TypeCheckStmt(&while_stmt.body(), types, values);
       return TCResult(types);
     }
     case Statement::Kind::Break:
@@ -865,8 +863,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
     case Statement::Kind::Block: {
       auto& block = cast<Block>(*s);
       for (auto* block_statement : block.statements()) {
-        auto result =
-            TypeCheckStmt(block_statement, types, values, return_type_context);
+        auto result = TypeCheckStmt(block_statement, types, values);
         types = result.types;
       }
       return TCResult(types);
@@ -895,43 +892,27 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       TypeCheckExp(&if_stmt.condition(), types, values);
       ExpectType(s->source_loc(), "condition of `if`", arena_->New<BoolType>(),
                  &if_stmt.condition().static_type());
-      TypeCheckStmt(&if_stmt.then_block(), types, values, return_type_context);
+      TypeCheckStmt(&if_stmt.then_block(), types, values);
       if (if_stmt.else_block()) {
-        TypeCheckStmt(*if_stmt.else_block(), types, values,
-                      return_type_context);
+        TypeCheckStmt(*if_stmt.else_block(), types, values);
       }
       return TCResult(types);
     }
     case Statement::Kind::Return: {
       auto& ret = cast<Return>(*s);
       TypeCheckExp(&ret.expression(), types, values);
-      if (return_type_context->is_auto()) {
-        if (return_type_context->deduced_return_type()) {
-          // Only one return is allowed when the return type is `auto`.
-          FATAL_COMPILATION_ERROR(s->source_loc())
-              << "Only one return is allowed in a function with an `auto` "
-                 "return type.";
-        } else {
-          // Infer the auto return from the first `return` statement.
-          return_type_context->set_deduced_return_type(
-              &ret.expression().static_type());
-        }
+      ReturnTerm& return_term = ret.function().return_term();
+      if (return_term.is_auto()) {
+        SetStaticType(&return_term, &ret.expression().static_type());
       } else {
-        ExpectType(s->source_loc(), "return",
-                   *return_type_context->deduced_return_type(),
+        ExpectType(s->source_loc(), "return", &return_term.static_type(),
                    &ret.expression().static_type());
       }
-      if (ret.is_omitted_expression() != return_type_context->is_omitted()) {
-        FATAL_COMPILATION_ERROR(s->source_loc())
-            << *s << " should"
-            << (return_type_context->is_omitted() ? " not" : "")
-            << " provide a return value, to match the function's signature.";
-      }
       return TCResult(types);
     }
     case Statement::Kind::Continuation: {
       auto& cont = cast<Continuation>(*s);
-      TypeCheckStmt(&cont.body(), types, values, return_type_context);
+      TypeCheckStmt(&cont.body(), types, values);
       types.Set(cont.continuation_variable(), arena_->New<ContinuationType>());
       return TCResult(types);
     }
@@ -1022,12 +1003,11 @@ void TypeChecker::ExpectReturnOnAllPaths(
   }
 }
 
-// TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
-// a function.
 // TODO: Add checking to function definitions to ensure that
 //   all deduced type parameters will be deduced.
-auto TypeChecker::TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types,
-                                  Env values) -> TCResult {
+auto TypeChecker::TypeCheckFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
+                                               TypeEnv types, Env values,
+                                               bool check_body) -> TCResult {
   // Bring the deduced parameters into scope
   for (Nonnull<const GenericBinding*> deduced : f->deduced_parameters()) {
     // auto t = interpreter_.InterpExp(values, deduced.type);
@@ -1038,56 +1018,48 @@ auto TypeChecker::TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types,
   // Type check the parameter pattern
   auto param_res =
       TypeCheckPattern(&f->param_pattern(), types, values, std::nullopt);
-  // Evaluate the return type expression
-  auto return_type = interpreter_.InterpPattern(values, &f->return_type());
-  if (f->name() == "Main") {
-    ExpectExactType(f->source_loc(), "return type of `Main`",
-                    arena_->New<IntType>(), return_type);
-    // TODO: Check that main doesn't have any parameters.
-  }
-  std::optional<Nonnull<Statement*>> body_stmt;
-  if (f->body()) {
-    ReturnTypeContext return_type_context(return_type,
-                                          f->is_omitted_return_type());
-    TypeCheckStmt(*f->body(), param_res.types, values, &return_type_context);
-    body_stmt = *f->body();
-    // Save the return type in case it changed.
-    if (return_type_context.deduced_return_type().has_value()) {
-      return_type = *return_type_context.deduced_return_type();
+
+  // Evaluate the return type, if we can do so without examining the body.
+  if (std::optional<Nonnull<Expression*>> return_expression =
+          f->return_term().type_expression();
+      return_expression.has_value()) {
+    // We ignore the return value because return type expressions can't bring
+    // new types into scope.
+    TypeCheckExp(*return_expression, param_res.types, values);
+    SetStaticType(&f->return_term(),
+                  interpreter_.InterpExp(values, *return_expression));
+  } else if (f->return_term().is_omitted()) {
+    SetStaticType(&f->return_term(), TupleValue::Empty());
+  } else {
+    // We have to type-check the body in order to determine the return type.
+    check_body = true;
+    if (!f->body().has_value()) {
+      FATAL_COMPILATION_ERROR(f->return_term().source_loc())
+          << "Function declaration has deduced return type but no body";
     }
   }
-  if (!f->is_omitted_return_type()) {
-    ExpectReturnOnAllPaths(body_stmt, f->source_loc());
+
+  if (f->body().has_value() && check_body) {
+    TypeCheckStmt(*f->body(), param_res.types, values);
+    if (!f->return_term().is_omitted()) {
+      ExpectReturnOnAllPaths(f->body(), f->source_loc());
+    }
   }
-  ExpectIsConcreteType(f->return_type().source_loc(), return_type);
+
+  ExpectIsConcreteType(f->source_loc(), &f->return_term().static_type());
   SetStaticType(f, arena_->New<FunctionType>(f->deduced_parameters(),
                                              &f->param_pattern().static_type(),
-                                             return_type));
-  return TCResult(types);
-}
-
-auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
-                               FunctionDeclaration* fun_def)
-    -> Nonnull<const Value*> {
-  // Bring the deduced parameters into scope
-  for (Nonnull<const GenericBinding*> deduced : fun_def->deduced_parameters()) {
-    // auto t = interpreter_.InterpExp(values, deduced.type);
-    types.Set(deduced->name(), arena_->New<VariableType>(deduced->name()));
-    AllocationId a = interpreter_.AllocateValue(*types.Get(deduced->name()));
-    values.Set(deduced->name(), a);
-  }
-  // Type check the parameter pattern
-  TypeCheckPattern(&fun_def->param_pattern(), types, values, std::nullopt);
-  // Evaluate the return type expression
-  auto ret = interpreter_.InterpPattern(values, &fun_def->return_type());
-  if (ret->kind() == Value::Kind::AutoType) {
-    // FIXME do this unconditionally?
-    TypeCheckFunDef(fun_def, types, values);
-    return &fun_def->static_type();
+                                             &f->return_term().static_type()));
+  if (f->name() == "Main") {
+    if (!f->return_term().type_expression().has_value()) {
+      FATAL_COMPILATION_ERROR(f->return_term().source_loc())
+          << "`Main` must have an explicit return type";
+    }
+    ExpectExactType(f->return_term().source_loc(), "return type of `Main`",
+                    arena_->New<IntType>(), &f->return_term().static_type());
+    // TODO: Check that main doesn't have any parameters.
   }
-  return arena_->New<FunctionType>(fun_def->deduced_parameters(),
-                                   &fun_def->param_pattern().static_type(),
-                                   ret);
+  return TCResult(types);
 }
 
 auto TypeChecker::TypeOfClassDecl(const ClassDeclaration& class_decl,
@@ -1151,7 +1123,8 @@ void TypeChecker::TypeCheckDeclaration(Nonnull<Declaration*> d,
                                        const Env& values) {
   switch (d->kind()) {
     case Declaration::Kind::FunctionDeclaration:
-      TypeCheckFunDef(&cast<FunctionDeclaration>(*d), types, values);
+      TypeCheckFunctionDeclaration(&cast<FunctionDeclaration>(*d), types,
+                                   values, /*check_body=*/true);
       return;
     case Declaration::Kind::ClassDeclaration:
       // TODO
@@ -1188,8 +1161,9 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
   switch (d->kind()) {
     case Declaration::Kind::FunctionDeclaration: {
       auto& func_def = cast<FunctionDeclaration>(*d);
-      auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
-      tops->types.Set(func_def.name(), t);
+      TypeCheckFunctionDeclaration(&func_def, tops->types, tops->values,
+                                   /*check_body=*/false);
+      tops->types.Set(func_def.name(), &func_def.static_type());
       interpreter_.InitEnv(*d, &tops->values);
       break;
     }

+ 18 - 56
executable_semantics/interpreter/type_checker.h

@@ -37,37 +37,6 @@ class TypeChecker {
     Env values;
   };
 
-  // Context about the return type, which may be updated during type checking.
-  class ReturnTypeContext {
-   public:
-    // If orig_return_type is auto, deduced_return_type_ will be nullopt;
-    // otherwise, it's orig_return_type. is_auto_ is set accordingly.
-    ReturnTypeContext(Nonnull<const Value*> orig_return_type, bool is_omitted);
-
-    auto is_auto() const -> bool { return is_auto_; }
-
-    auto deduced_return_type() const -> std::optional<Nonnull<const Value*>> {
-      return deduced_return_type_;
-    }
-    void set_deduced_return_type(Nonnull<const Value*> type) {
-      deduced_return_type_ = type;
-    }
-
-    auto is_omitted() const -> bool { return is_omitted_; }
-
-   private:
-    // Indicates an `auto` return type, as in `fn Foo() -> auto { return 0; }`.
-    const bool is_auto_;
-
-    // The actual return type. May be nullopt for an `auto` return type that has
-    // yet to be determined.
-    std::optional<Nonnull<const Value*>> deduced_return_type_;
-
-    // Indicates the return type was omitted and is implicitly the empty tuple,
-    // as in `fn Foo() {}`.
-    const bool is_omitted_;
-  };
-
   struct TCResult {
     explicit TCResult(TypeEnv types) : types(types) {}
 
@@ -86,52 +55,45 @@ class TypeChecker {
                                 Nonnull<const Value*> param,
                                 Nonnull<const Value*> arg) -> TypeEnv;
 
-  // TypeCheckExp performs semantic analysis on an expression.  It returns a new
-  // version of the expression, its type, and an updated environment which are
-  // bundled into a TCResult object.  The purpose of the updated environment is
-  // to bring pattern variables into scope, for example, in a match case.  The
-  // new version of the expression may include more information, for example,
-  // the type arguments deduced for the type parameters of a generic.
+  // Traverses the AST rooted at `e`, populating the static_type() of all nodes
+  // and ensuring they follow Carbon's typing rules.
   //
-  // e is the expression to be analyzed.
-  // types maps variable names to the type of their run-time value.
-  // values maps variable names to their compile-time values. It is not
+  // `types` maps variable names to the type of their run-time value.
+  // `values` maps variable names to their compile-time values. It is not
   //    directly used in this function but is passed to InterExp.
   auto TypeCheckExp(Nonnull<Expression*> e, TypeEnv types, Env values)
       -> TCResult;
 
-  // Equivalent to TypeCheckExp, but operates on Patterns instead of
-  // Expressions. `expected` is the type that this pattern is expected to have,
-  // if the surrounding context gives us that information. Otherwise, it is
+  // Equivalent to TypeCheckExp, but operates on the AST rooted at `p`.
+  //
+  // `expected` is the type that this pattern is expected to have, if the
+  // surrounding context gives us that information. Otherwise, it is
   // nullopt.
   auto TypeCheckPattern(Nonnull<Pattern*> p, TypeEnv types, Env values,
                         std::optional<Nonnull<const Value*>> expected)
       -> TCResult;
 
+  // Equivalent to TypeCheckExp, but operates on the AST rooted at `d`.
   void TypeCheckDeclaration(Nonnull<Declaration*> d, const TypeEnv& types,
                             const Env& values);
 
-  // TypeCheckStmt performs semantic analysis on a statement.  It returns a new
-  // version of the statement and a new type environment.
+  // Equivalent to TypeCheckExp, but operates on the AST rooted at `s`.
   //
-  // The ret_type parameter is used for analyzing return statements.  It is the
-  // declared return type of the enclosing function definition.  If the return
-  // type is "auto", then the return type is inferred from the first return
-  // statement.
-  auto TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types, Env values,
-                     Nonnull<ReturnTypeContext*> return_type_context)
+  // REQUIRES: f.return_term().has_static_type() || f.return_term().is_auto(),
+  // where `f` is nearest enclosing FunctionDeclaration of `s`.
+  auto TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types, Env values)
       -> TCResult;
 
-  auto TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types, Env values)
+  // Equivalent to TypeCheckExp, but operates on the AST rooted at `f`,
+  // and may not traverse f->body() if `check_body` is false.
+  auto TypeCheckFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
+                                    TypeEnv types, Env values, bool check_body)
       -> TCResult;
 
   auto TypeCheckCase(Nonnull<const Value*> expected, Nonnull<Pattern*> pat,
-                     Nonnull<Statement*> body, TypeEnv types, Env values,
-                     Nonnull<ReturnTypeContext*> return_type_context)
+                     Nonnull<Statement*> body, TypeEnv types, Env values)
       -> Match::Clause;
 
-  auto TypeOfFunDef(TypeEnv types, Env values, FunctionDeclaration* fun_def)
-      -> Nonnull<const Value*>;
   auto TypeOfClassDecl(const ClassDeclaration& class_decl, TypeEnv /*types*/,
                        Env ct_top) -> Nonnull<const Value*>;
 

+ 14 - 29
executable_semantics/syntax/parser.ypp

@@ -113,7 +113,7 @@
 %type <std::vector<Nonnull<GenericBinding*>>> deduced_param_list
 %type <Nonnull<Pattern*>> pattern
 %type <Nonnull<Pattern*>> non_expression_pattern
-%type <std::pair<Nonnull<Expression*>, bool>> return_type
+%type <BisonWrap<ReturnTerm>> return_term
 %type <Nonnull<Expression*>> paren_expression
 %type <Nonnull<StructLiteral*>> struct_literal
 %type <std::vector<FieldInitializer>> struct_literal_contents
@@ -376,12 +376,8 @@ expression:
           context.source_loc(), Operator::Ptr,
           std::vector<Nonnull<Expression*>>({$1}));
     }
-| FN_TYPE tuple return_type
-    {
-      auto [return_exp, is_omitted_exp] = $3;
-      $$ = arena->New<FunctionTypeLiteral>(context.source_loc(), $2, return_exp,
-                                           is_omitted_exp);
-    }
+| FN_TYPE tuple ARROW expression
+    { $$ = arena->New<FunctionTypeLiteral>(context.source_loc(), $2, $4); }
 ;
 designator: PERIOD identifier { $$ = $2; }
 ;
@@ -624,11 +620,13 @@ nonempty_block:
       $$ = arena->New<Block>(context.source_loc(), std::move($2));
     }
 ;
-return_type:
+return_term:
   // Empty
-    { $$ = {arena->New<TupleLiteral>(context.source_loc()), true}; }
+    { $$ = ReturnTerm::Omitted(context.source_loc()); }
+| ARROW AUTO %prec FNARROW
+    { $$ = ReturnTerm::Auto(context.source_loc()); }
 | ARROW expression %prec FNARROW
-    { $$ = {$2, false}; }
+    { $$ = ReturnTerm::Explicit($2); }
 ;
 generic_binding:
   identifier COLON_BANG expression
@@ -657,28 +655,15 @@ deduced_params:
     { $$ = $2; }
 ;
 function_declaration:
-  FN identifier deduced_params maybe_empty_tuple_pattern return_type block
-    {
-      auto [return_exp, is_omitted_exp] = $5;
-      $$ = arena->New<FunctionDeclaration>(
-          context.source_loc(), $2, $3, $4,
-          arena->New<ExpressionPattern>(return_exp), is_omitted_exp, $6);
-    }
-| FN identifier deduced_params maybe_empty_tuple_pattern ARROW AUTO block
+  FN identifier deduced_params maybe_empty_tuple_pattern return_term block
     {
-      // The return type is not considered "omitted" because it's `auto`.
-      $$ = arena->New<FunctionDeclaration>(
-          context.source_loc(), $2, $3, $4,
-          arena->New<AutoPattern>(context.source_loc()),
-          /*is_omitted_exp=*/false, $7);
+      $$ = arena->New<FunctionDeclaration>(context.source_loc(), $2, $3, $4, $5,
+                                           $6);
     }
-| FN identifier deduced_params maybe_empty_tuple_pattern return_type SEMICOLON
+| FN identifier deduced_params maybe_empty_tuple_pattern return_term SEMICOLON
     {
-      auto [return_exp, is_omitted_exp] = $5;
-      $$ = arena->New<FunctionDeclaration>(
-          context.source_loc(), $2, $3, $4,
-          arena->New<ExpressionPattern>(return_exp), is_omitted_exp,
-          std::nullopt);
+      $$ = arena->New<FunctionDeclaration>(context.source_loc(), $2, $3, $4, $5,
+                                           std::nullopt);
     }
 ;
 variable_declaration: identifier COLON pattern

+ 1 - 3
executable_semantics/testdata/function/auto_return/fail_separate_decl.carbon

@@ -7,15 +7,13 @@
 // RUN: not executable_semantics --trace %s 2>&1 | \
 // RUN:   FileCheck --match-full-lines --allow-unused-prefixes %s
 // AUTOUPDATE: executable_semantics %s
-// CHECK: COMPILATION ERROR: {{.*}}/executable_semantics/testdata/function/auto_return/fail_separate_decl.carbon:15: syntax error, unexpected SEMICOLON, expecting LEFT_CURLY_BRACE
+// CHECK: COMPILATION ERROR: {{.*}}/executable_semantics/testdata/function/auto_return/fail_separate_decl.carbon:15: Function declaration has deduced return type but no body
 
 package ExecutableSemanticsTest api;
 
 // This declaration is not allowed.
 fn Add(x: i32, y: i32) -> auto;
 
-fn Add(x: i32, y: i32) -> auto { return x + y; }
-
 fn Main() -> i32 {
   return Add(1, 2) - 3;
 }