Skip to content

Integration Guide

This guide shows how to integrate RAMJET into your PyTorch training workflow.

Table of Contents

  1. Basic Integration
  2. torchrun / torch.distributed.launch
  3. DeepSpeed
  4. HuggingFace Accelerate
  5. Lightning
  6. Custom Training Loop
  7. Multi-Node Setup
  8. Data Source Configuration

Basic Integration

The simplest way to use RAMJET:

import ramjetio
from torch.utils.data import DataLoader

# Step 1: Initialize RAMJET
# This connects to the dashboard and starts the local cache server
ramjetio.init()

# Step 2: Wrap your dataset
dataset = ramjetio.CachedDataset(YourDataset())

# Step 3: Use DataLoader as usual
loader = DataLoader(dataset, batch_size=32, num_workers=4)

# Step 4: Train
for epoch in range(100):
    for batch in loader:
        # First epoch: cache miss → load from source → cache
        # Next epochs: cache hit → instant
        loss = model(batch)
        loss.backward()
        optimizer.step()

Run with:

export RAMJET_API_KEY="your_key"
python train.py

torchrun

RAMJET works seamlessly with torchrun for multi-GPU and multi-node training.

Single Node, Multi-GPU

# train.py
import os
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import ramjetio

def main():
    # Initialize distributed (torchrun sets env vars automatically)
    dist.init_process_group("nccl")
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ["LOCAL_RANK"])

    torch.cuda.set_device(local_rank)

    # Initialize RAMJET (after DDP init)
    ramjetio.init()

    # Create model
    model = YourModel().cuda()
    model = DDP(model, device_ids=[local_rank])

    # Create cached dataset with distributed sampler
    dataset = ramjetio.CachedDataset(YourDataset())
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    loader = DataLoader(dataset, batch_size=32, sampler=sampler, num_workers=4)

    # Training loop
    for epoch in range(100):
        sampler.set_epoch(epoch)
        for batch in loader:
            batch = batch.cuda()
            loss = model(batch)
            loss.backward()
            optimizer.step()

    dist.destroy_process_group()

if __name__ == "__main__":
    main()

Run:

export RAMJET_API_KEY="your_key"
torchrun --nproc_per_node=4 train.py

Multi-Node

# On node 0 (master)
export RAMJET_API_KEY="your_key"
torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --node_rank=0 \
    --master_addr=node0.example.com \
    --master_port=29500 \
    train.py

# On node 1
export RAMJET_API_KEY="your_key"
torchrun \
    --nnodes=2 \
    --nproc_per_node=4 \
    --node_rank=1 \
    --master_addr=node0.example.com \
    --master_port=29500 \
    train.py

DeepSpeed

# train.py
import deepspeed
import ramjetio

# Initialize RAMJET first
ramjetio.init()

# Wrap dataset
dataset = ramjetio.CachedDataset(YourDataset())

# Initialize DeepSpeed
model_engine, optimizer, _, _ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    training_data=dataset,
    config="ds_config.json"
)

# Training loop
for batch in model_engine.training_dataloader:
    loss = model_engine(batch)
    model_engine.backward(loss)
    model_engine.step()

Run:

export RAMJET_API_KEY="your_key"
deepspeed --num_gpus=4 train.py

HuggingFace Accelerate

# train.py
from accelerate import Accelerator
from torch.utils.data import DataLoader
import ramjetio

# Initialize Accelerate
accelerator = Accelerator()

# Initialize RAMJET
ramjetio.init()

# Wrap dataset
dataset = ramjetio.CachedDataset(YourDataset())
loader = DataLoader(dataset, batch_size=32)

# Prepare with Accelerate
model, optimizer, loader = accelerator.prepare(model, optimizer, loader)

# Training loop
for batch in loader:
    loss = model(batch)
    accelerator.backward(loss)
    optimizer.step()

Run:

export RAMJET_API_KEY="your_key"
accelerate launch --num_processes=4 train.py

PyTorch Lightning

import pytorch_lightning as pl
from torch.utils.data import DataLoader
import ramjetio

class YourDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        ramjetio.init()  # Initialize RAMJET

    def train_dataloader(self):
        dataset = ramjetio.CachedDataset(YourDataset())
        return DataLoader(dataset, batch_size=32, num_workers=4)

# Train
trainer = pl.Trainer(accelerator="gpu", devices=4, strategy="ddp")
trainer.fit(model, datamodule=YourDataModule())

Run:

export RAMJET_API_KEY="your_key"
python train.py

Custom Training Loop

For maximum control:

import ramjetio
from ramjetio import DistributedCache, CachedDataset

# Manual initialization
cache = DistributedCache()  # Connects to dashboard, gets node list
ramjetio.init(cache=cache)

# Custom key function for caching
def make_key(idx):
    return f"sample_{idx}_v2"  # Include version for cache invalidation

dataset = CachedDataset(
    dataset=YourDataset(),
    cache=cache,
    key_fn=make_key,
    cache_on_miss=True,  # Cache items on first access
    ttl=86400,  # Cache for 24 hours
)

