Amazon SageMaker: Distributed Training

No training implementation is complete until it allows training on a cluster where each machine has multiple GPUs.

Multi-node/Multi-GPU Training with PyTorch Lightning

SageMaker does a great job enabling this in Script Mode, and all we have to do is write code that supports SageMaker SMDDP implementation of the distributed training DDP protocol.

PyTorch Lighting is also an obvious choice to abstract our training loop, since it supports everything we need, wraps everything up nicely, so you don’t need to gather validation results or make sure logging is activated correctly on the appropriate process.

This PyTorch Lightning Introduction into their distributed API is a great starting point.

Training Script Modifications

Just follow instructions from AWS to enable your training scripts.

I like to wrap it all into a set of funnctions:

is_win = sys.platform.startswith("win")

def get_trainer_env():
  env = LightningEnvironment()
  env.world_size = lambda: int(os.environ.get("WORLD_SIZE", 1))
  env.global_rank = lambda: int(os.environ.get("RANK", 0))
  return env

def get_initialization_info():
  Initialize the distributed training environment and return the data relevant to
  Lighning Trainer initialization.
  world_size = num_nodes = 1
  ddp = None
  num_gpus = int(os.environ.get("SM_NUM_GPUS", 1))

  if not is_win and num_gpus > 1:
    import smdistributed.dataparallel.torch.torch_smddp
    # For DDP with sagemaker
    env = get_trainer_env()
    ddp = DDPStrategy(

    world_size = int(os.environ["WORLD_SIZE"])
    num_nodes = int(world_size/num_gpus)"Training with {num_gpus} GPUs/node on {num_nodes} nodes")
  return ddp, num_gpus, num_nodes

def get_global_rank():
  return int(os.environ.get("RANK", 0))

def get_local_rank():
  return int(os.environ.get("LOCAL_RANK", 0))

The get_initialization_info function can be called from the training script to return all the data needed for distributed or non-distributed training initialization. So, this is either an SMDDP run or a non-distributed training run.

Since DDP is not supported on Windows, we are making doubly sure to not enable it if we are running in that environment.

The import statement on line 19 will only work inside the SageMaker script mode container, so we tuck it safely under the if statement to prevent it from executing in a non-SageMaker environment.

The purpose of the DDPStrategy instance defined on line 24 is to hook up PyTorch Lightning with the protocol SageMaker uses to pass necessary data about the world size and rank designation to the participating processes. Rank can be local or global. Global rank is an integer in the [0, WORLD_SIZE] (not including the upper bound) range, uniquely designating each process, while local rank is [0, NUM_LOCAL_GPUS] and is assigned to a process within its node.

Initializing the PyTorch Lightning trainer is then straightforward:

ddp, num_gpus, num_nodes = dist.get_initialization_info()

trainer = pl.Trainer(
        precision=16, # we'll use mixed precision

Handling I/O Conflicts

With multiple processes per node (each assigned to its own GPU), we may have a situation where the same data is being written/downloaded multiple times on the same node, which may cause conflicts and possibly crashes.

To avoid that, we can use PyTorch Lightning DataModule facility with its prepare_dataset function which takes care of the initial download safely by running it on a single process per node. Set self.prepare_data_per_node=True during module initialization and execute downloading code in a prepare_data override. See the example of this in the prepare_data documentation for PyTorch Lightning.

For instance, if downloading or creating a HuggingFace dataset:

def prepare_data(self, stage=None):

We are just creating the dataset locally, actual Dataset creation will happen in the setup override of the Lightning DataModule

SageMaker Notebook Modifications

from sagemaker.pytorch import PyTorch

# if running on a single instance with a single GPU
instance_type = 'ml.p3.2xlarge'

# Recommend one of these instances for multi-GPU cluster training
#instance_type = 'ml.p4d.24xlarge' 
instance_type = 'ml.p3dn.24xlarge'

# base job for easy identification in SageMaker console
base_job_name = 'donut-ddp-mult-instance-smddp'

distribution = None
distribution = {"pytorchddp": {"enabled": "true"}}
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}

pytorch_estimator = PyTorch(
    max_run=48 * 60 * 60,

distribution is set to the SMDDP backend, AWS docs indicate that PyTorch native DDP backend is also fully supported, but I haven’t tried it.

It’s a good idea to let the estimator figure out the appropriate version of the PyTorch image, since a lot is riding on different frameworks being compatible, so not specifying image_uri, but requesting versions of PyTorch and Python instead.

Don’t forget the requirements.txt file in the source_dir, which should at the minimum contain the line:


I found this version of PyTorch Lightning to play well with SageMaker.

One more thing to not forget: the effective learning rate will be the chosen lr * world_size, so set it accordingly in the hyperparameters:

hyperparameters = {
    "epochs": 30,
    "batch": 4,
    "lr": 1e-7 * 8 * 4,

A Gotcha

For desert we have this doozy of a gotcha. If you have specified a security group to the estimator through security_group_ids make sure the group has appropriate permissions for inbound and outbound communications as described in this article. This one got me stomped until an AWS Support specialist pointed out the solution.

Happy (distributed) training!

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s

This site uses Akismet to reduce spam. Learn how your comment data is processed.