Skip to content

Commit

Permalink
Add end_training/destroy_pg to everything and unpin numpy (#3030)
Browse files Browse the repository at this point in the history
* Add end_training/destroy_pg to everything

* Carry over to AcceleratorState

* If forked, ignore

* More numpy fun

* Skip only init
  • Loading branch information
muellerzr authored Aug 20, 2024
1 parent 7ffe766 commit 52fae09
Show file tree
Hide file tree
Showing 40 changed files with 72 additions and 19 deletions.
1 change: 1 addition & 0 deletions examples/by_feature/automatic_gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def inner_training_loop(batch_size):
# And call it at the end with no arguments
# Note: You could also refactor this outside of your training loop function
inner_training_loop()
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def training_function(config, args):
if args.output_dir is not None:
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def training_function(config, args):
preds = torch.stack(test_predictions, dim=0).sum(dim=0).div(int(args.num_folds)).argmax(dim=-1)
test_metric = metric.compute(predictions=preds, references=test_references)
accelerator.print("Average test metrics from all folds:", test_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/ddp_comm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,7 @@ def group_texts(examples):

with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity, "eval_loss": eval_loss.item()}, f)
accelerator.end_training()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ def training_function(config, args):

# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
3 changes: 1 addition & 2 deletions examples/by_feature/fsdp_with_peak_mem_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ def collate_fn(examples):
step=epoch,
)

if args.with_tracking:
accelerator.end_training()
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/megatron_lm_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def group_texts(examples):

with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump({"perplexity": perplexity}, f)
accelerator.end_training()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def inner_training_loop(batch_size):
# And call it at the end with no arguments
# Note: You could also refactor this outside of your training loop function
inner_training_loop()
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/multi_process_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/by_feature/schedule_free.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
6 changes: 1 addition & 5 deletions examples/by_feature/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,7 @@ def training_function(config, args):
step=epoch,
)

# New Code #
# When a run is finished, you should call `accelerator.end_training()`
# to close all of the open trackers
if args.with_tracking:
accelerator.end_training()
accelerator.end_training()


def main():
Expand Down
3 changes: 1 addition & 2 deletions examples/complete_cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,7 @@ def training_function(config, args):
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)

if args.with_tracking:
accelerator.end_training()
accelerator.end_training()


def main():
Expand Down
3 changes: 1 addition & 2 deletions examples/complete_nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,8 +256,7 @@ def collate_fn(examples):
output_dir = os.path.join(args.output_dir, output_dir)
accelerator.save_state(output_dir)

if args.with_tracking:
accelerator.end_training()
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/cv_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def training_function(config, args):
eval_metric = accurate.item() / num_elems
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}: {100 * eval_metric:.2f}")
accelerator.end_training()


