Forráskód Böngészése

Define a base class for all AST nodes. (#947)

Also implement code-generation to manage the resulting boilerplate.
Geoff Romer 4 éve
szülő
commit
7a5b8434c8

+ 5 - 0
executable_semantics/BUILD

@@ -24,3 +24,8 @@ lit_test(
     test_dir = "testdata",
     tools = [":executable_semantics"],
 )
+
+py_binary(
+    name = "gen_rtti",
+    srcs = ["gen_rtti.py"],
+)

+ 26 - 0
executable_semantics/ast/BUILD

@@ -14,6 +14,27 @@ cc_library(
     ],
 )
 
+cc_library(
+    name = "ast_node",
+    srcs = ["ast_node.cpp"],
+    hdrs = [
+        "ast_node.h",
+        "ast_rtti.h",
+    ],
+    deps = [
+        ":source_location",
+    ],
+)
+
+genrule(
+    name = "ast_rtti",
+    srcs = ["ast_rtti.txt"],
+    outs = ["ast_rtti.h"],
+    cmd = "./$(location //executable_semantics:gen_rtti)" +
+          " $(location ast_rtti.txt) > \"$@\"",
+    tools = ["//executable_semantics:gen_rtti"],
+)
+
 cc_library(
     name = "declaration",
     srcs = ["declaration.cpp"],
@@ -21,6 +42,7 @@ cc_library(
         "declaration.h",
     ],
     deps = [
+        ":ast_node",
         ":member",
         ":pattern",
         ":source_location",
@@ -37,6 +59,7 @@ cc_library(
     srcs = ["expression.cpp"],
     hdrs = ["expression.h"],
     deps = [
+        ":ast_node",
         ":paren_contents",
         "//common:indirect_value",
         "//common:ostream",
@@ -86,6 +109,7 @@ cc_library(
     srcs = ["pattern.cpp"],
     hdrs = ["pattern.h"],
     deps = [
+        ":ast_node",
         ":expression",
         ":source_location",
         ":static_scope",
@@ -112,6 +136,7 @@ cc_library(
     srcs = ["static_scope.cpp"],
     hdrs = ["static_scope.h"],
     deps = [
+        ":ast_node",
         ":source_location",
         "//executable_semantics/common:arena",
         "//executable_semantics/common:error",
@@ -132,6 +157,7 @@ cc_library(
     srcs = ["statement.cpp"],
     hdrs = ["statement.h"],
     deps = [
+        ":ast_node",
         ":expression",
         ":pattern",
         ":source_location",

+ 11 - 0
executable_semantics/ast/ast_node.cpp

@@ -0,0 +1,11 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#include "executable_semantics/ast/ast_node.h"
+
+namespace Carbon {
+
+AstNode::~AstNode() = default;
+
+}  // namespace Carbon

+ 74 - 0
executable_semantics/ast/ast_node.h

@@ -0,0 +1,74 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+#ifndef EXECUTABLE_SEMANTICS_AST_AST_NODE_H_
+#define EXECUTABLE_SEMANTICS_AST_AST_NODE_H_
+
+#include "executable_semantics/ast/ast_rtti.h"
+#include "executable_semantics/ast/source_location.h"
+
+namespace Carbon {
+
+// Base class for all nodes in the AST.
+//
+// Every class derived from this class must be listed in ast_rtti.txt. See
+// the documentation of gen_rtti.py for details about the format. As a result,
+// every abstract class `Foo` will have a `FooKind` enumerated type, whose
+// enumerators correspond to the subclasses of `Foo`.
+//
+// AstNode and its derived classes support LLVM-style RTTI, including
+// llvm::isa, llvm::cast, and llvm::dyn_cast. To support this, every
+// class derived from Declaration must provide a `classof` operation, with
+// the following form, where `Foo` is the name of the derived class:
+//
+// static auto classof(const AstNode* node) -> bool {
+//   return InheritsFromFoo(node->kind());
+// }
+//
+// Furthermore, if the class is abstract, it must provide a `kind()` operation,
+// with the following form:
+//
+// auto kind() const -> FooKind { return static_cast<FooKind>(root_kind()); }
+//
+// The definitions of `InheritsFromFoo` and `FooKind` are generated from
+// ast_rtti.txt, and are implicitly provided by this header.
+//
+// When inheriting from this class, the inheritance must me marked `virtual`.
+//
+// TODO: To support generic traversal, add children() method, and ensure that
+//   all AstNodes are reachable from a root AstNode.
+class AstNode {
+ public:
+  AstNode(AstNode&&) = delete;
+  auto operator=(AstNode&&) -> AstNode& = delete;
+  virtual ~AstNode() = 0;
+
+  // Returns an enumerator specifying the concrete type of this node.
+  //
+  // Abstract subclasses of AstNode will provide their own `kind()` method
+  // which hides this one, and provides a narrower return type.
+  auto kind() const -> AstNodeKind { return kind_; }
+
+  // The location of the code described by this node.
+  auto source_loc() const -> SourceLocation { return source_loc_; }
+
+ protected:
+  // Constructs an AstNode representing code at the given location. `kind`
+  // must be the enumerator that exactly matches the concrete type being
+  // constructed.
+  explicit AstNode(AstNodeKind kind, SourceLocation source_loc)
+      : kind_(kind), source_loc_(source_loc) {}
+
+  // Equivalent to kind(), but will not be hidden by `kind()` methods of
+  // derived classes.
+  auto root_kind() const -> AstNodeKind { return kind_; }
+
+ private:
+  AstNodeKind kind_;
+  SourceLocation source_loc_;
+};
+
+}  // namespace Carbon
+
+#endif  // EXECUTABLE_SEMANTICS_AST_AST_NODE_H_

+ 54 - 0
executable_semantics/ast/ast_rtti.txt

@@ -0,0 +1,54 @@
+# Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+# Exceptions. See /LICENSE for license information.
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+root class AstNode;
+interface class NamedEntity : AstNode;
+abstract class Pattern : AstNode;
+  class AutoPattern : Pattern;
+  class BindingPattern : Pattern, NamedEntity;
+  class TuplePattern : Pattern;
+  class AlternativePattern : Pattern;
+  class ExpressionPattern : Pattern;
+abstract class Declaration : AstNode, NamedEntity;
+  class FunctionDeclaration : Declaration;
+  class ClassDeclaration : Declaration;
+  class ChoiceDeclaration : Declaration;
+  class VariableDeclaration : Declaration;
+class GenericBinding : AstNode, NamedEntity;
+class AlternativeSignature : AstNode, NamedEntity;
+abstract class Statement : AstNode;
+  class ExpressionStatement : Statement;
+  class Assign : Statement;
+  class VariableDefinition : Statement;
+  class If : Statement;
+  class Return : Statement;
+  class Block : Statement;
+  class While : Statement;
+  class Break : Statement;
+  class Continue : Statement;
+  class Match : Statement;
+  class Continuation : Statement, NamedEntity;
+  class Run : Statement;
+  class Await : Statement;
+abstract class Expression : AstNode;
+  class BoolTypeLiteral : Expression;
+  class BoolLiteral : Expression;
+  class CallExpression : Expression;
+  class FunctionTypeLiteral : Expression;
+  class FieldAccessExpression : Expression;
+  class IndexExpression : Expression;
+  class IntTypeLiteral : Expression;
+  class ContinuationTypeLiteral : Expression;
+  class IntLiteral : Expression;
+  class PrimitiveOperatorExpression : Expression;
+  class StringLiteral : Expression;
+  class StringTypeLiteral : Expression;
+  class TupleLiteral : Expression;
+  class StructLiteral : Expression;
+  class StructTypeLiteral : Expression;
+  class TypeTypeLiteral : Expression;
+  class IdentifierExpression : Expression;
+  class IntrinsicExpression : Expression;
+abstract class Member : AstNode;
+  class FieldMember : Member, NamedEntity;

+ 7 - 6
executable_semantics/ast/declaration.cpp

@@ -10,13 +10,15 @@ namespace Carbon {
 
 using llvm::cast;
 
+Declaration::~Declaration() = default;
+
 void Declaration::Print(llvm::raw_ostream& out) const {
   switch (kind()) {
-    case Kind::FunctionDeclaration:
+    case DeclarationKind::FunctionDeclaration:
       cast<FunctionDeclaration>(*this).PrintDepth(-1, out);
       break;
 
-    case Kind::ClassDeclaration: {
+    case DeclarationKind::ClassDeclaration: {
       const auto& class_decl = cast<ClassDeclaration>(*this);
       out << "class " << class_decl.name() << " {\n";
       for (Nonnull<Member*> m : class_decl.members()) {
@@ -26,18 +28,17 @@ void Declaration::Print(llvm::raw_ostream& out) const {
       break;
     }
 
-    case Kind::ChoiceDeclaration: {
+    case DeclarationKind::ChoiceDeclaration: {
       const auto& choice = cast<ChoiceDeclaration>(*this);
       out << "choice " << choice.name() << " {\n";
-      for (Nonnull<const ChoiceDeclaration::Alternative*> alt :
-           choice.alternatives()) {
+      for (Nonnull<const AlternativeSignature*> alt : choice.alternatives()) {
         out << "alt " << alt->name() << " " << alt->signature() << ";\n";
       }
       out << "}\n";
       break;
     }
 
-    case Kind::VariableDeclaration: {
+    case DeclarationKind::VariableDeclaration: {
       const auto& var = cast<VariableDeclaration>(*this);
       out << "var " << var.binding() << " = " << var.initializer() << "\n";
       break;

+ 51 - 60
executable_semantics/ast/declaration.h

@@ -31,14 +31,9 @@ class StaticScope;
 // every concrete derived class must have a corresponding enumerator
 // in `Kind`; see https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html for
 // details.
-class Declaration : public NamedEntityInterface {
+class Declaration : public virtual AstNode, public NamedEntity {
  public:
-  enum class Kind {
-    FunctionDeclaration,
-    ClassDeclaration,
-    ChoiceDeclaration,
-    VariableDeclaration,
-  };
+  ~Declaration() override = 0;
 
   Declaration(const Member&) = delete;
   auto operator=(const Member&) -> Declaration& = delete;
@@ -46,16 +41,16 @@ class Declaration : public NamedEntityInterface {
   void Print(llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromDeclaration(node->kind());
+  }
+
   // Returns the enumerator corresponding to the most-derived type of this
   // object.
-  auto kind() const -> Kind { return kind_; }
-
-  auto named_entity_kind() const -> NamedEntityKind override {
-    return NamedEntityKind::Declaration;
+  auto kind() const -> DeclarationKind {
+    return static_cast<DeclarationKind>(root_kind());
   }
 
-  auto source_loc() const -> SourceLocation override { return source_loc_; }
-
   // The static type of the declared entity. Cannot be called before
   // typechecking.
   auto static_type() const -> const Value& { return **static_type_; }
@@ -73,33 +68,30 @@ class Declaration : public NamedEntityInterface {
   // Constructs a Declaration representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
   // constructed.
-  Declaration(Kind kind, SourceLocation source_loc)
-      : kind_(kind), source_loc_(source_loc) {}
+  Declaration() = default;
 
  private:
-  const Kind kind_;
-  SourceLocation source_loc_;
   std::optional<Nonnull<const Value*>> static_type_;
 };
 
 // TODO: expand the kinds of things that can be deduced parameters.
 //   For now, only generic parameters are supported.
-struct GenericBinding : public NamedEntityInterface {
+struct GenericBinding : public virtual AstNode, public NamedEntity {
  public:
   GenericBinding(SourceLocation source_loc, std::string name,
                  Nonnull<Expression*> type)
-      : source_loc_(source_loc), name_(std::move(name)), type_(type) {}
+      : AstNode(AstNodeKind::GenericBinding, source_loc),
+        name_(std::move(name)),
+        type_(type) {}
 
-  auto named_entity_kind() const -> NamedEntityKind override {
-    return NamedEntityKind::GenericBinding;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromGenericBinding(node->kind());
   }
 
-  auto source_loc() const -> SourceLocation override { return source_loc_; }
   auto name() const -> const std::string& { return name_; }
   auto type() const -> const Expression& { return *type_; }
 
  private:
-  SourceLocation source_loc_;
   std::string name_;
   Nonnull<Expression*> type_;
 };
@@ -190,15 +182,15 @@ class FunctionDeclaration : public Declaration {
                       Nonnull<TuplePattern*> param_pattern,
                       ReturnTerm return_term,
                       std::optional<Nonnull<Block*>> body)
-      : Declaration(Kind::FunctionDeclaration, source_loc),
+      : AstNode(AstNodeKind::FunctionDeclaration, source_loc),
         name_(std::move(name)),
         deduced_parameters_(std::move(deduced_params)),
         param_pattern_(param_pattern),
         return_term_(return_term),
         body_(body) {}
 
-  static auto classof(const Declaration* decl) -> bool {
-    return decl->kind() == Kind::FunctionDeclaration;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromFunctionDeclaration(node->kind());
   }
 
   void PrintDepth(int depth, llvm::raw_ostream& out) const;
@@ -232,12 +224,12 @@ class ClassDeclaration : public Declaration {
  public:
   ClassDeclaration(SourceLocation source_loc, std::string name,
                    std::vector<Nonnull<Member*>> members)
-      : Declaration(Kind::ClassDeclaration, source_loc),
+      : AstNode(AstNodeKind::ClassDeclaration, source_loc),
         name_(std::move(name)),
         members_(std::move(members)) {}
 
-  static auto classof(const Declaration* decl) -> bool {
-    return decl->kind() == Kind::ClassDeclaration;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromClassDeclaration(node->kind());
   }
 
   auto name() const -> const std::string& { return name_; }
@@ -253,42 +245,41 @@ class ClassDeclaration : public Declaration {
   StaticScope static_scope_;
 };
 
-class ChoiceDeclaration : public Declaration {
+class AlternativeSignature : public virtual AstNode, public NamedEntity {
  public:
-  class Alternative : public NamedEntityInterface {
-   public:
-    Alternative(SourceLocation source_loc, std::string name,
-                Nonnull<Expression*> signature)
-        : source_loc_(source_loc),
-          name_(std::move(name)),
-          signature_(signature) {}
-
-    auto named_entity_kind() const -> NamedEntityKind override {
-      return NamedEntityKind::ChoiceDeclarationAlternative;
-    }
-
-    auto source_loc() const -> SourceLocation override { return source_loc_; }
-    auto name() const -> const std::string& { return name_; }
-    auto signature() const -> const Expression& { return *signature_; }
-
-   private:
-    SourceLocation source_loc_;
-    std::string name_;
-    Nonnull<Expression*> signature_;
-  };
+  AlternativeSignature(SourceLocation source_loc, std::string name,
+                       Nonnull<Expression*> signature)
+      : AstNode(AstNodeKind::AlternativeSignature, source_loc),
+        name_(std::move(name)),
+        signature_(signature) {}
 
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromAlternativeSignature(node->kind());
+  }
+
+  auto name() const -> const std::string& { return name_; }
+  auto signature() const -> const Expression& { return *signature_; }
+
+ private:
+  std::string name_;
+  Nonnull<Expression*> signature_;
+};
+
+class ChoiceDeclaration : public Declaration {
+ public:
   ChoiceDeclaration(SourceLocation source_loc, std::string name,
-                    std::vector<Nonnull<Alternative*>> alternatives)
-      : Declaration(Kind::ChoiceDeclaration, source_loc),
+                    std::vector<Nonnull<AlternativeSignature*>> alternatives)
+      : AstNode(AstNodeKind::ChoiceDeclaration, source_loc),
         name_(std::move(name)),
         alternatives_(std::move(alternatives)) {}
 
-  static auto classof(const Declaration* decl) -> bool {
-    return decl->kind() == Kind::ChoiceDeclaration;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromChoiceDeclaration(node->kind());
   }
 
   auto name() const -> const std::string& { return name_; }
-  auto alternatives() const -> llvm::ArrayRef<Nonnull<const Alternative*>> {
+  auto alternatives() const
+      -> llvm::ArrayRef<Nonnull<const AlternativeSignature*>> {
     return alternatives_;
   }
 
@@ -298,7 +289,7 @@ class ChoiceDeclaration : public Declaration {
 
  private:
   std::string name_;
-  std::vector<Nonnull<Alternative*>> alternatives_;
+  std::vector<Nonnull<AlternativeSignature*>> alternatives_;
   StaticScope static_scope_;
 };
 
@@ -308,12 +299,12 @@ class VariableDeclaration : public Declaration {
   VariableDeclaration(SourceLocation source_loc,
                       Nonnull<BindingPattern*> binding,
                       Nonnull<Expression*> initializer)
-      : Declaration(Kind::VariableDeclaration, source_loc),
+      : AstNode(AstNodeKind::VariableDeclaration, source_loc),
         binding_(binding),
         initializer_(initializer) {}
 
-  static auto classof(const Declaration* decl) -> bool {
-    return decl->kind() == Kind::VariableDeclaration;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromVariableDeclaration(node->kind());
   }
 
   auto binding() const -> const BindingPattern& { return *binding_; }

+ 21 - 19
executable_semantics/ast/expression.cpp

@@ -34,6 +34,8 @@ auto TupleExpressionFromParenContents(
   return arena->New<TupleLiteral>(source_loc, paren_contents.elements);
 }
 
+Expression::~Expression() = default;
+
 static void PrintOp(llvm::raw_ostream& out, Operator op) {
   switch (op) {
     case Operator::Add:
@@ -74,17 +76,17 @@ static void PrintFields(llvm::raw_ostream& out,
 
 void Expression::Print(llvm::raw_ostream& out) const {
   switch (kind()) {
-    case Expression::Kind::IndexExpression: {
+    case ExpressionKind::IndexExpression: {
       const auto& index = cast<IndexExpression>(*this);
       out << index.aggregate() << "[" << index.offset() << "]";
       break;
     }
-    case Expression::Kind::FieldAccessExpression: {
+    case ExpressionKind::FieldAccessExpression: {
       const auto& access = cast<FieldAccessExpression>(*this);
       out << access.aggregate() << "." << access.field();
       break;
     }
-    case Expression::Kind::TupleLiteral: {
+    case ExpressionKind::TupleLiteral: {
       out << "(";
       llvm::ListSeparator sep;
       for (Nonnull<const Expression*> field :
@@ -94,25 +96,25 @@ void Expression::Print(llvm::raw_ostream& out) const {
       out << ")";
       break;
     }
-    case Expression::Kind::StructLiteral:
+    case ExpressionKind::StructLiteral:
       out << "{";
       PrintFields(out, cast<StructLiteral>(*this).fields(), " = ");
       out << "}";
       break;
-    case Expression::Kind::StructTypeLiteral:
+    case ExpressionKind::StructTypeLiteral:
       out << "{";
       PrintFields(out, cast<StructTypeLiteral>(*this).fields(), ": ");
       out << "}";
       break;
-    case Expression::Kind::IntLiteral:
+    case ExpressionKind::IntLiteral:
       out << cast<IntLiteral>(*this).value();
       break;
-    case Expression::Kind::BoolLiteral:
+    case ExpressionKind::BoolLiteral:
       out << (cast<BoolLiteral>(*this).value() ? "true" : "false");
       break;
-    case Expression::Kind::PrimitiveOperatorExpression: {
+    case ExpressionKind::PrimitiveOperatorExpression: {
       out << "(";
-      PrimitiveOperatorExpression op = cast<PrimitiveOperatorExpression>(*this);
+      const auto& op = cast<PrimitiveOperatorExpression>(*this);
       switch (op.arguments().size()) {
         case 0:
           PrintOp(out, op.op());
@@ -132,10 +134,10 @@ void Expression::Print(llvm::raw_ostream& out) const {
       out << ")";
       break;
     }
-    case Expression::Kind::IdentifierExpression:
+    case ExpressionKind::IdentifierExpression:
       out << cast<IdentifierExpression>(*this).name();
       break;
-    case Expression::Kind::CallExpression: {
+    case ExpressionKind::CallExpression: {
       const auto& call = cast<CallExpression>(*this);
       out << call.function();
       if (isa<TupleLiteral>(call.argument())) {
@@ -145,32 +147,32 @@ void Expression::Print(llvm::raw_ostream& out) const {
       }
       break;
     }
-    case Expression::Kind::BoolTypeLiteral:
+    case ExpressionKind::BoolTypeLiteral:
       out << "Bool";
       break;
-    case Expression::Kind::IntTypeLiteral:
+    case ExpressionKind::IntTypeLiteral:
       out << "i32";
       break;
-    case Expression::Kind::StringLiteral:
+    case ExpressionKind::StringLiteral:
       out << "\"";
       out.write_escaped(cast<StringLiteral>(*this).value());
       out << "\"";
       break;
-    case Expression::Kind::StringTypeLiteral:
+    case ExpressionKind::StringTypeLiteral:
       out << "String";
       break;
-    case Expression::Kind::TypeTypeLiteral:
+    case ExpressionKind::TypeTypeLiteral:
       out << "Type";
       break;
-    case Expression::Kind::ContinuationTypeLiteral:
+    case ExpressionKind::ContinuationTypeLiteral:
       out << "Continuation";
       break;
-    case Expression::Kind::FunctionTypeLiteral: {
+    case ExpressionKind::FunctionTypeLiteral: {
       const auto& fn = cast<FunctionTypeLiteral>(*this);
       out << "fn " << fn.parameter() << " -> " << fn.return_type();
       break;
     }
-    case Expression::Kind::IntrinsicExpression:
+    case ExpressionKind::IntrinsicExpression:
       out << "intrinsic_expression(";
       switch (cast<IntrinsicExpression>(*this).intrinsic()) {
         case IntrinsicExpression::Intrinsic::Print:

+ 68 - 83
executable_semantics/ast/expression.h

@@ -11,6 +11,7 @@
 #include <vector>
 
 #include "common/ostream.h"
+#include "executable_semantics/ast/ast_node.h"
 #include "executable_semantics/ast/paren_contents.h"
 #include "executable_semantics/ast/source_location.h"
 #include "executable_semantics/common/arena.h"
@@ -21,37 +22,22 @@ namespace Carbon {
 
 class Value;
 
-class Expression {
+class Expression : public virtual AstNode {
  public:
-  enum class Kind {
-    BoolTypeLiteral,
-    BoolLiteral,
-    CallExpression,
-    FunctionTypeLiteral,
-    FieldAccessExpression,
-    IndexExpression,
-    IntTypeLiteral,
-    ContinuationTypeLiteral,  // The type of a continuation value.
-    IntLiteral,
-    PrimitiveOperatorExpression,
-    StringLiteral,
-    StringTypeLiteral,
-    TupleLiteral,
-    StructLiteral,
-    StructTypeLiteral,
-    TypeTypeLiteral,
-    IdentifierExpression,
-    IntrinsicExpression,
-  };
+  ~Expression() override = 0;
 
   void Print(llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
+  static auto classof(const AstNode* node) {
+    return InheritsFromExpression(node->kind());
+  }
+
   // Returns the enumerator corresponding to the most-derived type of this
   // object.
-  auto kind() const -> Kind { return kind_; }
-
-  auto source_loc() const -> SourceLocation { return source_loc_; }
+  auto kind() const -> ExpressionKind {
+    return static_cast<ExpressionKind>(root_kind());
+  }
 
   // The static type of this expression. Cannot be called before typechecking.
   auto static_type() const -> const Value& { return **static_type_; }
@@ -69,13 +55,9 @@ class Expression {
   // Constructs an Expression representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
   // constructed.
-  Expression(Kind kind, SourceLocation source_loc)
-      : kind_(kind), source_loc_(source_loc) {}
+  Expression() = default;
 
  private:
-  const Kind kind_;
-  SourceLocation source_loc_;
-
   std::optional<Nonnull<const Value*>> static_type_;
 };
 
@@ -114,11 +96,11 @@ enum class Operator {
 class IdentifierExpression : public Expression {
  public:
   explicit IdentifierExpression(SourceLocation source_loc, std::string name)
-      : Expression(Kind::IdentifierExpression, source_loc),
+      : AstNode(AstNodeKind::IdentifierExpression, source_loc),
         name_(std::move(name)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::IdentifierExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIdentifierExpression(node->kind());
   }
 
   auto name() const -> const std::string& { return name_; }
@@ -132,12 +114,12 @@ class FieldAccessExpression : public Expression {
   explicit FieldAccessExpression(SourceLocation source_loc,
                                  Nonnull<Expression*> aggregate,
                                  std::string field)
-      : Expression(Kind::FieldAccessExpression, source_loc),
+      : AstNode(AstNodeKind::FieldAccessExpression, source_loc),
         aggregate_(aggregate),
         field_(std::move(field)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::FieldAccessExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromFieldAccessExpression(node->kind());
   }
 
   auto aggregate() const -> const Expression& { return *aggregate_; }
@@ -154,12 +136,12 @@ class IndexExpression : public Expression {
   explicit IndexExpression(SourceLocation source_loc,
                            Nonnull<Expression*> aggregate,
                            Nonnull<Expression*> offset)
-      : Expression(Kind::IndexExpression, source_loc),
+      : AstNode(AstNodeKind::IndexExpression, source_loc),
         aggregate_(aggregate),
         offset_(offset) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::IndexExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIndexExpression(node->kind());
   }
 
   auto aggregate() const -> const Expression& { return *aggregate_; }
@@ -175,10 +157,10 @@ class IndexExpression : public Expression {
 class IntLiteral : public Expression {
  public:
   explicit IntLiteral(SourceLocation source_loc, int value)
-      : Expression(Kind::IntLiteral, source_loc), value_(value) {}
+      : AstNode(AstNodeKind::IntLiteral, source_loc), value_(value) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::IntLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIntLiteral(node->kind());
   }
 
   auto value() const -> int { return value_; }
@@ -190,10 +172,10 @@ class IntLiteral : public Expression {
 class BoolLiteral : public Expression {
  public:
   explicit BoolLiteral(SourceLocation source_loc, bool value)
-      : Expression(Kind::BoolLiteral, source_loc), value_(value) {}
+      : AstNode(AstNodeKind::BoolLiteral, source_loc), value_(value) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::BoolLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBoolLiteral(node->kind());
   }
 
   auto value() const -> bool { return value_; }
@@ -205,10 +187,11 @@ class BoolLiteral : public Expression {
 class StringLiteral : public Expression {
  public:
   explicit StringLiteral(SourceLocation source_loc, std::string value)
-      : Expression(Kind::StringLiteral, source_loc), value_(std::move(value)) {}
+      : AstNode(AstNodeKind::StringLiteral, source_loc),
+        value_(std::move(value)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::StringLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromStringLiteral(node->kind());
   }
 
   auto value() const -> const std::string& { return value_; }
@@ -220,10 +203,10 @@ class StringLiteral : public Expression {
 class StringTypeLiteral : public Expression {
  public:
   explicit StringTypeLiteral(SourceLocation source_loc)
-      : Expression(Kind::StringTypeLiteral, source_loc) {}
+      : AstNode(AstNodeKind::StringTypeLiteral, source_loc) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::StringTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromStringTypeLiteral(node->kind());
   }
 };
 
@@ -234,11 +217,11 @@ class TupleLiteral : public Expression {
 
   explicit TupleLiteral(SourceLocation source_loc,
                         std::vector<Nonnull<Expression*>> fields)
-      : Expression(Kind::TupleLiteral, source_loc),
+      : AstNode(AstNodeKind::TupleLiteral, source_loc),
         fields_(std::move(fields)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::TupleLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromTupleLiteral(node->kind());
   }
 
   auto fields() const -> llvm::ArrayRef<Nonnull<const Expression*>> {
@@ -260,13 +243,13 @@ class StructLiteral : public Expression {
  public:
   explicit StructLiteral(SourceLocation loc,
                          std::vector<FieldInitializer> fields)
-      : Expression(Kind::StructLiteral, loc), fields_(std::move(fields)) {
+      : AstNode(AstNodeKind::StructLiteral, loc), fields_(std::move(fields)) {
     CHECK(!fields_.empty())
         << "`{}` is represented as a StructTypeLiteral, not a StructLiteral.";
   }
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::StructLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromStructLiteral(node->kind());
   }
 
   auto fields() const -> llvm::ArrayRef<FieldInitializer> { return fields_; }
@@ -286,10 +269,11 @@ class StructTypeLiteral : public Expression {
 
   explicit StructTypeLiteral(SourceLocation loc,
                              std::vector<FieldInitializer> fields)
-      : Expression(Kind::StructTypeLiteral, loc), fields_(std::move(fields)) {}
+      : AstNode(AstNodeKind::StructTypeLiteral, loc),
+        fields_(std::move(fields)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::StructTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromStructTypeLiteral(node->kind());
   }
 
   auto fields() const -> llvm::ArrayRef<FieldInitializer> { return fields_; }
@@ -304,12 +288,12 @@ class PrimitiveOperatorExpression : public Expression {
   explicit PrimitiveOperatorExpression(
       SourceLocation source_loc, Operator op,
       std::vector<Nonnull<Expression*>> arguments)
-      : Expression(Kind::PrimitiveOperatorExpression, source_loc),
+      : AstNode(AstNodeKind::PrimitiveOperatorExpression, source_loc),
         op_(op),
         arguments_(std::move(arguments)) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::PrimitiveOperatorExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromPrimitiveOperatorExpression(node->kind());
   }
 
   auto op() const -> Operator { return op_; }
@@ -330,12 +314,12 @@ class CallExpression : public Expression {
   explicit CallExpression(SourceLocation source_loc,
                           Nonnull<Expression*> function,
                           Nonnull<Expression*> argument)
-      : Expression(Kind::CallExpression, source_loc),
+      : AstNode(AstNodeKind::CallExpression, source_loc),
         function_(function),
         argument_(argument) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::CallExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromCallExpression(node->kind());
   }
 
   auto function() const -> const Expression& { return *function_; }
@@ -353,12 +337,12 @@ class FunctionTypeLiteral : public Expression {
   explicit FunctionTypeLiteral(SourceLocation source_loc,
                                Nonnull<Expression*> parameter,
                                Nonnull<Expression*> return_type)
-      : Expression(Kind::FunctionTypeLiteral, source_loc),
+      : AstNode(AstNodeKind::FunctionTypeLiteral, source_loc),
         parameter_(parameter),
         return_type_(return_type) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::FunctionTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromFunctionTypeLiteral(node->kind());
   }
 
   auto parameter() const -> const Expression& { return *parameter_; }
@@ -374,40 +358,40 @@ class FunctionTypeLiteral : public Expression {
 class BoolTypeLiteral : public Expression {
  public:
   explicit BoolTypeLiteral(SourceLocation source_loc)
-      : Expression(Kind::BoolTypeLiteral, source_loc) {}
+      : AstNode(AstNodeKind::BoolTypeLiteral, source_loc) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::BoolTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBoolTypeLiteral(node->kind());
   }
 };
 
 class IntTypeLiteral : public Expression {
  public:
   explicit IntTypeLiteral(SourceLocation source_loc)
-      : Expression(Kind::IntTypeLiteral, source_loc) {}
+      : AstNode(AstNodeKind::IntTypeLiteral, source_loc) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::IntTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIntTypeLiteral(node->kind());
   }
 };
 
 class ContinuationTypeLiteral : public Expression {
  public:
   explicit ContinuationTypeLiteral(SourceLocation source_loc)
-      : Expression(Kind::ContinuationTypeLiteral, source_loc) {}
+      : AstNode(AstNodeKind::ContinuationTypeLiteral, source_loc) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::ContinuationTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromContinuationTypeLiteral(node->kind());
   }
 };
 
 class TypeTypeLiteral : public Expression {
  public:
   explicit TypeTypeLiteral(SourceLocation source_loc)
-      : Expression(Kind::TypeTypeLiteral, source_loc) {}
+      : AstNode(AstNodeKind::TypeTypeLiteral, source_loc) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::TypeTypeLiteral;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromTypeTypeLiteral(node->kind());
   }
 };
 
@@ -418,11 +402,12 @@ class IntrinsicExpression : public Expression {
   };
 
   explicit IntrinsicExpression(Intrinsic intrinsic)
-      : Expression(Kind::IntrinsicExpression, SourceLocation("<intrinsic>", 0)),
+      : AstNode(AstNodeKind::IntrinsicExpression,
+                SourceLocation("<intrinsic>", 0)),
         intrinsic_(intrinsic) {}
 
-  static auto classof(const Expression* exp) -> bool {
-    return exp->kind() == Kind::IntrinsicExpression;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIntrinsicExpression(node->kind());
   }
 
   auto intrinsic() const -> Intrinsic { return intrinsic_; }

+ 9 - 9
executable_semantics/ast/expression_test.cpp

@@ -21,7 +21,7 @@ using testing::ElementsAre;
 using testing::IsEmpty;
 
 // Matches any `IntLiteral`.
-MATCHER(IntField, "") { return arg->kind() == Expression::Kind::IntLiteral; }
+MATCHER(IntField, "") { return arg->kind() == ExpressionKind::IntLiteral; }
 
 static auto FakeSourceLoc(int line_num) -> SourceLocation {
   return SourceLocation("<test>", line_num);
@@ -38,7 +38,7 @@ TEST_F(ExpressionTest, EmptyAsExpression) {
   Nonnull<const Expression*> expression =
       ExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(expression->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(expression->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*expression).fields(), IsEmpty());
 }
 
@@ -48,7 +48,7 @@ TEST_F(ExpressionTest, EmptyAsTuple) {
   Nonnull<const Expression*> tuple =
       TupleExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(tuple->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(tuple->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*tuple).fields(), IsEmpty());
 }
 
@@ -66,7 +66,7 @@ TEST_F(ExpressionTest, UnaryNoCommaAsExpression) {
   Nonnull<const Expression*> expression =
       ExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->source_loc(), FakeSourceLoc(2));
-  ASSERT_EQ(expression->kind(), Expression::Kind::IntLiteral);
+  ASSERT_EQ(expression->kind(), ExpressionKind::IntLiteral);
 }
 
 TEST_F(ExpressionTest, UnaryNoCommaAsTuple) {
@@ -77,7 +77,7 @@ TEST_F(ExpressionTest, UnaryNoCommaAsTuple) {
   Nonnull<const Expression*> tuple =
       TupleExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(tuple->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(tuple->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*tuple).fields(), ElementsAre(IntField()));
 }
 
@@ -89,7 +89,7 @@ TEST_F(ExpressionTest, UnaryWithCommaAsExpression) {
   Nonnull<const Expression*> expression =
       ExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(expression->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(expression->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*expression).fields(),
               ElementsAre(IntField()));
 }
@@ -102,7 +102,7 @@ TEST_F(ExpressionTest, UnaryWithCommaAsTuple) {
   Nonnull<const Expression*> tuple =
       TupleExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(tuple->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(tuple->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*tuple).fields(), ElementsAre(IntField()));
 }
 
@@ -115,7 +115,7 @@ TEST_F(ExpressionTest, BinaryAsExpression) {
   Nonnull<const Expression*> expression =
       ExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(expression->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(expression->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*expression).fields(),
               ElementsAre(IntField(), IntField()));
 }
@@ -129,7 +129,7 @@ TEST_F(ExpressionTest, BinaryAsTuple) {
   Nonnull<const Expression*> tuple =
       TupleExpressionFromParenContents(&arena, FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->source_loc(), FakeSourceLoc(1));
-  ASSERT_EQ(tuple->kind(), Expression::Kind::TupleLiteral);
+  ASSERT_EQ(tuple->kind(), ExpressionKind::TupleLiteral);
   EXPECT_THAT(cast<TupleLiteral>(*tuple).fields(),
               ElementsAre(IntField(), IntField()));
 }

+ 3 - 1
executable_semantics/ast/member.cpp

@@ -11,9 +11,11 @@ namespace Carbon {
 
 using llvm::cast;
 
+Member::~Member() = default;
+
 void Member::Print(llvm::raw_ostream& out) const {
   switch (kind()) {
-    case Kind::FieldMember:
+    case MemberKind::FieldMember:
       const auto& field = cast<FieldMember>(*this);
       out << "var " << field.binding() << ";\n";
       break;

+ 12 - 20
executable_semantics/ast/member.h

@@ -23,9 +23,9 @@ namespace Carbon {
 // every concrete derived class must have a corresponding enumerator
 // in `Kind`; see https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html for
 // details.
-class Member : public NamedEntityInterface {
+class Member : public virtual AstNode, public NamedEntity {
  public:
-  enum class Kind { FieldMember };
+  ~Member() override = 0;
 
   Member(const Member&) = delete;
   auto operator=(const Member&) -> Member& = delete;
@@ -33,35 +33,27 @@ class Member : public NamedEntityInterface {
   void Print(llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromMember(node->kind());
+  }
+
   // Returns the enumerator corresponding to the most-derived type of this
   // object.
-  auto kind() const -> Kind { return kind_; }
-
-  auto named_entity_kind() const -> NamedEntityKind override {
-    return NamedEntityKind::Member;
+  auto kind() const -> MemberKind {
+    return static_cast<MemberKind>(root_kind());
   }
 
-  auto source_loc() const -> SourceLocation override { return source_loc_; }
-
  protected:
-  // Constructs a Member representing syntax at the given line number.
-  // `kind` must be the enumerator corresponding to the most-derived type being
-  // constructed.
-  Member(Kind kind, SourceLocation source_loc)
-      : kind_(kind), source_loc_(source_loc) {}
-
- private:
-  const Kind kind_;
-  SourceLocation source_loc_;
+  Member() = default;
 };
 
 class FieldMember : public Member {
  public:
   FieldMember(SourceLocation source_loc, Nonnull<const BindingPattern*> binding)
-      : Member(Kind::FieldMember, source_loc), binding_(binding) {}
+      : AstNode(AstNodeKind::FieldMember, source_loc), binding_(binding) {}
 
-  static auto classof(const Member* member) -> bool {
-    return member->kind() == Kind::FieldMember;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromFieldMember(node->kind());
   }
 
   auto binding() const -> const BindingPattern& { return *binding_; }

+ 9 - 7
executable_semantics/ast/pattern.cpp

@@ -17,12 +17,14 @@ namespace Carbon {
 
 using llvm::cast;
 
+Pattern::~Pattern() = default;
+
 void Pattern::Print(llvm::raw_ostream& out) const {
   switch (kind()) {
-    case Kind::AutoPattern:
+    case PatternKind::AutoPattern:
       out << "auto";
       break;
-    case Kind::BindingPattern: {
+    case PatternKind::BindingPattern: {
       const auto& binding = cast<BindingPattern>(*this);
       if (binding.name().has_value()) {
         out << *binding.name();
@@ -32,7 +34,7 @@ void Pattern::Print(llvm::raw_ostream& out) const {
       out << ": " << binding.type();
       break;
     }
-    case Kind::TuplePattern: {
+    case PatternKind::TuplePattern: {
       const auto& tuple = cast<TuplePattern>(*this);
       out << "(";
       llvm::ListSeparator sep;
@@ -42,13 +44,13 @@ void Pattern::Print(llvm::raw_ostream& out) const {
       out << ")";
       break;
     }
-    case Kind::AlternativePattern: {
+    case PatternKind::AlternativePattern: {
       const auto& alternative = cast<AlternativePattern>(*this);
       out << alternative.choice_type() << "." << alternative.alternative_name()
           << alternative.arguments();
       break;
     }
-    case Kind::ExpressionPattern:
+    case PatternKind::ExpressionPattern:
       out << cast<ExpressionPattern>(*this).expression();
       break;
   }
@@ -77,7 +79,7 @@ auto TuplePatternFromParenContents(Nonnull<Arena*> arena,
 // apply.
 static auto RequireFieldAccess(Nonnull<Expression*> alternative)
     -> FieldAccessExpression& {
-  if (alternative->kind() != Expression::Kind::FieldAccessExpression) {
+  if (alternative->kind() != ExpressionKind::FieldAccessExpression) {
     FATAL_PROGRAM_ERROR(alternative->source_loc())
         << "Alternative pattern must have the form of a field access.";
   }
@@ -87,7 +89,7 @@ static auto RequireFieldAccess(Nonnull<Expression*> alternative)
 AlternativePattern::AlternativePattern(SourceLocation source_loc,
                                        Nonnull<Expression*> alternative,
                                        Nonnull<TuplePattern*> arguments)
-    : Pattern(Kind::AlternativePattern, source_loc),
+    : AstNode(AstNodeKind::AlternativePattern, source_loc),
       choice_type_(&RequireFieldAccess(alternative).aggregate()),
       alternative_name_(RequireFieldAccess(alternative).field()),
       arguments_(arguments) {}

+ 30 - 41
executable_semantics/ast/pattern.h

@@ -10,6 +10,8 @@
 #include <vector>
 
 #include "common/ostream.h"
+#include "executable_semantics/ast/ast_node.h"
+#include "executable_semantics/ast/ast_rtti.h"
 #include "executable_semantics/ast/expression.h"
 #include "executable_semantics/ast/source_location.h"
 #include "executable_semantics/ast/static_scope.h"
@@ -27,27 +29,25 @@ class Value;
 // every concrete derived class must have a corresponding enumerator
 // in `Kind`; see https://llvm.org/docs/HowToSetUpLLVMStyleRTTI.html for
 // details.
-class Pattern {
+class Pattern : public virtual AstNode {
  public:
-  enum class Kind {
-    AutoPattern,
-    BindingPattern,
-    TuplePattern,
-    AlternativePattern,
-    ExpressionPattern,
-  };
-
   Pattern(const Pattern&) = delete;
   auto operator=(const Pattern&) -> Pattern& = delete;
 
+  ~Pattern() override = 0;
+
   void Print(llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromPattern(node->kind());
+  }
+
   // Returns the enumerator corresponding to the most-derived type of this
   // object.
-  auto kind() const -> Kind { return kind_; }
-
-  auto source_loc() const -> SourceLocation { return source_loc_; }
+  auto kind() const -> PatternKind {
+    return static_cast<PatternKind>(root_kind());
+  }
 
   // The static type of this pattern. Cannot be called before typechecking.
   auto static_type() const -> const Value& { return **static_type_; }
@@ -77,13 +77,9 @@ class Pattern {
   // Constructs a Pattern representing syntax at the given line number.
   // `kind` must be the enumerator corresponding to the most-derived type being
   // constructed.
-  Pattern(Kind kind, SourceLocation source_loc)
-      : kind_(kind), source_loc_(source_loc) {}
+  Pattern() = default;
 
  private:
-  const Kind kind_;
-  SourceLocation source_loc_;
-
   std::optional<Nonnull<const Value*>> static_type_;
   std::optional<Nonnull<const Value*>> value_;
 };
@@ -92,33 +88,25 @@ class Pattern {
 class AutoPattern : public Pattern {
  public:
   explicit AutoPattern(SourceLocation source_loc)
-      : Pattern(Kind::AutoPattern, source_loc) {}
+      : AstNode(AstNodeKind::AutoPattern, source_loc) {}
 
-  static auto classof(const Pattern* pattern) -> bool {
-    return pattern->kind() == Kind::AutoPattern;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromAutoPattern(node->kind());
   }
 };
 
 // A pattern that matches a value of a specified type, and optionally binds
 // a name to it.
-class BindingPattern : public Pattern, public NamedEntityInterface {
+class BindingPattern : public Pattern, public NamedEntity {
  public:
   BindingPattern(SourceLocation source_loc, std::optional<std::string> name,
                  Nonnull<Pattern*> type)
-      : Pattern(Kind::BindingPattern, source_loc),
+      : AstNode(AstNodeKind::BindingPattern, source_loc),
         name_(std::move(name)),
         type_(type) {}
 
-  auto named_entity_kind() const -> NamedEntityKind override {
-    return NamedEntityKind::BindingPattern;
-  }
-
-  auto source_loc() const -> SourceLocation override {
-    return Pattern::source_loc();
-  }
-
-  static auto classof(const Pattern* pattern) -> bool {
-    return pattern->kind() == Kind::BindingPattern;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBindingPattern(node->kind());
   }
 
   // The name this pattern binds, if any.
@@ -137,10 +125,11 @@ class BindingPattern : public Pattern, public NamedEntityInterface {
 class TuplePattern : public Pattern {
  public:
   TuplePattern(SourceLocation source_loc, std::vector<Nonnull<Pattern*>> fields)
-      : Pattern(Kind::TuplePattern, source_loc), fields_(std::move(fields)) {}
+      : AstNode(AstNodeKind::TuplePattern, source_loc),
+        fields_(std::move(fields)) {}
 
-  static auto classof(const Pattern* pattern) -> bool {
-    return pattern->kind() == Kind::TuplePattern;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromTuplePattern(node->kind());
   }
 
   auto fields() const -> llvm::ArrayRef<Nonnull<const Pattern*>> {
@@ -182,7 +171,7 @@ class AlternativePattern : public Pattern {
                      Nonnull<Expression*> choice_type,
                      std::string alternative_name,
                      Nonnull<TuplePattern*> arguments)
-      : Pattern(Kind::AlternativePattern, source_loc),
+      : AstNode(AstNodeKind::AlternativePattern, source_loc),
         choice_type_(choice_type),
         alternative_name_(std::move(alternative_name)),
         arguments_(arguments) {}
@@ -193,8 +182,8 @@ class AlternativePattern : public Pattern {
                      Nonnull<Expression*> alternative,
                      Nonnull<TuplePattern*> arguments);
 
-  static auto classof(const Pattern* pattern) -> bool {
-    return pattern->kind() == Kind::AlternativePattern;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromAlternativePattern(node->kind());
   }
 
   auto choice_type() const -> const Expression& { return *choice_type_; }
@@ -216,11 +205,11 @@ class AlternativePattern : public Pattern {
 class ExpressionPattern : public Pattern {
  public:
   explicit ExpressionPattern(Nonnull<Expression*> expression)
-      : Pattern(Kind::ExpressionPattern, expression->source_loc()),
+      : AstNode(AstNodeKind::ExpressionPattern, expression->source_loc()),
         expression_(expression) {}
 
-  static auto classof(const Pattern* pattern) -> bool {
-    return pattern->kind() == Kind::ExpressionPattern;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromExpressionPattern(node->kind());
   }
 
   auto expression() const -> const Expression& { return *expression_; }

+ 15 - 13
executable_semantics/ast/statement.cpp

@@ -12,13 +12,15 @@ namespace Carbon {
 
 using llvm::cast;
 
+Statement::~Statement() = default;
+
 void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
   if (depth == 0) {
     out << " ... ";
     return;
   }
   switch (kind()) {
-    case Kind::Match: {
+    case StatementKind::Match: {
       const auto& match = cast<Match>(*this);
       out << "match (" << match.expression() << ") {";
       if (depth < 0 || depth > 1) {
@@ -34,32 +36,32 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       out << "}";
       break;
     }
-    case Kind::While: {
+    case StatementKind::While: {
       const auto& while_stmt = cast<While>(*this);
       out << "while (" << while_stmt.condition() << ")\n";
       while_stmt.body().PrintDepth(depth - 1, out);
       break;
     }
-    case Kind::Break:
+    case StatementKind::Break:
       out << "break;";
       break;
-    case Kind::Continue:
+    case StatementKind::Continue:
       out << "continue;";
       break;
-    case Kind::VariableDefinition: {
+    case StatementKind::VariableDefinition: {
       const auto& var = cast<VariableDefinition>(*this);
       out << "var " << var.pattern() << " = " << var.init() << ";";
       break;
     }
-    case Kind::ExpressionStatement:
+    case StatementKind::ExpressionStatement:
       out << cast<ExpressionStatement>(*this).expression() << ";";
       break;
-    case Kind::Assign: {
+    case StatementKind::Assign: {
       const auto& assign = cast<Assign>(*this);
       out << assign.lhs() << " = " << assign.rhs() << ";";
       break;
     }
-    case Kind::If: {
+    case StatementKind::If: {
       const auto& if_stmt = cast<If>(*this);
       out << "if (" << if_stmt.condition() << ")\n";
       if_stmt.then_block().PrintDepth(depth - 1, out);
@@ -69,7 +71,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       }
       break;
     }
-    case Kind::Return: {
+    case StatementKind::Return: {
       const auto& ret = cast<Return>(*this);
       if (ret.is_omitted_expression()) {
         out << "return;";
@@ -78,7 +80,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       }
       break;
     }
-    case Kind::Block: {
+    case StatementKind::Block: {
       const auto& block = cast<Block>(*this);
       out << "{";
       if (depth < 0 || depth > 1) {
@@ -96,7 +98,7 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       }
       break;
     }
-    case Kind::Continuation: {
+    case StatementKind::Continuation: {
       const auto& cont = cast<Continuation>(*this);
       out << "continuation " << cont.continuation_variable() << " ";
       if (depth < 0 || depth > 1) {
@@ -108,10 +110,10 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
       }
       break;
     }
-    case Kind::Run:
+    case StatementKind::Run:
       out << "run " << cast<Run>(*this).argument() << ";";
       break;
-    case Kind::Await:
+    case StatementKind::Await:
       out << "await;";
       break;
   }

+ 50 - 76
executable_semantics/ast/statement.h

@@ -21,53 +21,35 @@ namespace Carbon {
 class FunctionDeclaration;
 class StaticScope;
 
-class Statement {
+class Statement : public virtual AstNode {
  public:
-  enum class Kind {
-    ExpressionStatement,
-    Assign,
-    VariableDefinition,
-    If,
-    Return,
-    Block,
-    While,
-    Break,
-    Continue,
-    Match,
-    Continuation,  // Create a first-class continuation.
-    Run,           // Run a continuation to the next await or until it finishes.
-    Await,         // Pause execution of the continuation.
-  };
+  ~Statement() override = 0;
 
   void Print(llvm::raw_ostream& out) const { PrintDepth(-1, out); }
   void PrintDepth(int depth, llvm::raw_ostream& out) const;
   LLVM_DUMP_METHOD void Dump() const { Print(llvm::errs()); }
 
+  static auto classof(const AstNode* node) {
+    return InheritsFromStatement(node->kind());
+  }
+
   // Returns the enumerator corresponding to the most-derived type of this
   // object.
-  auto kind() const -> Kind { return kind_; }
-
-  auto source_loc() const -> SourceLocation { return source_loc_; }
+  auto kind() const -> StatementKind {
+    return static_cast<StatementKind>(root_kind());
+  }
 
  protected:
-  // Constructs an Statement representing syntax at the given line number.
-  // `kind` must be the enumerator corresponding to the most-derived type being
-  // constructed.
-  Statement(Kind kind, SourceLocation source_loc)
-      : kind_(kind), source_loc_(source_loc) {}
-
- private:
-  const Kind kind_;
-  SourceLocation source_loc_;
+  Statement() = default;
 };
 
 class Block : public Statement {
  public:
   Block(SourceLocation source_loc, std::vector<Nonnull<Statement*>> statements)
-      : Statement(Kind::Block, source_loc), statements_(statements) {}
+      : AstNode(AstNodeKind::Block, source_loc), statements_(statements) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Block;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBlock(node->kind());
   }
 
   auto statements() const -> llvm::ArrayRef<Nonnull<const Statement*>> {
@@ -89,11 +71,11 @@ class ExpressionStatement : public Statement {
  public:
   ExpressionStatement(SourceLocation source_loc,
                       Nonnull<Expression*> expression)
-      : Statement(Kind::ExpressionStatement, source_loc),
+      : AstNode(AstNodeKind::ExpressionStatement, source_loc),
         expression_(expression) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::ExpressionStatement;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromExpressionStatement(node->kind());
   }
 
   auto expression() const -> const Expression& { return *expression_; }
@@ -107,10 +89,10 @@ class Assign : public Statement {
  public:
   Assign(SourceLocation source_loc, Nonnull<Expression*> lhs,
          Nonnull<Expression*> rhs)
-      : Statement(Kind::Assign, source_loc), lhs_(lhs), rhs_(rhs) {}
+      : AstNode(AstNodeKind::Assign, source_loc), lhs_(lhs), rhs_(rhs) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Assign;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromAssign(node->kind());
   }
 
   auto lhs() const -> const Expression& { return *lhs_; }
@@ -127,12 +109,12 @@ class VariableDefinition : public Statement {
  public:
   VariableDefinition(SourceLocation source_loc, Nonnull<Pattern*> pattern,
                      Nonnull<Expression*> init)
-      : Statement(Kind::VariableDefinition, source_loc),
+      : AstNode(AstNodeKind::VariableDefinition, source_loc),
         pattern_(pattern),
         init_(init) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::VariableDefinition;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromVariableDefinition(node->kind());
   }
 
   auto pattern() const -> const Pattern& { return *pattern_; }
@@ -149,13 +131,13 @@ class If : public Statement {
  public:
   If(SourceLocation source_loc, Nonnull<Expression*> condition,
      Nonnull<Block*> then_block, std::optional<Nonnull<Block*>> else_block)
-      : Statement(Kind::If, source_loc),
+      : AstNode(AstNodeKind::If, source_loc),
         condition_(condition),
         then_block_(then_block),
         else_block_(else_block) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::If;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromIf(node->kind());
   }
 
   auto condition() const -> const Expression& { return *condition_; }
@@ -179,12 +161,12 @@ class Return : public Statement {
       : Return(source_loc, arena->New<TupleLiteral>(source_loc), true) {}
   Return(SourceLocation source_loc, Nonnull<Expression*> expression,
          bool is_omitted_expression)
-      : Statement(Kind::Return, source_loc),
+      : AstNode(AstNodeKind::Return, source_loc),
         expression_(expression),
         is_omitted_expression_(is_omitted_expression) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Return;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromReturn(node->kind());
   }
 
   auto expression() const -> const Expression& { return *expression_; }
@@ -216,12 +198,12 @@ class While : public Statement {
  public:
   While(SourceLocation source_loc, Nonnull<Expression*> condition,
         Nonnull<Block*> body)
-      : Statement(Kind::While, source_loc),
+      : AstNode(AstNodeKind::While, source_loc),
         condition_(condition),
         body_(body) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::While;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromWhile(node->kind());
   }
 
   auto condition() const -> const Expression& { return *condition_; }
@@ -237,10 +219,10 @@ class While : public Statement {
 class Break : public Statement {
  public:
   explicit Break(SourceLocation source_loc)
-      : Statement(Kind::Break, source_loc) {}
+      : AstNode(AstNodeKind::Break, source_loc) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Break;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromBreak(node->kind());
   }
 
   // The AST node representing the loop this statement breaks out of.
@@ -264,10 +246,10 @@ class Break : public Statement {
 class Continue : public Statement {
  public:
   explicit Continue(SourceLocation source_loc)
-      : Statement(Kind::Continue, source_loc) {}
+      : AstNode(AstNodeKind::Continue, source_loc) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Continue;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromContinue(node->kind());
   }
 
   // The AST node representing the loop this statement continues.
@@ -313,12 +295,12 @@ class Match : public Statement {
 
   Match(SourceLocation source_loc, Nonnull<Expression*> expression,
         std::vector<Clause> clauses)
-      : Statement(Kind::Match, source_loc),
+      : AstNode(AstNodeKind::Match, source_loc),
         expression_(expression),
         clauses_(std::move(clauses)) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Match;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromMatch(node->kind());
   }
 
   auto expression() const -> const Expression& { return *expression_; }
@@ -336,24 +318,16 @@ class Match : public Statement {
 //     __continuation <continuation_variable> {
 //       <body>
 //     }
-class Continuation : public Statement, public NamedEntityInterface {
+class Continuation : public Statement, public NamedEntity {
  public:
   Continuation(SourceLocation source_loc, std::string continuation_variable,
                Nonnull<Block*> body)
-      : Statement(Kind::Continuation, source_loc),
+      : AstNode(AstNodeKind::Continuation, source_loc),
         continuation_variable_(std::move(continuation_variable)),
         body_(body) {}
 
-  auto named_entity_kind() const -> NamedEntityKind override {
-    return NamedEntityKind::Continuation;
-  }
-
-  auto source_loc() const -> SourceLocation override {
-    return Statement::source_loc();
-  }
-
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Continuation;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromContinuation(node->kind());
   }
 
   auto continuation_variable() const -> const std::string& {
@@ -373,10 +347,10 @@ class Continuation : public Statement, public NamedEntityInterface {
 class Run : public Statement {
  public:
   Run(SourceLocation source_loc, Nonnull<Expression*> argument)
-      : Statement(Kind::Run, source_loc), argument_(argument) {}
+      : AstNode(AstNodeKind::Run, source_loc), argument_(argument) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Run;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromRun(node->kind());
   }
 
   auto argument() const -> const Expression& { return *argument_; }
@@ -392,10 +366,10 @@ class Run : public Statement {
 class Await : public Statement {
  public:
   explicit Await(SourceLocation source_loc)
-      : Statement(Kind::Await, source_loc) {}
+      : AstNode(AstNodeKind::Await, source_loc) {}
 
-  static auto classof(const Statement* stmt) -> bool {
-    return stmt->kind() == Kind::Await;
+  static auto classof(const AstNode* node) -> bool {
+    return InheritsFromAwait(node->kind());
   }
 };
 

+ 3 - 2
executable_semantics/ast/static_scope.cpp

@@ -8,8 +8,9 @@
 
 namespace Carbon {
 
-void StaticScope::Add(std::string name,
-                      Nonnull<const NamedEntityInterface*> entity) {
+NamedEntity::~NamedEntity() = default;
+
+void StaticScope::Add(std::string name, Nonnull<const NamedEntity*> entity) {
   if (!declared_names_.insert({name, entity}).second) {
     FATAL_COMPILATION_ERROR(entity->source_loc())
         << "Duplicate name `" << name << "` also found at "

+ 9 - 26
executable_semantics/ast/static_scope.h

@@ -10,50 +10,33 @@
 #include <variant>
 #include <vector>
 
+#include "executable_semantics/ast/ast_node.h"
 #include "executable_semantics/ast/source_location.h"
 #include "executable_semantics/common/nonnull.h"
 
 namespace Carbon {
 
-class NamedEntityInterface {
+class NamedEntity : public virtual AstNode {
  public:
-  enum class NamedEntityKind {
-    // Includes variable definitions and matching contexts.
-    BindingPattern,
-    // Used by entries in choices.
-    ChoiceDeclarationAlternative,
-    // Used by continuations.
-    Continuation,
-    // Includes choices, classes, and functions. Variables are handled through
-    // BindingPattern.
-    Declaration,
-    // Used by functions.
-    GenericBinding,
-    // Used by entries in classes.
-    Member,
-  };
+  virtual ~NamedEntity() = 0;
 
-  NamedEntityInterface() = default;
-  virtual ~NamedEntityInterface() = default;
-
-  NamedEntityInterface(NamedEntityInterface&&) = delete;
-  auto operator=(NamedEntityInterface&&) -> NamedEntityInterface& = delete;
+  NamedEntity() = default;
 
   // TODO: This is unused, but is intended for casts after lookup.
-  virtual auto named_entity_kind() const -> NamedEntityKind = 0;
-  virtual auto source_loc() const -> SourceLocation = 0;
+  auto kind() const -> NamedEntityKind {
+    return static_cast<NamedEntityKind>(root_kind());
+  }
 };
 
 // The set of declared names in a scope. This is not aware of child scopes, but
 // does include directions to parent or related scopes for lookup purposes.
 class StaticScope {
  public:
-  void Add(std::string name, Nonnull<const NamedEntityInterface*> entity);
+  void Add(std::string name, Nonnull<const NamedEntity*> entity);
 
  private:
   // Maps locally declared names to their entities.
-  std::unordered_map<std::string, Nonnull<const NamedEntityInterface*>>
-      declared_names_;
+  std::unordered_map<std::string, Nonnull<const NamedEntity*>> declared_names_;
 
   // A list of scopes used for name lookup within this scope.
   // TODO: This is unused, but is intended for name lookup cross-scope.

+ 332 - 0
executable_semantics/gen_rtti.py

@@ -0,0 +1,332 @@
+#!/usr/bin/env python3
+
+"""Generates C++ header to support LLVM-style RTTI for a class hierarchy.
+
+Takes as input a file describing the class hierarchy which can consist of
+four different kinds of classes: a *root* class is the base of a class
+hierarchy, meaning that it doesn't inherit from any other class. *Abstract* and
+*interface* classes are non-root classes that cannot be instantiated, and
+*concrete* classes are classes that can be instantiated.
+
+A non-root class C must inherit from exactly one parent, which can be a root or
+abstract class, and can also inherit from any number of interfaces, but each
+interface's parent must be an ancestor of C.
+
+The input file consists of comment lines starting with `#`, whitespace lines,
+and one `;`-terminated line for each class. The core of a line is `class`
+followed by the class name. `class` can be prefixed with `root`, `abstract`,
+or `interface` to specify the corresponding kind of class; if there is no
+prefix, the class is concrete. If the class is not a root class, the name is
+followed by `:` and then a comma-separated list of the names of the classes
+it inherits from. The first entry in the list is the parent, and the others
+are interfaces. A class cannot inherit from classes defined later in the file.
+For example:
+
+root class R;
+abstract class A : R;
+interface class I : R;
+abstract class B : R, I;
+class C : A;
+class D : B;
+class E : A, I;
+
+For each non-concrete class `Foo`, the generated header file will contain
+`enum class FooKind`, which has an enumerator for each concrete class derived
+from `Foo`, with a name that matches the concrete class name.
+
+For each non-root class `Foo` whose root class is `Root`, the generated header
+file will also contain a function `bool InheritsFromFoo(RootKind kind)`,
+which returns true if the value of `kind` corresponds to a class that is
+derived from `Foo`. This function can be used to implement `Foo::classof`.
+
+All enumerators that represent the same concrete class will have the same
+numeric value, so you can use `static_cast` to convert between the enum types
+for different classes that have a common root, so long as the enumerator value
+is present in both types. As a result, `InheritsFromFoo` can be used to
+determine whether casting to `FooKind` is safe.
+"""
+
+__copyright__ = """
+Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+Exceptions. See /LICENSE for license information.
+SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+"""
+
+import enum
+import re
+import sys
+
+
+class Class:
+    """Metadata about a class from the input file.
+
+    This consists of information
+
+    Attributes set at construction:
+      name: The class name.
+      kind: The class kind (root, abstract, interface, or concrete)
+      ancestors: A list of Class objects representing the class's ancestors,
+        starting with the root and ending with the current class's parent.
+      interfaces: A list of Class objects representing the interfaces the class
+        inherits from.
+      _children: A list of Class objects representing the classes that are
+        derived directly from this one.
+
+    Attributes set by Finalize():
+      id (CONCRETE only): The class's numeric ID, which will become its
+        enumerator value in the generated C++ code.
+      id_range (ROOT and ABSTRACT only): A pair such that a Class
+        object `c` represents a concrete class derived from `self` if and only
+        if c.id >= self.id_range[0] and c.id < self.id_range[1].
+      leaf_ids (INTERFACE only): A set containing the IDs of all concrete
+        classes derived from this interface.
+      leaves (ROOT only): A list of all concrete classes derived from this one,
+        indexed by their IDs.
+    """
+
+    Kind = enum.Enum("Kind", "ROOT ABSTRACT INTERFACE CONCRETE")
+
+    def __init__(self, name, kind, parent, interfaces):
+        self.name = name
+        self.kind = kind
+        self.interfaces = interfaces
+
+        assert (parent is None) == (kind == Class.Kind.ROOT)
+        if parent is None:
+            self.ancestors = []
+        else:
+            self.ancestors = parent.ancestors + [parent]
+
+        if self.kind == Class.Kind.ROOT:
+            self.leaves = []
+            self.id_range = None
+        elif self.kind == Class.Kind.ABSTRACT:
+            self.id_range = None
+        elif self.kind == Class.Kind.INTERFACE:
+            self.leaf_ids = set()
+        else:
+            self.id = None
+
+        if self.kind != Class.Kind.CONCRETE:
+            self._children = []
+
+        if parent:
+            parent._children.append(self)
+
+        for interface in self.interfaces:
+            interface._children.append(self)
+
+    def Parent(self):
+        """Returns this Class's parent."""
+        return self.ancestors[-1]
+
+    def Root(self):
+        """Returns the root Class of this hierarchy."""
+        if self.kind == Class.Kind.ROOT:
+            return self
+        else:
+            return self.ancestors[0]
+
+    def _RegisterLeaf(self, leaf):
+        """Records that `leaf` is derived from self.
+
+        Also recursively updates the parent and interfaces of self. leaf.id must
+        already be populated, and leaves must be registered in order of ID. This
+        operation is idempotent."""
+        already_visited = False
+        if self.kind == Class.Kind.ROOT:
+            if leaf.id == len(self.leaves):
+                self.leaves.append(leaf)
+            else:
+                assert leaf.id + 1 == len(self.leaves)
+                assert self.leaves[leaf.id] == leaf
+                already_visited = True
+        if self.kind in [Class.Kind.ROOT, Class.Kind.ABSTRACT]:
+            if self not in leaf.ancestors:
+                sys.exit(
+                    f"{leaf.name} derived from {self.name}, but has a"
+                    + " different root"
+                )
+            if not self.id_range:
+                self.id_range = (leaf.id, leaf.id + 1)
+            elif self.id_range[1] == leaf.id:
+                self.id_range = (self.id_range[0], self.id_range[1] + 1)
+            else:
+                assert self.id_range[1] == leaf.id + 1
+                already_visited = True
+
+        elif self.kind == Class.Kind.INTERFACE:
+            if leaf.id in self.leaf_ids:
+                already_visited = True
+            else:
+                self.leaf_ids.add(leaf.id)
+
+        if not already_visited:
+            if self.kind != Class.Kind.ROOT:
+                self.Parent()._RegisterLeaf(leaf)
+            for interface in self.interfaces:
+                interface._RegisterLeaf(leaf)
+
+    def Finalize(self):
+        """Populates additional attributes for `self` and derived Classes.
+
+        Each Class can only be finalized once, after which no additional Classes
+        can be derived from it.
+        """
+        if self.kind == Class.Kind.CONCRETE:
+            self.id = len(self.Root().leaves)
+            self._RegisterLeaf(self)
+        elif self.kind in [Class.Kind.ROOT, Class.Kind.ABSTRACT]:
+            for child in self._children:
+                child.Finalize()
+
+
+_LINE_PATTERN = r"""(?P<prefix> \w*) \s*
+                 class \s+
+                 (?P<name> \w+)
+                 (?: \s*:\s* (?P<parent> \w+)
+                   (?: , (?P<interfaces> .*) )?
+                 )?
+                 ;$"""
+
+
+def main():
+    input_filename = sys.argv[1]
+    with open(input_filename) as file:
+        lines = file.readlines()
+
+    classes = dict()
+    for line_num, line in enumerate(lines, 1):
+        if line.startswith("#") or line.strip() == "":
+            continue
+        match_result = re.match(_LINE_PATTERN, line.strip(), re.VERBOSE)
+        if not match_result:
+            sys.exit(f"Invalid format on line {line_num}")
+
+        prefix = match_result.group("prefix")
+        if prefix == "":
+            kind = Class.Kind.CONCRETE
+        elif prefix == "root":
+            kind = Class.Kind.ROOT
+        elif prefix == "abstract":
+            kind = Class.Kind.ABSTRACT
+        elif prefix == "interface":
+            kind = Class.Kind.INTERFACE
+        else:
+            sys.exit(f"Unrecognized class prefix '{prefix}' on line {line_num}")
+
+        parent = None
+        if match_result.group("parent"):
+            if kind == Class.Kind.ROOT:
+                sys.exit(f"Root class cannot have parent on line {line_num}")
+            parent_name = match_result.group("parent")
+            parent = classes[parent_name]
+            if not parent:
+                sys.exit(f"Unknown class '{parent_name}' on line {line_num}")
+            if parent.kind == Class.Kind.CONCRETE:
+                sys.exit(f"{parent.name} cannot be a parent on line {line_num}")
+            elif parent.kind == Class.Kind.INTERFACE:
+                if kind != Class.Kind.INTERFACE:
+                    sys.exit(
+                        "Interface cannot be parent of non-interface on"
+                        + f" line {line_num}"
+                    )
+        else:
+            if kind != Class.Kind.ROOT:
+                sys.exit(
+                    f"Non-root class must have a parent on line {line_num}"
+                )
+
+        interfaces = []
+        if match_result.group("interfaces"):
+            for unstripped_name in match_result.group("interfaces").split(","):
+                interface_name = unstripped_name.strip()
+                interface = classes[interface_name]
+                if not interface:
+                    sys.exit(
+                        f"Unknown class '{interface_name}' on line {line_num}"
+                    )
+                if interface.kind != Class.Kind.INTERFACE:
+                    sys.exit(
+                        f"'{interface_name}' used as interface on"
+                        + f" line {line_num}"
+                    )
+                interfaces.append(interface)
+
+        classes[match_result.group("name")] = Class(
+            match_result.group("name"), kind, parent, interfaces
+        )
+
+    for node in classes.values():
+        if node.kind == Class.Kind.ROOT:
+            node.Finalize()
+
+    print(
+        f"// Generated from {input_filename} by"
+        + " executable_semantics/gen_rtti.py\n"
+    )
+    guard_macro = (
+        input_filename.upper().translate(str.maketrans({"/": "_", ".": "_"}))
+        + "_"
+    )
+    print(f"#ifndef {guard_macro}")
+    print(f"#define {guard_macro}")
+    print("\nnamespace Carbon {\n")
+
+    for node in classes.values():
+        if node.kind != Class.Kind.CONCRETE:
+            if node.kind == Class.Kind.INTERFACE:
+                ids = sorted(node.leaf_ids)
+            else:
+                ids = range(node.id_range[0], node.id_range[1])
+            print(f"enum class {node.name}Kind {{")
+            for id in ids:
+                print(f"  {node.Root().leaves[id].name} = {id},")
+            print("};\n")
+
+        if node.kind != Class.Kind.ROOT:
+            print(
+                f"inline bool InheritsFrom{node.name}({node.Root().name}Kind"
+                + " kind) {"
+            )
+            if node.kind == Class.Kind.ABSTRACT:
+                if node.id_range[0] == node.id_range[1]:
+                    print("  return false;")
+                else:
+                    range_begin = node.Root().leaves[node.id_range[0]].name
+                    print(
+                        f"  return kind >= {node.Root().name}Kind"
+                        + f"::{range_begin}"
+                    )
+                    if node.id_range[1] < len(node.Root().leaves):
+                        range_end = node.Root().leaves[node.id_range[1]].name
+                        print(
+                            f"      && kind < {node.Root().name}Kind"
+                            + f"::{range_end}"
+                        )
+                    print("      ;")
+            elif node.kind == Class.Kind.INTERFACE:
+                print("  switch(kind) {")
+                is_empty = True
+                for id in sorted(node.leaf_ids):
+                    print(
+                        f"    case {node.Root().name}Kind::"
+                        + f"{node.Root().leaves[id].name}:"
+                    )
+                    is_empty = False
+                if not is_empty:
+                    print("      return true;")
+                print("    default:")
+                print("      return false;\n  }")
+            elif node.kind == Class.Kind.CONCRETE:
+                print(
+                    f"    return kind == {node.Root().name}Kind::{node.name};"
+                )
+            print("}\n")
+
+    print("}  // namespace Carbon\n")
+    print(f"#endif  // {guard_macro}")
+
+
+if __name__ == "__main__":
+    main()

+ 60 - 60
executable_semantics/interpreter/interpreter.cpp

@@ -111,7 +111,7 @@ auto Interpreter::EvalPrim(Operator op,
 
 void Interpreter::InitEnv(const Declaration& d, Env* env) {
   switch (d.kind()) {
-    case Declaration::Kind::FunctionDeclaration: {
+    case DeclarationKind::FunctionDeclaration: {
       const auto& func_def = cast<FunctionDeclaration>(d);
       Env new_env = *env;
       // Bring the deduced parameters into scope.
@@ -127,13 +127,13 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
       break;
     }
 
-    case Declaration::Kind::ClassDeclaration: {
+    case DeclarationKind::ClassDeclaration: {
       const auto& class_decl = cast<ClassDeclaration>(d);
       std::vector<NamedValue> fields;
       std::vector<NamedValue> methods;
       for (Nonnull<const Member*> m : class_decl.members()) {
         switch (m->kind()) {
-          case Member::Kind::FieldMember: {
+          case MemberKind::FieldMember: {
             const BindingPattern& binding = cast<FieldMember>(*m).binding();
             const Expression& type_expression =
                 cast<ExpressionPattern>(binding.type()).expression();
@@ -150,10 +150,10 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
       break;
     }
 
-    case Declaration::Kind::ChoiceDeclaration: {
+    case DeclarationKind::ChoiceDeclaration: {
       const auto& choice = cast<ChoiceDeclaration>(d);
       std::vector<NamedValue> alts;
-      for (Nonnull<const ChoiceDeclaration::Alternative*> alternative :
+      for (Nonnull<const AlternativeSignature*> alternative :
            choice.alternatives()) {
         auto t = InterpExp(Env(arena_), &alternative->signature());
         alts.push_back({.name = alternative->name(), .value = t});
@@ -164,7 +164,7 @@ void Interpreter::InitEnv(const Declaration& d, Env* env) {
       break;
     }
 
-    case Declaration::Kind::VariableDeclaration: {
+    case DeclarationKind::VariableDeclaration: {
       const auto& var = cast<VariableDeclaration>(d);
       // Adds an entry in `globals` mapping the variable's name to the
       // result of evaluating the initializer.
@@ -363,7 +363,7 @@ auto Interpreter::StepLvalue() -> Transition {
                  << ") --->\n";
   }
   switch (exp.kind()) {
-    case Expression::Kind::IdentifierExpression: {
+    case ExpressionKind::IdentifierExpression: {
       //    { {x :: C, E, F} :: S, H}
       // -> { {E(x) :: C, E, F} :: S, H}
       Address pointer =
@@ -371,7 +371,7 @@ auto Interpreter::StepLvalue() -> Transition {
       Nonnull<const Value*> v = arena_->New<PointerValue>(pointer);
       return Done{v};
     }
-    case Expression::Kind::FieldAccessExpression: {
+    case ExpressionKind::FieldAccessExpression: {
       if (act.pos() == 0) {
         //    { {e.f :: C, E, F} :: S, H}
         // -> { e :: [].f :: C, E, F} :: S, H}
@@ -386,7 +386,7 @@ auto Interpreter::StepLvalue() -> Transition {
         return Done{arena_->New<PointerValue>(field)};
       }
     }
-    case Expression::Kind::IndexExpression: {
+    case ExpressionKind::IndexExpression: {
       if (act.pos() == 0) {
         //    { {e[i] :: C, E, F} :: S, H}
         // -> { e :: [][i] :: C, E, F} :: S, H}
@@ -406,7 +406,7 @@ auto Interpreter::StepLvalue() -> Transition {
         return Done{arena_->New<PointerValue>(field)};
       }
     }
-    case Expression::Kind::TupleLiteral: {
+    case ExpressionKind::TupleLiteral: {
       if (act.pos() <
           static_cast<int>(cast<TupleLiteral>(exp).fields().size())) {
         //    { { vk :: (f1=v1,..., fk=[],fk+1=ek+1,...) :: C, E, F} :: S,
@@ -419,20 +419,20 @@ auto Interpreter::StepLvalue() -> Transition {
         return Done{arena_->New<TupleValue>(act.results())};
       }
     }
-    case Expression::Kind::StructLiteral:
-    case Expression::Kind::StructTypeLiteral:
-    case Expression::Kind::IntLiteral:
-    case Expression::Kind::BoolLiteral:
-    case Expression::Kind::CallExpression:
-    case Expression::Kind::PrimitiveOperatorExpression:
-    case Expression::Kind::IntTypeLiteral:
-    case Expression::Kind::BoolTypeLiteral:
-    case Expression::Kind::TypeTypeLiteral:
-    case Expression::Kind::FunctionTypeLiteral:
-    case Expression::Kind::ContinuationTypeLiteral:
-    case Expression::Kind::StringLiteral:
-    case Expression::Kind::StringTypeLiteral:
-    case Expression::Kind::IntrinsicExpression:
+    case ExpressionKind::StructLiteral:
+    case ExpressionKind::StructTypeLiteral:
+    case ExpressionKind::IntLiteral:
+    case ExpressionKind::BoolLiteral:
+    case ExpressionKind::CallExpression:
+    case ExpressionKind::PrimitiveOperatorExpression:
+    case ExpressionKind::IntTypeLiteral:
+    case ExpressionKind::BoolTypeLiteral:
+    case ExpressionKind::TypeTypeLiteral:
+    case ExpressionKind::FunctionTypeLiteral:
+    case ExpressionKind::ContinuationTypeLiteral:
+    case ExpressionKind::StringLiteral:
+    case ExpressionKind::StringTypeLiteral:
+    case ExpressionKind::IntrinsicExpression:
       FATAL_RUNTIME_ERROR_NO_LINE()
           << "Can't treat expression as lvalue: " << exp;
   }
@@ -513,7 +513,7 @@ auto Interpreter::StepExp() -> Transition {
                  << ") --->\n";
   }
   switch (exp.kind()) {
-    case Expression::Kind::IndexExpression: {
+    case ExpressionKind::IndexExpression: {
       if (act.pos() == 0) {
         //    { { e[i] :: C, E, F} :: S, H}
         // -> { { e :: [][i] :: C, E, F} :: S, H}
@@ -534,7 +534,7 @@ auto Interpreter::StepExp() -> Transition {
         return Done{tuple.elements()[i]};
       }
     }
-    case Expression::Kind::TupleLiteral: {
+    case ExpressionKind::TupleLiteral: {
       if (act.pos() <
           static_cast<int>(cast<TupleLiteral>(exp).fields().size())) {
         //    { { vk :: (f1=v1,..., fk=[],fk+1=ek+1,...) :: C, E, F} :: S,
@@ -547,7 +547,7 @@ auto Interpreter::StepExp() -> Transition {
         return Done{arena_->New<TupleValue>(act.results())};
       }
     }
-    case Expression::Kind::StructLiteral: {
+    case ExpressionKind::StructLiteral: {
       const auto& literal = cast<StructLiteral>(exp);
       if (act.pos() < static_cast<int>(literal.fields().size())) {
         return Spawn{std::make_unique<ExpressionAction>(
@@ -556,7 +556,7 @@ auto Interpreter::StepExp() -> Transition {
         return Done{CreateStruct(literal.fields(), act.results())};
       }
     }
-    case Expression::Kind::StructTypeLiteral: {
+    case ExpressionKind::StructTypeLiteral: {
       const auto& struct_type = cast<StructTypeLiteral>(exp);
       if (act.pos() < static_cast<int>(struct_type.fields().size())) {
         return Spawn{std::make_unique<ExpressionAction>(
@@ -569,7 +569,7 @@ auto Interpreter::StepExp() -> Transition {
         return Done{arena_->New<StructType>(std::move(fields))};
       }
     }
-    case Expression::Kind::FieldAccessExpression: {
+    case ExpressionKind::FieldAccessExpression: {
       const auto& access = cast<FieldAccessExpression>(exp);
       if (act.pos() == 0) {
         //    { { e.f :: C, E, F} :: S, H}
@@ -582,22 +582,22 @@ auto Interpreter::StepExp() -> Transition {
             arena_, FieldPath(access.field()), exp.source_loc())};
       }
     }
-    case Expression::Kind::IdentifierExpression: {
+    case ExpressionKind::IdentifierExpression: {
       CHECK(act.pos() == 0);
       const auto& ident = cast<IdentifierExpression>(exp);
       // { {x :: C, E, F} :: S, H} -> { {H(E(x)) :: C, E, F} :: S, H}
       Address pointer = GetFromEnv(exp.source_loc(), ident.name());
       return Done{heap_.Read(pointer, exp.source_loc())};
     }
-    case Expression::Kind::IntLiteral:
+    case ExpressionKind::IntLiteral:
       CHECK(act.pos() == 0);
       // { {n :: C, E, F} :: S, H} -> { {n' :: C, E, F} :: S, H}
       return Done{arena_->New<IntValue>(cast<IntLiteral>(exp).value())};
-    case Expression::Kind::BoolLiteral:
+    case ExpressionKind::BoolLiteral:
       CHECK(act.pos() == 0);
       // { {n :: C, E, F} :: S, H} -> { {n' :: C, E, F} :: S, H}
       return Done{arena_->New<BoolValue>(cast<BoolLiteral>(exp).value())};
-    case Expression::Kind::PrimitiveOperatorExpression: {
+    case ExpressionKind::PrimitiveOperatorExpression: {
       const auto& op = cast<PrimitiveOperatorExpression>(exp);
       if (act.pos() != static_cast<int>(op.arguments().size())) {
         //    { {v :: op(vs,[],e,es) :: C, E, F} :: S, H}
@@ -610,7 +610,7 @@ auto Interpreter::StepExp() -> Transition {
         return Done{EvalPrim(op.op(), act.results(), exp.source_loc())};
       }
     }
-    case Expression::Kind::CallExpression:
+    case ExpressionKind::CallExpression:
       if (act.pos() == 0) {
         //    { {e1(e2) :: C, E, F} :: S, H}
         // -> { {e1 :: [](e2) :: C, E, F} :: S, H}
@@ -651,7 +651,7 @@ auto Interpreter::StepExp() -> Transition {
       } else {
         FATAL() << "in handle_value with Call pos " << act.pos();
       }
-    case Expression::Kind::IntrinsicExpression:
+    case ExpressionKind::IntrinsicExpression:
       CHECK(act.pos() == 0);
       // { {n :: C, E, F} :: S, H} -> { {n' :: C, E, F} :: S, H}
       switch (cast<IntrinsicExpression>(exp).intrinsic()) {
@@ -664,19 +664,19 @@ auto Interpreter::StepExp() -> Transition {
           return Done{TupleValue::Empty()};
       }
 
-    case Expression::Kind::IntTypeLiteral: {
+    case ExpressionKind::IntTypeLiteral: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<IntType>()};
     }
-    case Expression::Kind::BoolTypeLiteral: {
+    case ExpressionKind::BoolTypeLiteral: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<BoolType>()};
     }
-    case Expression::Kind::TypeTypeLiteral: {
+    case ExpressionKind::TypeTypeLiteral: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<TypeType>()};
     }
-    case Expression::Kind::FunctionTypeLiteral: {
+    case ExpressionKind::FunctionTypeLiteral: {
       if (act.pos() == 0) {
         return Spawn{std::make_unique<ExpressionAction>(
             &cast<FunctionTypeLiteral>(exp).parameter())};
@@ -693,15 +693,15 @@ auto Interpreter::StepExp() -> Transition {
             act.results()[1])};
       }
     }
-    case Expression::Kind::ContinuationTypeLiteral: {
+    case ExpressionKind::ContinuationTypeLiteral: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<ContinuationType>()};
     }
-    case Expression::Kind::StringLiteral:
+    case ExpressionKind::StringLiteral:
       CHECK(act.pos() == 0);
       // { {n :: C, E, F} :: S, H} -> { {n' :: C, E, F} :: S, H}
       return Done{arena_->New<StringValue>(cast<StringLiteral>(exp).value())};
-    case Expression::Kind::StringTypeLiteral: {
+    case ExpressionKind::StringTypeLiteral: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<StringType>()};
     }
@@ -716,11 +716,11 @@ auto Interpreter::StepPattern() -> Transition {
                  << pattern.source_loc() << ") --->\n";
   }
   switch (pattern.kind()) {
-    case Pattern::Kind::AutoPattern: {
+    case PatternKind::AutoPattern: {
       CHECK(act.pos() == 0);
       return Done{arena_->New<AutoType>()};
     }
-    case Pattern::Kind::BindingPattern: {
+    case PatternKind::BindingPattern: {
       const auto& binding = cast<BindingPattern>(pattern);
       if (act.pos() == 0) {
         return Spawn{std::make_unique<PatternAction>(&binding.type())};
@@ -729,7 +729,7 @@ auto Interpreter::StepPattern() -> Transition {
                                                          act.results()[0])};
       }
     }
-    case Pattern::Kind::TuplePattern: {
+    case PatternKind::TuplePattern: {
       const auto& tuple = cast<TuplePattern>(pattern);
       if (act.pos() < static_cast<int>(tuple.fields().size())) {
         //    { { vk :: (f1=v1,..., fk=[],fk+1=ek+1,...) :: C, E, F} :: S,
@@ -742,7 +742,7 @@ auto Interpreter::StepPattern() -> Transition {
         return Done{arena_->New<TupleValue>(act.results())};
       }
     }
-    case Pattern::Kind::AlternativePattern: {
+    case PatternKind::AlternativePattern: {
       const auto& alternative = cast<AlternativePattern>(pattern);
       if (act.pos() == 0) {
         return Spawn{
@@ -757,7 +757,7 @@ auto Interpreter::StepPattern() -> Transition {
             act.results()[1])};
       }
     }
-    case Pattern::Kind::ExpressionPattern:
+    case PatternKind::ExpressionPattern:
       return Delegate{std::make_unique<ExpressionAction>(
           &cast<ExpressionPattern>(pattern).expression())};
   }
@@ -777,7 +777,7 @@ auto Interpreter::StepStmt() -> Transition {
     llvm::outs() << " (" << stmt.source_loc() << ") --->\n";
   }
   switch (stmt.kind()) {
-    case Statement::Kind::Match: {
+    case StatementKind::Match: {
       const auto& match_stmt = cast<Match>(stmt);
       if (act.pos() == 0) {
         //    { { (match (e) ...) :: C, E, F} :: S, H}
@@ -808,7 +808,7 @@ auto Interpreter::StepStmt() -> Transition {
         }
       }
     }
-    case Statement::Kind::While:
+    case StatementKind::While:
       if (act.pos() % 2 == 0) {
         //    { { (while (e) s) :: C, E, F} :: S, H}
         // -> { { e :: (while ([]) s) :: C, E, F} :: S, H}
@@ -829,19 +829,19 @@ auto Interpreter::StepStmt() -> Transition {
           return Done{};
         }
       }
-    case Statement::Kind::Break: {
+    case StatementKind::Break: {
       CHECK(act.pos() == 0);
       //    { { break; :: ... :: (while (e) s) :: C, E, F} :: S, H}
       // -> { { C, E', F} :: S, H}
       return UnwindPast{.ast_node = &cast<Break>(stmt).loop()};
     }
-    case Statement::Kind::Continue: {
+    case StatementKind::Continue: {
       CHECK(act.pos() == 0);
       //    { { continue; :: ... :: (while (e) s) :: C, E, F} :: S, H}
       // -> { { (while (e) s) :: C, E', F} :: S, H}
       return UnwindTo{.ast_node = &cast<Continue>(stmt).loop()};
     }
-    case Statement::Kind::Block: {
+    case StatementKind::Block: {
       const auto& block = cast<Block>(stmt);
       if (act.pos() >= static_cast<int>(block.statements().size())) {
         // If the position is past the end of the block, end processing. Note
@@ -857,7 +857,7 @@ auto Interpreter::StepStmt() -> Transition {
       return Spawn{
           std::make_unique<StatementAction>(block.statements()[act.pos()])};
     }
-    case Statement::Kind::VariableDefinition: {
+    case StatementKind::VariableDefinition: {
       const auto& definition = cast<VariableDefinition>(stmt);
       if (act.pos() == 0) {
         //    { {(var x = e) :: C, E, F} :: S, H}
@@ -882,7 +882,7 @@ auto Interpreter::StepStmt() -> Transition {
         return Done{};
       }
     }
-    case Statement::Kind::ExpressionStatement:
+    case StatementKind::ExpressionStatement:
       if (act.pos() == 0) {
         //    { {e :: C, E, F} :: S, H}
         // -> { {e :: C, E, F} :: S, H}
@@ -891,7 +891,7 @@ auto Interpreter::StepStmt() -> Transition {
       } else {
         return Done{};
       }
-    case Statement::Kind::Assign: {
+    case StatementKind::Assign: {
       const auto& assign = cast<Assign>(stmt);
       if (act.pos() == 0) {
         //    { {(lv = e) :: C, E, F} :: S, H}
@@ -910,7 +910,7 @@ auto Interpreter::StepStmt() -> Transition {
         return Done{};
       }
     }
-    case Statement::Kind::If:
+    case StatementKind::If:
       if (act.pos() == 0) {
         //    { {(if (e) then_stmt else else_stmt) :: C, E, F} :: S, H}
         // -> { { e :: (if ([]) then_stmt else else_stmt) :: C, E, F} :: S, H}
@@ -935,7 +935,7 @@ auto Interpreter::StepStmt() -> Transition {
           return Done{};
         }
       }
-    case Statement::Kind::Return:
+    case StatementKind::Return:
       if (act.pos() == 0) {
         //    { {return e :: C, E, F} :: S, H}
         // -> { {e :: return [] :: C, E, F} :: S, H}
@@ -950,7 +950,7 @@ auto Interpreter::StepStmt() -> Transition {
             .result = Convert(act.results()[0],
                               &function.return_term().static_type())};
       }
-    case Statement::Kind::Continuation: {
+    case StatementKind::Continuation: {
       CHECK(act.pos() == 0);
       // Create a continuation object by creating a frame similar the
       // way one is created in a function call.
@@ -969,7 +969,7 @@ auto Interpreter::StepStmt() -> Transition {
                               continuation_address);
       return Done{};
     }
-    case Statement::Kind::Run: {
+    case StatementKind::Run: {
       auto& run = cast<Run>(stmt);
       if (act.pos() == 0) {
         // Evaluate the argument of the run statement.
@@ -985,7 +985,7 @@ auto Interpreter::StepStmt() -> Transition {
         return Done{};
       }
     }
-    case Statement::Kind::Await:
+    case StatementKind::Await:
       CHECK(act.pos() == 0);
       // Pause the current continuation
       todo_.Pop();

+ 14 - 14
executable_semantics/interpreter/resolve_control_flow.cpp

@@ -34,7 +34,7 @@ static void ResolveControlFlow(Nonnull<Statement*> statement,
                                std::optional<Nonnull<const Statement*>> loop,
                                std::optional<Nonnull<FunctionData*>> function) {
   switch (statement->kind()) {
-    case Statement::Kind::Return: {
+    case StatementKind::Return: {
       if (!function.has_value()) {
         FATAL_COMPILATION_ERROR(statement->source_loc())
             << "return is not within a function body";
@@ -58,21 +58,21 @@ static void ResolveControlFlow(Nonnull<Statement*> statement,
       }
       return;
     }
-    case Statement::Kind::Break:
+    case StatementKind::Break:
       if (!loop.has_value()) {
         FATAL_COMPILATION_ERROR(statement->source_loc())
             << "break is not within a loop body";
       }
       cast<Break>(*statement).set_loop(*loop);
       return;
-    case Statement::Kind::Continue:
+    case StatementKind::Continue:
       if (!loop.has_value()) {
         FATAL_COMPILATION_ERROR(statement->source_loc())
             << "continue is not within a loop body";
       }
       cast<Continue>(*statement).set_loop(*loop);
       return;
-    case Statement::Kind::If: {
+    case StatementKind::If: {
       auto& if_stmt = cast<If>(*statement);
       ResolveControlFlow(&if_stmt.then_block(), loop, function);
       if (if_stmt.else_block().has_value()) {
@@ -80,39 +80,39 @@ static void ResolveControlFlow(Nonnull<Statement*> statement,
       }
       return;
     }
-    case Statement::Kind::Block: {
+    case StatementKind::Block: {
       auto& block = cast<Block>(*statement);
       for (auto* block_statement : block.statements()) {
         ResolveControlFlow(block_statement, loop, function);
       }
       return;
     }
-    case Statement::Kind::While:
+    case StatementKind::While:
       ResolveControlFlow(&cast<While>(*statement).body(), statement, function);
       return;
-    case Statement::Kind::Match: {
+    case StatementKind::Match: {
       auto& match = cast<Match>(*statement);
       for (Match::Clause& clause : match.clauses()) {
         ResolveControlFlow(&clause.statement(), loop, function);
       }
       return;
     }
-    case Statement::Kind::Continuation:
+    case StatementKind::Continuation:
       ResolveControlFlow(&cast<Continuation>(*statement).body(), std::nullopt,
                          std::nullopt);
       return;
-    case Statement::Kind::ExpressionStatement:
-    case Statement::Kind::Assign:
-    case Statement::Kind::VariableDefinition:
-    case Statement::Kind::Run:
-    case Statement::Kind::Await:
+    case StatementKind::ExpressionStatement:
+    case StatementKind::Assign:
+    case StatementKind::VariableDefinition:
+    case StatementKind::Run:
+    case StatementKind::Await:
       return;
   }
 }
 
 void ResolveControlFlow(AST& ast) {
   for (auto declaration : ast.declarations) {
-    if (declaration->kind() != Declaration::Kind::FunctionDeclaration) {
+    if (declaration->kind() != DeclarationKind::FunctionDeclaration) {
       continue;
     }
     auto& function = cast<FunctionDeclaration>(*declaration);

+ 28 - 29
executable_semantics/interpreter/resolve_names.cpp

@@ -17,27 +17,27 @@ namespace {
 // flow.
 void PopulateNamesInPattern(const Pattern& pattern, StaticScope& static_scope) {
   switch (pattern.kind()) {
-    case Pattern::Kind::AlternativePattern: {
+    case PatternKind::AlternativePattern: {
       const auto& alt = cast<AlternativePattern>(pattern);
       PopulateNamesInPattern(alt.arguments(), static_scope);
       break;
     }
-    case Pattern::Kind::BindingPattern: {
+    case PatternKind::BindingPattern: {
       const auto& binding = cast<BindingPattern>(pattern);
       if (binding.name().has_value()) {
         static_scope.Add(*binding.name(), &binding);
       }
       break;
     }
-    case Pattern::Kind::TuplePattern: {
+    case PatternKind::TuplePattern: {
       const auto& tuple = cast<TuplePattern>(pattern);
       for (auto* field : tuple.fields()) {
         PopulateNamesInPattern(*field, static_scope);
       }
       break;
     }
-    case Pattern::Kind::AutoPattern:
-    case Pattern::Kind::ExpressionPattern:
+    case PatternKind::AutoPattern:
+    case PatternKind::ExpressionPattern:
       // These don't add names.
       break;
   }
@@ -53,7 +53,7 @@ void PopulateNamesInStatement(Arena* arena,
   }
   Statement& statement = **opt_statement;
   switch (statement.kind()) {
-    case Statement::Kind::Block: {
+    case StatementKind::Block: {
       // Defines a new scope for names.
       auto& block = cast<Block>(statement);
       for (const auto& statement : block.statements()) {
@@ -61,33 +61,33 @@ void PopulateNamesInStatement(Arena* arena,
       }
       break;
     }
-    case Statement::Kind::Continuation: {
+    case StatementKind::Continuation: {
       // Defines a new name and contains a block.
       auto& cont = cast<Continuation>(statement);
       static_scope.Add(cont.continuation_variable(), &cont);
       PopulateNamesInStatement(arena, &cont.body(), static_scope);
       break;
     }
-    case Statement::Kind::VariableDefinition: {
+    case StatementKind::VariableDefinition: {
       // Defines a new name.
       const auto& var = cast<VariableDefinition>(statement);
       PopulateNamesInPattern(var.pattern(), static_scope);
       break;
     }
-    case Statement::Kind::If: {
+    case StatementKind::If: {
       // Contains blocks.
       auto& if_stmt = cast<If>(statement);
       PopulateNamesInStatement(arena, &if_stmt.then_block(), static_scope);
       PopulateNamesInStatement(arena, if_stmt.else_block(), static_scope);
       break;
     }
-    case Statement::Kind::While: {
+    case StatementKind::While: {
       // Contains a block.
       auto& while_stmt = cast<While>(statement);
       PopulateNamesInStatement(arena, &while_stmt.body(), static_scope);
       break;
     }
-    case Statement::Kind::Match: {
+    case StatementKind::Match: {
       // Contains blocks.
       auto& match = cast<Match>(statement);
       for (auto& clause : match.clauses()) {
@@ -97,13 +97,13 @@ void PopulateNamesInStatement(Arena* arena,
       }
       break;
     }
-    case Statement::Kind::Assign:
-    case Statement::Kind::Await:
-    case Statement::Kind::Break:
-    case Statement::Kind::Continue:
-    case Statement::Kind::ExpressionStatement:
-    case Statement::Kind::Return:
-    case Statement::Kind::Run:
+    case StatementKind::Assign:
+    case StatementKind::Await:
+    case StatementKind::Break:
+    case StatementKind::Continue:
+    case StatementKind::ExpressionStatement:
+    case StatementKind::Return:
+    case StatementKind::Run:
       // Neither contains names nor a scope.
       break;
   }
@@ -114,7 +114,7 @@ void PopulateNamesInStatement(Arena* arena,
 void PopulateNamesInMember(Arena* arena, const Member& member,
                            StaticScope& static_scope) {
   switch (member.kind()) {
-    case Member::Kind::FieldMember: {
+    case MemberKind::FieldMember: {
       const auto& field = cast<FieldMember>(member);
       if (field.binding().name().has_value()) {
         static_scope.Add(*field.binding().name(), &member);
@@ -130,7 +130,7 @@ void PopulateNamesInMember(Arena* arena, const Member& member,
 void PopulateNamesInDeclaration(Arena* arena, Declaration& declaration,
                                 StaticScope& static_scope) {
   switch (declaration.kind()) {
-    case Declaration::Kind::FunctionDeclaration: {
+    case DeclarationKind::FunctionDeclaration: {
       auto& func = cast<FunctionDeclaration>(declaration);
       static_scope.Add(func.name(), &declaration);
       for (Nonnull<const GenericBinding*> param : func.deduced_parameters()) {
@@ -140,7 +140,7 @@ void PopulateNamesInDeclaration(Arena* arena, Declaration& declaration,
       PopulateNamesInStatement(arena, func.body(), static_scope);
       break;
     }
-    case Declaration::Kind::ClassDeclaration: {
+    case DeclarationKind::ClassDeclaration: {
       auto& class_decl = cast<ClassDeclaration>(declaration);
       static_scope.Add(class_decl.name(), &declaration);
       for (auto* member : class_decl.members()) {
@@ -148,11 +148,10 @@ void PopulateNamesInDeclaration(Arena* arena, Declaration& declaration,
       }
       break;
     }
-    case Declaration::Kind::ChoiceDeclaration: {
+    case DeclarationKind::ChoiceDeclaration: {
       auto& choice = cast<ChoiceDeclaration>(declaration);
       static_scope.Add(choice.name(), &declaration);
-      for (Nonnull<const ChoiceDeclaration::Alternative*> alt :
-           choice.alternatives()) {
+      for (Nonnull<const AlternativeSignature*> alt : choice.alternatives()) {
         choice.static_scope().Add(alt->name(), alt);
       }
       // Populate name into declared_names.
@@ -160,7 +159,7 @@ void PopulateNamesInDeclaration(Arena* arena, Declaration& declaration,
       // alternatives.
       break;
     }
-    case Declaration::Kind::VariableDeclaration:
+    case DeclarationKind::VariableDeclaration:
       auto& var = cast<VariableDeclaration>(declaration);
       if (var.binding().name().has_value()) {
         static_scope.Add(*(var.binding().name()), &var.binding());
@@ -177,10 +176,10 @@ void PopulateNamesInDeclaration(Arena* arena, Declaration& declaration,
 void ResolveNamesInDeclaration(Declaration& declaration,
                                const StaticScope& static_scope) {
   switch (declaration.kind()) {
-    case Declaration::Kind::FunctionDeclaration:
-    case Declaration::Kind::ClassDeclaration:
-    case Declaration::Kind::ChoiceDeclaration:
-    case Declaration::Kind::VariableDeclaration:
+    case DeclarationKind::FunctionDeclaration:
+    case DeclarationKind::ClassDeclaration:
+    case DeclarationKind::ChoiceDeclaration:
+    case DeclarationKind::VariableDeclaration:
       break;
   }
 }

+ 64 - 64
executable_semantics/interpreter/type_checker.cpp

@@ -434,7 +434,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
     llvm::outs() << "\n";
   }
   switch (e->kind()) {
-    case Expression::Kind::IndexExpression: {
+    case ExpressionKind::IndexExpression: {
       auto& index = cast<IndexExpression>(*e);
       auto res = TypeCheckExp(&index.aggregate(), types, values);
       const Value& aggregate_type = index.aggregate().static_type();
@@ -455,7 +455,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
           FATAL_COMPILATION_ERROR(e->source_loc()) << "expected a tuple";
       }
     }
-    case Expression::Kind::TupleLiteral: {
+    case ExpressionKind::TupleLiteral: {
       std::vector<Nonnull<const Value*>> arg_types;
       auto new_types = types;
       for (auto& arg : cast<TupleLiteral>(*e).fields()) {
@@ -466,7 +466,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       SetStaticType(e, arena_->New<TupleValue>(std::move(arg_types)));
       return TCResult(new_types);
     }
-    case Expression::Kind::StructLiteral: {
+    case ExpressionKind::StructLiteral: {
       std::vector<FieldInitializer> new_args;
       std::vector<NamedValue> arg_types;
       auto new_types = types;
@@ -479,7 +479,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       SetStaticType(e, arena_->New<StructType>(std::move(arg_types)));
       return TCResult(new_types);
     }
-    case Expression::Kind::StructTypeLiteral: {
+    case ExpressionKind::StructTypeLiteral: {
       auto& struct_type = cast<StructTypeLiteral>(*e);
       std::vector<FieldInitializer> new_args;
       auto new_types = types;
@@ -501,7 +501,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       }
       return TCResult(new_types);
     }
-    case Expression::Kind::FieldAccessExpression: {
+    case ExpressionKind::FieldAccessExpression: {
       auto& access = cast<FieldAccessExpression>(*e);
       auto res = TypeCheckExp(&access.aggregate(), types, values);
       const Value& aggregate_type = access.aggregate().static_type();
@@ -559,7 +559,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
               << *e;
       }
     }
-    case Expression::Kind::IdentifierExpression: {
+    case ExpressionKind::IdentifierExpression: {
       auto& ident = cast<IdentifierExpression>(*e);
       std::optional<Nonnull<const Value*>> type = types.Get(ident.name());
       if (type) {
@@ -570,13 +570,13 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
             << "could not find `" << ident.name() << "`";
       }
     }
-    case Expression::Kind::IntLiteral:
+    case ExpressionKind::IntLiteral:
       SetStaticType(e, arena_->New<IntType>());
       return TCResult(types);
-    case Expression::Kind::BoolLiteral:
+    case ExpressionKind::BoolLiteral:
       SetStaticType(e, arena_->New<BoolType>());
       return TCResult(types);
-    case Expression::Kind::PrimitiveOperatorExpression: {
+    case ExpressionKind::PrimitiveOperatorExpression: {
       auto& op = cast<PrimitiveOperatorExpression>(*e);
       std::vector<Nonnull<Expression*>> es;
       std::vector<Nonnull<const Value*>> ts;
@@ -647,7 +647,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       }
       break;
     }
-    case Expression::Kind::CallExpression: {
+    case ExpressionKind::CallExpression: {
       auto& call = cast<CallExpression>(*e);
       auto fun_res = TypeCheckExp(&call.function(), types, values);
       switch (call.function().static_type().kind()) {
@@ -687,7 +687,7 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       }
       break;
     }
-    case Expression::Kind::FunctionTypeLiteral: {
+    case ExpressionKind::FunctionTypeLiteral: {
       auto& fn = cast<FunctionTypeLiteral>(*e);
       ExpectIsConcreteType(fn.parameter().source_loc(),
                            interpreter_.InterpExp(values, &fn.parameter()));
@@ -696,20 +696,20 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e, TypeEnv types,
       SetStaticType(&fn, arena_->New<TypeType>());
       return TCResult(types);
     }
-    case Expression::Kind::StringLiteral:
+    case ExpressionKind::StringLiteral:
       SetStaticType(e, arena_->New<StringType>());
       return TCResult(types);
-    case Expression::Kind::IntrinsicExpression:
+    case ExpressionKind::IntrinsicExpression:
       switch (cast<IntrinsicExpression>(*e).intrinsic()) {
         case IntrinsicExpression::Intrinsic::Print:
           SetStaticType(e, TupleValue::Empty());
           return TCResult(types);
       }
-    case Expression::Kind::IntTypeLiteral:
-    case Expression::Kind::BoolTypeLiteral:
-    case Expression::Kind::StringTypeLiteral:
-    case Expression::Kind::TypeTypeLiteral:
-    case Expression::Kind::ContinuationTypeLiteral:
+    case ExpressionKind::IntTypeLiteral:
+    case ExpressionKind::BoolTypeLiteral:
+    case ExpressionKind::StringTypeLiteral:
+    case ExpressionKind::TypeTypeLiteral:
+    case ExpressionKind::ContinuationTypeLiteral:
       SetStaticType(e, arena_->New<TypeType>());
       return TCResult(types);
   }
@@ -730,11 +730,11 @@ auto TypeChecker::TypeCheckPattern(
     llvm::outs() << "\n";
   }
   switch (p->kind()) {
-    case Pattern::Kind::AutoPattern: {
+    case PatternKind::AutoPattern: {
       SetStaticType(p, arena_->New<TypeType>());
       return TCResult(types);
     }
-    case Pattern::Kind::BindingPattern: {
+    case PatternKind::BindingPattern: {
       auto& binding = cast<BindingPattern>(*p);
       TypeCheckPattern(&binding.type(), types, values, std::nullopt);
       Nonnull<const Value*> type =
@@ -763,7 +763,7 @@ auto TypeChecker::TypeCheckPattern(
       SetValue(&binding, interpreter_.InterpPattern(values, &binding));
       return TCResult(types);
     }
-    case Pattern::Kind::TuplePattern: {
+    case PatternKind::TuplePattern: {
       auto& tuple = cast<TuplePattern>(*p);
       std::vector<Nonnull<const Value*>> field_types;
       auto new_types = types;
@@ -790,7 +790,7 @@ auto TypeChecker::TypeCheckPattern(
       SetValue(&tuple, interpreter_.InterpPattern(values, &tuple));
       return TCResult(new_types);
     }
-    case Pattern::Kind::AlternativePattern: {
+    case PatternKind::AlternativePattern: {
       auto& alternative = cast<AlternativePattern>(*p);
       Nonnull<const Value*> choice_type =
           interpreter_.InterpExp(values, &alternative.choice_type());
@@ -816,7 +816,7 @@ auto TypeChecker::TypeCheckPattern(
       SetValue(&alternative, interpreter_.InterpPattern(values, &alternative));
       return TCResult(arg_results.types);
     }
-    case Pattern::Kind::ExpressionPattern: {
+    case PatternKind::ExpressionPattern: {
       auto& expression = cast<ExpressionPattern>(*p).expression();
       TCResult result = TypeCheckExp(&expression, types, values);
       SetStaticType(p, &expression.static_type());
@@ -837,7 +837,7 @@ auto TypeChecker::TypeCheckCase(Nonnull<const Value*> expected,
 auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
                                 Env values) -> TCResult {
   switch (s->kind()) {
-    case Statement::Kind::Match: {
+    case StatementKind::Match: {
       auto& match = cast<Match>(*s);
       TypeCheckExp(&match.expression(), types, values);
       std::vector<Match::Clause> new_clauses;
@@ -848,7 +848,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       }
       return TCResult(types);
     }
-    case Statement::Kind::While: {
+    case StatementKind::While: {
       auto& while_stmt = cast<While>(*s);
       TypeCheckExp(&while_stmt.condition(), types, values);
       ExpectType(s->source_loc(), "condition of `while`",
@@ -857,10 +857,10 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       TypeCheckStmt(&while_stmt.body(), types, values);
       return TCResult(types);
     }
-    case Statement::Kind::Break:
-    case Statement::Kind::Continue:
+    case StatementKind::Break:
+    case StatementKind::Continue:
       return TCResult(types);
-    case Statement::Kind::Block: {
+    case StatementKind::Block: {
       auto& block = cast<Block>(*s);
       for (auto* block_statement : block.statements()) {
         auto result = TypeCheckStmt(block_statement, types, values);
@@ -868,14 +868,14 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       }
       return TCResult(types);
     }
-    case Statement::Kind::VariableDefinition: {
+    case StatementKind::VariableDefinition: {
       auto& var = cast<VariableDefinition>(*s);
       TypeCheckExp(&var.init(), types, values);
       const Value& rhs_ty = var.init().static_type();
       auto lhs_res = TypeCheckPattern(&var.pattern(), types, values, &rhs_ty);
       return TCResult(lhs_res.types);
     }
-    case Statement::Kind::Assign: {
+    case StatementKind::Assign: {
       auto& assign = cast<Assign>(*s);
       TypeCheckExp(&assign.rhs(), types, values);
       auto lhs_res = TypeCheckExp(&assign.lhs(), types, values);
@@ -883,11 +883,11 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
                  &assign.rhs().static_type());
       return TCResult(lhs_res.types);
     }
-    case Statement::Kind::ExpressionStatement: {
+    case StatementKind::ExpressionStatement: {
       TypeCheckExp(&cast<ExpressionStatement>(*s).expression(), types, values);
       return TCResult(types);
     }
-    case Statement::Kind::If: {
+    case StatementKind::If: {
       auto& if_stmt = cast<If>(*s);
       TypeCheckExp(&if_stmt.condition(), types, values);
       ExpectType(s->source_loc(), "condition of `if`", arena_->New<BoolType>(),
@@ -898,7 +898,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       }
       return TCResult(types);
     }
-    case Statement::Kind::Return: {
+    case StatementKind::Return: {
       auto& ret = cast<Return>(*s);
       TypeCheckExp(&ret.expression(), types, values);
       ReturnTerm& return_term = ret.function().return_term();
@@ -910,13 +910,13 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
       }
       return TCResult(types);
     }
-    case Statement::Kind::Continuation: {
+    case StatementKind::Continuation: {
       auto& cont = cast<Continuation>(*s);
       TypeCheckStmt(&cont.body(), types, values);
       types.Set(cont.continuation_variable(), arena_->New<ContinuationType>());
       return TCResult(types);
     }
-    case Statement::Kind::Run: {
+    case StatementKind::Run: {
       auto& run = cast<Run>(*s);
       TypeCheckExp(&run.argument(), types, values);
       ExpectType(s->source_loc(), "argument of `run`",
@@ -924,7 +924,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
                  &run.argument().static_type());
       return TCResult(types);
     }
-    case Statement::Kind::Await: {
+    case StatementKind::Await: {
       // nothing to do here
       return TCResult(types);
     }
@@ -939,7 +939,7 @@ auto TypeChecker::TypeCheckStmt(Nonnull<Statement*> s, TypeEnv types,
 static auto IsExhaustive(const Match& match) -> bool {
   for (const Match::Clause& clause : match.clauses()) {
     // A pattern consisting of a single variable binding is guaranteed to match.
-    if (clause.pattern().kind() == Pattern::Kind::BindingPattern) {
+    if (clause.pattern().kind() == PatternKind::BindingPattern) {
       return true;
     }
   }
@@ -955,7 +955,7 @@ void TypeChecker::ExpectReturnOnAllPaths(
   }
   Nonnull<Statement*> stmt = *opt_stmt;
   switch (stmt->kind()) {
-    case Statement::Kind::Match: {
+    case StatementKind::Match: {
       auto& match = cast<Match>(*stmt);
       if (!IsExhaustive(match)) {
         FATAL_COMPILATION_ERROR(source_loc)
@@ -968,7 +968,7 @@ void TypeChecker::ExpectReturnOnAllPaths(
       }
       return;
     }
-    case Statement::Kind::Block: {
+    case StatementKind::Block: {
       auto& block = cast<Block>(*stmt);
       if (block.statements().empty()) {
         FATAL_COMPILATION_ERROR(stmt->source_loc())
@@ -979,24 +979,24 @@ void TypeChecker::ExpectReturnOnAllPaths(
                              block.source_loc());
       return;
     }
-    case Statement::Kind::If: {
+    case StatementKind::If: {
       auto& if_stmt = cast<If>(*stmt);
       ExpectReturnOnAllPaths(&if_stmt.then_block(), stmt->source_loc());
       ExpectReturnOnAllPaths(if_stmt.else_block(), stmt->source_loc());
       return;
     }
-    case Statement::Kind::Return:
+    case StatementKind::Return:
       return;
-    case Statement::Kind::Continuation:
-    case Statement::Kind::Run:
-    case Statement::Kind::Await:
+    case StatementKind::Continuation:
+    case StatementKind::Run:
+    case StatementKind::Await:
       return;
-    case Statement::Kind::Assign:
-    case Statement::Kind::ExpressionStatement:
-    case Statement::Kind::While:
-    case Statement::Kind::Break:
-    case Statement::Kind::Continue:
-    case Statement::Kind::VariableDefinition:
+    case StatementKind::Assign:
+    case StatementKind::ExpressionStatement:
+    case StatementKind::While:
+    case StatementKind::Break:
+    case StatementKind::Continue:
+    case StatementKind::VariableDefinition:
       FATAL_COMPILATION_ERROR(stmt->source_loc())
           << "control-flow reaches end of function that provides a `->` "
              "return type without reaching a return statement";
@@ -1069,7 +1069,7 @@ auto TypeChecker::TypeOfClassDecl(const ClassDeclaration& class_decl,
   std::vector<NamedValue> methods;
   for (Nonnull<const Member*> m : class_decl.members()) {
     switch (m->kind()) {
-      case Member::Kind::FieldMember: {
+      case MemberKind::FieldMember: {
         const BindingPattern& binding = cast<FieldMember>(*m).binding();
         if (!binding.name().has_value()) {
           FATAL_COMPILATION_ERROR(binding.source_loc())
@@ -1092,13 +1092,13 @@ auto TypeChecker::TypeOfClassDecl(const ClassDeclaration& class_decl,
 
 static auto GetName(const Declaration& d) -> const std::string& {
   switch (d.kind()) {
-    case Declaration::Kind::FunctionDeclaration:
+    case DeclarationKind::FunctionDeclaration:
       return cast<FunctionDeclaration>(d).name();
-    case Declaration::Kind::ClassDeclaration:
+    case DeclarationKind::ClassDeclaration:
       return cast<ClassDeclaration>(d).name();
-    case Declaration::Kind::ChoiceDeclaration:
+    case DeclarationKind::ChoiceDeclaration:
       return cast<ChoiceDeclaration>(d).name();
-    case Declaration::Kind::VariableDeclaration: {
+    case DeclarationKind::VariableDeclaration: {
       const BindingPattern& binding = cast<VariableDeclaration>(d).binding();
       if (!binding.name().has_value()) {
         FATAL_COMPILATION_ERROR(binding.source_loc())
@@ -1122,19 +1122,19 @@ void TypeChecker::TypeCheckDeclaration(Nonnull<Declaration*> d,
                                        const TypeEnv& types,
                                        const Env& values) {
   switch (d->kind()) {
-    case Declaration::Kind::FunctionDeclaration:
+    case DeclarationKind::FunctionDeclaration:
       TypeCheckFunctionDeclaration(&cast<FunctionDeclaration>(*d), types,
                                    values, /*check_body=*/true);
       return;
-    case Declaration::Kind::ClassDeclaration:
+    case DeclarationKind::ClassDeclaration:
       // TODO
       return;
 
-    case Declaration::Kind::ChoiceDeclaration:
+    case DeclarationKind::ChoiceDeclaration:
       // TODO
       return;
 
-    case Declaration::Kind::VariableDeclaration: {
+    case DeclarationKind::VariableDeclaration: {
       auto& var = cast<VariableDeclaration>(*d);
       // Signals a type error if the initializing expression does not have
       // the declared type of the variable, otherwise returns this
@@ -1159,7 +1159,7 @@ void TypeChecker::TypeCheckDeclaration(Nonnull<Declaration*> d,
 
 void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
   switch (d->kind()) {
-    case Declaration::Kind::FunctionDeclaration: {
+    case DeclarationKind::FunctionDeclaration: {
       auto& func_def = cast<FunctionDeclaration>(*d);
       TypeCheckFunctionDeclaration(&func_def, tops->types, tops->values,
                                    /*check_body=*/false);
@@ -1168,7 +1168,7 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
       break;
     }
 
-    case Declaration::Kind::ClassDeclaration: {
+    case DeclarationKind::ClassDeclaration: {
       const auto& class_decl = cast<ClassDeclaration>(*d);
       auto st = TypeOfClassDecl(class_decl, tops->types, tops->values);
       AllocationId a = interpreter_.AllocateValue(st);
@@ -1177,10 +1177,10 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
       break;
     }
 
-    case Declaration::Kind::ChoiceDeclaration: {
+    case DeclarationKind::ChoiceDeclaration: {
       const auto& choice = cast<ChoiceDeclaration>(*d);
       std::vector<NamedValue> alts;
-      for (Nonnull<const ChoiceDeclaration::Alternative*> alternative :
+      for (Nonnull<const AlternativeSignature*> alternative :
            choice.alternatives()) {
         auto t =
             interpreter_.InterpExp(tops->values, &alternative->signature());
@@ -1193,7 +1193,7 @@ void TypeChecker::TopLevel(Nonnull<Declaration*> d, TypeCheckContext* tops) {
       break;
     }
 
-    case Declaration::Kind::VariableDeclaration: {
+    case DeclarationKind::VariableDeclaration: {
       auto& var = cast<VariableDeclaration>(*d);
       // Associate the variable name with it's declared type in the
       // compile-time symbol table.

+ 5 - 8
executable_semantics/syntax/parser.ypp

@@ -131,9 +131,9 @@
 %type <Nonnull<TuplePattern*>> maybe_empty_tuple_pattern
 %type <ParenContents<Pattern>> paren_pattern_base
 %type <ParenContents<Pattern>> paren_pattern_contents
-%type <Nonnull<ChoiceDeclaration::Alternative*>> alternative
-%type <std::vector<Nonnull<ChoiceDeclaration::Alternative*>>> alternative_list
-%type <std::vector<Nonnull<ChoiceDeclaration::Alternative*>>> alternative_list_contents
+%type <Nonnull<AlternativeSignature*>> alternative
+%type <std::vector<Nonnull<AlternativeSignature*>>> alternative_list
+%type <std::vector<Nonnull<AlternativeSignature*>>> alternative_list_contents
 %type <BisonWrap<Match::Clause>> clause
 %type <std::vector<Match::Clause>> clause_list
 
@@ -683,13 +683,10 @@ member_list:
 ;
 alternative:
   identifier tuple
-    {
-      $$ = arena->New<ChoiceDeclaration::Alternative>(context.source_loc(), $1,
-                                                      $2);
-    }
+    { $$ = arena->New<AlternativeSignature>(context.source_loc(), $1, $2); }
 | identifier
     {
-      $$ = arena->New<ChoiceDeclaration::Alternative>(
+      $$ = arena->New<AlternativeSignature>(
           context.source_loc(), $1,
           arena->New<TupleLiteral>(context.source_loc()));
     }