Skip to content

Commit

Permalink
DeepSpeed refactoring (#2313)
Browse files Browse the repository at this point in the history
* DeepSpeed refactoring

Co-Authored-By: Stas Bekman <[email protected]>

* add tests

* Update test_deepspeed.py

* Update test_deepspeed.py

---------

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
pacman100 and stas00 authored Jan 9, 2024
1 parent 4420ec6 commit 411aa58
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 14 deletions.
17 changes: 5 additions & 12 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,7 @@ def _prepare_deepspeed(self, *args):
for obj in args
]

if deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"] == "auto":
if deepspeed_plugin.is_auto("train_micro_batch_size_per_gpu"):
if is_dataloader_present:
batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
if any(bs is None for bs in batch_sizes):
Expand All @@ -1445,7 +1445,7 @@ def _prepare_deepspeed(self, *args):
"or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
)
else:
batch_size_per_device = deepspeed_plugin.deepspeed_config["train_micro_batch_size_per_gpu"]
batch_size_per_device = deepspeed_plugin.get_value("train_micro_batch_size_per_gpu")

# handle `gradient_accumulation_steps` when the value is `auto`
deepspeed_plugin.fill_match(
Expand All @@ -1457,7 +1457,7 @@ def _prepare_deepspeed(self, *args):
config_kwargs = {
"train_micro_batch_size_per_gpu": batch_size_per_device,
"train_batch_size": batch_size_per_device
* deepspeed_plugin.deepspeed_config["gradient_accumulation_steps"]
* deepspeed_plugin.get_value("gradient_accumulation_steps")
* self.num_processes,
"gradient_clipping": 1.0,
"zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
Expand Down Expand Up @@ -1516,20 +1516,13 @@ def _prepare_deepspeed(self, *args):
)

if model is not None:
ds_config = deepspeed_plugin.deepspeed_config
# deal with config keys that use `auto` value and rely on model's hidden_size
hidden_size_based_keys = [
"zero_optimization.reduce_bucket_size",
"zero_optimization.stage3_prefetch_bucket_size",
"zero_optimization.stage3_param_persistence_threshold",
]

def is_auto(ds_config, ds_key_long):
nodes = ds_key_long.split(".")
val = ds_config.get(nodes[0], {}).get(nodes[1], None)
return False if None else val == "auto"

hidden_size_auto_keys = [x for x in hidden_size_based_keys if is_auto(ds_config, x)]
hidden_size_auto_keys = [x for x in hidden_size_based_keys if deepspeed_plugin.is_auto(x)]
if len(hidden_size_auto_keys) > 0:
reasoning = (
"therefore it's not possible to automatically fill out the following `auto` entries "
Expand All @@ -1546,7 +1539,7 @@ def is_auto(ds_config, ds_key_long):
hidden_size = max(model.config.hidden_sizes)
else:
raise ValueError(
"Can't find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, " + reasoning
"Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, " + reasoning
)

config_kwargs.update(
Expand Down
10 changes: 10 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,16 @@ def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
if ds_val != kwargs[ds_key_long]:
mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}")

def is_auto(self, ds_key_long):
val = self.hf_ds_config.get_value(ds_key_long)
if val is None:
return False
else:
return val == "auto"

def get_value(self, ds_key_long, default=None):
return self.hf_ds_config.get_value(ds_key_long, default)

def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs):
"""Process the DeepSpeed config with the values from the kwargs."""
mismatches = [] if mismatches is None else mismatches
Expand Down
80 changes: 78 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from parameterized import parameterized
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import AutoModel, AutoModelForCausalLM, get_scheduler
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, get_scheduler
from transformers.testing_utils import mockenv_context
from transformers.trainer_utils import set_seed
from transformers.utils import is_torch_bf16_available
Expand All @@ -41,7 +41,7 @@
require_non_cpu,
slow,
)
from accelerate.test_utils.training import RegressionDataset
from accelerate.test_utils.training import RegressionDataset, RegressionModel
from accelerate.utils.dataclasses import DeepSpeedPlugin
from accelerate.utils.deepspeed import (
DeepSpeedEngineWrapper,
Expand All @@ -56,6 +56,7 @@
set_seed(42)

GPT2_TINY = "sshleifer/tiny-gpt2"
MOBILEVIT = "apple/mobilevit-xx-small"

ZERO2 = "zero2"
ZERO3 = "zero3"
Expand All @@ -68,9 +69,15 @@
DS_OPTIMIZER = "deepspeed_optimizer"
DS_SCHEDULER = "deepspeed_scheduler"

NO_CONFIG = "no_config"
CONFIG_WITH_NO_HIDDEN_SIZE = "config_with_no_hidden_size"
CONFIG_WITH_HIDDEN_SIZE = "config_with_hidden_size"
CONFIG_WITH_HIDDEN_SIZES = "config_with_hidden_sizes"

stages = [ZERO2, ZERO3]
optims = [CUSTOM_OPTIMIZER, DS_OPTIMIZER]
schedulers = [CUSTOM_SCHEDULER, DS_SCHEDULER]
model_types = [NO_CONFIG, CONFIG_WITH_NO_HIDDEN_SIZE, CONFIG_WITH_HIDDEN_SIZE, CONFIG_WITH_HIDDEN_SIZES]
if is_torch_bf16_available():
dtypes = [FP16, BF16]
else:
Expand All @@ -89,6 +96,11 @@ def parameterized_custom_name_func(func, param_num, param):
optim_scheduler_params = list(itertools.product(optims, schedulers))


class DummyConfig:
def __init__(self):
self._name_or_path = "dummy"


@require_deepspeed
@require_non_cpu
class DeepSpeedConfigIntegration(AccelerateTestCase):
Expand Down Expand Up @@ -646,6 +658,70 @@ def test_autofill_dsconfig(self):
accelerator.deepspeed_config["zero_optimization"]["stage3_gather_16bit_weights_on_model_save"]
)