def main():
Expand Down
1 change: 1 addition & 0 deletions examples/inference/pippy/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,4 @@
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()
1 change: 1 addition & 0 deletions examples/inference/pippy/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,4 @@
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()
1 change: 1 addition & 0 deletions examples/inference/pippy/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@
next_token_logits = output[0][:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
print(tokenizer.batch_decode(next_token))
PartialState().destroy_process_group()
1 change: 1 addition & 0 deletions examples/inference/pippy/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@
output = torch.stack(tuple(output[0]))
print(f"Time of first pass: {first_batch}")
print(f"Average time per batch: {(end_time - start_time) / 5}")
PartialState().destroy_process_group()
1 change: 1 addition & 0 deletions examples/nlp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def training_function(config, args):
eval_metric = metric.compute()
# Use accelerator.print to print only on the main process.
accelerator.print(f"epoch {epoch}:", eval_metric)
accelerator.end_training()


def main():
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
},
python_requires=">=3.8.0",
install_requires=[
"numpy>=1.17,<2.0.0",
"numpy>=1.17,<3.0.0",
"packaging>=20.0",
"psutil",
"pyyaml",
Expand Down
4 changes: 1 addition & 3 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2727,9 +2727,7 @@ def end_training(self):
for tracker in self.trackers:
tracker.finish()

if torch.distributed.is_initialized():
# needed when using torch.distributed.init_process_group
torch.distributed.destroy_process_group()
self.state.destroy_process_group()

def save(self, obj, f, safe_serialization=False):
"""
Expand Down
22 changes: 22 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,16 @@ def set_device(self):
self.device = torch.device(device, device_index)
device_module.set_device(self.device)

def destroy_process_group(self, group=None):
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
"""
if self.fork_launched and group is None:
return
# needed when using torch.distributed.init_process_group
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group(group)

def __getattr__(self, name: str):
# By this point we know that no attributes of `self` contain `name`,
# so we just modify the error message
Expand Down Expand Up @@ -983,6 +993,18 @@ def _reset_state(reset_partial_state: bool = False):
if reset_partial_state:
PartialState._reset_state()

def destroy_process_group(self, group=None):
"""
Destroys the process group. If one is not specified, the default process group is destroyed.
If `self.fork_lauched` is `True` and `group` is `None`, nothing happens.
"""
PartialState().destroy_process_group(group)

@property
def fork_launched(self):
return PartialState().fork_launched

@property
def use_distributed(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f:
json.dump(state, f)
accelerator.end_training()


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def main():
if accelerator.is_local_main_process:
print("**Test that `drop_last` is taken into account**")
test_gather_for_metrics_drop_last()
accelerator.end_training()
accelerator.state._reset_state()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f:
json.dump(train_total_peak_memory, f)
accelerator.end_training()


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def training_function(config, args):
if accelerator.is_main_process:
with open(os.path.join(args.output_dir, "all_results.json"), "w") as f:
json.dump(performance_metric, f)
accelerator.end_training()


def main():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,5 +125,6 @@ def test_resnet(batch_size: int = 2):
state.print("Testing CV model...")
test_resnet()
test_resnet(3)
state.destroy_process_group()
else:
print("Less than two GPUs found, not running tests!")
3 changes: 2 additions & 1 deletion src/accelerate/test_utils/scripts/test_ddp_comm_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import torch

from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs
from accelerate import Accelerator, DDPCommunicationHookType, DistributedDataParallelKwargs, PartialState


class MockModel(torch.nn.Module):
Expand Down Expand Up @@ -71,6 +71,7 @@ def main():
]:
print(f"Test DDP comm hook: {comm_hook}, comm wrapper: {comm_wrapper}")
test_ddp_comm_hook(comm_hook, comm_wrapper, comm_state_option)
PartialState().destroy_process_group()


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,8 @@ def main():
loader = DataLoader(dataset, sampler=sampler, batch_size=None, collate_fn=default_collate, num_workers=NUM_WORKERS)
test_data_loader(loader, accelerator)

accelerator.end_training()


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions src/accelerate/test_utils/scripts/test_merge_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,4 @@ def test_merge_weights_command_pytorch(model, path):
if accelerator.is_main_process:
shutil.rmtree(out_path)
accelerator.wait_for_everyone()
accelerator.end_training()
2 changes: 2 additions & 0 deletions src/accelerate/test_utils/scripts/test_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def main():
if is_bnb_available():
print("Test problematic imports (bnb)")
test_problematic_imports()
if NUM_PROCESSES > 1:
PartialState().destroy_process_group()


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/test_utils/scripts/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def main():
test_op_checker(state)
state.print("testing sending tensors across devices")
test_copy_tensor_to_devices(state)
state.destroy_process_group()


if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/test_utils/scripts/test_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,8 @@ def main():
print("\n**Test reinstantiated state**")
test_reinstantiated_state()

state.destroy_process_group()


if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader

from accelerate.accelerator import Accelerator, GradientAccumulationPlugin
from accelerate.accelerator import Accelerator, DataLoaderConfiguration, GradientAccumulationPlugin
from accelerate.state import GradientState
from accelerate.test_utils import RegressionDataset, RegressionModel
from accelerate.utils import DistributedType, set_seed
Expand Down Expand Up @@ -249,9 +249,9 @@ def test_gradient_accumulation_with_opt_and_scheduler(
split_batches=False, dispatch_batches=False, sync_each_batch=False
):
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch)
dataloader_config = DataLoaderConfiguration(split_batches=split_batches, dispatch_batches=dispatch_batches)
accelerator = Accelerator(
split_batches=split_batches,
dispatch_batches=dispatch_batches,
dataloader_config=dataloader_config,
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
# Test that context manager behaves properly
Expand Down Expand Up @@ -392,6 +392,7 @@ def main():
f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**",
)
test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch)
state.destroy_process_group()


def _mp_fn(index):
Expand Down
4 changes: 4 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@

import unittest

import numpy as np
from packaging import version

from accelerate import debug_launcher
from accelerate.test_utils import (
DEFAULT_LAUNCH_COMMAND,
Expand All @@ -29,6 +32,7 @@


@require_huggingface_suite
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "Test requires numpy version < 2.0")
class MetricTester(unittest.TestCase):
def setUp(self):
self.test_file_path = path_in_accelerate_package("test_utils", "scripts", "external_deps", "test_metrics.py")
Expand Down
2 changes: 2 additions & 0 deletions tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import numpy as np
import torch
from packaging import version

# We use TF to parse the logs
from accelerate import Accelerator
Expand Down Expand Up @@ -68,6 +69,7 @@

@require_tensorboard
class TensorBoardTrackingTest(unittest.TestCase):
@unittest.skipIf(version.parse(np.__version__) >= version.parse("2.0"), "TB doesn't support numpy 2.0")
def test_init_trackers(self):
project_name = "test_project_with_config"
with tempfile.TemporaryDirectory() as dirpath:
Expand Down

0 comments on commit 52fae09

Please sign in to comment.