- Use a laptop to simulate 4-node training of an image classification model using DiLoCo
- Use a single node with 4x 4090 GPUs to simulate 16-node training of a language model using SPARTA
- Simulate distributed training without setting up distributed clusters; avoid Kubernetes, Docker, and GPU hosting.
- Fast iteration: implementing a new distributed training algo from scratch takes as little as 5 lines
- Scale up number of nodes by changing a single parameter
- Switch hardware from a laptop to a multi-GPU node - with no code changes
EXO Gym spins up multiple virtual PyTorch nodes on the hardware available. The virtual nodes train in parallel across the devices, and can communicate with PyTorch primitives such as all_reduce.
... and anything else you can imagine! Implementing new algorithms with EXO Gym is very simple - see Custom Algorithms.
python>=3.10
To install:
git clone https://github.com/exo-explore/gym.git exogym
cd exogym
python3 -m venv .venv && source .venv/bin/activate
pip install -e .Strategies (eg. DiLoCo, SPARTA) are portable across domains. A custom dataset and model can be trained with a distributed algorithm like so:
from exogym import Trainer
from exogym.strategy.diloco import DiLoCoStrategy
train_dataset, val_dataset = ...
model = ... # model.forward() expects a batch, and returns a scalar loss
trainer = Trainer(model, train_dataset, val_dataset)
# Strategy for optimization & communication
strategy = DiLoCoStrategy(
inner_optim='adam',
H=100
)
trainer.fit(
strategy=strategy,
num_nodes=4,
device='mps'
)example/playground.py is a minimal starting-point for writing new algorithms. For example, to implement gradient quantization from scratch:
class QuantizationStrategy(Strategy):
def __init__(self, optim_spec, quantization_level: Literal['int8']):
super().__init__()
self.optim_spec = optim_spec
self.scale = 0.024
self.zero_point = 0
self.qdtype = torch.uint8
def step(self):
for param in self.model.parameters():
if param.grad is not None:
quantized = torch.round(param.grad / self.scale + self.zero_point).clamp(0, 255).to(self.qdtype)
q_wide = quantized.to(torch.int32)
all_reduce(q_wide)
param.grad = (q_wide.to(torch.float32) * self.scale) / self.num_nodes
self.optim.step()
super().step()- CPU
- CUDA
- MPS (CPU-bound for copy operations, see here)
For further details on how EXO Gym works under-the-hood, please see docs/.
If you use EXO Gym in your research, please cite:
@software{exogym2025,
title={EXO Gym},
author={Matt Beton, Mohamed Baioumy, Matt Reed, Seth Howes, Alex Cheema},
year={2025},
url={https://github.com/exo-explore/gym}
}



