Преглед изворни кода

Switch Statement to use inheritance+cast (#718)

Jon Meow пре 4 година
родитељ
комит
d5564280ab

+ 63 - 212
executable_semantics/ast/statement.cpp

@@ -6,196 +6,24 @@
 
 #include "common/check.h"
 #include "executable_semantics/common/arena.h"
+#include "llvm/Support/Casting.h"
 
 namespace Carbon {
 
-auto Statement::GetExpressionStatement() const -> const ExpressionStatement& {
-  return std::get<ExpressionStatement>(value);
-}
-
-auto Statement::GetAssign() const -> const Assign& {
-  return std::get<Assign>(value);
-}
-
-auto Statement::GetVariableDefinition() const -> const VariableDefinition& {
-  return std::get<VariableDefinition>(value);
-}
-
-auto Statement::GetIf() const -> const If& { return std::get<If>(value); }
-
-auto Statement::GetReturn() const -> const Return& {
-  return std::get<Return>(value);
-}
-
-auto Statement::GetSequence() const -> const Sequence& {
-  return std::get<Sequence>(value);
-}
-
-auto Statement::GetBlock() const -> const Block& {
-  return std::get<Block>(value);
-}
-
-auto Statement::GetWhile() const -> const While& {
-  return std::get<While>(value);
-}
-
-auto Statement::GetBreak() const -> const Break& {
-  return std::get<Break>(value);
-}
-
-auto Statement::GetContinue() const -> const Continue& {
-  return std::get<Continue>(value);
-}
-
-auto Statement::GetMatch() const -> const Match& {
-  return std::get<Match>(value);
-}
-
-auto Statement::GetContinuation() const -> const Continuation& {
-  return std::get<Continuation>(value);
-}
-
-auto Statement::GetRun() const -> const Run& { return std::get<Run>(value); }
-
-auto Statement::GetAwait() const -> const Await& {
-  return std::get<Await>(value);
-}
-
-auto Statement::MakeExpressionStatement(int line_num, const Expression* exp)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = ExpressionStatement({.exp = exp});
-  return s;
-}
-
-auto Statement::MakeAssign(int line_num, const Expression* lhs,
-                           const Expression* rhs) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Assign({.lhs = lhs, .rhs = rhs});
-  return s;
-}
-
-auto Statement::MakeVariableDefinition(int line_num, const Pattern* pat,
-                                       const Expression* init)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = VariableDefinition({.pat = pat, .init = init});
-  return s;
-}
-
-auto Statement::MakeIf(int line_num, const Expression* cond,
-                       const Statement* then_stmt, const Statement* else_stmt)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = If({.cond = cond, .then_stmt = then_stmt, .else_stmt = else_stmt});
-  return s;
-}
-
-auto Statement::MakeWhile(int line_num, const Expression* cond,
-                          const Statement* body) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = While({.cond = cond, .body = body});
-  return s;
-}
-
-auto Statement::MakeBreak(int line_num) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Break();
-  return s;
-}
-
-auto Statement::MakeContinue(int line_num) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Continue();
-  return s;
-}
-
-auto Statement::MakeReturn(int line_num, const Expression* exp,
-                           bool is_omitted_exp) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  if (exp == nullptr) {
-    CHECK(is_omitted_exp);
-    exp = global_arena->New<TupleLiteral>(line_num);
-  }
-  s->value = Return({.exp = exp, .is_omitted_exp = is_omitted_exp});
-  return s;
-}
-
-auto Statement::MakeSequence(int line_num, const Statement* s1,
-                             const Statement* s2) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Sequence({.stmt = s1, .next = s2});
-  return s;
-}
-
-auto Statement::MakeBlock(int line_num, const Statement* stmt)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Block({.stmt = stmt});
-  return s;
-}
-
-auto Statement::MakeMatch(
-    int line_num, const Expression* exp,
-    std::list<std::pair<const Pattern*, const Statement*>>* clauses)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Match({.exp = exp, .clauses = clauses});
-  return s;
-}
-
-// Returns an AST node for a continuation statement give its line number and
-// parts.
-auto Statement::MakeContinuation(int line_num,
-                                 std::string continuation_variable,
-                                 const Statement* body) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value =
-      Continuation({.continuation_variable = std::move(continuation_variable),
-                    .body = body});
-  return s;
-}
-
-// Returns an AST node for a run statement give its line number and argument.
-auto Statement::MakeRun(int line_num, const Expression* argument)
-    -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Run({.argument = argument});
-  return s;
-}
-
-// Returns an AST node for an await statement give its line number.
-auto Statement::MakeAwait(int line_num) -> const Statement* {
-  auto* s = global_arena->New<Statement>();
-  s->line_num = line_num;
-  s->value = Await();
-  return s;
-}
+using llvm::cast;
 
 void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
   if (depth == 0) {
     out << " ... ";
     return;
   }
