Skip to content

Commit e0bc399

Browse files
committed
Added main logic
1 parent 440dcde commit e0bc399

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

python/openai/openai_frontend/engine/triton_engine.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Redistribution and use in source and binary forms, with or without
44
# modification, are permitted provided that the following conditions
@@ -78,11 +78,22 @@ class TritonModelMetadata:
7878

7979
class TritonLLMEngine(LLMEngine):
8080
def __init__(
81-
self, server: tritonserver.Server, tokenizer: str, backend: Optional[str] = None
81+
self,
82+
server: tritonserver.Server,
83+
tokenizer_map: Dict[str, str] = None,
84+
backend: Optional[str] = None,
8285
):
8386
# Assume an already configured and started server
8487
self.server = server
85-
self.tokenizer = self._get_tokenizer(tokenizer)
88+
self.tokenizer_map = {}
89+
if tokenizer_map:
90+
for model_name, tokenizer_path in tokenizer_map.items():
91+
try:
92+
self.tokenizer_map[model_name] = get_tokenizer(tokenizer_path)
93+
except Exception as e:
94+
print(
95+
f"Warning: Failed to load tokenizer for {model_name} from {tokenizer_path}: {e}"
96+
)
8697
# TODO: Reconsider name of "backend" vs. something like "request_format"
8798
self.backend = backend
8899

@@ -253,12 +264,12 @@ def _get_model_metadata(self) -> Dict[str, TritonModelMetadata]:
253264
if not backend and model.config()["platform"] == "ensemble":
254265
backend = "ensemble"
255266
print(f"Found model: {name=}, {backend=}")
256-
267+
default_tokenizer = self.tokenizer_map.get("default", None)
257268
metadata = TritonModelMetadata(
258269
name=name,
259270
backend=backend,
260271
model=model,
261-
tokenizer=self.tokenizer,
272+
tokenizer=self.tokenizer_map.get(name, default_tokenizer),
262273
create_time=self.create_time,
263274
request_converter=self._determine_request_converter(backend),
264275
)

python/openai/openai_frontend/main.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env python3
22

3-
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# Copyright 2024-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
44
#
55
# Redistribution and use in source and binary forms, with or without
66
# modification, are permitted provided that the following conditions
@@ -95,12 +95,20 @@ def parse_args():
9595
required=True,
9696
help="Path to the Triton model repository holding the models to be served",
9797
)
98+
# TODO: determine what to do with single tokenizer flag
9899
triton_group.add_argument(
99100
"--tokenizer",
100101
type=str,
101102
default=None,
102103
help="HuggingFace ID or local folder path of the Tokenizer to use for chat templates",
103104
)
105+
triton_group.add_argument(
106+
"--tokenizers",
107+
type=str,
108+
nargs="+", # Accept multiple arguments
109+
default=[],
110+
help="List of HuggingFace IDs or local folder paths of Tokenizers to use. Format: model_name:tokenizer_path",
111+
)
104112
triton_group.add_argument(
105113
"--backend",
106114
type=str,
@@ -160,8 +168,22 @@ def parse_args():
160168
def main():
161169
args = parse_args()
162170

163-
# Initialize a Triton Inference Server pointing at LLM models
164-
server: tritonserver.Server = tritonserver.Server(
171+
# Parse tokenizer mappings
172+
tokenizer_map = {}
173+
for tokenizer_spec in args.tokenizers:
174+
try:
175+
model_name, tokenizer_path = tokenizer_spec.split(":")
176+
tokenizer_map[model_name] = tokenizer_path
177+
except ValueError:
178+
print(
179+
f"Warning: Skipping invalid tokenizer specification: {tokenizer_spec}. Format should be 'model_name:tokenizer_path'"
180+
)
181+
182+
if args.tokenizer:
183+
tokenizer_map["default"] = args.tokenizer
184+
185+
# Initialize Triton server
186+
server = tritonserver.Server(
165187
model_repository=args.model_repository,
166188
log_verbose=args.tritonserver_log_verbose_level,
167189
log_info=True,
@@ -170,8 +192,8 @@ def main():
170192
).start(wait_until_ready=True)
171193

172194
# Wrap Triton Inference Server in an interface-conforming "LLMEngine"
173-
engine: TritonLLMEngine = TritonLLMEngine(
174-
server=server, tokenizer=args.tokenizer, backend=args.backend
195+
engine = TritonLLMEngine(
196+
server=server, tokenizer_map=tokenizer_map, backend=args.backend
175197
)
176198

177199
# Attach TritonLLMEngine as the backbone for inference and model management

0 commit comments

Comments
 (0)