thunk.cpp 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/thunk.h"
  5. #include "toolchain/base/kind_switch.h"
  6. #include "toolchain/check/call.h"
  7. #include "toolchain/check/diagnostic_helpers.h"
  8. #include "toolchain/check/function.h"
  9. #include "toolchain/check/generic.h"
  10. #include "toolchain/check/inst.h"
  11. #include "toolchain/check/member_access.h"
  12. #include "toolchain/check/pattern.h"
  13. #include "toolchain/check/pattern_match.h"
  14. #include "toolchain/check/pointer_dereference.h"
  15. #include "toolchain/check/return.h"
  16. #include "toolchain/check/type.h"
  17. #include "toolchain/diagnostics/diagnostic.h"
  18. #include "toolchain/sem_ir/function.h"
  19. #include "toolchain/sem_ir/generic.h"
  20. #include "toolchain/sem_ir/ids.h"
  21. #include "toolchain/sem_ir/inst.h"
  22. #include "toolchain/sem_ir/pattern.h"
  23. #include "toolchain/sem_ir/typed_insts.h"
  24. namespace Carbon::Check {
  25. // Adds a pattern instruction for a thunk, copying the location from an existing
  26. // instruction.
  27. static auto RebuildPatternInst(Context& context, SemIR::InstId orig_inst_id,
  28. SemIR::Inst new_inst) -> SemIR::InstId {
  29. // Ensure we built the same kind of instruction. In particular, this ensures
  30. // that the location of the old instruction can be reused for the new one.
  31. CARBON_CHECK(context.insts().Get(orig_inst_id).kind() == new_inst.kind(),
  32. "Rebuilt pattern with the wrong kind: {0} -> {1}",
  33. context.insts().Get(orig_inst_id), new_inst);
  34. return AddPatternInst(context, SemIR::LocIdAndInst::UncheckedLoc(
  35. SemIR::LocId(orig_inst_id), new_inst));
  36. }
  37. // Wrapper to allow the type to be specified as a template argument for API
  38. // consistency with `AddInst`.
  39. template <typename InstT>
  40. static auto RebuildPatternInst(Context& context, SemIR::InstId orig_inst_id,
  41. InstT new_inst) -> SemIR::InstId {
  42. return RebuildPatternInst(context, orig_inst_id, SemIR::Inst(new_inst));
  43. }
  44. // Makes a copy of the given binding pattern, with its type adjusted to be
  45. // `new_pattern_type_id`.
  46. static auto CloneBindingPattern(Context& context, SemIR::InstId pattern_id,
  47. SemIR::AnyBindingPattern pattern,
  48. SemIR::TypeId new_pattern_type_id)
  49. -> SemIR::InstId {
  50. bool is_generic = pattern.kind == SemIR::SymbolicBindingPattern::Kind;
  51. auto entity_name = context.entity_names().Get(pattern.entity_name_id);
  52. CARBON_CHECK(is_generic == entity_name.bind_index().has_value());
  53. // Get the transformed type of the binding.
  54. if (new_pattern_type_id == SemIR::ErrorInst::TypeId) {
  55. return SemIR::ErrorInst::InstId;
  56. }
  57. auto type_inst_id = context.types()
  58. .GetAs<SemIR::PatternType>(new_pattern_type_id)
  59. .scrutinee_type_inst_id;
  60. auto type_id = context.types().GetTypeIdForTypeInstId(type_inst_id);
  61. auto type_expr_region_id = context.sem_ir().expr_regions().Add(
  62. {.block_ids = {SemIR::InstBlockId::Empty}, .result_id = type_inst_id});
  63. // Rebuild the binding pattern.
  64. return AddBindingPattern(context, SemIR::LocId(pattern_id),
  65. entity_name.name_id, type_id, type_expr_region_id,
  66. is_generic, entity_name.is_template)
  67. .pattern_id;
  68. }
  69. // Makes a copy of the given pattern instruction, substituting values from a
  70. // specific as needed. The resulting pattern behaves like a newly-created
  71. // pattern, so is suitable for running `CalleePatternMatch` against.
  72. static auto ClonePattern(Context& context, SemIR::SpecificId specific_id,
  73. SemIR::InstId pattern_id) -> SemIR::InstId {
  74. if (!pattern_id.has_value()) {
  75. return SemIR::InstId::None;
  76. }
  77. auto get_type = [&](SemIR::InstId inst_id) -> SemIR::TypeId {
  78. return SemIR::GetTypeOfInstInSpecific(context.sem_ir(), specific_id,
  79. inst_id);
  80. };
  81. auto pattern = context.insts().Get(pattern_id);
  82. // Decompose the pattern. The forms we allow for patterns in a function
  83. // parameter list are currently fairly restrictive.
  84. // Optional `addr`, only for `self`.
  85. auto [addr, addr_id] = context.insts().TryUnwrap(
  86. pattern, pattern_id, &SemIR::AddrPattern::inner_id);
  87. // Optional parameter pattern.
  88. auto [param, param_id] = context.insts().TryUnwrap(
  89. pattern, pattern_id, &SemIR::AnyParamPattern::subpattern_id);
  90. // Finally, either a binding pattern or a return slot pattern.
  91. auto new_pattern_id = SemIR::InstId::None;
  92. if (auto binding = pattern.TryAs<SemIR::AnyBindingPattern>()) {
  93. new_pattern_id = CloneBindingPattern(context, pattern_id, *binding,
  94. get_type(pattern_id));
  95. } else if (auto return_slot = pattern.TryAs<SemIR::ReturnSlotPattern>()) {
  96. new_pattern_id = RebuildPatternInst<SemIR::ReturnSlotPattern>(
  97. context, pattern_id,
  98. {.type_id = get_type(pattern_id),
  99. .type_inst_id = SemIR::TypeInstId::None});
  100. } else {
  101. CARBON_CHECK(pattern.Is<SemIR::ErrorInst>(),
  102. "Unexpected pattern {0} in function signature", pattern);
  103. return SemIR::ErrorInst::InstId;
  104. }
  105. // Rebuild parameter.
  106. if (param) {
  107. new_pattern_id = RebuildPatternInst<SemIR::AnyParamPattern>(
  108. context, param_id,
  109. {.kind = param->kind,
  110. .type_id = get_type(param_id),
  111. .subpattern_id = new_pattern_id,
  112. .index = SemIR::CallParamIndex::None});
  113. }
  114. // Rebuild `addr`.
  115. if (addr) {
  116. new_pattern_id = RebuildPatternInst<SemIR::AddrPattern>(
  117. context, addr_id,
  118. {.type_id = get_type(addr_id), .inner_id = new_pattern_id});
  119. }
  120. return new_pattern_id;
  121. }
  122. static auto ClonePatternBlock(Context& context, SemIR::SpecificId specific_id,
  123. SemIR::InstBlockId inst_block_id)
  124. -> SemIR::InstBlockId {
  125. if (!inst_block_id.has_value()) {
  126. return SemIR::InstBlockId::None;
  127. }
  128. return context.inst_blocks().Transform(
  129. inst_block_id, [&](SemIR::InstId inst_id) {
  130. return ClonePattern(context, specific_id, inst_id);
  131. });
  132. }
  133. static auto CloneFunctionDecl(Context& context, SemIR::LocId loc_id,
  134. SemIR::FunctionId signature_id,
  135. SemIR::SpecificId signature_specific_id,
  136. SemIR::FunctionId callee_id)
  137. -> std::pair<SemIR::FunctionId, SemIR::InstId> {
  138. StartGenericDecl(context);
  139. // Clone the signature. Note that we re-get the function after each of these,
  140. // because they might trigger imports that invalidate the function.
  141. context.pattern_block_stack().Push();
  142. auto implicit_param_patterns_id = ClonePatternBlock(
  143. context, signature_specific_id,
  144. context.functions().Get(signature_id).implicit_param_patterns_id);
  145. auto param_patterns_id = ClonePatternBlock(
  146. context, signature_specific_id,
  147. context.functions().Get(signature_id).param_patterns_id);
  148. auto return_slot_pattern_id = ClonePattern(
  149. context, signature_specific_id,
  150. context.functions().Get(signature_id).return_slot_pattern_id);
  151. auto self_param_id = FindSelfPattern(context, implicit_param_patterns_id);
  152. auto pattern_block_id = context.pattern_block_stack().Pop();
  153. // Perform callee-side pattern matching to rebuild the parameter list.
  154. context.inst_block_stack().Push();
  155. auto call_params_id =
  156. CalleePatternMatch(context, implicit_param_patterns_id, param_patterns_id,
  157. return_slot_pattern_id);
  158. auto decl_block_id = context.inst_block_stack().Pop();
  159. // Create the `FunctionDecl` instruction.
  160. SemIR::FunctionDecl function_decl = {SemIR::TypeId::None,
  161. SemIR::FunctionId::None, decl_block_id};
  162. auto decl_id = AddPlaceholderInst(
  163. context, SemIR::LocIdAndInst::UncheckedLoc(loc_id, function_decl));
  164. auto generic_id = BuildGenericDecl(context, decl_id);
  165. // Create the `Function` object.
  166. auto& signature = context.functions().Get(signature_id);
  167. auto& callee = context.functions().Get(callee_id);
  168. function_decl.function_id = context.functions().Add(SemIR::Function{
  169. {.name_id = signature.name_id,
  170. .parent_scope_id = callee.parent_scope_id,
  171. .generic_id = generic_id,
  172. .first_param_node_id = signature.first_param_node_id,
  173. .last_param_node_id = signature.last_param_node_id,
  174. .pattern_block_id = pattern_block_id,
  175. .implicit_param_patterns_id = implicit_param_patterns_id,
  176. .param_patterns_id = param_patterns_id,
  177. .is_extern = false,
  178. .extern_library_id = SemIR::LibraryNameId::None,
  179. .non_owning_decl_id = SemIR::InstId::None,
  180. .first_owning_decl_id = decl_id,
  181. .definition_id = decl_id},
  182. {.call_params_id = call_params_id,
  183. .return_slot_pattern_id = return_slot_pattern_id,
  184. .special_function_kind = SemIR::Function::SpecialFunctionKind::Thunk,
  185. .virtual_modifier = callee.virtual_modifier,
  186. .virtual_index = callee.virtual_index,
  187. .self_param_id = self_param_id}});
  188. function_decl.type_id =
  189. GetFunctionType(context, function_decl.function_id,
  190. context.scope_stack().PeekSpecificId());
  191. ReplaceInstBeforeConstantUse(context, decl_id, function_decl);
  192. return {function_decl.function_id, decl_id};
  193. }
  194. // Build an expression that names the value matched by a pattern.
  195. static auto BuildPatternRef(Context& context, SemIR::FunctionId function_id,
  196. SemIR::InstId pattern_id) -> SemIR::InstId {
  197. auto pattern = context.insts().Get(pattern_id);
  198. auto addr = context.insts()
  199. .TryUnwrap(pattern, pattern_id, &SemIR::AddrPattern::inner_id)
  200. .first;
  201. auto pattern_ref_id = SemIR::InstId::None;
  202. if (auto value_param = pattern.TryAs<SemIR::ValueParamPattern>()) {
  203. // Build a reference to this parameter.
  204. auto call_param_id = context.inst_blocks().Get(
  205. context.functions()
  206. .Get(function_id)
  207. .call_params_id)[value_param->index.index];
  208. // Use a pretty name for the `name_ref`. While it's suspicious to use a
  209. // pretty name in the IR like this, the only reason we include a name at
  210. // all here is to make the formatted SemIR more readable.
  211. pattern_ref_id = AddInst<SemIR::NameRef>(
  212. context, SemIR::LocId(pattern_id),
  213. {.type_id = context.insts().Get(call_param_id).type_id(),
  214. .name_id = SemIR::GetPrettyNameFromPatternId(
  215. context.sem_ir(), value_param->subpattern_id),
  216. .value_id = call_param_id});
  217. } else {
  218. if (pattern_id != SemIR::ErrorInst::InstId) {
  219. context.TODO(
  220. pattern_id,
  221. "don't know how to build reference to this pattern in thunk");
  222. }
  223. return SemIR::ErrorInst::InstId;
  224. }
  225. if (addr) {
  226. pattern_ref_id = PerformPointerDereference(
  227. context, SemIR::LocId(pattern_id), pattern_ref_id, [](SemIR::TypeId) {
  228. CARBON_FATAL("addr subpattern is not a pointer");
  229. });
  230. }
  231. return pattern_ref_id;
  232. }
  233. // Build a call to a function that forwards the arguments of the enclosing
  234. // function, for use when constructing a thunk.
  235. static auto BuildThunkCall(Context& context, SemIR::FunctionId function_id,
  236. SemIR::InstId callee_id) -> SemIR::InstId {
  237. auto loc_id = SemIR::LocId(callee_id);
  238. auto& function = context.functions().Get(function_id);
  239. // If we have a self parameter, form `self.<callee_id>`.
  240. if (function.self_param_id.has_value()) {
  241. callee_id = PerformCompoundMemberAccess(
  242. context, loc_id,
  243. BuildPatternRef(context, function_id, function.self_param_id),
  244. callee_id);
  245. }
  246. // Form an argument list.
  247. llvm::SmallVector<SemIR::InstId> args;
  248. for (auto pattern_id :
  249. context.inst_blocks().Get(function.param_patterns_id)) {
  250. args.push_back(BuildPatternRef(context, function_id, pattern_id));
  251. }
  252. return PerformCall(context, loc_id, callee_id, args);
  253. }
  254. static auto HasDeclaredReturnType(Context& context,
  255. SemIR::FunctionId function_id) -> bool {
  256. return context.functions()
  257. .Get(function_id)
  258. .return_slot_pattern_id.has_value();
  259. }
  260. // Given a declaration of a thunk and the function that it should call, build
  261. // the thunk body.
  262. static auto BuildThunkDefinition(Context& context,
  263. SemIR::FunctionId signature_id,
  264. SemIR::FunctionId function_id,
  265. SemIR::InstId thunk_id,
  266. SemIR::InstId callee_id) {
  267. // TODO: Improve the diagnostics produced here. Specifically, it would likely
  268. // be better for the primary error message to be that we tried to produce a
  269. // thunk because of a type mismatch, but couldn't, with notes explaining
  270. // why, rather than the primary error message being whatever went wrong
  271. // building the thunk.
  272. {
  273. // The check below produces diagnostics referring to the signature, so also
  274. // note the callee.
  275. Diagnostics::AnnotationScope annot_scope(
  276. &context.emitter(), [&](DiagnosticBuilder& builder) {
  277. CARBON_DIAGNOSTIC(ThunkCallee, Note,
  278. "while building thunk calling this function");
  279. builder.Note(callee_id, ThunkCallee);
  280. });
  281. CheckFunctionDefinitionSignature(context, function_id);
  282. }
  283. // TODO: This duplicates much of the handling for FunctionDefinitionStart and
  284. // FunctionDefinition parse nodes. Consider refactoring.
  285. context.scope_stack().PushForFunctionBody(thunk_id);
  286. context.inst_block_stack().Push();
  287. context.region_stack().PushRegion(context.inst_block_stack().PeekOrAdd());
  288. StartGenericDefinition(context,
  289. context.functions().Get(function_id).generic_id);
  290. // The checks below produce diagnostics pointing at the callee, so also note
  291. // the signature.
  292. Diagnostics::AnnotationScope annot_scope(
  293. &context.emitter(), [&](DiagnosticBuilder& builder) {
  294. CARBON_DIAGNOSTIC(
  295. ThunkSignature, Note,
  296. "while building thunk to match the signature of this function");
  297. builder.Note(context.functions().Get(signature_id).first_owning_decl_id,
  298. ThunkSignature);
  299. });
  300. auto call_id = BuildThunkCall(context, function_id, callee_id);
  301. if (HasDeclaredReturnType(context, function_id)) {
  302. BuildReturnWithExpr(context, SemIR::LocId(callee_id), call_id);
  303. } else {
  304. BuildReturnWithNoExpr(context, SemIR::LocId(callee_id));
  305. }
  306. context.inst_block_stack().Pop();
  307. context.scope_stack().Pop();
  308. auto& function = context.functions().Get(function_id);
  309. function.body_block_ids = context.region_stack().PopRegion();
  310. FinishGenericDefinition(context, function.generic_id);
  311. }
  312. auto BuildThunk(Context& context, SemIR::FunctionId signature_id,
  313. SemIR::SpecificId signature_specific_id,
  314. SemIR::InstId callee_id) -> SemIR::InstId {
  315. auto callee = SemIR::GetCalleeFunction(context.sem_ir(), callee_id);
  316. // Check whether we can use the given function without a thunk.
  317. // TODO: For virtual functions, we want different rules for checking `self`.
  318. // TODO: This is too strict; for example, we should not compare parameter
  319. // names here.
  320. if (CheckFunctionTypeMatches(
  321. context, context.functions().Get(callee.function_id),
  322. context.functions().Get(signature_id), signature_specific_id,
  323. /*check_syntax=*/false, /*check_self=*/true, /*diagnose=*/false)) {
  324. return callee_id;
  325. }
  326. // From P3763:
  327. // If the function in the interface does not have a return type, the
  328. // program is invalid if the function in the impl specifies a return type.
  329. //
  330. // Call into the redeclaration checking logic to produce a suitable error.
  331. //
  332. // TODO: Consider a different rule: always use an explicit return type for the
  333. // thunk, and always convert the result of the wrapped call to the return type
  334. // of the thunk.
  335. if (!HasDeclaredReturnType(context, signature_id) &&
  336. HasDeclaredReturnType(context, callee.function_id)) {
  337. bool success = CheckFunctionReturnTypeMatches(
  338. context, context.functions().Get(callee.function_id),
  339. context.functions().Get(signature_id), signature_specific_id);
  340. CARBON_CHECK(!success, "Return type unexpectedly matches");
  341. return SemIR::ErrorInst::InstId;
  342. }
  343. // Create a scope for the function's parameters and generic parameters.
  344. context.scope_stack().PushForDeclName();
  345. // We can't use the function directly. Build a thunk.
  346. // TODO: Check for and diagnose obvious reasons why this will fail, such as
  347. // arity mismatch, before trying to build the thunk.
  348. auto [function_id, thunk_id] =
  349. CloneFunctionDecl(context, SemIR::LocId(callee_id), signature_id,
  350. signature_specific_id, callee.function_id);
  351. // Define the thunk.
  352. // TODO: We should delay doing this until we get to the end of the enclosing
  353. // deferred definition scope, if there is one. For example, an `impl` inside a
  354. // `class` definition should have its thunks defined at the end of the class,
  355. // like they would be if they were defined inline.
  356. BuildThunkDefinition(context, signature_id, function_id, thunk_id, callee_id);
  357. context.scope_stack().Pop();
  358. return thunk_id;
  359. }
  360. } // namespace Carbon::Check