Skip to content

Commit 660494a

Browse files
committed
Add Ascend NPU support for generate and chat
1 parent 3c7e839 commit 660494a

15 files changed

+441
-68
lines changed

Diff for: README.md

+5-3
Original file line numberDiff line numberDiff line change
@@ -582,11 +582,13 @@ We provide
582582
583583
## Community Contributions
584584
585-
We really value our community and the contributions made by our wonderful users. We'll use this section to call out some of these contributions! If you'd like to help out as well, please see the [CONTRIBUTING](CONTRIBUTING.md) guide.
585+
We really value our community and the contributions made by our wonderful users!
586586
587-
To connect with us and other community members, we invite you to join our Slack community by filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSeADnUNW36fjKjYzyHDOzEB_abKQE9b6gqqW9NXse6O0MWh0A/viewform). Once you've joined, you can:
587+
If you'd like to help out, connect with us and other community members by joining our [Discord](https://discord.gg/hm2Keduk3v). Once you've joined, you can:
588588
* Head to the `#torchchat-general` channel for general questions, discussion, and community support.
589-
* Join the `#torchchat-contributors` channel if you're interested in contributing directly to project development.
589+
* Hop in the `#torchchat-contributors` channel if you're interested in contributing directly to project development.
590+
591+
Also give our [CONTRIBUTING](CONTRIBUTING.md) guide a read.
590592

591593
Looking forward to discussing with you about torchchat future!
592594

Diff for: install/.pins/et-pin.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
791472d6706b027552f39f11b28d034e4839c9af
1+
73740e9268a4a47baeaedc58a1f75597038d2377

Diff for: install/install_requirements.sh

+19-9
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,29 @@ echo "Using pip executable: $PIP_EXECUTABLE"
5151
# NOTE: If a newly-fetched version of the executorch repo changes the value of
5252
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
5353
# package versions.
54-
PYTORCH_NIGHTLY_VERSION=dev20250131
54+
PYTORCH_NIGHTLY_VERSION=dev20250327
5555

5656
# Nightly version for torchvision
57-
VISION_NIGHTLY_VERSION=dev20250131
57+
VISION_NIGHTLY_VERSION=dev20250327
5858

5959
# Nightly version for torchtune
60-
TUNE_NIGHTLY_VERSION=dev20250131
60+
TUNE_NIGHTLY_VERSION=dev20250327
6161

6262
# The pip repository that hosts nightly torch packages. cpu by default.
6363
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
6464
# with cuda for faster execution on cuda GPUs.
6565
if [[ -x "$(command -v nvidia-smi)" ]];
6666
then
67-
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu124"
67+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cu126"
6868
elif [[ -x "$(command -v rocminfo)" ]];
6969
then
7070
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
7171
elif [[ -x "$(command -v xpu-smi)" ]];
7272
then
7373
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
74+
elif [[ -x "$(command -v npu-smi)" ]]
75+
then
76+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/test/cpu"
7477
else
7578
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7679
fi
@@ -79,15 +82,22 @@ fi
7982
if [[ -x "$(command -v xpu-smi)" ]];
8083
then
8184
REQUIREMENTS_TO_INSTALL=(
82-
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
85+
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"
8386
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
84-
#torchtune=="0.6.0" # no 0.6.0 on xpu nightly
87+
#torchtune=="0.7.0" # no 0.6.0 on xpu nightly
88+
)
89+
elif [[ -x "$(command -v npu-smi)" ]];
90+
then
91+
REQUIREMENTS_TO_INSTALL=(
92+
torch=="2.7.0"
93+
torchvision=="0.22.0"
94+
torchtune=="0.6.0"
8595
)
8696
else
8797
REQUIREMENTS_TO_INSTALL=(
88-
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
98+
torch=="2.8.0.${PYTORCH_NIGHTLY_VERSION}"
8999
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
90-
torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}"
100+
torchtune=="0.7.0.${TUNE_NIGHTLY_VERSION}"
91101
)
92102
fi
93103

@@ -136,5 +146,5 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
136146
fi
137147
(
138148
set -x
139-
$PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0"
149+
$PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.7" psutil=="6.0.0"
140150
)

Diff for: install/requirements.txt

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ openai
2323

2424
# Build tools
2525
wheel
26-
cmake>=3.24
26+
cmake>=3.24, < 4.0.0 # 4.0 is BC breaking
2727
ninja
2828
zstd
2929

@@ -34,4 +34,4 @@ streamlit
3434
flask
3535

3636
# eval
37-
lm_eval==0.4.2
37+
lm_eval==0.4.7

Diff for: runner/run.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ using executorch::extension::TensorPtr;
5353
using torch::executor::EValue;
5454
using torch::executor::Module;
5555
using torch::executor::Result;
56+
using executorch::runtime::MemoryManager;
57+
using executorch::runtime::MemoryAllocator;
58+
using executorch::runtime::Error;
5659
#endif
5760

