How BRIA AI used distributed training in Amazon SageMaker to train latent diffusion foundation models for commercial use


This post is co-written with Bar Fingerman from BRIA AI.

This post explains how BRIA AI trained BRIA AI 2.0, a high-resolution (1024×1024) text-to-image diffusion model, on a dataset comprising petabytes of licensed images quickly and economically. Amazon SageMaker training jobs and Amazon SageMaker distributed training libraries took on the undifferentiated heavy lifting associated with infrastructure management. SageMaker helps you build, train, and deploy machine learning (ML) models for your use cases with fully managed infrastructure, tools, and workflows.

BRIA AI is a pioneering platform specializing in responsible and open generative artificial intelligence (AI) for developers, offering advanced models exclusively trained on licensed data from partners such as Getty Images, DepositPhotos, and Alamy. BRIA AI caters to major brands, animation and gaming studios, and marketing agencies with its multimodal suite of generative models. Emphasizing ethical sourcing and commercial readiness, BRIA AI’s models are source-available, secure, and optimized for integration with various tech stacks. By addressing foundational challenges in data procurement, continuous model training, and seamless technology integration, BRIA AI aims to be the go-to platform for creative AI application developers.

You can also find the BRIA AI 2.0 model for image generation on AWS Marketplace.

This blog post discusses how BRIA AI worked with AWS to address the following key challenges:

  • Achieving out-of-the-box operational excellence for large model training
  • Reducing time-to-train by using data parallelism
  • Maximizing GPU utilization with efficient data loading
  • Reducing model training cost (by paying only for net training time)

Importantly, BRIA AI was able to use SageMaker while keeping the initially used HuggingFace Accelerate (Accelerate) software stack intact. Thus, transitioning to SageMaker training didn’t require changes to BRIA AI’s model implementation or training code. Later, BRIA AI was able to seamlessly evolve their software stack on SageMaker along with their model training.

Training pipeline architecture

BRIA AI’s training pipeline consists of two main components:

Data preprocessing:

  • Data contributors upload licensed raw image files to BRIA AI’s Amazon Simple Storage Service (Amazon S3) bucket.
  • An image pre-processing pipeline using Amazon Simple Queue Service (Amazon SQS) and AWS Lambda functions generates missing image metadata and packages training data into large webdataset files for later efficient data streaming directly from an S3 bucket, and data sharding across GPUs. See the [Challenge 1] section. Webdataset is a PyTorch implementation therefore it fits well with Accelerate.

Model training:

  • SageMaker distributes training jobs for managing the training cluster and runs the training itself.
  • Streaming data from S3 to the training instances using SageMaker’s FastFile mode.

Pre-training challenges and solutions

Pre-training foundation models is a challenging task. Challenges include cost, performance, orchestration, monitoring, and the engineering expertise needed throughout the weeks-long training process.

The four challenges we faced were:

Challenge 1: Achieving out-of-the-box operational excellence for large model training

To orchestrate the training cluster and recover from failures, BRIA AI relies on SageMaker Training Jobs’ resiliency features. These include cluster health checks, built-in retries, and job resiliency. Before your job starts, SageMaker runs GPU health checks and verifies NVIDIA Collective Communications Library (NCCL) communication on GPU instances, replacing faulty instances (if necessary) to make sure your training script starts running on a healthy cluster of instances. You can also configure SageMaker to automatically retry training jobs that fail with a SageMaker internal server error (ISE). As part of retrying a job, SageMaker will replace instances that encountered unrecoverable GPU errors with fresh instances, reboot the healthy instances, and start the job again. This results in faster restarts and workload completion. By using AWS Deep Learning Containers, the BRIA AI workload benefited from the SageMaker SDK automatically setting the necessary environment variables to tune NVIDIA NCCL AWS Elastic Fabric Adapter (EFA) networking based on well-known best practices. This helps maximize the workload throughput.

To monitor the training cluster, BRIA AI used the built-in SageMaker integration to Amazon CloudWatch logs (applicative logs), and CloudWatch metrics (CPU, GPU, and networking metrics).

Challenge 2: Reducing time-to-train by using data parallelism

