Преглед на файлове

Improve access checking code (#4317)

This change accomplishes the TODOs for access checking. More
specifically it,
- makes `SemIR::AccessKind` formattable using `llvm::formatv`.
- makes use of `LookupUnqualifiedName` to find `Self`.
Brymer Meneses преди 1 година
родител
ревизия
da40c8b076
променени са 5 файла, в които са добавени 70 реда и са изтрити 87 реда
  1. 11 11
      toolchain/check/context.cpp
  2. 2 2
      toolchain/check/context.h
  3. 3 31
      toolchain/check/member_access.cpp
  4. 31 43
      toolchain/check/testdata/class/access_modifers.carbon
  5. 23 0
      toolchain/sem_ir/name_scope.h

+ 11 - 11
toolchain/check/context.cpp

@@ -299,7 +299,8 @@ auto Context::LookupNameInDecl(SemIR::LocId loc_id, SemIR::NameId name_id,
 }
 
 auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
-                                    SemIR::NameId name_id) -> LookupResult {
+                                    SemIR::NameId name_id, bool required)
+    -> LookupResult {
   // TODO: Check for shadowed lookup results.
 
   // Find the results from ancestor lexical scopes. These will be combined with
@@ -328,7 +329,10 @@ auto Context::LookupUnqualifiedName(Parse::NodeId node_id,
   }
 
   // We didn't find anything at all.
-  DiagnoseNameNotFound(node_id, name_id);
+  if (required) {
+    DiagnoseNameNotFound(node_id, name_id);
+  }
+
   return {.specific_id = SemIR::SpecificId::Invalid,
           .inst_id = SemIR::InstId::BuiltinError};
 }
@@ -368,18 +372,14 @@ static auto DiagnoseInvalidQualifiedNameAccess(Context& context, SemIRLoc loc,
   // TODO: Support scoped entities other than just classes.
   auto class_info = context.classes().Get(class_type->class_id);
 
-  // TODO: Support passing AccessKind to diagnostics.
   CARBON_DIAGNOSTIC(ClassInvalidMemberAccess, Error,
                     "Cannot access {0} member `{1}` of type `{2}`.",
-                    llvm::StringLiteral, SemIR::NameId, SemIR::TypeId);
+                    SemIR::AccessKind, SemIR::NameId, SemIR::TypeId);
   CARBON_DIAGNOSTIC(ClassMemberDefinition, Note,
-                    "The {0} member `{1}` is defined here.",
-                    llvm::StringLiteral, SemIR::NameId);
+                    "The {0} member `{1}` is defined here.", SemIR::AccessKind,
+                    SemIR::NameId);
 
   auto parent_type_id = class_info.self_type_id;
