Просмотр исходного кода

Support lowering of functions with variable binding parameters. (#6012)

Fix a crash when attempting to lower a function with a variable binding
as a parameter. This is a narrowly-targeted fix, and not the right
longer-term approach; more complex patterns as function parameters will
still fail and likely crash.
Richard Smith 8 месяцев назад
Родитель
Сommit
51cb078da4

+ 10 - 0
toolchain/lower/file_context.cpp

@@ -327,11 +327,21 @@ auto FileContext::BuildFunctionTypeInfo(const SemIR::Function& function,
   }
   for (auto param_pattern_id : llvm::concat<const SemIR::InstId>(
            implicit_param_patterns, param_patterns)) {
+    // TODO: Handle a general pattern here, rather than assuming that each
+    // parameter pattern contains at most one binding.
     auto param_pattern_info = SemIR::Function::GetParamPatternInfoFromPatternId(
         sem_ir(), param_pattern_id);
     if (!param_pattern_info) {
       continue;
     }
+    // TODO: Use a more general mechanism to determine if the binding is a
+    // reference binding.
+    if (param_pattern_info->var_pattern_id.has_value()) {
+      param_types.push_back(
+          llvm::PointerType::get(llvm_context(), /*AddressSpace=*/0));
+      param_inst_ids.push_back(param_pattern_id);
+      continue;
+    }
     auto param_type_id = ExtractScrutineeType(
         sem_ir(), SemIR::GetTypeOfInstInSpecific(sem_ir(), specific_id,
                                                  param_pattern_info->inst_id));

+ 137 - 0
toolchain/lower/testdata/function/definition/var_param.carbon

@@ -0,0 +1,137 @@
+// Part of the Carbon Language project, under the Apache License v2.0 with LLVM
+// Exceptions. See /LICENSE for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+// INCLUDE-FILE: toolchain/testing/testdata/min_prelude/int.carbon
+//
+// AUTOUPDATE
+// TIP: To test this file alone, run:
+// TIP:   bazel test //toolchain/testing:file_test --test_arg=--file_tests=toolchain/lower/testdata/function/definition/var_param.carbon
+// TIP: To dump output, run:
+// TIP:   bazel run //toolchain/testing:file_test -- --dump_output --file_tests=toolchain/lower/testdata/function/definition/var_param.carbon
+
+class X {}
+
+fn OneVar_i32(var n: i32) {}
+fn OneVar_X(var x: X) {}
+
+fn TwoVars(var a: i32, var b: X) {}
+
+fn VarThenLet(var a: i32, b: X) {}
+fn LetThenVar(a: i32, var b: X) {}
+
+fn Call() {
+  OneVar_i32(1);
+  OneVar_X({});
+  TwoVars(1, {});
+  VarThenLet(1, {});
+  LetThenVar(1, {});
+}
+
+// CHECK:STDOUT: ; ModuleID = 'var_param.carbon'
+// CHECK:STDOUT: source_filename = "var_param.carbon"
+// CHECK:STDOUT:
+// CHECK:STDOUT: @X.val.loc16_13.2 = internal constant {} zeroinitializer
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_COneVar_i32.Main(ptr %n) !dbg !4 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !7
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_COneVar_X.Main(ptr %x) !dbg !8 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !9
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CTwoVars.Main(ptr %a, ptr %b) !dbg !10 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !11
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CVarThenLet.Main(ptr %a, ptr %b) !dbg !12 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !13
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CLetThenVar.Main(i32 %a, ptr %b) !dbg !14 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   ret void, !dbg !15
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: define void @_CCall.Main() !dbg !16 {
+// CHECK:STDOUT: entry:
+// CHECK:STDOUT:   %.loc15_15.1.temp = alloca i32, align 4, !dbg !17
+// CHECK:STDOUT:   %.loc16_13.1.temp = alloca {}, align 8, !dbg !18
+// CHECK:STDOUT:   %.loc18_12.1.temp = alloca i32, align 4, !dbg !19
+// CHECK:STDOUT:   %.loc18_24.1.temp = alloca {}, align 8, !dbg !20
+// CHECK:STDOUT:   %.loc20_15.1.temp = alloca i32, align 4, !dbg !21
+// CHECK:STDOUT:   %.loc27_18.2.temp = alloca {}, align 8, !dbg !22
+// CHECK:STDOUT:   %.loc21_23.1.temp = alloca {}, align 8, !dbg !23
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc15_15.1.temp), !dbg !17
+// CHECK:STDOUT:   store i32 1, ptr %.loc15_15.1.temp, align 4, !dbg !17
+// CHECK:STDOUT:   call void @_COneVar_i32.Main(ptr %.loc15_15.1.temp), !dbg !24
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc16_13.1.temp), !dbg !18
+// CHECK:STDOUT:   call void @llvm.memcpy.p0.p0.i64(ptr align 1 %.loc16_13.1.temp, ptr align 1 @X.val.loc16_13.2, i64 0, i1 false), !dbg !18
+// CHECK:STDOUT:   call void @_COneVar_X.Main(ptr %.loc16_13.1.temp), !dbg !25
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc18_12.1.temp), !dbg !19
+// CHECK:STDOUT:   store i32 1, ptr %.loc18_12.1.temp, align 4, !dbg !19
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc18_24.1.temp), !dbg !20
+// CHECK:STDOUT:   call void @llvm.memcpy.p0.p0.i64(ptr align 1 %.loc18_24.1.temp, ptr align 1 @X.val.loc16_13.2, i64 0, i1 false), !dbg !20
+// CHECK:STDOUT:   call void @_CTwoVars.Main(ptr %.loc18_12.1.temp, ptr %.loc18_24.1.temp), !dbg !26
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc20_15.1.temp), !dbg !21
+// CHECK:STDOUT:   store i32 1, ptr %.loc20_15.1.temp, align 4, !dbg !21
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc27_18.2.temp), !dbg !22
+// CHECK:STDOUT:   call void @llvm.memcpy.p0.p0.i64(ptr align 1 %.loc27_18.2.temp, ptr align 1 @X.val.loc16_13.2, i64 0, i1 false), !dbg !22
+// CHECK:STDOUT:   call void @_CVarThenLet.Main(ptr %.loc20_15.1.temp, ptr %.loc27_18.2.temp), !dbg !27
+// CHECK:STDOUT:   call void @llvm.lifetime.start.p0(ptr %.loc21_23.1.temp), !dbg !23
+// CHECK:STDOUT:   call void @llvm.memcpy.p0.p0.i64(ptr align 1 %.loc21_23.1.temp, ptr align 1 @X.val.loc16_13.2, i64 0, i1 false), !dbg !23
+// CHECK:STDOUT:   call void @_CLetThenVar.Main(i32 1, ptr %.loc21_23.1.temp), !dbg !28
+// CHECK:STDOUT:   ret void, !dbg !29
+// CHECK:STDOUT: }
+// CHECK:STDOUT:
+// CHECK:STDOUT: ; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
+// CHECK:STDOUT: declare void @llvm.lifetime.start.p0(ptr captures(none)) #0
+// CHECK:STDOUT:
+// CHECK:STDOUT: ; Function Attrs: nocallback nofree nounwind willreturn memory(argmem: readwrite)
+// CHECK:STDOUT: declare void @llvm.memcpy.p0.p0.i64(ptr noalias writeonly captures(none), ptr noalias readonly captures(none), i64, i1 immarg) #1
+// CHECK:STDOUT:
+// CHECK:STDOUT: ; uselistorder directives
+// CHECK:STDOUT: uselistorder ptr @llvm.lifetime.start.p0, { 6, 5, 4, 3, 2, 1, 0 }
+// CHECK:STDOUT: uselistorder ptr @llvm.memcpy.p0.p0.i64, { 3, 2, 1, 0 }
+// CHECK:STDOUT:
+// CHECK:STDOUT: attributes #0 = { nocallback nofree nosync nounwind willreturn memory(argmem: readwrite) }
+// CHECK:STDOUT: attributes #1 = { nocallback nofree nounwind willreturn memory(argmem: readwrite) }
+// CHECK:STDOUT:
+// CHECK:STDOUT: !llvm.module.flags = !{!0, !1}
+// CHECK:STDOUT: !llvm.dbg.cu = !{!2}
+// CHECK:STDOUT:
+// CHECK:STDOUT: !0 = !{i32 7, !"Dwarf Version", i32 5}
+// CHECK:STDOUT: !1 = !{i32 2, !"Debug Info Version", i32 3}
+// CHECK:STDOUT: !2 = distinct !DICompileUnit(language: DW_LANG_C, file: !3, producer: "carbon", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug)
+// CHECK:STDOUT: !3 = !DIFile(filename: "var_param.carbon", directory: "")
+// CHECK:STDOUT: !4 = distinct !DISubprogram(name: "OneVar_i32", linkageName: "_COneVar_i32.Main", scope: null, file: !3, line: 15, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !5 = !DISubroutineType(types: !6)
+// CHECK:STDOUT: !6 = !{}
+// CHECK:STDOUT: !7 = !DILocation(line: 15, column: 1, scope: !4)
+// CHECK:STDOUT: !8 = distinct !DISubprogram(name: "OneVar_X", linkageName: "_COneVar_X.Main", scope: null, file: !3, line: 16, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !9 = !DILocation(line: 16, column: 1, scope: !8)
+// CHECK:STDOUT: !10 = distinct !DISubprogram(name: "TwoVars", linkageName: "_CTwoVars.Main", scope: null, file: !3, line: 18, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !11 = !DILocation(line: 18, column: 1, scope: !10)
+// CHECK:STDOUT: !12 = distinct !DISubprogram(name: "VarThenLet", linkageName: "_CVarThenLet.Main", scope: null, file: !3, line: 20, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !13 = !DILocation(line: 20, column: 1, scope: !12)
+// CHECK:STDOUT: !14 = distinct !DISubprogram(name: "LetThenVar", linkageName: "_CLetThenVar.Main", scope: null, file: !3, line: 21, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !15 = !DILocation(line: 21, column: 1, scope: !14)
+// CHECK:STDOUT: !16 = distinct !DISubprogram(name: "Call", linkageName: "_CCall.Main", scope: null, file: !3, line: 23, type: !5, spFlags: DISPFlagDefinition, unit: !2)
+// CHECK:STDOUT: !17 = !DILocation(line: 15, column: 15, scope: !16)
+// CHECK:STDOUT: !18 = !DILocation(line: 16, column: 13, scope: !16)
+// CHECK:STDOUT: !19 = !DILocation(line: 18, column: 12, scope: !16)
+// CHECK:STDOUT: !20 = !DILocation(line: 18, column: 24, scope: !16)
+// CHECK:STDOUT: !21 = !DILocation(line: 20, column: 15, scope: !16)
+// CHECK:STDOUT: !22 = !DILocation(line: 27, column: 17, scope: !16)
+// CHECK:STDOUT: !23 = !DILocation(line: 21, column: 23, scope: !16)
+// CHECK:STDOUT: !24 = !DILocation(line: 24, column: 3, scope: !16)
+// CHECK:STDOUT: !25 = !DILocation(line: 25, column: 3, scope: !16)
+// CHECK:STDOUT: !26 = !DILocation(line: 26, column: 3, scope: !16)
+// CHECK:STDOUT: !27 = !DILocation(line: 27, column: 3, scope: !16)
+// CHECK:STDOUT: !28 = !DILocation(line: 28, column: 3, scope: !16)
+// CHECK:STDOUT: !29 = !DILocation(line: 23, column: 1, scope: !16)

