Răsfoiți Sursa

Use the canonical instructions to get SpecificIds in GetCallee (#6726)

GetCallee returns a structure with SpecificIds in it, and then those
specifics are used to later get constant values. This is fine when those
specifics are canonical, but it's problematic when they are not, because
non-canonical specifics (from a generic eval block) do not ever have any
resolved decl/defn blocks.

Formatting in particular works with non-canonical instructions when it
formats a generic eval block. We want to be able to format the block,
but those specifics are not useful for constant value mapping/lookup.
GetCallee grabs (non-canonical) instruction ids out of other
instructions. When getting a SpecificId out of an instruction, it should
map that instruction to the canonical value first. This means the
specific will be resolved and can be used for constant value mapping
later.

Fixes #6677
Dana Jansens 2 luni în urmă
părinte
comite
e991657e1d

+ 19 - 0
toolchain/check/testdata/basics/dump_prelude.carbon

@@ -0,0 +1,19 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/full.carbon
+//
+// ARGS: compile --phase=check --dump-sem-ir %s
+//
+// NOAUTOUPDATE
+// SET-CHECK-SUBSET
+//
+// TIP: To test this file alone, run:
+// TIP:   bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/check/testdata/basics/dump_prelude.carbon
+// TIP: To dump output, run:
+// TIP:   bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/check/testdata/basics/dump_prelude.carbon
+
+// CHECK:STDOUT: interface @Copy {
+// CHECK:STDOUT: interface @Destroy {
+// CHECK:STDOUT: fn @MakeInt(%size.param: Core.IntLiteral) -> out %return.param: type = "int.make_type_signed";

+ 6 - 3
toolchain/check/testdata/facet/fail_deduction_uses_runtime_type_conversion.carbon

@@ -86,6 +86,7 @@ fn G(holds_to: HoldsType((RuntimeConvertTo, )), from:! RuntimeConvertFrom) {
 // CHECK:STDOUT:   %G: %G.type = struct_value () [concrete]
 // CHECK:STDOUT:   %.8ab: type = fn_type_with_self_type %ImplicitAs.WithSelf.Convert.type.3a2, %ImplicitAs.facet [concrete]
 // CHECK:STDOUT:   %RuntimeConvertFrom.as.ImplicitAs.impl.Convert.bound: <bound method> = bound_method %from, %RuntimeConvertFrom.as.ImplicitAs.impl.Convert [symbolic]
+// CHECK:STDOUT:   %empty_tuple: %empty_tuple.type = tuple_value () [concrete]
 // CHECK:STDOUT:   %Destroy.type: type = facet_type <@Destroy> [concrete]
 // CHECK:STDOUT:   %DestroyOp.type: type = fn_type @DestroyOp [concrete]
 // CHECK:STDOUT:   %DestroyOp: %DestroyOp.type = struct_value () [concrete]
@@ -168,8 +169,8 @@ fn G(holds_to: HoldsType((RuntimeConvertTo, )), from:! RuntimeConvertFrom) {
 // CHECK:STDOUT:       %HoldsType.ref: %HoldsType.type = name_ref HoldsType, file.%HoldsType.decl [concrete = constants.%HoldsType.generic]
 // CHECK:STDOUT:       %RuntimeConvertTo.ref: type = name_ref RuntimeConvertTo, file.%RuntimeConvertTo.decl [concrete = constants.%RuntimeConvertTo]
 // CHECK:STDOUT:       %.loc29_45: %tuple.type = tuple_literal (%RuntimeConvertTo.ref) [concrete = constants.%tuple.b95]
-// CHECK:STDOUT:       %tuple: %tuple.type = tuple_value (%RuntimeConvertTo.ref) [concrete = constants.%tuple.b95]
-// CHECK:STDOUT:       %.loc29_46.2: %tuple.type = converted %.loc29_45, %tuple [concrete = constants.%tuple.b95]
+// CHECK:STDOUT:       %tuple.loc29: %tuple.type = tuple_value (%RuntimeConvertTo.ref) [concrete = constants.%tuple.b95]
+// CHECK:STDOUT:       %.loc29_46.2: %tuple.type = converted %.loc29_45, %tuple.loc29 [concrete = constants.%tuple.b95]
 // CHECK:STDOUT:       %HoldsType: type = class_type @HoldsType, @HoldsType(constants.%tuple.b95) [concrete = constants.%HoldsType.0ca]
 // CHECK:STDOUT:     }
 // CHECK:STDOUT:     %holds_to: %HoldsType.0ca = value_binding holds_to, %holds_to.param
@@ -279,7 +280,9 @@ fn G(holds_to: HoldsType((RuntimeConvertTo, )), from:! RuntimeConvertFrom) {
 // CHECK:STDOUT:     %.loc40_19.3: ref %RuntimeConvertTo = temporary %.loc40_19.1, %.loc40_19.2
 // CHECK:STDOUT:     %.loc40_19.4: %RuntimeConvertTo = acquire_value %.loc40_19.3
 // CHECK:STDOUT:     %F.specific_fn: <specific function> = specific_function %F.ref, @F(constants.%tuple.b95, <error>) [concrete = <error>]
-// CHECK:STDOUT:     %F.call: init %empty_tuple.type = call %F.specific_fn(%holds_to.ref)
+// CHECK:STDOUT:     %.loc40_19.5: %empty_tuple.type = call %F.specific_fn(%holds_to.ref)
+// CHECK:STDOUT:     %tuple.loc40: %empty_tuple.type = tuple_value () [concrete = constants.%empty_tuple]
+// CHECK:STDOUT:     %.loc40_19.6: %empty_tuple.type = converted %.loc40_19.5, %tuple.loc40 [concrete = constants.%empty_tuple]
 // CHECK:STDOUT:     %DestroyOp.bound: <bound method> = bound_method %.loc40_19.3, constants.%DestroyOp
 // CHECK:STDOUT:     %DestroyOp.call: init %empty_tuple.type = call %DestroyOp.bound(%.loc40_19.3)
 // CHECK:STDOUT:     return

+ 15 - 13
toolchain/sem_ir/formatter.cpp

@@ -717,8 +717,8 @@ auto Formatter::FormatGenericEnd() -> void {
   out_ << '\n';
 }
 
-auto Formatter::FormatParamList(InstBlockId params_id,
-                                SemIR::InstId return_form_id) -> void {
+auto Formatter::FormatParamList(InstBlockId params_id, InstId return_form_id)
+    -> void {
   if (!params_id.has_value()) {
     // TODO: This happens for imported functions, for which we don't currently
     // import the call parameters list.
@@ -1313,18 +1313,20 @@ auto Formatter::FormatCallRhs(Call inst) -> void {
 
   // If there's a return argument, don't print it here, because it's printed on
   // the LHS.
-  auto callee_function = SemIR::GetCalleeAsFunction(*sem_ir_, inst.callee_id);
-  auto function = sem_ir_->functions().Get(callee_function.function_id);
-  auto return_form_id = function.GetDeclaredReturnForm(
-      *sem_ir_, callee_function.resolved_specific_id);
   int return_arg_index = -1;
-  if (return_form_id.has_value()) {
-    if (auto init_form =
-            sem_ir_->insts().TryGetAs<SemIR::InitForm>(return_form_id)) {
-      auto type_id = sem_ir_->types().GetTypeIdForTypeInstId(
-          init_form->type_component_inst_id);
-      if (SemIR::InitRepr::ForType(*sem_ir_, type_id).MightBeInPlace()) {
-        return_arg_index = init_form->index.index;
+  auto callee = GetCallee(*sem_ir_, inst.callee_id);
+  if (auto* callee_function = std::get_if<CalleeFunction>(&callee)) {
+    auto function = sem_ir_->functions().Get(callee_function->function_id);
+    auto return_form_id = function.GetDeclaredReturnForm(
+        *sem_ir_, callee_function->resolved_specific_id);
+    if (return_form_id.has_value()) {
+      if (auto init_form =
+              sem_ir_->insts().TryGetAs<InitForm>(return_form_id)) {
+        auto type_id = sem_ir_->types().GetTypeIdForTypeInstId(
+            init_form->type_component_inst_id);
+        if (InitRepr::ForType(*sem_ir_, type_id).MightBeInPlace()) {
+          return_arg_index = init_form->index.index;
+        }
       }
     }
   }

+ 30 - 18
toolchain/sem_ir/function.cpp

@@ -5,6 +5,7 @@
 #include "toolchain/sem_ir/function.h"
 
 #include <optional>
+#include <variant>
 
 #include "toolchain/base/kind_switch.h"
 #include "toolchain/sem_ir/file.h"
@@ -33,25 +34,32 @@ auto GetCallee(const File& sem_ir, InstId callee_id,
                  "Invalid callee id in a specific context");
   }
 
+  auto val_id = sem_ir.constant_values().GetConstantInstId(callee_id);
+  if (!val_id.has_value()) {
+    return CalleeNonFunction();
+  }
+
   if (auto specific_function =
-          sem_ir.insts().TryGetAs<SpecificFunction>(callee_id)) {
+          sem_ir.insts().TryGetAs<SpecificFunction>(val_id)) {
     fn.resolved_specific_id = specific_function->specific_id;
-    callee_id = specific_function->callee_id;
+    val_id = sem_ir.constant_values().GetConstantInstId(
+        specific_function->callee_id);
   } else if (auto specific_impl_function =
-                 sem_ir.insts().TryGetAs<SpecificImplFunction>(callee_id)) {
+                 sem_ir.insts().TryGetAs<SpecificImplFunction>(val_id)) {
     fn.resolved_specific_id = specific_impl_function->specific_id;
-    callee_id = specific_impl_function->callee_id;
+    val_id = sem_ir.constant_values().GetConstantInstId(
+        specific_impl_function->callee_id);
   }
-
-  // Identify the function we're calling by its type.
-  auto val_id = sem_ir.constant_values().GetConstantInstId(callee_id);
   if (!val_id.has_value()) {
     return CalleeNonFunction();
   }
-  auto fn_type_inst =
-      sem_ir.types().GetAsInst(sem_ir.insts().Get(val_id).type_id());
 
-  if (auto cpp_overload_set_type = fn_type_inst.TryAs<CppOverloadSetType>()) {
+  // Identify the function we're calling by its type.
+  auto fn_type_inst_id =
+      sem_ir.types().GetTypeInstId(sem_ir.insts().Get(val_id).type_id());
+
+  if (auto cpp_overload_set_type =
+          sem_ir.insts().TryGetAs<CppOverloadSetType>(fn_type_inst_id)) {
     CARBON_CHECK(!fn.resolved_specific_id.has_value(),
                  "Only `SpecificFunction` will be resolved, not C++ overloads");
     return CalleeCppOverloadSet{
@@ -59,18 +67,21 @@ auto GetCallee(const File& sem_ir, InstId callee_id,
         .self_id = fn.self_id};
   }
 
-  if (auto impl_fn_type = fn_type_inst.TryAs<FunctionTypeWithSelfType>()) {
+  if (auto impl_fn_type =
+          sem_ir.insts().TryGetAs<FunctionTypeWithSelfType>(fn_type_inst_id)) {
     // Combine the associated function's `Self` with the interface function
     // data.
     fn.self_type_id = impl_fn_type->self_id;
-    fn_type_inst = sem_ir.insts().Get(impl_fn_type->interface_function_type_id);
+    fn_type_inst_id = impl_fn_type->interface_function_type_id;
   }
 
-  auto fn_type = fn_type_inst.TryAs<FunctionType>();
+  auto fn_type_val_id =
+      sem_ir.constant_values().GetConstantInstId(fn_type_inst_id);
+  if (fn_type_val_id == ErrorInst::InstId) {
+    return CalleeError();
+  }
+  auto fn_type = sem_ir.insts().TryGetAs<FunctionType>(fn_type_val_id);
   if (!fn_type) {
-    if (fn_type_inst.Is<ErrorInst>()) {
-      return CalleeError();
-    }
     return CalleeNonFunction();
   }
 
@@ -81,8 +92,9 @@ auto GetCallee(const File& sem_ir, InstId callee_id,
 
 auto GetCalleeAsFunction(const File& sem_ir, InstId callee_id,
                          SpecificId caller_specific_id) -> CalleeFunction {
-  return std::get<CalleeFunction>(
-      GetCallee(sem_ir, callee_id, caller_specific_id));
+  auto callee = GetCallee(sem_ir, callee_id, caller_specific_id);
+  CARBON_CHECK(std::holds_alternative<CalleeFunction>(callee));
+  return std::get<CalleeFunction>(callee);
 }
 
 auto DecomposeVirtualFunction(const File& sem_ir, InstId fn_decl_id,

+ 2 - 0
toolchain/sem_ir/function.h

@@ -253,6 +253,8 @@ auto GetCallee(const File& sem_ir, InstId callee_id,
                SpecificId caller_specific_id = SpecificId::None) -> Callee;
 
 // Like `GetCallee`, but restricts to the `Function` callee kind.
+//
+// It is invalid to call this with a callee that has an error inside it.
 auto GetCalleeAsFunction(const File& sem_ir, InstId callee_id,
                          SpecificId caller_specific_id = SpecificId::None)
     -> CalleeFunction;