Topic 5: What is FSDP and YaFSDP?

we discuss the breakthroughs in GPU optimization techniques, exploring how YaFSDP surpasses FSDP in enhancing efficiency and scalability for large language models.

5When discussing the future of AI, skeptics often point to the immense resource demands of training foundation models and the limitations in current optimization techniques. The costs of training these behemoths can be prohibitive, even for well-funded institutions.

GPU optimization techniques like Fully Sharded Data Parallel (FSDP) and its enhanced version, YaFSDP, present a promising solution. FSDP allows models to be divided across multiple GPUs, reducing memory overhead and speeding up training. YaFSDP builds on this, offering even greater efficiency and scalability. Let’s dive into the mechanics of FSDP and YaFSDP and discover how they are transforming AI optimization!

In today’s episode, we will cover:

  • What is Fully Sharded Data Parallel (FSDP)?

  • What are limitations of FSDP?

  • Here comes YaFSDP

  • Here’s how YaFSDP works

  • What YaFSDP is especially good at? Performance Gains

  • Original resources

What is Fully Sharded Data Parallel (FSDP)?

Traditional data parallelism, as exemplified by Distributed Data Parallel (DDP) in PyTorch, replicates the entire model across multiple GPUs and synchronizes gradients to ensure model consistency. However, this approach is limited by the memory capacity of individual GPUs, making it challenging to train increasingly large models.

In September 2023, researchers from Meta AI presented their work on PyTorch Fully Sharded Data Parallel (FSDP), which solves the problem of training large neural network models that exceed the memory capacity of individual GPUs. This problem is critical as large models can deliver superior performance across various domains but remain accessible only to a few advanced users and industry leaders due to technical barriers.

To address this, the team designed FSDP by integrating it closely with PyTorch's core components like Tensor implementation, dispatcher system, and CUDA memory caching allocator. This integration ensures non-intrusive user experiences and high training efficiency.

Image Credit: The original FSDP paper

FSDP works by sharding model parameters, thus optimizing resource utilization across different hardware configurations. It incorporates techniques like deferred initialization, where models are created on a dummy device and then sharded unit by unit on a real GPU. It also uses configurable sharding strategies to match cluster topologies and various communication optimizations to overlap communication with computation. These methods collectively enable FSDP to handle significantly larger models than Distributed Data Parallel while maintaining comparable performance and near-linear scalability in TFLOPS. But FSDP has a few important limitations.

What are limitations of FSDP?

This article we open for everybody but if you want to support us →

  • It might not always match local training results exactly, especially with optimizers that rely on unsharded values or specific tensor structures.

  • Handling shared parameters can be tricky, and if not done right, it can cause errors and use up more memory than necessary.

  • Even with optimizations, there can still be a lot of communication overhead, particularly with larger models or more GPUs, which requires careful strategy adjustments.

  • Setting up FSDP can be complicated, especially for large models, sometimes needing other methods that come with their own challenges.

  • Mixing FSDP with other parallelism methods like pipeline or tensor parallelism can be tough and needs careful setup to avoid too much overhead.

  • Additionally, because it integrates deeply with PyTorch, FSDP can be harder to debug and troubleshoot compared to simpler data parallelism methods.

As the authors of YaFSDP approach wrote: “Despite all these advantages, there are also issues that we faced:

  1. FSDP dynamically allocates memory for layers and sometimes requires much more memory than is actually necessary.

  2. During backward passes, we came across a phenomenon that we called the “give-way effect”.

The first line here is the computation stream, and the other lines represent communication streams.

So what’s happening in the profile? Before the reduce_scatter operation (blue), there are many preparatory computations (small operations under the communications). The small computations run in parallel with the main computation stream, severely slowing down communications. This results in large gaps between communications, and consequently, the same gaps occur in the computation stream.”

Here comes YaFSDP

Yet Another Fully Sharded Data Parallel (YaFSDP) YaFSDP was developed and open-sourced by Yandex in May 2024. It improves upon FSDP by offering more efficient memory management, reducing redundant computations, and optimizing communication and synchronization during the training of large language models.

Here’s how YaFSDP works

  1. Layer Sharding: Instead of sharding individual parameters, YaFSDP shards entire layers. This approach maintains efficient communication and reduces redundancy. Each GPU handles a different shard of the model, minimizing memory usage.

Image Credit: Yandex at Medium

  1. Buffer Pre-allocation: YaFSDP pre-allocates buffers for all necessary data, ensuring that memory management by the Torch allocator does not introduce inefficiencies. This method uses two buffers for intermediate weights and gradients, alternating between odd and even layers.

Image Credit: Yandex at Medium

Memory Consumption Optimization

YaFSDP significantly reduces memory consumption by:

  1. Efficient Buffer Use: Buffers store intermediate values and consume a constant amount of memory.

  2. Activation Checkpointing: stores only essential activations during the forward pass and recomputes them in the backward pass, significantly reducing memory usage. For instance, training a Llama 2 70B model with a batch size of 8192 tokens can reduce activation storage from over 110 GB to 5 GB. However, this technique adds computational overhead, which YaFSDP can mitigate by optimizing memory usage and avoiding activation checkpointing for some layers.

  3. Sharding Weights, Gradients, and Optimizer States: These components' memory consumption tends to approach zero as the number of processes increases, minimizing duplication.

Communication Optimization

YaFSDP enhances communication efficiency by:

  1. Overlapping Communication with Computation: Using CUDA streams, YaFSDP manages concurrent computations and communications effectively. Two streams are used: one for computation and one for communication, synchronized by events to ensure correct operation order.

  2. Reducing Communication Overhead: By ensuring that data transfers occur only when necessary and using techniques to minimize redundant operations, YaFSDP improves overall efficiency.

What YaFSDP is especially good at? Performance Gains

YaFSDP demonstrates significant performance improvements. For a model with 70 billion parameters, it can save the resources of approximately 150 GPUs, translating to monthly cost savings of $0.5 to $1.5 million. Training time is reduced by up to 26% compared to existing methods like FSDP.

When compared to FSDP, the final speedup shown by YaFSDP on Llama 2 and Llama 3 demonstrates significant improvements in training efficiency.

YaFSDP can be used in conjunction with huggingface workflows and is up to 25% faster compared to FSDP.

Original resources

How did you like it?

Login or Subscribe to participate in polls.

Thank you for reading! Share this article with three friends and get a 1-month subscription free! 🤍


or to participate.