Skip to content

Qualcomm AI Engine Direct - GA Model Enablement (Roberta) #11354

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backends/qualcomm/_passes/i64_to_i32.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ class I64toI32(ExportPass):
I64_OPS = {
exir_ops.edge.aten.argmin.default,
exir_ops.edge.aten.arange.start_step,
exir_ops.edge.aten.cumsum.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.scalar_tensor.default,
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
}
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
# For example, scatter op can only accept args[2], the index, as int64.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def _build_tensor_constant(
dtype=(
node.args[0].meta["val"].dtype
if not is_float_tensor(node)
and not SCALAR_OPS.get(node.target).use_self_dtype
and (info := SCALAR_OPS.get(node.target))
and not info.use_self_dtype
else node.meta["val"].dtype
),
device=node.meta["val"].device,
Expand Down
6 changes: 6 additions & 0 deletions backends/qualcomm/_passes/replace_inf_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ def call(self, graph_module: torch.fx.GraphModule):
arg_list[index] = torch.finfo(torch.float32).min
elif arg == float("inf"):
arg_list[index] = torch.finfo(torch.float32).max

if node.target == torch.ops.aten.masked_fill.Scalar:
if arg_list[2] == torch.finfo(torch.float32).max:
arg_list[2] = 255
elif arg_list[2] == torch.finfo(torch.float32).min:
arg_list[2] = -255
node.args = tuple(arg_list)

graph_module.recompile()
Expand Down
2 changes: 2 additions & 0 deletions backends/qualcomm/builders/op_cum_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def define_node(
dim = self.get_param(node, input_tensor)

output_tensor = self.get_tensor(node, node)
if output_tensor.dtype == torch.int64:
output_tensor = output_tensor.to(torch.int32)
output_tensor_wrapper = self.define_tensor(
node,
node,
Expand Down
20 changes: 10 additions & 10 deletions backends/qualcomm/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,16 @@ def forward(self, x):
return torch.mean(x, (-1, -2))


class MaskedFill(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, attn_mask):
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)


class Maximum(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -1751,16 +1761,6 @@ def forward(self, x):
)


class MaskedFill(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, attn_mask):
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
attn_mask == 0, float(0.0)
)


# Mimi Decoder has 0D tensor which QNN cannot handle.
class ZeroDimTensor(torch.nn.Module):
def __init__(self):
Expand Down
67 changes: 64 additions & 3 deletions backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,24 @@ def test_qnn_backend_cos(self):
self.lower_module_and_test_output(module, sample_input)

def test_qnn_backend_cumsum(self):
module = CumSum() # noqa: F405
sample_input = (torch.randn(4),)
self.lower_module_and_test_output(module, sample_input)
sample_input = ()
test_comb = [
{
QCOM_MODULE: [CumSum()], # noqa: F405
QCOM_SAMPLE_INPUTS: [
(torch.randn(4),),
(torch.randint(0, 10, size=(4,)),),
],
}
]

index = 0
for comb in test_comb:
for module in comb[QCOM_MODULE]:
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
with self.subTest(i=index):
self.lower_module_and_test_output(module, sample_input)
index += 1

def test_qnn_backend_einsum_outer_product(self):
module = EinsumOuterProduct() # noqa: F405
Expand Down Expand Up @@ -316,6 +331,12 @@ def test_qnn_backend_element_wise_add(self):
QCOM_MODULE: [AddConstantFloat()], # noqa: F405
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
},
{
QCOM_MODULE: [
AddConstantLong(), # noqa: F405
],
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
},
]

index = 0
Expand Down Expand Up @@ -4562,6 +4583,40 @@ def test_retinanet(self):
else:
self.assertGreaterEqual(msg["mAP"], 0.6)

def test_roberta(self):
if not self.required_envs([self.sentence_dataset]):
self.skipTest("missing required envs")
cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py",
"--dataset",
self.sentence_dataset,
"--artifact",
self.artifact_dir,
"--build_folder",
self.build_folder,
"--device",
self.device,
"--model",
self.model,
"--ip",
self.ip,
"--port",
str(self.port),
]
if self.host:
cmds.extend(["--host", self.host])

p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
with Listener((self.ip, self.port)) as listener:
conn = listener.accept()
p.communicate()
msg = json.loads(conn.recv())
if "Error" in msg:
self.fail(msg["Error"])
else:
self.assertGreaterEqual(msg["accuracy"], 0.5)