+ 4 - 1
toolchain/sem_ir/function.cpp

@@ -98,6 +98,8 @@ auto Function::GetParamPatternInfoFromPatternId(const File& sem_ir,
   auto inst = sem_ir.insts().Get(inst_id);
 
   sem_ir.insts().TryUnwrap(inst, inst_id, &AddrPattern::inner_id);
+  auto [var_pattern, var_pattern_id] =
+      sem_ir.insts().TryUnwrap(inst, inst_id, &VarPattern::subpattern_id);
   auto [param_pattern, param_pattern_id] =
       sem_ir.insts().TryUnwrap(inst, inst_id, &AnyParamPattern::subpattern_id);
   if (!param_pattern) {
@@ -107,7 +109,8 @@ auto Function::GetParamPatternInfoFromPatternId(const File& sem_ir,
   auto binding_pattern = inst.As<AnyBindingPattern>();
   return {{.inst_id = param_pattern_id,
            .inst = *param_pattern,
-           .entity_name_id = binding_pattern.entity_name_id}};
+           .entity_name_id = binding_pattern.entity_name_id,
+           .var_pattern_id = var_pattern_id}};
 }
 
 auto Function::GetDeclaredReturnType(const File& file,

+ 1 - 0
toolchain/sem_ir/function.h

@@ -95,6 +95,7 @@ struct Function : public EntityWithParamsBase,
     InstId inst_id;
     AnyParamPattern inst;
     EntityNameId entity_name_id;
+    KnownInstId<VarPattern> var_pattern_id;
   };
 
   auto Print(llvm::raw_ostream& out) const -> void {

+ 7 - 8
toolchain/sem_ir/inst.h

@@ -524,7 +524,7 @@ class InstStore {
 
   template <class InstT>
   struct GetAsWithIdResult {
-    SemIR::KnownInstId<InstT> inst_id;
+    KnownInstId<InstT> inst_id;
     InstT inst;
   };
 
@@ -534,8 +534,7 @@ class InstStore {
   template <typename InstT>
   auto GetAsWithId(InstId inst_id) const -> GetAsWithIdResult<InstT> {
     auto inst = GetAs<InstT>(inst_id);
-    return {.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
-            .inst = inst};
+    return {.inst_id = KnownInstId<InstT>::UnsafeMake(inst_id), .inst = inst};
   }
 
   // Returns the requested instruction, if it is of that type, along with the
@@ -548,8 +547,8 @@ class InstStore {
     if (!inst) {
       return std::nullopt;
     }
-    return {{.inst_id = SemIR::KnownInstId<InstT>::UnsafeMake(inst_id),
-             .inst = *inst}};
+    return {
+        {.inst_id = KnownInstId<InstT>::UnsafeMake(inst_id), .inst = *inst}};
   }
 
   // Attempts to convert the given instruction to the type that contains
@@ -559,14 +558,14 @@ class InstStore {
   template <typename InstT, typename InstIdT>
     requires std::derived_from<InstIdT, InstId>
   auto TryUnwrap(Inst& inst, InstId& inst_id, InstIdT InstT::* member) const
-      -> std::pair<std::optional<InstT>, InstId> {
+      -> std::pair<std::optional<InstT>, KnownInstId<InstT>> {
     if (auto wrapped_inst = inst.TryAs<InstT>()) {
-      auto wrapped_inst_id = inst_id;
+      auto wrapped_inst_id = KnownInstId<InstT>::UnsafeMake(inst_id);
       inst_id = (*wrapped_inst).*member;
       inst = Get(inst_id);
       return {wrapped_inst, wrapped_inst_id};
     }
-    return {std::nullopt, InstId::None};
+    return {std::nullopt, KnownInstId<InstT>::None};
   }
 
   // Returns a resolved LocId, which will point to a parse node, an import, or