Przeglądaj źródła

Make `addr me: Self*` work for interface methods. (#2374)

Fixes #2223.
Richard Smith 3 lat temu
rodzic
commit
e557d4af3b

+ 8 - 11
explorer/ast/expression.h

@@ -263,6 +263,13 @@ class MemberAccessExpression : public Expression {
   // Can only be called once, during typechecking.
   void set_is_type_access(bool type_access) { is_type_access_ = type_access; }
 
+  // Returns true if the member is a method that has a "me" declaration in an
+  // AddrPattern.
+  auto is_addr_me_method() const -> bool { return is_addr_me_method_; }
+
+  // Can only be called once, during typechecking.
+  void set_is_addr_me_method() { is_addr_me_method_ = true; }
+
   // If `object` has a generic type, returns the witness value, which might be
   // either concrete or symbolic. Otherwise, returns `std::nullopt`. Should not
   // be called before typechecking.
@@ -292,6 +299,7 @@ class MemberAccessExpression : public Expression {
  private:
   Nonnull<Expression*> object_;
   bool is_type_access_ = false;
+  bool is_addr_me_method_ = false;
   std::optional<Nonnull<const Witness*>> impl_;
   std::optional<Nonnull<const Value*>> constant_value_;
 };
@@ -324,16 +332,6 @@ class SimpleMemberAccessExpression : public MemberAccessExpression {
     member_ = member;
   }
 
-  // Returns true if the field is a method that has a "me" declaration in an
-  // AddrPattern.
-  // TODO: Should be in MemberAccessExpression.
-  auto is_field_addr_me_method() const -> bool {
-    return is_field_addr_me_method_;
-  }
-
-  // Can only be called once, during typechecking.
-  void set_is_field_addr_me_method() { is_field_addr_me_method_ = true; }
-
   // If `object` is a constrained type parameter and `member` was found in an
   // interface, returns that interface. Should not be called before
   // typechecking.
@@ -351,7 +349,6 @@ class SimpleMemberAccessExpression : public MemberAccessExpression {
  private:
   std::string member_name_;
   std::optional<Member> member_;
-  bool is_field_addr_me_method_ = false;
   std::optional<Nonnull<const InterfaceType*>> found_in_interface_;
 };
 

+ 7 - 3
explorer/interpreter/interpreter.cpp

@@ -1043,7 +1043,7 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
       bool forming_member_name = isa<TypeOfMemberName>(&access.static_type());
       if (act.pos() == 0) {
         // First, evaluate the first operand.
-        if (access.is_field_addr_me_method()) {
+        if (access.is_addr_me_method()) {
           return todo_.Spawn(std::make_unique<LValAction>(&access.object()));
         } else {
           return todo_.Spawn(
@@ -1119,8 +1119,12 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
       bool forming_member_name = isa<TypeOfMemberName>(&access.static_type());
       if (act.pos() == 0) {
         // First, evaluate the first operand.
-        return todo_.Spawn(
-            std::make_unique<ExpressionAction>(&access.object()));
+        if (access.is_addr_me_method()) {
+          return todo_.Spawn(std::make_unique<LValAction>(&access.object()));
+        } else {
+          return todo_.Spawn(
+              std::make_unique<ExpressionAction>(&access.object()));
+        }
       } else if (act.pos() == 1 && access.impl().has_value() &&
                  !forming_member_name) {
         // Next, if we're accessing an interface member, evaluate the `impl`

+ 40 - 20
explorer/interpreter/type_checker.cpp

@@ -2430,6 +2430,27 @@ static auto IsInstanceMember(Member member) {
   }
 }
 
+auto TypeChecker::CheckAddrMeAccess(
+    Nonnull<MemberAccessExpression*> access,
+    Nonnull<const FunctionDeclaration*> func_decl, const Bindings& bindings,
+    const ImplScope& impl_scope) -> ErrorOr<Success> {
+  if (func_decl->is_method() &&
+      func_decl->me_pattern().kind() == PatternKind::AddrPattern) {
+    access->set_is_addr_me_method();
+    Nonnull<const Value*> me_type =
+        Substitute(bindings, &func_decl->me_pattern().static_type());
+    CARBON_RETURN_IF_ERROR(
+        ExpectExactType(access->source_loc(), "method access, receiver type",
+                        me_type, &access->object().static_type(), impl_scope));
+    if (access->object().value_category() != ValueCategory::Var) {
+      return ProgramError(access->source_loc())
+             << "method " << *access
+             << " requires its receiver to be an lvalue";
+    }
+  }
+  return Success();
+}
+
 auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                                const ImplScope& impl_scope)
     -> ErrorOr<Success> {
@@ -2568,21 +2589,8 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                 break;
               case DeclarationKind::FunctionDeclaration: {
                 const auto* func_decl = cast<FunctionDeclaration>(member);
-                if (func_decl->is_method() && func_decl->me_pattern().kind() ==
-                                                  PatternKind::AddrPattern) {
-                  access.set_is_field_addr_me_method();
-                  Nonnull<const Value*> me_type =
-                      Substitute(t_class.bindings(),
-                                 &func_decl->me_pattern().static_type());
-                  CARBON_RETURN_IF_ERROR(ExpectType(
-                      e->source_loc(), "method access, receiver type", me_type,
-                      &access.object().static_type(), impl_scope));
-                  if (access.object().value_category() != ValueCategory::Var) {
-                    return ProgramError(e->source_loc())
-                           << "method " << access.member_name()
-                           << " requires its receiver to be an lvalue";
-                  }
-                }
+                CARBON_RETURN_IF_ERROR(CheckAddrMeAccess(
+                    &access, func_decl, t_class.bindings(), impl_scope));
                 access.set_value_category(ValueCategory::Let);
                 break;
               }
@@ -2650,6 +2658,11 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
           access.set_is_type_access(!IsInstanceMember(access.member()));
           access.set_static_type(inst_member_type);
 
+          if (auto* func_decl = dyn_cast<FunctionDeclaration>(result.member)) {
+            CARBON_RETURN_IF_ERROR(
+                CheckAddrMeAccess(&access, func_decl, bindings, impl_scope));
+          }
+
           // TODO: This is just a ConstraintImplWitness into the
           // iface_constraint. If we can compute the right index, we can avoid
           // re-resolving it.
@@ -2904,19 +2917,23 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
         access.set_impl(impl);
       }
 
-      auto substitute_into_member_type = [&]() {
-        Nonnull<const Value*> member_type = &member_name.member().type();
+      auto bindings_for_member = [&]() -> Bindings {
         if (member_name.interface()) {
           Nonnull<const InterfaceType*> iface_type = *member_name.interface();
           Bindings bindings = iface_type->bindings();
           bindings.Add(iface_type->declaration().self(), *base_type, witness);
-          return Substitute(bindings, member_type);
+          return bindings;
         }
         if (const auto* class_type =
                 dyn_cast<NominalClassType>(base_type.value())) {
-          return Substitute(class_type->bindings(), member_type);
+          return class_type->bindings();
         }
-        return member_type;
+        return Bindings();
+      };
+
+      auto substitute_into_member_type = [&]() {
+        Nonnull<const Value*> member_type = &member_name.member().type();
+        return Substitute(bindings_for_member(), member_type);
       };
 
       switch (std::optional<Nonnull<const Declaration*>> decl =
@@ -2939,6 +2956,9 @@ auto TypeChecker::TypeCheckExp(Nonnull<Expression*> e,
                 << "vacuous compound member access";
             access.set_static_type(substitute_into_member_type());
             access.set_value_category(ValueCategory::Let);
+            CARBON_RETURN_IF_ERROR(
+                CheckAddrMeAccess(&access, cast<FunctionDeclaration>(*decl),
+                                  bindings_for_member(), impl_scope));
             return Success();
           }
           break;

+ 8 - 0
explorer/interpreter/type_checker.h

@@ -129,6 +129,14 @@ class TypeChecker {
     Nonnull<const Declaration*> member;
   };
 
+  // Checks a member access that might be accessing a function taking `addr me:
+  // Self*`. If it does, this function marks the member access accordingly and
+  // ensures the object argument is an lvalue.
+  auto CheckAddrMeAccess(Nonnull<MemberAccessExpression*> access,
+                         Nonnull<const FunctionDeclaration*> func_decl,
+                         const Bindings& bindings, const ImplScope& impl_scope)
+      -> ErrorOr<Success>;
+
   // Traverses the AST rooted at `e`, populating the static_type() of all nodes
   // and ensuring they follow Carbon's typing rules.
   //

+ 1 - 1
explorer/interpreter/value.cpp

@@ -59,7 +59,7 @@ static auto GetMember(Nonnull<Arena*> arena, Nonnull<const Value*> v,
           mem_decl.has_value()) {
         const auto& fun_decl = cast<FunctionDeclaration>(**mem_decl);
         if (fun_decl.is_method()) {
-          return arena->New<BoundMethodValue>(&fun_decl, v,
+          return arena->New<BoundMethodValue>(&fun_decl, me_value,
                                               &impl_witness->bindings());
         } else {
           // Class function.

+ 1 - 1
explorer/testdata/addr/fail_method_let.carbon

@@ -26,7 +26,7 @@ class Point {
 
 fn Main() -> i32 {
   let p: Point = Point.Origin();
-  // CHECK:STDERR: COMPILATION ERROR: {{.*}}/explorer/testdata/addr/fail_method_let.carbon:[[@LINE+1]]: method GetSetX requires its receiver to be an lvalue
+  // CHECK:STDERR: COMPILATION ERROR: {{.*}}/explorer/testdata/addr/fail_method_let.carbon:[[@LINE+1]]: method p.GetSetX requires its receiver to be an lvalue
   var x: auto = p.GetSetX(42);
   if (p.x == 42) {
     return x;

+ 3 - 1
explorer/testdata/addr/fail_method_me_type.carbon

@@ -30,7 +30,9 @@ class Point {
 
 fn Main() -> i32 {
   var p: Point = Point.Origin();
-  // CHECK:STDERR: COMPILATION ERROR: {{.*}}/explorer/testdata/addr/fail_method_me_type.carbon:[[@LINE+1]]: type error in method access, receiver type: 'class Point' is not implicitly convertible to 'class Shape'
+  // CHECK:STDERR: COMPILATION ERROR: {{.*}}/explorer/testdata/addr/fail_method_me_type.carbon:[[@LINE+3]]: type error in method access, receiver type
+  // CHECK:STDERR: expected: class Shape
+  // CHECK:STDERR: actual: class Point
   var x: auto = p.GetSetX(42);
   if (p.x == 42) {
     return x;

+ 0 - 1
explorer/testdata/destructor/dont_call_in_method.carbon

@@ -5,7 +5,6 @@
 // AUTOUPDATE
 // RUN: %{explorer-run}
 // RUN: %{explorer-run-trace}
-// AUTOUPDATE: %{explorer} %s
 // CHECK:STDOUT: TEST
 // CHECK:STDOUT: TEST 2
 // CHECK:STDOUT: DESTRUCTOR A 1

+ 50 - 0
explorer/testdata/interface/addr_me.carbon

@@ -0,0 +1,50 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// AUTOUPDATE
+// RUN: %{explorer-run}
+// RUN: %{explorer-run-trace}
+// CHECK:STDOUT: 75
+// CHECK:STDOUT: 100
+// CHECK:STDOUT: result: 0
+
+package ExplorerTest api;
+
+interface Vector {
+  fn Zero() -> Self;
+  fn Add[addr me: Self*](b: Self);
+  fn Scale[addr me: Self*](v: i32);
+}
+
+class Point {
+  var x: i32;
+  var y: i32;
+  impl as Vector {
+    fn Zero() -> Self {
+      return {.x = 1, .y = 1};
+    }
+    fn Add[addr me: Self*](b: Self) {
+      (*me).x = (*me).x + b.x;
+      (*me).y = (*me).y + b.y;
+    }
+    fn Scale[addr me: Self*](v: i32) {
+      (*me).x = (*me).x * v;
+      (*me).y = (*me).y * v;
+    }
+  }
+}
+
+fn AddAndScaleGeneric[T:! Vector](p: T*, s: i32) {
+  (*p).Add(T.Zero());
+  (*p).(Vector.Scale)(s);
+  (*p).(T.(Vector.Scale))(s);
+}
+
+fn Main() -> i32 {
+  var a: Point = {.x = 2, .y = 3};
+  AddAndScaleGeneric(&a, 5);
+  Print("{0}", a.x);
+  Print("{0}", a.y);
+  return 0;
+}