Przeglądaj źródła

Convert Pattern and Expression to Ptr (#787)

Sorry about the big change, this is hard to split. ParenContents is used by both, templated, and expects the same pointer type. While I could duplicate ParenContents with some ExpressionParenContents or PatternParenContents, that seems a little kludgy versus a single large change handling both. The worst of it is that Expression is already pretty sweeping, Pattern is really just incrementally adding.

That said, I believe this includes a couple fixes I found with incorrect use of dyn_cast in typecheck.cpp (checked nullptr at the wrong step in 2 code locations). There's also a missing `*` in member.cpp this caught. I adjust passing of expressions for Return due to nullness (I felt adding another constructor was the best solution).

I add a `.Release()` to BisonWrap due to things like `$3.first` needing some way to work through BIsonWrap. I felt this was better than `operator->`, but feel free to comment if you prefer the other path (`.Release()` conveniently lets me do pair unwrapping, so it felt a better solution).

I do add a TODO to think about better Ptr-to-Ptr cast<> support too, though, as that doesn't work cleanly with LLVM's infra. But so far it seems to only come up in one spot, so I'm not prioritizing it.
Jon Meow 4 lat temu
rodzic
commit
fd89bcb4aa

+ 9 - 9
executable_semantics/ast/declaration.h

@@ -97,7 +97,7 @@ class ChoiceDeclaration : public Declaration {
  public:
   ChoiceDeclaration(
       SourceLocation loc, std::string name,
-      std::list<std::pair<std::string, const Expression*>> alternatives)
+      std::list<std::pair<std::string, Ptr<const Expression>>> alternatives)
       : Declaration(Kind::ChoiceDeclaration, loc),
         name(std::move(name)),
         alternatives(std::move(alternatives)) {}
@@ -108,20 +108,20 @@ class ChoiceDeclaration : public Declaration {
 
   auto Name() const -> const std::string& { return name; }
   auto Alternatives() const
-      -> const std::list<std::pair<std::string, const Expression*>>& {
+      -> const std::list<std::pair<std::string, Ptr<const Expression>>>& {
     return alternatives;
   }
 
  private:
   std::string name;
-  std::list<std::pair<std::string, const Expression*>> alternatives;
+  std::list<std::pair<std::string, Ptr<const Expression>>> alternatives;
 };
 
 // Global variable definition implements the Declaration concept.
 class VariableDeclaration : public Declaration {
  public:
-  VariableDeclaration(SourceLocation loc, const BindingPattern* binding,
-                      const Expression* initializer)
+  VariableDeclaration(SourceLocation loc, Ptr<const BindingPattern> binding,
+                      Ptr<const Expression> initializer)
       : Declaration(Kind::VariableDeclaration, loc),
         binding(binding),
         initializer(initializer) {}
@@ -130,15 +130,15 @@ class VariableDeclaration : public Declaration {
     return decl->Tag() == Kind::VariableDeclaration;
   }
 
-  auto Binding() const -> const BindingPattern* { return binding; }
-  auto Initializer() const -> const Expression* { return initializer; }
+  auto Binding() const -> Ptr<const BindingPattern> { return binding; }
+  auto Initializer() const -> Ptr<const Expression> { return initializer; }
 
  private:
   // TODO: split this into a non-optional name and a type, initialized by
   // a constructor that takes a BindingPattern and handles errors like a
   // missing name.
-  const BindingPattern* binding;
-  const Expression* initializer;
+  Ptr<const BindingPattern> binding;
+  Ptr<const Expression> initializer;
 };
 
 }  // namespace Carbon

+ 5 - 4
executable_semantics/ast/expression.cpp

@@ -18,8 +18,9 @@ using llvm::cast;
 
 auto ExpressionFromParenContents(
     SourceLocation loc, const ParenContents<Expression>& paren_contents)
-    -> const Expression* {
-  std::optional<const Expression*> single_term = paren_contents.SingleTerm();
+    -> Ptr<const Expression> {
+  std::optional<Ptr<const Expression>> single_term =
+      paren_contents.SingleTerm();
   if (single_term.has_value()) {
     return *single_term;
   } else {
@@ -29,8 +30,8 @@ auto ExpressionFromParenContents(
 
 auto TupleExpressionFromParenContents(
     SourceLocation loc, const ParenContents<Expression>& paren_contents)
-    -> const Expression* {
-  return global_arena->RawNew<TupleLiteral>(
+    -> Ptr<const Expression> {
+  return global_arena->New<TupleLiteral>(
       loc, paren_contents.TupleElements<FieldInitializer>(loc));
 }
 

+ 32 - 29
executable_semantics/ast/expression.h

@@ -63,24 +63,24 @@ class Expression {
 // tuple otherwise.
 auto ExpressionFromParenContents(
     SourceLocation loc, const ParenContents<Expression>& paren_contents)
-    -> const Expression*;
+    -> Ptr<const Expression>;
 
 // Converts paren_contents to an Expression, interpreting the parentheses as
 // forming a tuple.
 auto TupleExpressionFromParenContents(
     SourceLocation loc, const ParenContents<Expression>& paren_contents)
-    -> const Expression*;
+    -> Ptr<const Expression>;
 
 // A FieldInitializer represents the initialization of a single tuple field.
 struct FieldInitializer {
-  FieldInitializer(std::string name, const Expression* expression)
+  FieldInitializer(std::string name, Ptr<const Expression> expression)
       : name(std::move(name)), expression(expression) {}
 
   // The field name. Cannot be empty.
   std::string name;
 
   // The expression that initializes the field.
-  const Expression* expression;
+  Ptr<const Expression> expression;
 };
 
 enum class Operator {
@@ -114,7 +114,8 @@ class IdentifierExpression : public Expression {
 class FieldAccessExpression : public Expression {
  public:
   explicit FieldAccessExpression(SourceLocation loc,
-                                 const Expression* aggregate, std::string field)
+                                 Ptr<const Expression> aggregate,
+                                 std::string field)
       : Expression(Kind::FieldAccessExpression, loc),
         aggregate(aggregate),
         field(std::move(field)) {}
@@ -123,18 +124,18 @@ class FieldAccessExpression : public Expression {
     return exp->Tag() == Kind::FieldAccessExpression;
   }
 
-  auto Aggregate() const -> const Expression* { return aggregate; }
+  auto Aggregate() const -> Ptr<const Expression> { return aggregate; }
   auto Field() const -> const std::string& { return field; }
 
  private:
-  const Expression* aggregate;
+  Ptr<const Expression> aggregate;
   std::string field;
 };
 
 class IndexExpression : public Expression {
  public:
-  explicit IndexExpression(SourceLocation loc, const Expression* aggregate,
-                           const Expression* offset)
+  explicit IndexExpression(SourceLocation loc, Ptr<const Expression> aggregate,
+                           Ptr<const Expression> offset)
       : Expression(Kind::IndexExpression, loc),
         aggregate(aggregate),
         offset(offset) {}
@@ -143,12 +144,12 @@ class IndexExpression : public Expression {
     return exp->Tag() == Kind::IndexExpression;
   }
 
-  auto Aggregate() const -> const Expression* { return aggregate; }
-  auto Offset() const -> const Expression* { return offset; }
+  auto Aggregate() const -> Ptr<const Expression> { return aggregate; }
+  auto Offset() const -> Ptr<const Expression> { return offset; }
 
  private:
-  const Expression* aggregate;
-  const Expression* offset;
+  Ptr<const Expression> aggregate;
+  Ptr<const Expression> offset;
 };
 
 class IntLiteral : public Expression {
@@ -226,8 +227,9 @@ class TupleLiteral : public Expression {
 
 class PrimitiveOperatorExpression : public Expression {
  public:
-  explicit PrimitiveOperatorExpression(SourceLocation loc, Operator op,
-                                       std::vector<const Expression*> arguments)
+  explicit PrimitiveOperatorExpression(
+      SourceLocation loc, Operator op,
+      std::vector<Ptr<const Expression>> arguments)
       : Expression(Kind::PrimitiveOperatorExpression, loc),
         op(op),
         arguments(std::move(arguments)) {}
@@ -237,19 +239,19 @@ class PrimitiveOperatorExpression : public Expression {
   }
 
   auto Op() const -> Operator { return op; }
-  auto Arguments() const -> const std::vector<const Expression*>& {
+  auto Arguments() const -> const std::vector<Ptr<const Expression>>& {
     return arguments;
   }
 
  private:
   Operator op;
-  std::vector<const Expression*> arguments;
+  std::vector<Ptr<const Expression>> arguments;
 };
 
 class CallExpression : public Expression {
  public:
-  explicit CallExpression(SourceLocation loc, const Expression* function,
-                          const Expression* argument)
+  explicit CallExpression(SourceLocation loc, Ptr<const Expression> function,
+                          Ptr<const Expression> argument)
       : Expression(Kind::CallExpression, loc),
         function(function),
         argument(argument) {}
@@ -258,18 +260,19 @@ class CallExpression : public Expression {
     return exp->Tag() == Kind::CallExpression;
   }
 
-  auto Function() const -> const Expression* { return function; }
-  auto Argument() const -> const Expression* { return argument; }
+  auto Function() const -> Ptr<const Expression> { return function; }
+  auto Argument() const -> Ptr<const Expression> { return argument; }
 
  private:
-  const Expression* function;
-  const Expression* argument;
+  Ptr<const Expression> function;
+  Ptr<const Expression> argument;
 };
 
 class FunctionTypeLiteral : public Expression {
  public:
-  explicit FunctionTypeLiteral(SourceLocation loc, const Expression* parameter,
-                               const Expression* return_type,
+  explicit FunctionTypeLiteral(SourceLocation loc,
+                               Ptr<const Expression> parameter,
+                               Ptr<const Expression> return_type,
                                bool is_omitted_return_type)
       : Expression(Kind::FunctionTypeLiteral, loc),
         parameter(parameter),
@@ -280,13 +283,13 @@ class FunctionTypeLiteral : public Expression {
     return exp->Tag() == Kind::FunctionTypeLiteral;
   }
 
-  auto Parameter() const -> const Expression* { return parameter; }
-  auto ReturnType() const -> const Expression* { return return_type; }
+  auto Parameter() const -> Ptr<const Expression> { return parameter; }
+  auto ReturnType() const -> Ptr<const Expression> { return return_type; }
   auto IsOmittedReturnType() const -> bool { return is_omitted_return_type; }
 
  private:
-  const Expression* parameter;
-  const Expression* return_type;
+  Ptr<const Expression> parameter;
+  Ptr<const Expression> return_type;
   bool is_omitted_return_type;
 };
 

+ 22 - 22
executable_semantics/ast/expression_test.cpp

@@ -33,7 +33,7 @@ static auto FakeSourceLoc(int line_num) -> SourceLocation {
 TEST(ExpressionTest, EmptyAsExpression) {
   ParenContents<Expression> contents = {.elements = {},
                                         .has_trailing_comma = false};
-  const Expression* expression =
+  Ptr<const Expression> expression =
       ExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(expression->Tag(), Expression::Kind::TupleLiteral);
@@ -43,7 +43,7 @@ TEST(ExpressionTest, EmptyAsExpression) {
 TEST(ExpressionTest, EmptyAsTuple) {
   ParenContents<Expression> contents = {.elements = {},
                                         .has_trailing_comma = false};
-  const Expression* tuple =
+  Ptr<const Expression> tuple =
       TupleExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(tuple->Tag(), Expression::Kind::TupleLiteral);
@@ -59,11 +59,11 @@ TEST(ExpressionTest, UnaryNoCommaAsExpression) {
   // ```
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(2),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)}},
       .has_trailing_comma = false};
 
-  const Expression* expression =
+  Ptr<const Expression> expression =
       ExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->SourceLoc(), FakeSourceLoc(2));
   ASSERT_EQ(expression->Tag(), Expression::Kind::IntLiteral);
@@ -72,11 +72,11 @@ TEST(ExpressionTest, UnaryNoCommaAsExpression) {
 TEST(ExpressionTest, UnaryNoCommaAsTuple) {
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(2),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)}},
       .has_trailing_comma = false};
 
-  const Expression* tuple =
+  Ptr<const Expression> tuple =
       TupleExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(tuple->Tag(), Expression::Kind::TupleLiteral);
@@ -87,11 +87,11 @@ TEST(ExpressionTest, UnaryNoCommaAsTuple) {
 TEST(ExpressionTest, UnaryWithCommaAsExpression) {
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(2),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)}},
       .has_trailing_comma = true};
 
-  const Expression* expression =
+  Ptr<const Expression> expression =
       ExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(expression->Tag(), Expression::Kind::TupleLiteral);
@@ -102,11 +102,11 @@ TEST(ExpressionTest, UnaryWithCommaAsExpression) {
 TEST(ExpressionTest, UnaryWithCommaAsTuple) {
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(2),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)}},
       .has_trailing_comma = true};
 
-  const Expression* tuple =
+  Ptr<const Expression> tuple =
       TupleExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(tuple->Tag(), Expression::Kind::TupleLiteral);
@@ -118,13 +118,13 @@ TEST(ExpressionTest, BinaryAsExpression) {
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
                     .term =
-                        global_arena->RawNew<IntLiteral>(FakeSourceLoc(2), 42)},
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)},
                    {.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(3),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(3), 42)}},
       .has_trailing_comma = true};
 
-  const Expression* expression =
+  Ptr<const Expression> expression =
       ExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(expression->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(expression->Tag(), Expression::Kind::TupleLiteral);
@@ -136,13 +136,13 @@ TEST(ExpressionTest, BinaryAsTuple) {
   ParenContents<Expression> contents = {
       .elements = {{.name = std::nullopt,
                     .term =
-                        global_arena->RawNew<IntLiteral>(FakeSourceLoc(2), 42)},
+                        global_arena->New<IntLiteral>(FakeSourceLoc(2), 42)},
                    {.name = std::nullopt,
-                    .term = global_arena->RawNew<IntLiteral>(FakeSourceLoc(3),
-                                                             42)}},
+                    .term =
+                        global_arena->New<IntLiteral>(FakeSourceLoc(3), 42)}},
       .has_trailing_comma = true};
 
-  const Expression* tuple =
+  Ptr<const Expression> tuple =
       TupleExpressionFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   ASSERT_EQ(tuple->Tag(), Expression::Kind::TupleLiteral);

+ 6 - 6
executable_semantics/ast/function_definition.h

@@ -18,15 +18,15 @@ namespace Carbon {
 //   For now, only generic parameters are supported.
 struct GenericBinding {
   std::string name;
-  const Expression* type;
+  Ptr<const Expression> type;
 };
 
 struct FunctionDefinition {
   FunctionDefinition(SourceLocation source_location, std::string name,
                      std::vector<GenericBinding> deduced_params,
-                     const TuplePattern* param_pattern,
-                     const Pattern* return_type, bool is_omitted_return_type,
-                     const Statement* body)
+                     Ptr<const TuplePattern> param_pattern,
+                     Ptr<const Pattern> return_type,
+                     bool is_omitted_return_type, const Statement* body)
       : source_location(source_location),
         name(std::move(name)),
         deduced_parameters(deduced_params),
@@ -42,8 +42,8 @@ struct FunctionDefinition {
   SourceLocation source_location;
   std::string name;
   std::vector<GenericBinding> deduced_parameters;
-  const TuplePattern* param_pattern;
-  const Pattern* return_type;
+  Ptr<const TuplePattern> param_pattern;
+  Ptr<const Pattern> return_type;
   bool is_omitted_return_type;
   const Statement* body;
 };

+ 1 - 1
executable_semantics/ast/member.cpp

@@ -15,7 +15,7 @@ void Member::Print(llvm::raw_ostream& out) const {
   switch (Tag()) {
     case Kind::FieldMember:
       const auto& field = cast<FieldMember>(*this);
-      out << "var " << field.Binding() << ";\n";
+      out << "var " << *field.Binding() << ";\n";
       break;
   }
 }

+ 3 - 3
executable_semantics/ast/member.h

@@ -51,20 +51,20 @@ class Member {
 
 class FieldMember : public Member {
  public:
-  FieldMember(SourceLocation loc, const BindingPattern* binding)
+  FieldMember(SourceLocation loc, Ptr<const BindingPattern> binding)
       : Member(Kind::FieldMember, loc), binding(binding) {}
 
   static auto classof(const Member* member) -> bool {
     return member->Tag() == Kind::FieldMember;
   }
 
-  auto Binding() const -> const BindingPattern* { return binding; }
+  auto Binding() const -> Ptr<const BindingPattern> { return binding; }
 
  private:
   // TODO: split this into a non-optional name and a type, initialized by
   // a constructor that takes a BindingPattern and handles errors like a
   // missing name.
-  const BindingPattern* binding;
+  Ptr<const BindingPattern> binding;
 };
 
 }  // namespace Carbon

+ 4 - 4
executable_semantics/ast/paren_contents.h

@@ -28,16 +28,16 @@ template <typename Term>
 struct ParenContents {
   struct Element {
     std::optional<std::string> name;
-    const Term* term;
+    Ptr<const Term> term;
   };
 
   // If this object represents a single term, with no name and no trailing
   // comma, this method returns that term. This typically means the parentheses
   // can be interpreted as grouping.
-  auto SingleTerm() const -> std::optional<const Term*>;
+  auto SingleTerm() const -> std::optional<Ptr<const Term>>;
 
   // Converts `elements` to std::vector<TupleElement>. TupleElement must
-  // have a constructor that takes a std::string and a const Term*.
+  // have a constructor that takes a std::string and a Ptr<const Term>.
   //
   // TODO: Find a way to deduce TupleElement from Term.
   template <typename TupleElement>
@@ -50,7 +50,7 @@ struct ParenContents {
 // Implementation details only below here.
 
 template <typename Term>
-auto ParenContents<Term>::SingleTerm() const -> std::optional<const Term*> {
+auto ParenContents<Term>::SingleTerm() const -> std::optional<Ptr<const Term>> {
   if (elements.size() == 1 && !elements.front().name.has_value() &&
       !has_trailing_comma) {
     return elements.front().term;

+ 21 - 14
executable_semantics/ast/pattern.cpp

@@ -54,19 +54,19 @@ void Pattern::Print(llvm::raw_ostream& out) const {
   }
 }
 
-TuplePattern::TuplePattern(const Expression* tuple_literal)
+TuplePattern::TuplePattern(Ptr<const Expression> tuple_literal)
     : Pattern(Kind::TuplePattern, tuple_literal->SourceLoc()) {
   const auto& tuple = cast<TupleLiteral>(*tuple_literal);
   for (const FieldInitializer& init : tuple.Fields()) {
     fields.push_back(Field(
-        init.name, global_arena->RawNew<ExpressionPattern>(init.expression)));
+        init.name, global_arena->New<ExpressionPattern>(init.expression)));
   }
 }
 
 auto PatternFromParenContents(SourceLocation loc,
                               const ParenContents<Pattern>& paren_contents)
-    -> const Pattern* {
-  std::optional<const Pattern*> single_term = paren_contents.SingleTerm();
+    -> Ptr<const Pattern> {
+  std::optional<Ptr<const Pattern>> single_term = paren_contents.SingleTerm();
   if (single_term.has_value()) {
     return *single_term;
   } else {
@@ -76,24 +76,31 @@ auto PatternFromParenContents(SourceLocation loc,
 
 auto TuplePatternFromParenContents(SourceLocation loc,
                                    const ParenContents<Pattern>& paren_contents)
-    -> const TuplePattern* {
-  return global_arena->RawNew<TuplePattern>(
+    -> Ptr<const TuplePattern> {
+  return global_arena->New<TuplePattern>(
       loc, paren_contents.TupleElements<TuplePattern::Field>(loc));
 }
 
-AlternativePattern::AlternativePattern(SourceLocation loc,
-                                       const Expression* alternative,
-                                       const TuplePattern* arguments)
-    : Pattern(Kind::AlternativePattern, loc), arguments(arguments) {
+// Used by AlternativePattern for constructor initialization. Produces a helpful
+// error for incorrect expressions, rather than letting a default cast error
+// apply.
+static const FieldAccessExpression& RequireFieldAccess(
+    Ptr<const Expression> alternative) {
   if (alternative->Tag() != Expression::Kind::FieldAccessExpression) {
     FATAL_PROGRAM_ERROR(alternative->SourceLoc())
         << "Alternative pattern must have the form of a field access.";
   }
-  const auto& field_access = cast<FieldAccessExpression>(*alternative);
-  choice_type = field_access.Aggregate();
-  alternative_name = field_access.Field();
+  return cast<FieldAccessExpression>(*alternative);
 }
 
+AlternativePattern::AlternativePattern(SourceLocation loc,
+                                       Ptr<const Expression> alternative,
+                                       Ptr<const TuplePattern> arguments)
+    : Pattern(Kind::AlternativePattern, loc),
+      choice_type(RequireFieldAccess(alternative).Aggregate()),
+      alternative_name(RequireFieldAccess(alternative).Field()),
+      arguments(arguments) {}
+
 auto ParenExpressionToParenPattern(const ParenContents<Expression>& contents)
     -> ParenContents<Pattern> {
   ParenContents<Pattern> result = {
@@ -101,7 +108,7 @@ auto ParenExpressionToParenPattern(const ParenContents<Expression>& contents)
   for (const auto& element : contents.elements) {
     result.elements.push_back(
         {.name = element.name,
-         .term = global_arena->RawNew<ExpressionPattern>(element.term)});
+         .term = global_arena->New<ExpressionPattern>(element.term)});
   }
   return result;
 }

+ 19 - 19
executable_semantics/ast/pattern.h

@@ -71,7 +71,7 @@ class AutoPattern : public Pattern {
 class BindingPattern : public Pattern {
  public:
   BindingPattern(SourceLocation loc, std::optional<std::string> name,
-                 const Pattern* type)
+                 Ptr<const Pattern> type)
       : Pattern(Kind::BindingPattern, loc), name(std::move(name)), type(type) {}
 
   static auto classof(const Pattern* pattern) -> bool {
@@ -82,11 +82,11 @@ class BindingPattern : public Pattern {
   auto Name() const -> const std::optional<std::string>& { return name; }
 
   // The pattern specifying the type of values that this pattern matches.
-  auto Type() const -> const Pattern* { return type; }
+  auto Type() const -> Ptr<const Pattern> { return type; }
 
  private:
   std::optional<std::string> name;
-  const Pattern* type;
+  Ptr<const Pattern> type;
 };
 
 // A pattern that matches a tuple value field-wise.
@@ -94,14 +94,14 @@ class TuplePattern : public Pattern {
  public:
   // Represents a portion of a tuple pattern corresponding to a single field.
   struct Field {
-    Field(std::string name, const Pattern* pattern)
+    Field(std::string name, Ptr<const Pattern> pattern)
         : name(std::move(name)), pattern(pattern) {}
 
     // The field name. Cannot be empty
     std::string name;
 
     // The pattern the field must match.
-    const Pattern* pattern;
+    Ptr<const Pattern> pattern;
   };
 
   TuplePattern(SourceLocation loc, std::vector<Field> fields)
@@ -111,7 +111,7 @@ class TuplePattern : public Pattern {
   // ExpressionPattern.
   //
   // REQUIRES: tuple_literal->Tag() == Expression::Kind::TupleLiteral
-  explicit TuplePattern(const Expression* tuple_literal);
+  explicit TuplePattern(Ptr<const Expression> tuple_literal);
 
   static auto classof(const Pattern* pattern) -> bool {
     return pattern->Tag() == Kind::TuplePattern;
@@ -128,13 +128,13 @@ class TuplePattern : public Pattern {
 // tuple otherwise.
 auto PatternFromParenContents(SourceLocation loc,
                               const ParenContents<Pattern>& paren_contents)
-    -> const Pattern*;
+    -> Ptr<const Pattern>;
 
 // Converts paren_contents to a TuplePattern, interpreting the parentheses as
 // forming a tuple.
 auto TuplePatternFromParenContents(SourceLocation loc,
                                    const ParenContents<Pattern>& paren_contents)
-    -> const TuplePattern*;
+    -> Ptr<const TuplePattern>;
 
 // Converts `contents` to ParenContents<Pattern> by replacing each Expression
 // with an ExpressionPattern.
@@ -147,9 +147,9 @@ class AlternativePattern : public Pattern {
   // Constructs an AlternativePattern that matches a value of the type
   // specified by choice_type if it represents an alternative named
   // alternative_name, and its arguments match `arguments`.
-  AlternativePattern(SourceLocation loc, const Expression* choice_type,
+  AlternativePattern(SourceLocation loc, Ptr<const Expression> choice_type,
                      std::string alternative_name,
-                     const TuplePattern* arguments)
+                     Ptr<const TuplePattern> arguments)
       : Pattern(Kind::AlternativePattern, loc),
         choice_type(choice_type),
         alternative_name(std::move(alternative_name)),
@@ -157,30 +157,30 @@ class AlternativePattern : public Pattern {
 
   // Constructs an AlternativePattern that matches the alternative specified
   // by `alternative`, if its arguments match `arguments`.
-  AlternativePattern(SourceLocation loc, const Expression* alternative,
-                     const TuplePattern* arguments);
+  AlternativePattern(SourceLocation loc, Ptr<const Expression> alternative,
+                     Ptr<const TuplePattern> arguments);
 
   static auto classof(const Pattern* pattern) -> bool {
     return pattern->Tag() == Kind::AlternativePattern;
   }
 
-  auto ChoiceType() const -> const Expression* { return choice_type; }
+  auto ChoiceType() const -> Ptr<const Expression> { return choice_type; }
   auto AlternativeName() const -> const std::string& {
     return alternative_name;
   }
-  auto Arguments() const -> const TuplePattern* { return arguments; }
+  auto Arguments() const -> Ptr<const TuplePattern> { return arguments; }
 
  private:
-  const Expression* choice_type;
+  Ptr<const Expression> choice_type;
   std::string alternative_name;
-  const TuplePattern* arguments;
+  Ptr<const TuplePattern> arguments;
 };
 
 // A pattern that matches a value if it is equal to the value of a given
 // expression.
 class ExpressionPattern : public Pattern {
  public:
-  ExpressionPattern(const Expression* expression)
+  ExpressionPattern(Ptr<const Expression> expression)
       : Pattern(Kind::ExpressionPattern, expression->SourceLoc()),
         expression(expression) {}
 
@@ -188,10 +188,10 @@ class ExpressionPattern : public Pattern {
     return pattern->Tag() == Kind::ExpressionPattern;
   }
 
-  auto Expression() const -> const Expression* { return expression; }
+  auto Expression() const -> Ptr<const Expression> { return expression; }
 
  private:
-  const Carbon::Expression* expression;
+  Ptr<const Carbon::Expression> expression;
 };
 
 }  // namespace Carbon

+ 28 - 32
executable_semantics/ast/pattern_test.cpp

@@ -22,7 +22,7 @@ using testing::IsEmpty;
 // Matches a TuplePattern::Field named `name` whose `pattern` is an
 // `AutoPattern`.
 MATCHER_P(AutoFieldNamed, name, "") {
-  return arg.name == std::string(name) && isa<AutoPattern>(arg.pattern);
+  return arg.name == std::string(name) && isa<AutoPattern>(*arg.pattern);
 }
 
 static auto FakeSourceLoc(int line_num) -> SourceLocation {
@@ -32,16 +32,17 @@ static auto FakeSourceLoc(int line_num) -> SourceLocation {
 TEST(PatternTest, EmptyAsPattern) {
   ParenContents<Pattern> contents = {.elements = {},
                                      .has_trailing_comma = false};
-  const Pattern* pattern = PatternFromParenContents(FakeSourceLoc(1), contents);
+  Ptr<const Pattern> pattern =
+      PatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(pattern->SourceLoc(), FakeSourceLoc(1));
-  ASSERT_TRUE(isa<TuplePattern>(pattern));
-  EXPECT_THAT(cast<TuplePattern>(pattern)->Fields(), IsEmpty());
+  ASSERT_TRUE(isa<TuplePattern>(*pattern));
+  EXPECT_THAT(cast<TuplePattern>(*pattern).Fields(), IsEmpty());
 }
 
 TEST(PatternTest, EmptyAsTuplePattern) {
   ParenContents<Pattern> contents = {.elements = {},
                                      .has_trailing_comma = false};
-  const TuplePattern* tuple =
+  Ptr<const TuplePattern> tuple =
       TuplePatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   EXPECT_THAT(tuple->Fields(), IsEmpty());
@@ -56,23 +57,22 @@ TEST(PatternTest, UnaryNoCommaAsPattern) {
   // ```
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = false};
 
-  const Pattern* pattern = PatternFromParenContents(FakeSourceLoc(1), contents);
+  Ptr<const Pattern> pattern =
+      PatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(pattern->SourceLoc(), FakeSourceLoc(2));
-  ASSERT_TRUE(isa<AutoPattern>(pattern));
+  ASSERT_TRUE(isa<AutoPattern>(*pattern));
 }
 
 TEST(PatternTest, UnaryNoCommaAsTuplePattern) {
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = false};
 
-  const TuplePattern* tuple =
+  Ptr<const TuplePattern> tuple =
       TuplePatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   EXPECT_THAT(tuple->Fields(), ElementsAre(AutoFieldNamed("0")));
@@ -81,25 +81,24 @@ TEST(PatternTest, UnaryNoCommaAsTuplePattern) {
 TEST(PatternTest, UnaryWithCommaAsPattern) {
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = true};
 
-  const Pattern* pattern = PatternFromParenContents(FakeSourceLoc(1), contents);
+  Ptr<const Pattern> pattern =
+      PatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(pattern->SourceLoc(), FakeSourceLoc(1));
-  ASSERT_TRUE(isa<TuplePattern>(pattern));
-  EXPECT_THAT(cast<TuplePattern>(pattern)->Fields(),
+  ASSERT_TRUE(isa<TuplePattern>(*pattern));
+  EXPECT_THAT(cast<TuplePattern>(*pattern).Fields(),
               ElementsAre(AutoFieldNamed("0")));
 }
 
 TEST(PatternTest, UnaryWithCommaAsTuplePattern) {
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = true};
 
-  const TuplePattern* tuple =
+  Ptr<const TuplePattern> tuple =
       TuplePatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   EXPECT_THAT(tuple->Fields(), ElementsAre(AutoFieldNamed("0")));
@@ -108,31 +107,28 @@ TEST(PatternTest, UnaryWithCommaAsTuplePattern) {
 TEST(PatternTest, BinaryAsPattern) {
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))},
                    {.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = true};
 
-  const Pattern* pattern = PatternFromParenContents(FakeSourceLoc(1), contents);
+  Ptr<const Pattern> pattern =
+      PatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(pattern->SourceLoc(), FakeSourceLoc(1));
-  ASSERT_TRUE(isa<TuplePattern>(pattern));
-  EXPECT_THAT(cast<TuplePattern>(pattern)->Fields(),
+  ASSERT_TRUE(isa<TuplePattern>(*pattern));
+  EXPECT_THAT(cast<TuplePattern>(*pattern).Fields(),
               ElementsAre(AutoFieldNamed("0"), AutoFieldNamed("1")));
 }
 
 TEST(PatternTest, BinaryAsTuplePattern) {
   ParenContents<Pattern> contents = {
       .elements = {{.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))},
                    {.name = std::nullopt,
-                    .term =
-                        global_arena->RawNew<AutoPattern>(FakeSourceLoc(2))}},
+                    .term = global_arena->New<AutoPattern>(FakeSourceLoc(2))}},
       .has_trailing_comma = true};
 
-  const TuplePattern* tuple =
+  Ptr<const TuplePattern> tuple =
       TuplePatternFromParenContents(FakeSourceLoc(1), contents);
   EXPECT_EQ(tuple->SourceLoc(), FakeSourceLoc(1));
   EXPECT_THAT(tuple->Fields(),

+ 0 - 7
executable_semantics/ast/statement.cpp

@@ -130,11 +130,4 @@ void Statement::PrintDepth(int depth, llvm::raw_ostream& out) const {
   }
 }
 
-Return::Return(SourceLocation loc, const Expression* exp, bool is_omitted_exp)
-    : Statement(Kind::Return, loc),
-      exp(exp != nullptr ? exp : global_arena->RawNew<TupleLiteral>(loc)),
-      is_omitted_exp(is_omitted_exp) {
-  CHECK(exp != nullptr || is_omitted_exp);
-}
-
 }  // namespace Carbon

+ 39 - 32
executable_semantics/ast/statement.h

@@ -11,6 +11,7 @@
 #include "executable_semantics/ast/expression.h"
 #include "executable_semantics/ast/pattern.h"
 #include "executable_semantics/ast/source_location.h"
+#include "executable_semantics/common/arena.h"
 #include "llvm/Support/Compiler.h"
 
 namespace Carbon {
@@ -57,57 +58,58 @@ class Statement {
 
 class ExpressionStatement : public Statement {
  public:
-  ExpressionStatement(SourceLocation loc, const Expression* exp)
+  ExpressionStatement(SourceLocation loc, Ptr<const Expression> exp)
       : Statement(Kind::ExpressionStatement, loc), exp(exp) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::ExpressionStatement;
   }
 
-  auto Exp() const -> const Expression* { return exp; }
+  auto Exp() const -> Ptr<const Expression> { return exp; }
 
  private:
-  const Expression* exp;
+  Ptr<const Expression> exp;
 };
 
 class Assign : public Statement {
  public:
-  Assign(SourceLocation loc, const Expression* lhs, const Expression* rhs)
+  Assign(SourceLocation loc, Ptr<const Expression> lhs,
+         Ptr<const Expression> rhs)
       : Statement(Kind::Assign, loc), lhs(lhs), rhs(rhs) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Assign;
   }
 
-  auto Lhs() const -> const Expression* { return lhs; }
-  auto Rhs() const -> const Expression* { return rhs; }
+  auto Lhs() const -> Ptr<const Expression> { return lhs; }
+  auto Rhs() const -> Ptr<const Expression> { return rhs; }
 
  private:
-  const Expression* lhs;
-  const Expression* rhs;
+  Ptr<const Expression> lhs;
+  Ptr<const Expression> rhs;
 };
 
 class VariableDefinition : public Statement {
  public:
-  VariableDefinition(SourceLocation loc, const Pattern* pat,
-                     const Expression* init)
+  VariableDefinition(SourceLocation loc, Ptr<const Pattern> pat,
+                     Ptr<const Expression> init)
       : Statement(Kind::VariableDefinition, loc), pat(pat), init(init) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::VariableDefinition;
   }
 
-  auto Pat() const -> const Pattern* { return pat; }
-  auto Init() const -> const Expression* { return init; }
+  auto Pat() const -> Ptr<const Pattern> { return pat; }
+  auto Init() const -> Ptr<const Expression> { return init; }
 
  private:
-  const Pattern* pat;
-  const Expression* init;
+  Ptr<const Pattern> pat;
+  Ptr<const Expression> init;
 };
 
 class If : public Statement {
  public:
-  If(SourceLocation loc, const Expression* cond, const Statement* then_stmt,
+  If(SourceLocation loc, Ptr<const Expression> cond, const Statement* then_stmt,
      const Statement* else_stmt)
       : Statement(Kind::If, loc),
         cond(cond),
@@ -118,29 +120,34 @@ class If : public Statement {
     return stmt->Tag() == Kind::If;
   }
 
-  auto Cond() const -> const Expression* { return cond; }
+  auto Cond() const -> Ptr<const Expression> { return cond; }
   auto ThenStmt() const -> const Statement* { return then_stmt; }
   auto ElseStmt() const -> const Statement* { return else_stmt; }
 
  private:
-  const Expression* cond;
+  Ptr<const Expression> cond;
   const Statement* then_stmt;
   const Statement* else_stmt;
 };
 
 class Return : public Statement {
  public:
-  Return(SourceLocation loc, const Expression* exp, bool is_omitted_exp);
+  explicit Return(SourceLocation loc)
+      : Return(loc, global_arena->New<TupleLiteral>(loc), true) {}
+  Return(SourceLocation loc, Ptr<const Expression> exp, bool is_omitted_exp)
+      : Statement(Kind::Return, loc),
+        exp(exp),
+        is_omitted_exp(is_omitted_exp) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Return;
   }
 
-  auto Exp() const -> const Expression* { return exp; }
+  auto Exp() const -> Ptr<const Expression> { return exp; }
   auto IsOmittedExp() const -> bool { return is_omitted_exp; }
 
  private:
-  const Expression* exp;
+  Ptr<const Expression> exp;
   bool is_omitted_exp;
 };
 
@@ -178,18 +185,18 @@ class Block : public Statement {
 
 class While : public Statement {
  public:
-  While(SourceLocation loc, const Expression* cond, const Statement* body)
+  While(SourceLocation loc, Ptr<const Expression> cond, const Statement* body)
       : Statement(Kind::While, loc), cond(cond), body(body) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::While;
   }
 
-  auto Cond() const -> const Expression* { return cond; }
+  auto Cond() const -> Ptr<const Expression> { return cond; }
   auto Body() const -> const Statement* { return body; }
 
  private:
-  const Expression* cond;
+  Ptr<const Expression> cond;
   const Statement* body;
 };
 
@@ -213,23 +220,23 @@ class Continue : public Statement {
 
 class Match : public Statement {
  public:
-  Match(SourceLocation loc, const Expression* exp,
-        std::list<std::pair<const Pattern*, const Statement*>>* clauses)
+  Match(SourceLocation loc, Ptr<const Expression> exp,
+        std::list<std::pair<Ptr<const Pattern>, const Statement*>>* clauses)
       : Statement(Kind::Match, loc), exp(exp), clauses(clauses) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Match;
   }
 
-  auto Exp() const -> const Expression* { return exp; }
+  auto Exp() const -> Ptr<const Expression> { return exp; }
   auto Clauses() const
-      -> const std::list<std::pair<const Pattern*, const Statement*>>* {
+      -> const std::list<std::pair<Ptr<const Pattern>, const Statement*>>* {
     return clauses;
   }
 
  private:
-  const Expression* exp;
-  std::list<std::pair<const Pattern*, const Statement*>>* clauses;
+  Ptr<const Expression> exp;
+  std::list<std::pair<Ptr<const Pattern>, const Statement*>>* clauses;
 };
 
 // A continuation statement.
@@ -264,17 +271,17 @@ class Continuation : public Statement {
 //     __run <argument>;
 class Run : public Statement {
  public:
-  Run(SourceLocation loc, const Expression* argument)
+  Run(SourceLocation loc, Ptr<const Expression> argument)
       : Statement(Kind::Run, loc), argument(argument) {}
 
   static auto classof(const Statement* stmt) -> bool {
     return stmt->Tag() == Kind::Run;
   }
 
-  auto Argument() const -> const Expression* { return argument; }
+  auto Argument() const -> Ptr<const Expression> { return argument; }
 
  private:
-  const Expression* argument;
+  Ptr<const Expression> argument;
 };
 
 // An await statement.

+ 9 - 9
executable_semantics/interpreter/action.h

@@ -72,47 +72,47 @@ class Action {
 
 class LValAction : public Action {
  public:
-  explicit LValAction(const Expression* exp)
+  explicit LValAction(Ptr<const Expression> exp)
       : Action(Kind::LValAction), exp(exp) {}
 
   static auto classof(const Action* action) -> bool {
     return action->Tag() == Kind::LValAction;
   }
 
-  auto Exp() const -> const Expression* { return exp; }
+  auto Exp() const -> Ptr<const Expression> { return exp; }
 
  private:
-  const Expression* exp;
+  Ptr<const Expression> exp;
 };
 
 class ExpressionAction : public Action {
  public:
-  explicit ExpressionAction(const Expression* exp)
+  explicit ExpressionAction(Ptr<const Expression> exp)
       : Action(Kind::ExpressionAction), exp(exp) {}
 
   static auto classof(const Action* action) -> bool {
     return action->Tag() == Kind::ExpressionAction;
   }
 
-  auto Exp() const -> const Expression* { return exp; }
+  auto Exp() const -> Ptr<const Expression> { return exp; }
 
  private:
-  const Expression* exp;
+  Ptr<const Expression> exp;
 };
 
 class PatternAction : public Action {
  public:
-  explicit PatternAction(const Pattern* pat)
+  explicit PatternAction(Ptr<const Pattern> pat)
       : Action(Kind::PatternAction), pat(pat) {}
 
   static auto classof(const Action* action) -> bool {
     return action->Tag() == Kind::PatternAction;
   }
 
-  auto Pat() const -> const Pattern* { return pat; }
+  auto Pat() const -> Ptr<const Pattern> { return pat; }
 
  private:
-  const Pattern* pat;
+  Ptr<const Pattern> pat;
 };
 
 class StatementAction : public Action {

+ 23 - 23
executable_semantics/interpreter/interpreter.cpp

@@ -142,9 +142,9 @@ void InitEnv(const Declaration& d, Env* env) {
       for (Ptr<const Member> m : class_def.members) {
         switch (m->Tag()) {
           case Member::Kind::FieldMember: {
-            const BindingPattern* binding = cast<FieldMember>(*m).Binding();
-            const Expression* type_expression =
-                cast<ExpressionPattern>(binding->Type())->Expression();
+            Ptr<const BindingPattern> binding = cast<FieldMember>(*m).Binding();
+            Ptr<const Expression> type_expression =
+                cast<ExpressionPattern>(*binding->Type()).Expression();
             auto type = InterpExp(Env(), type_expression);
             fields.push_back(make_pair(*binding->Name(), type));
             break;
@@ -205,7 +205,7 @@ void DeallocateLocals(Ptr<Frame> frame) {
   }
 }
 
-const Value* CreateTuple(Ptr<Action> act, const Expression* exp) {
+const Value* CreateTuple(Ptr<Action> act, Ptr<const Expression> exp) {
   //    { { (v1,...,vn) :: C, E, F} :: S, H}
   // -> { { `(v1,...,vn) :: C, E, F} :: S, H}
   const auto& tup_lit = cast<TupleLiteral>(*exp);
@@ -431,7 +431,7 @@ using Transition =
 // State transitions for lvalues.
 Transition StepLvalue() {
   Ptr<Action> act = state->stack.Top()->todo.Top();
-  const Expression* exp = cast<LValAction>(*act).Exp();
+  Ptr<const Expression> exp = cast<LValAction>(*act).Exp();
   if (tracing_output) {
     llvm::outs() << "--- step lvalue " << *exp << " --->\n";
   }
@@ -483,7 +483,8 @@ Transition StepLvalue() {
       if (act->Pos() == 0) {
         //    { {(f1=e1,...) :: C, E, F} :: S, H}
         // -> { {e1 :: (f1=[],...) :: C, E, F} :: S, H}
-        const Expression* e1 = cast<TupleLiteral>(*exp).Fields()[0].expression;
+        Ptr<const Expression> e1 =
+            cast<TupleLiteral>(*exp).Fields()[0].expression;
         return Spawn{global_arena->New<LValAction>(e1)};
       } else if (act->Pos() !=
                  static_cast<int>(cast<TupleLiteral>(*exp).Fields().size())) {
@@ -491,7 +492,7 @@ Transition StepLvalue() {
         //    H}
         // -> { { ek+1 :: (f1=v1,..., fk=vk, fk+1=[],...) :: C, E, F} :: S,
         // H}
-        const Expression* elt =
+        Ptr<const Expression> elt =
             cast<TupleLiteral>(*exp).Fields()[act->Pos()].expression;
         return Spawn{global_arena->New<LValAction>(elt)};
       } else {
@@ -518,7 +519,7 @@ Transition StepLvalue() {
 // State transitions for expressions.
 Transition StepExp() {
   Ptr<Action> act = state->stack.Top()->todo.Top();
-  const Expression* exp = cast<ExpressionAction>(*act).Exp();
+  Ptr<const Expression> exp = cast<ExpressionAction>(*act).Exp();
   if (tracing_output) {
     llvm::outs() << "--- step exp " << *exp << " --->\n";
   }
@@ -555,7 +556,7 @@ Transition StepExp() {
         if (cast<TupleLiteral>(*exp).Fields().size() > 0) {
           //    { {(f1=e1,...) :: C, E, F} :: S, H}
           // -> { {e1 :: (f1=[],...) :: C, E, F} :: S, H}
-          const Expression* e1 =
+          Ptr<const Expression> e1 =
               cast<TupleLiteral>(*exp).Fields()[0].expression;
           return Spawn{global_arena->New<ExpressionAction>(e1)};
         } else {
@@ -567,7 +568,7 @@ Transition StepExp() {
         //    H}
         // -> { { ek+1 :: (f1=v1,..., fk=vk, fk+1=[],...) :: C, E, F} :: S,
         // H}
-        const Expression* elt =
+        Ptr<const Expression> elt =
             cast<TupleLiteral>(*exp).Fields()[act->Pos()].expression;
         return Spawn{global_arena->New<ExpressionAction>(elt)};
       } else {
@@ -608,7 +609,7 @@ Transition StepExp() {
       if (act->Pos() != static_cast<int>(op.Arguments().size())) {
         //    { {v :: op(vs,[],e,es) :: C, E, F} :: S, H}
         // -> { {e :: op(vs,v,[],es) :: C, E, F} :: S, H}
-        const Expression* arg = op.Arguments()[act->Pos()];
+        Ptr<const Expression> arg = op.Arguments()[act->Pos()];
         return Spawn{global_arena->New<ExpressionAction>(arg)};
       } else {
         //    { {v :: op(vs,[]) :: C, E, F} :: S, H}
@@ -715,7 +716,7 @@ Transition StepExp() {
 
 Transition StepPattern() {
   Ptr<Action> act = state->stack.Top()->todo.Top();
-  const Pattern* pattern = cast<PatternAction>(*act).Pat();
+  Ptr<const Pattern> pattern = cast<PatternAction>(*act).Pat();
   if (tracing_output) {
     llvm::outs() << "--- step pattern " << *pattern << " --->\n";
   }
@@ -739,7 +740,7 @@ Transition StepPattern() {
         if (tuple.Fields().empty()) {
           return Done{&TupleValue::Empty()};
         } else {
-          const Pattern* p1 = tuple.Fields()[0].pattern;
+          Ptr<const Pattern> p1 = tuple.Fields()[0].pattern;
           return Spawn{(global_arena->New<PatternAction>(p1))};
         }
       } else if (act->Pos() != static_cast<int>(tuple.Fields().size())) {
@@ -747,7 +748,7 @@ Transition StepPattern() {
         //    H}
         // -> { { ek+1 :: (f1=v1,..., fk=vk, fk+1=[],...) :: C, E, F} :: S,
         // H}
-        const Pattern* elt = tuple.Fields()[act->Pos()].pattern;
+        Ptr<const Pattern> elt = tuple.Fields()[act->Pos()].pattern;
         return Spawn{global_arena->New<PatternAction>(elt)};
       } else {
         std::vector<TupleElement> elements;
@@ -775,7 +776,7 @@ Transition StepPattern() {
     }
     case Pattern::Kind::ExpressionPattern:
       return Delegate{global_arena->New<ExpressionAction>(
-          cast<ExpressionPattern>(pattern)->Expression())};
+          cast<ExpressionPattern>(*pattern).Expression())};
   }
 }
 
@@ -1045,8 +1046,7 @@ Transition StepStmt() {
           Stack<Ptr<Scope>>(global_arena->New<Scope>(CurrentEnv(state)));
       Stack<Ptr<Action>> todo;
       todo.Push(global_arena->New<StatementAction>(
-          global_arena->RawNew<Return>(stmt->SourceLoc(), nullptr,
-                                       /*is_omitted_exp=*/true)));
+          global_arena->RawNew<Return>(stmt->SourceLoc())));
       todo.Push(
           global_arena->New<StatementAction>(cast<Continuation>(*stmt).Body()));
       auto continuation_frame =
@@ -1076,7 +1076,7 @@ Transition StepStmt() {
         auto ignore_result = global_arena->New<StatementAction>(
             global_arena->RawNew<ExpressionStatement>(
                 stmt->SourceLoc(),
-                global_arena->RawNew<TupleLiteral>(stmt->SourceLoc())));
+                global_arena->New<TupleLiteral>(stmt->SourceLoc())));
         frame->todo.Push(ignore_result);
         // Push the continuation onto the current stack.
         const std::vector<Ptr<Frame>>& continuation_vector =
@@ -1217,9 +1217,9 @@ auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int {
 
   SourceLocation loc("<InterpProgram()>", 0);
 
-  const Expression* arg = global_arena->RawNew<TupleLiteral>(loc);
-  const Expression* call_main = global_arena->RawNew<CallExpression>(
-      loc, global_arena->RawNew<IdentifierExpression>(loc, "main"), arg);
+  Ptr<const Expression> arg = global_arena->New<TupleLiteral>(loc);
+  Ptr<const Expression> call_main = global_arena->New<CallExpression>(
+      loc, global_arena->New<IdentifierExpression>(loc, "main"), arg);
   auto todo =
       Stack<Ptr<Action>>(global_arena->New<ExpressionAction>(call_main));
   auto scopes = Stack<Ptr<Scope>>(global_arena->New<Scope>(globals));
@@ -1241,7 +1241,7 @@ auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int {
 }
 
 // Interpret an expression at compile-time.
-auto InterpExp(Env values, const Expression* e) -> const Value* {
+auto InterpExp(Env values, Ptr<const Expression> e) -> const Value* {
   CHECK(state->program_value == std::nullopt);
   auto program_value_guard =
       llvm::make_scope_exit([] { state->program_value = std::nullopt; });
@@ -1258,7 +1258,7 @@ auto InterpExp(Env values, const Expression* e) -> const Value* {
 }
 
 // Interpret a pattern at compile-time.
-auto InterpPattern(Env values, const Pattern* p) -> const Value* {
+auto InterpPattern(Env values, Ptr<const Pattern> p) -> const Value* {
   CHECK(state->program_value == std::nullopt);
   auto program_value_guard =
       llvm::make_scope_exit([] { state->program_value = std::nullopt; });

+ 2 - 2
executable_semantics/interpreter/interpreter.h

@@ -43,8 +43,8 @@ auto PatternMatch(const Value* p, const Value* v, SourceLocation loc)
     -> std::optional<Env>;
 
 auto InterpProgram(const std::list<Ptr<const Declaration>>& fs) -> int;
-auto InterpExp(Env values, const Expression* e) -> const Value*;
-auto InterpPattern(Env values, const Pattern* p) -> const Value*;
+auto InterpExp(Env values, Ptr<const Expression> e) -> const Value*;
+auto InterpPattern(Env values, Ptr<const Pattern> p) -> const Value*;
 
 }  // namespace Carbon
 

+ 62 - 62
executable_semantics/interpreter/typecheck.cpp

@@ -55,20 +55,20 @@ static SourceLocation ReifyFakeSourceLoc() {
 }
 
 // Reify type to type expression.
-static auto ReifyType(const Value* t, SourceLocation loc) -> const Expression* {
+static auto ReifyType(const Value* t, SourceLocation loc)
+    -> Ptr<const Expression> {
   switch (t->Tag()) {
     case Value::Kind::IntType:
-      return global_arena->RawNew<IntTypeLiteral>(ReifyFakeSourceLoc());
+      return global_arena->New<IntTypeLiteral>(ReifyFakeSourceLoc());
     case Value::Kind::BoolType:
-      return global_arena->RawNew<BoolTypeLiteral>(ReifyFakeSourceLoc());
+      return global_arena->New<BoolTypeLiteral>(ReifyFakeSourceLoc());
     case Value::Kind::TypeType:
-      return global_arena->RawNew<TypeTypeLiteral>(ReifyFakeSourceLoc());
+      return global_arena->New<TypeTypeLiteral>(ReifyFakeSourceLoc());
     case Value::Kind::ContinuationType:
-      return global_arena->RawNew<ContinuationTypeLiteral>(
-          ReifyFakeSourceLoc());
+      return global_arena->New<ContinuationTypeLiteral>(ReifyFakeSourceLoc());
     case Value::Kind::FunctionType: {
       const auto& fn_type = cast<FunctionType>(*t);
-      return global_arena->RawNew<FunctionTypeLiteral>(
+      return global_arena->New<FunctionTypeLiteral>(
           ReifyFakeSourceLoc(), ReifyType(fn_type.Param(), loc),
           ReifyType(fn_type.Ret(), loc),
           /*is_omitted_return_type=*/false);
@@ -79,24 +79,24 @@ static auto ReifyType(const Value* t, SourceLocation loc) -> const Expression* {
         args.push_back(
             FieldInitializer(field.name, ReifyType(field.value, loc)));
       }
-      return global_arena->RawNew<TupleLiteral>(ReifyFakeSourceLoc(), args);
+      return global_arena->New<TupleLiteral>(ReifyFakeSourceLoc(), args);
     }
     case Value::Kind::ClassType:
-      return global_arena->RawNew<IdentifierExpression>(
+      return global_arena->New<IdentifierExpression>(
           ReifyFakeSourceLoc(), cast<ClassType>(*t).Name());
     case Value::Kind::ChoiceType:
-      return global_arena->RawNew<IdentifierExpression>(
+      return global_arena->New<IdentifierExpression>(
           ReifyFakeSourceLoc(), cast<ChoiceType>(*t).Name());
     case Value::Kind::PointerType:
-      return global_arena->RawNew<PrimitiveOperatorExpression>(
+      return global_arena->New<PrimitiveOperatorExpression>(
           ReifyFakeSourceLoc(), Operator::Ptr,
-          std::vector<const Expression*>(
+          std::vector<Ptr<const Expression>>(
               {ReifyType(cast<PointerType>(*t).Type(), loc)}));
     case Value::Kind::VariableType:
-      return global_arena->RawNew<IdentifierExpression>(
+      return global_arena->New<IdentifierExpression>(
           ReifyFakeSourceLoc(), cast<VariableType>(*t).Name());
     case Value::Kind::StringType:
-      return global_arena->RawNew<StringTypeLiteral>(ReifyFakeSourceLoc());
+      return global_arena->New<StringTypeLiteral>(ReifyFakeSourceLoc());
     case Value::Kind::AlternativeConstructorValue:
     case Value::Kind::AlternativeValue:
     case Value::Kind::AutoType:
@@ -266,7 +266,7 @@ static auto Substitute(TypeEnv dict, const Value* type) -> const Value* {
 // types maps variable names to the type of their run-time value.
 // values maps variable names to their compile-time values. It is not
 //    directly used in this function but is passed to InterExp.
-auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
+auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
     -> TCExpression {
   if (tracing_output) {
     llvm::outs() << "checking expression " << *e << "\ntypes: ";
@@ -289,9 +289,9 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
             FATAL_COMPILATION_ERROR(e->SourceLoc())
                 << "field " << f << " is not in the tuple " << *t;
           }
-          auto new_e = global_arena->RawNew<IndexExpression>(
+          auto new_e = global_arena->New<IndexExpression>(
               e->SourceLoc(), res.exp,
-              global_arena->RawNew<IntLiteral>(e->SourceLoc(), i));
+              global_arena->New<IntLiteral>(e->SourceLoc(), i));
           return TCExpression(new_e, field_t, res.types);
         }
         default:
@@ -308,8 +308,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
         new_args.push_back(FieldInitializer(arg.name, arg_res.exp));
         arg_types.push_back({.name = arg.name, .value = arg_res.type});
       }
-      auto tuple_e =
-          global_arena->RawNew<TupleLiteral>(e->SourceLoc(), new_args);
+      auto tuple_e = global_arena->New<TupleLiteral>(e->SourceLoc(), new_args);
       auto tuple_t = global_arena->RawNew<TupleValue>(std::move(arg_types));
       return TCExpression(tuple_e, tuple_t, new_types);
     }
@@ -323,8 +322,8 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
           // Search for a field
           for (auto& field : t_class.Fields()) {
             if (access.Field() == field.first) {
-              const Expression* new_e =
-                  global_arena->RawNew<FieldAccessExpression>(
+              Ptr<const Expression> new_e =
+                  global_arena->New<FieldAccessExpression>(
                       e->SourceLoc(), res.exp, access.Field());
               return TCExpression(new_e, field.second, res.types);
             }
@@ -332,8 +331,8 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
           // Search for a method
           for (auto& method : t_class.Methods()) {
             if (access.Field() == method.first) {
-              const Expression* new_e =
-                  global_arena->RawNew<FieldAccessExpression>(
+              Ptr<const Expression> new_e =
+                  global_arena->New<FieldAccessExpression>(
                       e->SourceLoc(), res.exp, access.Field());
               return TCExpression(new_e, method.second, res.types);
             }
@@ -346,7 +345,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
           const auto& tup = cast<TupleValue>(*t);
           for (const TupleElement& field : tup.Elements()) {
             if (access.Field() == field.name) {
-              auto new_e = global_arena->RawNew<FieldAccessExpression>(
+              auto new_e = global_arena->New<FieldAccessExpression>(
                   e->SourceLoc(), res.exp, access.Field());
               return TCExpression(new_e, field.value, res.types);
             }
@@ -359,8 +358,8 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
           const auto& choice = cast<ChoiceType>(*t);
           for (const auto& vt : choice.Alternatives()) {
             if (access.Field() == vt.first) {
-              const Expression* new_e =
-                  global_arena->RawNew<FieldAccessExpression>(
+              Ptr<const Expression> new_e =
+                  global_arena->New<FieldAccessExpression>(
                       e->SourceLoc(), res.exp, access.Field());
               auto fun_ty = global_arena->RawNew<FunctionType>(
                   std::vector<GenericBinding>(), vt.second, t);
@@ -393,16 +392,16 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
       return TCExpression(e, global_arena->RawNew<BoolType>(), types);
     case Expression::Kind::PrimitiveOperatorExpression: {
       const auto& op = cast<PrimitiveOperatorExpression>(*e);
-      std::vector<const Expression*> es;
+      std::vector<Ptr<const Expression>> es;
       std::vector<const Value*> ts;
       auto new_types = types;
-      for (const Expression* argument : op.Arguments()) {
+      for (Ptr<const Expression> argument : op.Arguments()) {
         auto res = TypeCheckExp(argument, types, values);
         new_types = res.types;
         es.push_back(res.exp);
         ts.push_back(res.type);
       }
-      auto new_e = global_arena->RawNew<PrimitiveOperatorExpression>(
+      auto new_e = global_arena->New<PrimitiveOperatorExpression>(
           e->SourceLoc(), op.Op(), es);
       switch (op.Op()) {
         case Operator::Neg:
@@ -492,7 +491,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
           } else {
             ExpectType(e->SourceLoc(), "call", parameter_type, arg_res.type);
           }
-          auto new_e = global_arena->RawNew<CallExpression>(
+          auto new_e = global_arena->New<CallExpression>(
               e->SourceLoc(), fun_res.exp, arg_res.exp);
           return TCExpression(new_e, return_type, arg_res.types);
         }
@@ -508,7 +507,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
       const auto& fn = cast<FunctionTypeLiteral>(*e);
       auto pt = InterpExp(values, fn.Parameter());
       auto rt = InterpExp(values, fn.ReturnType());
-      auto new_e = global_arena->RawNew<FunctionTypeLiteral>(
+      auto new_e = global_arena->New<FunctionTypeLiteral>(
           e->SourceLoc(), ReifyType(pt, e->SourceLoc()),
           ReifyType(rt, e->SourceLoc()),
           /*is_omitted_return_type=*/false);
@@ -533,7 +532,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
 // Equivalent to TypeCheckExp, but operates on Patterns instead of Expressions.
 // `expected` is the type that this pattern is expected to have, if the
 // surrounding context gives us that information. Otherwise, it is null.
-auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
+auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
                       const Value* expected) -> TCPattern {
   if (tracing_output) {
     llvm::outs() << "checking pattern " << *p;
@@ -569,9 +568,9 @@ auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
             << "Name bindings within type patterns are unsupported";
         type = expected;
       }
-      auto new_p = global_arena->RawNew<BindingPattern>(
+      auto new_p = global_arena->New<BindingPattern>(
           binding.SourceLoc(), binding.Name(),
-          global_arena->RawNew<ExpressionPattern>(
+          global_arena->New<ExpressionPattern>(
               ReifyType(type, binding.SourceLoc())));
       if (binding.Name().has_value()) {
         types.Set(*binding.Name(), type);
@@ -612,7 +611,7 @@ auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
         field_types.push_back({.name = field.name, .value = field_result.type});
       }
       auto new_tuple =
-          global_arena->RawNew<TuplePattern>(tuple.SourceLoc(), new_fields);
+          global_arena->New<TuplePattern>(tuple.SourceLoc(), new_fields);
       auto tuple_t = global_arena->RawNew<TupleValue>(std::move(field_types));
       return {.pattern = new_tuple, .type = tuple_t, .types = new_types};
     }
@@ -637,28 +636,30 @@ auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
       }
       TCPattern arg_results = TypeCheckPattern(alternative.Arguments(), types,
                                                values, parameter_types);
-      return {.pattern = global_arena->RawNew<AlternativePattern>(
+      // TODO: Think about a cleaner way to cast between Ptr types.
+      auto arguments = Ptr<const TuplePattern>(
+          cast<const TuplePattern>(arg_results.pattern.Get()));
+      return {.pattern = global_arena->New<AlternativePattern>(
                   alternative.SourceLoc(),
                   ReifyType(choice_type, alternative.SourceLoc()),
-                  alternative.AlternativeName(),
-                  cast<TuplePattern>(arg_results.pattern)),
+                  alternative.AlternativeName(), arguments),
               .type = choice_type,
               .types = arg_results.types};
     }
     case Pattern::Kind::ExpressionPattern: {
       TCExpression result =
-          TypeCheckExp(cast<ExpressionPattern>(p)->Expression(), types, values);
-      return {.pattern = global_arena->RawNew<ExpressionPattern>(result.exp),
+          TypeCheckExp(cast<ExpressionPattern>(*p).Expression(), types, values);
+      return {.pattern = global_arena->New<ExpressionPattern>(result.exp),
               .type = result.type,
               .types = result.types};
     }
   }
 }
 
-static auto TypecheckCase(const Value* expected, const Pattern* pat,
+static auto TypecheckCase(const Value* expected, Ptr<const Pattern> pat,
                           const Statement* body, TypeEnv types, Env values,
                           const Value*& ret_type, bool is_omitted_ret_type)
-    -> std::pair<const Pattern*, const Statement*> {
+    -> std::pair<Ptr<const Pattern>, const Statement*> {
   auto pat_res = TypeCheckPattern(pat, types, values, expected);
   auto res =
       TypeCheckStmt(body, pat_res.types, values, ret_type, is_omitted_ret_type);
@@ -684,7 +685,7 @@ auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,
       auto res = TypeCheckExp(match.Exp(), types, values);
       auto res_type = res.type;
       auto new_clauses = global_arena->RawNew<
-          std::list<std::pair<const Pattern*, const Statement*>>>();
+          std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
       for (auto& clause : *match.Clauses()) {
         new_clauses->push_back(TypecheckCase(res_type, clause.first,
                                              clause.second, types, values,
@@ -817,8 +818,7 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
                                 SourceLocation loc) -> const Statement* {
   if (!stmt) {
     if (omitted_ret_type) {
-      return global_arena->RawNew<Return>(loc, nullptr,
-                                          /*is_omitted_exp=*/true);
+      return global_arena->RawNew<Return>(loc);
     } else {
       FATAL_COMPILATION_ERROR(loc)
           << "control-flow reaches end of function that provides a `->` return "
@@ -829,7 +829,7 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
     case Statement::Kind::Match: {
       const auto& match = cast<Match>(*stmt);
       auto new_clauses = global_arena->RawNew<
-          std::list<std::pair<const Pattern*, const Statement*>>>();
+          std::list<std::pair<Ptr<const Pattern>, const Statement*>>>();
       for (const auto& clause : *match.Clauses()) {
         auto s = CheckOrEnsureReturn(clause.second, omitted_ret_type,
                                      stmt->SourceLoc());
@@ -878,9 +878,7 @@ static auto CheckOrEnsureReturn(const Statement* stmt, bool omitted_ret_type,
     case Statement::Kind::VariableDefinition:
       if (omitted_ret_type) {
         return global_arena->RawNew<Sequence>(
-            stmt->SourceLoc(), stmt,
-            global_arena->RawNew<Return>(loc, nullptr,
-                                         /*is_omitted_exp=*/true));
+            stmt->SourceLoc(), stmt, global_arena->RawNew<Return>(loc));
       } else {
         FATAL_COMPILATION_ERROR(stmt->SourceLoc())
             << "control-flow reaches end of function that provides a `->` "
@@ -917,7 +915,7 @@ static auto TypeCheckFunDef(const FunctionDefinition* f, TypeEnv types,
                                   f->source_location);
   return global_arena->New<FunctionDefinition>(
       f->source_location, f->name, f->deduced_parameters, f->param_pattern,
-      global_arena->RawNew<ExpressionPattern>(
+      global_arena->New<ExpressionPattern>(
           ReifyType(return_type, f->source_location)),
       /*is_omitted_return_type=*/false, body);
 }
@@ -951,18 +949,18 @@ static auto TypeOfClassDef(const ClassDefinition* sd, TypeEnv /*types*/,
   for (Ptr<const Member> m : sd->members) {
     switch (m->Tag()) {
       case Member::Kind::FieldMember: {
-        const BindingPattern* binding = cast<FieldMember>(*m).Binding();
+        Ptr<const BindingPattern> binding = cast<FieldMember>(*m).Binding();
         if (!binding->Name().has_value()) {
           FATAL_COMPILATION_ERROR(binding->SourceLoc())
               << "Struct members must have names";
         }
-        const Expression* type_expression =
-            dyn_cast<ExpressionPattern>(binding->Type())->Expression();
-        if (type_expression == nullptr) {
+        const auto* binding_type =
+            dyn_cast<ExpressionPattern>(binding->Type().Get());
+        if (binding_type == nullptr) {
           FATAL_COMPILATION_ERROR(binding->SourceLoc())
               << "Struct members must have explicit types";
         }
-        auto type = InterpExp(ct_top, type_expression);
+        auto type = InterpExp(ct_top, binding_type->Expression());
         fields.push_back(std::make_pair(*binding->Name(), type));
         break;
       }
@@ -981,7 +979,8 @@ static auto GetName(const Declaration& d) -> const std::string& {
     case Declaration::Kind::ChoiceDeclaration:
       return cast<ChoiceDeclaration>(d).Name();
     case Declaration::Kind::VariableDeclaration: {
-      const BindingPattern* binding = cast<VariableDeclaration>(d).Binding();
+      Ptr<const BindingPattern> binding =
+          cast<VariableDeclaration>(d).Binding();
       if (!binding->Name().has_value()) {
         FATAL_COMPILATION_ERROR(binding->SourceLoc())
             << "Top-level variable declarations must have names";
@@ -1025,14 +1024,15 @@ auto MakeTypeChecked(const Ptr<const Declaration> d, const TypeEnv& types,
       // declaration with annotated types.
       TCExpression type_checked_initializer =
           TypeCheckExp(var.Initializer(), types, values);
-      const Expression* type =
-          dyn_cast<ExpressionPattern>(var.Binding()->Type())->Expression();
-      if (type == nullptr) {
+      const auto* binding_type =
+          dyn_cast<ExpressionPattern>(var.Binding()->Type().Get());
+      if (binding_type == nullptr) {
         // TODO: consider adding support for `auto`
         FATAL_COMPILATION_ERROR(var.SourceLoc())
             << "Type of a top-level variable must be an expression.";
       }
-      const Value* declared_type = InterpExp(values, type);
+      const Value* declared_type =
+          InterpExp(values, binding_type->Expression());
       ExpectType(var.SourceLoc(), "initializer of variable", declared_type,
                  type_checked_initializer.type);
       return d;
@@ -1087,8 +1087,8 @@ static void TopLevel(const Declaration& d, TypeCheckContext* tops) {
       const auto& var = cast<VariableDeclaration>(d);
       // Associate the variable name with it's declared type in the
       // compile-time symbol table.
-      const Expression* type =
-          cast<ExpressionPattern>(var.Binding()->Type())->Expression();
+      Ptr<const Expression> type =
+          cast<ExpressionPattern>(*var.Binding()->Type()).Expression();
       const Value* declared_type = InterpExp(tops->values, type);
       tops->types.Set(*var.Binding()->Name(), declared_type);
       break;

+ 5 - 5
executable_semantics/interpreter/typecheck.h

@@ -19,16 +19,16 @@ namespace Carbon {
 using TypeEnv = Dictionary<std::string, const Value*>;
 
 struct TCExpression {
-  TCExpression(const Expression* e, const Value* t, TypeEnv types)
+  TCExpression(Ptr<const Expression> e, const Value* t, TypeEnv types)
       : exp(e), type(t), types(types) {}
 
-  const Expression* exp;
+  Ptr<const Expression> exp;
   const Value* type;
   TypeEnv types;
 };
 
 struct TCPattern {
-  const Pattern* pattern;
+  Ptr<const Pattern> pattern;
   const Value* type;
   TypeEnv types;
 };
@@ -47,9 +47,9 @@ struct TypeCheckContext {
   Env values;
 };
 
-auto TypeCheckExp(const Expression* e, TypeEnv types, Env values)
+auto TypeCheckExp(Ptr<const Expression> e, TypeEnv types, Env values)
     -> TCExpression;
-auto TypeCheckPattern(const Pattern* p, TypeEnv types, Env values,
+auto TypeCheckPattern(Ptr<const Pattern> p, TypeEnv types, Env values,
                       const Value* expected) -> TCPattern;
 
 auto TypeCheckStmt(const Statement* s, TypeEnv types, Env values,

+ 6 - 3
executable_semantics/syntax/bison_wrap.h

@@ -23,9 +23,12 @@ class BisonWrap {
     return *this;
   }
 
-  // Support transparent conversion to the wrapped type, erroring if not
-  // initialized.
-  operator T() {
+  // Support transparent conversion to the wrapped type.
+  operator T() { return Release(); }
+
+  // Deliberately releases the contained value. Errors if not initialized.
+  // Called directly in parser.ypp when releasing pairs.
+  auto Release() -> T {
     CHECK(val.has_value());
     T ret = std::move(*val);
     val.reset();

+ 91 - 82
executable_semantics/syntax/parser.ypp

@@ -99,35 +99,35 @@ void Carbon::Parser::error(const location_type&, const std::string& message) {
 %type <const Statement*> statement
 %type <const Statement*> if_statement
 %type <const Statement*> optional_else
-%type <std::pair<const Expression*, bool>> return_expression
+%type <BisonWrap<std::pair<Ptr<const Expression>, bool>>> return_expression
 %type <const Statement*> block
 %type <const Statement*> statement_list
-%type <const Expression*> expression
-%type <GenericBinding> generic_binding
+%type <BisonWrap<Ptr<const Expression>>> expression
+%type <BisonWrap<GenericBinding>> generic_binding
 %type <std::vector<GenericBinding>> deduced_params
 %type <std::vector<GenericBinding>> deduced_param_list
-%type <const Pattern*> pattern
-%type <const Pattern*> non_expression_pattern
-%type <std::pair<const Expression*, bool>> return_type
-%type <const Expression*> paren_expression
-%type <const Expression*> tuple
+%type <BisonWrap<Ptr<const Pattern>>> pattern
+%type <BisonWrap<Ptr<const Pattern>>> non_expression_pattern
+%type <BisonWrap<std::pair<Ptr<const Expression>, bool>>> return_type
+%type <BisonWrap<Ptr<const Expression>>> paren_expression
+%type <BisonWrap<Ptr<const Expression>>> tuple
 %type <std::optional<std::string>> binding_lhs
-%type <const BindingPattern*> variable_declaration
+%type <BisonWrap<Ptr<const BindingPattern>>> variable_declaration
 %type <BisonWrap<Ptr<Member>>> member
 %type <std::list<Ptr<Member>>> member_list
-%type <ParenContents<Expression>::Element> paren_expression_element
+%type <BisonWrap<ParenContents<Expression>::Element>> paren_expression_element
 %type <ParenContents<Expression>> paren_expression_base
 %type <ParenContents<Expression>> paren_expression_contents
-%type <const Pattern*> paren_pattern
-%type <const TuplePattern*> tuple_pattern
-%type <const TuplePattern*> maybe_empty_tuple_pattern
+%type <BisonWrap<Ptr<const Pattern>>> paren_pattern
+%type <BisonWrap<Ptr<const TuplePattern>>> tuple_pattern
+%type <BisonWrap<Ptr<const TuplePattern>>> maybe_empty_tuple_pattern
 %type <ParenContents<Pattern>> paren_pattern_base
-%type <ParenContents<Pattern>::Element> paren_pattern_element
+%type <BisonWrap<ParenContents<Pattern>::Element>> paren_pattern_element
 %type <ParenContents<Pattern>> paren_pattern_contents
-%type <std::pair<std::string, const Expression*>> alternative
-%type <std::list<std::pair<std::string, const Expression*>>> alternative_list
-%type <std::pair<const Pattern*, const Statement*>*> clause
-%type <std::list<std::pair<const Pattern*, const Statement*>>*> clause_list
+%type <BisonWrap<std::pair<std::string, Ptr<const Expression>>>> alternative
+%type <std::list<std::pair<std::string, Ptr<const Expression>>>> alternative_list
+%type <std::pair<Ptr<const Pattern>, const Statement*>*> clause
+%type <std::list<std::pair<Ptr<const Pattern>, const Statement*>>*> clause_list
 %token END_OF_FILE 0
 %token AND
 %token OR
@@ -213,76 +213,78 @@ input: declaration_list
 ;
 expression:
   identifier
-    { $$ = global_arena->RawNew<IdentifierExpression>(context.SourceLoc(), $1); }
+    { $$ = global_arena->New<IdentifierExpression>(context.SourceLoc(), $1); }
 | expression designator
-    { $$ = global_arena->RawNew<FieldAccessExpression>(context.SourceLoc(), $1, $2); }
+    { $$ = global_arena->New<FieldAccessExpression>(context.SourceLoc(), $1, $2); }
 | expression "[" expression "]"
-    { $$ = global_arena->RawNew<IndexExpression>(context.SourceLoc(), $1, $3); }
+    { $$ = global_arena->New<IndexExpression>(context.SourceLoc(), $1, $3); }
 | integer_literal
-    { $$ = global_arena->RawNew<IntLiteral>(context.SourceLoc(), $1); }
+    { $$ = global_arena->New<IntLiteral>(context.SourceLoc(), $1); }
 | string_literal
-    { $$ = global_arena->RawNew<StringLiteral>(context.SourceLoc(), $1); }
+    { $$ = global_arena->New<StringLiteral>(context.SourceLoc(), $1); }
 | TRUE
-    { $$ = global_arena->RawNew<BoolLiteral>(context.SourceLoc(), true); }
+    { $$ = global_arena->New<BoolLiteral>(context.SourceLoc(), true); }
 | FALSE
-    { $$ = global_arena->RawNew<BoolLiteral>(context.SourceLoc(), false); }
+    { $$ = global_arena->New<BoolLiteral>(context.SourceLoc(), false); }
 | sized_type_literal
     {
       int val;
       CHECK(llvm::to_integer(llvm::StringRef($1).substr(1), val));
       CHECK($1[0] == 'i' && val == 32)  << "Only i32 is supported for now: " << $1;
-      $$ = global_arena->RawNew<IntTypeLiteral>(context.SourceLoc());
+      $$ = global_arena->New<IntTypeLiteral>(context.SourceLoc());
     }
 | STRING
-    { $$ = global_arena->RawNew<StringTypeLiteral>(context.SourceLoc()); }
+    { $$ = global_arena->New<StringTypeLiteral>(context.SourceLoc()); }
 | BOOL
-    { $$ = global_arena->RawNew<BoolTypeLiteral>(context.SourceLoc()); }
+    { $$ = global_arena->New<BoolTypeLiteral>(context.SourceLoc()); }
 | TYPE
-    { $$ = global_arena->RawNew<TypeTypeLiteral>(context.SourceLoc()); }
+    { $$ = global_arena->New<TypeTypeLiteral>(context.SourceLoc()); }
 | CONTINUATION_TYPE
-    { $$ = global_arena->RawNew<ContinuationTypeLiteral>(context.SourceLoc()); }
+    { $$ = global_arena->New<ContinuationTypeLiteral>(context.SourceLoc()); }
 | paren_expression { $$ = $1; }
 | expression EQUAL_EQUAL expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Eq, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Eq, std::vector<Ptr<const Expression>>({$1, $3})); }
 | expression "+" expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Add, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Add, std::vector<Ptr<const Expression>>({$1, $3})); }
 | expression "-" expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Sub, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Sub, std::vector<Ptr<const Expression>>({$1, $3})); }
 | expression BINARY_STAR expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Mul, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Mul, std::vector<Ptr<const Expression>>({$1, $3})); }
 | expression AND expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::And, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::And, std::vector<Ptr<const Expression>>({$1, $3})); }
 | expression OR expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Or, std::vector<const Expression*>({$1, $3})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Or, std::vector<Ptr<const Expression>>({$1, $3})); }
 | NOT expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Not, std::vector<const Expression*>({$2})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Not, std::vector<Ptr<const Expression>>({$2})); }
 | "-" expression %prec UNARY_MINUS
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Neg, std::vector<const Expression*>({$2})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Neg, std::vector<Ptr<const Expression>>({$2})); }
 | PREFIX_STAR expression
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Deref, std::vector<const Expression*>({$2})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Deref, std::vector<Ptr<const Expression>>({$2})); }
 | UNARY_STAR expression %prec PREFIX_STAR
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Deref, std::vector<const Expression*>({$2})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Deref, std::vector<Ptr<const Expression>>({$2})); }
 | expression tuple
-    { $$ = global_arena->RawNew<CallExpression>(context.SourceLoc(), $1, $2); }
+    { $$ = global_arena->New<CallExpression>(context.SourceLoc(), $1, $2); }
 | expression POSTFIX_STAR
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Ptr, std::vector<const Expression*>({$1})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Ptr, std::vector<Ptr<const Expression>>({$1})); }
 | expression UNARY_STAR
-    { $$ = global_arena->RawNew<PrimitiveOperatorExpression>(
-        context.SourceLoc(), Operator::Ptr, std::vector<const Expression*>({$1})); }
+    { $$ = global_arena->New<PrimitiveOperatorExpression>(
+        context.SourceLoc(), Operator::Ptr, std::vector<Ptr<const Expression>>({$1})); }
 | FNTY tuple return_type
-    { $$ = global_arena->RawNew<FunctionTypeLiteral>(
-        context.SourceLoc(), $2, $3.first, $3.second); }
+    {
+      auto [return_exp, is_omitted_exp] = $3.Release();
+      $$ = global_arena->New<FunctionTypeLiteral>(
+        context.SourceLoc(), $2, return_exp, is_omitted_exp); }
 ;
 designator: "." identifier { $$ = $2; }
 ;
@@ -329,17 +331,17 @@ pattern:
   non_expression_pattern
     { $$ = $1; }
 | expression
-    { $$ = global_arena->RawNew<ExpressionPattern>($1); }
+    { $$ = global_arena->New<ExpressionPattern>($1); }
 ;
 non_expression_pattern:
   AUTO
-    { $$ = global_arena->RawNew<AutoPattern>(context.SourceLoc()); }
+    { $$ = global_arena->New<AutoPattern>(context.SourceLoc()); }
 | binding_lhs ":" pattern
-    { $$ = global_arena->RawNew<BindingPattern>(context.SourceLoc(), $1, $3); }
+    { $$ = global_arena->New<BindingPattern>(context.SourceLoc(), $1, $3); }
 | paren_pattern
     { $$ = $1; }
 | expression tuple_pattern
-    { $$ = global_arena->RawNew<AlternativePattern>(context.SourceLoc(), $1, $2); }
+    { $$ = global_arena->New<AlternativePattern>(context.SourceLoc(), $1, $2); }
 ;
 binding_lhs:
   identifier { $$ = $1; }
@@ -373,7 +375,8 @@ paren_pattern_contents:
 | paren_pattern_contents "," paren_expression_element
     {
       $$ = $1;
-      $$.elements.push_back({.name = $3.name, .term = global_arena->RawNew<ExpressionPattern>($3.term)});
+      auto el = $3.Release();
+      $$.elements.push_back({.name = el.name, .term = global_arena->New<ExpressionPattern>(el.term)});
     }
 | paren_pattern_contents "," paren_pattern_element
     {
@@ -395,25 +398,25 @@ tuple_pattern: paren_pattern_base
 // rules out the possibility of an `expression` at this point.
 maybe_empty_tuple_pattern:
   "(" ")"
-    { $$ = global_arena->RawNew<TuplePattern>(context.SourceLoc(), std::vector<TuplePattern::Field>()); }
+    { $$ = global_arena->New<TuplePattern>(context.SourceLoc(), std::vector<TuplePattern::Field>()); }
 | tuple_pattern
     { $$ = $1; }
 ;
 clause:
   CASE pattern DBLARROW statement
-    { $$ = global_arena->RawNew<std::pair<const Pattern*, const Statement*>>($2, $4); }
+    { $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, const Statement*>>($2, $4); }
 | DEFAULT DBLARROW statement
     {
-      auto vp = global_arena->RawNew<BindingPattern>(
-          context.SourceLoc(), std::nullopt, global_arena->RawNew<AutoPattern>(context.SourceLoc()));
-      $$ = global_arena->RawNew<std::pair<const Pattern*, const Statement*>>(vp, $3);
+      auto vp = global_arena->New<BindingPattern>(
+          context.SourceLoc(), std::nullopt, global_arena->New<AutoPattern>(context.SourceLoc()));
+      $$ = global_arena->RawNew<std::pair<Ptr<const Pattern>, const Statement*>>(vp, $3);
     }
 ;
 clause_list:
   // Empty
     {
       $$ = global_arena->RawNew<std::list<
-          std::pair<const Pattern*, const Statement*>>>();
+          std::pair<Ptr<const Pattern>, const Statement*>>>();
     }
 | clause clause_list
     { $$ = $2; $$->push_front(*$1); }
@@ -434,7 +437,10 @@ statement:
 | CONTINUE ";"
     { $$ = global_arena->RawNew<Continue>(context.SourceLoc()); }
 | RETURN return_expression ";"
-    { $$ = global_arena->RawNew<Return>(context.SourceLoc(), $2.first, $2.second); }
+    {
+      auto [return_exp, is_omitted_exp] = $2.Release();
+      $$ = global_arena->RawNew<Return>(context.SourceLoc(), return_exp, is_omitted_exp);
+    }
 | block
     { $$ = $1; }
 | MATCH "(" expression ")" "{" clause_list "}"
@@ -460,7 +466,7 @@ optional_else:
 ;
 return_expression:
   // Empty
-    { $$ = {global_arena->RawNew<TupleLiteral>(context.SourceLoc()), true}; }
+    { $$ = {global_arena->New<TupleLiteral>(context.SourceLoc()), true}; }
 | expression
     { $$ = {$1, false}; }
 ;
@@ -476,7 +482,7 @@ block:
 ;
 return_type:
   // Empty
-    { $$ = {global_arena->RawNew<TupleLiteral>(context.SourceLoc()), true}; }
+    { $$ = {global_arena->New<TupleLiteral>(context.SourceLoc()), true}; }
 | ARROW expression %prec FNARROW
     { $$ = {$2, false}; }
 ;
@@ -509,10 +515,11 @@ deduced_params:
 function_definition:
   FN identifier deduced_params maybe_empty_tuple_pattern return_type block
     {
+      auto [return_exp, is_omitted_exp] = $5.Release();
       $$ = global_arena->New<FunctionDefinition>(
           context.SourceLoc(), $2, $3, $4,
-          global_arena->RawNew<ExpressionPattern>($5.first),
-          $5.second, $6);
+          global_arena->New<ExpressionPattern>(return_exp),
+          is_omitted_exp, $6);
     }
 | FN identifier deduced_params maybe_empty_tuple_pattern DBLARROW expression ";"
     {
@@ -520,20 +527,22 @@ function_definition:
       // the expression.
       $$ = global_arena->New<FunctionDefinition>(
           context.SourceLoc(), $2, $3, $4,
-          global_arena->RawNew<AutoPattern>(context.SourceLoc()), true,
+          global_arena->New<AutoPattern>(context.SourceLoc()), true,
           global_arena->RawNew<Return>(context.SourceLoc(), $6, true));
     }
 ;
 function_declaration:
   FN identifier deduced_params maybe_empty_tuple_pattern return_type ";"
     {
+      auto [return_exp, is_omitted_exp] = $5.Release();
       $$ = global_arena->New<FunctionDefinition>(
           context.SourceLoc(), $2, $3, $4,
-          global_arena->RawNew<ExpressionPattern>($5.first),
-          $5.second, nullptr); }
+          global_arena->New<ExpressionPattern>(return_exp),
+          is_omitted_exp, nullptr);
+    }
 ;
 variable_declaration: identifier ":" pattern
-    { $$ = global_arena->RawNew<BindingPattern>(context.SourceLoc(), $1, $3); }
+    { $$ = global_arena->New<BindingPattern>(context.SourceLoc(), $1, $3); }
 ;
 member: VAR variable_declaration ";"
     { $$ = global_arena->New<FieldMember>(context.SourceLoc(), $2); }
@@ -546,19 +555,19 @@ member_list:
 ;
 alternative:
   identifier tuple
-    { $$ = std::pair<std::string, const Expression*>($1, $2); }
+    { $$ = std::pair<std::string, Ptr<const Expression>>($1, $2); }
 | identifier
     {
-      $$ = std::pair<std::string, const Expression*>(
-          $1, global_arena->RawNew<TupleLiteral>(context.SourceLoc()));
+      $$ = std::pair<std::string, Ptr<const Expression>>(
+          $1, global_arena->New<TupleLiteral>(context.SourceLoc()));
     }
 ;
 alternative_list:
   // Empty
-    { $$ = std::list<std::pair<std::string, const Expression*>>(); }
+    { $$ = std::list<std::pair<std::string, Ptr<const Expression>>>(); }
 | alternative
     {
-      $$ = std::list<std::pair<std::string, const Expression*>>();
+      $$ = std::list<std::pair<std::string, Ptr<const Expression>>>();
       $$.push_front($1);
     }
 | alternative "," alternative_list

+ 7 - 7
executable_semantics/syntax/syntax_helpers.cpp

@@ -18,21 +18,21 @@ namespace Carbon {
 static void AddIntrinsics(std::list<Ptr<const Declaration>>* fs) {
   SourceLocation loc("<intrinsic>", 0);
   std::vector<TuplePattern::Field> print_fields = {TuplePattern::Field(
-      "0", global_arena->RawNew<BindingPattern>(
+      "0", global_arena->New<BindingPattern>(
                loc, "format_str",
-               global_arena->RawNew<ExpressionPattern>(
-                   global_arena->RawNew<StringTypeLiteral>(loc))))};
+               global_arena->New<ExpressionPattern>(
+                   global_arena->New<StringTypeLiteral>(loc))))};
   auto* print_return = global_arena->RawNew<Return>(
       loc,
-      global_arena->RawNew<IntrinsicExpression>(
+      global_arena->New<IntrinsicExpression>(
           IntrinsicExpression::IntrinsicKind::Print),
       false);
   auto print = global_arena->New<FunctionDeclaration>(
       global_arena->New<FunctionDefinition>(
           loc, "Print", std::vector<GenericBinding>(),
-          global_arena->RawNew<TuplePattern>(loc, print_fields),
-          global_arena->RawNew<ExpressionPattern>(
-              global_arena->RawNew<TupleLiteral>(loc)),
+          global_arena->New<TuplePattern>(loc, print_fields),
+          global_arena->New<ExpressionPattern>(
+              global_arena->New<TupleLiteral>(loc)),
           /*is_omitted_return_type=*/false, print_return));
   fs->insert(fs->begin(), print);
 }