Previously, in the AI Supercomputing and AI Supercomputing (part 2) posts, we summarized existing Machine Learning (ML) parallelism techniques. Later, in Building a GPT model from scratch, we built GPT-lite, the small variant of the GPT-2 model. In this post, we will perform large-scale parallel training of a GPT model and a large DNN on a network of 8 GPUs, using DeepSpeed and ZeRO (Zero Redundancy Optimizer). The DeepSpeed API is a lightweight wrapper on PyTorch, and can be installed by the deepspeed package for python.

3D Parallelism

An ML model allows for three types of parallelism, that can be combined into what we call 3D parallelism on Data, Pipeline and Tensors/Models.

The 3D parallelism aims and partitioning (color-coded) computer resources across the 3D space of data, pipeline and tensor (model) dimensions. In this post we will cover data parallelism. Source: Microsoft Research Blog

Data parallelism, by dividing the number of samples (batch size) across processors. Data parallelism is the focus of this post, and refers mainly to either of the following approaches:

  • Distributed Data Parallel keeps a full copy of the model (weights, optimizer parameters and gradients) in all processors. All models are initialized equally. Each processor takes as input to its model a different minibatch and performs a forward pass to compute the loss. On the backward pass, at every layer of the model, each processor computes its own gradients for its batch, and mean-reduces across all processors. This leads to all processors having then the same weight updates, keeping the model in sync throughout execution.
    • communication and computation overlap: in the pytorch DDP implementation, the backward pass iteratively computes gradients (from last to first layer) and collects blocks of gradients to be communicated. These blocks will be mean-reduced asynchronously during the backward pass, while the computation for the backward pass proceeds. Therefore it overlaps backward pass computation with gradients communication. At the end of the backward pass, all GPUs wait for all gradient all-reduces to finish, and then triggers the parameter updates.
  • ZeRO-DP (Zero-Redundancy Optimizer, Data Parallel) implement Sharding, sometimes reffered to Fully-Sharded Data Parallelism (FSDP). Here, processors dont hold a full copy of the model, but only the parameters, optimizer states and gradients to different/distinct subsets of layers. Different processors input different mini-batches, and there is no sharding of activations i.e. they are kept fully on each processor (with activations that refer to the corresponding input). ZeRO has three alternative execution modes, called stages. Each stage represents a different level of memory redundancy, corresponding to different variables being communicated or not. These are enabled cumulatively. In practice, by increasing the stage we define the tradeoff between memory usage and communication:
    • ZeRO stage 0 is equivalent to no distributed model variables, and to Distributed Data Parallelism;
    • ZeRO stage 1 (ZeRO-1): the optimizer states (e.g., for Adam optimizer, 32-bit weights, and the first, and second moment estimates) are partitioned across the processes, so that each process updates only its partition. Affects backward pass runtime.
    • ZeRO stage 2 (ZeRO-2): the reduced 32-bit gradients for updating the model weights are also partitioned such that each process retains only the gradients corresponding to its portion of the optimizer states. Also relevant only for the backward pass.
    • ZeRO stage 3 (ZeRO-3): the 16-bit model parameters are partitioned across the processes. Includes extra communication on both the forward and backward passes.

Memory consumption of the three different stages of ZeRO FSDP. Residual memory (activations, normalization layers, etc) is not included as FSDP does not shard them. Source: Microsoft Research blog

Additionaly, on top of stages 1 and 2, we can enable ZeRO-Offload, a system for offloading optimizer and gradient states to CPU memory. On top of stage 3, we can enable ZeRO-Infinity, also an offloading engine that extends ZeRO-offload with support to NVMe memory. According to the ZeRO-3 documentation, “ZeRO-Infinity has all of the savings of ZeRO-Offload, plus is able to offload more the model weights and has more effective bandwidth utilization and overlapping of computation and communication”.

Pipeline parallelism delegates different layers (or blocks of layers) of the model to different processors, as a pipeline. Tensor parallelism, vertical parallelism, intra-layer parallelism or sometimes simply model parallelism, partitions also the computation of activations in the forward and backward passes. These are the remaining two dimensions of parallelism and will be covered in the part 2 of this post.

