No training implementation is complete until it allows training on a cluster where each machine has multiple GPUs.
Multi-node/Multi-GPU Training with PyTorch Lightning
SageMaker does a great job enabling this in Script Mode, and all we have to do is write code that supports SageMaker SMDDP implementation of the distributed training DDP protocol.
PyTorch Lighting is also an obvious choice to abstract our training loop, since it supports everything we need, wraps everything up nicely, so you don’t need to gather validation results or make sure logging is activated correctly on the appropriate process.
This PyTorch Lightning Introduction into their distributed API is a great starting point.
Training Script Modifications
Just follow instructions from AWS to enable your training scripts.
I like to wrap it all into a set of funnctions:
is_win = sys.platform.startswith("win")
def get_trainer_env():
env = LightningEnvironment()
env.world_size = lambda: int(os.environ.get("WORLD_SIZE", 1))
env.global_rank = lambda: int(os.environ.get("RANK", 0))
return env
def get_initialization_info():
'''
Initialize the distributed training environment and return the data relevant to
Lighning Trainer initialization.
'''
world_size = num_nodes = 1
ddp = None
num_gpus = int(os.environ.get("SM_NUM_GPUS", 1))
if not is_win and num_gpus > 1:
import smdistributed.dataparallel.torch.torch_smddp
# For DDP with sagemaker
env = get_trainer_env()
ddp = DDPStrategy(
cluster_environment=env,
process_group_backend="smddp",
)
world_size = int(os.environ["WORLD_SIZE"])
num_nodes = int(world_size/num_gpus)
logging.info(f"Training with {num_gpus} GPUs/node on {num_nodes} nodes")
return ddp, num_gpus, num_nodes
def get_global_rank():
return int(os.environ.get("RANK", 0))
def get_local_rank():
return int(os.environ.get("LOCAL_RANK", 0))
The get_initialization_info
function can be called from the training script to return all the data needed for distributed or non-distributed training initialization. So, this is either an SMDDP run or a non-distributed training run.
Since DDP is not supported on Windows, we are making doubly sure to not enable it if we are running in that environment.
The import
statement on line 19
will only work inside the SageMaker script mode container, so we tuck it safely under the if
statement to prevent it from executing in a non-SageMaker environment.
The purpose of the DDPStrategy
instance defined on line 24
is to hook up PyTorch Lightning with the protocol SageMaker uses to pass necessary data about the world size and rank designation to the participating processes. Rank can be local or global. Global rank is an integer in the [0, WORLD_SIZE] (not including the upper bound) range, uniquely designating each process, while local rank is [0, NUM_LOCAL_GPUS] and is assigned to a process within its node.
Initializing the PyTorch Lightning trainer is then straightforward:
ddp, num_gpus, num_nodes = dist.get_initialization_info()
trainer = pl.Trainer(
accelerator="cuda",
strategy=ddp,
devices=num_gpus,
num_nodes=num_nodes,
max_epochs=args.epochs,
val_check_interval=args.val_check_interval,
check_val_every_n_epoch=args.check_val_every_n_epoch,
gradient_clip_val=1.0,
precision=16, # we'll use mixed precision
num_sanity_val_steps=0,
logger=logger,
callbacks=[checkpoint],
)
Handling I/O Conflicts
With multiple processes per node (each assigned to its own GPU), we may have a situation where the same data is being written/downloaded multiple times on the same node, which may cause conflicts and possibly crashes.
To avoid that, we can use PyTorch Lightning DataModule facility with its prepare_dataset
function which takes care of the initial download safely by running it on a single process per node. Set self.prepare_data_per_node=True
during module initialization and execute downloading code in a prepare_data
override. See the example of this in the prepare_data documentation for PyTorch Lightning.
For instance, if downloading or creating a HuggingFace dataset:
def prepare_data(self, stage=None):
load_dataset(self.dataset_name_or_path)
We are just creating the dataset locally, actual Dataset
creation will happen in the setup
override of the Lightning DataModule
SageMaker Notebook Modifications
from sagemaker.pytorch import PyTorch
# if running on a single instance with a single GPU
instance_type = 'ml.p3.2xlarge'
# Recommend one of these instances for multi-GPU cluster training
#instance_type = 'ml.p4d.24xlarge'
instance_type = 'ml.p3dn.24xlarge'
# base job for easy identification in SageMaker console
base_job_name = 'donut-ddp-mult-instance-smddp'
distribution = None
distribution = {"pytorchddp": {"enabled": "true"}}
distribution = {"smdistributed": {"dataparallel": {"enabled": True}}}
pytorch_estimator = PyTorch(
source_dir='finetuning',
entry_point='train.py',
base_job_name=base_job_name,
hyperparameters=hyperparameters,
framework_version="1.12.1",
py_version='py38',
role=role,
instance_type=instance_type,
instance_count=4,
volume_size=200,
use_spot_instances=False,
max_run=48 * 60 * 60,
security_group_ids=["My-SecurityGroup"],
distribution=distribution,
)
distribution
is set to the SMDDP backend, AWS docs indicate that PyTorch native DDP backend is also fully supported, but I haven’t tried it.
It’s a good idea to let the estimator figure out the appropriate version of the PyTorch image, since a lot is riding on different frameworks being compatible, so not specifying image_uri
, but requesting versions of PyTorch and Python instead.
Don’t forget the requirements.txt
file in the source_dir
, which should at the minimum contain the line:
pytorch-lightning==1.7.7
I found this version of PyTorch Lightning to play well with SageMaker.
One more thing to not forget: the effective learning rate will be the chosen lr
* world_size
, so set it accordingly in the hyperparameters:
hyperparameters = {
"epochs": 30,
"batch": 4,
"lr": 1e-7 * 8 * 4,
}
A Gotcha
For desert we have this doozy of a gotcha. If you have specified a security group to the estimator through security_group_ids
make sure the group has appropriate permissions for inbound and outbound communications as described in this article. This one got me stomped until an AWS Support specialist pointed out the solution.
Happy (distributed) training!