def test_squeezenet(self):
if not self.required_envs([self.image_dataset]):
self.skipTest("missing required envs")
Expand Down Expand Up @@ -5366,6 +5421,11 @@ def setup_environment():
help="Location for imagenet dataset",
type=str,
)
parser.add_argument(
"--sentence_dataset",
help="Location for sentence dataset",
type=str,
)
parser.add_argument(
"-p",
"--pretrained_weight",
Expand Down Expand Up @@ -5417,6 +5477,7 @@ def setup_environment():
TestQNN.executorch_root = args.executorch_root
TestQNN.artifact_dir = args.artifact_dir
TestQNN.image_dataset = args.image_dataset
TestQNN.sentence_dataset = args.sentence_dataset
TestQNN.pretrained_weight = args.pretrained_weight
TestQNN.model_name = args.model_name
TestQNN.online_prepare = args.online_prepare
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
executorch_root: str = ""
artifact_dir: str = ""
image_dataset: str = ""
sentence_dataset: str = ""
pretrained_weight: str = ""
enable_profile: bool = False
online_prepare: bool = False
Expand Down
163 changes: 163 additions & 0 deletions examples/qualcomm/oss_scripts/roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import getpass
import json
import os
from multiprocessing.connection import Client

import evaluate
import numpy as np
import torch

from executorch.backends.qualcomm._passes.qnn_pass_manager import (
get_capture_program_passes,
)
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype

from executorch.examples.qualcomm.utils import (
build_executorch_binary,
get_masked_language_model_dataset,
make_output_dir,
parse_skip_delegation_node,
setup_common_args_and_variables,
SimpleADB,
)
from transformers import AutoModelForMaskedLM, AutoTokenizer


def get_instance(args):
module = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base").eval()
return module


def main(args):
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)

os.makedirs(args.artifact, exist_ok=True)
data_size = 100

tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
inputs, targets, input_list = get_masked_language_model_dataset(
args.dataset, tokenizer, data_size
)

# Get the Roberta model.
model = get_instance(args)
pte_filename = "roberta_qnn"

# lower to QNN
passes_job = get_capture_program_passes()
build_executorch_binary(
model,
inputs[0],
args.model,
f"{args.artifact}/{pte_filename}",
dataset=inputs,
skip_node_id_set=skip_node_id_set,
skip_node_op_set=skip_node_op_set,
quant_dtype=QuantDtype.use_16a8w,
passes_job=passes_job,
shared_buffer=args.shared_buffer,
)

if args.compile_only:
return

workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
pte_path = f"{args.artifact}/{pte_filename}.pte"

adb = SimpleADB(
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
build_path=f"{args.build_folder}",
pte_path=pte_path,
workspace=workspace,
device_id=args.device,
host_id=args.host,
soc_model=args.model,
)
output_data_folder = f"{args.artifact}/outputs"
make_output_dir(output_data_folder)

# demo
mask_token = tokenizer.mask_token
text = f"Hello I'm a {mask_token} model."
sample_input = tokenizer(
text,
return_tensors="pt",
padding="max_length",
max_length=inputs[0][0].shape[1],
)
sample_input["input_ids"] = sample_input["input_ids"].to(torch.int32)
sample_input["attention_mask"] = sample_input["attention_mask"].to(torch.float32)
sample_input = tuple(sample_input.values())
golden = model(*sample_input)[0]
adb.push(inputs=[sample_input], input_list="input_0_0.raw input_0_1.raw\n")
adb.execute()
adb.pull(output_path=args.artifact)

print(f"input: {tokenizer.batch_decode(sample_input[0])}")
print(f"golden output: {tokenizer.batch_decode(golden.argmax(axis=2))}")
predictions = np.fromfile(
os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32
).reshape([1, inputs[0][0].shape[1], -1])
print(f"QNN output: {tokenizer.batch_decode(predictions.argmax(axis=2))}")

# accuracy analysis
adb.push(inputs=inputs, input_list=input_list)
adb.execute()
adb.pull(output_path=args.artifact)
goldens, predictions = [], []
for i in range(len(inputs)):
indice = [i for i, x in enumerate(targets[i]) if x != -100]
goldens.extend(targets[i][indice].tolist())
prediction = (
np.fromfile(
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
)
.reshape([1, inputs[0][0].shape[1], -1])
.argmax(axis=-1)
)
predictions.extend(prediction[0, indice].tolist())
metric = evaluate.load("accuracy")
results = metric.compute(predictions=predictions, references=goldens)
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"accuracy": results["accuracy"]}))
else:
print(f"accuracy: {results['accuracy']}")


if __name__ == "__main__":
parser = setup_common_args_and_variables()
parser.add_argument(
"-a",
"--artifact",
help="path for storing generated artifacts and output by this example. Default ./Roberta_qnn",
default="./Roberta_qnn",
type=str,
)
parser.add_argument(
"-d",
"--dataset",
help=(
"path to the validation text. "
"e.g. --dataset wikisent2.txt "
"for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences"
),
type=str,
required=True,
)

args = parser.parse_args()
try:
main(args)
except Exception as e:
if args.ip and args.port != -1:
with Client((args.ip, args.port)) as conn:
conn.send(json.dumps({"Error": str(e)}))
else:
raise Exception(e)
Loading
Loading