We will see in this post that finding the optimal parallelism hyperparameters is a hard problem. This is a resources allocation problem on the 3D ML parallelism space, aiming at finding an optimal load balance of compute time, memory usage or throughput across resources. In practice, balanced computation yields a low overall runtime, and balanced memory allows for an increase of the maximum model size.

Model and dataset

The code that follows is applicable to any model of type torch.nn.Module and any dataset of type torch.utils.data.Dataset. So we will detail three use cases: an advanced use case, specific to a large language model (GPTlite), an out-of-the-box pre-defined model from torchvision and a simple DNN model of arbitrary width and depth used to simulate different ML workload conditions (we will call this our benchmark model).

GPTlite

We start by taking our previous GPT-lite implementation and matching the architecture of the model to the GPT-2 Small model description in Language Models are Few-Shot Learners (Fig 2.1):

## gptlite.py

# depth of the network as number of decoder blocks.
n_layer = 12

# size of the embeddings (d_model)
n_embd = 768

# number of attention heads in the Multi-Attention mechanism
n_head = 12

# block size ie max number of training sequence, the $n_{ctx}$ in the paper .
block_size = 2048

# dropout rate (variable p) for dropout units
dropout = 0.1

We then define the methods get_model() and get_dataset() that return our model and the tiny shakespeare dataset:

## gptlite.py

def get_dataset():
  
  class GPTliteDataset(torch.utils.data.Dataset):

      def __init__(self, train_data, block_size):
        self.train_data = train_data
        self.block_size = block_size

      def __len__(self):
        return len(self.train_data)

      def __getitem__(self, idx):
        # generate 1 random offset on the data
        ix = torch.randint(len(self.train_data)-self.block_size , size=())
        # input is a random subset of tokens
        x = self.train_data[ix   : ix+self.block_size]
        # target is just x shifted right (ie the next predicted word)
        y = self.train_data[ix+1 : ix+1+self.block_size]
        return x, y

  train_data, valid_dataset, vocab_size = load_tiny_shakespeare_data()
  train_dataset = GPTliteDataset(train_data, gptlite.block_size)
  valid_dataset = GPTliteDataset(valid_data, gptlite.block_size)
  return train_dataset, valid_dataset, vocab_size


def get_model(vocab_size):
  return GPTlite(vocab_size)

Using a torchvision model

If you’d want to perform a multi-class classification using the ResNet network on the CIFAR10 dataset available in torchvision, you’d define the previous 2 methods as:

import torchvision

def get_dataset():
  import torchvision.transforms as transforms
  transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  ])
  dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform)
  return dataset

def get_model(num_classes):
  return torchvision.models.resnet18(num_classes=num_classes)

As a relevant remark, pre-existing models do not define activation checkpointing layers and pipelining layers that are required to activate these two features (discuss later).

Benchmark model

If we’d want instead to test the response of DeepSpeed scaling of a very simple model of varying width and depth, we could create a benchmark model which is simply a DNN of L layers of width W, for multi-label classification, whose objective is to compute the modulo of the sum of squares of a random input vector:

The benchmark model, a DNN with L layers of dimensionality W (right)

The implementation of the benchmark model in benchmark.py is straightforward:

## benchmark.py 

class BenchmarkModel(nn.Module):
  """" DNN with L layers and W neurons per layer """

  def __init__(self, W, L, in_size, out_size):
    super(BenchmarkModel, self).__init__()
    self.layers = [nn.Linear(in_size, W), nn.ReLU()]
    for _ in range(L-2):
      self.layers += [nn.Linear(W, W), nn.ReLU()]
    self.layers += [nn.Linear(W, out_size), nn.ReLU()]
    self.layers = nn.Sequential(*self.layers)

  def forward(self, x):
    return self.layers(x)


class BenchmarkDataset(torch.utils.data.Dataset):
    def __init__(self, in_size, out_size, len=2**16):
      self.in_size = in_size
      self.len = len
      self.out_size = out_size

    def __len__(self):
      return self.len

    def __getitem__(self, _):
      x = torch.Tensor(self.in_size).uniform_(-10, 10)
      y = int( x @ x % self.out_size)
      return x, torch.tensor(y, dtype=torch.long)


get_dataset = lambda W: BenchmarkDataset(W), BenchmarkDataset(W)
get_model = lambda W, L: BenchmarkModel(W, L)

