Skip to content

API Reference

Module: ramjetio

ramjetio.init()

Initialize RAMJET. This is the recommended way to start using RAMJET.

ramjetio.init(
    api_key: str = None,        # API key (default: from RAMJET_API_KEY env var)
    cache_path: str = None,     # Cache directory (default: /tmp/ramjet_cache)
    cache_size: str = None,     # Max cache size (default: 100GB)
    port: int = None,           # Cache server port (default: 9000)
    auto_start_server: bool = True,  # Start cache server automatically
)

What it does: 1. Reads configuration from environment/arguments 2. Connects to RAMJET dashboard 3. Registers this node 4. Starts local cache server (if auto_start_server=True) 5. Starts background heartbeat/metrics thread

Example:

import ramjetio

# Simple (uses env vars)
ramjetio.init()

# With explicit config
ramjetio.init(
    api_key="ramjet_abc123",
    cache_size="500GB",
    port=9001,
)


ramjetio.CachedDataset

PyTorch Dataset wrapper with distributed caching.

ramjetio.CachedDataset(
    dataset: Dataset,                    # Original PyTorch dataset
    cache: DistributedCache = None,      # Cache instance (default: global cache)
    key_fn: Callable[[int], str] = None, # Key generation function
    cache_on_miss: bool = True,          # Cache items on miss
    ttl: int = None,                     # Time-to-live in seconds
    transform_before_cache: Callable = None,  # Transform to apply before caching
)

Example:

from torch.utils.data import DataLoader
import ramjetio

ramjetio.init()

# Basic usage
dataset = ramjetio.CachedDataset(YourDataset())

# With custom key function
dataset = ramjetio.CachedDataset(
    YourDataset(),
    key_fn=lambda idx: f"sample_{idx}_v2",
    ttl=86400,  # 24 hours
)

loader = DataLoader(dataset, batch_size=32)

Methods: - get_cache_stats() -> dict — Get hit/miss statistics - reset_stats() — Reset statistics counters


ramjetio.DistributedCache

Low-level cache client.

ramjetio.DistributedCache(
    nodes: List[str] = None,  # Node addresses (default: from dashboard)
    timeout: float = 30.0,    # Request timeout
    retry_attempts: int = 3,  # Retry failed requests
)

Methods:

cache = ramjetio.DistributedCache()

# Store data
cache.set(key: str, value: Any, ttl: int = None) -> bool

# Retrieve data
cache.get(key: str) -> Optional[Any]

# Delete data
cache.delete(key: str) -> bool

# Check existence
cache.exists(key: str) -> bool

# Get statistics
cache.stats() -> Dict[str, Dict]

# Clear all data
cache.clear() -> bool

Example:

import torch
from ramjetio import DistributedCache

cache = DistributedCache()

# Store tensor
tensor = torch.randn(100, 100)
cache.set("my_tensor", tensor)

# Retrieve
loaded = cache.get("my_tensor")
assert torch.allclose(tensor, loaded)

# With TTL (expires in 1 hour)
cache.set("temp_data", data, ttl=3600)


ramjetio.get_data_source()

Get data source configuration from dashboard.

ds = ramjetio.get_data_source() -> Optional[DataSource]

Returns DataSource with:

ds.type        # "s3", "http", "local"
ds.endpoint    # S3 endpoint URL
ds.bucket      # S3 bucket name
ds.access_key  # S3 access key
ds.secret_key  # S3 secret key
ds.region      # S3 region (optional)
ds.use_ssl     # Use SSL for S3
ds.path_style  # Use path-style URLs (for MinIO)
ds.url         # HTTP URL (if type="http")
ds.auth_header # HTTP auth header
ds.prefix      # Path prefix

Example:

import ramjetio
import s3fs

ramjetio.init()
ds = ramjetio.get_data_source()

if ds.type == "s3":
    fs = s3fs.S3FileSystem(
        endpoint_url=ds.endpoint,
        key=ds.access_key,
        secret=ds.secret_key,
    )
    files = fs.ls(f"{ds.bucket}/{ds.prefix}")


ramjet.get_cluster_nodes()

Get list of all nodes in the cluster.

nodes = ramjet.get_cluster_nodes(
    online_only: bool = True
) -> List[Dict]

Returns:

[
    {"hostname": "node0", "port": 9000, "status": "online"},
    {"hostname": "node1", "port": 9000, "status": "online"},
]


ramjet.log_metrics()

Log training metrics to dashboard.

ramjet.log_metrics(
    loss: float = None,
    epoch: int = None,
    step: int = None,
    learning_rate: float = None,
    throughput: float = None,  # samples/sec
    **kwargs,  # Any additional metrics
)

Example:

for epoch in range(100):
    for step, batch in enumerate(loader):
        loss = train_step(batch)

        ramjet.log_metrics(
            loss=loss.item(),
            epoch=epoch,
            step=step,
            learning_rate=optimizer.param_groups[0]['lr'],
        )


Environment Variables

Variable Description Default
RAMJET_API_KEY API key for authentication Required
RAMJET_CACHE_PATH Local cache directory /tmp/ramjet_cache
RAMJET_CACHE_SIZE Maximum cache size 100GB
RAMJET_PORT Cache server port 9000
RAMJET_NODE_NAME Node identifier hostname
RAMJET_BACKEND_HOST Backend host api.ramjet.io
RAMJET_BACKEND_PORT Backend port 443
RAMJET_LOG_LEVEL Logging level INFO

CLI Commands

ramjet-server

Start cache server daemon.

ramjet-server [OPTIONS]

Options:
  --host TEXT        Host to bind to [default: 0.0.0.0]
  --port INTEGER     Port to bind to [default: 9000]
  --storage-path     Path for cache storage [default: /tmp/ramjet_cache]
  --capacity TEXT    Maximum cache size [default: 100GB]
  --log-level TEXT   Logging level [default: INFO]

ramjet-client

CLI client for cache operations.

ramjet-client [OPTIONS] COMMAND [ARGS]

Commands:
  stats    Show cache statistics
  get      Get value by key
  set      Set key-value pair
  delete   Delete key
  clear    Clear all cache data
  nodes    List cluster nodes

Options:
  --nodes TEXT  Comma-separated node addresses