deduce.cpp 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/deduce.h"
  5. #include "toolchain/base/kind_switch.h"
  6. #include "toolchain/check/context.h"
  7. #include "toolchain/check/convert.h"
  8. #include "toolchain/check/generic.h"
  9. #include "toolchain/check/subst.h"
  10. #include "toolchain/sem_ir/ids.h"
  11. #include "toolchain/sem_ir/impl.h"
  12. #include "toolchain/sem_ir/typed_insts.h"
  13. namespace Carbon::Check {
  14. namespace {
  15. // A list of pairs of (instruction from generic, corresponding instruction from
  16. // call to of generic) for which we still need to perform deduction, along with
  17. // methods to add and pop pending deductions from the list. Deductions are
  18. // popped in order from most- to least-recently pushed, with the intent that
  19. // they are visited in depth-first order, although the order is not expected to
  20. // matter except when it influences which error is diagnosed.
  21. class DeductionWorklist {
  22. public:
  23. explicit DeductionWorklist(Context& context) : context_(context) {}
  24. struct PendingDeduction {
  25. SemIR::InstId param;
  26. SemIR::InstId arg;
  27. bool needs_substitution;
  28. };
  29. // Adds a single (param, arg) deduction.
  30. auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
  31. -> void {
  32. deductions_.push_back(
  33. {.param = param, .arg = arg, .needs_substitution = needs_substitution});
  34. }
  35. // Adds a single (param, arg) type deduction.
  36. auto Add(SemIR::TypeId param, SemIR::TypeId arg, bool needs_substitution)
  37. -> void {
  38. Add(context_.types().GetInstId(param), context_.types().GetInstId(arg),
  39. needs_substitution);
  40. }
  41. // Adds a single (param, arg) deduction of a specific.
  42. auto Add(SemIR::SpecificId param, SemIR::SpecificId arg,
  43. bool needs_substitution) -> void {
  44. auto& param_specific = context_.specifics().Get(param);
  45. auto& arg_specific = context_.specifics().Get(arg);
  46. if (param_specific.generic_id != arg_specific.generic_id) {
  47. // TODO: Decide whether to error on this or just treat the specific as
  48. // non-deduced. For now we treat it as non-deduced.
  49. return;
  50. }
  51. AddAll(param_specific.args_id, arg_specific.args_id, needs_substitution);
  52. }
  53. // Adds a list of (param, arg) deductions. These are added in reverse order so
  54. // they are popped in forward order.
  55. template <typename ElementId>
  56. auto AddAll(llvm::ArrayRef<ElementId> params, llvm::ArrayRef<ElementId> args,
  57. bool needs_substitution) -> void {
  58. if (params.size() != args.size()) {
  59. // TODO: Decide whether to error on this or just treat the parameter list
  60. // as non-deduced. For now we treat it as non-deduced.
  61. return;
  62. }
  63. for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
  64. Add(param, arg, needs_substitution);
  65. }
  66. }
  67. auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
  68. bool needs_substitution) -> void {
  69. AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
  70. }
  71. auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
  72. bool needs_substitution) -> void {
  73. AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
  74. needs_substitution);
  75. }
  76. auto AddAll(SemIR::TypeBlockId params, SemIR::TypeBlockId args,
  77. bool needs_substitution) -> void {
  78. AddAll(context_.type_blocks().Get(params), context_.type_blocks().Get(args),
  79. needs_substitution);
  80. }
  81. // Adds a (param, arg) pair for an instruction argument, given its kind.
  82. auto AddInstArg(SemIR::IdKind kind, int32_t param, int32_t arg,
  83. bool needs_substitution) -> void {
  84. switch (kind) {
  85. case SemIR::IdKind::None:
  86. case SemIR::IdKind::For<SemIR::ClassId>:
  87. case SemIR::IdKind::For<SemIR::InterfaceId>:
  88. case SemIR::IdKind::For<SemIR::IntKind>:
  89. break;
  90. case SemIR::IdKind::For<SemIR::InstId>:
  91. Add(SemIR::InstId(param), SemIR::InstId(arg), needs_substitution);
  92. break;
  93. case SemIR::IdKind::For<SemIR::TypeId>:
  94. Add(SemIR::TypeId(param), SemIR::TypeId(arg), needs_substitution);
  95. break;
  96. case SemIR::IdKind::For<SemIR::InstBlockId>:
  97. AddAll(SemIR::InstBlockId(param), SemIR::InstBlockId(arg),
  98. needs_substitution);
  99. break;
  100. case SemIR::IdKind::For<SemIR::TypeBlockId>:
  101. AddAll(SemIR::TypeBlockId(param), SemIR::TypeBlockId(arg),
  102. needs_substitution);
  103. break;
  104. case SemIR::IdKind::For<SemIR::SpecificId>:
  105. Add(SemIR::SpecificId(param), SemIR::SpecificId(arg),
  106. needs_substitution);
  107. break;
  108. default:
  109. CARBON_FATAL("unexpected argument kind");
  110. }
  111. }
  112. // Returns whether we have completed all deductions.
  113. auto Done() -> bool { return deductions_.empty(); }
  114. // Pops the next deduction. Requires `!Done()`.
  115. auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }
  116. private:
  117. Context& context_;
  118. llvm::SmallVector<PendingDeduction> deductions_;
  119. };
  120. // State that is tracked throughout the deduction process.
  121. class DeductionContext {
  122. public:
  123. // Preparse to perform deduction. If an enclosing specific is provided, adds
  124. // the arguments from the given specific as known arguments that will not be
  125. // deduced.
  126. DeductionContext(Context& context, SemIR::LocId loc_id,
  127. SemIR::GenericId generic_id,
  128. SemIR::SpecificId enclosing_specific_id, bool diagnose);
  129. auto context() const -> Context& { return *context_; }
  130. // Adds a pending deduction of `param` from `arg`. `needs_substitution`
  131. // indicates whether we need to substitute known generic parameters into
  132. // `param`.
  133. template <typename ParamT, typename ArgT>
  134. auto Add(ParamT param, ArgT arg, bool needs_substitution) -> void {
  135. worklist_.Add(param, arg, needs_substitution);
  136. }
  137. // Same as `Add` but for an array or block of operands.
  138. template <typename ParamT, typename ArgT>
  139. auto AddAll(ParamT param, ArgT arg, bool needs_substitution) -> void {
  140. worklist_.AddAll(param, arg, needs_substitution);
  141. }
  142. // Performs all deductions in the deduction worklist. Returns whether
  143. // deduction succeeded.
  144. auto Deduce() -> bool;
  145. // Returns whether every generic parameter has a corresponding deduced generic
  146. // argument. If not, issues a suitable diagnostic.
  147. auto CheckDeductionIsComplete() -> bool;
  148. // Forms a specific corresponding to the deduced generic with the deduced
  149. // argument list. Must not be called before deduction is complete.
  150. auto MakeSpecific() -> SemIR::SpecificId;
  151. private:
  152. Context* context_;
  153. SemIR::LocId loc_id_;
  154. SemIR::GenericId generic_id_;
  155. bool diagnose_;
  156. DeductionWorklist worklist_;
  157. llvm::SmallVector<SemIR::InstId> result_arg_ids_;
  158. llvm::SmallVector<Substitution> substitutions_;
  159. SemIR::CompileTimeBindIndex first_deduced_index_;
  160. };
  161. } // namespace
  162. static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
  163. Context::DiagnosticBuilder& diag) -> void {
  164. CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
  165. "while deducing parameters of generic declared here");
  166. diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
  167. }
  168. DeductionContext::DeductionContext(Context& context, SemIR::LocId loc_id,
  169. SemIR::GenericId generic_id,
  170. SemIR::SpecificId enclosing_specific_id,
  171. bool diagnose)
  172. : context_(&context),
  173. loc_id_(loc_id),
  174. generic_id_(generic_id),
  175. diagnose_(diagnose),
  176. worklist_(context),
  177. first_deduced_index_(0) {
  178. CARBON_CHECK(generic_id.is_valid(),
  179. "Performing deduction for non-generic entity");
  180. // Initialize the deduced arguments to Invalid.
  181. result_arg_ids_.resize(
  182. context.inst_blocks()
  183. .Get(context.generics().Get(generic_id_).bindings_id)
  184. .size(),
  185. SemIR::InstId::Invalid);
  186. if (enclosing_specific_id.is_valid()) {
  187. // Copy any outer generic arguments from the specified instance and prepare
  188. // to substitute them into the function declaration.
  189. auto args = context.inst_blocks().Get(
  190. context.specifics().Get(enclosing_specific_id).args_id);
  191. std::copy(args.begin(), args.end(), result_arg_ids_.begin());
  192. // TODO: Subst is linear in the length of the substitutions list. Change
  193. // it so we can pass in an array mapping indexes to substitutions instead.
  194. substitutions_.reserve(args.size());
  195. for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
  196. substitutions_.push_back(
  197. {.bind_id = SemIR::CompileTimeBindIndex(i),
  198. .replacement_id = context.constant_values().Get(subst_inst_id)});
  199. }
  200. first_deduced_index_ = SemIR::CompileTimeBindIndex(args.size());
  201. }
  202. }
  203. auto DeductionContext::Deduce() -> bool {
  204. while (!worklist_.Done()) {
  205. auto [param_id, arg_id, needs_substitution] = worklist_.PopNext();
  206. // If the parameter has a symbolic type, deduce against that.
  207. auto param_type_id = context().insts().Get(param_id).type_id();
  208. if (param_type_id.AsConstantId().is_symbolic()) {
  209. Add(context().types().GetInstId(param_type_id),
  210. context().types().GetInstId(context().insts().Get(arg_id).type_id()),
  211. needs_substitution);
  212. } else {
  213. // The argument needs to have the same type as the parameter.
  214. // TODO: Suppress diagnostics here if diagnose_ is false.
  215. DiagnosticAnnotationScope annotate_diagnostics(
  216. &context().emitter(), [&](auto& builder) {
  217. if (auto param =
  218. context().insts().TryGetAs<SemIR::BindSymbolicName>(
  219. param_id)) {
  220. CARBON_DIAGNOSTIC(
  221. InitializingGenericParam, Note,
  222. "initializing generic parameter `{0}` declared here",
  223. SemIR::NameId);
  224. builder.Note(
  225. param_id, InitializingGenericParam,
  226. context().entity_names().Get(param->entity_name_id).name_id);
  227. }
  228. });
  229. arg_id = ConvertToValueOfType(context(), loc_id_, arg_id, param_type_id);
  230. if (arg_id == SemIR::InstId::BuiltinError) {
  231. return false;
  232. }
  233. }
  234. // If the parameter is a symbolic constant, deduce against it. Otherwise, we
  235. // assume there is nothing to deduce.
  236. // TODO: This won't do the right thing in a template deduction.
  237. auto param_const_id = context().constant_values().Get(param_id);
  238. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  239. continue;
  240. }
  241. // Attempt to match `param_inst` against `arg_id`. If the match succeeds,
  242. // this should `continue` the outer loop. On `break`, we will try to desugar
  243. // the parameter to continue looking for a match.
  244. auto param_inst = context().insts().Get(
  245. context().constant_values().GetInstId(param_const_id));
  246. CARBON_KIND_SWITCH(param_inst) {
  247. // Deducing a symbolic binding from an argument with a constant value
  248. // deduces the binding as having that constant value.
  249. case SemIR::InstKind::SymbolicBindingPattern:
  250. case SemIR::InstKind::BindSymbolicName: {
  251. auto entity_name_id = SemIR::EntityNameId::Invalid;
  252. if (auto bind = param_inst.TryAs<SemIR::SymbolicBindingPattern>()) {
  253. entity_name_id = bind->entity_name_id;
  254. } else {
  255. entity_name_id =
  256. param_inst.As<SemIR::BindSymbolicName>().entity_name_id;
  257. }
  258. auto& entity_name = context().entity_names().Get(entity_name_id);
  259. auto index = entity_name.bind_index;
  260. if (!index.is_valid() || index < first_deduced_index_) {
  261. break;
  262. }
  263. CARBON_CHECK(static_cast<size_t>(index.index) < result_arg_ids_.size(),
  264. "Deduced value for unexpected index {0}; expected to "
  265. "deduce {1} arguments.",
  266. index, result_arg_ids_.size());
  267. auto arg_const_inst_id =
  268. context().constant_values().GetConstantInstId(arg_id);
  269. if (arg_const_inst_id.is_valid()) {
  270. if (result_arg_ids_[index.index].is_valid() &&
  271. result_arg_ids_[index.index] != arg_const_inst_id) {
  272. if (diagnose_) {
  273. // TODO: Include the two different deduced values.
  274. CARBON_DIAGNOSTIC(DeductionInconsistent, Error,
  275. "inconsistent deductions for value of generic "
  276. "parameter `{0}`",
  277. SemIR::NameId);
  278. auto diag = context().emitter().Build(
  279. loc_id_, DeductionInconsistent, entity_name.name_id);
  280. NoteGenericHere(context(), generic_id_, diag);
  281. diag.Emit();
  282. }
  283. return false;
  284. }
  285. result_arg_ids_[index.index] = arg_const_inst_id;
  286. }
  287. continue;
  288. }
  289. // Various kinds of parameter should match an argument of the same form,
  290. // if the operands all match.
  291. case SemIR::ArrayType::Kind:
  292. case SemIR::ClassType::Kind:
  293. case SemIR::ConstType::Kind:
  294. case SemIR::FloatType::Kind:
  295. case SemIR::InterfaceType::Kind:
  296. case SemIR::IntType::Kind:
  297. case SemIR::PointerType::Kind:
  298. case SemIR::TupleType::Kind:
  299. case SemIR::TupleValue::Kind: {
  300. auto arg_inst = context().insts().Get(arg_id);
  301. if (arg_inst.kind() != param_inst.kind()) {
  302. break;
  303. }
  304. auto [kind0, kind1] = param_inst.ArgKinds();
  305. worklist_.AddInstArg(kind0, param_inst.arg0(), arg_inst.arg0(),
  306. needs_substitution);
  307. worklist_.AddInstArg(kind1, param_inst.arg1(), arg_inst.arg1(),
  308. needs_substitution);
  309. continue;
  310. }
  311. case SemIR::StructType::Kind:
  312. case SemIR::StructValue::Kind:
  313. // TODO: Match field name order between param and arg.
  314. break;
  315. // TODO: Handle more cases.
  316. default:
  317. break;
  318. }
  319. // If we've not yet substituted into the parameter, do so now and try again.
  320. if (needs_substitution) {
  321. param_const_id = SubstConstant(context(), param_const_id, substitutions_);
  322. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  323. continue;
  324. }
  325. Add(context().constant_values().GetInstId(param_const_id), arg_id,
  326. /*needs_substitution=*/false);
  327. }
  328. }
  329. return true;
  330. }
  331. auto DeductionContext::CheckDeductionIsComplete() -> bool {
  332. // Check we deduced an argument value for every parameter.
  333. for (auto [i, deduced_arg_id] :
  334. llvm::enumerate(llvm::ArrayRef(result_arg_ids_)
  335. .drop_front(first_deduced_index_.index))) {
  336. if (!deduced_arg_id.is_valid()) {
  337. if (diagnose_) {
  338. auto binding_index = first_deduced_index_.index + i;
  339. auto binding_id = context().inst_blocks().Get(
  340. context().generics().Get(generic_id_).bindings_id)[binding_index];
  341. auto entity_name_id = context()
  342. .insts()
  343. .GetAs<SemIR::AnyBindName>(binding_id)
  344. .entity_name_id;
  345. CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
  346. "cannot deduce value for generic parameter `{0}`",
  347. SemIR::NameId);
  348. auto diag = context().emitter().Build(
  349. loc_id_, DeductionIncomplete,
  350. context().entity_names().Get(entity_name_id).name_id);
  351. NoteGenericHere(context(), generic_id_, diag);
  352. diag.Emit();
  353. }
  354. return false;
  355. }
  356. }
  357. return true;
  358. }
  359. auto DeductionContext::MakeSpecific() -> SemIR::SpecificId {
  360. // TODO: Convert the deduced values to the types of the bindings.
  361. return Check::MakeSpecific(
  362. context(), generic_id_,
  363. context().inst_blocks().AddCanonical(result_arg_ids_));
  364. }
  365. auto DeduceGenericCallArguments(
  366. Context& context, SemIR::LocId loc_id, SemIR::GenericId generic_id,
  367. SemIR::SpecificId enclosing_specific_id,
  368. [[maybe_unused]] SemIR::InstBlockId implicit_params_id,
  369. SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
  370. llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
  371. DeductionContext deduction(context, loc_id, generic_id, enclosing_specific_id,
  372. /*diagnose=*/true);
  373. // Prepare to perform deduction of the explicit parameters against their
  374. // arguments.
  375. // TODO: Also perform deduction for type of self.
  376. deduction.AddAll(params_id, arg_ids, /*needs_substitution=*/true);
  377. if (!deduction.Deduce() || !deduction.CheckDeductionIsComplete()) {
  378. return SemIR::SpecificId::Invalid;
  379. }
  380. return deduction.MakeSpecific();
  381. }
  382. // Deduces the impl arguments to use in a use of a parameterized impl. Returns
  383. // `Invalid` if deduction fails.
  384. auto DeduceImplArguments(Context& context, SemIR::LocId loc_id,
  385. const SemIR::Impl& impl, SemIR::ConstantId self_id,
  386. SemIR::ConstantId constraint_id) -> SemIR::SpecificId {
  387. DeductionContext deduction(
  388. context, loc_id, impl.generic_id,
  389. /*enclosing_specific_id=*/SemIR::SpecificId::Invalid,
  390. /*diagnose=*/false);
  391. // Prepare to perform deduction of the type and interface.
  392. deduction.Add(impl.self_id, context.constant_values().GetInstId(self_id),
  393. /*needs_substitution=*/false);
  394. deduction.Add(impl.constraint_id,
  395. context.constant_values().GetInstId(constraint_id),
  396. /*needs_substitution=*/false);
  397. if (!deduction.Deduce() || !deduction.CheckDeductionIsComplete()) {
  398. return SemIR::SpecificId::Invalid;
  399. }
  400. return deduction.MakeSpecific();
  401. }
  402. } // namespace Carbon::Check