We will call this the Benchmark Model and we will use it later in our benchmark section to test DeepSpeed’s response to models of varying width and depth..

Main code

We start integrating DeepSpeed in our code by creating the ArgumentParser object that is required by the initialize() method in DeepSpeed. The ArgumentParser object must contain:

  • the --local_rank parameter that is the local rank of each process in the network, and will be populated automatically by the deepspeed launcher when launching a script;
  • optionally, we add the --deepspeed_config where we specify the path to the DeepSpeed config file. If you choose not to add it to the command line arguments, then it must be specified as the parameter config in the call to deepspeed.initialize().

The most correct way to do this is to call deepspeed.add_config_arguments(), that adds the --deepspeed_config and other DeepSpeed-specific arguments:

## train.py

import deepspeed

def get_cmd_line_args(description='GPT-lite on DeepSpeed'):
  import argparse
  parser = argparse.ArgumentParser(description=description)
  # mandatory argument for calls with deepseed
  parser.add_argument('--local_rank', type=int, default=0,
                        help='local rank passed from distributed launcher')
  # Include DeepSpeed configuration arguments (--deepspeed, --deepspeed_config, ...)
  parser = deepspeed.add_config_arguments(parser)
  return parser.parse_args()

Note: the --local_rank exists for legacy support, and the new versions of DeepSpeed will compare it with os.environ["LOCAL_RANK"] and use the latter instead. So you can pass --no_local_rank to ignore it --local_rank for simplicity. However, if you launch the run with the deepspeed launcher, --local_rank is added automatically and it can be removed from sys.arg before calling deeopseed.initialize() (below).

The bulk of the code is pretty simple. In practice, all boilerplate code that PyTorch requires for optimizers, learning rates, parallelism, data loaders etc, are all managed by DeepSpeed and are defined in its config file. So the initialization of a DeepSpeed run is pretty straightforward:

## train.py

def main_deepspeed(n_epochs=100, random_seed=42):

  torch.manual_seed(random_seed)  #set random seed (used by DataLoader)
  deepspeed.runtime.utils.set_random_seed(random_seed) #set DeepSpeed seed
  deepspeed.init_distributed()  # initialize distributed DeepSpeed
  args = get_cmd_line_args()  # initialize command line arguments parser
  criterion = torch.nn.CrossEntropyLoss()  # initialize loss function
  train_dataset, _, vocab_size = gptlite.get_dataset()  # initializer dataset
  model = gptlite.get_model(vocab_size)  # initialize model

  engine, optimizer, train_dataloader , _ = deepspeed.initialize(
    args=args, model=model, training_data=train_dataset,) # initialize deepspeed

We then write the training loop, with a structure very similar to a PyTorch implementation. The only exception is that we don’t perform zeroing of gradients, as this is managed internally by DeepSpeed. Also, train_dataloader is of type torch.utils.data.distributed.DistributedSampler and created automatically by the initialize(), so multi-process runs will have each process automatically delegated to a different subset of data.

## train.py

def main_deepspeed(n_epochs=100, random_seed=42):
  # ...
  for epoch in range(n_epochs):
    for step, data in enumerate(train_dataloader):
      inputs = data[0].to(engine.device)
      labels = data[1].to(engine.device)
              
      outputs = engine(inputs)  # fwd pass
      loss = criterion(outputs, labels)
      engine.backward(loss)  # backprop
      engine.step()  # update weights, no need for zero-ing

  # print loss for epoch
  if engine.local_rank == 0: print(f"Epoch: {epoch}, Loss: {loss}")

Config file

The real nuance and complexity in using DeepSpeed is the .json config file. The number of possible optimizations is large, as it defines parallelism, floating point precision, logger, communication parameters, etc. These fields are detailed in the DeepSpeed config documentation. Here we start with a simple config, where the configure the DeepSpeed logger to output memory and throughput info at every 10 epochs (steps_per_print), and define the settings of the optimizer (optimizer) and learning rate scheduler (scheduler):

{
  "train_batch_size": 256,
  "steps_per_print": 10,
  "optimizer": {
    "type": "AdamW",
    "params": {
      "lr": 0.001,
      "betas": [
        0.8,
        0.999
      ],
      "eps": 1e-8,
      "weight_decay": 3e-7
    }
  },
  "scheduler": {
    "type": "WarmupLR",
    "params": {
      "warmup_min_lr": 0,
      "warmup_max_lr": 0.001,
      "warmup_num_steps": 1000
    }
  }
}

