Integration Guide¶
This guide shows how to integrate RAMJET into your PyTorch training workflow.
Table of Contents¶
- Basic Integration
- torchrun / torch.distributed.launch
- DeepSpeed
- HuggingFace Accelerate
- Lightning
- Custom Training Loop
- Multi-Node Setup
- Data Source Configuration
Basic Integration¶
The simplest way to use RAMJET:
import ramjetio
from torch.utils.data import DataLoader
# Step 1: Initialize RAMJET
# This connects to the dashboard and starts the local cache server
ramjetio.init()
# Step 2: Wrap your dataset
dataset = ramjetio.CachedDataset(YourDataset())
# Step 3: Use DataLoader as usual
loader = DataLoader(dataset, batch_size=32, num_workers=4)
# Step 4: Train
for epoch in range(100):
for batch in loader:
# First epoch: cache miss → load from source → cache
# Next epochs: cache hit → instant
loss = model(batch)
loss.backward()
optimizer.step()
Run with:
torchrun¶
RAMJET works seamlessly with torchrun for multi-GPU and multi-node training.
Single Node, Multi-GPU¶
# train.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import ramjetio
def main():
# Initialize distributed (torchrun sets env vars automatically)
dist.init_process_group("nccl")
rank = dist.get_rank()
world_size = dist.get_world_size()
local_rank = int(os.environ["LOCAL_RANK"])
torch.cuda.set_device(local_rank)
# Initialize RAMJET (after DDP init)
ramjetio.init()
# Create model
model = YourModel().cuda()
model = DDP(model, device_ids=[local_rank])
# Create cached dataset with distributed sampler
dataset = ramjetio.CachedDataset(YourDataset())
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)
# Training loop
for epoch in range(100):
sampler.set_epoch(epoch)
for batch in loader:
batch = batch.cuda()
loss = model(batch)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
main()
Run:
Multi-Node¶
# On node 0 (master)
export RAMJET_API_KEY="your_key"
torchrun \
--nnodes=2 \
--nproc_per_node=4 \
--node_rank=0 \
--master_addr=node0.example.com \
--master_port=29500 \
train.py
# On node 1
export RAMJET_API_KEY="your_key"
torchrun \
--nnodes=2 \
--nproc_per_node=4 \
--node_rank=1 \
--master_addr=node0.example.com \
--master_port=29500 \
train.py
DeepSpeed¶
# train.py
import deepspeed
import ramjetio
# Initialize RAMJET first
ramjetio.init()
# Wrap dataset
dataset = ramjetio.CachedDataset(YourDataset())
# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
model=model,
model_parameters=model.parameters(),
training_data=dataset,
config="ds_config.json"
)
# Training loop
for batch in model_engine.training_dataloader:
loss = model_engine(batch)
model_engine.backward(loss)
model_engine.step()
Run:
HuggingFace Accelerate¶
# train.py
from accelerate import Accelerator
from torch.utils.data import DataLoader
import ramjetio
# Initialize Accelerate
accelerator = Accelerator()
# Initialize RAMJET
ramjetio.init()
# Wrap dataset
dataset = ramjetio.CachedDataset(YourDataset())
loader = DataLoader(dataset, batch_size=32)
# Prepare with Accelerate
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)
# Training loop
for batch in loader:
loss = model(batch)
accelerator.backward(loss)
optimizer.step()
Run:
PyTorch Lightning¶
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import ramjetio
class YourDataModule(pl.LightningDataModule):
def __init__(self):
super().__init__()
ramjetio.init() # Initialize RAMJET
def train_dataloader(self):
dataset = ramjetio.CachedDataset(YourDataset())
return DataLoader(dataset, batch_size=32, num_workers=4)
# Train
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp")
trainer.fit(model, datamodule=YourDataModule())
Run:
Custom Training Loop¶
For maximum control:
import ramjetio
from ramjetio import DistributedCache, CachedDataset
# Manual initialization
cache = DistributedCache() # Connects to dashboard, gets node list
ramjetio.init(cache=cache)
# Custom key function for caching
def make_key(idx):
return f"sample_{idx}_v2" # Include version for cache invalidation
dataset = CachedDataset(
dataset=YourDataset(),
cache=cache,
key_fn=make_key,
cache_on_miss=True, # Cache items on first access
ttl=86400, # Cache for 24 hours
)
# Access cache directly
cache.set("my_tensor", torch.randn(100, 100))
tensor = cache.get("my_tensor")
cache.delete("my_tensor")
# Get cache stats
print(dataset.get_cache_stats())
# {'cache_hits': 950, 'cache_misses': 50, 'hit_rate': 95.0}
Multi-Node Setup¶
How Node Discovery Works¶
- Each node starts with the same
RAMJET_API_KEY - Each node registers with the RAMJET dashboard
- Dashboard returns list of all nodes in the cluster
- Nodes use consistent hashing to determine which node caches which data
You don't need to hardcode node addresses — discovery is automatic.
Network Requirements¶
Nodes need to communicate on port 9000 (configurable via RAMJET_PORT).
Example: 4-Node Cluster¶
# Same command on all 4 nodes — RAMJET auto-discovers peers
export RAMJET_API_KEY="your_key"
torchrun \
--nnodes=4 \
--nproc_per_node=8 \
--node_rank=$NODE_RANK \
--master_addr=node0 \
train.py
Data Source Configuration¶
Configure your data source in the dashboard, not in code. This keeps credentials secure and allows changes without redeploying.
S3 / MinIO¶
In dashboard settings:
- Type: S3
- Endpoint: https://s3.amazonaws.com or http://minio.local:9000
- Bucket: my-training-data
- Access Key: AKIAIOSFODNN7EXAMPLE
- Secret Key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY
- Path Style: Enable for MinIO
Then in your code, use the data source:
import ramjetio
ramjetio.init()
# Get data source config from dashboard
ds = ramjetio.get_data_source()
# Use it (example with s3fs)
import s3fs
fs = s3fs.S3FileSystem(
endpoint_url=ds.endpoint,
key=ds.access_key,
secret=ds.secret_key,
)
# Your dataset uses this filesystem
class MyDataset:
def __getitem__(self, idx):
with fs.open(f"{ds.bucket}/data/{idx}.pt") as f:
return torch.load(f)
HTTP URL¶
For simple HTTP/HTTPS data sources:
ds = ramjetio.get_data_source()
# ds.type == "http"
# ds.url == "https://data.example.com/dataset"
# ds.auth_header == "Bearer xxx" (if set)
Best Practices¶
1. Cache Key Design¶
# Bad: Key doesn't reflect preprocessing version
key_fn = lambda idx: f"sample_{idx}"
# Good: Include version in key
key_fn = lambda idx: f"sample_{idx}_v3_aug2"
# Best: Include hash of preprocessing config
import hashlib
config_hash = hashlib.md5(str(preprocess_config).encode()).hexdigest()[:8]
key_fn = lambda idx: f"sample_{idx}_{config_hash}"
2. Cache Warming¶
For large datasets, warm the cache before training:
import ramjetio
from tqdm import tqdm
ramjetio.init()
dataset = ramjetio.CachedDataset(YourDataset())
# Warm cache (single process, no GPU needed)
print("Warming cache...")
for i in tqdm(range(len(dataset))):
_ = dataset[i]
print(f"Cache ready: {dataset.get_cache_stats()}")
3. Memory-Efficient Loading¶
dataset = ramjetio.CachedDataset(
YourDataset(),
cache_on_miss=True,
# Only cache transformed data, not raw
transform_before_cache=your_transform,
)
Troubleshooting¶
See TROUBLESHOOTING.md for common issues.
Supported Launchers¶
RAMJET automatically detects and works with all major distributed training launchers:
| Launcher | Env Variables Used | Status |
|---|---|---|
torchrun / torch.distributed.launch |
LOCAL_RANK, RANK, WORLD_SIZE | ✅ Tested |
| DeepSpeed | LOCAL_RANK, RANK, WORLD_SIZE | ✅ Tested |
| HuggingFace Accelerate | LOCAL_RANK, RANK, WORLD_SIZE | ✅ Tested |
| PyTorch Lightning | LOCAL_RANK, RANK, WORLD_SIZE | ✅ Tested |
SLURM (srun) |
SLURM_LOCALID, SLURM_PROCID, SLURM_NTASKS | ✅ Supported |
OpenMPI (mpirun) |
OMPI_COMM_WORLD_LOCAL_RANK, etc. | ✅ Supported |
| MPICH | MPI_LOCALRANKID, PMI_RANK, PMI_SIZE | ✅ Supported |
| Horovod | HOROVOD_LOCAL_RANK, HOROVOD_RANK | ✅ Supported |
| AWS SageMaker | SM_CURRENT_INSTANCE_LOCAL_RANK | ✅ Supported |
SLURM Example¶
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
export RAMJET_API_KEY="your_key"
srun python train.py
# train.py - no changes needed!
import ramjetio
ramjetio.init() # Automatically reads SLURM_LOCALID, SLURM_PROCID