-  auto access_desc = access_kind == SemIR::AccessKind::Private
-                         ? llvm::StringLiteral("private")
-                         : llvm::StringLiteral("protected");
 
   if (access_kind == SemIR::AccessKind::Private && is_parent_access) {
     if (auto base_decl = context.insts().TryGetAsIfValid<SemIR::BaseDecl>(
@@ -395,9 +395,9 @@ static auto DiagnoseInvalidQualifiedNameAccess(Context& context, SemIRLoc loc,
   }
 
   context.emitter()
-      .Build(loc, ClassInvalidMemberAccess, access_desc, name_id,
+      .Build(loc, ClassInvalidMemberAccess, access_kind, name_id,
              parent_type_id)
-      .Note(scope_result_id, ClassMemberDefinition, access_desc, name_id)
+      .Note(scope_result_id, ClassMemberDefinition, access_kind, name_id)
       .Emit();
 }
 

+ 2 - 2
toolchain/check/context.h

@@ -176,8 +176,8 @@ class Context {
                         SemIR::NameScopeId scope_id) -> SemIR::InstId;
 
   // Performs an unqualified name lookup, returning the referenced instruction.
-  auto LookupUnqualifiedName(Parse::NodeId node_id, SemIR::NameId name_id)
-      -> LookupResult;
+  auto LookupUnqualifiedName(Parse::NodeId node_id, SemIR::NameId name_id,
+                             bool required = true) -> LookupResult;
 
   // Performs a name lookup in a specified scope, returning the referenced
   // instruction. Does not look into extended scopes. Returns an invalid

+ 3 - 31
toolchain/check/member_access.cpp

@@ -100,41 +100,13 @@ static auto IsInstanceMethod(const SemIR::File& sem_ir,
   return false;
 }
 
-// Returns the FunctionId of the current function if it exists.
-static auto GetCurrentFunction(Context& context)
-    -> std::optional<SemIR::FunctionId> {
-  if (context.return_scope_stack().empty()) {
-    return std::nullopt;
-  }
-
-  return context.insts()
-      .GetAs<SemIR::FunctionDecl>(context.return_scope_stack().back().decl_id)
-      .function_id;
-}
-
 // Returns the highest allowed access. For example, if this returns `Protected`
 // then only `Public` and `Protected` accesses are allowed--not `Private`.
-static auto GetHighestAllowedAccess(Context& context, SemIRLoc loc,
+static auto GetHighestAllowedAccess(Context& context, SemIR::LocId loc_id,
                                     SemIR::ConstantId name_scope_const_id)
     -> SemIR::AccessKind {
-  // TODO: Maybe use LookupUnqualifiedName for `Self` to support things like
-  // `var x: Self.ParentProtectedType`?
-  auto current_function = GetCurrentFunction(context);
-  // If `current_function` is a `nullopt` then we're accessing from a global
-  // variable.
-  if (!current_function) {
-    return SemIR::AccessKind::Public;
-  }
-
-  auto scope_id = context.functions().Get(*current_function).parent_scope_id;
-  if (!scope_id.is_valid()) {
-    return SemIR::AccessKind::Public;
-  }
-  auto scope = context.name_scopes().Get(scope_id);
-
-  // Lookup the inst for `Self` in the parent scope of the current function.
-  auto [self_type_inst_id, _] = context.LookupNameInExactScope(
-      loc, SemIR::NameId::SelfType, scope_id, scope);
+  auto [_, self_type_inst_id] = context.LookupUnqualifiedName(
+      loc_id.node_id(), SemIR::NameId::SelfType, /*required=*/false);
   if (!self_type_inst_id.is_valid()) {
     return SemIR::AccessKind::Public;
   }

+ 31 - 43
toolchain/check/testdata/class/access_modifers.carbon

@@ -128,35 +128,21 @@ class A {
 // CHECK:STDERR:                 ^
 // CHECK:STDERR:
 let x: i32 = A.x;
-// CHECK:STDERR: fail_global_access.carbon:[[@LINE+7]]:14: ERROR: Cannot access private member `y` of type `A`.
+// CHECK:STDERR: fail_global_access.carbon:[[@LINE+6]]:14: ERROR: Cannot access private member `y` of type `A`.
 // CHECK:STDERR: let y: i32 = A.y;
 // CHECK:STDERR:              ^~~
 // CHECK:STDERR: fail_global_access.carbon:[[@LINE-14]]:15: The private member `y` is defined here.
 // CHECK:STDERR:   private let y: i32 = 5;
 // CHECK:STDERR:               ^
-// CHECK:STDERR:
 let y: i32 = A.y;
 
-// --- fail_todo_global_self_access.carbon
+// --- self_access.carbon
 
 library "[[@TEST_NAME]]";
 
 class A {
-  private let internal: i32 = 10;
-  // CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE+13]]:16: ERROR: Member access into incomplete class `A`.
-  // CHECK:STDERR:   let y: i32 = Self.internal;
-  // CHECK:STDERR:                ^~~~~~~~~~~~~
-  // CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE-5]]:1: Class is incomplete within its definition.
-  // CHECK:STDERR: class A {
-  // CHECK:STDERR: ^~~~~~~~~
-  // CHECK:STDERR:
-  // CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE+6]]:16: ERROR: Cannot access private member `internal` of type `A`.
-  // CHECK:STDERR:   let y: i32 = Self.internal;
-  // CHECK:STDERR:                ^~~~~~~~~~~~~
-  // CHECK:STDERR: fail_todo_global_self_access.carbon:[[@LINE-11]]:15: The private member `internal` is defined here.
-  // CHECK:STDERR:   private let internal: i32 = 10;
-  // CHECK:STDERR:               ^~~~~~~~
-  let y: i32 = Self.internal;
+  private fn F() {}
+  private fn G() { Self.F(); }
 }
 
 // CHECK:STDOUT: --- fail_private_field_access.carbon
@@ -554,9 +540,9 @@ class A {
 // CHECK:STDOUT:   %int.make_type_32.loc16: init type = call constants.%Int32() [template = i32]
 // CHECK:STDOUT:   %.loc16_8.1: type = value_of_initializer %int.make_type_32.loc16 [template = i32]
 // CHECK:STDOUT:   %.loc16_8.2: type = converted %int.make_type_32.loc16, %.loc16_8.1 [template = i32]
-// CHECK:STDOUT:   %int.make_type_32.loc24: init type = call constants.%Int32() [template = i32]
-// CHECK:STDOUT:   %.loc24_8.1: type = value_of_initializer %int.make_type_32.loc24 [template = i32]
-// CHECK:STDOUT:   %.loc24_8.2: type = converted %int.make_type_32.loc24, %.loc24_8.1 [template = i32]
+// CHECK:STDOUT:   %int.make_type_32.loc23: init type = call constants.%Int32() [template = i32]
+// CHECK:STDOUT:   %.loc23_8.1: type = value_of_initializer %int.make_type_32.loc23 [template = i32]
+// CHECK:STDOUT:   %.loc23_8.2: type = converted %int.make_type_32.loc23, %.loc23_8.1 [template = i32]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: class @A {
@@ -584,26 +570,27 @@ class A {
 // CHECK:STDOUT:   %A.ref.loc16: type = name_ref A, file.%A.decl [template = constants.%A]
 // CHECK:STDOUT:   %x.ref: <error> = name_ref x, <error> [template = <error>]
 // CHECK:STDOUT:   %x: i32 = bind_name x, <error>
-// CHECK:STDOUT:   %A.ref.loc24: type = name_ref A, file.%A.decl [template = constants.%A]
+// CHECK:STDOUT:   %A.ref.loc23: type = name_ref A, file.%A.decl [template = constants.%A]
 // CHECK:STDOUT:   %y.ref: <error> = name_ref y, <error> [template = <error>]
 // CHECK:STDOUT:   %y: i32 = bind_name y, <error>
 // CHECK:STDOUT:   return
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: --- fail_todo_global_self_access.carbon
+// CHECK:STDOUT: --- self_access.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {
 // CHECK:STDOUT:   %A: type = class_type @A [template]
-// CHECK:STDOUT:   %Int32.type: type = fn_type @Int32 [template]
+// CHECK:STDOUT:   %F.type: type = fn_type @F [template]
 // CHECK:STDOUT:   %.1: type = tuple_type () [template]
-// CHECK:STDOUT:   %Int32: %Int32.type = struct_value () [template]
-// CHECK:STDOUT:   %.2: i32 = int_literal 10 [template]
-// CHECK:STDOUT:   %.3: type = struct_type {} [template]
+// CHECK:STDOUT:   %F: %F.type = struct_value () [template]
+// CHECK:STDOUT:   %G.type: type = fn_type @G [template]
+// CHECK:STDOUT:   %G: %G.type = struct_value () [template]
+// CHECK:STDOUT:   %.2: type = struct_type {} [template]
+// CHECK:STDOUT:   %.3: type = ptr_type %.2 [template]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: imports {
 // CHECK:STDOUT:   %Core: <namespace> = namespace file.%Core.import, [template] {
-// CHECK:STDOUT:     .Int32 = %import_ref
 // CHECK:STDOUT:     import Core//prelude
 // CHECK:STDOUT:     import Core//prelude/operators
 // CHECK:STDOUT:     import Core//prelude/types
@@ -613,7 +600,6 @@ class A {
 // CHECK:STDOUT:     import Core//prelude/operators/comparison
 // CHECK:STDOUT:     import Core//prelude/types/bool
 // CHECK:STDOUT:   }
-// CHECK:STDOUT:   %import_ref: %Int32.type = import_ref Core//prelude/types, inst+4, loaded [template = constants.%Int32]
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: file {
@@ -626,23 +612,25 @@ class A {
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
 // CHECK:STDOUT: class @A {
-// CHECK:STDOUT:   %int.make_type_32.loc5: init type = call constants.%Int32() [template = i32]
-// CHECK:STDOUT:   %.loc5_25.1: type = value_of_initializer %int.make_type_32.loc5 [template = i32]
-// CHECK:STDOUT:   %.loc5_25.2: type = converted %int.make_type_32.loc5, %.loc5_25.1 [template = i32]
-// CHECK:STDOUT:   %.loc5_31: i32 = int_literal 10 [template = constants.%.2]
-// CHECK:STDOUT:   %internal: i32 = bind_name internal, %.loc5_31
-// CHECK:STDOUT:   %int.make_type_32.loc19: init type = call constants.%Int32() [template = i32]
-// CHECK:STDOUT:   %.loc19_10.1: type = value_of_initializer %int.make_type_32.loc19 [template = i32]
-// CHECK:STDOUT:   %.loc19_10.2: type = converted %int.make_type_32.loc19, %.loc19_10.1 [template = i32]
-// CHECK:STDOUT:   %Self.ref: type = name_ref Self, constants.%A [template = constants.%A]
-// CHECK:STDOUT:   %internal.ref: <error> = name_ref internal, <error> [template = <error>]
-// CHECK:STDOUT:   %y: i32 = bind_name y, <error>
+// CHECK:STDOUT:   %F.decl: %F.type = fn_decl @F [template = constants.%F] {}
+// CHECK:STDOUT:   %G.decl: %G.type = fn_decl @G [template = constants.%G] {}
 // CHECK:STDOUT:
 // CHECK:STDOUT: !members:
 // CHECK:STDOUT:   .Self = constants.%A
-// CHECK:STDOUT:   .internal [private] = %internal
-// CHECK:STDOUT:   .y = %y
+// CHECK:STDOUT:   .F [private] = %F.decl
+// CHECK:STDOUT:   .G [private] = %G.decl
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
-// CHECK:STDOUT: fn @Int32() -> type = "int.make_type_32";
+// CHECK:STDOUT: fn @F() {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   return
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @G() {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   %Self.ref: type = name_ref Self, constants.%A [template = constants.%A]
+// CHECK:STDOUT:   %F.ref: %F.type = name_ref F, @A.%F.decl [template = constants.%F]
+// CHECK:STDOUT:   %F.call: init %.1 = call %F.ref()
+// CHECK:STDOUT:   return
+// CHECK:STDOUT: }
 // CHECK:STDOUT:

+ 23 - 0
toolchain/sem_ir/name_scope.h

@@ -18,6 +18,29 @@ enum class AccessKind : int8_t {
   Private,
 };
 
+}  // namespace Carbon::SemIR
+
+template <>
+struct llvm::format_provider<Carbon::SemIR::AccessKind> {
+  using AccessKind = Carbon::SemIR::AccessKind;
+  static void format(const AccessKind& loc, raw_ostream& out,
+                     StringRef /*style*/) {
+    switch (loc) {
+      case AccessKind::Private:
+        out << "private";
+        break;
+      case AccessKind::Protected:
+        out << "protected";
+        break;
+      case AccessKind::Public:
+        out << "public";
+        break;
+    }
+  }
+};
+
+namespace Carbon::SemIR {
+
 struct NameScope : Printable<NameScope> {
   struct Entry {
     NameId name_id;