<aside> 💡 FYI: I write a similar piece every weekend. Subscribe by emailing literally anything to [email protected]
</aside>
You might have heard the news of AI companies acquiring increasingly large clusters of GPUs to train models — Meta announced that they have 600K GPUs, Sam Altman is trying to raise $7T to build an NVIDIA competitor, and NVIDIA keeps surpassing the GDP of countries in net worth.
But how are these companies using all these GPUs? How are large models with billions of parameters trained in a distributed fashion? There’s two realms of parallelism techniques that most modern training setups utilize to train these models at scale — split your data across machines (data parallelism) or split your model across machines (model parallelism).
Today, we dive deep into how data parallelism works.
First, let’s all agree on how single-gpu training works.
Step 1: Input data. This is some vector of integers $x$.
Step 2: Forward pass. $x$ is put through the model, resulting in some output $\hat{y}$. This is the forward pass $f(x) = \hat{y}$.
Step 3: Loss calculation. We have some ground truth label $y$ and take the loss between them: $L(y, \hat{y})$.
Step 4: Backward pass. We calculate the derivate of the loss with respect to all of the model weights: $g(w) = \frac{d(L)}{dw}\ \forall w$.
Step 5: Optimizer step. Finally, we update the weights by the value of the gradient times some constant factor $c$: $w_{new} = w - c\dot g(w)$. These are the new (and hopefully better!) weights of our model.
Repeat until GPT. That’s it!
graph TD
A[Data]
B[Forward Pass]
C[Backward Pass]
D[Optimizer Step]
A --> B
B --> C
C --> D
As far as possible, we want to scale vertically— making our 1 GPU setup chunkier, by increasing the GPU’s memory, power, etc. But this is not a good solution, because:
So when we can no longer scale vertically, we need to figure out how to scale horizontally. In all parallelism strategies we will investigate, the goal is to shard something across multiple machines— either the data sequences, or the model weights. We need to do this in a way that preserves ML semantics — it makes no sense to do this if we get a much worse model with many more resources.
<aside>
💡 The idea behind data parallelism is this: If we copy our neural network onto N
GPUs, assign each of them a portion of data to work on, then coalesce all your N
models together somehow… then we would be processing data at a rate N
times greater than before!
</aside>
Here’s how one loop of this process happens: