Skip to content

Commit 569769e

Browse files
committed
Add QNN backend end to end script for Roberta model
- Add end to end script for Roberta - Handle extreme values in replace_inf function for masked_fill
1 parent e02ca41 commit 569769e

File tree

9 files changed

+312
-14
lines changed

9 files changed

+312
-14
lines changed

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class I64toI32(ExportPass):
2828
I64_OPS = {
2929
exir_ops.edge.aten.argmin.default,
3030
exir_ops.edge.aten.arange.start_step,
31+
exir_ops.edge.aten.cumsum.default,
3132
exir_ops.edge.aten.full.default,
3233
exir_ops.edge.aten.scalar_tensor.default,
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3335
}
3436
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
3537
# For example, scatter op can only accept args[2], the index, as int64.

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def _build_tensor_constant(
8686
dtype=(
8787
node.args[0].meta["val"].dtype
8888
if not is_float_tensor(node)
89-
and not SCALAR_OPS.get(node.target).use_self_dtype
89+
and (info := SCALAR_OPS.get(node.target))
90+
and not info.use_self_dtype
9091
else node.meta["val"].dtype
9192
),
9293
device=node.meta["val"].device,

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
arg_list[index] = torch.finfo(torch.float32).min
3131
elif arg == float("inf"):
3232
arg_list[index] = torch.finfo(torch.float32).max
33+
34+
if node.target == torch.ops.aten.masked_fill.Scalar:
35+
if arg_list[2] == torch.finfo(torch.float32).max:
36+
arg_list[2] = 255
37+
elif arg_list[2] == torch.finfo(torch.float32).min:
38+
arg_list[2] = -255
3339
node.args = tuple(arg_list)
3440

3541
graph_module.recompile()

backends/qualcomm/builders/op_cum_sum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def define_node(
5050
dim = self.get_param(node, input_tensor)
5151

5252
output_tensor = self.get_tensor(node, node)
53+
if output_tensor.dtype == torch.int64:
54+
output_tensor = output_tensor.to(torch.int32)
5355
output_tensor_wrapper = self.define_tensor(
5456
node,
5557
node,

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def forward(self, x):
11011101
return torch.mean(x, (-1, -2))
11021102

11031103

1104+
class MaskedFill(torch.nn.Module):
1105+
def __init__(self):
1106+
super().__init__()
1107+
1108+
def forward(self, attn_mask):
1109+
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1110+
attn_mask == 0, float(0.0)
1111+
)
1112+
1113+
11041114
class Maximum(torch.nn.Module):
11051115
def __init__(self):
11061116
super().__init__()
@@ -1751,16 +1761,6 @@ def forward(self, x):
17511761
)
17521762

17531763

1754-
class MaskedFill(torch.nn.Module):
1755-
def __init__(self):
1756-
super().__init__()
1757-
1758-
def forward(self, attn_mask):
1759-
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1760-
attn_mask == 0, float(0.0)
1761-
)
1762-
1763-
17641764
# Mimi Decoder has 0D tensor which QNN cannot handle.
17651765
class ZeroDimTensor(torch.nn.Module):
17661766
def __init__(self):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -277,9 +277,24 @@ def test_qnn_backend_cos(self):
277277
self.lower_module_and_test_output(module, sample_input)
278278

279279
def test_qnn_backend_cumsum(self):
280-
module = CumSum() # noqa: F405
281-
sample_input = (torch.randn(4),)
282-
self.lower_module_and_test_output(module, sample_input)
280+
sample_input = ()
281+
test_comb = [
282+
{
283+
QCOM_MODULE: [CumSum()], # noqa: F405
284+
QCOM_SAMPLE_INPUTS: [
285+
(torch.randn(4),),
286+
(torch.randint(0, 10, size=(4,)),),
287+
],
288+
}
289+
]
290+
291+
index = 0
292+
for comb in test_comb:
293+
for module in comb[QCOM_MODULE]:
294+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
295+
with self.subTest(i=index):
296+
self.lower_module_and_test_output(module, sample_input)
297+
index += 1
283298