Gradient accumulation based on micro-batching is a technique that simulates a large mini-batch as an iteration across several micro-batches. This is particularly relevant when the whole mini-batch does not fit into memory, and using an accumulation of micro-batches will overcome that limitation. This method is enabled by setting train_micro_batch_size_per_gpu (defaulted to train_batch_size) or gradient_accumulation_steps (defaulted to 1). At runtime, the micro-batch size can be retrieved by engine.gradient_accumulation_steps(). In our case, we will start with a micro-batch of 1 single input per GPU, that accummulate up to a batch size of 256 across all 8 GPUs, therefore resulting in 32 gradient accumulation steps:

{
  "train_batch_size": 256,
  "train_micro_batch_size_per_gpu": 1
}

ZeRO Fully-Sharded Data Parallel can be activated by specifying the relevant stage in the config file. If omitted, or when passing the stage 0, DeepSpeed is disabled and the execution follows a regular distributed data paralllel workflow:

{
  "zero_optimization": {
    "stage": 3
  }
}

Limiting the size of communication buffers is important when activating ZeRO. In practice, enabling ZeRO leads to the distribution of parameters across all processors. This in practice will add a communication overhead, that requires memory to be allocated for all buffers responsible for the data to be sent or received. This is an issue as these buffers may be large. To overcome this issue, we can decrease the maximum size of the communication buffers so that communication is performed in parcels of smaller buffers. We can also enable communication overlap that attempts to overlap the reduction of the gradients with backward computation. To enable these 2 optimizations, we add to the config:

{
  "zero_optimization": {
    "reduce_bucket_size": 4e5,
    "allgather_bucket_size": 4e5,
    "stage3_prefetch_bucket_size": 4e5,
    "overlap_comm": true,
  }
}

ZeRO-Infinity performs offloading of several variables in memory to CPU and VNMe for huge memory savings. It is only compatible with ZeRO-3 and can be enabled with:

{
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
  }
}

Mixed precision representation allows for calculus with value types (parameters, activations, accumulators) stored with different numerical representations, leading to a reduction of memory and compute time. It can be enabled by adding the fp16 entry in the config. As a side note, the amp config entry also enables mixed precision training that follows the NVIDIA Apex implementation i.e. with the O0 to O3 opimization levels. However, it is not compatible with ZeRO, therefore we won’t use it. The fp16 is equivalent to APEX optimization level O2, and according to the documentation, “if you want to use ZeRO (currently) you must use this mode”. We can enable it with the entry "fp16: { enabled: true } that is equivalent to the following default values:

{
  "fp16": {
    "enabled": true,
    "auto_cast": false,
    "loss_scale": 0,
    "initial_scale_power": 16,
    "loss_scale_window": 1000,
    "hysteresis": 2,
    "consecutive_hysteresis": false,
    "min_loss_scale": 1
  }
}

However, if your hardware supports bfloat16 (brain floating point), this should be used in lieu of float16, as it has a longer integer (exponent) representation: 8 bits instead of the 5 in float16, ie the same 8 bits as in float32. This makes it more numerically stable and does not require loss scaling. bfloat16 can be activated by adding to the config bf16 { "enabled": true }.

As a final note, the configuration file can also be extended with custom fields, that are e.g. specific to application or hardware, but for brevity we’ll omit those details here.

Activation Checkpointing

Activation Checkpointing allows for a large reduction in memory requirements by not storing all the forward pass activations that are required for the backward propagation. The rationale is simply: instead of storing the output of every layer after the forward pass, only a small subset of (checkpoint) layer outputs are kept in memory, and the remaining are computed on-the-fly - during the backward pass - with a forward pass from the closest lower layer. Activation checkpointing is extremelly relevant for DeepSpeed, as activations are not sharded, therefore not storing all layer activations in memory reduces substantially the memory footprint.

In our use case, and for simplicity, we will store layer activations at a user-specified interval. For that, we create the command line argument --activation_checkpoint_interval that specifies how often to store layer checkpoints:

## train.py

