Skip to content

Commit 9b82627

Browse files
authored
Merge branch 'main' into support-dynamically-quantized-convolutions
2 parents d82e080 + 7c150d4 commit 9b82627

File tree

22 files changed

+238
-401
lines changed

22 files changed

+238
-401
lines changed

.github/workflows/doc-build.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ jobs:
2121
- name: Check URLs
2222
run: bash ./scripts/check_urls.sh
2323

24-
check-links:
24+
check-xrefs:
2525
runs-on: ubuntu-latest
2626
steps:
2727
- uses: actions/checkout@v3
2828
- name: Check Links
29-
run: bash ./scripts/check_links.sh
29+
run: bash ./scripts/check_xrefs.sh
3030

3131
build:
3232
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main

backends/apple/mps/mps_preprocess.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import ClassVar, Dict, final, List, Tuple
77

88
import torch
9+
from executorch import exir
910

1011
from executorch.backends.apple.mps.operators.node_visitor import (
1112
get_node_visitors,
@@ -35,6 +36,7 @@
3536

3637
from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
3738
from executorch.exir.program._program import _transform
39+
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
3840
from torch.export.exported_program import ExportedProgram
3941

4042
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
@@ -87,7 +89,19 @@ def preprocess(
8789
# the `output_ids` array in the schema.
8890

8991
# TODO: Remove this once we have a better support for the dim-order ops.
90-
edge_program = _transform(edge_program, DimOrderOpsRevertPass())
92+
# Need to override the verifier to skip the non dim-order ops from tripping the default verifier.
93+
edge_program = _transform(
94+
edge_program,
95+
DimOrderOpsRevertPass(),
96+
override_verifiers=[
97+
EXIREdgeDialectVerifier(
98+
edge_compile_config=exir.EdgeCompileConfig(
99+
_check_ir_validity=False, # Disable the edge dialect verifier, since we are in the mps backend.
100+
),
101+
class_only=True,
102+
)
103+
],
104+
)
91105

92106
mps_graph = MPSGraph(
93107
version="0",

examples/demo-apps/android/LlamaDemo/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ Ensure you have the following functions in your callback class that you provided
135135
}
136136

137137
@Override
138-
public void onStats(float tps) {
139-
//...tps (tokens per second) stats is provided by framework
138+
public void onStats(String stats) {
139+
//... will be a json. See extension/llm/stats.h for the field definitions
140140
}
141141

142142
```

examples/models/llama/export_llama_lib.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1227,10 +1227,22 @@ def _get_source_transforms( # noqa
12271227
if args.expand_rope_table:
12281228
transforms.append(materialze_broadcast_of_rope_freq_cis)
12291229

1230+
use_attention_mask_for_custom_sdpa = False
1231+
if isinstance(args, argparse.Namespace):
1232+
if getattr(args, "use_custom_sdpa_with_attention_mask", None):
1233+
use_attention_mask_for_custom_sdpa = True
1234+
12301235
if args.use_sdpa_with_kv_cache:
12311236
transforms.append(replace_kv_cache_with_custom_kv_cache)
12321237
# todo: do this optionally
1233-
transforms.append(replace_sdpa_with_custom_op)
1238+
# if use attention mask instead of causal attention
1239+
# then create partial function that sets use_attention_mask=True
1240+
if use_attention_mask_for_custom_sdpa:
1241+
transforms.append(
1242+
partial(replace_sdpa_with_custom_op, use_attention_mask=True)
1243+
)
1244+
else:
1245+
transforms.append(replace_sdpa_with_custom_op)
12341246

12351247
if args.quantize_kv_cache:
12361248
assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"

examples/models/llama/source_transformation/sdpa.py

+50-14
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module):
2222
def __init__(
2323
self,
2424
dim: int,
25+
max_context_len,
26+
enable_dynamic_shape,
27+
use_attention_mask: bool = False,
2528
):
2629
super().__init__()
2730
self.dim = dim
31+
self.max_context_len = max_context_len
32+
self.use_attention_mask = use_attention_mask
33+
self.enable_dynamic_shape = enable_dynamic_shape
2834

2935
def forward(
3036
self,
@@ -36,6 +42,16 @@ def forward(
3642
seqlen,
3743
mask,
3844
):
45+
if self.use_attention_mask:
46+
if self.enable_dynamic_shape:
47+
start_pos = input_pos[-1].item()
48+
torch._check_is_size(start_pos)
49+
torch._check(start_pos < self.max_context_len)
50+
seq_length = q.size(2)
51+
mask = mask.narrow(0, start_pos, seq_length)
52+
else:
53+
mask = mask[input_pos]
54+
3955
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim)
4056
k = k.transpose(1, 2)
4157
v = v.transpose(1, 2)
@@ -47,34 +63,54 @@ def forward(
4763
k = k.to(dtype=torch.float)
4864
v = v.to(dtype=torch.float)
4965

50-
output = torch.ops.llama.custom_sdpa(
51-
q,
52-
k,
53-
v,
54-
input_pos[0].item(),
55-
None, # Attention mask
56-
0, # dropout probability. Ignored by the code
57-
True, # is_causal
58-
)
66+
if self.use_attention_mask:
67+
output = torch.ops.llama.custom_sdpa(
68+
q,
69+
k,
70+
v,
71+
input_pos[0].item(),
72+
mask, # Attention mask
73+
0, # dropout probability. Ignored by the code
74+
False, # is_causal
75+
)
76+
else:
77+
output = torch.ops.llama.custom_sdpa(
78+
q,
79+
k,
80+
v,
81+
input_pos[0].item(),
82+
None, # Attention mask
83+
0, # dropout probability. Ignored by the code
84+
True, # is_causal
85+
)
5986
return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype)
6087

6188

62-
def _replace_sdpa_with_custom_op(module: torch.nn.Module):
89+
def _replace_sdpa_with_custom_op(
90+
module: torch.nn.Module, use_attention_mask: bool = False
91+
):
6392
for name, child in module.named_children():
6493
if isinstance(child, SDPA):
6594
setattr(
6695
module,
6796
name,
68-
SDPACustom(child.dim),
97+
SDPACustom(
98+
child.dim,
99+
child.max_context_len,
100+
child.enable_dynamic_shape,
101+
use_attention_mask=use_attention_mask,
102+
),
69103
)
70104
else:
71-
_replace_sdpa_with_custom_op(child)
105+
_replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask)
72106

73107

74-
def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module:
108+
def replace_sdpa_with_custom_op(
109+
module: torch.nn.Module, use_attention_mask: bool = False
110+
) -> torch.nn.Module:
75111
from executorch.extension.llm.custom_ops import custom_ops # noqa
76112

77-
_replace_sdpa_with_custom_op(module)
113+
_replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask)
78114
return module
79115

80116

examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False):
7171
self.seq_len = 3
7272
self._init_cache()
7373
q, k_val, v_val = self._init_kv()
74-
self.float_sdpa = SDPACustom(self.dim)
75-
self.quantized_sdpa = SDPACustom(self.dim)
74+
self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True)
75+
self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True)
7676
k, v = self.custom_kv_cache.update(input_pos, k_val, v_val)
7777
float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None)
7878
k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)

examples/qualcomm/scripts/mobilebert_fine_tune.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size):
102102
from transformers import get_linear_schedule_with_warmup
103103

104104
# grab dataset
105-
url = (
106-
"https://raw.githubusercontent.com/susanli2016/NLP-with-Python/master/data/title_conference.csv"
107-
)
105+
url = "https://raw.githubusercontent.com/susanli2016/NLP-with-Python/master/data/title_conference.csv"
108106
content = requests.get(url, allow_redirects=True).content
109107
data = pd.read_csv(BytesIO(content))
110108

exir/program/_program.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,30 @@ def _get_updated_graph_signature(
212212
return new_signature
213213

214214

215-
def _transform(self, *passes: PassType) -> "ExportedProgram":
215+
def _transform(
216+
self,
217+
*passes: PassType,
218+
override_verifiers: None | list[Type[Verifier]] = None,
219+
) -> "ExportedProgram":
220+
"""
221+
Transforms the program according to the provided passes.
222+
223+
Args:
224+
self: The ExportedProgram instance to transform
225+
*passes: A sequence of passes to apply to the program
226+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
227+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
228+
229+
Returns:
230+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
231+
"""
232+
# A user friendly check to avoid vararg surprises, PEP 3102
233+
assert not any(
234+
isinstance(p, (list, Verifier)) for p in passes
235+
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
236+
237+
for p in list(passes):
238+
print(type(p))
216239
pm = PassManager(list(passes))
217240
res = pm(self.graph_module)
218241
transformed_gm = res.graph_module if res is not None else self.graph_module
@@ -221,7 +244,9 @@ def _transform(self, *passes: PassType) -> "ExportedProgram":
221244
if transformed_gm is self.graph_module and not res.modified:
222245
return self
223246

224-
return _update_exported_program_graph_module(self, transformed_gm)
247+
return _update_exported_program_graph_module(
248+
self, transformed_gm, override_verifiers
249+
)
225250

226251

227252
def _update_exported_program_graph_module(

exir/program/test/test_program.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from executorch.exir.pass_base import ExportPass
2323
from executorch.exir.passes import MemoryPlanningPass
2424
from executorch.exir.program._program import (
25+
_transform,
2526
EdgeProgramManager,
2627
ExecutorchProgramManager,
2728
to_edge,
@@ -34,6 +35,7 @@
3435
from executorch.extension.pybindings.portable_lib import (
3536
_load_for_executorch_from_buffer,
3637
)
38+
from torch._export.verifier import Verifier
3739
from torch.export import Dim, export, ExportedProgram
3840
from torch.export._trace import _export
3941

@@ -273,7 +275,6 @@ def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
273275
for output_val in method.outputs:
274276
evalue = method.values[output_val]
275277
self.assertNotEqual(evalue.val.allocation_info, None)
276-
else:
277278
for input_val in method.inputs:
278279
evalue = method.values[input_val]
279280
self.assertEqual(evalue.val.allocation_info, None)
@@ -847,3 +848,23 @@ def test_save_fails(self):
847848
et = edge.to_executorch()
848849
with self.assertRaises(ValueError):
849850
_ = et.save("/tmp/test_save.pt")
851+
852+
def test__transform_override_verifiers(self):
853+
"""Test that _transform can override verifiers in the exported program."""
854+
855+
class MyVerifier(Verifier):
856+
dialect: str = "MY_DIALECT"
857+
858+
def __init__(self):
859+
super().__init__()
860+
861+
model = TestLinear()
862+
program = torch.export.export(model, model._get_random_inputs(), strict=True)
863+
self.assertFalse(issubclass(program.verifiers[0], MyVerifier))
864+
865+
# Apply transformation with custom verifier
866+
transformed = _transform(
867+
program, AddToMulPassEdge(), override_verifiers=[MyVerifier]
868+
)
869+
self.assertTrue(issubclass(transformed.verifiers[0], MyVerifier))
870+
self.assertFalse(issubclass(program.verifiers[0], MyVerifier))

extension/android/executorch_android/src/androidTest/java/org/pytorch/executorch/LlmModuleInstrumentationTest.java

+12-12
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,7 @@ public void onResult(String result) {
9797

9898
@Override
9999
public void onStats(String stats) {
100-
float tps = 0;
101-
try {
102-
JSONObject jsonObject = new JSONObject(stats);
103-
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
104-
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
105-
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
106-
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
107-
LlmModuleInstrumentationTest.this.onStats(tps);
108-
} catch (JSONException e) {
109-
}
100+
LlmModuleInstrumentationTest.this.onStats(stats);
110101
}
111102
});
112103

@@ -120,7 +111,16 @@ public void onResult(String result) {
120111
}
121112

122113
@Override
123-
public void onStats(float tps) {
124-
tokensPerSecond.add(tps);
114+
public void onStats(String stats) {
115+
float tps = 0;
116+
try {
117+
JSONObject jsonObject = new JSONObject(stats);
118+
int numGeneratedTokens = jsonObject.getInt("generated_tokens");
119+
int inferenceEndMs = jsonObject.getInt("inference_end_ms");
120+
int promptEvalEndMs = jsonObject.getInt("prompt_eval_end_ms");
121+
tps = (float) numGeneratedTokens / (inferenceEndMs - promptEvalEndMs) * 1000;
122+
tokensPerSecond.add(tps);
123+
} catch (JSONException e) {
124+
}
125125
}
126126
}

extension/android/executorch_android/src/main/java/org/pytorch/executorch/extension/llm/LlmCallback.java

+1-12
Original file line numberDiff line numberDiff line change
@@ -28,17 +28,6 @@ public interface LlmCallback {
2828
@DoNotStrip
2929
public void onResult(String result);
3030

31-
/**
32-
* Called when the statistics for the generate() is available.
33-
*
34-
* Note: This is a deprecated API and will be removed in the future. Please use onStats(String stats)
35-
*
36-
* @param tps Tokens/second for generated tokens.
37-
*/
38-
@Deprecated
39-
@DoNotStrip
40-
default public void onStats(float tps) {}
41-
4231
/**
4332
* Called when the statistics for the generate() is available.
4433
*
@@ -48,5 +37,5 @@ default public void onStats(float tps) {}
4837
* @param stats JSON string containing the statistics for the generate()
4938
*/
5039
@DoNotStrip
51-
default public void onStats(String stats) {}
40+
default void onStats(String stats) {}
5241
}

extension/android/jni/jni_layer_llama.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,6 @@ class ExecuTorchLlmCallbackJni
100100

101101
void onStats(const llm::Stats& result) const {
102102
static auto cls = ExecuTorchLlmCallbackJni::javaClassStatic();
103-
static const auto tps_method = cls->getMethod<void(jfloat)>("onStats");
104-
double eval_time =
105-
(double)(result.inference_end_ms - result.prompt_eval_end_ms);
106-
107-
float tps = result.num_generated_tokens / eval_time *
108-
result.SCALING_FACTOR_UNITS_PER_SECOND;
109-
tps_method(self(), tps);
110-
111103
static const auto on_stats_method =
112104
cls->getMethod<void(facebook::jni::local_ref<jstring>)>("onStats");
113105
on_stats_method(

extension/llm/custom_ops/op_sdpa.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ Tensor& custom_sdpa_out_impl(
400400

401401
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
402402

403-
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
403+
const int64_t num_keys_for_causal_attention =
404+
attn_mask.has_value() ? -1 : start_pos + seq_len;
404405

405406
ET_KERNEL_CHECK(
406407
ctx,

0 commit comments

Comments
 (0)