소스 검색

C++ interop: Support C++20 operator and overload resolution for expression rewriting (#6171)

This allows to find the spaceship `operator<=>` when a comparison
operator is not available, and `operator==` when `operator!=` is not
available.
Support added to both lookup and overload resolution, by adding
`OperatorRewriteInfo` and propagating it in `CppOverloadSet`.
In case overload resolution chooses to use an operator which requires
rewriting, we emit a `TODO` since rewriting is not yet supported.

Part of #6170.
Boaz Brickner 6 달 전
부모
커밋
7c13bddc92

+ 8 - 6
toolchain/check/cpp/import.cpp

@@ -2060,16 +2060,17 @@ static auto LookupBuiltinTypes(Context& context, SemIR::LocId loc_id,
   return inst_id;
 }
 
-auto ImportCppOverloadSet(Context& context, SemIR::NameScopeId scope_id,
-                          SemIR::NameId name_id,
-                          clang::CXXRecordDecl* naming_class,
-                          clang::UnresolvedSet<4>&& overload_set)
+auto ImportCppOverloadSet(
+    Context& context, SemIR::NameScopeId scope_id, SemIR::NameId name_id,
+    clang::CXXRecordDecl* naming_class, clang::UnresolvedSet<4>&& overload_set,
+    clang::OverloadCandidateSet::OperatorRewriteInfo operator_rewrite_info)
     -> SemIR::InstId {
   SemIR::CppOverloadSetId overload_set_id = context.cpp_overload_sets().Add(
       SemIR::CppOverloadSet{.name_id = name_id,
                             .parent_scope_id = scope_id,
                             .naming_class = naming_class,
-                            .candidate_functions = std::move(overload_set)});
+                            .candidate_functions = std::move(overload_set),
+                            .operator_rewrite_info = operator_rewrite_info});
 
   auto overload_set_inst_id =
       // TODO: Add a location.
