|
|
@@ -58,6 +58,15 @@ static void SetStaticType(Nonnull<Declaration*> definition,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+static void SetStaticType(Nonnull<ReturnTerm*> return_term,
|
|
|
+ Nonnull<const Value*> type) {
|
|
|
+ if (return_term->has_static_type()) {
|
|
|
+ CHECK(TypeEqual(&return_term->static_type(), type));
|
|
|
+ } else {
|
|
|
+ return_term->set_static_type(type);
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
static void SetValue(Nonnull<Pattern*> pattern, Nonnull<const Value*> value) {
|
|
|
// TODO: find some way to CHECK that `value` is identical to pattern->value(),
|
|
|
// if it's already set. Unclear if `ValueEqual` is suitable, because it
|
|
|
@@ -68,13 +77,6 @@ static void SetValue(Nonnull<Pattern*> pattern, Nonnull<const Value*> value) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-TypeChecker::ReturnTypeContext::ReturnTypeContext(
|
|
|
- Nonnull<const Value*> orig_return_type, bool is_omitted)
|
|
|
- : is_auto_(isa<AutoType>(orig_return_type)),
|
|
|
- deduced_return_type_(is_auto_ ? std::nullopt
|
|
|
- : std::optional(orig_return_type)),
|
|
|
- is_omitted_(is_omitted) {}
|
|
|
-
|
|
|
void TypeChecker::PrintTypeEnv(TypeEnv types, llvm::raw_ostream& out) {
|
|
|
llvm::ListSeparator sep;
|
|
|
for (const auto& [name, type] : types) {
|
|
|
@@ -826,27 +828,23 @@ auto TypeChecker::TypeCheckPattern(
|
|
|
|
|
|
auto TypeChecker::TypeCheckCase(Nonnull<const Value*> expected,
|
|
|
Nonnull<Pattern*> pat, Nonnull<Statement*> body,
|
|
|
- TypeEnv types, Env values,
|
|
|
- Nonnull<ReturnTypeContext*> return_type_context)
|
|
|
- -> Match::Clause {
|
|
|
+ TypeEnv types, Env values) -> Match::Clause {
|
|
|
auto pat_res = TypeCheckPattern(pat, types, values, expected);
|
|
|
- TypeCheckStmt(body, pat_res.types, values, return_type_context);
|
|
|
+ TypeCheckStmt(body, pat_res.types, values);
|
|
|
return Match::Clause(pat, body);
|
|
|
}
|
|
|
|
|
|
auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
|
|
|
- Env values,
|
|
|
- Nonnull<ReturnTypeContext*> return_type_context)
|
|
|
- -> TCResult {
|
|
|
+ Env values) -> TCResult {
|
|
|
switch (s->kind()) {
|
|
|
case Statement::Kind::Match: {
|
|
|
auto& match = cast<Match>(*s);
|
|
|
TypeCheckExp(&match.expression(), types, values);
|
|
|
std::vector<Match::Clause> new_clauses;
|
|
|
for (auto& clause : match.clauses()) {
|
|
|
- new_clauses.push_back(TypeCheckCase(
|
|
|
- &match.expression().static_type(), &clause.pattern(),
|
|
|
- &clause.statement(), types, values, return_type_context));
|
|
|
+ new_clauses.push_back(
|
|
|
+ TypeCheckCase(&match.expression().static_type(), &clause.pattern(),
|
|
|
+ &clause.statement(), types, values));
|
|
|
}
|
|
|
return TCResult(types);
|
|
|
}
|
|
|
@@ -856,7 +854,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
|
|
|
ExpectType(s->source_loc(), "condition of `while`",
|
|
|
arena_->New<BoolType>(),
|
|
|
&while_stmt.condition().static_type());
|
|
|
- TypeCheckStmt(&while_stmt.body(), types, values, return_type_context);
|
|
|
+ TypeCheckStmt(&while_stmt.body(), types, values);
|
|
|
return TCResult(types);
|
|
|
}
|
|
|
case Statement::Kind::Break:
|
|
|
@@ -865,8 +863,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
|
|
|
case Statement::Kind::Block: {
|
|
|
auto& block = cast<Block>(*s);
|
|
|
for (auto* block_statement : block.statements()) {
|
|
|
- auto result =
|
|
|
- TypeCheckStmt(block_statement, types, values, return_type_context);
|
|
|
+ auto result = TypeCheckStmt(block_statement, types, values);
|
|
|
types = result.types;
|
|
|
}
|
|
|
return TCResult(types);
|
|
|
@@ -895,43 +892,27 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
|
|
|
TypeCheckExp(&if_stmt.condition(), types, values);
|
|
|
ExpectType(s->source_loc(), "condition of `if`", arena_->New<BoolType>(),
|
|
|
&if_stmt.condition().static_type());
|
|
|
- TypeCheckStmt(&if_stmt.then_block(), types, values, return_type_context);
|
|
|
+ TypeCheckStmt(&if_stmt.then_block(), types, values);
|
|
|
if (if_stmt.else_block()) {
|
|
|
- TypeCheckStmt(*if_stmt.else_block(), types, values,
|
|
|
- return_type_context);
|
|
|
+ TypeCheckStmt(*if_stmt.else_block(), types, values);
|
|
|
}
|
|
|
return TCResult(types);
|
|
|
}
|
|
|
case Statement::Kind::Return: {
|
|
|
auto& ret = cast<Return>(*s);
|
|
|
TypeCheckExp(&ret.expression(), types, values);
|
|
|
- if (return_type_context->is_auto()) {
|
|
|
- if (return_type_context->deduced_return_type()) {
|
|
|
- // Only one return is allowed when the return type is `auto`.
|
|
|
- FATAL_COMPILATION_ERROR(s->source_loc())
|
|
|
- << "Only one return is allowed in a function with an `auto` "
|
|
|
- "return type.";
|
|
|
- } else {
|
|
|
- // Infer the auto return from the first `return` statement.
|
|
|
- return_type_context->set_deduced_return_type(
|
|
|
- &ret.expression().static_type());
|
|
|
- }
|
|
|
+ ReturnTerm& return_term = ret.function().return_term();
|
|
|
+ if (return_term.is_auto()) {
|
|
|
+ SetStaticType(&return_term, &ret.expression().static_type());
|
|
|
} else {
|
|
|
- ExpectType(s->source_loc(), "return",
|
|
|
- *return_type_context->deduced_return_type(),
|
|
|
+ ExpectType(s->source_loc(), "return", &return_term.static_type(),
|
|
|
&ret.expression().static_type());
|
|
|
}
|
|
|
- if (ret.is_omitted_expression() != return_type_context->is_omitted()) {
|
|
|
- FATAL_COMPILATION_ERROR(s->source_loc())
|
|
|
- << *s << " should"
|
|
|
- << (return_type_context->is_omitted() ? " not" : "")
|
|
|
- << " provide a return value, to match the function's signature.";
|
|
|
- }
|
|
|
return TCResult(types);
|
|
|
}
|
|
|
case Statement::Kind::Continuation: {
|
|
|
auto& cont = cast<Continuation>(*s);
|
|
|
- TypeCheckStmt(&cont.body(), types, values, return_type_context);
|
|
|
+ TypeCheckStmt(&cont.body(), types, values);
|
|
|
types.Set(cont.continuation_variable(), arena_->New<ContinuationType>());
|
|
|
return TCResult(types);
|
|
|
}
|
|
|
@@ -1022,12 +1003,11 @@ void TypeChecker::ExpectReturnOnAllPaths(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-// TODO: factor common parts of TypeCheckFunDef and TypeOfFunDef into
|
|
|
-// a function.
|
|
|
// TODO: Add checking to function definitions to ensure that
|
|
|
// all deduced type parameters will be deduced.
|
|
|
-auto TypeChecker::TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types,
|
|
|
- Env values) -> TCResult {
|
|
|
+auto TypeChecker::TypeCheckFunctionDeclaration(Nonnull<FunctionDeclaration*> f,
|
|
|
+ TypeEnv types, Env values,
|
|
|
+ bool check_body) -> TCResult {
|
|
|
// Bring the deduced parameters into scope
|
|
|
for (Nonnull<const GenericBinding*> deduced : f->deduced_parameters()) {
|
|
|
// auto t = interpreter_.InterpExp(values, deduced.type);
|
|
|
@@ -1038,56 +1018,48 @@ auto TypeChecker::TypeCheckFunDef(FunctionDeclaration* f, TypeEnv types,
|
|
|
// Type check the parameter pattern
|
|
|
auto param_res =
|
|
|
TypeCheckPattern(&f->param_pattern(), types, values, std::nullopt);
|
|
|
- // Evaluate the return type expression
|
|
|
- auto return_type = interpreter_.InterpPattern(values, &f->return_type());
|
|
|
- if (f->name() == "Main") {
|
|
|
- ExpectExactType(f->source_loc(), "return type of `Main`",
|
|
|
- arena_->New<IntType>(), return_type);
|
|
|
- // TODO: Check that main doesn't have any parameters.
|
|
|
- }
|
|
|
- std::optional<Nonnull<Statement*>> body_stmt;
|
|
|
- if (f->body()) {
|
|
|
- ReturnTypeContext return_type_context(return_type,
|
|
|
- f->is_omitted_return_type());
|
|
|
- TypeCheckStmt(*f->body(), param_res.types, values, &return_type_context);
|
|
|
- body_stmt = *f->body();
|
|
|
- // Save the return type in case it changed.
|
|
|
- if (return_type_context.deduced_return_type().has_value()) {
|
|
|
- return_type = *return_type_context.deduced_return_type();
|
|
|
+
|
|
|
+ // Evaluate the return type, if we can do so without examining the body.
|
|
|
+ if (std::optional<Nonnull<Expression*>> return_expression =
|
|
|
+ f->return_term().type_expression();
|
|
|
+ return_expression.has_value()) {
|
|
|
+ // We ignore the return value because return type expressions can't bring
|
|
|
+ // new types into scope.
|
|
|
+ TypeCheckExp(*return_expression, param_res.types, values);
|
|
|
+ SetStaticType(&f->return_term(),
|
|
|
+ interpreter_.InterpExp(values, *return_expression));
|
|
|
+ } else if (f->return_term().is_omitted()) {
|
|
|
+ SetStaticType(&f->return_term(), TupleValue::Empty());
|
|
|
+ } else {
|
|
|
+ // We have to type-check the body in order to determine the return type.
|
|
|
+ check_body = true;
|
|
|
+ if (!f->body().has_value()) {
|
|
|
+ FATAL_COMPILATION_ERROR(f->return_term().source_loc())
|
|
|
+ << "Function declaration has deduced return type but no body";
|
|
|
}
|
|
|
}
|
|
|
- if (!f->is_omitted_return_type()) {
|
|
|
- ExpectReturnOnAllPaths(body_stmt, f->source_loc());
|
|
|
+
|
|
|
+ if (f->body().has_value() && check_body) {
|
|
|
+ TypeCheckStmt(*f->body(), param_res.types, values);
|
|
|
+ if (!f->return_term().is_omitted()) {
|
|
|
+ ExpectReturnOnAllPaths(f->body(), f->source_loc());
|
|
|
+ }
|
|
|
}
|
|
|
- ExpectIsConcreteType(f->return_type().source_loc(), return_type);
|
|
|
+
|
|
|
+ ExpectIsConcreteType(f->source_loc(), &f->return_term().static_type());
|
|
|
SetStaticType(f, arena_->New<FunctionType>(f->deduced_parameters(),
|
|
|
&f->param_pattern().static_type(),
|
|
|
- return_type));
|
|
|
- return TCResult(types);
|
|
|
-}
|
|
|
-
|
|
|
-auto TypeChecker::TypeOfFunDef(TypeEnv types, Env values,
|
|
|
- FunctionDeclaration* fun_def)
|
|
|
- -> Nonnull<const Value*> {
|
|
|
- // Bring the deduced parameters into scope
|
|
|
- for (Nonnull<const GenericBinding*> deduced : fun_def->deduced_parameters()) {
|
|
|
- // auto t = interpreter_.InterpExp(values, deduced.type);
|
|
|
- types.Set(deduced->name(), arena_->New<VariableType>(deduced->name()));
|
|
|
- AllocationId a = interpreter_.AllocateValue(*types.Get(deduced->name()));
|
|
|
- values.Set(deduced->name(), a);
|
|
|
- }
|
|
|
- // Type check the parameter pattern
|
|
|
- TypeCheckPattern(&fun_def->param_pattern(), types, values, std::nullopt);
|
|
|
- // Evaluate the return type expression
|
|
|
- auto ret = interpreter_.InterpPattern(values, &fun_def->return_type());
|
|
|
- if (ret->kind() == Value::Kind::AutoType) {
|
|
|
- // FIXME do this unconditionally?
|
|
|
- TypeCheckFunDef(fun_def, types, values);
|
|
|
- return &fun_def->static_type();
|
|
|
+ &f->return_term().static_type()));
|
|
|
+ if (f->name() == "Main") {
|
|
|
+ if (!f->return_term().type_expression().has_value()) {
|
|
|
+ FATAL_COMPILATION_ERROR(f->return_term().source_loc())
|
|
|
+ << "`Main` must have an explicit return type";
|
|
|
+ }
|
|
|
+ ExpectExactType(f->return_term().source_loc(), "return type of `Main`",
|
|
|
+ arena_->New<IntType>(), &f->return_term().static_type());
|
|
|
+ // TODO: Check that main doesn't have any parameters.
|
|
|
}
|
|
|
- return arena_->New<FunctionType>(fun_def->deduced_parameters(),
|
|
|
- &fun_def->param_pattern().static_type(),
|
|
|
- ret);
|
|
|
+ return TCResult(types);
|
|
|
}
|
|
|
|
|
|
auto TypeChecker::TypeOfClassDecl(const ClassDeclaration& class_decl,
|
|
|
@@ -1151,7 +1123,8 @@ void TypeChecker::TypeCheckDeclaration(Nonnull<Declaration*> d,
|
|
|
const Env& values) {
|
|
|
switch (d->kind()) {
|
|
|
case Declaration::Kind::FunctionDeclaration:
|
|
|
- TypeCheckFunDef(&cast<FunctionDeclaration>(*d), types, values);
|
|
|
+ TypeCheckFunctionDeclaration(&cast<FunctionDeclaration>(*d), types,
|
|
|
+ values, /*check_body=*/true);
|
|
|
return;
|
|
|
case Declaration::Kind::ClassDeclaration:
|
|
|
// TODO
|
|
|
@@ -1188,8 +1161,9 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
|
|
|
switch (d->kind()) {
|
|
|
case Declaration::Kind::FunctionDeclaration: {
|
|
|
auto& func_def = cast<FunctionDeclaration>(*d);
|
|
|
- auto t = TypeOfFunDef(tops->types, tops->values, &func_def);
|
|
|
- tops->types.Set(func_def.name(), t);
|
|
|
+ TypeCheckFunctionDeclaration(&func_def, tops->types, tops->values,
|
|
|
+ /*check_body=*/false);
|
|
|
+ tops->types.Set(func_def.name(), &func_def.static_type());
|
|
|
interpreter_.InitEnv(*d, &tops->values);
|
|
|
break;
|
|
|
}
|