Skip to content

Commit afd04dc

Browse files
feat(server): update vllm version (#723)
1 parent f848dec commit afd04dc

File tree

3 files changed

+21
-22
lines changed

3 files changed

+21
-22
lines changed

router/src/main.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
233233
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
234234
);
235235
}
236+
if max_total_tokens as u32 > max_supported_batch_total_tokens {
237+
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
238+
}
239+
236240
max_supported_batch_total_tokens
237241
}
238242
};
@@ -270,7 +274,7 @@ fn main() -> Result<(), RouterError> {
270274
ngrok_authtoken,
271275
ngrok_edge,
272276
)
273-
.await?;
277+
.await?;
274278
Ok(())
275279
})
276280
}

server/Makefile-vllm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
1+
vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a
22

33
vllm:
44
# Clone vllm

server/text_generation_server/utils/layers.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -219,36 +219,31 @@ def load(config, prefix: str, weights):
219219
)
220220

221221
def forward(self, input: torch.Tensor) -> torch.Tensor:
222-
if not self.should_gather:
223-
return super().forward(input)
224-
225222
world_size = self.process_group.size()
226-
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
223+
# Fast branch for single requests
224+
if (
225+
self.should_gather
226+
and len(input.shape) == 2
227+
and isinstance(self.linear, FastLinear)
228+
and input.shape[0] == 1
229+
):
227230
out_dim = self.linear.weight.shape[0]
228231

229-
if input.shape[0] == 1:
230-
world_out = input.new_empty(1, out_dim * world_size)
231-
local_out = input.new_empty(1, out_dim)
232-
gather_input = local_out
233-
else:
234-
world_out = input.new_empty(out_dim * world_size, input.shape[0])
235-
gather_input = input.new_empty(out_dim, input.shape[0])
236-
local_out = gather_input.T
232+
world_out = input.new_empty(1, out_dim * world_size)
233+
local_out = input.new_empty(1, out_dim)
237234

238235
torch.mm(input, self.linear.weight.T, out=local_out)
239236

240237
torch.distributed.all_gather_into_tensor(
241-
world_out, gather_input, group=self.process_group
238+
world_out, local_out, group=self.process_group
242239
)
243-
244-
if input.shape[0] == 1:
245-
return world_out
246-
return world_out.T
240+
return world_out
247241

248242
output = super().forward(input)
249-
world_output = [
250-
torch.empty_like(output) for _ in range(self.process_group.size())
251-
]
243+
if not self.should_gather:
244+
return output
245+
246+
world_output = [torch.empty_like(output) for _ in range(world_size)]
252247
torch.distributed.all_gather(world_output, output, group=self.process_group)
253248
world_output = torch.cat(world_output, dim=-1)
254249
return world_output

0 commit comments

Comments
 (0)