-
Notifications
You must be signed in to change notification settings - Fork 61
Add Support for Guided Decoding to On Device Sampling #624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: quic-sanising <[email protected]> Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
|
Depends on #597 |
Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: sanising <[email protected]>
Signed-off-by: sanising <[email protected]>
|
Ready for review |
Signed-off-by: sanising <[email protected]>
Signed-off-by: quic-xiyushi <[email protected]>
Signed-off-by: Mamta Singh <[email protected]>
Signed-off-by: Mamta Singh <[email protected]>
| top_ps: Optional[torch.Tensor] = None, | ||
| min_ps: Optional[torch.Tensor] = None, | ||
| random_numbers: Optional[torch.Tensor] = None, | ||
| vision_embeds: Optional[torch.Tensor] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please keep dtype of these 2 consistent as per lines 27-28. also update function docstring for these newly added args.
| 1, # spec_length | ||
| False, # is_vlm | ||
| ), | ||
| # pytest.param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason for this test to be disabled?
| additional_configs["config"] = config | ||
| additional_configs["kv_offload"] = True | ||
| qeff_class = QEFFAutoModelForImageTextToText | ||
| assert isinstance(prompts, tuple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add error message in this case for vlms.
| prompts = prompts[1] | ||
| else: | ||
| qeff_class = QEFFAutoModelForCausalLM | ||
| spec_length -= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
num_hidden_layers should be specified here or does this test require all model layers?
| "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), | ||
| "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), | ||
| "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), | ||
| "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How these numbers are changed?
| "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), | ||
| "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), | ||
| "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), | ||
| "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For greedy sampling, top_k = 1 and top_p = 0 should be used here?
| else: | ||
| additional_configs["num_hidden_layers"] = 2 | ||
| qeff_class = QEFFAutoModelForCausalLM | ||
| spec_length -= 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we create a function(with inputs is_vlm, model, prompts, spec_length and return qeff_class, additional_params, additional_configs, spec_length) using lines 565-580 and reuse to avoid duplication?
| qaic_config={ | ||
| "include_sampler": True, | ||
| "return_pdfs": False, | ||
| "max_top_k_ids": 1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we define this test parameter in constants and then reuse?
| include_sampler = None | ||
| return_pdfs = None | ||
| max_top_k_ids = None | ||
| include_guided_decoding = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These should be used as boolean variables here?
| QEffMptForCausalLM, | ||
| QEffPhi3ForCausalLM, | ||
| QEffQwen2ForCausalLM, | ||
| QEffQwen_2_5_vl_DecoderWrapper, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this supported for Intern and Qwen models only?
✨ Add Support for Guided Decoding to On Device Sampling
📌 Overview
This PR introduces guided decoding capabilities in On Device Sampling for
QEffForCausalLMandQEffCausalLMForTextImageToTextModelmodels.🚀 Motivation
As outlined in this blog on structured decoding, structured decoding represents a fundamental shift in controlling LLM outputs. Instead of relying on post-processing, constraints are enforced during token generation via logits manipulation. This approach ensures:
The constraints are provided through
token_bitmaskswhich is a Boolean matrix of shape(batch_size, vocab_size). Here, each element indicates whether a token should be kept (1) or masked (0). During sampling, this mask is applied to the logits before token selection, ensuring that only allowed tokens are considered.By performing this operation directly on the device, we eliminate host-device transfers, reduce latency, and improve throughput for structured decoding workloads.
🛠️ Implementation Details
The guided decoding logic is injected via
include_guided_decoding=Trueduring model loading. No changes to the model architecture are required.To disable guided decoding, simply set
include_guided_decoding=False.