deduce.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/deduce.h"
  5. #include "toolchain/base/kind_switch.h"
  6. #include "toolchain/check/context.h"
  7. #include "toolchain/check/convert.h"
  8. #include "toolchain/check/generic.h"
  9. #include "toolchain/check/subst.h"
  10. #include "toolchain/sem_ir/ids.h"
  11. #include "toolchain/sem_ir/impl.h"
  12. #include "toolchain/sem_ir/typed_insts.h"
  13. namespace Carbon::Check {
  14. namespace {
  15. // A list of pairs of (instruction from generic, corresponding instruction from
  16. // call to of generic) for which we still need to perform deduction, along with
  17. // methods to add and pop pending deductions from the list. Deductions are
  18. // popped in order from most- to least-recently pushed, with the intent that
  19. // they are visited in depth-first order, although the order is not expected to
  20. // matter except when it influences which error is diagnosed.
  21. class DeductionWorklist {
  22. public:
  23. explicit DeductionWorklist(Context& context) : context_(context) {}
  24. struct PendingDeduction {
  25. SemIR::InstId param;
  26. SemIR::InstId arg;
  27. bool needs_substitution;
  28. };
  29. // Adds a single (param, arg) deduction.
  30. auto Add(SemIR::InstId param, SemIR::InstId arg, bool needs_substitution)
  31. -> void {
  32. deductions_.push_back(
  33. {.param = param, .arg = arg, .needs_substitution = needs_substitution});
  34. }
  35. // Adds a list of (param, arg) deductions. These are added in reverse order so
  36. // they are popped in forward order.
  37. auto AddAll(llvm::ArrayRef<SemIR::InstId> params,
  38. llvm::ArrayRef<SemIR::InstId> args, bool needs_substitution)
  39. -> void {
  40. if (params.size() != args.size()) {
  41. // TODO: Decide whether to error on this or just treat the parameter list
  42. // as non-deduced. For now we treat it as non-deduced.
  43. return;
  44. }
  45. for (auto [param, arg] : llvm::reverse(llvm::zip_equal(params, args))) {
  46. Add(param, arg, needs_substitution);
  47. }
  48. }
  49. auto AddAll(SemIR::InstBlockId params, llvm::ArrayRef<SemIR::InstId> args,
  50. bool needs_substitution) -> void {
  51. AddAll(context_.inst_blocks().Get(params), args, needs_substitution);
  52. }
  53. auto AddAll(SemIR::InstBlockId params, SemIR::InstBlockId args,
  54. bool needs_substitution) -> void {
  55. AddAll(context_.inst_blocks().Get(params), context_.inst_blocks().Get(args),
  56. needs_substitution);
  57. }
  58. // Returns whether we have completed all deductions.
  59. auto Done() -> bool { return deductions_.empty(); }
  60. // Pops the next deduction. Requires `!Done()`.
  61. auto PopNext() -> PendingDeduction { return deductions_.pop_back_val(); }
  62. private:
  63. Context& context_;
  64. llvm::SmallVector<PendingDeduction> deductions_;
  65. };
  66. // State that is tracked throughout the deduction process.
  67. class DeductionContext {
  68. public:
  69. // Preparse to perform deduction. If an enclosing specific is provided, adds
  70. // the arguments from the given specific as known arguments that will not be
  71. // deduced.
  72. DeductionContext(Context& context, SemIR::LocId loc_id,
  73. SemIR::GenericId generic_id,
  74. SemIR::SpecificId enclosing_specific_id, bool diagnose);
  75. auto context() const -> Context& { return *context_; }
  76. // Adds a pending deduction of `param` from `arg`. `needs_substitution`
  77. // indicates whether we need to substitute known generic parameters into
  78. // `param`.
  79. template <typename ParamT, typename ArgT>
  80. auto Add(ParamT param, ArgT arg, bool needs_substitution) -> void {
  81. worklist_.Add(param, arg, needs_substitution);
  82. }
  83. // Same as `Add` but for an array or block of operands.
  84. template <typename ParamT, typename ArgT>
  85. auto AddAll(ParamT param, ArgT arg, bool needs_substitution) -> void {
  86. worklist_.AddAll(param, arg, needs_substitution);
  87. }
  88. // Performs all deductions in the deduction worklist. Returns whether
  89. // deduction succeeded.
  90. auto Deduce() -> bool;
  91. // Returns whether every generic parameter has a corresponding deduced generic
  92. // argument. If not, issues a suitable diagnostic.
  93. auto CheckDeductionIsComplete() -> bool;
  94. // Forms a specific corresponding to the deduced generic with the deduced
  95. // argument list. Must not be called before deduction is complete.
  96. auto MakeSpecific() -> SemIR::SpecificId;
  97. private:
  98. Context* context_;
  99. SemIR::LocId loc_id_;
  100. SemIR::GenericId generic_id_;
  101. bool diagnose_;
  102. DeductionWorklist worklist_;
  103. llvm::SmallVector<SemIR::InstId> result_arg_ids_;
  104. llvm::SmallVector<Substitution> substitutions_;
  105. SemIR::CompileTimeBindIndex first_deduced_index_;
  106. };
  107. } // namespace
  108. static auto NoteGenericHere(Context& context, SemIR::GenericId generic_id,
  109. Context::DiagnosticBuilder& diag) -> void {
  110. CARBON_DIAGNOSTIC(DeductionGenericHere, Note,
  111. "while deducing parameters of generic declared here");
  112. diag.Note(context.generics().Get(generic_id).decl_id, DeductionGenericHere);
  113. }
  114. DeductionContext::DeductionContext(Context& context, SemIR::LocId loc_id,
  115. SemIR::GenericId generic_id,
  116. SemIR::SpecificId enclosing_specific_id,
  117. bool diagnose)
  118. : context_(&context),
  119. loc_id_(loc_id),
  120. generic_id_(generic_id),
  121. diagnose_(diagnose),
  122. worklist_(context),
  123. first_deduced_index_(0) {
  124. CARBON_CHECK(generic_id.is_valid(),
  125. "Performing deduction for non-generic entity");
  126. // Initialize the deduced arguments to Invalid.
  127. result_arg_ids_.resize(
  128. context.inst_blocks()
  129. .Get(context.generics().Get(generic_id_).bindings_id)
  130. .size(),
  131. SemIR::InstId::Invalid);
  132. if (enclosing_specific_id.is_valid()) {
  133. // Copy any outer generic arguments from the specified instance and prepare
  134. // to substitute them into the function declaration.
  135. auto args = context.inst_blocks().Get(
  136. context.specifics().Get(enclosing_specific_id).args_id);
  137. std::copy(args.begin(), args.end(), result_arg_ids_.begin());
  138. // TODO: Subst is linear in the length of the substitutions list. Change
  139. // it so we can pass in an array mapping indexes to substitutions instead.
  140. substitutions_.reserve(args.size());
  141. for (auto [i, subst_inst_id] : llvm::enumerate(args)) {
  142. substitutions_.push_back(
  143. {.bind_id = SemIR::CompileTimeBindIndex(i),
  144. .replacement_id = context.constant_values().Get(subst_inst_id)});
  145. }
  146. first_deduced_index_ = SemIR::CompileTimeBindIndex(args.size());
  147. }
  148. }
  149. auto DeductionContext::Deduce() -> bool {
  150. while (!worklist_.Done()) {
  151. auto [param_id, arg_id, needs_substitution] = worklist_.PopNext();
  152. // If the parameter has a symbolic type, deduce against that.
  153. auto param_type_id = context().insts().Get(param_id).type_id();
  154. if (param_type_id.AsConstantId().is_symbolic()) {
  155. Add(context().types().GetInstId(param_type_id),
  156. context().types().GetInstId(context().insts().Get(arg_id).type_id()),
  157. needs_substitution);
  158. } else {
  159. // The argument needs to have the same type as the parameter.
  160. // TODO: Suppress diagnostics here if diagnose_ is false.
  161. DiagnosticAnnotationScope annotate_diagnostics(
  162. &context().emitter(), [&](auto& builder) {
  163. if (auto param =
  164. context().insts().TryGetAs<SemIR::BindSymbolicName>(
  165. param_id)) {
  166. CARBON_DIAGNOSTIC(
  167. InitializingGenericParam, Note,
  168. "initializing generic parameter `{0}` declared here",
  169. SemIR::NameId);
  170. builder.Note(
  171. param_id, InitializingGenericParam,
  172. context().entity_names().Get(param->entity_name_id).name_id);
  173. }
  174. });
  175. arg_id = ConvertToValueOfType(context(), loc_id_, arg_id, param_type_id);
  176. if (arg_id == SemIR::InstId::BuiltinError) {
  177. return false;
  178. }
  179. }
  180. // If the parameter is a symbolic constant, deduce against it.
  181. auto param_const_id = context().constant_values().Get(param_id);
  182. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  183. continue;
  184. }
  185. // If we've not yet substituted into the parameter, do so now.
  186. if (needs_substitution) {
  187. param_const_id = SubstConstant(context(), param_const_id, substitutions_);
  188. if (!param_const_id.is_valid() || !param_const_id.is_symbolic()) {
  189. continue;
  190. }
  191. needs_substitution = false;
  192. }
  193. CARBON_KIND_SWITCH(context().insts().Get(
  194. context().constant_values().GetInstId(
  195. param_const_id))) {
  196. // Deducing a symbolic binding from an argument with a constant value
  197. // deduces the binding as having that constant value.
  198. case CARBON_KIND(SemIR::BindSymbolicName bind): {
  199. auto& entity_name = context().entity_names().Get(bind.entity_name_id);
  200. auto index = entity_name.bind_index;
  201. if (index.is_valid() && index >= first_deduced_index_) {
  202. CARBON_CHECK(
  203. static_cast<size_t>(index.index) < result_arg_ids_.size(),
  204. "Deduced value for unexpected index {0}; expected to "
  205. "deduce {1} arguments.",
  206. index, result_arg_ids_.size());
  207. auto arg_const_inst_id =
  208. context().constant_values().GetConstantInstId(arg_id);
  209. if (arg_const_inst_id.is_valid()) {
  210. if (result_arg_ids_[index.index].is_valid() &&
  211. result_arg_ids_[index.index] != arg_const_inst_id) {
  212. if (diagnose_) {
  213. // TODO: Include the two different deduced values.
  214. CARBON_DIAGNOSTIC(
  215. DeductionInconsistent, Error,
  216. "inconsistent deductions for value of generic "
  217. "parameter `{0}`",
  218. SemIR::NameId);
  219. auto diag = context().emitter().Build(
  220. loc_id_, DeductionInconsistent, entity_name.name_id);
  221. NoteGenericHere(context(), generic_id_, diag);
  222. diag.Emit();
  223. }
  224. return false;
  225. }
  226. result_arg_ids_[index.index] = arg_const_inst_id;
  227. }
  228. }
  229. break;
  230. }
  231. // TODO: Handle more cases.
  232. default:
  233. break;
  234. }
  235. }
  236. return true;
  237. }
  238. auto DeductionContext::CheckDeductionIsComplete() -> bool {
  239. // Check we deduced an argument value for every parameter.
  240. for (auto [i, deduced_arg_id] :
  241. llvm::enumerate(llvm::ArrayRef(result_arg_ids_)
  242. .drop_front(first_deduced_index_.index))) {
  243. if (!deduced_arg_id.is_valid()) {
  244. if (diagnose_) {
  245. auto binding_index = first_deduced_index_.index + i;
  246. auto binding_id = context().inst_blocks().Get(
  247. context().generics().Get(generic_id_).bindings_id)[binding_index];
  248. auto entity_name_id = context()
  249. .insts()
  250. .GetAs<SemIR::AnyBindName>(binding_id)
  251. .entity_name_id;
  252. CARBON_DIAGNOSTIC(DeductionIncomplete, Error,
  253. "cannot deduce value for generic parameter `{0}`",
  254. SemIR::NameId);
  255. auto diag = context().emitter().Build(
  256. loc_id_, DeductionIncomplete,
  257. context().entity_names().Get(entity_name_id).name_id);
  258. NoteGenericHere(context(), generic_id_, diag);
  259. diag.Emit();
  260. }
  261. return false;
  262. }
  263. }
  264. return true;
  265. }
  266. auto DeductionContext::MakeSpecific() -> SemIR::SpecificId {
  267. // TODO: Convert the deduced values to the types of the bindings.
  268. return Check::MakeSpecific(
  269. context(), generic_id_,
  270. context().inst_blocks().AddCanonical(result_arg_ids_));
  271. }
  272. auto DeduceGenericCallArguments(
  273. Context& context, SemIR::LocId loc_id, SemIR::GenericId generic_id,
  274. SemIR::SpecificId enclosing_specific_id,
  275. [[maybe_unused]] SemIR::InstBlockId implicit_params_id,
  276. SemIR::InstBlockId params_id, [[maybe_unused]] SemIR::InstId self_id,
  277. llvm::ArrayRef<SemIR::InstId> arg_ids) -> SemIR::SpecificId {
  278. DeductionContext deduction(context, loc_id, generic_id, enclosing_specific_id,
  279. /*diagnose=*/true);
  280. // Prepare to perform deduction of the explicit parameters against their
  281. // arguments.
  282. // TODO: Also perform deduction for type of self.
  283. deduction.AddAll(params_id, arg_ids, /*needs_substitution=*/true);
  284. if (!deduction.Deduce() || !deduction.CheckDeductionIsComplete()) {
  285. return SemIR::SpecificId::Invalid;
  286. }
  287. return deduction.MakeSpecific();
  288. }
  289. // Deduces the impl arguments to use in a use of a parameterized impl. Returns
  290. // `Invalid` if deduction fails.
  291. auto DeduceImplArguments(Context& context, SemIR::LocId loc_id,
  292. const SemIR::Impl& impl, SemIR::ConstantId self_id,
  293. SemIR::ConstantId constraint_id) -> SemIR::SpecificId {
  294. DeductionContext deduction(
  295. context, loc_id, impl.generic_id,
  296. /*enclosing_specific_id=*/SemIR::SpecificId::Invalid,
  297. /*diagnose=*/false);
  298. // Prepare to perform deduction of the type and interface.
  299. deduction.Add(impl.self_id, context.constant_values().GetInstId(self_id),
  300. /*needs_substitution=*/false);
  301. deduction.Add(impl.constraint_id,
  302. context.constant_values().GetInstId(constraint_id),
  303. /*needs_substitution=*/false);
  304. if (!deduction.Deduce() || !deduction.CheckDeductionIsComplete()) {
  305. return SemIR::SpecificId::Invalid;
  306. }
  307. return deduction.MakeSpecific();
  308. }
  309. } // namespace Carbon::Check