forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathscheduler.py
196 lines (158 loc) · 7.78 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Optional
from tensorrt_llm.bindings import executor as tb_executor
from tensorrt_llm.bindings import internal as tb_internal
from .llm_request import LlmRequest, LlmRequestState
RequestList = list[LlmRequest]
SchedulerOutput = namedtuple("SchedulerOutput", [
"context_requests", "generation_requests", "paused_requests",
"fitting_disagg_gen_init_requests", "num_fitting_requests"
])
class ScheduledRequests:
# to be aligned with ScheduledRequests in cpp/tensorrt_llm/batch_manager/common.h
context_requests: RequestList
generation_requests: RequestList
paused_requests: RequestList
@property
def is_generation_only(self) -> bool:
return (not self.context_requests and all(
len(req.draft_tokens) == 0 for req in self.generation_requests))
@property
def can_run_cuda_graph(self) -> bool:
return (not self.context_requests)
@property
def batch_size(self) -> int:
return len(self.context_requests) + len(self.generation_requests)
class RequestScheduler(ABC):
@abstractmethod
def schedule_request(self, active_requests: RequestList,
inflight_request_ids: set[int]) -> SchedulerOutput:
"""
:param active_requests: list of active requests, up to maximum number of sequences
:param inflight_request_ids: set of request ids that are inflight (of all micro batches)
:return: SchedulerOutput
"""
# to be aligned with RequestScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/requestScheduler.h
raise NotImplementedError
class CapacityScheduler(ABC):
@abstractmethod
def schedule_request(
self, active_requests: RequestList
) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]:
"""
:param active_requests: list of active requests, up to maximum number of sequences
:return: (scheduledRequests, pausedRequests)
"""
# to be aligned with CapacityScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/capacityScheduler.h
raise NotImplementedError
class BindCapacityScheduler(CapacityScheduler):
def __init__(
self,
max_num_requests: int,
kv_cache_manager,
scheduler_policy: tb_executor.CapacitySchedulerPolicy = tb_executor.
CapacitySchedulerPolicy.GUARANTEED_NO_EVICT):
super(BindCapacityScheduler, self).__init__()
self.kv_cache_manager = kv_cache_manager
self.impl = tb_internal.algorithms.CapacityScheduler(
max_num_requests, scheduler_policy, kv_cache_manager is not None,
False, LlmRequestState.CONTEXT_INIT,
LlmRequestState.GENERATION_COMPLETE)
def schedule_request(
self, active_requests: RequestList
) -> tuple[list[LlmRequest], list[LlmRequest], list[LlmRequest]]:
return self.impl(active_requests, self.kv_cache_manager)
class GuaranteedNoEvictScheduler(CapacityScheduler):
# only schedule requests has no_schedule_until_state <= state < no_schedule_after_state
no_schedule_until_state = LlmRequestState.CONTEXT_INIT
no_schedule_after_state = LlmRequestState.GENERATION_COMPLETE
def __init__(self, max_num_requests: int, kv_cache_manager):
super(GuaranteedNoEvictScheduler, self).__init__()
self.max_num_requests = max_num_requests
self.kv_cache_manager = kv_cache_manager
def schedule_request(
self, active_requests: RequestList
) -> tuple[list[LlmRequest], list[LlmRequest]]:
scheduled_requests = []
pending_requests = []
reserved_blocks = 0
max_blocks = self.kv_cache_manager.get_max_resource_count()
for request in active_requests:
req_state = request.state
# if request cannot be scheduled yet or request should no longer be scheduled, skip
if req_state.value < self.no_schedule_until_state.value or req_state.value >= self.no_schedule_after_state.value:
continue
if len(scheduled_requests
) >= self.max_num_requests or reserved_blocks >= max_blocks:
break
elif req_state == LlmRequestState.GENERATION_IN_PROGRESS or req_state == LlmRequestState.GENERATION_TO_COMPLETE:
scheduled_requests.append(request)
reserved_blocks += self.kv_cache_manager.get_needed_resource_to_completion(
request)
else:
pending_requests.append(request)
avaiable_blocks = max_blocks - reserved_blocks
for request in pending_requests:
req_state = request.state
if len(scheduled_requests) >= self.max_num_requests:
break
elif req_state == LlmRequestState.CONTEXT_INIT:
needed_blocks = self.kv_cache_manager.get_needed_resource_to_completion(
request)
if needed_blocks <= avaiable_blocks:
scheduled_requests.append(request)
avaiable_blocks -= needed_blocks
elif needed_blocks > avaiable_blocks:
# If one requests fails to be scheduled, break
break
assert len(scheduled_requests) > 0, (
"no pending request can get enough resource to complete, "
"please increase KV cache pool size.")
return scheduled_requests, []
class MicroBatchScheduler(ABC):
@abstractmethod
def schedule(
self, active_requests: RequestList, inflight_request_ids: set[int]
) -> tuple[list[LlmRequest], list[LlmRequest]]:
"""
:param active_requests: list of active requests, up to maximum number of sequences
:param inflight_request_ids: set of request ids that are inflight (of all micro batches)
:return: (contextRequests, generationRequests)
"""
# to be aligned with MicroBatchScheduler::scheduleRequests in cpp/tensorrt_llm/batch_manager/microBatchScheduler.h
raise NotImplementedError
class BindMicroBatchScheduler(MicroBatchScheduler):
def __init__(
self,
max_batch_size: int,
max_num_tokens: int = None,
ctx_chunk_config: Optional[
tb_internal.batch_manager.ContextChunkingConfig] = None,
) -> None:
super(BindMicroBatchScheduler, self).__init__()
self.max_batch_size = max_batch_size
self.max_num_tokens = max_num_tokens
self.impl = tb_internal.algorithms.MicroBatchScheduler(
ctx_chunk_config, max_num_tokens)
def schedule(
self, active_requests: RequestList, inflight_request_ids: set[int]
) -> tuple[list[LlmRequest], list[LlmRequest]]:
return self.impl(active_requests, inflight_request_ids,
self.max_batch_size, self.max_num_tokens)
class SimpleScheduler(RequestScheduler):
def __init__(self, capacity_scheduler: CapacityScheduler,
micro_batch_scheduler: MicroBatchScheduler):
super(SimpleScheduler, self).__init__()
self.capacity_scheduler = capacity_scheduler
self.micro_batch_scheduler = micro_batch_scheduler
def schedule_request(self, active_requests: RequestList,
inflight_request_ids: set[int]) -> SchedulerOutput:
fitting_requests, fitting_disagg_gen_init_requests, paused_requests = self.capacity_scheduler.schedule_request(
active_requests)
context_requests, generation_requests = self.micro_batch_scheduler.schedule(
fitting_requests, inflight_request_ids)
return SchedulerOutput(context_requests, generation_requests,
paused_requests,
fitting_disagg_gen_init_requests,
len(fitting_requests))