def get_cmd_line_args(description='GPT lite on DeepSpeed'):
  # ...
  parser.add_argument('--activation_checkpoint_interval', type=int, default=0,
                      help='activation checkpoint interval (0 means disabled)')
  # ...

We have to manually specify which layers to checkpoint, by calling deepspeed.checkpointing.checkpoint at the checkpoint layers. We will use the previous lo_layers() method to iterate over the layers of a model and assign the relevant checkpointing in the forward() pass of GPTlite as:

## gptlite.py

class GPTlite(nn.Module):
  #...

  def forward(self, idx, targets=None):

    if self.activation_checkpoint_interval > 0:
      x=idx
      for l, layer in enumerate(self.to_layers()):
        is_checkpoint = l % self.activation_checkpoint_interval == 0 
        x = deepspeed.checkpointing.checkpoint(layer, x) if is_checkpoint else layer(x)
      return x

where self.activation_checkpoint_interval is a value set during initialization of the class. Finally, when doing model parallelism, we can reduce memory substantially by partitioning activations and offloading those checkpoints to the CPU instead of saving them in memory. DeepSpeed does not support model/tensor parallelism natively so we will skip this, but check the json documentation if you are interested.

Pitfalls of activation parallelism in distributed executions

Combining activation checkpointing with distributed model parameters (ZeRO stage-3) is very tricky, and I really recommend against using both. The problem is that, if you need to perform a forward pass from the closest checkpoint layer to collect the weights required for the back propagation, and if those weights are distributed (stage 3), then there has to be an extra collective communication step at every layer (from checkpoint layer to current back-prop layer) to collect those weights. This incurs in a heavy communication burden, and in my experience, led to wrong results (NaN loss).

Launching a distributed execution

The installation of DeepSpeed includes the deepspeed launcher, a network bootstrapper that spaws a python script across compute nodes and GPUs, with different --local_rank argument and different environment variables for the comm world. In our example, to launch the script train.py on a compute node with 8 GPUs, with the DeepSpeed config file ds_config.json, we run on the shell:

$ deepspeed --num_gpus=8 train.py --deepspeed --deepspeed_config ds_config.json

Run deepspeed --help for a brief summary of the launcher options. With torchrun, it can be launched with:

$ torchrun --standalone --nproc_per_node=8 train.py --deepspeed --deepspeed_config ds_config.json --no_local_rank

and on a slurm-cluster execution, with:

slurm-torchrun --torch-script-path="train.py"  \
  --torch-script-extra-args="--deepspeed --deepspeed_config ds_config.json --no_local_rank"

Few notes about distributed executions:

  • --num_gpus is optional: if not provided, it will default to the available GPUs returned by the cuda toolkit;
  • launching with python instead of deepspeed will perform a single-node single-GPU run;
  • if we were required to run this on multiple compute nodes, we’d need to pass an extra parameter --hostfile hostfile, where hostfile is an MPI-style descriptor file of nodes and gpus per node;
  • the batch size should take into consideration the number of compute nodes, the number of GPUs, and the number of gradient accumulation steps or micro-batch size (when applicable). In brief, each process needs at least 1 input sample and:
batch_size = micro_batch_size_per_gpu * num_gpus * num_nodes * gradient_accumulation_steps

Detour: measuring memory allocated to parameters

We can use the DeepSpeed API to estimate the memory requirements of model parameters for different ZeRO implementations, by calling the following method at the onset of execution:

## train.py

def measure_parameters_memory(model):
  param_size_GB = sum([p.nelement() * p.element_size() for p in model.parameters()])/1024**3
  print(f"Native model parameters size: {round(param_size_GB, 2)}GB.")

  from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
  estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)

  from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live
  estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)

The output tells us that:

  • the base model requires about 0.323GB for the storage of parameters, per GPU;
  • DeepSpeed ZeRO-2 requires 0.161GB and 0.484GB for the with and without offload optimizations;
  • DeepSpeed ZeRO-3 requires 0.009GB and 0.190GB for the with and without offload optimizations;

This metric is very useful as it gives a quick overview of scaling and is very fast to compute. However, it has many fallacies: it only measures the parameters overheard, it does not take activations or other residual buffers (e.g. normalization variables) into account, does not take the batch size and numerical precision (or any field in the config file) into account, it does not consider temporary (e.g. communication) buffers, etc.

