class.cpp 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/class.h"
  5. #include "toolchain/check/context.h"
  6. #include "toolchain/check/convert.h"
  7. #include "toolchain/check/eval.h"
  8. #include "toolchain/check/function.h"
  9. #include "toolchain/check/generic.h"
  10. #include "toolchain/check/impl.h"
  11. #include "toolchain/check/import_ref.h"
  12. #include "toolchain/check/inst.h"
  13. #include "toolchain/check/name_lookup.h"
  14. #include "toolchain/check/name_ref.h"
  15. #include "toolchain/check/pattern.h"
  16. #include "toolchain/check/pattern_match.h"
  17. #include "toolchain/check/type.h"
  18. #include "toolchain/parse/node_ids.h"
  19. #include "toolchain/sem_ir/builtin_function_kind.h"
  20. #include "toolchain/sem_ir/function.h"
  21. #include "toolchain/sem_ir/ids.h"
  22. #include "toolchain/sem_ir/typed_insts.h"
  23. namespace Carbon::Check {
  24. auto SetClassSelfType(Context& context, SemIR::ClassId class_id) -> void {
  25. auto& class_info = context.classes().Get(class_id);
  26. auto specific_id = context.generics().GetSelfSpecific(class_info.generic_id);
  27. class_info.self_type_id = GetClassType(context, class_id, specific_id);
  28. }
  29. auto StartClassDefinition(Context& context, SemIR::Class& class_info,
  30. SemIR::InstId definition_id) -> void {
  31. // Track that this declaration is the definition.
  32. CARBON_CHECK(!class_info.has_definition_started());
  33. class_info.definition_id = definition_id;
  34. class_info.scope_id = context.name_scopes().Add(
  35. definition_id, SemIR::NameId::None, class_info.parent_scope_id);
  36. // Introduce `Self`.
  37. context.name_scopes().AddRequiredName(
  38. class_info.scope_id, SemIR::NameId::SelfType,
  39. context.types().GetInstId(class_info.self_type_id));
  40. }
  41. // Checks that the specified finished adapter definition is valid and builds and
  42. // returns a corresponding complete type witness instruction.
  43. static auto CheckCompleteAdapterClassType(
  44. Context& context, Parse::NodeId node_id, SemIR::ClassId class_id,
  45. llvm::ArrayRef<SemIR::InstId> field_decls,
  46. llvm::ArrayRef<SemIR::InstId> body) -> SemIR::InstId {
  47. const auto& class_info = context.classes().Get(class_id);
  48. if (class_info.base_id.has_value()) {
  49. CARBON_DIAGNOSTIC(AdaptWithBase, Error, "adapter with base class");
  50. CARBON_DIAGNOSTIC(AdaptWithBaseHere, Note, "`base` declaration is here");
  51. context.emitter()
  52. .Build(class_info.adapt_id, AdaptWithBase)
  53. .Note(class_info.base_id, AdaptWithBaseHere)
  54. .Emit();
  55. return SemIR::ErrorInst::InstId;
  56. }
  57. if (!field_decls.empty()) {
  58. CARBON_DIAGNOSTIC(AdaptWithFields, Error, "adapter with fields");
  59. CARBON_DIAGNOSTIC(AdaptWithFieldHere, Note,
  60. "first field declaration is here");
  61. context.emitter()
  62. .Build(class_info.adapt_id, AdaptWithFields)
  63. .Note(field_decls.front(), AdaptWithFieldHere)
  64. .Emit();
  65. return SemIR::ErrorInst::InstId;
  66. }
  67. for (auto inst_id : body) {
  68. if (auto function_decl =
  69. context.insts().TryGetAs<SemIR::FunctionDecl>(inst_id)) {
  70. auto& function = context.functions().Get(function_decl->function_id);
  71. if (function.virtual_modifier ==
  72. SemIR::Function::VirtualModifier::Virtual) {
  73. CARBON_DIAGNOSTIC(AdaptWithVirtual, Error,
  74. "adapter with virtual function");
  75. CARBON_DIAGNOSTIC(AdaptWithVirtualHere, Note,
  76. "first virtual function declaration is here");
  77. context.emitter()
  78. .Build(class_info.adapt_id, AdaptWithVirtual)
  79. .Note(inst_id, AdaptWithVirtualHere)
  80. .Emit();
  81. return SemIR::ErrorInst::InstId;
  82. }
  83. }
  84. }
  85. // The object representation of the adapter is the object representation
  86. // of the adapted type.
  87. auto adapted_type_id =
  88. class_info.GetAdaptedType(context.sem_ir(), SemIR::SpecificId::None);
  89. auto object_repr_id = context.types().GetObjectRepr(adapted_type_id);
  90. return AddInst<SemIR::CompleteTypeWitness>(
  91. context, node_id,
  92. {.type_id = GetSingletonType(context, SemIR::WitnessType::TypeInstId),
  93. // TODO: Use InstId from the adapt declaration.
  94. .object_repr_type_inst_id = context.types().GetInstId(object_repr_id)});
  95. }
  96. static auto AddStructTypeFields(
  97. Context& context,
  98. llvm::SmallVector<SemIR::StructTypeField>& struct_type_fields,
  99. llvm::ArrayRef<SemIR::InstId> field_decls) -> SemIR::StructTypeFieldsId {
  100. for (auto field_decl_id : field_decls) {
  101. auto field_decl = context.insts().GetAs<SemIR::FieldDecl>(field_decl_id);
  102. field_decl.index =
  103. SemIR::ElementIndex{static_cast<int>(struct_type_fields.size())};
  104. ReplaceInstPreservingConstantValue(context, field_decl_id, field_decl);
  105. if (field_decl.type_id == SemIR::ErrorInst::TypeId) {
  106. struct_type_fields.push_back(
  107. {.name_id = field_decl.name_id,
  108. .type_inst_id = SemIR::ErrorInst::TypeInstId});
  109. continue;
  110. }
  111. auto unbound_element_type =
  112. context.sem_ir().types().GetAs<SemIR::UnboundElementType>(
  113. field_decl.type_id);
  114. struct_type_fields.push_back(
  115. {.name_id = field_decl.name_id,
  116. .type_inst_id = unbound_element_type.element_type_inst_id});
  117. }
  118. auto fields_id =
  119. context.struct_type_fields().AddCanonical(struct_type_fields);
  120. return fields_id;
  121. }
  122. // Builds and returns a vtable for the current class. Assumes that the virtual
  123. // functions for the class are listed as the top element of the `vtable_stack`.
  124. static auto BuildVtable(Context& context, Parse::ClassDefinitionId node_id,
  125. SemIR::ClassId class_id,
  126. std::optional<SemIR::ClassType> base_class_type,
  127. llvm::ArrayRef<SemIR::InstId> vtable_contents)
  128. -> SemIR::VtableId {
  129. auto base_vtable_id = SemIR::VtableId::None;
  130. auto base_class_specific_id = SemIR::SpecificId::None;
  131. // Get some base class/type/specific info.
  132. if (base_class_type) {
  133. auto& base_class_info = context.classes().Get(base_class_type->class_id);
  134. auto base_vtable_ptr_inst_id = base_class_info.vtable_ptr_id;
  135. if (base_vtable_ptr_inst_id.has_value()) {
  136. LoadImportRef(context, base_vtable_ptr_inst_id);
  137. auto canonical_base_vtable_inst_id =
  138. context.constant_values().GetConstantInstId(base_vtable_ptr_inst_id);
  139. const auto& base_vtable_ptr_inst =
  140. context.insts().GetAs<SemIR::VtablePtr>(
  141. canonical_base_vtable_inst_id);
  142. base_vtable_id = base_vtable_ptr_inst.vtable_id;
  143. base_class_specific_id = base_class_type->specific_id;
  144. }
  145. }
  146. const auto& class_info = context.classes().Get(class_id);
  147. auto class_generic_id = class_info.generic_id;
  148. // Wrap vtable entries in SpecificFunctions as needed/in generic classes.
  149. auto build_specific_function =
  150. [&](SemIR::InstId fn_decl_id) -> SemIR::InstId {
  151. if (!class_generic_id.has_value()) {
  152. return fn_decl_id;
  153. }
  154. const auto& fn_decl =
  155. context.insts().GetAs<SemIR::FunctionDecl>(fn_decl_id);
  156. const auto& function = context.functions().Get(fn_decl.function_id);
  157. return GetOrAddInst<SemIR::SpecificFunction>(
  158. context, node_id,
  159. {.type_id =
  160. GetSingletonType(context, SemIR::SpecificFunctionType::TypeInstId),
  161. .callee_id = fn_decl_id,
  162. .specific_id =
  163. context.generics().GetSelfSpecific(function.generic_id)});
  164. };
  165. llvm::SmallVector<SemIR::InstId> vtable;
  166. Set<SemIR::FunctionId> implemented_impls;
  167. if (base_vtable_id.has_value()) {
  168. auto base_vtable_inst_block = context.inst_blocks().Get(
  169. context.vtables().Get(base_vtable_id).virtual_functions_id);
  170. // TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
  171. // elements of the top of `vtable_stack`.
  172. for (auto base_vtable_entry_id : base_vtable_inst_block) {
  173. LoadImportRef(context, base_vtable_entry_id);
  174. auto [derived_vtable_entry_id, derived_vtable_entry_const_id, fn_id,
  175. specific_id] =
  176. DecomposeVirtualFunction(context.sem_ir(), base_vtable_entry_id,
  177. base_class_specific_id);
  178. const auto& fn = context.sem_ir().functions().Get(fn_id);
  179. const auto* i = llvm::find_if(
  180. vtable_contents, [&](SemIR::InstId override_fn_decl_id) -> bool {
  181. const auto& override_fn = context.functions().Get(
  182. context.insts()
  183. .GetAs<SemIR::FunctionDecl>(override_fn_decl_id)
  184. .function_id);
  185. return override_fn.virtual_modifier ==
  186. SemIR::FunctionFields::VirtualModifier::Impl &&
  187. override_fn.name_id == fn.name_id;
  188. });
  189. if (i != vtable_contents.end()) {
  190. auto override_fn_id =
  191. context.insts().GetAs<SemIR::FunctionDecl>(*i).function_id;
  192. implemented_impls.Insert(override_fn_id);
  193. auto& override_fn = context.functions().Get(override_fn_id);
  194. CheckFunctionTypeMatches(context, override_fn, fn, specific_id,
  195. /*check_syntax=*/false,
  196. /*check_self=*/false);
  197. derived_vtable_entry_id = build_specific_function(*i);
  198. override_fn.virtual_index = vtable.size();
  199. CARBON_CHECK(override_fn.virtual_index == fn.virtual_index);
  200. } else if (auto base_vtable_specific_function =
  201. context.sem_ir().insts().TryGetAs<SemIR::SpecificFunction>(
  202. derived_vtable_entry_id)) {
  203. if (derived_vtable_entry_const_id.is_symbolic()) {
  204. // Create a new instruction here that is otherwise identical to
  205. // `derived_vtable_entry_id` but is dependent within the derived
  206. // class. This ensures we can `GetConstantValueInSpecific` for it
  207. // with the derived class's specific (when forming further derived
  208. // classes, lowering the vtable, etc).
  209. derived_vtable_entry_id = GetOrAddInst<SemIR::SpecificFunction>(
  210. context, node_id,
  211. {.type_id = GetSingletonType(
  212. context, SemIR::SpecificFunctionType::TypeInstId),
  213. .callee_id = base_vtable_specific_function->callee_id,
  214. .specific_id = base_vtable_specific_function->specific_id});
  215. }
  216. }
  217. vtable.push_back(derived_vtable_entry_id);
  218. }
  219. }
  220. for (auto inst_id : vtable_contents) {
  221. auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
  222. auto& fn = context.functions().Get(fn_decl.function_id);
  223. if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
  224. fn.virtual_index = vtable.size();
  225. vtable.push_back(build_specific_function(inst_id));
  226. } else if (!implemented_impls.Lookup(fn_decl.function_id)) {
  227. CARBON_DIAGNOSTIC(ImplWithoutVirtualInBase, Error,
  228. "impl without compatible virtual in base class");
  229. context.emitter().Emit(SemIR::LocId(inst_id), ImplWithoutVirtualInBase);
  230. }
  231. }
  232. return context.vtables().Add(
  233. {{.class_id = class_id,
  234. .virtual_functions_id = context.inst_blocks().Add(vtable)}});
  235. }
  236. // Checks that the specified finished class definition is valid and builds and
  237. // returns a corresponding complete type witness instruction.
  238. static auto CheckCompleteClassType(
  239. Context& context, Parse::ClassDefinitionId node_id, SemIR::ClassId class_id,
  240. llvm::ArrayRef<SemIR::InstId> field_decls,
  241. llvm::ArrayRef<SemIR::InstId> vtable_contents,
  242. llvm::ArrayRef<SemIR::InstId> body) -> SemIR::InstId {
  243. auto& class_info = context.classes().Get(class_id);
  244. if (class_info.adapt_id.has_value()) {
  245. return CheckCompleteAdapterClassType(context, node_id, class_id,
  246. field_decls, body);
  247. }
  248. bool defining_vptr = class_info.is_dynamic;
  249. auto base_type_id =
  250. class_info.GetBaseType(context.sem_ir(), SemIR::SpecificId::None);
  251. // TODO: Use InstId from base declaration.
  252. auto base_type_inst_id = context.types().GetInstId(base_type_id);
  253. std::optional<SemIR::ClassType> base_class_type;
  254. if (base_type_id.has_value()) {
  255. // TODO: If the base class is template dependent, we will need to decide
  256. // whether to add a vptr as part of instantiation.
  257. base_class_type = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
  258. if (base_class_type &&
  259. context.classes().Get(base_class_type->class_id).is_dynamic) {
  260. defining_vptr = false;
  261. }
  262. }
  263. llvm::SmallVector<SemIR::StructTypeField> struct_type_fields;
  264. struct_type_fields.reserve(defining_vptr + class_info.base_id.has_value() +
  265. field_decls.size());
  266. if (defining_vptr) {
  267. struct_type_fields.push_back(
  268. {.name_id = SemIR::NameId::Vptr,
  269. .type_inst_id = context.types().GetInstId(
  270. GetPointerType(context, SemIR::VtableType::TypeInstId))});
  271. }
  272. if (base_type_id.has_value()) {
  273. auto base_decl = context.insts().GetAs<SemIR::BaseDecl>(class_info.base_id);
  274. base_decl.index =
  275. SemIR::ElementIndex{static_cast<int>(struct_type_fields.size())};
  276. ReplaceInstPreservingConstantValue(context, class_info.base_id, base_decl);
  277. struct_type_fields.push_back(
  278. {.name_id = SemIR::NameId::Base, .type_inst_id = base_type_inst_id});
  279. }
  280. if (class_info.is_dynamic) {
  281. auto vtable_id = BuildVtable(context, node_id, class_id, base_class_type,
  282. vtable_contents);
  283. auto vptr_type_id = GetPointerType(context, SemIR::VtableType::TypeInstId);
  284. auto generic_id = class_info.generic_id;
  285. auto self_specific_id = context.generics().GetSelfSpecific(generic_id);
  286. class_info.vtable_ptr_id =
  287. AddInst<SemIR::VtablePtr>(context, node_id,
  288. {.type_id = vptr_type_id,
  289. .vtable_id = vtable_id,
  290. .specific_id = self_specific_id});
  291. }
  292. auto struct_type_inst_id = AddTypeInst<SemIR::StructType>(
  293. context, node_id,
  294. {.type_id = SemIR::TypeType::TypeId,
  295. .fields_id =
  296. AddStructTypeFields(context, struct_type_fields, field_decls)});
  297. return AddInst<SemIR::CompleteTypeWitness>(
  298. context, node_id,
  299. {.type_id = GetSingletonType(context, SemIR::WitnessType::TypeInstId),
  300. .object_repr_type_inst_id = struct_type_inst_id});
  301. }
  302. auto ComputeClassObjectRepr(Context& context, Parse::ClassDefinitionId node_id,
  303. SemIR::ClassId class_id,
  304. llvm::ArrayRef<SemIR::InstId> field_decls,
  305. llvm::ArrayRef<SemIR::InstId> vtable_contents,
  306. llvm::ArrayRef<SemIR::InstId> body) -> void {
  307. auto complete_type_witness_id = CheckCompleteClassType(
  308. context, node_id, class_id, field_decls, vtable_contents, body);
  309. auto& class_info = context.classes().Get(class_id);
  310. class_info.complete_type_witness_id = complete_type_witness_id;
  311. }
  312. } // namespace Carbon::Check