@@ -2110,7 +2111,8 @@ static auto ImportOverloadSetIntoScope(Context& context,
     -> SemIR::ScopeLookupResult {
   SemIR::AccessKind access_kind = GetOverloadSetAccess(overload_set);
   SemIR::InstId inst_id = ImportCppOverloadSet(
-      context, scope_id, name_id, naming_class, std::move(overload_set));
+      context, scope_id, name_id, naming_class, std::move(overload_set),
+      /*operator_rewrite_info=*/{});
   AddNameToScope(context, scope_id, name_id, access_kind, inst_id);
   return SemIR::ScopeLookupResult::MakeWrappedLookupResult(inst_id,
                                                            access_kind);

+ 4 - 4
toolchain/check/cpp/import.h

@@ -32,10 +32,10 @@ auto ImportCppFunctionDecl(Context& context, SemIR::LocId loc_id,
     -> SemIR::InstId;
 
 // Imports an overloaded function set from Clang to Carbon.
-auto ImportCppOverloadSet(Context& context, SemIR::NameScopeId scope_id,
-                          SemIR::NameId name_id,
-                          clang::CXXRecordDecl* naming_class,
-                          clang::UnresolvedSet<4>&& overload_set)
+auto ImportCppOverloadSet(
+    Context& context, SemIR::NameScopeId scope_id, SemIR::NameId name_id,
+    clang::CXXRecordDecl* naming_class, clang::UnresolvedSet<4>&& overload_set,
+    clang::OverloadCandidateSet::OperatorRewriteInfo operator_rewrite_info)
     -> SemIR::InstId;
 
 // Looks up the given name in the Clang AST generated when importing C++ code

+ 7 - 5
toolchain/check/cpp/operators.cpp

@@ -190,10 +190,12 @@ auto LookupCppOperator(Context& context, SemIR::LocId loc_id, Operator op,
     return SemIR::ErrorInst::InstId;
   }
 
+  clang::SourceLocation loc = GetCppLocation(context, loc_id);
+  clang::OverloadCandidateSet::OperatorRewriteInfo operator_rewrite_info(
+      *op_kind, loc, /*AllowRewritten=*/true);
   clang::UnresolvedSet<4> functions;
   clang::OverloadCandidateSet candidate_set(
-      GetCppLocation(context, loc_id),
-      clang::OverloadCandidateSet::CSK_Operator);
+      loc, clang::OverloadCandidateSet::CSK_Operator, operator_rewrite_info);
   // This works for both unary and binary operators.
   context.clang_sema().LookupOverloadedBinOp(candidate_set, *op_kind, functions,
                                              *arg_exprs);
@@ -205,9 +207,9 @@ auto LookupCppOperator(Context& context, SemIR::LocId loc_id, Operator op,
     functions.addDecl(it.Function, it.FoundDecl.getAccess());
   }
 
-  return ImportCppOverloadSet(context, SemIR::NameScopeId::None,
-                              SemIR::NameId::CppOperator,
-                              /*naming_class=*/nullptr, std::move(functions));
+  return ImportCppOverloadSet(
+      context, SemIR::NameScopeId::None, SemIR::NameId::CppOperator,
+      /*naming_class=*/nullptr, std::move(functions), operator_rewrite_info);
 }
 
 auto IsCppOperatorMethodDecl(clang::Decl* decl) -> bool {

+ 14 - 1
toolchain/check/cpp/overload_resolution.cpp

@@ -149,7 +149,11 @@ auto PerformCppOverloadResolution(Context& context, SemIR::LocId loc_id,
 
   // Add candidate functions from the name lookup.
   clang::OverloadCandidateSet candidate_set(
-      loc, clang::OverloadCandidateSet::CandidateSetKind::CSK_Normal);
+      loc,
+      overload_set.operator_rewrite_info.OriginalOperator
+          ? clang::OverloadCandidateSet::CandidateSetKind::CSK_Operator
+          : clang::OverloadCandidateSet::CandidateSetKind::CSK_Normal,
+      overload_set.operator_rewrite_info);
 
   clang::Sema& sema = context.clang_sema();
 
@@ -165,6 +169,15 @@ auto PerformCppOverloadResolution(Context& context, SemIR::LocId loc_id,
     case clang::OverloadingResult::OR_Success: {
       // TODO: Handle the cases when Function is null.
       CARBON_CHECK(best_viable_fn->Function);
+      if (best_viable_fn->RewriteKind) {
+        context.TODO(
+            loc_id,
+            llvm::formatv("Rewriting operator{0} using {1} is not supported",
+                          clang::getOperatorSpelling(
+                              candidate_set.getRewriteInfo().OriginalOperator),
+                          best_viable_fn->Function->getNameAsString()));
+        return SemIR::ErrorInst::InstId;
+      }
       sema.MarkFunctionReferenced(loc, best_viable_fn->Function);
       SemIR::InstId result_id = ImportCppFunctionDecl(
           context, loc_id, best_viable_fn->Function,

+ 380 - 1
toolchain/check/testdata/interop/cpp/function/operators.carbon

@@ -3,7 +3,7 @@
 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
 //
 // INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon
-// EXTRA-ARGS: --target=x86_64-linux-gnu
+// EXTRA-ARGS: --target=x86_64-linux-gnu --clang-arg=-std=c++20
 //
 // AUTOUPDATE
 // TIP: To test this file alone, run:
@@ -252,6 +252,81 @@ fn F() {
   let c3: Cpp.C = 6 + c1;
 }
 
+// ============================================================================
+// Rewrite using the spaceship operator
+// ============================================================================
+
+// --- fail_todo_rewrite_spaceship.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp inline '''
+class C {};
+
+namespace std { class strong_ordering {}; }
+auto operator<=>(C lhs, C rhs) -> std::strong_ordering;
+auto operator>(C lhs, C rhs) -> bool;
+auto operator<=(C lhs, C rhs) -> bool;
+''';
+
+fn F() {
+  //@dump-sem-ir-begin
+  var c1: Cpp.C = Cpp.C.C();
+  var c2: Cpp.C = Cpp.C.C();
+
+  // No rewrite.
+  let greater_than: bool = c1 > c2;
+
+  // Rewrite.
+  // CHECK:STDERR: fail_todo_rewrite_spaceship.carbon:[[@LINE+4]]:25: error: semantics TODO: `Rewriting operator< using operator<=> is not supported` [SemanticsTodo]
+  // CHECK:STDERR:   let less_than: bool = c1 < c2;
+  // CHECK:STDERR:                         ^~~~~~~
+  // CHECK:STDERR:
+  let less_than: bool = c1 < c2;
+
+  // Rewrite.
+  // CHECK:STDERR: fail_todo_rewrite_spaceship.carbon:[[@LINE+4]]:37: error: semantics TODO: `Rewriting operator>= using operator<=> is not supported` [SemanticsTodo]
+  // CHECK:STDERR:   let greater_than_or_equal: bool = c1 >= c2;
+  // CHECK:STDERR:                                     ^~~~~~~~
+  // CHECK:STDERR:
+  let greater_than_or_equal: bool = c1 >= c2;
+
+  // No rewrite.
+  let less_than_or_equal: bool = c1 <= c2;
+  //@dump-sem-ir-end
+}
+
+// ============================================================================
+// Rewrite using the equal operator
+// ============================================================================
+
+// --- fail_todo_rewrite_equal.carbon
+
+library "[[@TEST_NAME]]";
+
+import Cpp inline '''
+class C {};
+
+auto operator==(C lhs, C rhs) -> bool;
+''';
+
+fn F() {
+  //@dump-sem-ir-begin
+  var c1: Cpp.C = Cpp.C.C();
+  var c2: Cpp.C = Cpp.C.C();
+
+  // No rewrite.
+  let equal: bool = c1 == c2;
+
+  // Rewrite.
+  // CHECK:STDERR: fail_todo_rewrite_equal.carbon:[[@LINE+4]]:25: error: semantics TODO: `Rewriting operator!= using operator== is not supported` [SemanticsTodo]
+  // CHECK:STDERR:   let not_equal: bool = c1 != c2;
+  // CHECK:STDERR:                         ^~~~~~~~
+  // CHECK:STDERR:
+  let not_equal: bool = c1 != c2;
+  //@dump-sem-ir-end
+}
+
 // ============================================================================
 // One of two operands conversion
 // ============================================================================
@@ -663,6 +738,7 @@ fn F() {
   let c3: Cpp.N.O.C = c1 + c2;
   //@dump-sem-ir-end
 }
+
 // ============================================================================
 // Member operator
 // ============================================================================
@@ -2118,6 +2194,309 @@ fn F() {
 // CHECK:STDOUT:   <elided>
 // CHECK:STDOUT: }
 // CHECK:STDOUT:
+// CHECK:STDOUT: --- fail_todo_rewrite_spaceship.carbon
+// CHECK:STDOUT:
+// CHECK:STDOUT: constants {
+// CHECK:STDOUT:   %empty_tuple.type: type = tuple_type () [concrete]
+// CHECK:STDOUT:   %C: type = class_type @C [concrete]
+// CHECK:STDOUT:   %pattern_type.217: type = pattern_type %C [concrete]
+// CHECK:STDOUT:   %C.C.cpp_overload_set.type: type = cpp_overload_set_type @C.C.cpp_overload_set [concrete]
+// CHECK:STDOUT:   %C.C.cpp_overload_set.value: %C.C.cpp_overload_set.type = cpp_overload_set_value @C.C.cpp_overload_set [concrete]
+// CHECK:STDOUT:   %ptr.d9e: type = ptr_type %C [concrete]
+// CHECK:STDOUT:   %C__carbon_thunk.type: type = fn_type @C__carbon_thunk [concrete]
+// CHECK:STDOUT:   %C__carbon_thunk: %C__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %Bool.type: type = fn_type @Bool [concrete]
+// CHECK:STDOUT:   %Bool: %Bool.type = struct_value () [concrete]
+// CHECK:STDOUT:   %pattern_type.831: type = pattern_type bool [concrete]
+// CHECK:STDOUT:   %ptr.bb2: type = ptr_type bool [concrete]
+// CHECK:STDOUT:   %operator>__carbon_thunk.type: type = fn_type @operator>__carbon_thunk [concrete]
+// CHECK:STDOUT:   %operator>__carbon_thunk: %operator>__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %operator<=__carbon_thunk.type: type = fn_type @operator<=__carbon_thunk [concrete]
+// CHECK:STDOUT:   %operator<=__carbon_thunk: %operator<=__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %type_where: type = facet_type <type where .Self impls <CanDestroy>> [concrete]
+// CHECK:STDOUT:   %facet_value: %type_where = facet_value %C, () [concrete]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.type.b92: type = fn_type @DestroyT.binding.as_type.as.Destroy.impl.Op, @DestroyT.binding.as_type.as.Destroy.impl(%facet_value) [concrete]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.841: %DestroyT.binding.as_type.as.Destroy.impl.Op.type.b92 = struct_value () [concrete]
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: imports {
+// CHECK:STDOUT:   %Cpp: <namespace> = namespace file.%Cpp.import_cpp, [concrete] {
+// CHECK:STDOUT:     .C = %C.decl
+// CHECK:STDOUT:     import Cpp//...
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %C.decl: type = class_decl @C [concrete = constants.%C] {} {}
+// CHECK:STDOUT:   %C.C.cpp_overload_set.value: %C.C.cpp_overload_set.type = cpp_overload_set_value @C.C.cpp_overload_set [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %C__carbon_thunk.decl: %C__carbon_thunk.type = fn_decl @C__carbon_thunk [concrete = constants.%C__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %operator>__carbon_thunk.decl: %operator>__carbon_thunk.type = fn_decl @operator>__carbon_thunk [concrete = constants.%operator>__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %operator<=__carbon_thunk.decl: %operator<=__carbon_thunk.type = fn_decl @operator<=__carbon_thunk [concrete = constants.%operator<=__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @F() {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %c1.patt: %pattern_type.217 = binding_pattern c1 [concrete]
+// CHECK:STDOUT:     %c1.var_patt: %pattern_type.217 = var_pattern %c1.patt [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.var: ref %C = var %c1.var_patt
+// CHECK:STDOUT:   %Cpp.ref.loc15_19: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:   %C.ref.loc15_22: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   %C.ref.loc15_24: %C.C.cpp_overload_set.type = name_ref C, imports.%C.C.cpp_overload_set.value [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %.loc15_3.1: ref %C = splice_block %c1.var {}
+// CHECK:STDOUT:   %addr.loc15_27: %ptr.d9e = addr_of %.loc15_3.1
+// CHECK:STDOUT:   %C__carbon_thunk.call.loc15: init %empty_tuple.type = call imports.%C__carbon_thunk.decl(%addr.loc15_27)
+// CHECK:STDOUT:   %.loc15_27: init %C = in_place_init %C__carbon_thunk.call.loc15, %.loc15_3.1
+// CHECK:STDOUT:   assign %c1.var, %.loc15_27
+// CHECK:STDOUT:   %.loc15_14: type = splice_block %C.ref.loc15_14 [concrete = constants.%C] {
+// CHECK:STDOUT:     %Cpp.ref.loc15_11: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:     %C.ref.loc15_14: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1: ref %C = bind_name c1, %c1.var
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %c2.patt: %pattern_type.217 = binding_pattern c2 [concrete]
+// CHECK:STDOUT:     %c2.var_patt: %pattern_type.217 = var_pattern %c2.patt [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c2.var: ref %C = var %c2.var_patt
+// CHECK:STDOUT:   %Cpp.ref.loc16_19: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:   %C.ref.loc16_22: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   %C.ref.loc16_24: %C.C.cpp_overload_set.type = name_ref C, imports.%C.C.cpp_overload_set.value [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %.loc16_3.1: ref %C = splice_block %c2.var {}
+// CHECK:STDOUT:   %addr.loc16_27: %ptr.d9e = addr_of %.loc16_3.1
+// CHECK:STDOUT:   %C__carbon_thunk.call.loc16: init %empty_tuple.type = call imports.%C__carbon_thunk.decl(%addr.loc16_27)
+// CHECK:STDOUT:   %.loc16_27: init %C = in_place_init %C__carbon_thunk.call.loc16, %.loc16_3.1
+// CHECK:STDOUT:   assign %c2.var, %.loc16_27
+// CHECK:STDOUT:   %.loc16_14: type = splice_block %C.ref.loc16_14 [concrete = constants.%C] {
+// CHECK:STDOUT:     %Cpp.ref.loc16_11: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:     %C.ref.loc16_14: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c2: ref %C = bind_name c2, %c2.var
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %greater_than.patt: %pattern_type.831 = binding_pattern greater_than [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc19: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc19: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc19_28.1: %C = bind_value %c1.ref.loc19
+// CHECK:STDOUT:   %.loc19_33.1: %C = bind_value %c2.ref.loc19
+// CHECK:STDOUT:   %.loc19_28.2: ref %C = value_as_ref %.loc19_28.1
+// CHECK:STDOUT:   %addr.loc19_31.1: %ptr.d9e = addr_of %.loc19_28.2
+// CHECK:STDOUT:   %.loc19_33.2: ref %C = value_as_ref %.loc19_33.1
+// CHECK:STDOUT:   %addr.loc19_31.2: %ptr.d9e = addr_of %.loc19_33.2
+// CHECK:STDOUT:   %.loc19_31.1: ref bool = temporary_storage
+// CHECK:STDOUT:   %addr.loc19_31.3: %ptr.bb2 = addr_of %.loc19_31.1
+// CHECK:STDOUT:   %operator>__carbon_thunk.call: init %empty_tuple.type = call imports.%operator>__carbon_thunk.decl(%addr.loc19_31.1, %addr.loc19_31.2, %addr.loc19_31.3)
+// CHECK:STDOUT:   %.loc19_31.2: init bool = in_place_init %operator>__carbon_thunk.call, %.loc19_31.1
+// CHECK:STDOUT:   %.loc19_21.1: type = splice_block %.loc19_21.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc19: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc19_21.2: type = value_of_initializer %Bool.call.loc19 [concrete = bool]
+// CHECK:STDOUT:     %.loc19_21.3: type = converted %Bool.call.loc19, %.loc19_21.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %.loc19_31.3: bool = value_of_initializer %.loc19_31.2
+// CHECK:STDOUT:   %.loc19_31.4: bool = converted %.loc19_31.2, %.loc19_31.3
+// CHECK:STDOUT:   %greater_than: bool = bind_name greater_than, %.loc19_31.4
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %less_than.patt: %pattern_type.831 = binding_pattern less_than [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc26: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc26: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc26_18.1: type = splice_block %.loc26_18.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc26: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc26_18.2: type = value_of_initializer %Bool.call.loc26 [concrete = bool]
+// CHECK:STDOUT:     %.loc26_18.3: type = converted %Bool.call.loc26, %.loc26_18.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %less_than: bool = bind_name less_than, <error> [concrete = <error>]
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %greater_than_or_equal.patt: %pattern_type.831 = binding_pattern greater_than_or_equal [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc33: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc33: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc33_30.1: type = splice_block %.loc33_30.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc33: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc33_30.2: type = value_of_initializer %Bool.call.loc33 [concrete = bool]
+// CHECK:STDOUT:     %.loc33_30.3: type = converted %Bool.call.loc33, %.loc33_30.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %greater_than_or_equal: bool = bind_name greater_than_or_equal, <error> [concrete = <error>]
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %less_than_or_equal.patt: %pattern_type.831 = binding_pattern less_than_or_equal [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc36: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc36: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc36_34.1: %C = bind_value %c1.ref.loc36
+// CHECK:STDOUT:   %.loc36_40.1: %C = bind_value %c2.ref.loc36
+// CHECK:STDOUT:   %.loc36_34.2: ref %C = value_as_ref %.loc36_34.1
+// CHECK:STDOUT:   %addr.loc36_37.1: %ptr.d9e = addr_of %.loc36_34.2
+// CHECK:STDOUT:   %.loc36_40.2: ref %C = value_as_ref %.loc36_40.1
+// CHECK:STDOUT:   %addr.loc36_37.2: %ptr.d9e = addr_of %.loc36_40.2
+// CHECK:STDOUT:   %.loc36_37.1: ref bool = temporary_storage
+// CHECK:STDOUT:   %addr.loc36_37.3: %ptr.bb2 = addr_of %.loc36_37.1
+// CHECK:STDOUT:   %operator<=__carbon_thunk.call: init %empty_tuple.type = call imports.%operator<=__carbon_thunk.decl(%addr.loc36_37.1, %addr.loc36_37.2, %addr.loc36_37.3)
+// CHECK:STDOUT:   %.loc36_37.2: init bool = in_place_init %operator<=__carbon_thunk.call, %.loc36_37.1
+// CHECK:STDOUT:   %.loc36_27.1: type = splice_block %.loc36_27.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc36: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc36_27.2: type = value_of_initializer %Bool.call.loc36 [concrete = bool]
+// CHECK:STDOUT:     %.loc36_27.3: type = converted %Bool.call.loc36, %.loc36_27.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %.loc36_37.3: bool = value_of_initializer %.loc36_37.2
+// CHECK:STDOUT:   %.loc36_37.4: bool = converted %.loc36_37.2, %.loc36_37.3
+// CHECK:STDOUT:   %less_than_or_equal: bool = bind_name less_than_or_equal, %.loc36_37.4
+// CHECK:STDOUT:   %facet_value.loc16: %type_where = facet_value constants.%C, () [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %.loc16_3.2: %type_where = converted constants.%C, %facet_value.loc16 [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.bound.loc16: <bound method> = bound_method %c2.var, constants.%DestroyT.binding.as_type.as.Destroy.impl.Op.841
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %bound_method.loc16: <bound method> = bound_method %c2.var, %DestroyT.binding.as_type.as.Destroy.impl.Op.specific_fn.1
+// CHECK:STDOUT:   %addr.loc16_3: %ptr.d9e = addr_of %c2.var
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.call.loc16: init %empty_tuple.type = call %bound_method.loc16(%addr.loc16_3)
+// CHECK:STDOUT:   %facet_value.loc15: %type_where = facet_value constants.%C, () [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %.loc15_3.2: %type_where = converted constants.%C, %facet_value.loc15 [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.bound.loc15: <bound method> = bound_method %c1.var, constants.%DestroyT.binding.as_type.as.Destroy.impl.Op.841
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %bound_method.loc15: <bound method> = bound_method %c1.var, %DestroyT.binding.as_type.as.Destroy.impl.Op.specific_fn.2
+// CHECK:STDOUT:   %addr.loc15_3: %ptr.d9e = addr_of %c1.var
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.call.loc15: init %empty_tuple.type = call %bound_method.loc15(%addr.loc15_3)
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: --- fail_todo_rewrite_equal.carbon
+// CHECK:STDOUT:
+// CHECK:STDOUT: constants {
+// CHECK:STDOUT:   %empty_tuple.type: type = tuple_type () [concrete]
+// CHECK:STDOUT:   %C: type = class_type @C [concrete]
+// CHECK:STDOUT:   %pattern_type.217: type = pattern_type %C [concrete]
+// CHECK:STDOUT:   %C.C.cpp_overload_set.type: type = cpp_overload_set_type @C.C.cpp_overload_set [concrete]
+// CHECK:STDOUT:   %C.C.cpp_overload_set.value: %C.C.cpp_overload_set.type = cpp_overload_set_value @C.C.cpp_overload_set [concrete]
+// CHECK:STDOUT:   %ptr.d9e: type = ptr_type %C [concrete]
+// CHECK:STDOUT:   %C__carbon_thunk.type: type = fn_type @C__carbon_thunk [concrete]
+// CHECK:STDOUT:   %C__carbon_thunk: %C__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %Bool.type: type = fn_type @Bool [concrete]
+// CHECK:STDOUT:   %Bool: %Bool.type = struct_value () [concrete]
+// CHECK:STDOUT:   %pattern_type.831: type = pattern_type bool [concrete]
+// CHECK:STDOUT:   %ptr.bb2: type = ptr_type bool [concrete]
+// CHECK:STDOUT:   %operator==__carbon_thunk.type: type = fn_type @operator==__carbon_thunk [concrete]
+// CHECK:STDOUT:   %operator==__carbon_thunk: %operator==__carbon_thunk.type = struct_value () [concrete]
+// CHECK:STDOUT:   %type_where: type = facet_type <type where .Self impls <CanDestroy>> [concrete]
+// CHECK:STDOUT:   %facet_value: %type_where = facet_value %C, () [concrete]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.type.b92: type = fn_type @DestroyT.binding.as_type.as.Destroy.impl.Op, @DestroyT.binding.as_type.as.Destroy.impl(%facet_value) [concrete]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.841: %DestroyT.binding.as_type.as.Destroy.impl.Op.type.b92 = struct_value () [concrete]
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: imports {
+// CHECK:STDOUT:   %Cpp: <namespace> = namespace file.%Cpp.import_cpp, [concrete] {
+// CHECK:STDOUT:     .C = %C.decl
+// CHECK:STDOUT:     import Cpp//...
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %C.decl: type = class_decl @C [concrete = constants.%C] {} {}
+// CHECK:STDOUT:   %C.C.cpp_overload_set.value: %C.C.cpp_overload_set.type = cpp_overload_set_value @C.C.cpp_overload_set [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %C__carbon_thunk.decl: %C__carbon_thunk.type = fn_decl @C__carbon_thunk [concrete = constants.%C__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %operator==__carbon_thunk.decl: %operator==__carbon_thunk.type = fn_decl @operator==__carbon_thunk [concrete = constants.%operator==__carbon_thunk] {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   } {
+// CHECK:STDOUT:     <elided>
+// CHECK:STDOUT:   }
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: fn @F() {
+// CHECK:STDOUT: !entry:
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %c1.patt: %pattern_type.217 = binding_pattern c1 [concrete]
+// CHECK:STDOUT:     %c1.var_patt: %pattern_type.217 = var_pattern %c1.patt [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.var: ref %C = var %c1.var_patt
+// CHECK:STDOUT:   %Cpp.ref.loc12_19: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:   %C.ref.loc12_22: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   %C.ref.loc12_24: %C.C.cpp_overload_set.type = name_ref C, imports.%C.C.cpp_overload_set.value [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %.loc12_3.1: ref %C = splice_block %c1.var {}
+// CHECK:STDOUT:   %addr.loc12_27: %ptr.d9e = addr_of %.loc12_3.1
+// CHECK:STDOUT:   %C__carbon_thunk.call.loc12: init %empty_tuple.type = call imports.%C__carbon_thunk.decl(%addr.loc12_27)
+// CHECK:STDOUT:   %.loc12_27: init %C = in_place_init %C__carbon_thunk.call.loc12, %.loc12_3.1
+// CHECK:STDOUT:   assign %c1.var, %.loc12_27
+// CHECK:STDOUT:   %.loc12_14: type = splice_block %C.ref.loc12_14 [concrete = constants.%C] {
+// CHECK:STDOUT:     %Cpp.ref.loc12_11: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:     %C.ref.loc12_14: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1: ref %C = bind_name c1, %c1.var
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %c2.patt: %pattern_type.217 = binding_pattern c2 [concrete]
+// CHECK:STDOUT:     %c2.var_patt: %pattern_type.217 = var_pattern %c2.patt [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c2.var: ref %C = var %c2.var_patt
+// CHECK:STDOUT:   %Cpp.ref.loc13_19: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:   %C.ref.loc13_22: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   %C.ref.loc13_24: %C.C.cpp_overload_set.type = name_ref C, imports.%C.C.cpp_overload_set.value [concrete = constants.%C.C.cpp_overload_set.value]
+// CHECK:STDOUT:   %.loc13_3.1: ref %C = splice_block %c2.var {}
+// CHECK:STDOUT:   %addr.loc13_27: %ptr.d9e = addr_of %.loc13_3.1
+// CHECK:STDOUT:   %C__carbon_thunk.call.loc13: init %empty_tuple.type = call imports.%C__carbon_thunk.decl(%addr.loc13_27)
+// CHECK:STDOUT:   %.loc13_27: init %C = in_place_init %C__carbon_thunk.call.loc13, %.loc13_3.1
+// CHECK:STDOUT:   assign %c2.var, %.loc13_27
+// CHECK:STDOUT:   %.loc13_14: type = splice_block %C.ref.loc13_14 [concrete = constants.%C] {
+// CHECK:STDOUT:     %Cpp.ref.loc13_11: <namespace> = name_ref Cpp, imports.%Cpp [concrete = imports.%Cpp]
+// CHECK:STDOUT:     %C.ref.loc13_14: type = name_ref C, imports.%C.decl [concrete = constants.%C]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c2: ref %C = bind_name c2, %c2.var
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %equal.patt: %pattern_type.831 = binding_pattern equal [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc16: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc16: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc16_21.1: %C = bind_value %c1.ref.loc16
+// CHECK:STDOUT:   %.loc16_27.1: %C = bind_value %c2.ref.loc16
+// CHECK:STDOUT:   %.loc16_21.2: ref %C = value_as_ref %.loc16_21.1
+// CHECK:STDOUT:   %addr.loc16_24.1: %ptr.d9e = addr_of %.loc16_21.2
+// CHECK:STDOUT:   %.loc16_27.2: ref %C = value_as_ref %.loc16_27.1
+// CHECK:STDOUT:   %addr.loc16_24.2: %ptr.d9e = addr_of %.loc16_27.2
+// CHECK:STDOUT:   %.loc16_24.1: ref bool = temporary_storage
+// CHECK:STDOUT:   %addr.loc16_24.3: %ptr.bb2 = addr_of %.loc16_24.1
+// CHECK:STDOUT:   %operator==__carbon_thunk.call: init %empty_tuple.type = call imports.%operator==__carbon_thunk.decl(%addr.loc16_24.1, %addr.loc16_24.2, %addr.loc16_24.3)
+// CHECK:STDOUT:   %.loc16_24.2: init bool = in_place_init %operator==__carbon_thunk.call, %.loc16_24.1
+// CHECK:STDOUT:   %.loc16_14.1: type = splice_block %.loc16_14.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc16: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc16_14.2: type = value_of_initializer %Bool.call.loc16 [concrete = bool]
+// CHECK:STDOUT:     %.loc16_14.3: type = converted %Bool.call.loc16, %.loc16_14.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %.loc16_24.3: bool = value_of_initializer %.loc16_24.2
+// CHECK:STDOUT:   %.loc16_24.4: bool = converted %.loc16_24.2, %.loc16_24.3
+// CHECK:STDOUT:   %equal: bool = bind_name equal, %.loc16_24.4
+// CHECK:STDOUT:   name_binding_decl {
+// CHECK:STDOUT:     %not_equal.patt: %pattern_type.831 = binding_pattern not_equal [concrete]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %c1.ref.loc23: ref %C = name_ref c1, %c1
+// CHECK:STDOUT:   %c2.ref.loc23: ref %C = name_ref c2, %c2
+// CHECK:STDOUT:   %.loc23_18.1: type = splice_block %.loc23_18.3 [concrete = bool] {
+// CHECK:STDOUT:     %Bool.call.loc23: init type = call constants.%Bool() [concrete = bool]
+// CHECK:STDOUT:     %.loc23_18.2: type = value_of_initializer %Bool.call.loc23 [concrete = bool]
+// CHECK:STDOUT:     %.loc23_18.3: type = converted %Bool.call.loc23, %.loc23_18.2 [concrete = bool]
+// CHECK:STDOUT:   }
+// CHECK:STDOUT:   %not_equal: bool = bind_name not_equal, <error> [concrete = <error>]
+// CHECK:STDOUT:   %facet_value.loc13: %type_where = facet_value constants.%C, () [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %.loc13_3.2: %type_where = converted constants.%C, %facet_value.loc13 [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.bound.loc13: <bound method> = bound_method %c2.var, constants.%DestroyT.binding.as_type.as.Destroy.impl.Op.841
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %bound_method.loc13: <bound method> = bound_method %c2.var, %DestroyT.binding.as_type.as.Destroy.impl.Op.specific_fn.1
+// CHECK:STDOUT:   %addr.loc13_3: %ptr.d9e = addr_of %c2.var
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.call.loc13: init %empty_tuple.type = call %bound_method.loc13(%addr.loc13_3)
+// CHECK:STDOUT:   %facet_value.loc12: %type_where = facet_value constants.%C, () [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %.loc12_3.2: %type_where = converted constants.%C, %facet_value.loc12 [concrete = constants.%facet_value]
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.bound.loc12: <bound method> = bound_method %c1.var, constants.%DestroyT.binding.as_type.as.Destroy.impl.Op.841
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT:   %bound_method.loc12: <bound method> = bound_method %c1.var, %DestroyT.binding.as_type.as.Destroy.impl.Op.specific_fn.2
+// CHECK:STDOUT:   %addr.loc12_3: %ptr.d9e = addr_of %c1.var
+// CHECK:STDOUT:   %DestroyT.binding.as_type.as.Destroy.impl.Op.call.loc12: init %empty_tuple.type = call %bound_method.loc12(%addr.loc12_3)
+// CHECK:STDOUT:   <elided>
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
 // CHECK:STDOUT: --- import_single_namespace.carbon
 // CHECK:STDOUT:
 // CHECK:STDOUT: constants {

+ 1 - 0
toolchain/sem_ir/BUILD

@@ -139,6 +139,7 @@ cc_library(
         "//toolchain/parse:tree",
         "@llvm-project//clang:ast",
         "@llvm-project//clang:frontend",
+        "@llvm-project//clang:sema",
         "@llvm-project//llvm:Support",
     ],
 )

+ 5 - 0
toolchain/sem_ir/cpp_overload_set.h

@@ -7,6 +7,7 @@
 
 #include "clang/AST/Decl.h"
 #include "clang/AST/UnresolvedSet.h"
+#include "clang/Sema/Overload.h"
 #include "common/ostream.h"
 #include "toolchain/base/value_store.h"
 #include "toolchain/sem_ir/ids.h"
@@ -29,6 +30,10 @@ struct CppOverloadSet : public Printable<CppOverloadSet> {
   // store the candidates.
   clang::UnresolvedSet<4> candidate_functions;
 
+  /// Information about operator rewrites to consider when adding operator
+  /// functions to a candidate set.
+  clang::OverloadCandidateSet::OperatorRewriteInfo operator_rewrite_info;
+
   auto Print(llvm::raw_ostream& out) const -> void {
     out << "name: " << name_id << ", parent_scope: " << parent_scope_id;
   }