diff --git a/src/data/__init__.py b/src/data/__init__.py index c24b0b03..aa39bcbb 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -46,7 +46,7 @@ def get_datasets(dataset_cfgs: Union[Dict, DictConfig], **kwargs): return dataset -def get_data(data_cfg: DictConfig, mode="train", **kwargs): +def get_data(data_cfg: DictConfig, mode="train", seed=0, **kwargs): data = {} data_cfg = dict(data_cfg) anchor = data_cfg.pop("anchor", "forget") @@ -56,7 +56,9 @@ def get_data(data_cfg: DictConfig, mode="train", **kwargs): return data elif mode == "unlearn": unlearn_splits = {k: v for k, v in data.items() if k not in ("eval", "test")} - unlearn_dataset = ForgetRetainDataset(**unlearn_splits, anchor=anchor) + unlearn_dataset = ForgetRetainDataset( + **unlearn_splits, anchor=anchor, seed=seed + ) data["train"] = unlearn_dataset for split in unlearn_splits: data.pop(split) diff --git a/src/data/unlearn.py b/src/data/unlearn.py index 0cb0bada..dff81be7 100644 --- a/src/data/unlearn.py +++ b/src/data/unlearn.py @@ -4,17 +4,19 @@ class ForgetRetainDataset(Dataset): # https://github.com/OPTML-Group/SOUL/blob/main/src/dataset/Base.py - def __init__(self, forget, retain, anchor="forget"): + def __init__(self, forget, retain, anchor="forget", seed=0): """Wraps the forget retain dataset into unlearning dataset. Args: forget (Dataset): Forget Dataset retain (Dataset): Retain Dataset anchor (str, optional): Specifies which dataset to anchor while randomly sampling from the other dataset. Defaults to 'forget'. + seed (int, optional): Random seed for reproducibility. Defaults to 0. """ self.forget = forget self.retain = retain self.anchor = anchor + self.seed = seed def __len__(self): """Ensures the sampled dataset matches the anchor dataset's length.""" @@ -33,14 +35,22 @@ def __len__(self): def __getitem__(self, idx): item = {} + g = torch.Generator() + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank_seed = self.seed + rank + idx + g.manual_seed(rank_seed) if self.anchor == "forget": item["forget"] = self.forget[idx] if self.retain: - retain_idx = torch.randint(0, len(self.retain), (1,)).item() + retain_idx = torch.randint( + 0, len(self.retain), (1,), generator=g + ).item() item["retain"] = self.retain[retain_idx] elif self.anchor == "retain": item["retain"] = self.retain[idx] if self.forget: - forget_idx = torch.randint(0, len(self.forget), (1,)).item() + forget_idx = torch.randint( + 0, len(self.forget), (1,), generator=g + ).item() item["forget"] = self.forget[forget_idx] return item diff --git a/src/train.py b/src/train.py index a2f81c8d..5e8f6db5 100644 --- a/src/train.py +++ b/src/train.py @@ -23,7 +23,11 @@ def main(cfg: DictConfig): # Load Dataset data_cfg = cfg.data data = get_data( - data_cfg, mode=mode, tokenizer=tokenizer, template_args=template_args + data_cfg, + mode=mode, + tokenizer=tokenizer, + template_args=template_args, + seed=cfg.trainer.args.seed, ) # Load collator