284299
def test_qnn_backend_einsum_outer_product(self):
285300
module = EinsumOuterProduct() # noqa: F405
@@ -316,6 +331,12 @@ def test_qnn_backend_element_wise_add(self):
316331
QCOM_MODULE: [AddConstantFloat()], # noqa: F405
317332
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
318333
},
334+
{
335+
QCOM_MODULE: [
336+
AddConstantLong(),
337+
], # noqa: F405
338+
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
339+
},
319340
]
320341

321342
index = 0
@@ -4562,6 +4583,40 @@ def test_retinanet(self):
45624583
else:
45634584
self.assertGreaterEqual(msg["mAP"], 0.6)
45644585

4586+
def test_roberta(self):
4587+
if not self.required_envs([self.sentence_dataset]):
4588+
self.skipTest("missing required envs")
4589+
cmds = [
4590+
"python",
4591+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py",
4592+
"--dataset",
4593+
self.sentence_dataset,
4594+
"--artifact",
4595+
self.artifact_dir,
4596+
"--build_folder",
4597+
self.build_folder,
4598+
"--device",
4599+
self.device,
4600+
"--model",
4601+
self.model,
4602+
"--ip",
4603+
self.ip,
4604+
"--port",
4605+
str(self.port),
4606+
]
4607+
if self.host:
4608+
cmds.extend(["--host", self.host])
4609+
4610+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4611+
with Listener((self.ip, self.port)) as listener:
4612+
conn = listener.accept()
4613+
p.communicate()
4614+
msg = json.loads(conn.recv())
4615+
if "Error" in msg:
4616+
self.fail(msg["Error"])
4617+
else:
4618+
self.assertGreaterEqual(msg["accuracy"], 0.5)
4619+
45654620
def test_squeezenet(self):
45664621
if not self.required_envs([self.image_dataset]):
45674622
self.skipTest("missing required envs")
@@ -5366,6 +5421,11 @@ def setup_environment():
53665421
help="Location for imagenet dataset",
53675422
type=str,
53685423
)
5424+
parser.add_argument(
5425+
"--sentence_dataset",
5426+
help="Location for sentence dataset",
5427+
type=str,
5428+
)
53695429
parser.add_argument(
53705430
"-p",
53715431
"--pretrained_weight",
@@ -5417,6 +5477,7 @@ def setup_environment():
54175477
TestQNN.executorch_root = args.executorch_root
54185478
TestQNN.artifact_dir = args.artifact_dir
54195479
TestQNN.image_dataset = args.image_dataset
5480+
TestQNN.sentence_dataset = args.sentence_dataset
54205481
TestQNN.pretrained_weight = args.pretrained_weight
54215482
TestQNN.model_name = args.model_name
54225483
TestQNN.online_prepare = args.online_prepare

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
sentence_dataset: str = ""
186187
pretrained_weight: str = ""
187188
enable_profile: bool = False
188189
online_prepare: bool = False
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import getpass
8+
import json
9+
import os
10+
from multiprocessing.connection import Client
11+
12+
import evaluate
13+
import numpy as np
14+
import torch
15+
16+
from executorch.backends.qualcomm._passes.qnn_pass_manager import (
17+
get_capture_program_passes,
18+
)
19+
from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
20+
21+
from executorch.examples.qualcomm.utils import (
22+
build_executorch_binary,
23+
get_masked_language_model_dataset,
24+
make_output_dir,
25+
parse_skip_delegation_node,
26+
setup_common_args_and_variables,
27+
SimpleADB,
28+
)
29+
from transformers import AutoModelForMaskedLM, AutoTokenizer
30+
31+
32+
def get_instance(args):
33+
module = AutoModelForMaskedLM.from_pretrained("xlm-roberta-base").eval()
34+
return module
35+
36+
37+
def main(args):
38+
skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
39+
40+
os.makedirs(args.artifact, exist_ok=True)
41+
data_size = 100
42+
43+
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
44+
inputs, targets, input_list = get_masked_language_model_dataset(
45+
args.dataset, tokenizer, data_size
46+
)
47+
48+
# Get the Roberta model.
49+
model = get_instance(args)
50+
pte_filename = "roberta_qnn"
51+
52+
# lower to QNN
53+
passes_job = get_capture_program_passes()
54+
build_executorch_binary(
55+
model,
56+
inputs[0],
57+
args.model,
58+
f"{args.artifact}/{pte_filename}",
59+
dataset=inputs,
60+
skip_node_id_set=skip_node_id_set,
61+
skip_node_op_set=skip_node_op_set,
62+
quant_dtype=QuantDtype.use_16a8w,
63+
passes_job=passes_job,
64+
shared_buffer=args.shared_buffer,
65+
)
66+
67+
if args.compile_only:
68+
return
69+
70+
workspace = f"/data/local/tmp/{getpass.getuser()}/executorch/{pte_filename}"
71+
pte_path = f"{args.artifact}/{pte_filename}.pte"
72+
73+
adb = SimpleADB(
74+
qnn_sdk=os.getenv("QNN_SDK_ROOT"),
75+
build_path=f"{args.build_folder}",
76+
pte_path=pte_path,
77+
workspace=workspace,
78+
device_id=args.device,
79+
host_id=args.host,
80+
soc_model=args.model,
81+
)
82+
output_data_folder = f"{args.artifact}/outputs"
83+
make_output_dir(output_data_folder)
84+
85+
# demo
86+
mask_token = tokenizer.mask_token
87+
text = f"Hello I'm a {mask_token} model."
88+
sample_input = tokenizer(
89+
text,
90+
return_tensors="pt",
91+
padding="max_length",
92+
max_length=inputs[0][0].shape[1],
93+
)
94+
sample_input["input_ids"] = sample_input["input_ids"].to(torch.int32)
95+
sample_input["attention_mask"] = sample_input["attention_mask"].to(torch.float32)
96+
sample_input = tuple(sample_input.values())
97+
golden = model(*sample_input)[0]
98+
adb.push(inputs=[sample_input], input_list="input_0_0.raw input_0_1.raw\n")
99+
adb.execute()
100+
adb.pull(output_path=args.artifact)
101+
102+
print(f"input: {tokenizer.batch_decode(sample_input[0])}")
103+
print(f"golden output: {tokenizer.batch_decode(golden.argmax(axis=2))}")
104+
predictions = np.fromfile(
105+
os.path.join(output_data_folder, "output_0_0.raw"), dtype=np.float32
106+
).reshape([1, inputs[0][0].shape[1], -1])
107+
print(f"QNN output: {tokenizer.batch_decode(predictions.argmax(axis=2))}")
108+
109+
# accuracy analysis
110+
adb.push(inputs=inputs, input_list=input_list)
111+
adb.execute()
112+
adb.pull(output_path=args.artifact)
113+
goldens, predictions = [], []
114+
for i in range(len(inputs)):
115+
indice = [i for i, x in enumerate(targets[i]) if x != -100]
116+
goldens.extend(targets[i][indice].tolist())
117+
prediction = (
118+
np.fromfile(
119+
os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
120+
)
121+
.reshape([1, inputs[0][0].shape[1], -1])
122+
.argmax(axis=-1)
123+
)
124+
predictions.extend(prediction[0, indice].tolist())
125+
metric = evaluate.load("accuracy")
126+
results = metric.compute(predictions=predictions, references=goldens)
127+
if args.ip and args.port != -1:
128+
with Client((args.ip, args.port)) as conn:
129+
conn.send(json.dumps({"accuracy": results["accuracy"]}))
130+
else:
131+
print(f"accuracy: {results['accuracy']}")
132+
133+
134+
if __name__ == "__main__":
135+
parser = setup_common_args_and_variables()
136+
parser.add_argument(
137+
"-a",
138+
"--artifact",
139+
help="path for storing generated artifacts and output by this example. Default ./Roberta_qnn",
140+
default="./Roberta_qnn",
141+
type=str,
142+
)
143+
parser.add_argument(
144+
"-d",
145+
"--dataset",
146+
help=(
147+
"path to the validation text. "
148+
"e.g. --dataset wikisent2.txt "
149+
"for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences"
150+
),
151+
type=str,
152+
required=True,
153+
)
154+
155+
args = parser.parse_args()
156+
try:
157+
main(args)
158+
except Exception as e:
159+
if args.ip and args.port != -1:
160+
with Client((args.ip, args.port)) as conn:
161+
conn.send(json.dumps({"Error": str(e)}))
162+
else:
163+
raise Exception(e)

0 commit comments

Comments
 (0)