pattern_match.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/pattern_match.h"
  5. #include <functional>
  6. #include <vector>
  7. #include "llvm/ADT/STLExtras.h"
  8. #include "llvm/ADT/SmallVector.h"
  9. #include "toolchain/base/kind_switch.h"
  10. #include "toolchain/check/context.h"
  11. #include "toolchain/check/convert.h"
  12. namespace Carbon::Check {
  13. // Returns a best-effort name for the given ParamPattern, suitable for use in
  14. // IR pretty-printing.
  15. // TODO: Resolve overlap with SemIR::Function::ParamPatternInfo::GetNameId
  16. template <typename ParamPattern>
  17. static auto GetPrettyName(Context& context, ParamPattern param_pattern)
  18. -> SemIR::NameId {
  19. if (context.insts().Is<SemIR::ReturnSlotPattern>(
  20. param_pattern.subpattern_id)) {
  21. return SemIR::NameId::ReturnSlot;
  22. }
  23. if (auto binding_pattern = context.insts().TryGetAs<SemIR::AnyBindingPattern>(
  24. param_pattern.subpattern_id)) {
  25. return context.entity_names().Get(binding_pattern->entity_name_id).name_id;
  26. }
  27. return SemIR::NameId::Invalid;
  28. }
  29. namespace {
  30. // Selects between the different kinds of pattern matching.
  31. enum class MatchKind {
  32. // Caller pattern matching occurs on the caller side of a function call, and
  33. // is responsible for matching the argument expression against the portion
  34. // of the pattern above the ParamPattern insts.
  35. Caller,
  36. // Callee pattern matching occurs in the function decl block, and is
  37. // responsible for matching the function's calling-convention parameters
  38. // against the portion of the pattern below the ParamPattern insts.
  39. Callee,
  40. // TODO: Add enumerator for non-function-call pattern match.
  41. };
  42. // The collected state of a pattern-matching operation.
  43. class MatchContext {
  44. public:
  45. struct WorkItem {
  46. SemIR::InstId pattern_id;
  47. // Invalid when processing the callee side.
  48. SemIR::InstId scrutinee_id;
  49. };
  50. // Constructs a MatchContext. If `callee_specific_id` is valid, this pattern
  51. // match operation is part of implementing the signature of the given
  52. // specific.
  53. explicit MatchContext(MatchKind kind, SemIR::SpecificId callee_specific_id =
  54. SemIR::SpecificId::Invalid)
  55. : next_index_(0),
  56. kind_(kind),
  57. callee_specific_id_(callee_specific_id),
  58. return_slot_id_(SemIR::InstId::Invalid) {}
  59. // Adds a work item to the stack.
  60. auto AddWork(WorkItem work_item) -> void { stack_.push_back(work_item); }
  61. // Processes all work items on the stack. When performing caller pattern
  62. // matching, returns an inst block with one inst reference for each
  63. // calling-convention argument. When performing callee pattern matching,
  64. // returns an inst block with references to all the emitted BindName insts.
  65. auto DoWork(Context& context) -> SemIR::InstBlockId;
  66. auto return_slot_id() const -> SemIR::InstId { return return_slot_id_; }
  67. private:
  68. // Allocates the next unallocated RuntimeParamIndex, starting from 0.
  69. auto NextRuntimeIndex() -> SemIR::RuntimeParamIndex {
  70. auto result = next_index_;
  71. ++next_index_.index;
  72. return result;
  73. }
  74. // Emits the pattern-match insts necessary to match the pattern inst
  75. // `entry.pattern_id` against the scrutinee value `entry.scrutinee_id`, and
  76. // adds to `stack_` any work necessary to traverse into its subpatterns. This
  77. // behavior is contingent on the kind of match being performed, as indicated
  78. // by kind_`. For example, when performing a callee pattern match, this does
  79. // not emit insts for patterns on the caller side. However, it still traverses
  80. // into subpatterns if any of their descendants might emit insts.
  81. // TODO: Require that `entry.scrutinee_id` is valid if and only if insts
  82. // should be emitted, once we start emitting `Param` insts in the
  83. // `ParamPattern` case.
  84. auto EmitPatternMatch(Context& context, MatchContext::WorkItem entry) -> void;
  85. // The stack of work to be processed.
  86. llvm::SmallVector<WorkItem> stack_;
  87. // The next index to be allocated by `NextRuntimeIndex`.
  88. SemIR::RuntimeParamIndex next_index_;
  89. // The pending results that will be returned by the current `DoWork` call.
  90. llvm::SmallVector<SemIR::InstId> results_;
  91. // The kind of pattern match being performed.
  92. MatchKind kind_;
  93. // The SpecificId of the function being called (if any).
  94. SemIR::SpecificId callee_specific_id_;
  95. // The return slot inst emitted by `DoWork`, if any.
  96. // TODO: Can this be added to the block returned by `DoWork`, instead?
  97. SemIR::InstId return_slot_id_;
  98. };
  99. } // namespace
  100. auto MatchContext::DoWork(Context& context) -> SemIR::InstBlockId {
  101. results_.reserve(stack_.size());
  102. while (!stack_.empty()) {
  103. EmitPatternMatch(context, stack_.pop_back_val());
  104. }
  105. auto block_id = context.inst_blocks().Add(results_);
  106. results_.clear();
  107. return block_id;
  108. }
  109. auto MatchContext::EmitPatternMatch(Context& context,
  110. MatchContext::WorkItem entry) -> void {
  111. if (entry.pattern_id == SemIR::InstId::BuiltinErrorInst) {
  112. results_.push_back(SemIR::InstId::BuiltinErrorInst);
  113. return;
  114. }
  115. DiagnosticAnnotationScope annotate_diagnostics(
  116. &context.emitter(), [&](auto& builder) {
  117. if (kind_ == MatchKind::Caller) {
  118. CARBON_DIAGNOSTIC(InCallToFunctionParam, Note,
  119. "initializing function parameter");
  120. builder.Note(entry.pattern_id, InCallToFunctionParam);
  121. }
  122. });
  123. auto pattern = context.insts().GetWithLocId(entry.pattern_id);
  124. CARBON_KIND_SWITCH(pattern.inst) {
  125. case SemIR::BindingPattern::Kind:
  126. case SemIR::SymbolicBindingPattern::Kind: {
  127. CARBON_CHECK(kind_ == MatchKind::Callee);
  128. auto binding_pattern = pattern.inst.As<SemIR::AnyBindingPattern>();
  129. auto cache_entry =
  130. context.bind_name_cache().Lookup(binding_pattern.entity_name_id);
  131. // The cached bind_name should only be used once.
  132. auto bind_name_id =
  133. std::exchange(cache_entry.value(), SemIR::InstId::Invalid);
  134. auto bind_name = context.insts().GetAs<SemIR::AnyBindName>(bind_name_id);
  135. CARBON_CHECK(!bind_name.value_id.is_valid());
  136. bind_name.value_id = entry.scrutinee_id;
  137. context.ReplaceInstBeforeConstantUse(bind_name_id, bind_name);
  138. context.inst_block_stack().AddInstId(bind_name_id);
  139. if (context.insts()
  140. .GetAs<SemIR::AnyParam>(entry.scrutinee_id)
  141. .runtime_index.is_valid()) {
  142. results_.push_back(entry.scrutinee_id);
  143. }
  144. break;
  145. }
  146. case CARBON_KIND(SemIR::AddrPattern addr_pattern): {
  147. if (kind_ == MatchKind::Callee) {
  148. // We're emitting pattern-match IR for the callee, but we're still on
  149. // the caller side of the pattern, so we traverse without emitting any
  150. // insts.
  151. AddWork({.pattern_id = addr_pattern.inner_id,
  152. .scrutinee_id = SemIR::InstId::Invalid});
  153. break;
  154. }
  155. CARBON_CHECK(entry.scrutinee_id.is_valid());
  156. auto scrutinee_ref_id =
  157. ConvertToValueOrRefExpr(context, entry.scrutinee_id);
  158. switch (SemIR::GetExprCategory(context.sem_ir(), scrutinee_ref_id)) {
  159. case SemIR::ExprCategory::Error:
  160. case SemIR::ExprCategory::DurableRef:
  161. case SemIR::ExprCategory::EphemeralRef:
  162. break;
  163. default:
  164. CARBON_DIAGNOSTIC(AddrSelfIsNonRef, Error,
  165. "`addr self` method cannot be invoked on a value");
  166. context.emitter().Emit(
  167. TokenOnly(context.insts().GetLocId(entry.scrutinee_id)),
  168. AddrSelfIsNonRef);
  169. results_.push_back(SemIR::InstId::BuiltinErrorInst);
  170. return;
  171. }
  172. auto scrutinee_ref = context.insts().Get(scrutinee_ref_id);
  173. auto new_scrutinee = context.AddInst<SemIR::AddrOf>(
  174. context.insts().GetLocId(scrutinee_ref_id),
  175. {.type_id = context.GetPointerType(scrutinee_ref.type_id()),
  176. .lvalue_id = scrutinee_ref_id});
  177. AddWork(
  178. {.pattern_id = addr_pattern.inner_id, .scrutinee_id = new_scrutinee});
  179. break;
  180. }
  181. case CARBON_KIND(SemIR::ValueParamPattern param_pattern): {
  182. CARBON_CHECK(param_pattern.runtime_index.index < 0 ||
  183. static_cast<size_t>(param_pattern.runtime_index.index) ==
  184. results_.size(),
  185. "Parameters out of order; expecting {0} but got {1}",
  186. results_.size(), param_pattern.runtime_index.index);
  187. switch (kind_) {
  188. case MatchKind::Caller: {
  189. CARBON_CHECK(entry.scrutinee_id.is_valid());
  190. if (entry.scrutinee_id == SemIR::InstId::BuiltinErrorInst) {
  191. results_.push_back(SemIR::InstId::BuiltinErrorInst);
  192. } else {
  193. results_.push_back(ConvertToValueOfType(
  194. context, context.insts().GetLocId(entry.scrutinee_id),
  195. entry.scrutinee_id,
  196. SemIR::GetTypeInSpecific(context.sem_ir(), callee_specific_id_,
  197. param_pattern.type_id)));
  198. }
  199. // Do not traverse farther, because the caller side of the pattern
  200. // ends here.
  201. break;
  202. }
  203. case MatchKind::Callee: {
  204. if (param_pattern.runtime_index ==
  205. SemIR::RuntimeParamIndex::Unknown) {
  206. param_pattern.runtime_index = NextRuntimeIndex();
  207. context.ReplaceInstBeforeConstantUse(entry.pattern_id,
  208. param_pattern);
  209. }
  210. AddWork(
  211. {.pattern_id = param_pattern.subpattern_id,
  212. .scrutinee_id = context.AddInst<SemIR::ValueParam>(
  213. pattern.loc_id,
  214. {.type_id = param_pattern.type_id,
  215. .runtime_index = param_pattern.runtime_index,
  216. .pretty_name_id = GetPrettyName(context, param_pattern)})});
  217. break;
  218. }
  219. }
  220. break;
  221. }
  222. case CARBON_KIND(SemIR::OutParamPattern param_pattern): {
  223. switch (kind_) {
  224. case MatchKind::Caller: {
  225. CARBON_CHECK(entry.scrutinee_id.is_valid());
  226. CARBON_CHECK(context.insts().Get(entry.scrutinee_id).type_id() ==
  227. SemIR::GetTypeInSpecific(context.sem_ir(),
  228. callee_specific_id_,
  229. param_pattern.type_id));
  230. results_.push_back(entry.scrutinee_id);
  231. // Do not traverse farther, because the caller side of the pattern
  232. // ends here.
  233. break;
  234. }
  235. case MatchKind::Callee: {
  236. // TODO: Consider ways to address near-duplication with the
  237. // ValueParamPattern case.
  238. if (param_pattern.runtime_index ==
  239. SemIR::RuntimeParamIndex::Unknown) {
  240. param_pattern.runtime_index = NextRuntimeIndex();
  241. context.ReplaceInstBeforeConstantUse(entry.pattern_id,
  242. param_pattern);
  243. }
  244. AddWork(
  245. {.pattern_id = param_pattern.subpattern_id,
  246. .scrutinee_id = context.AddInst<SemIR::OutParam>(
  247. pattern.loc_id,
  248. {.type_id = param_pattern.type_id,
  249. .runtime_index = param_pattern.runtime_index,
  250. .pretty_name_id = GetPrettyName(context, param_pattern)})});
  251. break;
  252. }
  253. }
  254. break;
  255. }
  256. case CARBON_KIND(SemIR::ReturnSlotPattern return_slot_pattern): {
  257. CARBON_CHECK(kind_ == MatchKind::Callee);
  258. return_slot_id_ = context.AddInst<SemIR::ReturnSlot>(
  259. pattern.loc_id, {.type_id = return_slot_pattern.type_id,
  260. .type_inst_id = return_slot_pattern.type_inst_id,
  261. .storage_id = entry.scrutinee_id});
  262. results_.push_back(entry.scrutinee_id);
  263. break;
  264. }
  265. default: {
  266. CARBON_FATAL("Inst kind not handled: {0}", pattern.inst.kind());
  267. }
  268. }
  269. }
  270. auto CalleePatternMatch(Context& context,
  271. SemIR::InstBlockId implicit_param_patterns_id,
  272. SemIR::InstBlockId param_patterns_id,
  273. SemIR::InstId return_slot_pattern_id)
  274. -> ParameterBlocks {
  275. if (!return_slot_pattern_id.is_valid() && !param_patterns_id.is_valid() &&
  276. !implicit_param_patterns_id.is_valid()) {
  277. return {.call_params_id = SemIR::InstBlockId::Invalid,
  278. .return_slot_id = SemIR::InstId::Invalid};
  279. }
  280. MatchContext match(MatchKind::Callee);
  281. // We add work to the stack in reverse so that the results will be produced
  282. // in the original order.
  283. if (return_slot_pattern_id.is_valid()) {
  284. match.AddWork({.pattern_id = return_slot_pattern_id,
  285. .scrutinee_id = SemIR::InstId::Invalid});
  286. }
  287. if (param_patterns_id.is_valid()) {
  288. for (SemIR::InstId inst_id :
  289. llvm::reverse(context.inst_blocks().Get(param_patterns_id))) {
  290. match.AddWork(
  291. {.pattern_id = inst_id, .scrutinee_id = SemIR::InstId::Invalid});
  292. }
  293. }
  294. if (implicit_param_patterns_id.is_valid()) {
  295. for (SemIR::InstId inst_id :
  296. llvm::reverse(context.inst_blocks().Get(implicit_param_patterns_id))) {
  297. match.AddWork(
  298. {.pattern_id = inst_id, .scrutinee_id = SemIR::InstId::Invalid});
  299. }
  300. }
  301. return {.call_params_id = match.DoWork(context),
  302. .return_slot_id = match.return_slot_id()};
  303. }
  304. auto CallerPatternMatch(Context& context, SemIR::SpecificId specific_id,
  305. SemIR::InstId self_pattern_id,
  306. SemIR::InstBlockId param_patterns_id,
  307. SemIR::InstId return_slot_pattern_id,
  308. SemIR::InstId self_arg_id,
  309. llvm::ArrayRef<SemIR::InstId> arg_refs,
  310. SemIR::InstId return_slot_arg_id)
  311. -> SemIR::InstBlockId {
  312. MatchContext match(MatchKind::Caller, specific_id);
  313. // Track the return storage, if present.
  314. if (return_slot_arg_id.is_valid()) {
  315. CARBON_CHECK(return_slot_pattern_id.is_valid());
  316. match.AddWork({.pattern_id = return_slot_pattern_id,
  317. .scrutinee_id = return_slot_arg_id});
  318. }
  319. // Check type conversions per-element.
  320. for (auto [arg_id, param_pattern_id] : llvm::reverse(llvm::zip_equal(
  321. arg_refs, context.inst_blocks().GetOrEmpty(param_patterns_id)))) {
  322. auto runtime_index = SemIR::Function::GetParamPatternInfoFromPatternId(
  323. context.sem_ir(), param_pattern_id)
  324. .inst.runtime_index;
  325. if (!runtime_index.is_valid()) {
  326. // Not a runtime parameter: we don't pass an argument.
  327. continue;
  328. }
  329. match.AddWork({.pattern_id = param_pattern_id, .scrutinee_id = arg_id});
  330. }
  331. if (self_pattern_id.is_valid()) {
  332. match.AddWork({.pattern_id = self_pattern_id, .scrutinee_id = self_arg_id});
  333. }
  334. return match.DoWork(context);
  335. }
  336. } // namespace Carbon::Check