diff --git a/src/alignment/data.py b/src/alignment/data.py index 8d3bb314..c60be140 100644 --- a/src/alignment/data.py +++ b/src/alignment/data.py @@ -13,7 +13,7 @@ # limitations under the License. import logging - +import os import datasets from datasets import DatasetDict, concatenate_datasets @@ -23,30 +23,45 @@ logger = logging.getLogger(__name__) + def get_dataset(args: ScriptArguments) -> DatasetDict: """Load a dataset or a mixture of datasets based on the configuration. - + Args: args (ScriptArguments): Script arguments containing dataset configuration. - + Returns: DatasetDict: The loaded datasets. """ if args.dataset_name and not args.dataset_mixture: logger.info(f"Loading dataset: {args.dataset_name}") - return datasets.load_dataset(args.dataset_name, args.dataset_config) + # Check if it's a local path + if os.path.exists(args.dataset_name): + logger.info(f"Loading local dataset from disk: {args.dataset_name}") + return datasets.load_from_disk(args.dataset_name) + else: + return datasets.load_dataset(args.dataset_name, args.dataset_config) elif args.dataset_mixture: logger.info(f"Creating dataset mixture with {len(args.dataset_mixture.datasets)} datasets") seed = args.dataset_mixture.seed datasets_list = [] - for dataset_config in args.dataset_mixture.datasets: logger.info(f"Loading dataset for mixture: {dataset_config.id} (config: {dataset_config.config})") - ds = datasets.load_dataset( - dataset_config.id, - dataset_config.config, - split=dataset_config.split, - ) + + # Check if it's a local path + if os.path.exists(dataset_config.id): + logger.info(f"Loading local dataset from disk: {dataset_config.id}") + ds = datasets.load_from_disk(dataset_config.id) + # Handle split if specified + if dataset_config.split and isinstance(ds, DatasetDict): + ds = ds[dataset_config.split] + else: + ds = datasets.load_dataset( + dataset_config.id, + dataset_config.config, + split=dataset_config.split, + ) + if dataset_config.columns is not None: ds = ds.select_columns(dataset_config.columns) if dataset_config.weight is not None: @@ -54,14 +69,11 @@ def get_dataset(args: ScriptArguments) -> DatasetDict: logger.info( f"Subsampled dataset '{dataset_config.id}' (config: {dataset_config.config}) with weight={dataset_config.weight} to {len(ds)} examples" ) - datasets_list.append(ds) - if datasets_list: combined_dataset = concatenate_datasets(datasets_list) combined_dataset = combined_dataset.shuffle(seed=seed) logger.info(f"Created dataset mixture with {len(combined_dataset)} examples") - if args.dataset_mixture.test_split_size is not None: combined_dataset = combined_dataset.train_test_split( test_size=args.dataset_mixture.test_split_size, seed=seed @@ -74,6 +86,5 @@ def get_dataset(args: ScriptArguments) -> DatasetDict: return DatasetDict({"train": combined_dataset}) else: raise ValueError("No datasets were loaded from the mixture configuration") - else: raise ValueError("Either `dataset_name` or `dataset_mixture` must be provided")