-  switch (tag()) {
-    case StatementKind::Match:
-      out << "match (" << *GetMatch().exp << ") {";
+  switch (Tag()) {
+    case Kind::Match: {
+      const auto& match = cast<Match>(*this);
+      out << "match (" << *match.Exp() << ") {";
       if (depth < 0 || depth > 1) {
         out << "\n";
-        for (auto& clause : *GetMatch().clauses) {
+        for (auto& clause : *match.Clauses()) {
           out << "case " << *clause.first << " =>\n";
           clause.second->PrintDepth(depth - 1, out);
           out << "\n";
@@ -205,59 +33,72 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       }
       out << "}";
       break;
-    case StatementKind::While:
-      out << "while (" << *GetWhile().cond << ")\n";
-      GetWhile().body->PrintDepth(depth - 1, out);
+    }
+    case Kind::While: {
+      const auto& while_stmt = cast<While>(*this);
+      out << "while (" << *while_stmt.Cond() << ")\n";
+      while_stmt.Body()->PrintDepth(depth - 1, out);
       break;
-    case StatementKind::Break:
+    }
+    case Kind::Break:
       out << "break;";
       break;
-    case StatementKind::Continue:
+    case Kind::Continue:
       out << "continue;";
       break;
-    case StatementKind::VariableDefinition:
-      out << "var " << *GetVariableDefinition().pat << " = "
-          << *GetVariableDefinition().init << ";";
+    case Kind::VariableDefinition: {
+      const auto& var = cast<VariableDefinition>(*this);
+      out << "var " << *var.Pat() << " = " << *var.Init() << ";";
       break;
-    case StatementKind::ExpressionStatement:
-      out << *GetExpressionStatement().exp << ";";
+    }
+    case Kind::ExpressionStatement:
+      out << *cast<ExpressionStatement>(*this).Exp() << ";";
       break;
-    case StatementKind::Assign:
-      out << *GetAssign().lhs << " = " << *GetAssign().rhs << ";";
+    case Kind::Assign: {
+      const auto& assign = cast<Assign>(*this);
+      out << *assign.Lhs() << " = " << *assign.Rhs() << ";";
       break;
-    case StatementKind::If:
-      out << "if (" << *GetIf().cond << ")\n";
-      GetIf().then_stmt->PrintDepth(depth - 1, out);
-      if (GetIf().else_stmt) {
+    }
+    case Kind::If: {
+      const auto& if_stmt = cast<If>(*this);
+      out << "if (" << *if_stmt.Cond() << ")\n";
+      if_stmt.ThenStmt()->PrintDepth(depth - 1, out);
+      if (if_stmt.ElseStmt()) {
         out << "\nelse\n";
-        GetIf().else_stmt->PrintDepth(depth - 1, out);
+        if_stmt.ElseStmt()->PrintDepth(depth - 1, out);
       }
       break;
-    case StatementKind::Return:
-      if (GetReturn().is_omitted_exp) {
+    }
+    case Kind::Return: {
+      const auto& ret = cast<Return>(*this);
+      if (ret.IsOmittedExp()) {
         out << "return;";
       } else {
-        out << "return " << *GetReturn().exp << ";";
+        out << "return " << *ret.Exp() << ";";
       }
       break;
-    case StatementKind::Sequence:
-      GetSequence().stmt->PrintDepth(depth, out);
+    }
+    case Kind::Sequence: {
+      const auto& seq = cast<Sequence>(*this);
+      seq.Stmt()->PrintDepth(depth, out);
       if (depth < 0 || depth > 1) {
         out << "\n";
       } else {
         out << " ";
       }
-      if (GetSequence().next) {
-        GetSequence().next->PrintDepth(depth - 1, out);
+      if (seq.Next()) {
+        seq.Next()->PrintDepth(depth - 1, out);
       }
       break;
-    case StatementKind::Block:
+    }
+    case Kind::Block: {
+      const auto& block = cast<Block>(*this);
       out << "{";
       if (depth < 0 || depth > 1) {
         out << "\n";
       }
-      if (GetBlock().stmt) {
-        GetBlock().stmt->PrintDepth(depth, out);
+      if (block.Stmt()) {
+        block.Stmt()->PrintDepth(depth, out);
         if (depth < 0 || depth > 1) {
           out << "\n";
         }
@@ -267,23 +108,33 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
         out << "\n";
       }
       break;
-    case StatementKind::Continuation:
-      out << "continuation " << GetContinuation().continuation_variable << " ";
+    }
+    case Kind::Continuation: {
+      const auto& cont = cast<Continuation>(*this);
+      out << "continuation " << cont.ContinuationVariable() << " ";
       if (depth < 0 || depth > 1) {
         out << "\n";
       }
-      GetContinuation().body->PrintDepth(depth - 1, out);
+      cont.Body()->PrintDepth(depth - 1, out);
       if (depth < 0 || depth > 1) {
         out << "\n";
       }
       break;
-    case StatementKind::Run:
-      out << "run " << *GetRun().argument << ";";
+    }
+    case Kind::Run:
+      out << "run " << *cast<Run>(*this).Argument() << ";";
       break;
-    case StatementKind::Await:
+    case Kind::Await:
       out << "await;";
       break;
   }
 }
 
+Return::Return(int line_num, const Expression* exp, bool is_omitted_exp)
+    : Statement(Kind::Return, line_num),
+      exp(exp != nullptr ? exp : global_arena->New<TupleLiteral>(line_num)),
+      is_omitted_exp(is_omitted_exp) {
+  CHECK(exp != nullptr || is_omitted_exp);
+}
+
 }  // namespace Carbon

+ 220 - 117
executable_semantics/ast/statement.h

@@ -14,174 +14,277 @@
 
 namespace Carbon {
 
-enum class StatementKind {
-  ExpressionStatement,
-  Assign,
-  VariableDefinition,
-  If,
-  Return,
-  Sequence,
-  Block,
-  While,
-  Break,
-  Continue,
-  Match,
-  Continuation,  // Create a first-class continuation.
-  Run,           // Run a continuation to the next await or until it finishes..
-  Await,         // Pause execution of the continuation.
+class Statement {
+ public:
+  enum class Kind {
+    ExpressionStatement,
+    Assign,
+    VariableDefinition,
+    If,
+    Return,
+    Sequence,
+    Block,
+    While,
+    Break,
+    Continue,
+    Match,
+    Continuation,  // Create a first-class continuation.
+    Run,           // Run a continuation to the next await or until it finishes.
+    Await,         // Pause execution of the continuation.
+  };
+
+  // Returns the enumerator corresponding to the most-derived type of this
+  // object.
+  auto Tag() const -> Kind { return tag; }
+
+  auto LineNumber() const -> int { return line_num; }
+
+  void Print(llvm::raw_ostream& out) const { PrintDepth(-1, out); }
+  void PrintDepth(int depth, llvm::raw_ostream& out) const;
+  LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
+
+ protected:
+  // Constructs an Statement representing syntax at the given line number.
+  // `tag` must be the enumerator corresponding to the most-derived type being
+  // constructed.
+  Statement(Kind tag, int line_num) : tag(tag), line_num(line_num) {}
+
+ private:
+  const Kind tag;
+  int line_num;
 };
 
-struct Statement;
+class ExpressionStatement : public Statement {
+ public:
+  ExpressionStatement(int line_num, const Expression* exp)
+      : Statement(Kind::ExpressionStatement, line_num), exp(exp) {}
 
-struct ExpressionStatement {
-  static constexpr StatementKind Kind = StatementKind::ExpressionStatement;
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::ExpressionStatement;
+  }
+
+  auto Exp() const -> const Expression* { return exp; }
+
+ private:
   const Expression* exp;
 };
 
-struct Assign {
-  static constexpr StatementKind Kind = StatementKind::Assign;
+class Assign : public Statement {
+ public:
+  Assign(int line_num, const Expression* lhs, const Expression* rhs)
+      : Statement(Kind::Assign, line_num), lhs(lhs), rhs(rhs) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Assign;
+  }
+
+  auto Lhs() const -> const Expression* { return lhs; }
+  auto Rhs() const -> const Expression* { return rhs; }
+
+ private:
   const Expression* lhs;
   const Expression* rhs;
 };
 
-struct VariableDefinition {
-  static constexpr StatementKind Kind = StatementKind::VariableDefinition;
+class VariableDefinition : public Statement {
+ public:
+  VariableDefinition(int line_num, const Pattern* pat, const Expression* init)
+      : Statement(Kind::VariableDefinition, line_num), pat(pat), init(init) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::VariableDefinition;
+  }
+
+  auto Pat() const -> const Pattern* { return pat; }
+  auto Init() const -> const Expression* { return init; }
+
+ private:
   const Pattern* pat;
   const Expression* init;
 };
 
-struct If {
-  static constexpr StatementKind Kind = StatementKind::If;
+class If : public Statement {
+ public:
+  If(int line_num, const Expression* cond, const Statement* then_stmt,
+     const Statement* else_stmt)
+      : Statement(Kind::If, line_num),
+        cond(cond),
+        then_stmt(then_stmt),
+        else_stmt(else_stmt) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::If;
+  }
+
+  auto Cond() const -> const Expression* { return cond; }
+  auto ThenStmt() const -> const Statement* { return then_stmt; }
+  auto ElseStmt() const -> const Statement* { return else_stmt; }
+
+ private:
   const Expression* cond;
   const Statement* then_stmt;
   const Statement* else_stmt;
 };
 
-struct Return {
-  static constexpr StatementKind Kind = StatementKind::Return;
+class Return : public Statement {
+ public:
+  Return(int line_num, const Expression* exp, bool is_omitted_exp);
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Return;
+  }
+
+  auto Exp() const -> const Expression* { return exp; }
+  auto IsOmittedExp() const -> bool { return is_omitted_exp; }
+
+ private:
   const Expression* exp;
   bool is_omitted_exp;
 };
 
-struct Sequence {
-  static constexpr StatementKind Kind = StatementKind::Sequence;
+class Sequence : public Statement {
+ public:
+  Sequence(int line_num, const Statement* stmt, const Statement* next)
+      : Statement(Kind::Sequence, line_num), stmt(stmt), next(next) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Sequence;
+  }
+
+  auto Stmt() const -> const Statement* { return stmt; }
+  auto Next() const -> const Statement* { return next; }
+
+ private:
   const Statement* stmt;
   const Statement* next;
 };
 
-struct Block {
-  static constexpr StatementKind Kind = StatementKind::Block;
+class Block : public Statement {
+ public:
+  Block(int line_num, const Statement* stmt)
+      : Statement(Kind::Block, line_num), stmt(stmt) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Block;
+  }
+
+  auto Stmt() const -> const Statement* { return stmt; }
+
+ private:
   const Statement* stmt;
 };
 
-struct While {
-  static constexpr StatementKind Kind = StatementKind::While;
+class While : public Statement {
+ public:
+  While(int line_num, const Expression* cond, const Statement* body)
+      : Statement(Kind::While, line_num), cond(cond), body(body) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::While;
+  }
+
+  auto Cond() const -> const Expression* { return cond; }
+  auto Body() const -> const Statement* { return body; }
+
+ private:
   const Expression* cond;
   const Statement* body;
 };
 
-struct Break {
-  static constexpr StatementKind Kind = StatementKind::Break;
+class Break : public Statement {
+ public:
+  explicit Break(int line_num) : Statement(Kind::Break, line_num) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Break;
+  }
 };
 
-struct Continue {
-  static constexpr StatementKind Kind = StatementKind::Continue;
+class Continue : public Statement {
+ public:
+  explicit Continue(int line_num) : Statement(Kind::Continue, line_num) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Continue;
+  }
 };
 
-struct Match {
-  static constexpr StatementKind Kind = StatementKind::Match;
+class Match : public Statement {
+ public:
+  Match(int line_num, const Expression* exp,
+        std::list<std::pair<const Pattern*, const Statement*>>* clauses)
+      : Statement(Kind::Match, line_num), exp(exp), clauses(clauses) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Match;
+  }
+
+  auto Exp() const -> const Expression* { return exp; }
+  auto Clauses() const
+      -> const std::list<std::pair<const Pattern*, const Statement*>>* {
+    return clauses;
+  }
+
+ private:
   const Expression* exp;
   std::list<std::pair<const Pattern*, const Statement*>>* clauses;
 };
 
-struct Continuation {
-  static constexpr StatementKind Kind = StatementKind::Continuation;
-  std::string continuation_variable;
-  const Statement* body;
-};
+// A continuation statement.
+//
+//     __continuation <continuation_variable> {
+//       <body>
+//     }
+class Continuation : public Statement {
+ public:
+  Continuation(int line_num, std::string continuation_variable,
+               const Statement* body)
+      : Statement(Kind::Continuation, line_num),
+        continuation_variable(std::move(continuation_variable)),
+        body(body) {}
 
-struct Run {
-  static constexpr StatementKind Kind = StatementKind::Run;
-  const Expression* argument;
-};
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Continuation;
+  }
 
-struct Await {
-  static constexpr StatementKind Kind = StatementKind::Await;
-};
+  auto ContinuationVariable() const -> const std::string& {
+    return continuation_variable;
+  }
+  auto Body() const -> const Statement* { return body; }
 
-struct Statement {
-  // Constructors
-  static auto MakeExpressionStatement(int line_num, const Expression* exp)
-      -> const Statement*;
-  static auto MakeAssign(int line_num, const Expression* lhs,
-                         const Expression* rhs) -> const Statement*;
-  static auto MakeVariableDefinition(int line_num, const Pattern* pat,
-                                     const Expression* init)
-      -> const Statement*;
-  static auto MakeIf(int line_num, const Expression* cond,
-                     const Statement* then_stmt, const Statement* else_stmt)
-      -> const Statement*;
-  static auto MakeReturn(int line_num, const Expression* exp,
-                         bool is_omitted_exp) -> const Statement*;
-  static auto MakeSequence(int line_num, const Statement* s1,
-                           const Statement* s2) -> const Statement*;
-  static auto MakeBlock(int line_num, const Statement* s) -> const Statement*;
-  static auto MakeWhile(int line_num, const Expression* cond,
-                        const Statement* body) -> const Statement*;
-  static auto MakeBreak(int line_num) -> const Statement*;
-  static auto MakeContinue(int line_num) -> const Statement*;
-  static auto MakeMatch(
-      int line_num, const Expression* exp,
-      std::list<std::pair<const Pattern*, const Statement*>>* clauses)
-      -> const Statement*;
-  // Returns an AST node for a continuation statement give its line number and
-  // contituent parts.
-  //
-  //     __continuation <continuation_variable> {
-  //       <body>
-  //     }
-  static auto MakeContinuation(int line_num, std::string continuation_variable,
-                               const Statement* body) -> const Statement*;
-  // Returns an AST node for a run statement give its line number and argument.
-  //
-  //     __run <argument>;
-  static auto MakeRun(int line_num, const Expression* argument)
-      -> const Statement*;
-  // Returns an AST node for an await statement give its line number.
-  //
-  //     __await;
-  static auto MakeAwait(int line_num) -> const Statement*;
-
-  auto GetExpressionStatement() const -> const ExpressionStatement&;
-  auto GetAssign() const -> const Assign&;
-  auto GetVariableDefinition() const -> const VariableDefinition&;
-  auto GetIf() const -> const If&;
-  auto GetReturn() const -> const Return&;
-  auto GetSequence() const -> const Sequence&;
-  auto GetBlock() const -> const Block&;
-  auto GetWhile() const -> const While&;
-  auto GetBreak() const -> const Break&;
-  auto GetContinue() const -> const Continue&;
-  auto GetMatch() const -> const Match&;
-  auto GetContinuation() const -> const Continuation&;
-  auto GetRun() const -> const Run&;
-  auto GetAwait() const -> const Await&;
+ private:
+  std::string continuation_variable;
+  const Statement* body;
+};
 
-  void Print(llvm::raw_ostream& out) const { PrintDepth(-1, out); }
-  void PrintDepth(int depth, llvm::raw_ostream& out) const;
-  LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
+// A run statement.
+//
+//     __run <argument>;
+class Run : public Statement {
+ public:
+  Run(int line_num, const Expression* argument)
+      : Statement(Kind::Run, line_num), argument(argument) {}
 
-  inline auto tag() const -> StatementKind {
-    return std::visit([](const auto& t) { return t.Kind; }, value);
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Run;
   }
 
-  int line_num;
+  auto Argument() const -> const Expression* { return argument; }
 
  private:
-  std::variant<ExpressionStatement, Assign, VariableDefinition, If, Return,
-               Sequence, Block, While, Break, Continue, Match, Continuation,
-               Run, Await>
-      value;
+  const Expression* argument;
+};
+
+// An await statement.
+//
+//    __await;
+class Await : public Statement {
+ public:
+  explicit Await(int line_num) : Statement(Kind::Await, line_num) {}
+
+  static auto classof(const Statement* stmt) -> bool {
+    return stmt->Tag() == Kind::Await;
+  }
 };
 
 }  // namespace Carbon

+ 62 - 59
executable_semantics/interpreter/interpreter.cpp

@@ -805,8 +805,8 @@ void StepPattern() {
 auto IsWhileAct(Action* act) -> bool {
   switch (act->Tag()) {
     case Action::Kind::StatementAction:
-      switch (cast<StatementAction>(*act).Stmt()->tag()) {
-        case StatementKind::While:
+      switch (cast<StatementAction>(*act).Stmt()->Tag()) {
+        case Statement::Kind::While:
           return true;
         default:
           return false;
@@ -819,8 +819,8 @@ auto IsWhileAct(Action* act) -> bool {
 auto IsBlockAct(Action* act) -> bool {
   switch (act->Tag()) {
     case Action::Kind::StatementAction:
-      switch (cast<StatementAction>(*act).Stmt()->tag()) {
-        case StatementKind::Block:
+      switch (cast<StatementAction>(*act).Stmt()->Tag()) {
+        case Statement::Kind::Block:
           return true;
         default:
           return false;
@@ -842,13 +842,13 @@ void StepStmt() {
     stmt->PrintDepth(1, llvm::outs());
     llvm::outs() << " --->\n";
   }
-  switch (stmt->tag()) {
-    case StatementKind::Match:
+  switch (stmt->Tag()) {
+    case Statement::Kind::Match:
       if (act->Pos() == 0) {
         //    { { (match (e) ...) :: C, E, F} :: S, H}
         // -> { { e :: (match ([]) ...) :: C, E, F} :: S, H}
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetMatch().exp));
+            global_arena->New<ExpressionAction>(cast<Match>(*stmt).Exp()));
         act->IncrementPos();
       } else {
         // Regarding act->Pos():
@@ -861,11 +861,12 @@ void StepStmt() {
         // * 2: the pattern for clause 1
         // * ...
         auto clause_num = (act->Pos() - 1) / 2;
-        if (clause_num >= static_cast<int>(stmt->GetMatch().clauses->size())) {
+        if (clause_num >=
+            static_cast<int>(cast<Match>(*stmt).Clauses()->size())) {
           frame->todo.Pop(1);
           break;
         }
-        auto c = stmt->GetMatch().clauses->begin();
+        auto c = cast<Match>(*stmt).Clauses()->begin();
         std::advance(c, clause_num);
 
         if (act->Pos() % 2 == 1) {
@@ -880,12 +881,12 @@ void StepStmt() {
           auto values = CurrentEnv(state);
           std::list<std::string> vars;
           std::optional<Env> matches =
-              PatternMatch(pat, v, values, &vars, stmt->line_num);
+              PatternMatch(pat, v, values, &vars, stmt->LineNumber());
           if (matches) {  // we have a match, start the body
             auto* new_scope = global_arena->New<Scope>(*matches, vars);
             frame->scopes.Push(new_scope);
             const Statement* body_block =
-                Statement::MakeBlock(stmt->line_num, c->second);
+                global_arena->New<Block>(stmt->LineNumber(), c->second);
             Action* body_act = global_arena->New<StatementAction>(body_block);
             body_act->IncrementPos();
             frame->todo.Pop(1);
@@ -896,26 +897,26 @@ void StepStmt() {
             act->IncrementPos();
             clause_num = (act->Pos() - 1) / 2;
             if (clause_num ==
-                static_cast<int>(stmt->GetMatch().clauses->size())) {
+                static_cast<int>(cast<Match>(*stmt).Clauses()->size())) {
               frame->todo.Pop(1);
             }
           }
         }
       }
       break;
-    case StatementKind::While:
+    case Statement::Kind::While:
       if (act->Pos() == 0) {
         //    { { (while (e) s) :: C, E, F} :: S, H}
         // -> { { e :: (while ([]) s) :: C, E, F} :: S, H}
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetWhile().cond));
+            global_arena->New<ExpressionAction>(cast<While>(*stmt).Cond()));
         act->IncrementPos();
       } else if (cast<BoolValue>(*act->Results()[0]).Val()) {
         //    { {true :: (while ([]) s) :: C, E, F} :: S, H}
         // -> { { s :: (while (e) s) :: C, E, F } :: S, H}
         frame->todo.Top()->Clear();
         frame->todo.Push(
-            global_arena->New<StatementAction>(stmt->GetWhile().body));
+            global_arena->New<StatementAction>(cast<While>(*stmt).Body()));
       } else {
         //    { {false :: (while ([]) s) :: C, E, F} :: S, H}
         // -> { { C, E, F } :: S, H}
@@ -923,41 +924,41 @@ void StepStmt() {
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::Break:
+    case Statement::Kind::Break:
       CHECK(act->Pos() == 0);
       //    { { break; :: ... :: (while (e) s) :: C, E, F} :: S, H}
       // -> { { C, E', F} :: S, H}
       frame->todo.Pop(1);
       while (!frame->todo.IsEmpty() && !IsWhileAct(frame->todo.Top())) {
         if (IsBlockAct(frame->todo.Top())) {
-          DeallocateScope(stmt->line_num, frame->scopes.Top());
+          DeallocateScope(stmt->LineNumber(), frame->scopes.Top());
           frame->scopes.Pop(1);
         }
         frame->todo.Pop(1);
       }
       frame->todo.Pop(1);
       break;
-    case StatementKind::Continue:
+    case Statement::Kind::Continue:
       CHECK(act->Pos() == 0);
       //    { { continue; :: ... :: (while (e) s) :: C, E, F} :: S, H}
       // -> { { (while (e) s) :: C, E', F} :: S, H}
       frame->todo.Pop(1);
       while (!frame->todo.IsEmpty() && !IsWhileAct(frame->todo.Top())) {
         if (IsBlockAct(frame->todo.Top())) {
-          DeallocateScope(stmt->line_num, frame->scopes.Top());
+          DeallocateScope(stmt->LineNumber(), frame->scopes.Top());
           frame->scopes.Pop(1);
         }
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::Block: {
+    case Statement::Kind::Block: {
       if (act->Pos() == 0) {
-        if (stmt->GetBlock().stmt) {
+        if (cast<Block>(*stmt).Stmt()) {
           auto* scope = global_arena->New<Scope>(CurrentEnv(state),
                                                  std::list<std::string>());
           frame->scopes.Push(scope);
           frame->todo.Push(
-              global_arena->New<StatementAction>(stmt->GetBlock().stmt));
+              global_arena->New<StatementAction>(cast<Block>(*stmt).Stmt()));
           act->IncrementPos();
           act->IncrementPos();
         } else {
@@ -965,22 +966,22 @@ void StepStmt() {
         }
       } else {
         Scope* scope = frame->scopes.Top();
-        DeallocateScope(stmt->line_num, scope);
+        DeallocateScope(stmt->LineNumber(), scope);
         frame->scopes.Pop(1);
         frame->todo.Pop(1);
       }
       break;
     }
-    case StatementKind::VariableDefinition:
+    case Statement::Kind::VariableDefinition:
       if (act->Pos() == 0) {
         //    { {(var x = e) :: C, E, F} :: S, H}
         // -> { {e :: (var x = []) :: C, E, F} :: S, H}
         frame->todo.Push(global_arena->New<ExpressionAction>(
-            stmt->GetVariableDefinition().init));
+            cast<VariableDefinition>(*stmt).Init()));
         act->IncrementPos();
       } else if (act->Pos() == 1) {
         frame->todo.Push(global_arena->New<PatternAction>(
-            stmt->GetVariableDefinition().pat));
+            cast<VariableDefinition>(*stmt).Pat()));
         act->IncrementPos();
       } else if (act->Pos() == 2) {
         //    { { v :: (x = []) :: C, E, F} :: S, H}
@@ -990,52 +991,53 @@ void StepStmt() {
 
         std::optional<Env> matches =
             PatternMatch(p, v, frame->scopes.Top()->values,
-                         &frame->scopes.Top()->locals, stmt->line_num);
+                         &frame->scopes.Top()->locals, stmt->LineNumber());
         CHECK(matches)
-            << stmt->line_num
+            << stmt->LineNumber()
             << ": internal error in variable definition, match failed";
         frame->scopes.Top()->values = *matches;
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::ExpressionStatement:
+    case Statement::Kind::ExpressionStatement:
       if (act->Pos() == 0) {
         //    { {e :: C, E, F} :: S, H}
         // -> { {e :: C, E, F} :: S, H}
         frame->todo.Push(global_arena->New<ExpressionAction>(
-            stmt->GetExpressionStatement().exp));
+            cast<ExpressionStatement>(*stmt).Exp()));
         act->IncrementPos();
       } else {
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::Assign:
+    case Statement::Kind::Assign:
       if (act->Pos() == 0) {
         //    { {(lv = e) :: C, E, F} :: S, H}
         // -> { {lv :: ([] = e) :: C, E, F} :: S, H}
-        frame->todo.Push(global_arena->New<LValAction>(stmt->GetAssign().lhs));
+        frame->todo.Push(
+            global_arena->New<LValAction>(cast<Assign>(*stmt).Lhs()));
         act->IncrementPos();
       } else if (act->Pos() == 1) {
         //    { { a :: ([] = e) :: C, E, F} :: S, H}
         // -> { { e :: (a = []) :: C, E, F} :: S, H}
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetAssign().rhs));
+            global_arena->New<ExpressionAction>(cast<Assign>(*stmt).Rhs()));
         act->IncrementPos();
       } else if (act->Pos() == 2) {
         //    { { v :: (a = []) :: C, E, F} :: S, H}
         // -> { { C, E, F} :: S, H(a := v)}
         auto pat = act->Results()[0];
         auto val = act->Results()[1];
-        PatternAssignment(pat, val, stmt->line_num);
+        PatternAssignment(pat, val, stmt->LineNumber());
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::If:
+    case Statement::Kind::If:
       if (act->Pos() == 0) {
         //    { {(if (e) then_stmt else else_stmt) :: C, E, F} :: S, H}
         // -> { { e :: (if ([]) then_stmt else else_stmt) :: C, E, F} :: S, H}
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetIf().cond));
+            global_arena->New<ExpressionAction>(cast<If>(*stmt).Cond()));
         act->IncrementPos();
       } else if (cast<BoolValue>(*act->Results()[0]).Val()) {
         //    { {true :: if ([]) then_stmt else else_stmt :: C, E, F} ::
@@ -1043,48 +1045,48 @@ void StepStmt() {
         // -> { { then_stmt :: C, E, F } :: S, H}
         frame->todo.Pop(1);
         frame->todo.Push(
-            global_arena->New<StatementAction>(stmt->GetIf().then_stmt));
-      } else if (stmt->GetIf().else_stmt) {
+            global_arena->New<StatementAction>(cast<If>(*stmt).ThenStmt()));
+      } else if (cast<If>(*stmt).ElseStmt()) {
         //    { {false :: if ([]) then_stmt else else_stmt :: C, E, F} ::
         //      S, H}
         // -> { { else_stmt :: C, E, F } :: S, H}
         frame->todo.Pop(1);
         frame->todo.Push(
-            global_arena->New<StatementAction>(stmt->GetIf().else_stmt));
+            global_arena->New<StatementAction>(cast<If>(*stmt).ElseStmt()));
       } else {
         frame->todo.Pop(1);
       }
       break;
-    case StatementKind::Return:
+    case Statement::Kind::Return:
       if (act->Pos() == 0) {
         //    { {return e :: C, E, F} :: S, H}
         // -> { {e :: return [] :: C, E, F} :: S, H}
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetReturn().exp));
+            global_arena->New<ExpressionAction>(cast<Return>(*stmt).Exp()));
         act->IncrementPos();
       } else {
         //    { {v :: return [] :: C, E, F} :: {C', E', F'} :: S, H}
         // -> { {v :: C', E', F'} :: S, H}
-        const Value* ret_val = CopyVal(act->Results()[0], stmt->line_num);
-        DeallocateLocals(stmt->line_num, frame);
+        const Value* ret_val = CopyVal(act->Results()[0], stmt->LineNumber());
+        DeallocateLocals(stmt->LineNumber(), frame);
         state->stack.Pop(1);
         frame = state->stack.Top();
         frame->todo.Push(global_arena->New<ValAction>(ret_val));
       }
       break;
-    case StatementKind::Sequence:
+    case Statement::Kind::Sequence:
       CHECK(act->Pos() == 0);
       //    { { (s1,s2) :: C, E, F} :: S, H}
       // -> { { s1 :: s2 :: C, E, F} :: S, H}
       frame->todo.Pop(1);
-      if (stmt->GetSequence().next) {
+      if (cast<Sequence>(*stmt).Next()) {
         frame->todo.Push(
-            global_arena->New<StatementAction>(stmt->GetSequence().next));
+            global_arena->New<StatementAction>(cast<Sequence>(*stmt).Next()));
       }
       frame->todo.Push(
-          global_arena->New<StatementAction>(stmt->GetSequence().stmt));
+          global_arena->New<StatementAction>(cast<Sequence>(*stmt).Stmt()));
       break;
-    case StatementKind::Continuation: {
+    case Statement::Kind::Continuation: {
       CHECK(act->Pos() == 0);
       // Create a continuation object by creating a frame similar the
       // way one is created in a function call.
@@ -1094,10 +1096,10 @@ void StepStmt() {
       scopes.Push(scope);
       Stack<Action*> todo;
       todo.Push(global_arena->New<StatementAction>(
-          Statement::MakeReturn(stmt->line_num, nullptr,
-                                /*is_omitted_exp=*/true)));
+          global_arena->New<Return>(stmt->LineNumber(), nullptr,
+                                    /*is_omitted_exp=*/true)));
       todo.Push(
-          global_arena->New<StatementAction>(stmt->GetContinuation().body));
+          global_arena->New<StatementAction>(cast<Continuation>(*stmt).Body()));
       Frame* continuation_frame =
           global_arena->New<Frame>("__continuation", scopes, todo);
       Address continuation_address =
@@ -1107,25 +1109,26 @@ void StepStmt() {
       continuation_frame->continuation = continuation_address;
       // Bind the continuation object to the continuation variable
       frame->scopes.Top()->values.Set(
-          stmt->GetContinuation().continuation_variable, continuation_address);
+          cast<Continuation>(*stmt).ContinuationVariable(),
+          continuation_address);
       // Pop the continuation statement.
       frame->todo.Pop();
       break;
     }
-    case StatementKind::Run:
+    case Statement::Kind::Run:
       if (act->Pos() == 0) {
         // Evaluate the argument of the run statement.
         frame->todo.Push(
-            global_arena->New<ExpressionAction>(stmt->GetRun().argument));
+            global_arena->New<ExpressionAction>(cast<Run>(*stmt).Argument()));
         act->IncrementPos();
       } else {
         frame->todo.Pop(1);
         // Push an expression statement action to ignore the result
         // value from the continuation.
         Action* ignore_result = global_arena->New<StatementAction>(
-            Statement::MakeExpressionStatement(
-                stmt->line_num,
-                global_arena->New<TupleLiteral>(stmt->line_num)));
+            global_arena->New<ExpressionStatement>(
+                stmt->LineNumber(),
+                global_arena->New<TupleLiteral>(stmt->LineNumber())));
         frame->todo.Push(ignore_result);
         // Push the continuation onto the current stack.
         const std::vector<Frame*>& continuation_vector =
@@ -1136,7 +1139,7 @@ void StepStmt() {
         }
       }
       break;
-    case StatementKind::Await:
+    case Statement::Kind::Await:
       CHECK(act->Pos() == 0);
       // Pause the current continuation
       frame->todo.Pop();
@@ -1147,7 +1150,7 @@ void StepStmt() {
       // Update the continuation with the paused stack.
       state->heap.Write(*paused.back()->continuation,
                         global_arena->New<ContinuationValue>(paused),
-                        stmt->line_num);
+                        stmt->LineNumber());
       break;
   }
 }

+ 127 - 116
executable_semantics/interpreter/typecheck.cpp

@@ -641,128 +641,135 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
   if (!s) {
     return TCStatement(s, types);
   }
-  switch (s->tag()) {
-    case StatementKind::Match: {
-      auto res = TypeCheckExp(s->GetMatch().exp, types, values);
+  switch (s->Tag()) {
+    case Statement::Kind::Match: {
+      const auto& match = cast<Match>(*s);
+      auto res = TypeCheckExp(match.Exp(), types, values);
       auto res_type = res.type;
       auto new_clauses =
           global_arena
               ->New<std::list<std::pair<const Pattern*, const Statement*>>>();
-      for (auto& clause : *s->GetMatch().clauses) {
+      for (auto& clause : *match.Clauses()) {
         new_clauses->push_back(TypecheckCase(res_type, clause.first,
                                              clause.second, types, values,
                                              ret_type, is_omitted_ret_type));
       }
       const Statement* new_s =
-          Statement::MakeMatch(s->line_num, res.exp, new_clauses);
+          global_arena->New<Match>(s->LineNumber(), res.exp, new_clauses);
       return TCStatement(new_s, types);
     }
-    case StatementKind::While: {
-      auto cnd_res = TypeCheckExp(s->GetWhile().cond, types, values);
-      ExpectType(s->line_num, "condition of `while`",
+    case Statement::Kind::While: {
+      const auto& while_stmt = cast<While>(*s);
+      auto cnd_res = TypeCheckExp(while_stmt.Cond(), types, values);
+      ExpectType(s->LineNumber(), "condition of `while`",
                  global_arena->New<BoolType>(), cnd_res.type);
-      auto body_res = TypeCheckStmt(s->GetWhile().body, types, values, ret_type,
+      auto body_res = TypeCheckStmt(while_stmt.Body(), types, values, ret_type,
                                     is_omitted_ret_type);
       auto new_s =
-          Statement::MakeWhile(s->line_num, cnd_res.exp, body_res.stmt);
+          global_arena->New<While>(s->LineNumber(), cnd_res.exp, body_res.stmt);
       return TCStatement(new_s, types);
     }
-    case StatementKind::Break:
-    case StatementKind::Continue:
+    case Statement::Kind::Break:
+    case Statement::Kind::Continue:
       return TCStatement(s, types);
-    case StatementKind::Block: {
-      auto stmt_res = TypeCheckStmt(s->GetBlock().stmt, types, values, ret_type,
-                                    is_omitted_ret_type);
-      return TCStatement(Statement::MakeBlock(s->line_num, stmt_res.stmt),
-                         types);
+    case Statement::Kind::Block: {
+      auto stmt_res = TypeCheckStmt(cast<Block>(*s).Stmt(), types, values,
+                                    ret_type, is_omitted_ret_type);
+      return TCStatement(
+          global_arena->New<Block>(s->LineNumber(), stmt_res.stmt), types);
     }
-    case StatementKind::VariableDefinition: {
-      auto res = TypeCheckExp(s->GetVariableDefinition().init, types, values);
+    case Statement::Kind::VariableDefinition: {
+      const auto& var = cast<VariableDefinition>(*s);
+      auto res = TypeCheckExp(var.Init(), types, values);
       const Value* rhs_ty = res.type;
-      auto lhs_res = TypeCheckPattern(s->GetVariableDefinition().pat, types,
-                                      values, rhs_ty);
-      const Statement* new_s = Statement::MakeVariableDefinition(
-          s->line_num, s->GetVariableDefinition().pat, res.exp);
+      auto lhs_res = TypeCheckPattern(var.Pat(), types, values, rhs_ty);
+      const Statement* new_s = global_arena->New<VariableDefinition>(
+          s->LineNumber(), var.Pat(), res.exp);
       return TCStatement(new_s, lhs_res.types);
     }
-    case StatementKind::Sequence: {
-      auto stmt_res = TypeCheckStmt(s->GetSequence().stmt, types, values,
-                                    ret_type, is_omitted_ret_type);
+    case Statement::Kind::Sequence: {
+      const auto& seq = cast<Sequence>(*s);
+      auto stmt_res = TypeCheckStmt(seq.Stmt(), types, values, ret_type,
+                                    is_omitted_ret_type);
       auto types2 = stmt_res.types;
-      auto next_res = TypeCheckStmt(s->GetSequence().next, types2, values,
-                                    ret_type, is_omitted_ret_type);
+      auto next_res = TypeCheckStmt(seq.Next(), types2, values, ret_type,
+                                    is_omitted_ret_type);
       auto types3 = next_res.types;
-      return TCStatement(
-          Statement::MakeSequence(s->line_num, stmt_res.stmt, next_res.stmt),
-          types3);
+      return TCStatement(global_arena->New<Sequence>(
+                             s->LineNumber(), stmt_res.stmt, next_res.stmt),
+                         types3);
     }
-    case StatementKind::Assign: {
-      auto rhs_res = TypeCheckExp(s->GetAssign().rhs, types, values);
+    case Statement::Kind::Assign: {
+      const auto& assign = cast<Assign>(*s);
+      auto rhs_res = TypeCheckExp(assign.Rhs(), types, values);
       auto rhs_t = rhs_res.type;
-      auto lhs_res = TypeCheckExp(s->GetAssign().lhs, types, values);
+      auto lhs_res = TypeCheckExp(assign.Lhs(), types, values);
       auto lhs_t = lhs_res.type;
-      ExpectType(s->line_num, "assign", lhs_t, rhs_t);
-      auto new_s = Statement::MakeAssign(s->line_num, lhs_res.exp, rhs_res.exp);
+      ExpectType(s->LineNumber(), "assign", lhs_t, rhs_t);
+      auto new_s =
+          global_arena->New<Assign>(s->LineNumber(), lhs_res.exp, rhs_res.exp);
       return TCStatement(new_s, lhs_res.types);
     }
-    case StatementKind::ExpressionStatement: {
-      auto res = TypeCheckExp(s->GetExpressionStatement().exp, types, values);
-      auto new_s = Statement::MakeExpressionStatement(s->line_num, res.exp);
+    case Statement::Kind::ExpressionStatement: {
+      auto res =
+          TypeCheckExp(cast<ExpressionStatement>(*s).Exp(), types, values);
+      auto new_s =
+          global_arena->New<ExpressionStatement>(s->LineNumber(), res.exp);
       return TCStatement(new_s, types);
     }
-    case StatementKind::If: {
-      auto cnd_res = TypeCheckExp(s->GetIf().cond, types, values);
-      ExpectType(s->line_num, "condition of `if`",
+    case Statement::Kind::If: {
+      const auto& if_stmt = cast<If>(*s);
+      auto cnd_res = TypeCheckExp(if_stmt.Cond(), types, values);
+      ExpectType(s->LineNumber(), "condition of `if`",
                  global_arena->New<BoolType>(), cnd_res.type);
-      auto thn_res = TypeCheckStmt(s->GetIf().then_stmt, types, values,
-                                   ret_type, is_omitted_ret_type);
-      auto els_res = TypeCheckStmt(s->GetIf().else_stmt, types, values,
-                                   ret_type, is_omitted_ret_type);
-      auto new_s = Statement::MakeIf(s->line_num, cnd_res.exp, thn_res.stmt,
-                                     els_res.stmt);
+      auto then_res = TypeCheckStmt(if_stmt.ThenStmt(), types, values, ret_type,
+                                    is_omitted_ret_type);
+      auto else_res = TypeCheckStmt(if_stmt.ElseStmt(), types, values, ret_type,
+                                    is_omitted_ret_type);
+      auto new_s = global_arena->New<If>(s->LineNumber(), cnd_res.exp,
+                                         then_res.stmt, else_res.stmt);
       return TCStatement(new_s, types);
     }
-    case StatementKind::Return: {
-      const auto& ret = s->GetReturn();
-      auto res = TypeCheckExp(ret.exp, types, values);
+    case Statement::Kind::Return: {
+      const auto& ret = cast<Return>(*s);
+      auto res = TypeCheckExp(ret.Exp(), types, values);
       if (ret_type->Tag() == Value::Kind::AutoType) {
         // The following infers the return type from the first 'return'
         // statement. This will get more difficult with subtyping, when we
         // should infer the least-upper bound of all the 'return' statements.
         ret_type = res.type;
       } else {
-        ExpectType(s->line_num, "return", ret_type, res.type);
+        ExpectType(s->LineNumber(), "return", ret_type, res.type);
       }
-      if (ret.is_omitted_exp != is_omitted_ret_type) {
-        FATAL_COMPILATION_ERROR(s->line_num)
+      if (ret.IsOmittedExp() != is_omitted_ret_type) {
+        FATAL_COMPILATION_ERROR(s->LineNumber())
             << *s << " should" << (is_omitted_ret_type ? " not" : "")
             << " provide a return value, to match the function's signature.";
       }
-      return TCStatement(
-          Statement::MakeReturn(s->line_num, res.exp, ret.is_omitted_exp),
-          types);
+      return TCStatement(global_arena->New<Return>(s->LineNumber(), res.exp,
+                                                   ret.IsOmittedExp()),
+                         types);
     }
-    case StatementKind::Continuation: {
-      TCStatement body_result =
-          TypeCheckStmt(s->GetContinuation().body, types, values, ret_type,
-                        is_omitted_ret_type);
-      const Statement* new_continuation = Statement::MakeContinuation(
-          s->line_num, s->GetContinuation().continuation_variable,
-          body_result.stmt);
-      types.Set(s->GetContinuation().continuation_variable,
+    case Statement::Kind::Continuation: {
+      const auto& cont = cast<Continuation>(*s);
+      TCStatement body_result = TypeCheckStmt(cont.Body(), types, values,
+                                              ret_type, is_omitted_ret_type);
+      const Statement* new_continuation = global_arena->New<Continuation>(
+          s->LineNumber(), cont.ContinuationVariable(), body_result.stmt);
+      types.Set(cont.ContinuationVariable(),
                 global_arena->New<ContinuationType>());
       return TCStatement(new_continuation, types);
     }
-    case StatementKind::Run: {
+    case Statement::Kind::Run: {
       TCExpression argument_result =
-          TypeCheckExp(s->GetRun().argument, types, values);
-      ExpectType(s->line_num, "argument of `run`",
+          TypeCheckExp(cast<Run>(*s).Argument(), types, values);
+      ExpectType(s->LineNumber(), "argument of `run`",
                  global_arena->New<ContinuationType>(), argument_result.type);
       const Statement* new_run =
-          Statement::MakeRun(s->line_num, argument_result.exp);
+          global_arena->New<Run>(s->LineNumber(), argument_result.exp);
       return TCStatement(new_run, types);
     }
-    case StatementKind::Await: {
+    case Statement::Kind::Await: {
       // nothing to do here
       return TCStatement(s, types);
     }
@@ -773,69 +780,73 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
                                 int line_num) -> const Statement* {
   if (!stmt) {
     if (omitted_ret_type) {
-      return Statement::MakeReturn(line_num, nullptr,
-                                   /*is_omitted_exp=*/true);
+      return global_arena->New<Return>(line_num, nullptr,
+                                       /*is_omitted_exp=*/true);
     } else {
       FATAL_COMPILATION_ERROR(line_num)
           << "control-flow reaches end of function that provides a `->` return "
              "type without reaching a return statement";
     }
   }
-  switch (stmt->tag()) {
-    case StatementKind::Match: {
+  switch (stmt->Tag()) {
+    case Statement::Kind::Match: {
+      const auto& match = cast<Match>(*stmt);
       auto new_clauses =
           global_arena
               ->New<std::list<std::pair<const Pattern*, const Statement*>>>();
-      for (auto i = stmt->GetMatch().clauses->begin();
-           i != stmt->GetMatch().clauses->end(); ++i) {
-        auto s =
-            CheckOrEnsureReturn(i->second, omitted_ret_type, stmt->line_num);
-        new_clauses->push_back(std::make_pair(i->first, s));
+      for (const auto& clause : *match.Clauses()) {
+        auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
+                                     stmt->LineNumber());
+        new_clauses->push_back(std::make_pair(clause.first, s));
       }
-      return Statement::MakeMatch(stmt->line_num, stmt->GetMatch().exp,
-                                  new_clauses);
+      return global_arena->New<Match>(stmt->LineNumber(), match.Exp(),
+                                      new_clauses);
     }
-    case StatementKind::Block:
-      return Statement::MakeBlock(
-          stmt->line_num,
-          CheckOrEnsureReturn(stmt->GetBlock().stmt, omitted_ret_type,
-                              stmt->line_num));
-    case StatementKind::If:
-      return Statement::MakeIf(
-          stmt->line_num, stmt->GetIf().cond,
-          CheckOrEnsureReturn(stmt->GetIf().then_stmt, omitted_ret_type,
-                              stmt->line_num),
-          CheckOrEnsureReturn(stmt->GetIf().else_stmt, omitted_ret_type,
-                              stmt->line_num));
-    case StatementKind::Return:
+    case Statement::Kind::Block:
+      return global_arena->New<Block>(
+          stmt->LineNumber(),
+          CheckOrEnsureReturn(cast<Block>(*stmt).Stmt(), omitted_ret_type,
+                              stmt->LineNumber()));
+    case Statement::Kind::If: {
+      const auto& if_stmt = cast<If>(*stmt);
+      return global_arena->New<If>(
+          stmt->LineNumber(), if_stmt.Cond(),
+          CheckOrEnsureReturn(if_stmt.ThenStmt(), omitted_ret_type,
+                              stmt->LineNumber()),
+          CheckOrEnsureReturn(if_stmt.ElseStmt(), omitted_ret_type,
+                              stmt->LineNumber()));
+    }
+    case Statement::Kind::Return:
       return stmt;
-    case StatementKind::Sequence:
-      if (stmt->GetSequence().next) {
-        return Statement::MakeSequence(
-            stmt->line_num, stmt->GetSequence().stmt,
-            CheckOrEnsureReturn(stmt->GetSequence().next, omitted_ret_type,
-                                stmt->line_num));
+    case Statement::Kind::Sequence: {
+      const auto& seq = cast<Sequence>(*stmt);
+      if (seq.Next()) {
+        return global_arena->New<Sequence>(
+            stmt->LineNumber(), seq.Stmt(),
+            CheckOrEnsureReturn(seq.Next(), omitted_ret_type,
+                                stmt->LineNumber()));
       } else {
-        return CheckOrEnsureReturn(stmt->GetSequence().stmt, omitted_ret_type,
-                                   stmt->line_num);
+        return CheckOrEnsureReturn(seq.Stmt(), omitted_ret_type,
+                                   stmt->LineNumber());
       }
-    case StatementKind::Continuation:
-    case StatementKind::Run:
-    case StatementKind::Await:
+    }
+    case Statement::Kind::Continuation:
+    case Statement::Kind::Run:
+    case Statement::Kind::Await:
       return stmt;
-    case StatementKind::Assign:
-    case StatementKind::ExpressionStatement:
-    case StatementKind::While:
-    case StatementKind::Break:
-    case StatementKind::Continue:
-    case StatementKind::VariableDefinition:
+    case Statement::Kind::Assign:
+    case Statement::Kind::ExpressionStatement:
+    case Statement::Kind::While:
+    case Statement::Kind::Break:
+    case Statement::Kind::Continue:
+    case Statement::Kind::VariableDefinition:
       if (omitted_ret_type) {
-        return Statement::MakeSequence(
-            stmt->line_num, stmt,
-            Statement::MakeReturn(line_num, nullptr,
-                                  /*is_omitted_exp=*/true));
+        return global_arena->New<Sequence>(
+            stmt->LineNumber(), stmt,
+            global_arena->New<Return>(line_num, nullptr,
+                                      /*is_omitted_exp=*/true));
       } else {
-        FATAL_COMPILATION_ERROR(stmt->line_num)
+        FATAL_COMPILATION_ERROR(stmt->LineNumber())
             << "control-flow reaches end of function that provides a `->` "
                "return type without reaching a return statement";
       }

+ 15 - 15
executable_semantics/syntax/parser.ypp

@@ -414,35 +414,35 @@ clause_list:
 ;
 statement:
   expression "=" expression ";"
-    { $$ = Statement::MakeAssign(yylineno, $1, $3); }
+    { $$ = global_arena->New<Assign>(yylineno, $1, $3); }
 | VAR pattern "=" expression ";"
-    { $$ = Statement::MakeVariableDefinition(yylineno, $2, $4); }
+    { $$ = global_arena->New<VariableDefinition>(yylineno, $2, $4); }
 | expression ";"
-    { $$ = Statement::MakeExpressionStatement(yylineno, $1); }
+    { $$ = global_arena->New<ExpressionStatement>(yylineno, $1); }
 | if_statement
     { $$ = $1; }
 | WHILE "(" expression ")" block
-    { $$ = Statement::MakeWhile(yylineno, $3, $5); }
+    { $$ = global_arena->New<While>(yylineno, $3, $5); }
 | BREAK ";"
-    { $$ = Statement::MakeBreak(yylineno); }
+    { $$ = global_arena->New<Break>(yylineno); }
 | CONTINUE ";"
-    { $$ = Statement::MakeContinue(yylineno); }
+    { $$ = global_arena->New<Continue>(yylineno); }
 | RETURN return_expression ";"
-    { $$ = Statement::MakeReturn(yylineno, $2.first, $2.second); }
+    { $$ = global_arena->New<Return>(yylineno, $2.first, $2.second); }
 | block
     { $$ = $1; }
 | MATCH "(" expression ")" "{" clause_list "}"
-    { $$ = Statement::MakeMatch(yylineno, $3, $6); }
+    { $$ = global_arena->New<Match>(yylineno, $3, $6); }
 | CONTINUATION identifier statement
-    { $$ = Statement::MakeContinuation(yylineno, $2, $3); }
+    { $$ = global_arena->New<Continuation>(yylineno, $2, $3); }
 | RUN expression ";"
-    { $$ = Statement::MakeRun(yylineno, $2); }
+    { $$ = global_arena->New<Run>(yylineno, $2); }
 | AWAIT ";"
-    { $$ = Statement::MakeAwait(yylineno); }
+    { $$ = global_arena->New<Await>(yylineno); }
 ;
 if_statement:
   IF "(" expression ")" block optional_else
-    { $$ = Statement::MakeIf(yylineno, $3, $5, $6); }
+    { $$ = global_arena->New<If>(yylineno, $3, $5, $6); }
 ;
 optional_else:
   // Empty
@@ -462,11 +462,11 @@ statement_list:
   // Empty
     { $$ = 0; }
 | statement statement_list
-    { $$ = Statement::MakeSequence(yylineno, $1, $2); }
+    { $$ = global_arena->New<Sequence>(yylineno, $1, $2); }
 ;
 block:
   "{" statement_list "}"
-    { $$ = Statement::MakeBlock(yylineno, $2); }
+    { $$ = global_arena->New<Block>(yylineno, $2); }
 ;
 return_type:
   // Empty
@@ -515,7 +515,7 @@ function_definition:
       $$ = FunctionDefinition(
           yylineno, $2, $3, $4,
           global_arena->New<AutoPattern>(yylineno), true,
-          Statement::MakeReturn(yylineno, $6, true));
+          global_arena->New<Return>(yylineno, $6, true));
     }
 ;
 function_declaration: