Bläddra i källkod

Refactor `SemIR::InstNamer` (#3690)

- Rename `ScopeIndex` to `ScopeId`, because the order is not meaningful.
This avoids collisions with `SemIR::ScopeIndex`.
- Change `GetScopeFor` to use an `if constexpr` chain rather than
duplicating the numbering logic across multiple functions.
- Simplify `GetNameFor`, avoiding multiple identical overloads.

As requested in review of #3683.
Richard Smith 2 år sedan
förälder
incheckning
ead4dbbbe3
1 ändrade filer med 59 tillägg och 75 borttagningar
  1. 59 75
      toolchain/sem_ir/formatter.cpp

+ 59 - 75
toolchain/sem_ir/formatter.cpp

@@ -29,13 +29,15 @@ class InstNamer {
  public:
   // int32_t matches the input value size.
   // NOLINTNEXTLINE(performance-enum-size)
-  enum class ScopeIndex : int32_t {
+  enum class ScopeId : int32_t {
     None = -1,
     File = 0,
     Constants = 1,
     FirstFunction = 2,
   };
-  static_assert(sizeof(ScopeIndex) == sizeof(FunctionId));
+  static_assert(sizeof(ScopeId) == sizeof(FunctionId));
+
+  struct NumberOfScopesTag {};
 
   InstNamer(const Lex::TokenizedBuffer& tokenized_buffer,
             const Parse::Tree& parse_tree, const File& sem_ir)
@@ -44,19 +46,16 @@ class InstNamer {
         sem_ir_(sem_ir) {
     insts.resize(sem_ir.insts().size());
     labels.resize(sem_ir.inst_blocks().size());
-    scopes.resize(static_cast<int32_t>(ScopeIndex::FirstFunction) +
-                  sem_ir.functions().size() + sem_ir.classes().size() +
-                  sem_ir.interfaces().size());
+    scopes.resize(static_cast<size_t>(GetScopeFor(NumberOfScopesTag())));
 
     // Build the constants scope.
-    GetScopeInfo(ScopeIndex::Constants).name =
+    GetScopeInfo(ScopeId::Constants).name =
         globals.AddNameUnchecked("constants");
-    CollectNamesInBlock(ScopeIndex::Constants,
-                        sem_ir.constants().GetAsVector());
+    CollectNamesInBlock(ScopeId::Constants, sem_ir.constants().GetAsVector());
 
     // Build the file scope.
-    GetScopeInfo(ScopeIndex::File).name = globals.AddNameUnchecked("file");
-    CollectNamesInBlock(ScopeIndex::File, sem_ir.top_inst_block_id());
+    GetScopeInfo(ScopeId::File).name = globals.AddNameUnchecked("file");
+    CollectNamesInBlock(ScopeId::File, sem_ir.top_inst_block_id());
 
     // Build each function scope.
     for (auto [i, fn] : llvm::enumerate(sem_ir.functions().array_ref())) {
@@ -117,54 +116,41 @@ class InstNamer {
     }
   }
 
-  // Returns the scope index corresponding to a function.
-  auto GetScopeFor(FunctionId fn_id) -> ScopeIndex {
-    return static_cast<ScopeIndex>(
-        static_cast<int32_t>(ScopeIndex::FirstFunction) + fn_id.index);
-  }
-
-  // Returns the scope index corresponding to a class.
-  auto GetScopeFor(ClassId class_id) -> ScopeIndex {
-    return static_cast<ScopeIndex>(
-        static_cast<int32_t>(ScopeIndex::FirstFunction) +
-        sem_ir_.functions().size() + class_id.index);
-  }
-
-  // Returns the scope index corresponding to an interface.
-  auto GetScopeFor(InterfaceId interface_id) -> ScopeIndex {
-    return static_cast<ScopeIndex>(
-        static_cast<int32_t>(ScopeIndex::FirstFunction) +
-        sem_ir_.functions().size() + sem_ir_.classes().size() +
-        interface_id.index);
-  }
-
-  // Returns the IR name to use for a function.
-  auto GetNameFor(FunctionId fn_id) -> llvm::StringRef {
-    if (!fn_id.is_valid()) {
-      return "invalid";
+  // Returns the scope index corresponding to an ID of a function, class, or
+  // interface.
+  template <typename IdT>
+  auto GetScopeFor(IdT id) -> ScopeId {
+    auto index = static_cast<int32_t>(ScopeId::FirstFunction);
+
+    if constexpr (!std::same_as<FunctionId, IdT>) {
+      index += sem_ir_.functions().size();
+      if constexpr (!std::same_as<ClassId, IdT>) {
+        index += sem_ir_.classes().size();
+        if constexpr (!std::same_as<InterfaceId, IdT>) {
+          index += sem_ir_.interfaces().size();
+          static_assert(std::same_as<NumberOfScopesTag, IdT>,
+                        "Unknown ID kind for scope");
+        }
+      }
     }
-    return GetScopeInfo(GetScopeFor(fn_id)).name.str();
-  }
-
-  // Returns the IR name to use for a class.
-  auto GetNameFor(ClassId class_id) -> llvm::StringRef {
-    if (!class_id.is_valid()) {
-      return "invalid";
+    if constexpr (!std::same_as<NumberOfScopesTag, IdT>) {
+      index += id.index;
     }
-    return GetScopeInfo(GetScopeFor(class_id)).name.str();
+    return static_cast<ScopeId>(index);
   }
 
-  // Returns the IR name to use for an interface.
-  auto GetNameFor(InterfaceId interface_id) -> llvm::StringRef {
-    if (!interface_id.is_valid()) {
+  // Returns the IR name to use for a function, class, or interface.
+  template <typename IdT>
+  auto GetNameFor(IdT id) -> llvm::StringRef {
+    if (!id.is_valid()) {
       return "invalid";
     }
-    return GetScopeInfo(GetScopeFor(interface_id)).name.str();
+    return GetScopeInfo(GetScopeFor(id)).name.str();
   }
 
   // Returns the IR name to use for an instruction, when referenced from a given
   // scope.
-  auto GetNameFor(ScopeIndex scope_idx, InstId inst_id) -> std::string {
+  auto GetNameFor(ScopeId scope_id, InstId inst_id) -> std::string {
     if (!inst_id.is_valid()) {
       return "invalid";
     }
@@ -185,14 +171,14 @@ class InstNamer {
       llvm::raw_string_ostream(str) << "<unexpected instref " << inst_id << ">";
       return str;
     }
-    if (inst_scope == scope_idx) {
+    if (inst_scope == scope_id) {
       return inst_name.str().str();
     }
     return (GetScopeInfo(inst_scope).name.str() + "." + inst_name.str()).str();
   }
 
   // Returns the IR name to use for a label, when referenced from a given scope.
-  auto GetLabelFor(ScopeIndex scope_idx, InstBlockId block_id) -> std::string {
+  auto GetLabelFor(ScopeId scope_id, InstBlockId block_id) -> std::string {
     if (!block_id.is_valid()) {
       return "!invalid";
     }
@@ -205,7 +191,7 @@ class InstNamer {
           << "<unexpected instblockref " << block_id << ">";
       return str;
     }
-    if (label_scope == scope_idx) {
+    if (label_scope == scope_id) {
       return label_name.str().str();
     }
     return (GetScopeInfo(label_scope).name.str() + "." + label_name.str())
@@ -327,11 +313,11 @@ class InstNamer {
     Namespace labels = {.prefix = "!"};
   };
 
-  auto GetScopeInfo(ScopeIndex scope_idx) -> Scope& {
-    return scopes[static_cast<int>(scope_idx)];
+  auto GetScopeInfo(ScopeId scope_id) -> Scope& {
+    return scopes[static_cast<int>(scope_id)];
   }
 
-  auto AddBlockLabel(ScopeIndex scope_idx, InstBlockId block_id,
+  auto AddBlockLabel(ScopeId scope_id, InstBlockId block_id,
                      std::string name = "",
                      Parse::NodeId parse_node = Parse::NodeId::Invalid)
       -> void {
@@ -346,14 +332,14 @@ class InstNamer {
       }
     }
 
-    labels[block_id.index] = {scope_idx,
-                              GetScopeInfo(scope_idx).labels.AllocateName(
-                                  *this, parse_node, std::move(name))};
+    labels[block_id.index] = {
+        scope_id, GetScopeInfo(scope_id).labels.AllocateName(*this, parse_node,
+                                                             std::move(name))};
   }
 
   // Finds and adds a suitable block label for the given SemIR instruction that
   // represents some kind of branch.
-  auto AddBlockLabel(ScopeIndex scope_idx, Parse::NodeId parse_node,
+  auto AddBlockLabel(ScopeId scope_id, Parse::NodeId parse_node,
                      AnyBranch branch) -> void {
     llvm::StringRef name;
     switch (parse_tree_.node_kind(parse_node)) {
@@ -418,18 +404,18 @@ class InstNamer {
         break;
     }
 
-    AddBlockLabel(scope_idx, branch.target_id, name.str(), parse_node);
+    AddBlockLabel(scope_id, branch.target_id, name.str(), parse_node);
   }
 
-  auto CollectNamesInBlock(ScopeIndex scope_idx, InstBlockId block_id) -> void {
+  auto CollectNamesInBlock(ScopeId scope_id, InstBlockId block_id) -> void {
     if (block_id.is_valid()) {
-      CollectNamesInBlock(scope_idx, sem_ir_.inst_blocks().Get(block_id));
+      CollectNamesInBlock(scope_id, sem_ir_.inst_blocks().Get(block_id));
     }
   }
 
-  auto CollectNamesInBlock(ScopeIndex scope_idx, llvm::ArrayRef<InstId> block)
+  auto CollectNamesInBlock(ScopeId scope_id, llvm::ArrayRef<InstId> block)
       -> void {
-    Scope& scope = GetScopeInfo(scope_idx);
+    Scope& scope = GetScopeInfo(scope_id);
 
     // Use bound names where available. Otherwise, assign a backup name.
     for (auto inst_id : block) {
@@ -440,8 +426,8 @@ class InstNamer {
       auto inst = sem_ir_.insts().Get(inst_id);
       auto add_inst_name = [&](std::string name) {
         insts[inst_id.index] = {
-            scope_idx, scope.insts.AllocateName(
-                           *this, sem_ir_.insts().GetParseNode(inst_id), name)};
+            scope_id, scope.insts.AllocateName(
+                          *this, sem_ir_.insts().GetParseNode(inst_id), name)};
       };
       auto add_inst_name_id = [&](NameId name_id, llvm::StringRef suffix = "") {
         add_inst_name(
@@ -449,8 +435,7 @@ class InstNamer {
       };
 
       if (auto branch = inst.TryAs<AnyBranch>()) {
-        AddBlockLabel(scope_idx, sem_ir_.insts().GetParseNode(inst_id),
-                      *branch);
+        AddBlockLabel(scope_id, sem_ir_.insts().GetParseNode(inst_id), *branch);
       }
 
       switch (inst.kind()) {
@@ -459,11 +444,11 @@ class InstNamer {
           // function declarations, which may be nested within a pattern. For
           // now, just look through `addr`, but we should find a better way to
           // visit parameters.
-          CollectNamesInBlock(scope_idx, inst.As<AddrPattern>().inner_id);
+          CollectNamesInBlock(scope_id, inst.As<AddrPattern>().inner_id);
           break;
         }
         case SpliceBlock::Kind: {
-          CollectNamesInBlock(scope_idx, inst.As<SpliceBlock>().block_id);
+          CollectNamesInBlock(scope_id, inst.As<SpliceBlock>().block_id);
           break;
         }
         case BindName::Kind:
@@ -535,8 +520,8 @@ class InstNamer {
   const File& sem_ir_;
 
   Namespace globals = {.prefix = "@"};
-  std::vector<std::pair<ScopeIndex, Namespace::Name>> insts;
-  std::vector<std::pair<ScopeIndex, Namespace::Name>> labels;
+  std::vector<std::pair<ScopeId, Namespace::Name>> insts;
+  std::vector<std::pair<ScopeId, Namespace::Name>> labels;
   std::vector<Scope> scopes;
 };
 }  // namespace
@@ -566,7 +551,7 @@ class Formatter {
     // blocks. For example, there may be branching in the initializer of a
     // global or a type expression.
     if (auto block_id = sem_ir_.top_inst_block_id(); block_id.is_valid()) {
-      llvm::SaveAndRestore file_scope(scope_, InstNamer::ScopeIndex::File);
+      llvm::SaveAndRestore file_scope(scope_, InstNamer::ScopeId::File);
       FormatCodeBlock(block_id);
     }
     out_ << "}\n";
@@ -592,8 +577,7 @@ class Formatter {
       return;
     }
 
-    llvm::SaveAndRestore constants_scope(scope_,
-                                         InstNamer::ScopeIndex::Constants);
+    llvm::SaveAndRestore constants_scope(scope_, InstNamer::ScopeId::Constants);
     out_ << "constants {\n";
     FormatCodeBlock(sem_ir_.constants().GetAsVector());
     out_ << "}\n\n";
@@ -1115,7 +1099,7 @@ class Formatter {
   const File& sem_ir_;
   llvm::raw_ostream& out_;
   InstNamer inst_namer_;
-  InstNamer::ScopeIndex scope_ = InstNamer::ScopeIndex::None;
+  InstNamer::ScopeId scope_ = InstNamer::ScopeId::None;
   bool in_terminator_sequence_ = false;
   int indent_ = 2;
 };