Просмотр исходного кода

Refactor `Expression` to use `std::variant` instead of a union (#586)

Geoff Romer 4 лет назад
Родитель
Сommit
1916258e9b

+ 44 - 76
executable_semantics/ast/expression.cpp

@@ -9,107 +9,94 @@
 
 namespace Carbon {
 
-Variable Expression::GetVariable() const {
-  assert(tag == ExpressionKind::Variable);
-  return u.variable;
+auto Expression::GetVariable() const -> const Variable& {
+  return std::get<Variable>(value);
 }
 
-FieldAccess Expression::GetFieldAccess() const {
-  assert(tag == ExpressionKind::GetField);
-  return u.get_field;
+auto Expression::GetFieldAccess() const -> const FieldAccess& {
+  return std::get<FieldAccess>(value);
 }
 
-Index Expression::GetIndex() const {
-  assert(tag == ExpressionKind::Index);
-  return u.index;
+auto Expression::GetIndex() const -> const Index& {
+  return std::get<Index>(value);
 }
 
-PatternVariable Expression::GetPatternVariable() const {
-  assert(tag == ExpressionKind::PatternVariable);
-  return u.pattern_variable;
+auto Expression::GetPatternVariable() const -> const PatternVariable& {
+  return std::get<PatternVariable>(value);
 }
 
-int Expression::GetInteger() const {
-  assert(tag == ExpressionKind::Integer);
-  return u.integer;
+auto Expression::GetInteger() const -> int {
+  return std::get<IntLiteral>(value).value;
 }
 
-bool Expression::GetBoolean() const {
-  assert(tag == ExpressionKind::Boolean);
-  return u.boolean;
+auto Expression::GetBoolean() const -> bool {
+  return std::get<BoolLiteral>(value).value;
 }
 
-Tuple Expression::GetTuple() const {
-  assert(tag == ExpressionKind::Tuple);
-  return u.tuple;
+auto Expression::GetTuple() const -> const Tuple& {
+  return std::get<Tuple>(value);
 }
 
-PrimitiveOperator Expression::GetPrimitiveOperator() const {
-  assert(tag == ExpressionKind::PrimitiveOp);
-  return u.primitive_op;
+auto Expression::GetPrimitiveOperator() const -> const PrimitiveOperator& {
+  return std::get<PrimitiveOperator>(value);
 }
 
-Call Expression::GetCall() const {
-  assert(tag == ExpressionKind::Call);
-  return u.call;
+auto Expression::GetCall() const -> const Call& {
+  return std::get<Call>(value);
 }
 
-FunctionType Expression::GetFunctionType() const {
-  assert(tag == ExpressionKind::FunctionT);
-  return u.function_type;
+auto Expression::GetFunctionType() const -> const FunctionType& {
+  return std::get<FunctionType>(value);
 }
 
 auto Expression::MakeTypeType(int line_num) -> const Expression* {
   auto* t = new Expression();
-  t->tag = ExpressionKind::TypeT;
   t->line_num = line_num;
+  t->value = TypeT();
   return t;
 }
 
 auto Expression::MakeIntType(int line_num) -> const Expression* {
   auto* t = new Expression();
-  t->tag = ExpressionKind::IntT;
   t->line_num = line_num;
+  t->value = IntT();
   return t;
 }
 
 auto Expression::MakeBoolType(int line_num) -> const Expression* {
   auto* t = new Expression();
-  t->tag = ExpressionKind::BoolT;
   t->line_num = line_num;
+  t->value = BoolT();
   return t;
 }
 
 auto Expression::MakeAutoType(int line_num) -> const Expression* {
   auto* t = new Expression();
-  t->tag = ExpressionKind::AutoT;
   t->line_num = line_num;
+  t->value = AutoT();
   return t;
 }
 
 // Returns a Continuation type AST node at the given source location.
 auto Expression::MakeContinuationType(int line_num) -> const Expression* {
   auto* type = new Expression();
-  type->tag = ExpressionKind::ContinuationT;
   type->line_num = line_num;
+  type->value = ContinuationT();
   return type;
 }
 
 auto Expression::MakeFunType(int line_num, const Expression* param,
                              const Expression* ret) -> const Expression* {
   auto* t = new Expression();
-  t->tag = ExpressionKind::FunctionT;
   t->line_num = line_num;
-  t->u.function_type.parameter = param;
-  t->u.function_type.return_type = ret;
+  t->value = FunctionType({.parameter = param, .return_type = ret});
   return t;
 }
 
 auto Expression::MakeVar(int line_num, std::string var) -> const Expression* {
   auto* v = new Expression();
   v->line_num = line_num;
-  v->tag = ExpressionKind::Variable;
-  v->u.variable.name = new std::string(std::move(var));
+  v->value = Variable({.name = new std::string(std::move(var))});
   return v;
 }
 
@@ -117,25 +104,22 @@ auto Expression::MakeVarPat(int line_num, std::string var,
                             const Expression* type) -> const Expression* {
   auto* v = new Expression();
   v->line_num = line_num;
-  v->tag = ExpressionKind::PatternVariable;
-  v->u.pattern_variable.name = new std::string(std::move(var));
-  v->u.pattern_variable.type = type;
+  v->value =
+      PatternVariable({.name = new std::string(std::move(var)), .type = type});
   return v;
 }
 
 auto Expression::MakeInt(int line_num, int i) -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::Integer;
-  e->u.integer = i;
+  e->value = IntLiteral({.value = i});
   return e;
 }
 
 auto Expression::MakeBool(int line_num, bool b) -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::Boolean;
-  e->u.boolean = b;
+  e->value = BoolLiteral({.value = b});
   return e;
 }
 
@@ -144,9 +128,7 @@ auto Expression::MakeOp(int line_num, enum Operator op,
     -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::PrimitiveOp;
-  e->u.primitive_op.op = op;
-  e->u.primitive_op.arguments = args;
+  e->value = PrimitiveOperator({.op = op, .arguments = args});
   return e;
 }
 
@@ -154,11 +136,8 @@ auto Expression::MakeUnOp(int line_num, enum Operator op, const Expression* arg)
     -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::PrimitiveOp;
-  e->u.primitive_op.op = op;
-  auto* args = new std::vector<const Expression*>();
-  args->push_back(arg);
-  e->u.primitive_op.arguments = args;
+  e->value = PrimitiveOperator(
+      {.op = op, .arguments = new std::vector<const Expression*>{arg}});
   return e;
 }
 
@@ -167,12 +146,8 @@ auto Expression::MakeBinOp(int line_num, enum Operator op,
     -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::PrimitiveOp;
-  e->u.primitive_op.op = op;
-  auto* args = new std::vector<const Expression*>();
-  args->push_back(arg1);
-  args->push_back(arg2);
-  e->u.primitive_op.arguments = args;
+  e->value = PrimitiveOperator(
+      {.op = op, .arguments = new std::vector<const Expression*>{arg1, arg2}});
   return e;
 }
 
@@ -180,9 +155,7 @@ auto Expression::MakeCall(int line_num, const Expression* fun,
                           const Expression* arg) -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::Call;
-  e->u.call.function = fun;
-  e->u.call.argument = arg;
+  e->value = Call({.function = fun, .argument = arg});
   return e;
 }
 
@@ -190,9 +163,8 @@ auto Expression::MakeGetField(int line_num, const Expression* exp,
                               std::string field) -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::GetField;
-  e->u.get_field.aggregate = exp;
-  e->u.get_field.field = new std::string(std::move(field));
+  e->value = FieldAccess(
+      {.aggregate = exp, .field = new std::string(std::move(field))});
   return e;
 }
 
@@ -200,7 +172,6 @@ auto Expression::MakeTuple(int line_num, std::vector<FieldInitializer>* args)
     -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::Tuple;
   int i = 0;
   bool seen_named_member = false;
   for (auto& arg : *args) {
@@ -217,7 +188,7 @@ auto Expression::MakeTuple(int line_num, std::vector<FieldInitializer>* args)
       seen_named_member = true;
     }
   }
-  e->u.tuple.fields = args;
+  e->value = Tuple({.fields = args});
   return e;
 }
 
@@ -227,9 +198,8 @@ auto Expression::MakeTuple(int line_num, std::vector<FieldInitializer>* args)
 auto Expression::MakeUnit(int line_num) -> const Expression* {
   auto* unit = new Expression();
   unit->line_num = line_num;
-  unit->tag = ExpressionKind::Tuple;
   auto* args = new std::vector<FieldInitializer>();
-  unit->u.tuple.fields = args;
+  unit->value = Tuple({.fields = args});
   return unit;
 }
 
@@ -237,9 +207,7 @@ auto Expression::MakeIndex(int line_num, const Expression* exp,
                            const Expression* i) -> const Expression* {
   auto* e = new Expression();
   e->line_num = line_num;
-  e->tag = ExpressionKind::Index;
-  e->u.index.aggregate = exp;
-  e->u.index.offset = i;
+  e->value = Index({.aggregate = exp, .offset = i});
   return e;
 }
 
@@ -284,7 +252,7 @@ static void PrintFields(std::vector<FieldInitializer>* fields) {
 }
 
 void PrintExp(const Expression* e) {
-  switch (e->tag) {
+  switch (e->tag()) {
     case ExpressionKind::Index:
       PrintExp(e->GetIndex().aggregate);
       std::cout << "[";
@@ -340,7 +308,7 @@ void PrintExp(const Expression* e) {
       break;
     case ExpressionKind::Call:
       PrintExp(e->GetCall().function);
-      if (e->GetCall().argument->tag == ExpressionKind::Tuple) {
+      if (e->GetCall().argument->tag() == ExpressionKind::Tuple) {
         PrintExp(e->GetCall().argument);
       } else {
         std::cout << "(";

+ 67 - 23
executable_semantics/ast/expression.h

@@ -6,6 +6,7 @@
 #define EXECUTABLE_SEMANTICS_AST_EXPRESSION_H_
 
 #include <string>
+#include <variant>
 #include <vector>
 
 namespace Carbon {
@@ -55,46 +56,84 @@ enum class Operator {
 struct Expression;
 
 struct Variable {
+  static constexpr ExpressionKind Kind = ExpressionKind::Variable;
   std::string* name;
 };
 
 struct FieldAccess {
+  static constexpr ExpressionKind Kind = ExpressionKind::GetField;
   const Expression* aggregate;
   std::string* field;
 };
 
 struct Index {
+  static constexpr ExpressionKind Kind = ExpressionKind::Index;
   const Expression* aggregate;
   const Expression* offset;
 };
 
 struct PatternVariable {
+  static constexpr ExpressionKind Kind = ExpressionKind::PatternVariable;
   std::string* name;
   const Expression* type;
 };
 
+struct IntLiteral {
+  static constexpr ExpressionKind Kind = ExpressionKind::Integer;
+  int value;
+};
+
+struct BoolLiteral {
+  static constexpr ExpressionKind Kind = ExpressionKind::Boolean;
+  bool value;
+};
+
 struct Tuple {
+  static constexpr ExpressionKind Kind = ExpressionKind::Tuple;
   std::vector<FieldInitializer>* fields;
 };
 
 struct PrimitiveOperator {
+  static constexpr ExpressionKind Kind = ExpressionKind::PrimitiveOp;
   Operator op;
   std::vector<const Expression*>* arguments;
 };
 
 struct Call {
+  static constexpr ExpressionKind Kind = ExpressionKind::Call;
   const Expression* function;
   const Expression* argument;
 };
 
 struct FunctionType {
+  static constexpr ExpressionKind Kind = ExpressionKind::FunctionT;
   const Expression* parameter;
   const Expression* return_type;
 };
 
+struct AutoT {
+  static constexpr ExpressionKind Kind = ExpressionKind::AutoT;
+};
+
+struct BoolT {
+  static constexpr ExpressionKind Kind = ExpressionKind::BoolT;
+};
+
+struct IntT {
+  static constexpr ExpressionKind Kind = ExpressionKind::IntT;
+};
+
+struct ContinuationT {
+  static constexpr ExpressionKind Kind = ExpressionKind::ContinuationT;
+};
+
+struct TypeT {
+  static constexpr ExpressionKind Kind = ExpressionKind::TypeT;
+};
+
 struct Expression {
   int line_num;
-  ExpressionKind tag;
+  inline auto tag() const -> ExpressionKind;
 
   static auto MakeVar(int line_num, std::string var) -> const Expression*;
   static auto MakeVarPat(int line_num, std::string var, const Expression* type)
@@ -124,34 +163,39 @@ struct Expression {
   static auto MakeAutoType(int line_num) -> const Expression*;
   static auto MakeContinuationType(int line_num) -> const Expression*;
 
-  Variable GetVariable() const;
-  FieldAccess GetFieldAccess() const;
-  Index GetIndex() const;
-  PatternVariable GetPatternVariable() const;
-  int GetInteger() const;
-  bool GetBoolean() const;
-  Tuple GetTuple() const;
-  PrimitiveOperator GetPrimitiveOperator() const;
-  Call GetCall() const;
-  FunctionType GetFunctionType() const;
+  auto GetVariable() const -> const Variable&;
+  auto GetFieldAccess() const -> const FieldAccess&;
+  auto GetIndex() const -> const Index&;
+  auto GetPatternVariable() const -> const PatternVariable&;
+  auto GetInteger() const -> int;
+  auto GetBoolean() const -> bool;
+  auto GetTuple() const -> const Tuple&;
+  auto GetPrimitiveOperator() const -> const PrimitiveOperator&;
+  auto GetCall() const -> const Call&;
+  auto GetFunctionType() const -> const FunctionType&;
 
  private:
-  union {
-    Variable variable;
-    FieldAccess get_field;
-    Index index;
-    PatternVariable pattern_variable;
-    int integer;
-    bool boolean;
-    Tuple tuple;
-    PrimitiveOperator primitive_op;
-    Call call;
-    FunctionType function_type;
-  } u;
+  std::variant<Variable, FieldAccess, Index, PatternVariable, IntLiteral,
+               BoolLiteral, Tuple, PrimitiveOperator, Call, FunctionType, AutoT,
+               BoolT, IntT, ContinuationT, TypeT>
+      value;
 };
 
 void PrintExp(const Expression* exp);
 
+// Implementation details only beyond this point
+
+struct TagVisitor {
+  template <typename Alternative>
+  auto operator()(const Alternative&) -> ExpressionKind {
+    return Alternative::Kind;
+  }
+};
+
+auto Expression::tag() const -> ExpressionKind {
+  return std::visit(TagVisitor(), value);
+}
+
 }  // namespace Carbon
 
 #endif  // EXECUTABLE_SEMANTICS_AST_EXPRESSION_H_

+ 4 - 4
executable_semantics/interpreter/interpreter.cpp

@@ -619,7 +619,7 @@ void StepLvalue() {
     PrintExp(exp);
     std::cout << " --->" << std::endl;
   }
-  switch (exp->tag) {
+  switch (exp->tag()) {
     case ExpressionKind::Variable: {
       //    { {x :: C, E, F} :: S, H}
       // -> { {E(x) :: C, E, F} :: S, H}
@@ -686,7 +686,7 @@ void StepExp() {
     PrintExp(exp);
     std::cout << " --->" << std::endl;
   }
-  switch (exp->tag) {
+  switch (exp->tag()) {
     case ExpressionKind::PatternVariable: {
       frame->todo.Push(MakeExpAct(exp->GetPatternVariable().type));
       act->pos++;
@@ -1076,7 +1076,7 @@ void HandleValue() {
     }
     case ActionKind::LValAction: {
       const Expression* exp = act->u.exp;
-      switch (exp->tag) {
+      switch (exp->tag()) {
         case ExpressionKind::GetField: {
           //    { v :: [].f :: C, E, F} :: S, H}
           // -> { { &v.f :: C, E, F} :: S, H }
@@ -1133,7 +1133,7 @@ void HandleValue() {
     }
     case ActionKind::ExpressionAction: {
       const Expression* exp = act->u.exp;
-      switch (exp->tag) {
+      switch (exp->tag()) {
         case ExpressionKind::PatternVariable: {
           auto v = Value::MakeVarPatVal(*exp->GetPatternVariable().name,
                                         act->results[0]);

+ 1 - 1
executable_semantics/interpreter/typecheck.cpp

@@ -134,7 +134,7 @@ auto TypeCheckExp(const Expression* e, TypeEnv types, Env values,
     PrintExp(e);
     std::cout << std::endl;
   }
-  switch (e->tag) {
+  switch (e->tag()) {
     case ExpressionKind::PatternVariable: {
       if (context != TCContext::PatternContext) {
         std::cerr

+ 15 - 15
executable_semantics/syntax/paren_contents_test.cpp

@@ -13,7 +13,7 @@ TEST(ParenContentsTest, EmptyAsExpression) {
   ParenContents contents;
   const Expression* expression = contents.AsExpression(/*line_num=*/1);
   EXPECT_EQ(expression->line_num, 1);
-  ASSERT_EQ(expression->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(expression->tag(), ExpressionKind::Tuple);
   EXPECT_EQ(expression->GetTuple().fields->size(), 0);
 }
 
@@ -21,7 +21,7 @@ TEST(ParenContentsTest, EmptyAsTuple) {
   ParenContents contents;
   const Expression* tuple = contents.AsTuple(/*line_num=*/1);
   EXPECT_EQ(tuple->line_num, 1);
-  ASSERT_EQ(tuple->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(tuple->tag(), ExpressionKind::Tuple);
   EXPECT_EQ(tuple->GetTuple().fields->size(), 0);
 }
 
@@ -38,7 +38,7 @@ TEST(ParenContentsTest, UnaryNoCommaAsExpression) {
 
   const Expression* expression = contents.AsExpression(/*line_num=*/1);
   EXPECT_EQ(expression->line_num, 2);
-  ASSERT_EQ(expression->tag, ExpressionKind::Integer);
+  ASSERT_EQ(expression->tag(), ExpressionKind::Integer);
 }
 
 TEST(ParenContentsTest, UnaryNoCommaAsTuple) {
@@ -48,10 +48,10 @@ TEST(ParenContentsTest, UnaryNoCommaAsTuple) {
 
   const Expression* tuple = contents.AsTuple(/*line_num=*/1);
   EXPECT_EQ(tuple->line_num, 1);
-  ASSERT_EQ(tuple->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(tuple->tag(), ExpressionKind::Tuple);
   std::vector<FieldInitializer> fields = *tuple->GetTuple().fields;
   ASSERT_EQ(fields.size(), 1);
-  EXPECT_EQ(fields[0].expression->tag, ExpressionKind::Integer);
+  EXPECT_EQ(fields[0].expression->tag(), ExpressionKind::Integer);
 }
 
 TEST(ParenContentsTest, UnaryWithCommaAsExpression) {
@@ -61,10 +61,10 @@ TEST(ParenContentsTest, UnaryWithCommaAsExpression) {
 
   const Expression* expression = contents.AsExpression(/*line_num=*/1);
   EXPECT_EQ(expression->line_num, 1);
-  ASSERT_EQ(expression->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(expression->tag(), ExpressionKind::Tuple);
   std::vector<FieldInitializer> fields = *expression->GetTuple().fields;
   ASSERT_EQ(fields.size(), 1);
-  EXPECT_EQ(fields[0].expression->tag, ExpressionKind::Integer);
+  EXPECT_EQ(fields[0].expression->tag(), ExpressionKind::Integer);
 }
 
 TEST(ParenContentsTest, UnaryWithCommaAsTuple) {
@@ -74,10 +74,10 @@ TEST(ParenContentsTest, UnaryWithCommaAsTuple) {
 
   const Expression* tuple = contents.AsTuple(/*line_num=*/1);
   EXPECT_EQ(tuple->line_num, 1);
-  ASSERT_EQ(tuple->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(tuple->tag(), ExpressionKind::Tuple);
   std::vector<FieldInitializer> fields = *tuple->GetTuple().fields;
   ASSERT_EQ(fields.size(), 1);
-  EXPECT_EQ(fields[0].expression->tag, ExpressionKind::Integer);
+  EXPECT_EQ(fields[0].expression->tag(), ExpressionKind::Integer);
 }
 
 TEST(ParenContentsTest, BinaryAsExpression) {
@@ -88,11 +88,11 @@ TEST(ParenContentsTest, BinaryAsExpression) {
 
   const Expression* expression = contents.AsExpression(/*line_num=*/1);
   EXPECT_EQ(expression->line_num, 1);
-  ASSERT_EQ(expression->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(expression->tag(), ExpressionKind::Tuple);
   std::vector<FieldInitializer> fields = *expression->GetTuple().fields;
   ASSERT_EQ(fields.size(), 2);
-  EXPECT_EQ(fields[0].expression->tag, ExpressionKind::Integer);
-  EXPECT_EQ(fields[1].expression->tag, ExpressionKind::Integer);
+  EXPECT_EQ(fields[0].expression->tag(), ExpressionKind::Integer);
+  EXPECT_EQ(fields[1].expression->tag(), ExpressionKind::Integer);
 }
 
 TEST(ParenContentsTest, BinaryAsTuple) {
@@ -103,11 +103,11 @@ TEST(ParenContentsTest, BinaryAsTuple) {
 
   const Expression* tuple = contents.AsTuple(/*line_num=*/1);
   EXPECT_EQ(tuple->line_num, 1);
-  ASSERT_EQ(tuple->tag, ExpressionKind::Tuple);
+  ASSERT_EQ(tuple->tag(), ExpressionKind::Tuple);
   std::vector<FieldInitializer> fields = *tuple->GetTuple().fields;
   ASSERT_EQ(fields.size(), 2);
-  EXPECT_EQ(fields[0].expression->tag, ExpressionKind::Integer);
-  EXPECT_EQ(fields[1].expression->tag, ExpressionKind::Integer);
+  EXPECT_EQ(fields[0].expression->tag(), ExpressionKind::Integer);
+  EXPECT_EQ(fields[1].expression->tag(), ExpressionKind::Integer);
 }
 
 }  // namespace