Skip to content

Commit d59a88d

Browse files
authored
Push install deeper in generate and undo import order change for etmodel (#1542)
1 parent 98a5ac7 commit d59a88d

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

torchchat/generate.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -869,13 +869,6 @@ def _gen_model_input(
869869
max_new_tokens: Optional[int] = None,
870870
max_seq_len: Optional[int] = 2048,
871871
) -> Tuple[torch.Tensor, Optional[Dict[str, Any]]]:
872-
# torchtune model definition dependencies
873-
from torchtune.data import Message, padded_collate_tiled_images_and_mask
874-
from torchtune.models.llama3_2_vision._model_builders import (
875-
llama3_2_vision_transform,
876-
)
877-
from torchtune.training import set_default_dtype
878-
879872
"""
880873
Convert prompt and image prompts into consumable model input args.
881874
@@ -911,6 +904,14 @@ def _gen_model_input(
911904
return encoded, None
912905

913906
# Llama 3.2 11B
907+
908+
# torchtune model definition dependencies
909+
from torchtune.data import Message, padded_collate_tiled_images_and_mask
910+
from torchtune.models.llama3_2_vision._model_builders import (
911+
llama3_2_vision_transform,
912+
)
913+
from torchtune.training import set_default_dtype
914+
914915
assert (
915916
image_prompts is None or len(image_prompts) == 1
916917
), "At most one image is supported at the moment"

torchchat/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,13 +1054,13 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
10541054
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
10551055

10561056
try:
1057-
# For llama::sdpa_with_kv_cache.out, preprocess ops
1058-
from executorch.extension.llm.custom_ops import custom_ops # no-qa
10591057
from executorch.extension.pybindings import portable_lib as exec_lib
10601058

10611059
# ET changed the way it's loading the custom ops so it's not included in portable_lib but has to be loaded separately.
10621060
# For quantized_decomposed ops
10631061
from executorch.kernels import quantized # no-qa
1062+
# For llama::sdpa_with_kv_cache.out, preprocess ops
1063+
from executorch.extension.llm.custom_ops import custom_ops # no-qa
10641064

10651065
class PTEModel(nn.Module):
10661066
def __init__(self, config, path) -> None:

0 commit comments

Comments
 (0)