5861
using tokenizers::SPTokenizer;
@@ -867,7 +870,26 @@ int main(int argc, char *argv[]) {
867870
: torch::Device(torch::kCUDA);
868871
ModelType model_type = get_model_type(std::stoi(aoti_metadata["tokenizer_type"]));
869872
#else // __ET_MODEL__
870-
ModelType model_type = get_model_type(llama_ver);
873+
Error load_status = transformer.runner->load();
874+
ET_CHECK_MSG(
875+
load_status == torch::executor::Error::Ok,
876+
"program::load() failed with status 0x%" PRIx32,
877+
static_cast<uint32_t>(load_status));
878+
879+
static std::array<uint8_t, 4 * 1024U * 1024U> method_allocator_pool; // 4MB
880+
MemoryAllocator method_allocator{MemoryAllocator(
881+
sizeof(method_allocator_pool), method_allocator_pool.data())};
882+
MemoryManager memory_manager(&method_allocator, nullptr);
883+
auto tokenizer_method = transformer.runner->program()->load_method("tokenizer_type", &memory_manager);
884+
885+
Error execute_status = tokenizer_method->execute();
886+
ET_CHECK_MSG(
887+
execute_status == torch::executor::Error::Ok,
888+
"method::execute() failed with status 0x%" PRIx32,
889+
static_cast<uint32_t>(execute_status));
890+
891+
auto tokenizer_type = tokenizer_method->get_output(0).toInt();
892+
ModelType model_type = get_model_type(tokenizer_type);
871893
#endif
872894

873895
if (model_type == UNKNOWN_MODEL) {

Diff for: torchchat/cli/builder.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchchat.utils.build_utils import (
3030
device_sync,
3131
is_cpu_device,
32-
is_cuda_or_cpu_or_xpu_device,
32+
is_supported_device,
3333
name_to_dtype,
3434
)
3535
from torchchat.utils.measure_time import measure_time
@@ -78,6 +78,8 @@ def __post_init__(self):
7878
self.device = "cuda"
7979
elif torch.xpu.is_available():
8080
self.device = "xpu"
81+
elif hasattr(torch, "npu") and torch.npu.is_available():
82+
self.device = "npu"
8183
else:
8284
self.device = "cpu"
8385

@@ -539,7 +541,7 @@ def _initialize_model(
539541
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
540542

541543
if builder_args.dso_path:
542-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
544+
if not is_supported_device(builder_args.device):
543545
print(
544546
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
545547
)
@@ -573,7 +575,7 @@ def do_nothing(max_batch_size, max_seq_length):
573575
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
574576

575577
elif builder_args.aoti_package_path:
576-
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
578+
if not is_supported_device(builder_args.device):
577579
print(
578580
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
579581
)

Diff for: torchchat/cli/cli.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu", "npu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu, npu",
181181
)
182182
model_config_parser.add_argument(
183183
"--attention-backend",
@@ -432,6 +432,14 @@ def _add_evaluation_args(parser) -> None:
432432
help="Maximum length sequence to evaluate",
433433
)
434434

435+
eval_parser.add_argument(
436+
"--modality",
437+
type=str,
438+
default="text",
439+
choices=["text", "text-image"],
440+
help="Modality of the model. Options: text, text-image",
441+
)
442+
435443

436444
# Add CLI Args related to distributed inference
437445
# This feature is currently a [WIP] and hidden from --help

Diff for: torchchat/export.py

+24-12
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def export_to_edge(
313313
core_aten_ep, edge_constant_methods, edge_compile_config, verbose=verbose
314314
)
315315

316-
def export_for_et(model, device, output_path) -> str:
316+
def export_for_et(model, device, output_path, edge_constant_methods) -> str:
317317

318318
input = (
319319
torch.tensor([[1]], dtype=torch.long, device=device),
@@ -344,12 +344,15 @@ def export_for_et(model, device, output_path) -> str:
344344
with torch.nn.attention.sdpa_kernel(
345345
[torch.nn.attention.SDPBackend.MATH]
346346
), torch.no_grad():
347-
m = export_for_training(model, input, dynamic_shapes=dynamic_shapes).module()
347+
m = export_for_training(
348+
model, input, dynamic_shapes=dynamic_shapes
349+
).module()
348350

349351
edge_manager = export_to_edge(
350352
m,
351353
input,
352354
dynamic_shapes=dynamic_shapes,
355+
edge_constant_methods=edge_constant_methods,
353356
edge_compile_config=edge_config,
354357
)
355358
edge_manager = edge_manager.to_backend(XnnpackDynamicallyQuantizedPartitioner())
@@ -365,6 +368,7 @@ def export_for_et(model, device, output_path) -> str:
365368
)
366369

367370
print("The methods are: ", export_program.methods)
371+
print("The config methods are: ", export_program.config_methods)
368372
with open(output_path, "wb") as f:
369373
export_program.write_to_file(f)
370374

@@ -407,7 +411,9 @@ def main(args):
407411
f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={builder_args.device} setting."
408412
)
409413
builder_args.device = "cpu"
410-
elif (output_pte_path or output_dso_path or output_aoti_package_path) and "mps" in builder_args.device:
414+
elif (
415+
output_pte_path or output_dso_path or output_aoti_package_path
416+
) and "mps" in builder_args.device:
411417
print("Warning! Device MPS not supported for export. Exporting for device CPU.")
412418
builder_args.device = "cpu"
413419

@@ -473,13 +479,26 @@ def main(args):
473479
support_tensor_subclass=False,
474480
)
475481
_unset_gguf_kwargs(builder_args)
476-
482+
483+
if tokenizer_args is None:
484+
tokenizer_type = "0"
485+
elif tokenizer_args.is_sentencepiece:
486+
tokenizer_type = "2" # Corresponding to llama2
487+
else:
488+
tokenizer_type = "3" # Corresponding to llama3
489+
477490
with torch.no_grad():
478491
if output_pte_path:
479492
output_pte_path = str(os.path.abspath(output_pte_path))
480493
if executorch_export_available:
481494
print(f"Exporting model using ExecuTorch to {output_pte_path}")
482-
export_for_et(model_to_pte, builder_args.device, args.output_pte_path)
495+
print(f"Tokenizer type is {tokenizer_type}")
496+
export_for_et(
497+
model_to_pte,
498+
builder_args.device,
499+
args.output_pte_path,
500+
{"tokenizer_type": int(tokenizer_type)},
501+
)
483502
else:
484503
print(
485504
"Export with executorch requested but ExecuTorch could not be loaded"
@@ -503,13 +522,6 @@ def main(args):
503522
if output_aoti_package_path:
504523
output_aoti_package_path = str(os.path.abspath(output_aoti_package_path))
505524

506-
if tokenizer_args is None:
507-
tokenizer_type = "0"
508-
elif tokenizer_args.is_sentencepiece:
509-
tokenizer_type = "2" # Corresponding to llama2
510-
else:
511-
tokenizer_type = "3" # Corresponding to llama3
512-
513525
metadata = {"tokenizer_type": tokenizer_type}
514526
print(
515527
"Exporting model using AOT Inductor to " f"{output_aoti_package_path}."

Diff for: torchchat/generate.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,8 @@ def callback(x, *, done_generating=False):
12131213
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
12141214
elif self.builder_args.device == "cuda":
12151215
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1216+
elif self.builder_args.device == "npu":
1217+
print(prof.key_averages().table(sort_by="self_npu_time_total"))
12161218
else:
12171219
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
12181220
prof.export_chrome_trace(f"{self.profile}.json")
@@ -1299,8 +1301,10 @@ def callback(x, *, done_generating=False):
12991301
)
13001302
if torch.cuda.is_available():
13011303
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1302-
if torch.xpu.is_available():
1304+
elif torch.xpu.is_available():
13031305
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1306+
elif hasattr(torch, "npu") and torch.npu.is_available():
1307+
print(f"Memory used: {torch.npu.max_memory_reserved() / 1e9:.02f} GB")
13041308

13051309

13061310

@@ -1595,7 +1599,6 @@ def sample(
15951599

15961600
return idx_next, probs
15971601

1598-
15991602
def run_generator(
16001603
args,
16011604
rank: Optional[int] =None
@@ -1628,8 +1631,10 @@ def run_generator(
16281631
)
16291632
if torch.cuda.is_available():
16301633
torch.cuda.reset_peak_memory_stats()
1631-
if torch.xpu.is_available():
1634+
elif torch.xpu.is_available():
16321635
torch.xpu.reset_peak_memory_stats()
1636+
elif hasattr(torch, "npu") and torch.npu.is_available():
1637+
torch.npu.reset_peak_memory_stats()
16331638

16341639
for _ in gen.chat(generator_args):
16351640
pass

Diff for: torchchat/model.py

+6
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,12 @@ def setup_caches(self, batch_size, dtype, encoder_max_seq_len, decoder_max_seq_l
608608
decoder_max_seq_len=decoder_max_seq_len,
609609
)
610610

611+
def caches_are_setup(self) -> bool:
612+
return self.model.caches_are_setup()
613+
614+
def caches_are_enabled(self) -> bool:
615+
return self.model.caches_are_enabled()
616+
611617
def reset_caches(self):
612618
self.model.reset_caches()
613619

0 commit comments

Comments
 (0)