class.cpp 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. // Part of the Carbon Language project, under the Apache License v2.0 with LLVM
  2. // Exceptions. See /LICENSE for license information.
  3. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
  4. #include "toolchain/check/class.h"
  5. #include "toolchain/check/context.h"
  6. #include "toolchain/check/eval.h"
  7. #include "toolchain/check/function.h"
  8. #include "toolchain/check/import_ref.h"
  9. #include "toolchain/check/inst.h"
  10. #include "toolchain/check/type.h"
  11. #include "toolchain/parse/node_ids.h"
  12. namespace Carbon::Check {
  13. auto SetClassSelfType(Context& context, SemIR::ClassId class_id) -> void {
  14. auto& class_info = context.classes().Get(class_id);
  15. auto specific_id = context.generics().GetSelfSpecific(class_info.generic_id);
  16. class_info.self_type_id = GetClassType(context, class_id, specific_id);
  17. }
  18. auto StartClassDefinition(Context& context, SemIR::Class& class_info,
  19. SemIR::InstId definition_id) -> void {
  20. // Track that this declaration is the definition.
  21. CARBON_CHECK(!class_info.has_definition_started());
  22. class_info.definition_id = definition_id;
  23. class_info.scope_id = context.name_scopes().Add(
  24. definition_id, SemIR::NameId::None, class_info.parent_scope_id);
  25. // Introduce `Self`.
  26. context.name_scopes().AddRequiredName(
  27. class_info.scope_id, SemIR::NameId::SelfType,
  28. context.types().GetInstId(class_info.self_type_id));
  29. }
  30. // Checks that the specified finished adapter definition is valid and builds and
  31. // returns a corresponding complete type witness instruction.
  32. static auto CheckCompleteAdapterClassType(
  33. Context& context, Parse::NodeId node_id, SemIR::ClassId class_id,
  34. llvm::ArrayRef<SemIR::InstId> field_decls,
  35. llvm::ArrayRef<SemIR::InstId> body) -> SemIR::InstId {
  36. const auto& class_info = context.classes().Get(class_id);
  37. if (class_info.base_id.has_value()) {
  38. CARBON_DIAGNOSTIC(AdaptWithBase, Error, "adapter with base class");
  39. CARBON_DIAGNOSTIC(AdaptWithBaseHere, Note, "`base` declaration is here");
  40. context.emitter()
  41. .Build(class_info.adapt_id, AdaptWithBase)
  42. .Note(class_info.base_id, AdaptWithBaseHere)
  43. .Emit();
  44. return SemIR::ErrorInst::InstId;
  45. }
  46. if (!field_decls.empty()) {
  47. CARBON_DIAGNOSTIC(AdaptWithFields, Error, "adapter with fields");
  48. CARBON_DIAGNOSTIC(AdaptWithFieldHere, Note,
  49. "first field declaration is here");
  50. context.emitter()
  51. .Build(class_info.adapt_id, AdaptWithFields)
  52. .Note(field_decls.front(), AdaptWithFieldHere)
  53. .Emit();
  54. return SemIR::ErrorInst::InstId;
  55. }
  56. for (auto inst_id : body) {
  57. if (auto function_decl =
  58. context.insts().TryGetAs<SemIR::FunctionDecl>(inst_id)) {
  59. auto& function = context.functions().Get(function_decl->function_id);
  60. if (function.virtual_modifier ==
  61. SemIR::Function::VirtualModifier::Virtual) {
  62. CARBON_DIAGNOSTIC(AdaptWithVirtual, Error,
  63. "adapter with virtual function");
  64. CARBON_DIAGNOSTIC(AdaptWithVirtualHere, Note,
  65. "first virtual function declaration is here");
  66. context.emitter()
  67. .Build(class_info.adapt_id, AdaptWithVirtual)
  68. .Note(inst_id, AdaptWithVirtualHere)
  69. .Emit();
  70. return SemIR::ErrorInst::InstId;
  71. }
  72. }
  73. }
  74. // The object representation of the adapter is the object representation
  75. // of the adapted type.
  76. auto adapted_type_id =
  77. class_info.GetAdaptedType(context.sem_ir(), SemIR::SpecificId::None);
  78. auto object_repr_id = context.types().GetObjectRepr(adapted_type_id);
  79. return AddInst<SemIR::CompleteTypeWitness>(
  80. context, node_id,
  81. {.type_id = GetSingletonType(context, SemIR::WitnessType::TypeInstId),
  82. // TODO: Use InstId from the adapt declaration.
  83. .object_repr_type_inst_id = context.types().GetInstId(object_repr_id)});
  84. }
  85. static auto AddStructTypeFields(
  86. Context& context,
  87. llvm::SmallVector<SemIR::StructTypeField>& struct_type_fields,
  88. llvm::ArrayRef<SemIR::InstId> field_decls) -> SemIR::StructTypeFieldsId {
  89. for (auto field_decl_id : field_decls) {
  90. auto field_decl = context.insts().GetAs<SemIR::FieldDecl>(field_decl_id);
  91. field_decl.index =
  92. SemIR::ElementIndex{static_cast<int>(struct_type_fields.size())};
  93. ReplaceInstPreservingConstantValue(context, field_decl_id, field_decl);
  94. if (field_decl.type_id == SemIR::ErrorInst::TypeId) {
  95. struct_type_fields.push_back(
  96. {.name_id = field_decl.name_id,
  97. .type_inst_id = SemIR::ErrorInst::TypeInstId});
  98. continue;
  99. }
  100. auto unbound_element_type =
  101. context.sem_ir().types().GetAs<SemIR::UnboundElementType>(
  102. field_decl.type_id);
  103. struct_type_fields.push_back(
  104. {.name_id = field_decl.name_id,
  105. .type_inst_id = unbound_element_type.element_type_inst_id});
  106. }
  107. auto fields_id =
  108. context.struct_type_fields().AddCanonical(struct_type_fields);
  109. return fields_id;
  110. }
  111. // Builds and returns a vtable for the current class. Assumes that the virtual
  112. // functions for the class are listed as the top element of the `vtable_stack`.
  113. static auto BuildVtable(Context& context, Parse::ClassDefinitionId node_id,
  114. SemIR::ClassId class_id,
  115. std::optional<SemIR::ClassType> base_class_type,
  116. llvm::ArrayRef<SemIR::InstId> vtable_contents)
  117. -> SemIR::VtableId {
  118. auto base_vtable_id = SemIR::VtableId::None;
  119. auto base_class_specific_id = SemIR::SpecificId::None;
  120. // Get some base class/type/specific info.
  121. if (base_class_type) {
  122. auto& base_class_info = context.classes().Get(base_class_type->class_id);
  123. auto base_vtable_ptr_inst_id = base_class_info.vtable_ptr_id;
  124. if (base_vtable_ptr_inst_id.has_value()) {
  125. LoadImportRef(context, base_vtable_ptr_inst_id);
  126. auto canonical_base_vtable_inst_id =
  127. context.constant_values().GetConstantInstId(base_vtable_ptr_inst_id);
  128. const auto& base_vtable_ptr_inst =
  129. context.insts().GetAs<SemIR::VtablePtr>(
  130. canonical_base_vtable_inst_id);
  131. base_vtable_id = base_vtable_ptr_inst.vtable_id;
  132. base_class_specific_id = base_class_type->specific_id;
  133. }
  134. }
  135. const auto& class_info = context.classes().Get(class_id);
  136. auto class_generic_id = class_info.generic_id;
  137. // Wrap vtable entries in SpecificFunctions as needed/in generic classes.
  138. auto build_specific_function =
  139. [&](SemIR::InstId fn_decl_id) -> SemIR::InstId {
  140. if (!class_generic_id.has_value()) {
  141. return fn_decl_id;
  142. }
  143. const auto& fn_decl =
  144. context.insts().GetAs<SemIR::FunctionDecl>(fn_decl_id);
  145. const auto& function = context.functions().Get(fn_decl.function_id);
  146. return GetOrAddInst<SemIR::SpecificFunction>(
  147. context, node_id,
  148. {.type_id =
  149. GetSingletonType(context, SemIR::SpecificFunctionType::TypeInstId),
  150. .callee_id = fn_decl_id,
  151. .specific_id =
  152. context.generics().GetSelfSpecific(function.generic_id)});
  153. };
  154. llvm::SmallVector<SemIR::InstId> vtable;
  155. if (base_vtable_id.has_value()) {
  156. auto base_vtable_inst_block = context.inst_blocks().Get(
  157. context.vtables().Get(base_vtable_id).virtual_functions_id);
  158. // TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
  159. // elements of the top of `vtable_stack`.
  160. for (auto fn_decl_id : base_vtable_inst_block) {
  161. auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
  162. const auto& fn = context.functions().Get(fn_decl.function_id);
  163. const auto* i = llvm::find_if(
  164. vtable_contents, [&](SemIR::InstId override_fn_decl_id) -> bool {
  165. const auto& override_fn = context.functions().Get(
  166. context.insts()
  167. .GetAs<SemIR::FunctionDecl>(override_fn_decl_id)
  168. .function_id);
  169. return override_fn.virtual_modifier ==
  170. SemIR::FunctionFields::VirtualModifier::Impl &&
  171. override_fn.name_id == fn.name_id;
  172. });
  173. if (i != vtable_contents.end()) {
  174. auto& override_fn = context.functions().Get(
  175. context.insts().GetAs<SemIR::FunctionDecl>(*i).function_id);
  176. // TODO: Support generic base classes, rather than passing
  177. // `SpecificId::None`. This'll need to `GetConstantValueInSpecific` for
  178. // the base function, then extract the specific from that for use here.
  179. CheckFunctionTypeMatches(context, override_fn, fn,
  180. SemIR::SpecificId::None,
  181. /*check_syntax=*/false,
  182. /*check_self=*/false);
  183. fn_decl_id = build_specific_function(*i);
  184. override_fn.virtual_index = vtable.size();
  185. CARBON_CHECK(override_fn.virtual_index == fn.virtual_index);
  186. } else {
  187. // Remap the base's vtable entry to the appropriate constant usable in
  188. // the context of the derived class (for the specific for the base
  189. // class, for instance)..
  190. fn_decl_id = context.sem_ir().constant_values().GetInstId(
  191. GetConstantValueInSpecific(context.sem_ir(), base_class_specific_id,
  192. fn_decl_id));
  193. }
  194. vtable.push_back(fn_decl_id);
  195. }
  196. }
  197. for (auto inst_id : vtable_contents) {
  198. auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
  199. auto& fn = context.functions().Get(fn_decl.function_id);
  200. if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
  201. fn.virtual_index = vtable.size();
  202. vtable.push_back(build_specific_function(inst_id));
  203. }
  204. }
  205. return context.vtables().Add(
  206. {{.class_id = class_id,
  207. .virtual_functions_id = context.inst_blocks().Add(vtable)}});
  208. }
  209. // Checks that the specified finished class definition is valid and builds and
  210. // returns a corresponding complete type witness instruction.
  211. static auto CheckCompleteClassType(
  212. Context& context, Parse::ClassDefinitionId node_id, SemIR::ClassId class_id,
  213. llvm::ArrayRef<SemIR::InstId> field_decls,
  214. llvm::ArrayRef<SemIR::InstId> vtable_contents,
  215. llvm::ArrayRef<SemIR::InstId> body) -> SemIR::InstId {
  216. auto& class_info = context.classes().Get(class_id);
  217. if (class_info.adapt_id.has_value()) {
  218. return CheckCompleteAdapterClassType(context, node_id, class_id,
  219. field_decls, body);
  220. }
  221. bool defining_vptr = class_info.is_dynamic;
  222. auto base_type_id =
  223. class_info.GetBaseType(context.sem_ir(), SemIR::SpecificId::None);
  224. // TODO: Use InstId from base declaration.
  225. auto base_type_inst_id = context.types().GetInstId(base_type_id);
  226. std::optional<SemIR::ClassType> base_class_type;
  227. if (base_type_id.has_value()) {
  228. // TODO: If the base class is template dependent, we will need to decide
  229. // whether to add a vptr as part of instantiation.
  230. base_class_type = context.types().TryGetAs<SemIR::ClassType>(base_type_id);
  231. if (base_class_type &&
  232. context.classes().Get(base_class_type->class_id).is_dynamic) {
  233. defining_vptr = false;
  234. }
  235. }
  236. llvm::SmallVector<SemIR::StructTypeField> struct_type_fields;
  237. struct_type_fields.reserve(defining_vptr + class_info.base_id.has_value() +
  238. field_decls.size());
  239. if (defining_vptr) {
  240. struct_type_fields.push_back(
  241. {.name_id = SemIR::NameId::Vptr,
  242. .type_inst_id = context.types().GetInstId(
  243. GetPointerType(context, SemIR::VtableType::TypeInstId))});
  244. }
  245. if (base_type_id.has_value()) {
  246. auto base_decl = context.insts().GetAs<SemIR::BaseDecl>(class_info.base_id);
  247. base_decl.index =
  248. SemIR::ElementIndex{static_cast<int>(struct_type_fields.size())};
  249. ReplaceInstPreservingConstantValue(context, class_info.base_id, base_decl);
  250. struct_type_fields.push_back(
  251. {.name_id = SemIR::NameId::Base, .type_inst_id = base_type_inst_id});
  252. }
  253. if (class_info.is_dynamic) {
  254. auto vtable_id = BuildVtable(context, node_id, class_id, base_class_type,
  255. vtable_contents);
  256. auto vptr_type_id = GetPointerType(context, SemIR::VtableType::TypeInstId);
  257. // TODO: Handle specifics here, probably passing
  258. // `context.generics().GetSelfSpecific(class_info.generic_id)` as the
  259. // specific_id here (but more work involved to get this all plumbed in and
  260. // tested).
  261. auto generic_id = class_info.generic_id;
  262. auto self_specific_id = context.generics().GetSelfSpecific(generic_id);
  263. class_info.vtable_ptr_id =
  264. AddInst<SemIR::VtablePtr>(context, node_id,
  265. {.type_id = vptr_type_id,
  266. .vtable_id = vtable_id,
  267. .specific_id = self_specific_id});
  268. }
  269. auto struct_type_inst_id = AddTypeInst<SemIR::StructType>(
  270. context, node_id,
  271. {.type_id = SemIR::TypeType::TypeId,
  272. .fields_id =
  273. AddStructTypeFields(context, struct_type_fields, field_decls)});
  274. return AddInst<SemIR::CompleteTypeWitness>(
  275. context, node_id,
  276. {.type_id = GetSingletonType(context, SemIR::WitnessType::TypeInstId),
  277. .object_repr_type_inst_id = struct_type_inst_id});
  278. }
  279. auto ComputeClassObjectRepr(Context& context, Parse::ClassDefinitionId node_id,
  280. SemIR::ClassId class_id,
  281. llvm::ArrayRef<SemIR::InstId> field_decls,
  282. llvm::ArrayRef<SemIR::InstId> vtable_contents,
  283. llvm::ArrayRef<SemIR::InstId> body) -> void {
  284. auto complete_type_witness_id = CheckCompleteClassType(
  285. context, node_id, class_id, field_decls, vtable_contents, body);
  286. auto& class_info = context.classes().Get(class_id);
  287. class_info.complete_type_witness_id = complete_type_witness_id;
  288. }
  289. } // namespace Carbon::Check