# Access cache directly
cache.set("my_tensor", torch.randn(100, 100))
tensor = cache.get("my_tensor")
cache.delete("my_tensor")

# Get cache stats
print(dataset.get_cache_stats())
# {'cache_hits': 950, 'cache_misses': 50, 'hit_rate': 95.0}

Multi-Node Setup

How Node Discovery Works

  1. Each node starts with the same RAMJET_API_KEY
  2. Each node registers with the RAMJET dashboard
  3. Dashboard returns list of all nodes in the cluster
  4. Nodes use consistent hashing to determine which node caches which data

You don't need to hardcode node addresses — discovery is automatic.

Network Requirements

Nodes need to communicate on port 9000 (configurable via RAMJET_PORT).

# On each node, ensure port 9000 is open
sudo ufw allow 9000/tcp

Example: 4-Node Cluster

# Same command on all 4 nodes — RAMJET auto-discovers peers
export RAMJET_API_KEY="your_key"
torchrun \
    --nnodes=4 \
    --nproc_per_node=8 \
    --node_rank=$NODE_RANK \
    --master_addr=node0 \
    train.py

Data Source Configuration

Configure your data source in the dashboard, not in code. This keeps credentials secure and allows changes without redeploying.

S3 / MinIO

In dashboard settings: - Type: S3 - Endpoint: https://s3.amazonaws.com or http://minio.local:9000 - Bucket: my-training-data - Access Key: AKIAIOSFODNN7EXAMPLE - Secret Key: wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY - Path Style: Enable for MinIO

Then in your code, use the data source:

import ramjetio

ramjetio.init()

# Get data source config from dashboard
ds = ramjetio.get_data_source()

# Use it (example with s3fs)
import s3fs
fs = s3fs.S3FileSystem(
    endpoint_url=ds.endpoint,
    key=ds.access_key,
    secret=ds.secret_key,
)

# Your dataset uses this filesystem
class MyDataset:
    def __getitem__(self, idx):
        with fs.open(f"{ds.bucket}/data/{idx}.pt") as f:
            return torch.load(f)

HTTP URL

For simple HTTP/HTTPS data sources:

ds = ramjetio.get_data_source()
# ds.type == "http"
# ds.url == "https://data.example.com/dataset"
# ds.auth_header == "Bearer xxx" (if set)

Best Practices

1. Cache Key Design

# Bad: Key doesn't reflect preprocessing version
key_fn = lambda idx: f"sample_{idx}"

# Good: Include version in key
key_fn = lambda idx: f"sample_{idx}_v3_aug2"

# Best: Include hash of preprocessing config
import hashlib
config_hash = hashlib.md5(str(preprocess_config).encode()).hexdigest()[:8]
key_fn = lambda idx: f"sample_{idx}_{config_hash}"

2. Cache Warming

For large datasets, warm the cache before training:

import ramjetio
from tqdm import tqdm

ramjetio.init()
dataset = ramjetio.CachedDataset(YourDataset())

# Warm cache (single process, no GPU needed)
print("Warming cache...")
for i in tqdm(range(len(dataset))):
    _ = dataset[i]

print(f"Cache ready: {dataset.get_cache_stats()}")

3. Memory-Efficient Loading

dataset = ramjetio.CachedDataset(
    YourDataset(),
    cache_on_miss=True,
    # Only cache transformed data, not raw
    transform_before_cache=your_transform,
)

Troubleshooting

See TROUBLESHOOTING.md for common issues.


Supported Launchers

RAMJET automatically detects and works with all major distributed training launchers:

Launcher Env Variables Used Status
torchrun / torch.distributed.launch LOCAL_RANK, RANK, WORLD_SIZE ✅ Tested
DeepSpeed LOCAL_RANK, RANK, WORLD_SIZE ✅ Tested
HuggingFace Accelerate LOCAL_RANK, RANK, WORLD_SIZE ✅ Tested
PyTorch Lightning LOCAL_RANK, RANK, WORLD_SIZE ✅ Tested
SLURM (srun) SLURM_LOCALID, SLURM_PROCID, SLURM_NTASKS ✅ Supported
OpenMPI (mpirun) OMPI_COMM_WORLD_LOCAL_RANK, etc. ✅ Supported
MPICH MPI_LOCALRANKID, PMI_RANK, PMI_SIZE ✅ Supported
Horovod HOROVOD_LOCAL_RANK, HOROVOD_RANK ✅ Supported
AWS SageMaker SM_CURRENT_INSTANCE_LOCAL_RANK ✅ Supported

SLURM Example

#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4

export RAMJET_API_KEY="your_key"
srun python train.py
# train.py - no changes needed!
import ramjetio
ramjetio.init()  # Automatically reads SLURM_LOCALID, SLURM_PROCID

MPI Example

export RAMJET_API_KEY="your_key"
mpirun -np 8 --hostfile hosts.txt python train.py

Horovod Example

import horovod.torch as hvd
import ramjetio

hvd.init()
ramjetio.init()  # Reads HOROVOD_LOCAL_RANK automatically

dataset = ramjetio.CachedDataset(YourDataset())