@parameterized.expand(model_types, name_func=parameterized_custom_name_func)
def test_autofill_comm_buffers_dsconfig(self, model_type):
deepspeed_plugin = DeepSpeedPlugin(
hf_ds_config=self.ds_config_file[ZERO3],
zero3_init_flag=True,
)
del deepspeed_plugin.deepspeed_config["bf16"]
del deepspeed_plugin.deepspeed_config["fp16"]
del deepspeed_plugin.deepspeed_config["optimizer"]
del deepspeed_plugin.deepspeed_config["scheduler"]
with mockenv_context(**self.dist_env):
accelerator = Accelerator(mixed_precision="fp16", deepspeed_plugin=deepspeed_plugin)

train_set = RegressionDataset(length=80)
eval_set = RegressionDataset(length=20)
train_dataloader = DataLoader(train_set, batch_size=16, shuffle=True)
eval_dataloader = DataLoader(eval_set, batch_size=32, shuffle=False)
model = RegressionModel()
if model_type == CONFIG_WITH_NO_HIDDEN_SIZE:
model.config = DummyConfig()
elif model_type == CONFIG_WITH_HIDDEN_SIZE:
model.config = AutoConfig.from_pretrained(GPT2_TINY)
hidden_size = model.config.hidden_size
elif model_type == CONFIG_WITH_HIDDEN_SIZES:
model.config = AutoConfig.from_pretrained(MOBILEVIT)
hidden_size = max(model.config.hidden_sizes)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=1000,
)

if model_type == NO_CONFIG:
with self.assertRaises(ValueError) as cm:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
msg = "Can't find `model.config` entry"
self.assertTrue(msg in str(cm.exception))
elif model_type == CONFIG_WITH_NO_HIDDEN_SIZE:
with self.assertRaises(ValueError) as cm:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
msg = "Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`"
self.assertTrue(msg in str(cm.exception))
else:
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)
self.assertEqual(
accelerator.deepspeed_config["zero_optimization"]["reduce_bucket_size"], hidden_size * hidden_size
)
self.assertEqual(
accelerator.deepspeed_config["zero_optimization"]["stage3_prefetch_bucket_size"],
0.9 * hidden_size * hidden_size,
)
self.assertEqual(
accelerator.deepspeed_config["zero_optimization"]["stage3_param_persistence_threshold"],
10 * hidden_size,
)

@parameterized.expand([FP16, BF16], name_func=parameterized_custom_name_func)
def test_autofill_dsconfig_from_ds_plugin(self, dtype):
ds_config = self.ds_config_dict["zero3"]
Expand Down

0 comments on commit 411aa58

Please sign in to comment.