Skip to content

Pm/view propagation #4306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft

Pm/view propagation #4306

wants to merge 2 commits into from

Conversation

Priya2698
Copy link
Collaborator

No description provided.

Copy link

github-actions bot commented Apr 24, 2025

Review updated until commit 6744b00

Description

  • Added view operation sharding support

  • Implemented reshaped ID handling in propagate_shardings.cpp

  • Updated tests to include multiple transform reshapes

  • Modified reference functions for loop split MHA


Changes walkthrough 📝

Relevant files
Enhancement
utils.cpp
Namespace adjustment and function implementation                 

csrc/multidevice/utils.cpp

  • Moved namespace closing brace
  • Added getInputsInTargetDomain function implementation
  • +2/-2     
    propagate_shardings.cpp
    View operation sharding implementation                                     

    csrc/preseg_passes/propagate_shardings.cpp

  • Added getReshapedIds, splitLike, and shardViewOp functions
  • Integrated shardViewOp into PropagateShardingsPass::runPass
  • +161/-0 
    utils.h
    Function declaration addition                                                       

    csrc/multidevice/utils.h

    • Added getInputsInTargetDomain function declaration
    +4/-0     
    Tests
    test_multidevice_sharding.cpp
    New test case for multiple transforms                                       

    tests/cpp/test_multidevice_sharding.cpp

    • Added MultipleTransformReshape test case
    +25/-0   
    test_multidevice_transformer.cpp
    Update reference function and test case                                   

    tests/cpp/test_multidevice_transformer.cpp

  • Updated reference_loop_split_mha function
  • Modified LoopSplitMHAFwd test case
  • +33/-15 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Complexity

    The shardViewOp function is quite complex and involves multiple nested loops and conditionals. It would be beneficial to break down this function into smaller, more manageable functions to improve readability and maintainability.

    void shardViewOp(ViewOp* view_op, int64_t& did_pos) {
      // This implementation asserts that only one sharding is applied on the
      // reshaped ids. Inner split is not supported. The cases are:
      // 1. Split reshape: [h] -> [a, h/a]. Sharding on h is applied to a in
      // consumer.
      // 2. Merge reshape: [a, h/a] -> [h]. Sharding on a is applied to h in
      // consumer.
      // 3. Multiple splits or merge reshapes: [x, y, z] -> [xyz]. Sharding on x and
      // xyz. Similarly for the corresponding split reshape.
      // 4. Independent splits or merge reshapes: [w, x, y, z] -> [wx, yz]. Sharding
      // is on w and y. In the consumer, it is applied to wx and yz. An improvement
      // is to support mult-levels of sharding (not a real case yet) if they
      // are all outer splits. For example: For the reshape [h] -> [a, h/a] where
      // the h is sharded twice: [h] -> [cp, h/cp] -> [cp, tp, h/(cp*tp)]
    
      // A more general approach maybe to "undo" the reshape (reverse transforms
      // from root to logical domain), followed by simplification of the consumer
      // loop domain to move DID upwards.
    
      TensorView* producer = view_op->in();
      TensorView* consumer = view_op->out();
    
      const std::unordered_map<IterDomain*, IterDomain*>& c2p =
          PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer();
      const std::unordered_map<IterDomain*, IterDomain*>& p2c =
          PairwiseLogicalDomainMap(producer, consumer).mapProducerToConsumer();
      auto [p_logical_reshaped_ids, c_root_reshaped_ids] =
          getReshapedIds(view_op, c2p);
    
      auto p_loop_domain = producer->getLoopDomain();
      auto c_loop_domain = consumer->getLoopDomain();
      auto c_logical_domain = consumer->getLogicalDomain();
    
      // Track number of DID axis on reshaped ids that were propagated to the
      // consumer. These will not be included in TransformPropagator.
      int64_t num_reshape_shardings = 0;
    
      for (auto idx : c10::irange(did_pos)) {
        IterDomain* p_did = p_loop_domain.at(idx);
        NVF_ERROR(p_did->isDeviceDim());
    
        auto p_transforms = DependencyCheck::getAllExprsBetween(
            {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()},
            {p_loop_domain.at(idx)});
    
        if (p_transforms.empty()) {
          // This device axis is not on reshaped ids. We will use the
          // TransformPropagator.
          continue;
        }
    
        if (p_transforms.size() > 1) {
          // This reshape has been transformed.
          // This can happen, for example, when there is a consumer-to-producer
          // propagation before this pass.
          // We will attempt to use TransformPropagator for this DID axis.
          continue;
        }
    
        NVF_ERROR(
            p_transforms.front()->isA<Split>(),
            "Expected a split transform producing the did axis.");
        NVF_ERROR(
            TensorDomain::sameAs(c_logical_domain, c_loop_domain),
            "Sharding a previously transformed reshape is not supported.");
    
        num_reshape_shardings++;
    
        // Find the producer logical id that is sharded.
        // We expect the outermost reshaped id to be sharded and follow the
        // outermost path traversing the transforms
        auto* p_did_split = p_did->definition()->as<Split>();
        IterDomain* p_logical_did = p_did_split->in();
    
        // Find the mapping of the corresponding producer logical id in consumer
        // root.
        IterDomain* c_root_did = p2c.at(p_logical_did);
    
        // Get the reshape transforms corresponding to this root id.
        // We use the c_root_did to only find the reshape IDs related to this did.
        auto reshape_transforms = DependencyCheck::getAllExprsBetween(
            {c_root_did},
            {consumer->getLogicalDomain().begin(),
             consumer->getLogicalDomain().end()});
    
        // Obtain the logical axis sharded in the consumer.
        IterDomain* c_logical_did = c_root_did;
        for (auto transform : reshape_transforms) {
          if (transform->isA<Split>()) {
            c_logical_did = transform->as<Split>()->outer();
          }
          if (transform->isA<Merge>()) {
            NVF_ERROR(
                c_logical_did == transform->as<Merge>()->outer(),
                "Expected the sharding to be on the outer reshaped id.");
            c_logical_did = transform->as<Merge>()->out();
          }
        }
    
        int64_t sharded_axis = std::distance(
            c_loop_domain.begin(),
            std::find(c_loop_domain.begin(), c_loop_domain.end(), c_logical_did));
    
        // TODO: Check for divisibility of the consumer axis by the split factor.
        splitLike(consumer, sharded_axis, p_did_split);
        consumer->axis(sharded_axis)->parallelize(p_did->getParallelType());
    
        // Move this did_pos behind the non-propagated DID axis to avoid using
        // TransformPropagator on it.
        producer->reorder({{idx, did_pos - 1}});
      }
    
      did_pos -= num_reshape_shardings;
    Error Handling

    The function shardViewOp includes several NVF_ERROR assertions. It would be good to ensure that these assertions are comprehensive and that all potential error cases are covered.

    IterDomain* p_did = p_loop_domain.at(idx);
    NVF_ERROR(p_did->isDeviceDim());
    
    auto p_transforms = DependencyCheck::getAllExprsBetween(
        {p_logical_reshaped_ids.begin(), p_logical_reshaped_ids.end()},
        {p_loop_domain.at(idx)});
    
    if (p_transforms.empty()) {
      // This device axis is not on reshaped ids. We will use the
      // TransformPropagator.
      continue;
    }
    
    if (p_transforms.size() > 1) {
      // This reshape has been transformed.
      // This can happen, for example, when there is a consumer-to-producer
      // propagation before this pass.
      // We will attempt to use TransformPropagator for this DID axis.
      continue;
    }
    
    NVF_ERROR(
        p_transforms.front()->isA<Split>(),
        "Expected a split transform producing the did axis.");
    NVF_ERROR(
        TensorDomain::sameAs(c_logical_domain, c_loop_domain),
        "Sharding a previously transformed reshape is not supported.");
    
    Test Coverage

    The new test MultipleTransformReshape covers a specific case but it would be beneficial to add more test cases to ensure the robustness of the new functionality.

        at::TensorOptions().dtype(at_dtype).device(communicator_->device());
    auto x_ = at::randn({B * S, E}, options);
    auto ln0_w_ = at::randn(E, options).to(at::kFloat);
    auto ln0_b_ = at::randn(E, options).to(at::kFloat);
    auto mha_w0_ = at::randn({3 * E, E}, options) * kParamScale;
    auto mha_b0_ = at::randn({3 * E}, options) * kParamScale;
    auto mha_w1_ = at::randn({E, E}, options) * kParamScale;
    auto mha_b1_ = at::randn({E}, options) * kParamScale;
    auto ln1_w_ = at::randn(E, options).to(at::kFloat);
    auto ln1_b_ = at::randn(E, options).to(at::kFloat);
    auto mlp_w0_ = at::randn({4 * E, E}, options) * kParamScale;
    auto mlp_b0_ = at::randn({4 * E}, options) * kParamScale;
    auto grad_ = at::randn({B * S, E}, options) * kParamScale;
    auto mlp_w1_ = at::randn({E, 4 * E}, options) * kParamScale;
    auto mlp_b1_ = at::randn({E}, options) * kParamScale;
    
    at::manual_seed(getATenRandomSeed());
    // Run forward pass up to MLP to generate cached inputs
    auto [ln0_, ln0_mean_, ln0_rstd_] = at::native_layer_norm(
        x_.to(at::kFloat), norm_shape, ln0_w_, ln0_b_, kEps);
    auto mha_in_ = ln0_.to(at_dtype);
    auto mha_out_ = reference_mha(mha_in_, mha_w0_, mha_b0_, mha_w1_, mha_b1_);
    auto resid0_ = mha_out_[3] + x_.to(at::kFloat);
    auto [ln1_, ln1_mean_, ln1_rstd_] =
        at::native_layer_norm(resid0_, norm_shape, ln1_w_, ln1_b_, kEps);
    auto mlp_in_ = ln1_.to(at_dtype);

    @Priya2698 Priya2698 closed this Apr 24, 2025
    @Priya2698 Priya2698 force-pushed the pm/view_propagation branch from f37e75d to 95c9bde Compare April 24, 2025 23:24
    @Priya2698 Priya2698 reopened this Apr 25, 2025
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    1 participant