ソースを参照

Make more use of llvm::zip and enumerate in explorer (#3135)

I was initially looking at llvm::seq, but then looking through uses it
seemed like enumerate/zip would be best for these, allowing less direct
indexing and more range-based looping.
Jon Ross-Perkins 2 年 前
コミット
c3da338ce4

+ 16 - 19
explorer/interpreter/interpreter.cpp

@@ -11,7 +11,6 @@
 #include <optional>
 #include <random>
 #include <utility>
-#include <variant>
 #include <vector>
 
 #include "common/check.h"
@@ -28,13 +27,14 @@
 #include "explorer/base/trace_stream.h"
 #include "explorer/interpreter/action.h"
 #include "explorer/interpreter/action_stack.h"
+#include "explorer/interpreter/heap.h"
 #include "explorer/interpreter/pattern_match.h"
-#include "explorer/interpreter/stack.h"
 #include "explorer/interpreter/type_utils.h"
 #include "llvm/ADT/APInt.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/Support/Casting.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
 #include "llvm/Support/raw_ostream.h"
 
@@ -290,10 +290,9 @@ auto Interpreter::EvalPrim(Operator op, Nonnull<const Value*> /*static_type*/,
 auto Interpreter::CreateStruct(const std::vector<FieldInitializer>& fields,
                                const std::vector<Nonnull<const Value*>>& values)
     -> Nonnull<const Value*> {
-  CARBON_CHECK(fields.size() == values.size());
   std::vector<NamedValue> elements;
-  for (size_t i = 0; i < fields.size(); ++i) {
-    elements.push_back({fields[i].name(), values[i]});
+  for (const auto [field, value] : llvm::zip_equal(fields, values)) {
+    elements.push_back({field.name(), value});
   }
 
   return arena_->New<StructValue>(std::move(elements));
@@ -741,14 +740,11 @@ auto Interpreter::Convert(Nonnull<const Value*> value,
           return value;
         }
       }
-      CARBON_CHECK(tuple->elements().size() ==
-                   destination_element_types.size());
       std::vector<Nonnull<const Value*>> new_elements;
-      for (size_t i = 0; i < tuple->elements().size(); ++i) {
-        CARBON_ASSIGN_OR_RETURN(
-            Nonnull<const Value*> val,
-            Convert(tuple->elements()[i], destination_element_types[i],
-                    source_loc));
+      for (const auto [element, dest_type] :
+           llvm::zip_equal(tuple->elements(), destination_element_types)) {
+        CARBON_ASSIGN_OR_RETURN(Nonnull<const Value*> val,
+                                Convert(element, dest_type, source_loc));
         new_elements.push_back(val);
       }
       return arena_->New<TupleValue>(std::move(new_elements));
@@ -1534,17 +1530,18 @@ auto Interpreter::StepExp() -> ErrorOr<Success> {
         // -> { {C',E',F'} :: {C, E, F} :: S, H}
         // Prepare parameters tuple.
         std::vector<Nonnull<const Value*>> param_values;
-        for (int i = 1; i <= num_args; ++i) {
-          param_values.push_back(act.results()[i]);
+        for (const auto& arg_result :
+             llvm::ArrayRef(act.results()).slice(1, num_args)) {
+          param_values.push_back(arg_result);
         }
         const auto* param_tuple = arena_->New<TupleValue>(param_values);
         // Prepare witnesses.
         ImplWitnessMap witnesses;
         if (num_witnesses > 0) {
-          int i = 1 + num_args;
-          for (const auto& [impl_bind, impl_exp] : call.witnesses()) {
-            witnesses[impl_bind] = act.results()[i];
-            ++i;
+          for (const auto [witness, result] : llvm::zip(
+                   call.witnesses(),
+                   llvm::ArrayRef(act.results()).drop_front(1 + num_args))) {
+            witnesses[witness.first] = result;
           }
         }
         return CallFunction(call, act.results()[0], param_tuple,

+ 0 - 9
explorer/interpreter/interpreter.h

@@ -5,20 +5,11 @@
 #ifndef CARBON_EXPLORER_INTERPRETER_INTERPRETER_H_
 #define CARBON_EXPLORER_INTERPRETER_INTERPRETER_H_
 
-#include <optional>
-#include <utility>
-#include <vector>
-
 #include "common/ostream.h"
 #include "explorer/ast/ast.h"
-#include "explorer/ast/declaration.h"
 #include "explorer/ast/expression.h"
-#include "explorer/ast/pattern.h"
 #include "explorer/ast/value.h"
 #include "explorer/base/trace_stream.h"
-#include "explorer/interpreter/action.h"
-#include "explorer/interpreter/heap.h"
-#include "llvm/ADT/ArrayRef.h"
 
 namespace Carbon {
 

+ 25 - 33
explorer/interpreter/type_checker.cpp

@@ -5,7 +5,6 @@
 #include "explorer/interpreter/type_checker.h"
 
 #include <deque>
-#include <iterator>
 #include <map>
 #include <optional>
 #include <set>
@@ -32,11 +31,11 @@
 #include "explorer/interpreter/pattern_match.h"
 #include "explorer/interpreter/type_structure.h"
 #include "explorer/interpreter/type_utils.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/TinyPtrVector.h"
 #include "llvm/Support/Casting.h"
-#include "llvm/Support/Error.h"
 #include "llvm/Support/SaveAndRestore.h"
 
 using llvm::cast;
@@ -469,11 +468,11 @@ auto TypeChecker::IsBuiltinConversion(SourceLocation source_loc,
             break;
           }
           bool all_ok = true;
-          for (size_t i = 0; i < source_tuple.elements().size(); ++i) {
+          for (const auto [source_elem, dest_elem] : llvm::zip_equal(
+                   source_tuple.elements(), destination_tuple.elements())) {
             CARBON_ASSIGN_OR_RETURN(
                 bool convertible,
-                IsImplicitlyConvertible(source_loc, source_tuple.elements()[i],
-                                        destination_tuple.elements()[i],
+                IsImplicitlyConvertible(source_loc, source_elem, dest_elem,
                                         impl_scope,
                                         allow_user_defined_conversions));
             if (!convertible) {
@@ -680,7 +679,8 @@ auto TypeChecker::BuildBuiltinConversion(Nonnull<Expression*> source,
             return conversion_failed();
           }
           std::vector<Nonnull<Expression*>> converted_elements;
-          for (size_t i = 0; i < source_tuple.elements().size(); ++i) {
+          for (const auto [i, dest_elem] :
+               llvm::enumerate(destination_tuple.elements())) {
             auto* elem = arena_->New<IndexExpression>(
                 source->source_loc(), source,
                 arena_->New<IntLiteral>(source->source_loc(), i));
@@ -688,7 +688,7 @@ auto TypeChecker::BuildBuiltinConversion(Nonnull<Expression*> source,
             CARBON_ASSIGN_OR_RETURN(
                 Nonnull<Expression*> converted,
                 ImplicitlyConvert("implicit conversion", impl_scope, elem,
-                                  destination_tuple.elements()[i]));
+                                  dest_elem));
             converted_elements.push_back(converted);
           }
           auto* result = arena_->New<TupleLiteral>(
@@ -853,12 +853,12 @@ auto TypeChecker::ImplicitlyConvert(std::string_view context,
                << "`" << *destination << "` of different length";
       }
       std::vector<Nonnull<Expression*>> converted_elements;
-      for (size_t i = 0; i < source_tuple->fields().size(); ++i) {
+      for (const auto [source_field, dest_elem] : llvm::zip_equal(
+               source_tuple->fields(), destination_tuple->elements())) {
         CARBON_ASSIGN_OR_RETURN(
             Nonnull<Expression*> converted,
-            ImplicitlyConvert("implicit conversion", impl_scope,
-                              source_tuple->fields()[i],
-                              destination_tuple->elements()[i]));
+            ImplicitlyConvert("implicit conversion", impl_scope, source_field,
+                              dest_elem));
         converted_elements.push_back(converted);
       }
       auto* result = arena_->New<TupleLiteral>(source->source_loc(),
@@ -1163,10 +1163,10 @@ auto TypeChecker::ArgumentDeduction::Deduce(Nonnull<const Value*> param,
                << param_tup.elements().size() << " but got "
                << arg_tup.elements().size();
       }
-      for (size_t i = 0; i < param_tup.elements().size(); ++i) {
-        CARBON_RETURN_IF_ERROR(Deduce(param_tup.elements()[i],
-                                      arg_tup.elements()[i],
-                                      allow_implicit_conversion));
+      for (const auto [param_elem, arg_elem] :
+           llvm::zip_equal(param_tup.elements(), arg_tup.elements())) {
+        CARBON_RETURN_IF_ERROR(
+            Deduce(param_elem, arg_elem, allow_implicit_conversion));
       }
       return Success();
     }
@@ -1207,10 +1207,8 @@ auto TypeChecker::ArgumentDeduction::Deduce(Nonnull<const Value*> param,
                          << "duplicate field name?";
         }
       } else {
-        size_t smaller_size = std::min(param_fields.size(), arg_fields.size());
-        for (size_t i = 0; i < smaller_size; ++i) {
-          NamedValue param_field = param_fields[i];
-          NamedValue arg_field = arg_fields[i];
+        for (const auto [param_field, arg_field] :
+             llvm::zip(param_fields, arg_fields)) {
           if (param_field.name != arg_field.name) {
             return ProgramError(source_loc_)
                    << "mismatch in field names, `" << param_field.name
@@ -1592,8 +1590,7 @@ class TypeChecker::ConstraintTypeBuilder {
   // Adds an `impls` constraint -- `T impls C` if not already present.
   // Returns the index of the impls constraint within the self witness.
   auto AddImplsConstraint(ImplsConstraint impls) -> int {
-    for (int i = 0; i != static_cast<int>(impls_constraints_.size()); ++i) {
-      ImplsConstraint& existing = impls_constraints_[i];
+    for (const auto [i, existing] : llvm::enumerate(impls_constraints_)) {
       if (TypeEqual(existing.type, impls.type, std::nullopt) &&
           TypeEqual(existing.interface, impls.interface, std::nullopt)) {
         return i;
@@ -2508,9 +2505,7 @@ auto TypeChecker::DeduceCallBindings(
 
   // Deduce and/or convert each argument to the corresponding
   // parameter.
-  for (size_t i = 0; i < params.size(); ++i) {
-    const Value* param = params[i];
-    Expression* arg = args[i];
+  for (const auto [i, param, arg] : llvm::enumerate(params, args)) {
     if (!generic_params.empty() && generic_params.front().index == i) {
       // The parameter is a `:!` binding. Collect its argument so we can
       // evaluate it when we're done with deduction.
@@ -3658,11 +3653,10 @@ auto TypeChecker::TypeCheckExpImpl(Nonnull<Expression*> e,
 
           // Collect the top-level generic parameters and their constraints.
           std::vector<FunctionType::GenericParameter> generic_parameters;
-          llvm::ArrayRef<Nonnull<const Pattern*>> params =
-              param_name.params().fields();
-          for (size_t i = 0; i != params.size(); ++i) {
+          for (const auto [i, param] :
+               llvm::enumerate(param_name.params().fields())) {
             // TODO: Should we disallow all other kinds of top-level params?
-            if (const auto* binding = dyn_cast<GenericBinding>(params[i])) {
+            if (const auto* binding = dyn_cast<GenericBinding>(param)) {
               generic_parameters.push_back({{}, i, binding});
             }
           }
@@ -4390,8 +4384,7 @@ auto TypeChecker::TypeCheckPattern(
                           cast<TupleType>(**expected).elements().size()) {
         return ProgramError(tuple.source_loc()) << "tuples of different length";
       }
-      for (size_t i = 0; i < tuple.fields().size(); ++i) {
-        Nonnull<Pattern*> field = tuple.fields()[i];
+      for (const auto [i, field] : llvm::enumerate(tuple.fields())) {
         std::optional<Nonnull<const Value*>> expected_field_type;
         if (expected) {
           expected_field_type = cast<TupleType>(**expected).elements()[i];
@@ -4946,9 +4939,8 @@ auto TypeChecker::DeclareCallableDeclaration(Nonnull<CallableDeclaration*> f,
   // Keep track of any generic parameters and nested generic bindings in the
   // parameter pattern.
   std::vector<FunctionType::GenericParameter> generic_parameters;
-  for (size_t i = 0; i != f->param_pattern().fields().size(); ++i) {
-    Pattern* param_pattern = f->param_pattern().fields()[i];
-
+  for (const auto [i, param_pattern] :
+       llvm::enumerate(f->param_pattern().fields())) {
     size_t old_size = all_bindings.size();
     CollectAndNumberGenericBindingsInPattern(param_pattern, all_bindings);