Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 79 additions & 10 deletions src/megatron/energon/epathlib/epath.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import os
import re
import shutil
from pathlib import Path as PathlibPath
from pathlib import PurePosixPath
from typing import BinaryIO, Generator, Literal, Optional, TextIO, Tuple, Union, overload
from typing import (
BinaryIO,
Generator,
Literal,
Optional,
TextIO,
Tuple,
Union,
overload,
)

import multistorageclient as msc

Expand All @@ -31,7 +41,9 @@ class EPath:
for more information.
"""

# The path without the protocol. Can also be in S3 for example
__slots__ = ("internal_path", "profile", "fs")

# The path without the protocol/profile. Can also be in S3 for example
internal_path: PurePosixPath
# The profile used to access the file system
profile: str
Expand All @@ -52,11 +64,29 @@ def __init__(
profile = DEFAULT_PROFILE_NAME
else:
protocol, profile, path = self._split_protocol(initial_path)
if protocol is None or protocol == "file":
if protocol is None:
# Just a local absolute/relative path
assert profile is None
profile = DEFAULT_PROFILE_NAME
path = str(PathlibPath(path).absolute())
elif protocol == "file":
# A file:// path, e.g. file:///home/user/file.txt (absolute) or file://file.txt (relative)
assert profile is not None
path = profile + "/" + path
profile = DEFAULT_PROFILE_NAME
path = str(PathlibPath(path).absolute())
elif protocol == "rclone":
warn_deprecated("rclone:// protocol is deprecated. Use msc:// instead.")
elif protocol == "dss":
# Profile corresponds to the dataset name and version
assert profile is not None
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
self.fs = NVDATASET_CACHE_DIR.fs
self.profile = "dss"
self.internal_path = self._resolve(f"/{profile}/{path}")
return
else:
assert protocol == "msc", f"Unknown protocol: {protocol}"
if not path.startswith("/"):
Expand All @@ -77,7 +107,13 @@ def __getstate__(self) -> dict:
def __setstate__(self, state: dict) -> None:
self.internal_path = state["internal_path"]
self.profile = state["profile"]
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")
if self.profile == "dss":
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
self.fs = NVDATASET_CACHE_DIR.fs
else:
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")

@staticmethod
def _resolve(path: Union[str, PurePosixPath]) -> PurePosixPath:
Expand All @@ -103,16 +139,28 @@ def _resolve(path: Union[str, PurePosixPath]) -> PurePosixPath:

@staticmethod
def _split_protocol(path: str) -> Tuple[Optional[str], Optional[str], str]:
regex = re.compile(r"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)/(?P<path>.+)$")
regex = re.compile(r"^(?P<protocol>[a-z]+)://(?P<profile>[^/]+?)(?:/(?P<path>.*))?$")
m = regex.match(path)
if m is None:
return None, None, path
return m.group("protocol"), m.group("profile"), m.group("path")
inner_path = m.group("path")
if not inner_path:
inner_path = ""
return m.group("protocol"), m.group("profile"), inner_path

@property
def _internal_str_path(self) -> str:
"""Return the path as used inside the file system, without the protocol and fs part."""
return str(self.internal_path)
"""Return the path as used inside the file system, without the protocol and fs part.
This is for usage with `self.fs` functions."""
if self.profile == "dss":
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
# Applying a trick for efficiency:
# The internal path is relative to the NVDATASET_CACHE_DIR (i.e. strip the leading /, then concat with /)
return NVDATASET_CACHE_DIR._internal_str_path + str(self.internal_path)
else:
return str(self.internal_path)

@overload
def open(
Expand Down Expand Up @@ -189,14 +237,21 @@ def parent(self) -> "EPath":

@property
def url(self) -> str:
if self.is_local():
if self.is_pure_local():
return self._internal_str_path
int_path_str = str(self.internal_path)
if self.profile == "dss":
if int_path_str.startswith("/"):
int_path_str = int_path_str[1:]
return f"dss://{int_path_str}"
return f"msc://{self.profile}{int_path_str}"

def is_local(self) -> bool:
def is_pure_local(self) -> bool:
return self.profile == DEFAULT_PROFILE_NAME

def is_local(self) -> bool:
return self.profile == DEFAULT_PROFILE_NAME or self.profile == "dss"

def is_dir(self) -> bool:
try:
return self.fs.info(self._internal_str_path).type == "directory"
Expand Down Expand Up @@ -240,6 +295,14 @@ def relative_to(self, other: "EPath") -> str:

return str(self.internal_path.relative_to(other.internal_path))

@property
def display_name(self) -> str:
if self.profile == "dss":
# Use the ds name for DSS paths
return self.internal_path.parents[-2].name
# Use the name for other paths
return self.name

def __truediv__(self, other: Union[str, "EPath"]) -> "EPath":
if isinstance(other, EPath):
# Always absolute
Expand Down Expand Up @@ -271,3 +334,9 @@ def __eq__(self, other: object) -> bool:
and self.internal_path == other.internal_path
and self.profile == other.profile
)


if "NVDATASET_CACHE_DIR" in os.environ:
NVDATASET_CACHE_DIR = EPath(os.environ["NVDATASET_CACHE_DIR"])
else:
NVDATASET_CACHE_DIR = None
1 change: 1 addition & 0 deletions src/megatron/energon/flavors/jsonl/crude_jsonl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
assert self.__sample_type__ is not None, f"Class {type(self)} must define __sample_type__"
self.path = path
self.paths = [path]
self.name = path.display_name
self.training = training
self.worker_config = worker_config
self.shuffle_over_epochs = shuffle_over_epochs
Expand Down
1 change: 1 addition & 0 deletions src/megatron/energon/flavors/webdataset/base_webdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(
)
self.path = path
self.paths = [path]
self.name = path.display_name
self.shards = wds_meta.shards
self.sample_excludes = wds_meta.sample_excludes
self.split_part_files = wds_meta.split_part_files
Expand Down
6 changes: 5 additions & 1 deletion src/megatron/energon/metadataset/metadataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,11 @@ def post_initialize(self, mds_path: Optional[EPath] = None) -> None:
if m.group("protocol") == "filesystem":
new_aux[k] = AuxFilesystemReference(fs_path=m.group("path"))
else:
raise ValueError(f"Unsupported protocol: {m.group('protocol')}")
epath = EPath(v)
if epath.is_local():
new_aux[k] = AuxFilesystemReference(fs_path=epath)
else:
raise ValueError(f"Unsupported remote aux dataset: {v}")
else:
new_aux[k] = AuxDatasetReference(path=v)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_dataset_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_split_parts(self):
"000004", # Shard 2
"000005",
"000006",
]
], f"All keys: {all_keys}"

def test_text_dataset(self):
worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0)
Expand Down
80 changes: 80 additions & 0 deletions tests/test_epathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import pickle
import struct
import sys
import tempfile
import unittest
from pathlib import Path

from megatron.energon.epathlib import DEFAULT_PROFILE_NAME, EPath
from tests.epath_s3_emulator import setup_s3_emulator
Expand Down Expand Up @@ -279,6 +281,84 @@ def test_msc_s3(self):
# assert not EPath("msc://s3test_msc/test").is_dir()
assert not EPath("msc://s3test_msc/test/dir").is_dir()

def test_metadataset_v2_dss_path_parsing_str(self):
"""Parse a MetadatasetV2 config and ensure DSS URLs stringify correctly as EPath."""

# Ensure DSS paths are supported even when tests are run without NVDATASET_CACHE_DIR set
# at process start (EPath reads it at import time).
import megatron.energon.epathlib.epath as epath_mod
from megatron.energon.metadataset.loader import load_dataset

yaml_text = "\n".join(
[
"__module__: megatron.energon",
"__class__: MetadatasetV2",
"splits:",
" train:",
" blend:",
" - path: dss://charts1234_zh@v0",
" weight: 1",
" aux:",
" media_source: dss://charts1234@v0",
]
)

orig_env_cache_dir = os.environ.get("NVDATASET_CACHE_DIR")
orig_mod_cache_dir = epath_mod.NVDATASET_CACHE_DIR

with tempfile.TemporaryDirectory() as td:
td_path = Path(td)
cache_dir = td_path / "nvds_cache"
cache_dir.mkdir(parents=True, exist_ok=True)

# Create dummy DSS datasets in the cache dir so that `load_dataset()` can run
# post-initialization without hitting missing-path errors.
#
# - charts1234_zh@v0: minimal "webdataset" marker (presence of .nv-meta/.info.json)
# - charts1234@v0: folder with images (aux media source)
webdataset_root = cache_dir / "charts1234_zh@v0"
(webdataset_root / ".nv-meta").mkdir(parents=True, exist_ok=True)
(webdataset_root / ".nv-meta" / ".info.json").write_text("{}", encoding="utf-8")

media_root = cache_dir / "charts1234@v0"
(media_root / "images").mkdir(parents=True, exist_ok=True)
(media_root / "images" / "000.jpg").write_bytes(b"\xff\xd8\xff\xd9")
(media_root / "images" / "001.jpg").write_bytes(b"\xff\xd8\xff\xd9")

mds_yaml_path = td_path / "metadataset_v2_dss.yaml"
mds_yaml_path.write_text(yaml_text, encoding="utf-8")

try:
os.environ["NVDATASET_CACHE_DIR"] = str(cache_dir)
epath_mod.NVDATASET_CACHE_DIR = EPath(cache_dir)

mds_path = EPath(mds_yaml_path)
mds = load_dataset(mds_path)

train = mds.splits["train"]
from megatron.energon.metadataset.metadataset_v2 import AuxFilesystemReference

assert isinstance(train.blend[0].path, EPath)
ds0 = train.blend[0].path

assert train.blend[0].aux is not None
aux_ref0 = train.blend[0].aux["media_source"]
assert isinstance(aux_ref0, AuxFilesystemReference)
assert isinstance(aux_ref0.fs_path, EPath)
aux0 = aux_ref0.fs_path

for p in (ds0, aux0):
print(f"Dataset: {str(p)}, url: {p.url}")

assert ds0.url == "dss://charts1234_zh@v0"
assert aux0.url == "dss://charts1234@v0"
finally:
if orig_env_cache_dir is None:
os.environ.pop("NVDATASET_CACHE_DIR", None)
else:
os.environ["NVDATASET_CACHE_DIR"] = orig_env_cache_dir
epath_mod.NVDATASET_CACHE_DIR = orig_mod_cache_dir


def _multiproc_test_func(p: EPath, test_function: bool):
"""Helper function for multiprocessing test"""
Expand Down
5 changes: 4 additions & 1 deletion tests/test_file_cache_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,10 @@ def worker(filename):
lazy_ref = pool.get_lazy(mock_raw_file_store, filename)
sample_for_source_info = {"__sources__": ()}
result = lazy_ref.get(sample_for_source_info)
assert sample_for_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path()
assert (
sample_for_source_info["__sources__"][0].dataset_path
== mock_raw_file_store.get_path()
)
assert sample_for_source_info["__sources__"][0].index is None
assert sample_for_source_info["__sources__"][0].shard_name is None
assert sample_for_source_info["__sources__"][0].file_names == (filename,)
Expand Down
Loading