Skip to content

Commit

Permalink
Merge branch 'main' into fix-broken-compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Nov 18, 2024
2 parents 24eb138 + 78e5560 commit dc983df
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
5 changes: 4 additions & 1 deletion infrastructure/ami/scripts/install-huggingface-libraries.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ pip install --upgrade --no-cache-dir \
"notebook==7.0.6" \
"markupsafe==2.1.1" \
"jinja2==3.1.2" \
"attrs==23.1.0"
"attrs==23.1.0" \
"hf_transfer>=0.1.4"

# Temporary fix for the issue: https://github.com/huggingface/optimum-neuron/issues/142
pip install -U optimum
echo 'export PATH="${HOME}/.local/bin:$PATH"' >> "${HOME}/.bashrc"
# Add HF_TRANSFER by default
echo 'export HF_HUB_ENABLE_HF_TRANSFER=1' >> "${HOME}/.bashrc"

echo "Step: install-and-copy-optimum-neuron-examples"
git clone -b $OPTIMUM_VERSION https://github.com/huggingface/optimum-neuron.git
Expand Down
20 changes: 13 additions & 7 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,21 @@ def consolidate_tensor_parallel_checkpoints(
# This might not be the case anymore when `ParameterMetadata` uses slices.
sharded_metadata = sharded_metadatas[name]
if sharded_metadata.is_tied:
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu")
consolidated_state_dict[original_name] = state_dicts[0][name].to("cpu").contiguous()
else:
weights = [state_dict[name] for state_dict in state_dicts]
# Ensure that all tensors are contiguous before concatenating or further processing
weights = [state_dict[name].contiguous() for state_dict in state_dicts]
tp_size = len(weights)
full_weight = torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
full_weight = full_weight.to("cpu")

full_weight = (
torch.cat(
weights,
dim=sharded_metadata.partition_dim,
)
.to("cpu")
.contiguous()
) # Ensure the result is also contiguous

if weight_name in ["weight_k", "weight_v", "bias_k", "bias_v"]:
full_weight = (
torch.chunk(full_weight, gqa_qkv_metadata["kv_size_multiplier"], dim=0)[0].detach().clone()
Expand Down
9 changes: 7 additions & 2 deletions optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NEURON_COMPILER_VERSION = get_neuroncc_version()

if is_neuronx_available():
import torch_neuronx
from torch_neuronx import move_trace_to_device

NEURON_COMPILER_TYPE = "neuronx-cc"
Expand Down Expand Up @@ -127,8 +128,12 @@ def load_model(
if path.is_file():
model = torch.jit.load(path)
# For non-inlined models, send the module manually to device. This is important for weights/neff non-inlined module since when loading the module, the neff is automatically moved to Neuron but not the weights. We need to move the weights to Neuron as well manually to avoid great host to device IO penalty.
if is_neuronx_available() and to_neuron:
move_trace_to_device(model, device_id)
if is_neuronx_available():
torch_neuronx.experimental.set_neuron_cores(
model, start_nc=0, nc_count=1
) # The inputs are allocated to nc:0 by default, this line ensures both input tensors and the model are on the same core.
if to_neuron:
move_trace_to_device(model, device_id)
return model

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def state_dict(self):

adapter_shards_dir_model = os.path.join(output_dir, "adapter_shards", "model")
if not os.path.isdir(adapter_shards_dir_model):
os.makedirs(adapter_shards_dir_model)
os.makedirs(adapter_shards_dir_model, exist_ok=True)

dummy_mod = DummyModule()
neuronx_distributed.trainer.save_checkpoint(
Expand Down

0 comments on commit dc983df

Please sign in to comment.