A high-performance, distributed dataset loading library for PyTorch with advanced shuffling algorithms and seamless scaling from single-machine to multi-machine distributed training.
CorgiPile Dataset API is designed for large-scale machine learning scenarios where traditional PyTorch DataLoader falls short. Our solution addresses critical challenges in distributed training environments.
- Enterprise ML Training: Large datasets (TB-scale) that don't fit in memory
- Multi-Machine Distributed Training: Seamless scaling from 1 to 100+ machines
- Cloud & Cluster Environments: HDFS, distributed storage systems
- High-Performance Computing: Maximize GPU utilization with efficient data loading
Tuple id denotes the tuple position after shuffling. #tuple refers to the number of negative/positive tuples in every 20 tuples shuffled.
Left: Advanced Dual-Layer Shuffle | Center: Single-Machine Parallelism | Right: Multi-Machine Distribution
- Block+Tuple dual-layer algorithm provides optimal balance of randomness and efficiency
- Significantly better than sequential loading, comparable to full randomization
- Memory-efficient: No need to load entire dataset for shuffling
- Single-machine multi-threading: Automatic load balancing across workers
- Multi-machine distributed: File-level partitioning with zero data overlap
- Universal storage: Local filesystem, HDFS, extensible to cloud storage
- Complete traceability: Every sample tracked with
(file_id, inner_index) - Robust error handling: Graceful failure recovery in distributed environments
- Performance monitoring: Built-in logging and debugging capabilities
View All Shuffle Methods Comparison
- Advanced Shuffle Algorithm: Block+Tuple dual-layer provides optimal randomness vs. efficiency trade-off
- Production Debugging: Full sample traceability with
(file_id, inner_index)tracking - Storage Agnostic: Local files, HDFS, easily extensible to S3/GCS
CorgiPile has garnered positive feedback and adoption in various communities and systems beyond our initial implementation. For instance:
- CorgiPile-PostgreSQL: Integrated into PostgreSQL for efficient data shuffling in database-driven ML pipelines, improving query and training performance on large stored datasets (https://github.com/DS3Lab/CorgiPile-PostgreSQL).
- CorgiPile-openGauss (GaussML): Adopted in the openGauss , enhancing shuffled data access for distributed ML workloads with reduced I/O latency (https://ieeexplore.ieee.org/document/10597842).
- Mobileye's Corg²: An improved variant used by Mobileye, which applies CorgiPile twice—once offline for initial data preparation and once online during training—to further optimize for real-time autonomous driving data processing (https://arxiv.org/pdf/2309.01640).
- LLM Training Systems: Enhanced versions of CorgiPile have been employed in MLSys-inspired frameworks for LLM pretraining, where handling terabyte-scale corpora benefits from the method's efficiency, as evidenced by community discussions and adaptations in open-source LLM tools (https://openreview.net/forum?id=I2LF8QHaua).
git clone https://github.com/yourusername/corgipile-dataset-api.git
cd corgipile-dataset-api
# Install dependencies
pip install -r requirements.txtThe key to using CorgiPile is implementing your custom load_data_fn. This function defines how to read your specific data format:
def my_data_loader(file_path: str, **kwargs):
"""
Custom data loading function - adapt this to your data format
Args:
file_path: Path to the data file
**kwargs: Additional info (file_id, etc.)
Yields:
tuple: (data, label, trace_info)
- data: Your actual data (tensor, text, etc.)
- label: Ground truth label
- trace_info: (file_id, inner_index) for debugging
"""
file_id = kwargs.get('file_id', 0)
# Example 1: Text data (TSV format)
with open(file_path, 'r') as f:
for inner_idx, line in enumerate(f):
text, label = line.strip().split('\t')
yield (text, int(label), (file_id, inner_idx))
# Example 2: JSON data
# import json
# with open(file_path, 'r') as f:
# for inner_idx, line in enumerate(f):
# data = json.loads(line)
# yield (data['features'], data['label'], (file_id, inner_idx))
# Example 3: Binary data
# import pickle
# with open(file_path, 'rb') as f:
# data = pickle.load(f)
# for inner_idx, (features, label) in enumerate(data):
# yield (features, label, (file_id, inner_idx))Perfect for single GPU/multi-core training with large datasets:
import torch
from torch.utils.data import DataLoader
from corgipile_dataset_api import CorgiPileLocalDataset
# Create dataset with your custom loader
dataset = CorgiPileLocalDataset(
data_dir="/path/to/your/data", # Directory with your data files
block_size=100, # Samples per block (tune for memory)
load_data_fn=my_data_loader, # Your custom loading function
shuffle=True, # Enable dual-layer shuffle
log_dir="./logs" # Track data loading (optional)
)
# Standard PyTorch DataLoader - works seamlessly!
dataloader = DataLoader(
dataset,
batch_size=32,
num_workers=4, # Multi-threading for performance
pin_memory=True # GPU optimization
)
# Train as usual
for batch_idx, (data, labels, trace_info) in enumerate(dataloader):
# trace_info contains (file_id, inner_index) for each sample
# Your training code here
outputs = model(data)
loss = criterion(outputs, labels)
# ...Scale to multiple machines with automatic data partitioning:
from corgipile_dataset_api import CorgiPileDistributedLocalDataset
# Each machine gets different data files automatically
dataset = CorgiPileDistributedLocalDataset(
data_dir="/shared/training/data", # Shared storage (NFS, etc.)
block_size=100,
load_data_fn=my_data_loader,
rank=rank, # 0, 1, 2, ... (current machine)
world_size=world_size # Total number of machines
)
# Same DataLoader code - CorgiPile handles the distribution!
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)Perfect for Hadoop clusters and cloud environments:
from corgipile_dataset_api import CorgiPileHDFSDataset, CorgiPileDistributedHDFSDataset
# Single-machine HDFS
dataset = CorgiPileHDFSDataset(
hdfs_root="/user/data/training",
hdfs_host="namenode-host",
hdfs_port=9000,
hdfs_user="hadoop-user",
block_size=100,
load_data_fn=my_data_loader, # Same function works!
shuffle=True
)
# Multi-machine HDFS (for large clusters)
distributed_dataset = CorgiPileDistributedHDFSDataset(
hdfs_root="/user/data/training",
hdfs_host="namenode-host",
hdfs_port=9000,
hdfs_user="hadoop-user",
block_size=100,
load_data_fn=my_data_loader,
rank=rank, # Machine rank in cluster
world_size=world_size, # Total machines
shuffle=True
)Parameters:
data_dir(str): Root directory containing data filesblock_size(int): Number of samples per blockload_data_fn(Callable): Function to load data from file pathshuffle(bool): Enable dual-layer shuffle. Default: Truelog_dir(Optional[str]): Directory for logging. If None, no loggingfile_filter_fn(Optional[Callable]): Function to filter valid files
Additional Parameters:
rank(int): Current machine rank. Default: 0world_size(int): Total number of machines. Default: 1
Additional Parameters:
hdfs_root(str): HDFS root directory pathhdfs_host(str): HDFS namenode hostnamehdfs_port(int): HDFS namenode porthdfs_user(str): HDFS username
Additional Parameters:
hdfs_root(str): HDFS root directory pathhdfs_host(str): HDFS namenode hostnamehdfs_port(int): HDFS namenode porthdfs_user(str): HDFS usernamerank(int): Current machine rank. Default: 0world_size(int): Total number of machines. Default: 1
- Memory-efficient streaming of large datasets
- Configurable block sizes for optimal performance
- Better CPU cache utilization through locality
- Inter-block shuffle: Randomize the order of data blocks
- Intra-block shuffle: Shuffle samples within each block
- Result: Superior training randomness with controlled memory usage
Every sample includes source information (file_id, inner_index) enabling:
- Debugging data loading issues
- Reproducible training
- Data lineage tracking
- Choose appropriate block_size: Balance memory usage and shuffle quality
- Smaller blocks: Better shuffle, more memory overhead
- Larger blocks: Better performance, less randomness
- Worker Configuration: Set
num_workers > 0in DataLoader for parallel processing - HDFS Optimization: Use
multiprocessing_context='spawn'for HDFS datasets
| Example | Description | Use Case | Code Link |
|---|---|---|---|
| CIFAR-10 Demo | Complete training pipeline with CIFAR-10 dataset | Learning CorgiPile basics | examples/cifar_example.py |
| Distributed Training | Multi-machine setup with automatic data partitioning | Production distributed training | examples/distributed_example.py |
# Single-machine local training
python examples/cifar_example.py --mode local
# Single-machine with HDFS
python examples/cifar_example.py --mode hdfs --hdfs_host your-namenode# On Machine 0 (rank=0)
python examples/distributed_example.py --mode local --rank 0 --world_size 4
# On Machine 1 (rank=1)
python examples/distributed_example.py --mode local --rank 1 --world_size 4
# On Machine 2 (rank=2)
python examples/distributed_example.py --mode local --rank 2 --world_size 4
# On Machine 3 (rank=3)
python examples/distributed_example.py --mode local --rank 3 --world_size 4# Each machine in your cluster
python examples/distributed_example.py --mode hdfs --rank $RANK --world_size $WORLD_SIZE \
--hdfs_host namenode-host --hdfs_port 9000 --hdfs_user hadoop-userEach example includes detailed comments explaining the CorgiPile-specific parts!
@article{DBLP:journals/vldb/XuQYJRGKLLWYZ24,
author = {Lijie Xu and
Shuang Qiu and
Binhang Yuan and
Jiawei Jiang and
C{\'{e}}dric Renggli and
Shaoduo Gan and
Kaan Kara and
Guoliang Li and
Ji Liu and
Wentao Wu and
Jieping Ye and
Ce Zhang},
title = {Stochastic gradient descent without full data shuffle: with applications
to in-database machine learning and deep learning systems},
journal = {{VLDB} J.},
volume = {33},
number = {5},
pages = {1231--1255},
year = {2024},
url = {https://doi.org/10.1007/s00778-024-00845-0},
doi = {10.1007/S00778-024-00845-0},
timestamp = {Sat, 06 Sep 2025 20:29:54 +0200},
biburl = {https://dblp.org/rec/journals/vldb/XuQYJRGKLLWYZ24.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}