BRIA AI needed to train a stable-diffusion 2.0 model from scratch on petabytes-scale licensed image dataset. Training on a single GPU could take few month to complete. To meet deadline requirements, BRIA AI used data parallelism by using a SageMaker training with 16 p4de.24xlarge instances, reducing the total training time to under two weeks. Distributed data parallel training allows for much faster training of large models by splitting data across many devices that train in parallel, while syncing gradients regularly to keep a consistent shared model. It uses the combined computing power of many devices. BRIA AI used a cluster of four p4de.24xlarge instances (8xA100 80GB NVIDIA GPUs) to achieve a throughput of 1.8 it per second for an effective batch size of 2048 (batch=8, bf16, accumulate=2).

p4de.24xlarge instances include 600 GB per second peer-to-peer GPU communication with NVIDIA NVSwitch. 400 gigabits per second (Gbps) instance networking with support for EFA and NVIDIA GPUDirect RDMA (remote direct memory access).

Note: Currently you can use p5.48xlarge instances (8XH100 80GB GPUs) with 3200 Gbps networking between instances using EFA 2.0 (not used in this pre-training by BRIA AI).

Accelerate is a library that enables the same PyTorch code to be run across a distributed configuration with minimal code adjustments.

BRIA AI used Accelerate for small scale training off the cloud. When it was time to scale out training in the cloud, BRIA AI was able to continue using Accelerate, thanks to its built-in integration with SageMaker and Amazon SageMaker distributed data parallel library (SMDDP). SMDDP is purpose built to the AWS infrastructure, reducing communications overhead in two ways:

  • The library performs AllReduce, a key operation during distributed training that’s responsible for a large portion of communication overhead (optimal GPU usage with efficient AllReduce overlapping with a backward pass).
  • The library performs optimized node-to-node communication by fully utilizing the AWS network infrastructure and Amazon Elastic Compute Cloud (Amazon EC2) instance topology (optimal bandwidth use with balanced fusion buffer).

Note that SageMaker training supports many open source distributed training libraries, for example Fully Sharded Data Parallel (FSDP), and DeepSpeed. BRIA AI used FSDP in SageMaker in other training workloads. In this case, by using the ShardingStrategy.SHARD_GRAD_OP feature, BRIA AI was able to achieve an optimal batch size and accelerate their training process.

Challenge 3: Achieving efficient data loading

The BRIA AI dataset included hundreds of millions of images that needed to be delivered from storage onto GPUs for processing. Efficiently accessing this large amount of data across a training cluster presents several challenges:

  • The data might not fit into the storage of a single instance.
  • Downloading the multi-terabyte dataset to each training instance is time consuming while the GPUs sit idle.
  • Copying millions of small image files from Amazon S3 can become a bottleneck because of accumulated roundtrip time of fetching objects from S3.
  • The data needs to be split correctly between instances.

BRIA AI addressed these challenges by using SageMaker fast file input mode, which provided the following out-of-the-box features:

  • Streaming Instead of copying data when training starts, or using an additional distributed file system, we chose to stream data directly from Amazon S3 to the training instances using SageMaker fast file mode. This allows training to start immediately without waiting for downloads. Streaming also reduces the need to fit datasets into instance storage.
  • Data distribution: Fast file mode was configured to shard the dataset files between multiple instances using S3DataDistributionType=ShardedByS3Key.
  • Local file access: Fast file mode provides a local POSIX filesystem interface to data in Amazon S3. This allowed BRIA AI’s data loader to access remote data as if it was local.
  • Packaging files to large containers: Using millions of small image and metadata files is an overhead when streaming data from object storage like Amazon S3. To reduce this overhead, BRIA AI compacted multiple files into large TAR file containers (2–5 GB), which can be efficiently streamed from S3 using fast file mode to the instances. Specifically, BRIA AI used WebDataset for efficient local data loading and used a policy wherein there is no data loading synchronization between instances and each GPU loads random batches through a fixed seed. This policy helps eliminate bottlenecks and maintains fast and deterministic data loading performance.

For more on data loading considerations, see Choose the best data source for your Amazon SageMaker training job blog post.

