diff --git a/megatron/checkpointing.py b/megatron/checkpointing.py index 1b6909c9f..97a79366a 100644 --- a/megatron/checkpointing.py +++ b/megatron/checkpointing.py @@ -96,7 +96,9 @@ def do_forward_pass(neox_args, model, inference=False): tokens, attention_mask, position_ids = get_batch( neox_args, context_tokens_tensor[:, : neox_args.seq_length] ) - logits = model((tokens, position_ids, attention_mask)) + output = model((tokens, position_ids, attention_mask)) + logits = output[0] if isinstance(output, tuple) else output + # reset to train mode, if model was in training before if model_was_in_train: diff --git a/megatron/logging.py b/megatron/logging.py index af8a41fe5..37c96e125 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -13,6 +13,7 @@ # limitations under the License. import sys +import os import torch @@ -26,6 +27,7 @@ import math +''' class Tee: """Duplicate output to both stdout/err and file""" @@ -61,6 +63,58 @@ def flush(self) -> None: self.file.flush() except OSError: pass +''' + +class Tee: + """Duplicate output to both stdout/err and file""" + + def __init__(self, file, err: bool = False) -> None: + self.err = err + self.std = sys.stderr if err else sys.stdout + + if isinstance(file, str): + try: + # Ensure the directory exists if file is a path + os.makedirs(os.path.dirname(file), exist_ok=True) + self.file = open(file, "w") + except IOError as e: + print(f"Warning: Could not open file {file} for writing. {str(e)}", file=self.std) + self.file = None + elif hasattr(file, 'write') and hasattr(file, 'flush'): + # If it's a file-like object, use it directly + self.file = file + else: + raise ValueError("'file' must be either a file path or a file-like object") + + if not err: + sys.stdout = self + else: + sys.stderr = self + + def __del__(self) -> None: + if not self.err: + sys.stdout = self.std + else: + sys.stderr = self.std + + if self.file and hasattr(self.file, 'close'): + self.file.close() + + def write(self, data) -> None: + self.std.write(data) + if self.file: + try: + self.file.write(data) + except IOError as e: + print(f"Warning: Could not write to file. {str(e)}", file=self.std) + + def flush(self) -> None: + self.std.flush() + if self.file: + try: + self.file.flush() + except IOError as e: + print(f"Warning: Could not flush file. {str(e)}", file=self.std) def human_readable_flops(num) -> str: diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index fa475c057..5960ca232 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -59,6 +59,7 @@ OKAY = f"{GREEN}[OKAY]{END}" WARNING = f"{YELLOW}[WARNING]{END}" FAIL = f"{RED}[FAIL]{END}" +ERROR = f"{RED}[ERROR]{END}" INFO = "[INFO]" # ZERO defaults by deespeed @@ -875,7 +876,6 @@ def calculate_derived(self): """ Derives additional configuration values necessary for training from the current config """ - # number of gpus # Get number of GPUs param or hostfile to determine train_batch_size global_num_gpus = getattr(self, "global_num_gpus", None) @@ -896,6 +896,7 @@ def calculate_derived(self): else: global_num_gpus = torch.cuda.device_count() self.update_value("global_num_gpus", global_num_gpus) + logging.info( self.__class__.__name__ diff --git a/tests/common.py b/tests/common.py index c63ced0f7..893476a42 100644 --- a/tests/common.py +++ b/tests/common.py @@ -16,6 +16,8 @@ import time import shutil import itertools +import inspect +import subprocess from pathlib import Path from abc import ABC, abstractmethod from deepspeed.accelerator import get_accelerator @@ -48,6 +50,14 @@ DEEPSPEED_UNIT_WORKER_TIMEOUT = 120 DEEPSPEED_TEST_TIMEOUT = 600 +def is_rocm_pytorch(): + """ + Check if the current PyTorch installation is using ROCm. + + Returns: + bool: True if PyTorch is using ROCm, False otherwise. + """ + return hasattr(torch.version, 'hip') and torch.version.hip is not None def get_xdist_worker_id(): xdist_worker = os.environ.get("PYTEST_XDIST_WORKER", None) @@ -67,7 +77,6 @@ def get_master_port(): _num_gpus = None - def set_accelerator_visible(): cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", None) xdist_worker_id = get_xdist_worker_id() @@ -123,8 +132,6 @@ def set_accelerator_visible(): def count_gpus(): global _num_gpus if _num_gpus is None: - import subprocess - nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"]) _num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n")) return _num_gpus @@ -137,8 +144,6 @@ def set_cuda_visibile(): xdist_worker_id = 0 if cuda_visible is None: # CUDA_VISIBLE_DEVICES is not set, discover it from nvidia-smi instead - import subprocess - nvidia_smi = subprocess.check_output(["nvidia-smi", "--list-gpus"]) num_gpus = len(nvidia_smi.decode("utf-8").strip().split("\n")) cuda_visible = ",".join(map(str, range(num_gpus))) @@ -428,9 +433,7 @@ def test_2(self, val1, val2, val3, val4): assert int(os.environ["WORLD_SIZE"]) == 1 assert all(val1, val2, val3, val4) """ - - def __init__(self): - self.is_dist_test = True + is_dist_test = True # Temporary directory that is shared among test methods in a class @pytest.fixture(autouse=True, scope="class") @@ -476,7 +479,7 @@ def get_test_path(filename): def model_setup(yaml_list=None, param_dict=None, clear_data=True): from megatron.neox_arguments import NeoXArgs from megatron.mpu import destroy_model_parallel - from megatron import initialize_megatron + from megatron.initialize import initialize_megatron from megatron.training import setup_model_and_optimizer destroy_model_parallel() # mpu model parallel contains remaining global vars @@ -509,10 +512,10 @@ def model_setup(yaml_list=None, param_dict=None, clear_data=True): args_loaded.build_tokenizer() initialize_megatron(neox_args=args_loaded) - model, optimizer, lr_scheduler = setup_model_and_optimizer( + model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( neox_args=args_loaded, use_cache=True ) - return model, optimizer, lr_scheduler, args_loaded + return model, optimizer, lr_scheduler, reference_model, args_loaded def simulate_deepy_env(monkeypatch, input_args): diff --git a/tests/model/test_fused_kernels.py b/tests/model/test_fused_kernels.py index 125eb6c52..ce48390bc 100644 --- a/tests/model/test_fused_kernels.py +++ b/tests/model/test_fused_kernels.py @@ -30,6 +30,7 @@ ) +@pytest.mark.forked @pytest.mark.xfail(reason="SystemExit: None") def test_load_fused_kernels(): load() @@ -45,6 +46,7 @@ def test_load_fused_kernels(): raise e +@pytest.mark.forked @pytest.mark.xfail(reason="SystemExit: None") def test_fused_softmax(): load() @@ -148,6 +150,7 @@ def test_fused_softmax(): ) +@pytest.mark.forked @pytest.mark.xfail(reason="SystemExit: None") def test_fused_upper_triangle_mask_softmax(): load() diff --git a/tests/model/test_model_checkpoint.py b/tests/model/test_model_checkpoint.py index 96f51683b..7bd108d61 100644 --- a/tests/model/test_model_checkpoint.py +++ b/tests/model/test_model_checkpoint.py @@ -33,7 +33,8 @@ import torch PARAMS_TO_TEST = { - "pipe_parallel_size,model_parallel_size": [[0, 1], [1, 2], [0, 2], [2, 1]], + "include":["localhost:0,1"], + "pipe_parallel_size,model_parallel_size": [[1, 2], [0, 2], [2, 1]], "checkpoint_validation_with_forward_pass": [True], "fp16,fp32_allreduce": [ [ @@ -61,31 +62,22 @@ } parameters, names = parametrize( - PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None + PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=42 ) +class TestModelCheckpoint(DistributedTest): + world_size = 2 -@pytest.mark.skip -@pytest.mark.parametrize("param_dict", parameters, ids=names) -def test_train(param_dict): - import tempfile - - d = tempfile.mkdtemp() - param_dict["save"] = d - - t1 = test_run_checkpoint_test_class() - t1.run_checkpoint_test(param_dict=param_dict) - - -class test_run_checkpoint_test_class(DistributedTest): - def run_checkpoint_test(yaml_list=None, param_dict=None): - + @pytest.mark.parametrize("param_dict", parameters, ids=names) + def test_checkpoint(self, param_dict, tmpdir): from megatron.checkpointing import load_checkpoint from megatron.checkpointing import save_checkpoint + print("BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB") - model, optimizer, lr_scheduler, args_loaded = model_setup( - yaml_list, param_dict, clear_data=True + model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup( + yaml_list=None, param_dict=param_dict, clear_data=True ) + print("CCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCCC") # save model checkpoint save_checkpoint( @@ -101,8 +93,9 @@ def run_checkpoint_test(yaml_list=None, param_dict=None): reloaded_model, reloaded_optimizer, reloaded_lr_scheduler, + reloaded_reference_model, args_reloaded, - ) = model_setup(yaml_list, param_dict, clear_data=False) + ) = model_setup(yaml_list=None, param_dict=param_dict, clear_data=False) iteration = load_checkpoint( neox_args=args_reloaded, model=reloaded_model, @@ -111,9 +104,7 @@ def run_checkpoint_test(yaml_list=None, param_dict=None): ) # ensure same checkpoint is loaded - assert ( - iteration == 42 - ), "run_checkpoint_test() iteration loaded from checkpoint correct" + assert iteration == 42, "Iteration loaded from checkpoint is incorrect" # check all weight groups are the same for idx, ((n1, p1), (n2, p2)) in enumerate( @@ -123,14 +114,8 @@ def run_checkpoint_test(yaml_list=None, param_dict=None): ) ): assert n1 == n2 - params_equal = (p1 == p2).all().item() - assert params_equal, "run_checkpoint_test() params equal: " + str(n1) - + params_equal = torch.all(p1 == p2).item() + assert params_equal, f"Parameters not equal: {n1}" -if __name__ == "__main__": - params = list( - parametrize( - PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None - ) - ) - test_train(params[0]) + # Clean up + del model, reloaded_model \ No newline at end of file diff --git a/tests/model/test_model_generation.py b/tests/model/test_model_generation.py index 6dd93f355..093c174c3 100644 --- a/tests/model/test_model_generation.py +++ b/tests/model/test_model_generation.py @@ -25,6 +25,7 @@ from tests.common import DistributedTest, model_setup, parametrize PARAMS_TO_TEST = { + "include":["localhost:0,1"], "pipe_parallel_size,model_parallel_size,world_size": [ [0, 1, 1], [0, 1, 2], @@ -63,18 +64,11 @@ PARAMS_TO_TEST, max_tests=int(os.getenv("MAX_TESTCASES", 50)), seed=None ) - -@pytest.mark.skip -@pytest.mark.parametrize("param_dict", parameters, ids=names) -def test_train(param_dict): - t1 = run_generate_test_class() - t1.run_generate_test(param_dict, param_dict.pop("prompt")) - - -class run_generate_test_class(DistributedTest): +class TestModelGeneration(DistributedTest): world_size = 2 - def run_generate_test(param_dict, prompt): + @pytest.mark.parametrize("param_dict", parameters, ids=names) + def test_generate(self, param_dict, tmpdir): from megatron.text_generation_utils import generate_samples_from_prompt from megatron.utils import is_mp_rank_0 @@ -89,10 +83,10 @@ def run_generate_test(param_dict, prompt): } param_dict.update(fixed_params) - # TODO: we don't need to reinstantiate the model every time if we're only changing sampling settings - should be a workaround for this - model, _, _, args_loaded = model_setup(None, param_dict, clear_data=True) + model, _, _, _, args_loaded = model_setup(None, param_dict, clear_data=True) model.eval() + prompt = param_dict.pop("prompt") prompts = [prompt for _ in range(args_loaded.num_samples)] output = generate_samples_from_prompt( neox_args=args_loaded, @@ -111,3 +105,6 @@ def run_generate_test(param_dict, prompt): for prompt, out in zip(prompts, output): assert prompt == out["context"] assert len(out["text"]) > 0 + + # Clean up + del model diff --git a/tests/model/test_model_instantiation.py b/tests/model/test_model_instantiation.py index 81c5cae4c..8adb70148 100644 --- a/tests/model/test_model_instantiation.py +++ b/tests/model/test_model_instantiation.py @@ -115,7 +115,7 @@ class test_instantiate_optimizers_class(DistributedTest): def run_test_model_instantiation(yaml_list=None, param_dict=None): from deepspeed.runtime.pipe.engine import PipelineEngine, DeepSpeedEngine - model, optimizer, lr_scheduler, args_loaded = model_setup(yaml_list, param_dict) + model, optimizer, lr_scheduler, reference_model, args_loaded = model_setup(yaml_list, param_dict) if args_loaded.pipe_parallel_size < 2: assert isinstance( model, DeepSpeedEngine diff --git a/tests/model/test_model_train.py b/tests/model/test_model_train.py index 65adfcdee..d05b00650 100644 --- a/tests/model/test_model_train.py +++ b/tests/model/test_model_train.py @@ -48,10 +48,6 @@ keys_to_test = PARAMS_TO_TEST.keys() -# TODO: fix model training tests -@pytest.mark.skip( - reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." -) @pytest.mark.parametrize( "key, value", [(key, value) for key in keys_to_test for value in PARAMS_TO_TEST[key]], diff --git a/tests/unit/test_format_conversion_scripts.py b/tests/unit/test_format_conversion_scripts.py index 6935e480a..93d0fc380 100644 --- a/tests/unit/test_format_conversion_scripts.py +++ b/tests/unit/test_format_conversion_scripts.py @@ -4,9 +4,6 @@ from megatron.neox_arguments.neox_args import NeoXArgsTokenizer -@pytest.mark.skip( - reason="Conversion test is skipped until we fix the CUDA + torch multiprocessing issue." -) def test_gpt_neox_to_huggingface(monkeypatch, tmpdir, tmp_path): # Generate random GPT-NEOX model, check we can convert to hf format diff --git a/tests/unit/test_launcher_scripts.py b/tests/unit/test_launcher_scripts.py index bdc38f111..a7f96b21a 100644 --- a/tests/unit/test_launcher_scripts.py +++ b/tests/unit/test_launcher_scripts.py @@ -56,9 +56,6 @@ def test_preprocess_data(tokenizer_type): preprocess_data.main(input_args) -@pytest.mark.skip( - reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." -) def test_generate(monkeypatch, tmpdir, tmp_path, sample_input_file): model_dir = str(tmpdir) sample_output_file = str(tmp_path) + ".txt" @@ -75,9 +72,6 @@ def test_generate(monkeypatch, tmpdir, tmp_path, sample_input_file): generate.main(input_args=deepspeed_main_args, overwrite_values=generate_args) -@pytest.mark.skip( - reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." -) def test_evaluate(monkeypatch, tmpdir, tmp_path): model_dir = str(tmpdir) sample_output_file = str(tmp_path) @@ -94,9 +88,9 @@ def test_evaluate(monkeypatch, tmpdir, tmp_path): eval.main(input_args=deepspeed_main_args, overwrite_values=evaluate_args) -@pytest.mark.skip( - reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." -) +#@pytest.mark.skip( +# reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." +#) def test_finetuning(monkeypatch, tmpdir, tmp_path): # Save random model, load random model, keep training # TODO: add mocking to check that we're not ignoring the previously loaded model @@ -111,9 +105,9 @@ def test_finetuning(monkeypatch, tmpdir, tmp_path): train.main(input_args=deepspeed_main_args, overwrite_values=finetune_args) -@pytest.mark.skip( - reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." -) +#@pytest.mark.skip( +# reason="All model tests are skipped until we fix the CUDA + torch multiprocessing issue." +#) def test_train_launcher(monkeypatch): input_args = ["train.py", "tests/config/test_setup.yml"] deepspeed_main_args = simulate_deepy_env(monkeypatch, input_args) diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 8dfe02d54..ae480dd2d 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -444,10 +444,12 @@ def reshard_and_split_qkv( def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential): """Determine whether the checkpoint uses the legacy or new MLP naming convention.""" - print(list(loaded_tp_ranks[0]["module"].keys())) + for state_dict in loaded_tp_ranks: + print("------------------------------") + print(state_dict.keys()) if any( [ - ["mlp.linear1.weight" in key for key in list(state_dict["module"].keys())] + ["mlp.linear1.weight" in key for key in list(state_dict.keys())] for state_dict in loaded_tp_ranks ] ): @@ -456,7 +458,7 @@ def get_mlp_naming_convention(loaded_tp_ranks, layer_idx, sequential): [ [ "mlp.dense_h_to_4h.weight" in key - for key in list(state_dict["module"].keys()) + for key in list(state_dict.keys()) ] for state_dict in loaded_tp_ranks ]