-
Notifications
You must be signed in to change notification settings - Fork 3.6k
feat: qol for page sizes - decrease defaults and raise error in PD when size is different #14474
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
|
Also, I'm not sure what should be the expected behavior -- we throw a ValueError when they mismatch, but the server does not stop. Should we stopping both servers? |
Signed-off-by: Raayan Dhar [email protected] <[email protected]>
| dst_tp_rank: int | ||
| dst_attn_tp_size: int | ||
| dst_kv_item_len: int | ||
| page_size: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this would be better?
| page_size: int | |
| dst_page_size: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will make the change tomorrow
| dst_tp_rank=int(msg[7].decode("ascii")), | ||
| dst_attn_tp_size=int(msg[8].decode("ascii")), | ||
| dst_kv_item_len=int(msg[9].decode("ascii")), | ||
| page_size=int(msg[10].decode("ascii")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
ShangmingCai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The motivation has some points. But I don't think we should crash the prefill server. This could cause a chain reaction, leading to the entire cluster crashing due to the addition of only a misconfigured node. This is something we don't want to see. Maybe a warning should be enough.
| decode_tp_size: int | ||
| decode_tp_rank: int | ||
| dst_kv_item_len: int | ||
| page_size: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| decode_tp_size=int(msg[8].decode("ascii")), | ||
| decode_tp_rank=int(msg[9].decode("ascii")), | ||
| dst_kv_item_len=int(msg[10].decode("ascii")), | ||
| page_size=int(msg[11].decode("ascii")), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
| f"TensorRT-LLM MLA only supports page_size of 32 or 64, changing page_size from {self.page_size} to 32." | ||
| ) | ||
| self.page_size = 64 | ||
| self.page_size = 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you eloborate on this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, this is from this issue: #14443 that was opened.
Hmm, ok. Makes sense. I'll change to log a warning instead, so that if there is some downstream issue caused by the page size we have a better idea what happened from the logs. That's more in line with the original motivation in the issue. Also, should we log the warning on both prefill and decode? Right now we have this on just prefill, I personally think that's fine since the error is related to both servers anyways. But I defer to your judgement. I'm also on slack if you want to discuss further. |
|
@raayandhar I think we don't need to notify the prefill server, we just crash the decode server at the bootstrapping step. I think this is the best way to help users identify the misconfiguration and won't affect any performance. |
|
Sorry, I realize that the prefill page size wasn't registered to the bootstrap server, so the decode node doesn't know. |
|
Maybe this check covers this scenario: # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
logger.error(
f"[{mooncake_session_id}] slice size ({heads_bytes_per_token_to_send}) exceeds "
f"target token slot size ({dst_kv_item_len // page_size})"
)
return -1 |
hmm, I will go and try tomorrow to see what happens and understand better. But I think at least some better messaging pointing to the issue (!= page sizes) could be nice. |
b8zhong
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we are in the single node scenario we can use different attention backend for prefill and decode (which can have mismatched page sizes when incorrectly set), can you make sure it also works in the single node case?
Motivation
See fix #14439 and fix #14443
Modifications
Lowered the defaults as described here: #14443
To check that the page sizes are the same, we have to communicate this between the prefill and decode servers, so I did that and on the prefill side it will throw an error. But I'm not really entirely happy with this design since the decode server is left hanging and the request will also hang. We should be throwing the error on both sides I think. Also, it only does this once we send a request, not during warmup / server start. I would really appreciate some feedback here as I'm sure there's a more intelligent way of doing this. I will continue working on that part.
Example (on prefill):
Accuracy Tests
Benchmarking and Profiling
Checklist