Challenge 4: Paying only for net training time

Pre-training large language models is not continuous. The model training often requires intermittent stops for evaluation and adjustments. For instance, the model might stop converging and need adjustments, or you might want to pause training to test the model, refine data, or troubleshoot issues. These pauses result in extended periods where the GPU cluster is idle. With SageMaker training jobs, BRIA AI was able to only pay for the duration of their active training time. This allowed BRIA AI to train models at a lower cost and with greater efficiency.

BRIA AI training strategy is composed of three steps for resolution for optimal model convergence:

  1. Initial training on a 256×256 – 32 GPUs cluster
  2. Progressive refinement to a 512×512 – 64 GPUs cluster
  3. Final training on a 1024×1024 – 128 GPUs cluster

In each step, the computing required was different due to applied tradeoffs, such as the batch size per resolution and the upper limit of the GPU and gradient accumulation. The tradeoff is between cost-saving and model coverage.

BRIA AI’s cost calculations were facilitated by maintaining a consistent iteration per second rate, which allowed for accurate estimation of training time. This enabled precise determination of the required number of iterations and calculation of the training compute cost per hour.

BRIA AI training GPU utilization and average batch size time:

  • GPU utilization:  Average is over 98 percent, signifying maximization of GPUs for the whole training cycle and that our data loader is efficiently streaming data at a high rate.
  • Iterations per second :  Training strategy is composed of three steps—Initial training on 256×256, progressive refinement to 512×512, and final training on 1024×1024 resolution for optimal model convergence. For each step, the amount of computing varies because there are tradeoffs that we can apply with different batch sizes per resolution while considering the upper limit of the GPU and gradient accumulation, where the tension is cost-saving against model coverage.

Result examples

Result examples

Prompts used for generating the images
Prompt 1, upper left image: A stylish man sitting casually on outdoor steps, wearing a green hoodie, matching green pants, black shoes, and sunglasses. He is smiling and has neatly groomed hair and a short beard. A brown leather bag is placed beside him. The background features a brick wall and a window with white frames.

Prompt 2, upper right image: A vibrant Indian wedding ceremony. The smiling bride in a magenta saree with gold embroidery and henna-adorned hands sits adorned in traditional gold jewelry. The groom, sitting in front of her, in a golden sherwani and white dhoti, pours water into a ceremonial vessel. They are surrounded by flowers, candles, and leaves in a colorful, festive atmosphere filled with traditional objects.

Prompt 3, lower left image: A wooden tray filled with a variety of delicious pastries. The tray includes a croissant dusted with powdered sugar, a chocolate-filled croissant, a partially eaten croissant, a Danish pastry and a muffin next to a small jar of chocolate sauce, and a bowl of coffee beans, all arranged on a beige cloth.

Prompt 4, lower right image: A panda pouring milk into a white cup on a table with coffee beans, flowers, and a coffee press. The background features a black-and-white picture and a decorative wall piece.

Conclusion

In this post, we saw how Amazon SageMaker enabled BRIA AI to train a diffusion model efficiently, without needing to manually provision and configure infrastructure. By using SageMaker training, BRIA AI was able to reduce costs and accelerate iteration speed, reducing training time with distributed training while maintaining 98 percent GPU utilization, and maximize value per cost. By taking on the undifferentiated heavy lifting, SageMaker empowered BRIA AI’s team to be more productive and deliver innovations faster. The ease of use and automation offered by SageMaker training jobs makes it an attractive option for any team looking to efficiently train large, state-of-the-art models.

To learn more about how SageMaker can help you train large AI models efficiently and cost-effectively, explore the Amazon SageMaker page. You can also reach out to your AWS account team to discover how to unlock the full potential of your large-scale AI initiatives.


About the Authors

Bar FingermanBar Fingerman, Head Of Engineering AI/ML at BRIA AI.

Doron BleibergDoron Bleiberg, Senior Startup Solutions Architect.

Gili Nachum, Principal Gen AI/ML Specialist Solutions ArchitectGili Nachum, Principal Gen AI/ML Specialist Solutions Architect.

Erez ZarumErez Zarum, Startup Solutions Architect,



Source link