Results

To measure our performance, we used the deepspeed logger to extract the following metrics from different runs at every 10 steps: model throughput as average number of samples per second, the average allocated memory, and the maximum allocated memory. We used pytorch==2.01, CUDA 11.7 and deepspeed==0.10.3.

All implementations use the same mixed-precision representation, communication bucket sizes, and other config parameters. We benchmarked the following implementations (and configs):

  1. The distributed data parallel (DDP) implementation, i.e. no DeepSpeed (ds_config.json with 'stage': 0);
  2. The fully-sharded data parallel implementation with ZeRO 1, 2 and 3 (ds_config.json with 'stage' :1, 2 or 3);
  3. The ZeRO-3 implementation with ZeRO-Infinity for CPU offloading (ds_config_offload.json);
  4. The ZeRO-3 implementation without activation checkpointing and with activation checkpointing at every block.

We tested three models. The first is a wide version of our benchmark model, with a high parametric space and a small layer count (W=8192, L=3), and input and output of size 8192. We used a batch size of \(2^{14}\), and a micro-batch size of \(2^{11}\) inputs per GPU, ie 'train_batch_size': 16384 and 'train_micro_batch_size_per_gpu': 2048. The benchmark results are:

Then we tested a deep benchmark model with a small parameter space (W=256), a high layer count (L=2048), an input and output size of 256, and the same bath sizes as the previous wide model:

And finally our GPT-lite model, with a micro-batch size of 1 sample per GPU:

Memory overhead from communication buffers. Looking at the max vs average memory, note that the max memory in theory should be much higher at high ZeRO stages compared to low ZeRO stages and DPP. This is due to more parameters being communicated requiring more communication buffers. However, setting the communication bucket sizes to a low value in the config file overcomes this effect. In fact, we also benchmarked several runs with the default communication bucket sizes (5e9) and it led to a higher memory usage as expected (of approximately double the amount in stages 2 and 3), that became prohibitive for some runs.

Parameter vs residual memory. Note the difference between average memory and maximum memory. That gap in memory consumption is due to temporary memory dedicated to activations, residual buffers, communication buffers, etc.

Communication vs computation trade-off from different stages in ZeRO. In ideal scenarios, as you move from DDP to ZeRO-1, ZeRO-2, ZeRO-3 and ZeRO-Infinity, the memory consumption and throughput are reduced. As expected, we swap data locality for communication of parameters, and pay a price in performance for the communication/offload of parameters. This is the pattern observed in the deep benchmark and GPT-lite models. However, the wide benchmark model does not respond similarly, as from stage 2 to stage 3 there is an increase in throughput. I believe this is due to ZeRO-3 distributing the fp16 model parameters, leading to a distributed parallelism of the very large sums of squares per layers.

Offloaded vs in-memory parameters. Offloading proved to be a consistent way to reduce memory usage with the drawback of a small reduction of throughput, as expected.

Activation checkpointing tested on GPTlite trades runtime for memory usage. It yielded a 4x memory reduction at the overhead of +20% in execution time.

Lack of superlinear speedup: We observed a small improvement of memory efficiency, but still far from the claims of the original DeepSpeed team. One explanation is that we used a small network of 8 GPUs, compared to the 64 to 800 GPUs used by the authors in their benchmarks, therefore we achieved a much smaller memory reduction. A large network of GPUs also underlies their claim of superlinear speed up that we did not observe.

Memory as main bottleneck: In most runs, the maximum memory is much higher than the average memory. This is due to temporary buffers that drive the maximum memory up. On the GPTlite use case, most non-parameter memory is due to activations, and could have been improved with Flash Attention or KV cache, that we did not include in this implementation. Maximum memory seems to be the main scaling bottleneck.

There is a lot of food for thought here, and I will be updating this post as I find new insights.

Resources and code

In this post we explored only the dimension of data parallelism. For pipeline parallelism and tensor/model parallelism, see the part 2 of this post.

For general documentation, I recommend the DeepSpeed API documentation, the training features page, the tutorials page, the HuggingFace page for DeepSpeed, and the examples at DeepSpeedExamples.

And finally, if you want to try this on your own, see the GPT-lite-DeepSpeed repo.