Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JingyaHuang committed Oct 25, 2024
1 parent 6a6bb03 commit de9e582
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions optimum/neuron/modeling_traced.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
NEURON_COMPILER_VERSION = get_neuroncc_version()

if is_neuronx_available():
import torch_neuronx
from torch_neuronx import move_trace_to_device

NEURON_COMPILER_TYPE = "neuronx-cc"
Expand Down Expand Up @@ -127,8 +128,12 @@ def load_model(
if path.is_file():
model = torch.jit.load(path)
# For non-inlined models, send the module manually to device. This is important for weights/neff non-inlined module since when loading the module, the neff is automatically moved to Neuron but not the weights. We need to move the weights to Neuron as well manually to avoid great host to device IO penalty.
if is_neuronx_available() and to_neuron:
move_trace_to_device(model, device_id)
if is_neuronx_available():
torch_neuronx.experimental.set_neuron_cores(
model, start_nc=0, nc_count=1
) # The inputs are allocated to nc:0 by default, this line ensures both input tensors and the model are on the same core.
if to_neuron:
move_trace_to_device(model, device_id)
return model

def replace_weights(self, weights: Optional[Union[Dict[str, torch.Tensor], torch.nn.Module]] = None):
Expand Down

0 comments on commit de9e582

Please sign in to comment.