Skip to content

Commit

Permalink
Init on the xla device (#521)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun authored Apr 3, 2024
1 parent bb66802 commit 12b06a3
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 12 deletions.
3 changes: 1 addition & 2 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,7 @@ def _prepare_model_for_mp(
cpu_ids = {name: id(param) for name, param in model.named_parameters()}
tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
# TODO: use self.device.
model = self.state.mp_plugin.parallelize_model(model, device=None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)

if model_main_input_name is not None:
setattr(model, "main_input_name", model_main_input_name)
Expand Down
16 changes: 11 additions & 5 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ def _initialize_or_load_weights(
continue
else:
slices = None

new_parameter = torch.nn.Parameter(
load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype)
)
weight_data = load_tensor_for_weight(weight_info, tensor_slices=slices).to(parameter.dtype)
if device is not None:
weight_data = weight_data.to(device)
new_parameter = torch.nn.Parameter(weight_data)
elif parameter.device != torch.device("meta") and (
was_already_initialized_during_parallelization(parameter)
or not parameter_can_be_initialized(model, module, attribute_name)
Expand All @@ -407,7 +407,8 @@ def _initialize_or_load_weights(
continue
else:
# This means that there is no information about where to find the weights for this parameter.
device = torch.device("cpu") if device is None else device
# We first create the module on CPU, initialize it and then move it on device if needed.
device = torch.device("cpu")
new_parameter = torch.nn.Parameter(torch.empty_like(parameter, device=device))
modules_to_initialize[module].append(attribute_name)

Expand All @@ -418,6 +419,7 @@ def _initialize_or_load_weights(
)
tied_weights[parameter] = new_parameter
new_parameters.add(new_parameter)
gc.collect()

for mod, parameter_names in modules_to_initialize.items():
if isinstance(mod, torch.nn.Embedding):
Expand Down Expand Up @@ -500,6 +502,10 @@ def initialize(mod: GQAQKVColumnParallelLinear, proj_name: str, output_size: int
if left_uninitialized and hasattr(mod, "reset_parameters"):
initialize_torch_nn_module(mod, parameter_names)

if device is not None:
mod.to(device)
gc.collect()

@classmethod
@requires_neuronx_distributed
def _initialize_for_precompilation(
Expand Down
24 changes: 19 additions & 5 deletions optimum/neuron/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,17 +1100,28 @@ def try_to_hf_initialize(
`model._init_weights` method. It returns the names of the parameters that were left uninitialized.
"""
cached_params_data = {name: param.data.detach().clone().to("cpu") for name, param in mod.named_parameters()}
device = torch.device("cpu")
for name in parameter_names:
param_device = getattr(mod, name).device
if param_device != torch.device("meta"):
device = param_device

mod.to("cpu")

cached_params_data = {name: param.data.detach().clone() for name, param in mod.named_parameters()}

# We initialize on cpu to have the same RNG state (mostly useful for tests).
model._init_weights(mod)

if parameter_names_mapping is None:
parameter_names_mapping = {}

reverse_parameter_names_mapping = {v: k for k, v in parameter_names_mapping.items()}

def name_in_mod(name: str):
return parameter_names_mapping.get(name, name)

dummy_mod = copy.deepcopy(mod).to("cpu")
dummy_mod = copy.deepcopy(mod)
for name in parameter_names:
getattr(dummy_mod, name_in_mod(name)).random_()
model._init_weights(dummy_mod)
Expand All @@ -1120,15 +1131,15 @@ def name_in_mod(name: str):
for param_name in parameter_names:
name = name_in_mod(param_name)
# The parameter was left unchanged.
param_on_cpu = getattr(mod, name).data.to("cpu")
if torch.all(param_on_cpu == cached_params_data[name]):
param = getattr(mod, name).data
if torch.all(param == cached_params_data[name]):
# There are two possible reasons:
# 1. The model cannot initialize the module that owns the parameter.
# 2. The parameter already had the proper value.

# We check if a dummy copy of the module, filled with random values is modified to know if the model
# can initialize the module.
dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param_on_cpu)
dummy_param_was_changed = torch.all(getattr(dummy_mod, name).data == param)
if not dummy_param_was_changed:
left_uninitialized.append(param_name)

Expand All @@ -1138,6 +1149,9 @@ def name_in_mod(name: str):
param = getattr(mod, name)
param.data = cached_data

# We restore the module back to its original device.
mod.to(device)

return left_uninitialized


Expand Down

0 comments on commit 12b06a3

Please sign in to comment.