소스 검색

Change StringLiteral to less frequently allocate a new string. (#3314)

Building on #3311, which started moving the result string into a
`unique_ptr`, instead have `StringLiteral` use a `BumpPtrAllocator` to
manage memory. But also, detect when a string is really trivial during
`Lex` and, if so, return `contents_` directly.
Jon Ross-Perkins 2 년 전
부모
커밋
b2cfd5a8a8

+ 84 - 46
toolchain/lex/string_literal.cpp

@@ -114,11 +114,13 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
   terminator.resize(terminator.size() + hash_level, '#');
   escape.resize(escape.size() + hash_level, '#');
 
+  bool content_needs_validation = false;
+
   // TODO: Detect indent / dedent for multi-line string literals in order to
   // stop parsing on dedent before a terminator is found.
   for (; cursor < source_text_size; ++cursor) {
     // Use a lookup table to allow us to quickly skip uninteresting characters.
-    static constexpr CharSet InterestingChars = {'\\', '\n', '"', '\''};
+    static constexpr CharSet InterestingChars = {'\\', '\n', '"', '\'', '\t'};
     if (!InterestingChars[source_text[cursor]]) {
       continue;
     }
@@ -127,9 +129,14 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
     // escape sequences starting with a predictable character and not containing
     // embedded and unescaped terminators or newlines.
     switch (source_text[cursor]) {
+      case '\t':
+        // Tabs have extra validation.
+        content_needs_validation = true;
+        break;
       case '\\':
         if (escape.size() == 1 ||
             source_text.substr(cursor + 1).startswith(escape.substr(1))) {
+          content_needs_validation = true;
           cursor += escape.size();
           // If there's either not a character following the escape, or it's a
           // single-line string and the escaped character is a newline, we
@@ -137,7 +144,8 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
           if (cursor >= source_text_size || (introducer->kind == NotMultiLine &&
                                              source_text[cursor] == '\n')) {
             llvm::StringRef text = source_text.take_front(cursor);
-            return StringLiteral(text, text.drop_front(prefix_len), hash_level,
+            return StringLiteral(text, text.drop_front(prefix_len),
+                                 content_needs_validation, hash_level,
                                  introducer->kind,
                                  /*is_terminated=*/false);
           }
@@ -146,7 +154,8 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
       case '\n':
         if (introducer->kind == NotMultiLine) {
           llvm::StringRef text = source_text.take_front(cursor);
-          return StringLiteral(text, text.drop_front(prefix_len), hash_level,
+          return StringLiteral(text, text.drop_front(prefix_len),
+                               content_needs_validation, hash_level,
                                introducer->kind,
                                /*is_terminated=*/false);
         }
@@ -158,7 +167,8 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
               source_text.substr(0, cursor + terminator.size());
           llvm::StringRef content =
               source_text.substr(prefix_len, cursor - prefix_len);
-          return StringLiteral(text, content, hash_level, introducer->kind,
+          return StringLiteral(text, content, content_needs_validation,
+                               hash_level, introducer->kind,
                                /*is_terminated=*/true);
         }
         break;
@@ -169,7 +179,7 @@ auto StringLiteral::Lex(llvm::StringRef source_text)
   }
   // No terminator was found.
   return StringLiteral(source_text, source_text.drop_front(prefix_len),
-                       hash_level, introducer->kind,
+                       content_needs_validation, hash_level, introducer->kind,
                        /*is_terminated=*/false);
 }
 
@@ -214,7 +224,7 @@ static auto CheckIndent(LexerDiagnosticEmitter& emitter, llvm::StringRef text,
 // Expand a `\u{HHHHHH}` escape sequence into a sequence of UTF-8 code units.
 static auto ExpandUnicodeEscapeSequence(LexerDiagnosticEmitter& emitter,
                                         llvm::StringRef digits,
-                                        std::string& result) -> bool {
+                                        char*& buffer_cursor) -> bool {
   unsigned code_point;
   if (!CanLexInteger(emitter, digits)) {
     return false;
@@ -238,51 +248,65 @@ static auto ExpandUnicodeEscapeSequence(LexerDiagnosticEmitter& emitter,
   // Convert the code point to a sequence of UTF-8 code units.
   // Every code point fits in 6 UTF-8 code units.
   const llvm::UTF32 utf32_code_units[1] = {code_point};
-  llvm::UTF8 utf8_code_units[6];
   const llvm::UTF32* src_pos = utf32_code_units;
-  llvm::UTF8* dest_pos = utf8_code_units;
+  auto*& buffer_cursor_as_utf8 = reinterpret_cast<llvm::UTF8*&>(buffer_cursor);
   llvm::ConversionResult conv_result = llvm::ConvertUTF32toUTF8(
-      &src_pos, src_pos + 1, &dest_pos, dest_pos + 6, llvm::strictConversion);
+      &src_pos, src_pos + 1, &buffer_cursor_as_utf8, buffer_cursor_as_utf8 + 6,
+      llvm::strictConversion);
   if (conv_result != llvm::conversionOK) {
     llvm_unreachable("conversion of valid code point to UTF-8 cannot fail");
   }
-  result.insert(result.end(), reinterpret_cast<char*>(utf8_code_units),
-                reinterpret_cast<char*>(dest_pos));
   return true;
 }
 
+// Appends a character to the buffer and advances the cursor.
+static auto AppendChar(char*& buffer_cursor, char append_char) -> void {
+  buffer_cursor[0] = append_char;
+  ++buffer_cursor;
+}
+
+// Appends the front of contents to the buffer and advances the cursor.
+static auto AppendFrontOfContents(char*& buffer_cursor,
+                                  llvm::StringRef contents, size_t len_or_npos)
+    -> void {
+  auto len =
+      len_or_npos == llvm::StringRef::npos ? contents.size() : len_or_npos;
+  memcpy(buffer_cursor, contents.data(), len);
+  buffer_cursor += len;
+}
+
 // Expand an escape sequence, appending the expanded value to the given
 // `result` string. `content` is the string content, starting from the first
 // character after the escape sequence introducer (for example, the `n` in
 // `\n`), and will be updated to remove the leading escape sequence.
 static auto ExpandAndConsumeEscapeSequence(LexerDiagnosticEmitter& emitter,
                                            llvm::StringRef& content,
-                                           std::string& result) -> void {
+                                           char*& buffer_cursor) -> void {
   CARBON_CHECK(!content.empty()) << "should have escaped closing delimiter";
   char first = content.front();
   content = content.drop_front(1);
 
   switch (first) {
     case 't':
-      result += '\t';
+      AppendChar(buffer_cursor, '\t');
       return;
     case 'n':
-      result += '\n';
+      AppendChar(buffer_cursor, '\n');
       return;
     case 'r':
-      result += '\r';
+      AppendChar(buffer_cursor, '\r');
       return;
     case '"':
-      result += '"';
+      AppendChar(buffer_cursor, '"');
       return;
     case '\'':
-      result += '\'';
+      AppendChar(buffer_cursor, '\'');
       return;
     case '\\':
-      result += '\\';
+      AppendChar(buffer_cursor, '\\');
       return;
     case '0':
-      result += '\0';
+      AppendChar(buffer_cursor, '\0');
       if (!content.empty() && IsDecimalDigit(content.front())) {
         CARBON_DIAGNOSTIC(
             DecimalEscapeSequence, Error,
@@ -295,8 +319,8 @@ static auto ExpandAndConsumeEscapeSequence(LexerDiagnosticEmitter& emitter,
     case 'x':
       if (content.size() >= 2 && IsUpperHexDigit(content[0]) &&
           IsUpperHexDigit(content[1])) {
-        result +=
-            static_cast<char>(llvm::hexFromNibbles(content[0], content[1]));
+        AppendChar(buffer_cursor, static_cast<char>(llvm::hexFromNibbles(
+                                      content[0], content[1])));
         content = content.drop_front(2);
         return;
       }
@@ -311,7 +335,7 @@ static auto ExpandAndConsumeEscapeSequence(LexerDiagnosticEmitter& emitter,
         llvm::StringRef digits = remaining.take_while(IsUpperHexDigit);
         remaining = remaining.drop_front(digits.size());
         if (!digits.empty() && remaining.consume_front("}")) {
-          if (!ExpandUnicodeEscapeSequence(emitter, digits, result)) {
+          if (!ExpandUnicodeEscapeSequence(emitter, digits, buffer_cursor)) {
             break;
           }
           content = remaining;
@@ -335,15 +359,14 @@ static auto ExpandAndConsumeEscapeSequence(LexerDiagnosticEmitter& emitter,
   // If we get here, we didn't recognize this escape sequence and have already
   // issued a diagnostic. For error recovery purposes, expand this escape
   // sequence to itself, dropping the introducer (for example, `\q` -> `q`).
-  result += first;
+  AppendChar(buffer_cursor, first);
 }
 
 // Expand any escape sequences in the given string literal.
 static auto ExpandEscapeSequencesAndRemoveIndent(
     LexerDiagnosticEmitter& emitter, llvm::StringRef contents, int hash_level,
-    llvm::StringRef indent) -> std::string {
-  std::string result;
-  result.reserve(contents.size());
+    llvm::StringRef indent, char* buffer) -> llvm::StringRef {
+  char* buffer_cursor = buffer;
 
   llvm::SmallString<16> escape("\\");
   escape.resize(1 + hash_level, '#');
@@ -365,9 +388,9 @@ static auto ExpandEscapeSequencesAndRemoveIndent(
       }
     }
 
-    // Tracks the length of the result at the last time we expanded an escape
-    // to ensure we don't misinterpret it as unescaped when backtracking.
-    size_t last_escape_length = 0;
+    // Tracks the position at the last time we expanded an escape to ensure we
+    // don't misinterpret it as unescaped when backtracking.
+    char* buffer_last_escape = buffer_cursor;
 
     // Process the contents of the line.
     while (true) {
@@ -376,22 +399,24 @@ static auto ExpandEscapeSequencesAndRemoveIndent(
         return c == '\n' || c == '\\' ||
                (IsHorizontalWhitespace(c) && c != ' ');
       });
-      result += contents.substr(0, end_of_regular_text);
-      contents = contents.substr(end_of_regular_text);
-
-      if (contents.empty()) {
-        return result;
+      AppendFrontOfContents(buffer_cursor, contents, end_of_regular_text);
+      if (end_of_regular_text == llvm::StringRef::npos) {
+        return llvm::StringRef(buffer, buffer_cursor - buffer);
       }
+      contents = contents.drop_front(end_of_regular_text);
 
       if (contents.consume_front("\n")) {
         // Trailing whitespace in the source before a newline doesn't contribute
         // to the string literal value. However, escaped whitespace (like `\t`)
         // and any whitespace just before that does contribute.
-        while (!result.empty() && result.back() != '\n' &&
-               IsSpace(result.back()) && result.length() > last_escape_length) {
-          result.pop_back();
+        while (buffer_cursor > buffer_last_escape) {
+          char back = *(buffer_cursor - 1);
+          if (back == '\n' || !IsSpace(back)) {
+            break;
+          }
+          --buffer_cursor;
         }
-        result += '\n';
+        AppendChar(buffer_cursor, '\n');
         // Move onto to the next line.
         break;
       }
@@ -412,7 +437,7 @@ static auto ExpandEscapeSequencesAndRemoveIndent(
               "escape sequence in a string literal.");
           emitter.Emit(contents.begin(), InvalidHorizontalWhitespaceInString);
           // Include the whitespace in the string contents for error recovery.
-          result += contents.substr(0, after_space);
+          AppendFrontOfContents(buffer_cursor, contents, after_space);
         }
         contents = contents.substr(after_space);
         continue;
@@ -420,7 +445,7 @@ static auto ExpandEscapeSequencesAndRemoveIndent(
 
       if (!contents.consume_front(escape)) {
         // This is not an escape sequence, just a raw `\`.
-        result += contents.front();
+        AppendChar(buffer_cursor, contents.front());
         contents = contents.drop_front(1);
         continue;
       }
@@ -432,14 +457,15 @@ static auto ExpandEscapeSequencesAndRemoveIndent(
       }
 
       // Handle this escape sequence.
-      ExpandAndConsumeEscapeSequence(emitter, contents, result);
-      last_escape_length = result.length();
+      ExpandAndConsumeEscapeSequence(emitter, contents, buffer_cursor);
+      buffer_last_escape = buffer_cursor;
     }
   }
 }
 
-auto StringLiteral::ComputeValue(LexerDiagnosticEmitter& emitter) const
-    -> std::string {
+auto StringLiteral::ComputeValue(llvm::BumpPtrAllocator& allocator,
+                                 LexerDiagnosticEmitter& emitter) const
+    -> llvm::StringRef {
   if (!is_terminated_) {
     return "";
   }
@@ -451,8 +477,20 @@ auto StringLiteral::ComputeValue(LexerDiagnosticEmitter& emitter) const
   }
   llvm::StringRef indent =
       multi_line_ ? CheckIndent(emitter, text_, content_) : llvm::StringRef();
-  return ExpandEscapeSequencesAndRemoveIndent(emitter, content_, hash_level_,
-                                              indent);
+  if (!content_needs_validation_ && (!multi_line_ || indent.empty())) {
+    return content_;
+  }
+
+  // "Expanding" escape sequences should only ever shorten content. As a
+  // consequence, the output string should allows fit within this allocation.
+  // Although this may waste some space, it avoids a reallocation.
+  auto result = ExpandEscapeSequencesAndRemoveIndent(
+      emitter, content_, hash_level_, indent,
+      allocator.Allocate<char>(content_.size()));
+  CARBON_CHECK(result.size() <= content_.size())
+      << "Content grew from " << content_.size() << " to " << result.size()
+      << ": `" << content_ << "`";
+  return result;
 }
 
 }  // namespace Carbon::Lex

+ 14 - 5
toolchain/lex/string_literal.h

@@ -6,9 +6,9 @@
 #define CARBON_TOOLCHAIN_LEX_STRING_LITERAL_H_
 
 #include <optional>
-#include <string>
 
 #include "llvm/ADT/StringRef.h"
+#include "llvm/Support/Allocator.h"
 #include "toolchain/diagnostics/diagnostic_emitter.h"
 
 namespace Carbon::Lex {
@@ -24,8 +24,13 @@ class StringLiteral {
 
   // Expand any escape sequences in the given string literal and compute the
   // resulting value. This handles error recovery internally and cannot fail.
-  auto ComputeValue(DiagnosticEmitter<const char*>& emitter) const
-      -> std::string;
+  //
+  // When content_needs_validation_ is false and the string has no indent to
+  // deal with, this can return the content directly. Otherwise, the allocator
+  // will be used for the StringRef.
+  auto ComputeValue(llvm::BumpPtrAllocator& allocator,
+                    DiagnosticEmitter<const char*>& emitter) const
+      -> llvm::StringRef;
 
   // Get the text corresponding to this literal.
   [[nodiscard]] auto text() const -> llvm::StringRef { return text_; }
@@ -46,10 +51,11 @@ class StringLiteral {
   struct Introducer;
 
   explicit StringLiteral(llvm::StringRef text, llvm::StringRef content,
-                         int hash_level, MultiLineKind multi_line,
-                         bool is_terminated)
+                         bool content_needs_validation, int hash_level,
+                         MultiLineKind multi_line, bool is_terminated)
       : text_(text),
         content_(content),
+        content_needs_validation_(content_needs_validation),
         hash_level_(hash_level),
         multi_line_(multi_line),
         is_terminated_(is_terminated) {}
@@ -61,6 +67,9 @@ class StringLiteral {
   // at the start of the closing `"""`. Leading whitespace is not removed from
   // either end.
   llvm::StringRef content_;
+  // Whether content needs validation, in particular due to either an escape
+  // (which needs modifications) or a tab character (which may cause a warning).
+  bool content_needs_validation_;
   // The number of `#`s preceding the opening `"` or `"""`.
   int hash_level_;
   // Whether this was a multi-line string literal.

+ 40 - 16
toolchain/lex/string_literal_benchmark.cpp

@@ -82,37 +82,61 @@ BENCHMARK(BM_IncompleteWithEscapes_Multiline);
 BENCHMARK(BM_IncompleteWithEscapes_MultilineDoubleQuote);
 BENCHMARK(BM_IncompleteWithEscapes_Raw);
 
-static void BM_SimpleStringValue(benchmark::State& state,
-                                 std::string_view introducer,
+static void BM_SimpleStringValue(benchmark::State& state, int size,
+                                 std::string_view introducer, bool add_escape,
                                  std::string_view terminator) {
+  llvm::BumpPtrAllocator allocator;
   std::string x(introducer);
-  x.append(100000, 'a');
+  x.append(size, 'a');
+  if (add_escape) {
+    // Adds a basic escape that forces ComputeValue to generate a new string.
+    x.append("\\\\");
+  }
   x.append(terminator);
   for (auto _ : state) {
-    StringLiteral::Lex(x)->ComputeValue(NullDiagnosticEmitter<const char*>());
+    StringLiteral::Lex(x)->ComputeValue(allocator,
+                                        NullDiagnosticEmitter<const char*>());
   }
 }
 
-static void BM_SimpleStringValue_Simple(benchmark::State& state) {
-  BM_SimpleStringValue(state, "\"", "\"");
+static void BM_ComputeValue_NoGenerate_Short(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10, "\"", /*add_escape=*/false, "\"");
+}
+
+static void BM_ComputeValue_NoGenerate_Long(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10000, "\"", /*add_escape=*/false, "\"");
+}
+
+static void BM_ComputeValue_WillGenerate_Short(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10, "\"", /*add_escape=*/true, "\"");
 }
 
-static void BM_SimpleStringValue_Multiline(benchmark::State& state) {
-  BM_SimpleStringValue(state, "'''\n", "\n'''");
+static void BM_ComputeValue_WillGenerate_Long(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10000, "\"", /*add_escape=*/true, "\"");
 }
 
-static void BM_SimpleStringValue_MultilineDoubleQuote(benchmark::State& state) {
-  BM_SimpleStringValue(state, "\"\"\"\n", "\n\"\"\"");
+static void BM_ComputeValue_WillGenerate_Multiline(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10000, "'''\n", /*add_escape=*/true, "\n'''");
 }
 
-static void BM_SimpleStringValue_Raw(benchmark::State& state) {
-  BM_SimpleStringValue(state, "#\"", "\"#");
+static void BM_ComputeValue_WillGenerate_MultilineDoubleQuote(
+    benchmark::State& state) {
+  BM_SimpleStringValue(state, 10000, "\"\"\"\n", /*add_escape=*/true,
+                       "\n\"\"\"");
 }
 
-BENCHMARK(BM_SimpleStringValue_Simple);
-BENCHMARK(BM_SimpleStringValue_Multiline);
-BENCHMARK(BM_SimpleStringValue_MultilineDoubleQuote);
-BENCHMARK(BM_SimpleStringValue_Raw);
+static void BM_ComputeValue_WillGenerate_Raw(benchmark::State& state) {
+  BM_SimpleStringValue(state, 10000, "#\"", /*add_escape=*/true, "\"#");
+}
+
+BENCHMARK(BM_ComputeValue_NoGenerate_Short);
+BENCHMARK(BM_ComputeValue_NoGenerate_Long);
+
+BENCHMARK(BM_ComputeValue_WillGenerate_Short);
+BENCHMARK(BM_ComputeValue_WillGenerate_Long);
+BENCHMARK(BM_ComputeValue_WillGenerate_Multiline);
+BENCHMARK(BM_ComputeValue_WillGenerate_MultilineDoubleQuote);
+BENCHMARK(BM_ComputeValue_WillGenerate_Raw);
 
 }  // namespace
 }  // namespace Carbon::Lex

+ 2 - 1
toolchain/lex/string_literal_fuzzer.cpp

@@ -33,8 +33,9 @@ extern "C" int LLVMFuzzerTestOneInput(const unsigned char* data,
   // Check multiline flag was computed correctly.
   CARBON_CHECK(token->is_multi_line() == token->text().contains('\n'));
 
+  llvm::BumpPtrAllocator allocator;
   volatile auto value =
-      token->ComputeValue(NullDiagnosticEmitter<const char*>());
+      token->ComputeValue(allocator, NullDiagnosticEmitter<const char*>());
   (void)value;
 
   return 0;

+ 4 - 3
toolchain/lex/string_literal_test.cpp

@@ -25,13 +25,14 @@ class StringLiteralTest : public ::testing::Test {
     return *result;
   }
 
-  auto Parse(llvm::StringRef text) -> std::string {
+  auto Parse(llvm::StringRef text) -> llvm::StringRef {
     StringLiteral token = Lex(text);
     Testing::SingleTokenDiagnosticTranslator translator(text);
     DiagnosticEmitter<const char*> emitter(translator, error_tracker);
-    return token.ComputeValue(emitter);
+    return token.ComputeValue(allocator, emitter);
   }
 
+  llvm::BumpPtrAllocator allocator;
   ErrorTrackingDiagnosticConsumer error_tracker;
 };
 
@@ -311,7 +312,7 @@ TEST_F(StringLiteralTest, StringLiteralBadEscapeSequence) {
 
   for (llvm::StringLiteral test : testcases) {
     error_tracker.Reset();
-    auto value = Parse(test);
+    Parse(test);
     EXPECT_TRUE(error_tracker.seen_error()) << "`" << test << "`";
     // TODO: Test value produced by error recovery.
   }

+ 1 - 5
toolchain/lex/tokenized_buffer.cpp

@@ -652,12 +652,8 @@ class [[clang::internal_linkage]] TokenizedBuffer::Lexer {
     }
 
     if (literal->is_terminated()) {
-      // TODO: Refactor to reduce copies.
-      // https://github.com/carbon-language/carbon-lang/pull/3311#discussion_r1366048360
-      buffer_.computed_strings_.push_back(
-          std::make_unique<std::string>(literal->ComputeValue(emitter_)));
       auto string_id = buffer_.value_stores_->strings().Add(
-          *buffer_.computed_strings_.back());
+          literal->ComputeValue(buffer_.allocator_, emitter_));
       auto token = buffer_.AddToken({.kind = TokenKind::StringLiteral,
                                      .token_line = string_line,
                                      .column = string_column,

+ 4 - 0
toolchain/lex/tokenized_buffer.h

@@ -14,6 +14,7 @@
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/iterator.h"
 #include "llvm/ADT/iterator_range.h"
+#include "llvm/Support/Allocator.h"
 #include "llvm/Support/raw_ostream.h"
 #include "toolchain/base/index_base.h"
 #include "toolchain/base/value_store.h"
@@ -345,6 +346,9 @@ class TokenizedBuffer : public Printable<TokenizedBuffer> {
   auto PrintToken(llvm::raw_ostream& output_stream, Token token,
                   PrintWidths widths) const -> void;
 
+  // Used to allocate computed string literals.
+  llvm::BumpPtrAllocator allocator_;
+
   SharedValueStores* value_stores_;
   SourceBuffer* source_;