Mosaic is a comprehensive memory profiling and analysis tool designed for PyTorch workloads. It offers a suite of utilities for analyzing Pytorch Memory Snapshots:
- Visualize memory usage patterns
- Identify memory peaks and bottlenecks
With Mosaic, users can better understand how their PyTorch models use memory. This makes it easier to identify problems and improve efficiency during both training and inference.
- Memory Peak Analysis: Identify when and where peak memory usage occurs during model execution
- Memory Usage Tracking: Identify memory usage at specific memory allocation
- Call Stack Analysis: Analyze memory usage by call stack to identify memory-intensive operations
- Custom Profiling: Create custom categorization rules to profile memory by specific code patterns
- Memory Comparison: Compare memory usage between different snapshots or code versions
- Annotation-based Analysis: Track memory usage across custom annotations and training stages
- Visualization: Generate interactive HTML visualizations of memory usage over time
From the mosaic directory:
pip install -e .Mosaic provides several command-line utilities:
mosaic-get-memory --snapshot <path_to_snapshot.pickle> --out-path <output.html> --profile <profile_type>Profile types:
annotations: Profile by external annotationscategories: Profile by PyTorch memory categoriescompile_context: Profile by torch.compile contextcustom: Profile using custom regex patterns
mosaic-get-json-snapshot --snapshot <path_to_snapshot.pickle> --output-file <output.json>mosaic_get_memory_usage --snapshot <path_to_snapshot.pickle> --allocation <address> --action <alloc|free>mosaic_get_memory_usage_peak --snapshot <path_to_snapshot.pickle>mosaic_get_memory_usage_diff --snapshot-base <base.pickle> --snapshot-diff <diff.pickle>mosaic_usage_by_annotation_stage --snapshot <path_to_snapshot.pickle> --annotation <annotation_name>from mosaic.libmosaic.analyzer.memory_abstract import MemoryAbstract
# Load and analyze a memory snapshot
memory_abstract = MemoryAbstract(memory_snapshot_file="snapshot.pickle")
memory_abstract.load_memory_snapshot()
# Analyze peak memory usage
memory_abstract.memory_snapshot.analyze_memory_snapshot(opt="memory_peak")
peak_memory = memory_abstract.memory_snapshot.memory_peak
print(f"Peak memory: {peak_memory / 1024**3:.2f} GiB")from mosaic.cmd.entry_point import get_memory_profile
# Define custom profiling rules
custom_rules = {
"Model Forward": ".*forward.*",
"Optimizer": ".*optimizer.*",
"Data Loading": ".*DataLoader.*"
}
# Generate profile with custom categories
get_memory_profile(
snapshot="snapshot.pickle",
out_path="profile.html",
profile="custom",
custom_profile=json.dumps(custom_rules)
)from mosaic.cmd.entry_point import get_memory_usage_by_annotation_stage
# Get memory usage at each training stage
memory_by_stage = get_memory_usage_by_annotation_stage(
snapshot="snapshot.pickle",
annotation=("forward", "backward", "optimizer"),
paste=False
)
for stage, (annotation, memory_bytes) in memory_by_stage.items():
print(f"{stage}: {memory_bytes / 1024**3:.2f} GiB")MemoryAbstract: High-level interface for memory analysisMemorySnapshot: Core snapshot analysis and processinggpu_trace: GPU trace event analysis
snapshot_loader: Load PyTorch memory snapshotssnapshot_utils: Utilities for snapshot manipulationplotting: Generate interactive visualizationsdata_utils: Data structures for memory events
Command-line entry points for various memory analysis tasks
- Identify Memory Bottlenecks: Find which operations consume the most memory during training
- Optimize Model Memory: Analyze memory usage patterns to reduce peak memory consumption
- Debug OOM Errors: Understand what causes out-of-memory errors in your models
- Compare Memory Impact: Compare memory usage before and after code changes
- Profile Large-Scale Training: Analyze memory patterns in distributed training workloads
- Valid Memory Snapshot(s) generated from PyTorch
- Python 3.8+
- Additional dependencies specified in setup.py
See the CONTRIBUTING file for how to help out.
Mosaic has a BSD-style license, as found in the LICENSE file.