Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/optimum-executorch.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0123293118efb08ac4ffc4fefe9d330201465c93
de4f3c4978b4d36cc0bb8f87c6877a4a040d7ae7
34 changes: 32 additions & 2 deletions .ci/scripts/test_huggingface_optimum_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,35 @@ def test_text_generation(model_id, model_dir, recipe, *, quantize=True, run_only
assert check_causal_lm_output_quality(model_id, generated_tokens) is True


def get_tokenizer_path(model_dir: str, saved_files: tuple) -> str:
"""
Determine the tokenizer path based on files saved by tokenizer.save_pretrained().

Args:
model_dir: The directory where tokenizer files were saved
saved_files: Tuple of file paths returned by tokenizer.save_pretrained()

Returns:
The path to use for loading the tokenizer (either a specific file or directory)

Raises:
ValueError: If no supported tokenizer file format is found
"""
saved_filenames = {Path(f).name for f in saved_files}

if "tokenizer.model" in saved_filenames:
return f"{model_dir}/tokenizer.model"

if "tokenizer.json" in saved_filenames:
return model_dir

# No supported tokenizer format found
raise ValueError(
f"Unsupported tokenizer format. Expected 'tokenizer.model' (SentencePiece) "
f"or 'tokenizer.json' (HuggingFace) but found: {saved_filenames}"
)


def test_llm_with_image_modality(
model_id, model_dir, recipe, *, quantize=True, run_only=False
):
Expand All @@ -196,7 +225,8 @@ def test_llm_with_image_modality(
cli_export(command, model_dir)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.save_pretrained(model_dir)
saved_files = tokenizer.save_pretrained(model_dir)
tokenizer_path = get_tokenizer_path(model_dir, saved_files)

# input
processor = AutoProcessor.from_pretrained(model_id)
Expand Down Expand Up @@ -232,7 +262,7 @@ def test_llm_with_image_modality(

from executorch.extension.llm.runner import GenerationConfig, MultimodalRunner

runner = MultimodalRunner(f"{model_dir}/model.pte", f"{model_dir}/tokenizer.model")
runner = MultimodalRunner(f"{model_dir}/model.pte", tokenizer_path)
generated_text = runner.generate_text_hf(
inputs,
GenerationConfig(max_new_tokens=128, temperature=0, echo=False),
Expand Down
2 changes: 1 addition & 1 deletion .ci/scripts/test_model_e2e.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ case "$HF_MODEL" in
MODEL_NAME="voxtral"
RUNNER_TARGET="voxtral_runner"
RUNNER_PATH="voxtral"
EXPECTED_OUTPUT="poem"
EXPECTED_OUTPUT="contemplating"
PREPROCESSOR="voxtral_preprocessor.pte"
TOKENIZER_URL="https://huggingface.co/mistralai/Voxtral-Mini-3B-2507/resolve/main" # @lint-ignore
TOKENIZER_FILE="tekken.json"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/trunk.yml
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,7 @@ jobs:
qwen3-0.6b|xnnpack|--quantize,
qwen3-1.7b|xnnpack|--quantize,
gemma3-1b|xnnpack|--quantize,
phi4-mini|xnnpack|--quantize,
# phi4-mini|xnnpack|--quantize, transformers v5.0.0rc0 introduces a data-dependent branching in transformers/modeling_rope_utils.py:61
smollm2-135m|xnnpack|--quantize,
smollm3-3b|xnnpack|--quantize
]
Expand Down
2 changes: 2 additions & 0 deletions backends/aoti/aoti_delegate_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <string>

namespace executorch {
namespace backends {
Expand Down Expand Up @@ -85,6 +86,7 @@ struct AOTIDelegateHandle {
AOTInductorModelContainerHandle container_handle;
void* cuda_stream; // cudaStream_t stored as void* to avoid CUDA header
// dependency
std::string method_name;

// Function pointers specific to this handle's shared library
AOTInductorModelContainerCreateWithDeviceFunc create_with_device;
Expand Down
110 changes: 96 additions & 14 deletions backends/cuda/runtime/cuda_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@

#include <cuda_runtime.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
#include <cstdio>

#include <array>
#include <filesystem>
#include <fstream>
#include <mutex>
#include <string>
#include <vector>

Expand All @@ -35,20 +38,54 @@ using executorch::runtime::ArrayRef;
using executorch::runtime::Backend;
using executorch::runtime::BackendExecutionContext;
using executorch::runtime::BackendInitContext;
using executorch::runtime::BackendOption;
using executorch::runtime::BackendOptionContext;
using executorch::runtime::CompileSpec;
using executorch::runtime::DelegateHandle;
using executorch::runtime::Error;
using executorch::runtime::EValue;
using executorch::runtime::FreeableBuffer;
using executorch::runtime::kMaxOptionValueLength;
using executorch::runtime::MemoryAllocator;
using executorch::runtime::NamedDataMap;
using executorch::runtime::Result;
using executorch::runtime::Span;
using executorch::runtime::etensor::Tensor;

namespace {
constexpr char kSkipCopyOutputToCpuForMethod[] =
"skip_copy_output_to_cpu_for_method";
}

class ET_EXPERIMENTAL CudaBackend final
: public ::executorch::runtime::BackendInterface {
private:
void set_skip_copy_method(
const std::array<char, kMaxOptionValueLength>& raw) {
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
skip_copy_method_ = std::string(raw.data());
}

std::array<char, kMaxOptionValueLength> get_skip_copy_method_as_option()
const {
std::array<char, kMaxOptionValueLength> out{};
std::string value;
{
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
value = skip_copy_method_;
}
std::snprintf(out.data(), out.size(), "%s", value.c_str());
return out;
}

bool should_skip_copy_for_method(const std::string& method_name) const {
if (method_name.empty()) {
return false;
}
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
return method_name == skip_copy_method_;
}

Error load_function_pointers_into_handle(
void* so_handle,
AOTIDelegateHandle* handle) const {
Expand Down Expand Up @@ -91,6 +128,38 @@ class ET_EXPERIMENTAL CudaBackend final
return 1;
}

Error set_option(
ET_UNUSED BackendOptionContext& context,
const executorch::runtime::Span<BackendOption>& backend_options)
override {
for (const auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
if (auto* val = std::get_if<std::array<char, kMaxOptionValueLength>>(
&option.value)) {
set_skip_copy_method(*val);
} else {
ET_LOG(
Error,
"Option %s must be a method name string.",
kSkipCopyOutputToCpuForMethod);
return Error::InvalidArgument;
}
}
}
return Error::Ok;
}

Error get_option(
ET_UNUSED BackendOptionContext& context,
executorch::runtime::Span<BackendOption>& backend_options) override {
for (auto& option : backend_options) {
if (std::strcmp(option.key, kSkipCopyOutputToCpuForMethod) == 0) {
option.value = get_skip_copy_method_as_option();
}
}
return Error::Ok;
}

// Once per loaded binary blob
Result<DelegateHandle*> init(
BackendInitContext& context,
Expand Down Expand Up @@ -159,6 +228,7 @@ class ET_EXPERIMENTAL CudaBackend final
AOTIDelegateHandle* handle = new AOTIDelegateHandle();
handle->so_handle = lib_handle;
handle->so_path = so_path.string();
handle->method_name = method_name;

// Load function pointers specific to this handle's shared library
ET_CHECK_OK_OR_RETURN_ERROR(
Expand Down Expand Up @@ -224,7 +294,7 @@ class ET_EXPERIMENTAL CudaBackend final

// Process input tensors: ExecuTorch provides CPU tensors, create GPU
// copies
for (int i = 0; i < n_inputs; i++) {
for (size_t i = 0; i < n_inputs; i++) {
// Get tensor dimensions and properties from ExecuTorch CPU tensor
auto cpu_tensor = &(args[i]->toTensor());
auto sizes = cpu_tensor->sizes();
Expand Down Expand Up @@ -260,7 +330,7 @@ class ET_EXPERIMENTAL CudaBackend final
}
// Process output tensors: create GPU counterparts for ExecuTorch CPU
// tensors
for (int i = 0; i < n_outputs; i++) {
for (size_t i = 0; i < n_outputs; i++) {
// Get output tensor dimensions from ExecuTorch CPU tensor
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
auto sizes = cpu_output_tensor->sizes();
Expand Down Expand Up @@ -303,18 +373,26 @@ class ET_EXPERIMENTAL CudaBackend final
"AOTInductorModelContainerRun failed with error code %d",
error);

// Copy GPU output results back to CPU output tensors
for (int i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
// For DYNAMIC_BOUND tensors we try to resize
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
"Error resizing tensor at output index %d",
i);
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
"Failed to copy GPU output %d back to CPU",
i);
const bool copy_outputs = !should_skip_copy_for_method(handle->method_name);

if (copy_outputs) {
// Copy GPU output results back to CPU output tensors
for (size_t i = 0; i < n_outputs; i++) {
auto cpu_output_tensor = &(args[i + n_inputs]->toTensor());
// For DYNAMIC_BOUND tensors we try to resize
ET_CHECK_OK_OR_RETURN_ERROR(
resize_tensor(*cpu_output_tensor, gpu_outputs[i]->sizes()),
"Error resizing tensor at output index %d",
i);
ET_CHECK_OK_OR_RETURN_ERROR(
aoti_torch_copy_(cpu_output_tensor, gpu_outputs[i], 0),
"Failed to copy GPU output %d back to CPU",
i);
}
} else {
for (size_t i = 0; i < n_outputs; i++) {
args[i + n_inputs]->toTensor() = *gpu_outputs[i];
}
}

return Error::Ok;
Expand Down Expand Up @@ -365,6 +443,10 @@ class ET_EXPERIMENTAL CudaBackend final
delete handle;
clear_all_tensors();
}

private:
mutable std::mutex skip_copy_method_mutex_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need mutex at all?

We enforce the callers of ExecuTorch to be aware of thread safety, and don't guarantee any thread safety within the internals of ET.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the pattern in xnnpack so would like to keep it the same way.

mutable std::mutex weights_cache_mutex_;

std::string skip_copy_method_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't it be an arrary of string cus we may skip copy on multiple methods? Its ok to be update in the following PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah next PR will be supporting a comma separated string.

};

} // namespace executorch::backends::cuda
Expand Down
8 changes: 3 additions & 5 deletions examples/qualcomm/oss_scripts/mobilevit_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)
from PIL import Image
from torchvision import datasets
from transformers import AutoModelForImageClassification, MobileViTFeatureExtractor
from transformers import AutoImageProcessor, AutoModelForImageClassification


def get_imagenet_dataset(dataset_path, data_size, shuffle=True):
Expand All @@ -39,15 +39,13 @@ def get_data_loader():
# prepare input data
inputs, targets = [], []
data_loader = get_data_loader()
feature_extractor = MobileViTFeatureExtractor.from_pretrained(
"apple/mobilevit-xx-small"
)
image_processor = AutoImageProcessor.from_pretrained("apple/mobilevit-xx-small")
for index, data in enumerate(data_loader.dataset.imgs):
if index >= data_size:
break
data_path, target = data
image = Image.open(data_path).convert("RGB")
feature = feature_extractor(images=image, return_tensors="pt")
feature = image_processor(images=image, return_tensors="pt")
inputs.append((feature["pixel_values"],))
targets.append(torch.tensor(target))

Expand Down
8 changes: 3 additions & 5 deletions examples/qualcomm/oss_scripts/mobilevit_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from PIL import Image
from torchvision import datasets
from transformers import AutoModelForImageClassification, MobileViTFeatureExtractor
from transformers import AutoImageProcessor, AutoModelForImageClassification


def get_imagenet_dataset(dataset_path, data_size, shuffle=True):
Expand All @@ -40,15 +40,13 @@ def get_data_loader():
# prepare input data
inputs, targets = [], []
data_loader = get_data_loader()
feature_extractor = MobileViTFeatureExtractor.from_pretrained(
"apple/mobilevit-xx-small"
)
image_processor = AutoImageProcessor.from_pretrained("apple/mobilevit-xx-small")
for index, data in enumerate(data_loader.dataset.imgs):
if index >= data_size:
break
data_path, target = data
image = Image.open(data_path).convert("RGB")
feature = feature_extractor(images=image, return_tensors="pt")
feature = image_processor(images=image, return_tensors="pt")
inputs.append((feature["pixel_values"],))
targets.append(torch.tensor(target))

Expand Down
16 changes: 16 additions & 0 deletions extension/asr/runner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ set_target_properties(
extension_asr_runner PROPERTIES POSITION_INDEPENDENT_CODE ON
)

# If the project is configured to build with CUDA support, try to find a CUDA
# runtime (prefer the CUDAToolkit package). If found, expose a compile-time
# macro so sources can conditionally compile CUDA-aware code.
if(EXECUTORCH_BUILD_CUDA)
find_package(CUDAToolkit QUIET)
if(CUDAToolkit_FOUND)
target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE)
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner")
else()
message(
STATUS
"CUDA requested (EXECUTORCH_BUILD_CUDA=ON) but no CUDA runtime found"
)
endif()
endif()

install(
TARGETS extension_asr_runner
EXPORT ExecuTorchTargets
Expand Down
19 changes: 18 additions & 1 deletion extension/asr/runner/runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <executorch/extension/llm/runner/util.h>
#include <executorch/extension/llm/sampler/util.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/backend/interface.h>
#include <executorch/runtime/backend/options.h>
#include <executorch/runtime/core/evalue.h>
#include <executorch/runtime/platform/assert.h>
#include <executorch/runtime/platform/log.h>
Expand Down Expand Up @@ -107,7 +109,22 @@ Error AsrRunner::load() {

ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName));
decoder_method_loaded_ = true;

#ifdef CUDA_AVAILABLE
executorch::runtime::BackendOptions<1> backend_options;
// For decoder still copy output from GPU to CPU for sampling.
// TODO: change sampler to use a CUDA kernel to sample and then skip copying
// decoder output as well
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
"skip_copy_output_to_cpu_for_method", kEncoderMethodName));
const auto opt_err =
executorch::runtime::set_option("CudaBackend", backend_options.view());
if (opt_err != ::executorch::runtime::Error::Ok) {
ET_LOG(
Warning,
"Failed to set CUDA backend options: %d",
static_cast<int>(opt_err));
}
#endif
ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer());
auto eos_ids = get_eos_ids(tokenizer_.get(), module_.get());
if (!eos_ids.empty()) {
Expand Down
2 changes: 1 addition & 1 deletion requirements-examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now
timm == 1.0.7
torchsr == 1.0.4
torchtune >= 0.6.1
transformers == 4.56.1
transformers == 5.0.0rc1
Loading