Skip to content

Commit 211b54a

Browse files
NarsilVinno97
andauthored
Rebased #617 (#868)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil --> --------- Co-authored-by: Vincent Brouwers <[email protected]>
1 parent 4486f78 commit 211b54a

File tree

23 files changed

+529
-34
lines changed

23 files changed

+529
-34
lines changed

benchmark/src/generation.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pub(crate) async fn generation_task(
3737
batch_size: Vec<u32>,
3838
sequence_length: u32,
3939
decode_length: u32,
40+
top_n_tokens: Option<u32>,
4041
n_runs: usize,
4142
warmups: usize,
4243
parameters: NextTokenChooserParameters,
@@ -48,7 +49,7 @@ pub(crate) async fn generation_task(
4849
// End task if a message is received on shutdown_receiver
4950
// _shutdown_guard_sender will be dropped once the task is finished
5051
tokio::select! {
51-
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, n_runs, warmups, parameters, client, run_sender.clone()) => {
52+
res = generate_runs(tokenizer, batch_size, sequence_length, decode_length, top_n_tokens, n_runs, warmups, parameters, client, run_sender.clone()) => {
5253
if let Err(err) = res {
5354
run_sender.send(Err(err)).await.unwrap_or(());
5455
}
@@ -64,6 +65,7 @@ async fn generate_runs(
6465
batch_size: Vec<u32>,
6566
sequence_length: u32,
6667
decode_length: u32,
68+
top_n_tokens: Option<u32>,
6769
n_runs: usize,
6870
warmups: usize,
6971
parameters: NextTokenChooserParameters,
@@ -82,6 +84,7 @@ async fn generate_runs(
8284
b,
8385
decode_length,
8486
parameters.clone(),
87+
top_n_tokens,
8588
&mut client,
8689
)
8790
.await?;
@@ -97,6 +100,7 @@ async fn generate_runs(
97100
b,
98101
decode_length,
99102
parameters.clone(),
103+
top_n_tokens,
100104
&mut client,
101105
)
102106
.await?;
@@ -130,6 +134,7 @@ async fn prefill(
130134
batch_size: u32,
131135
decode_length: u32,
132136
parameters: NextTokenChooserParameters,
137+
top_n_tokens: Option<u32>,
133138
client: &mut ShardedClient,
134139
) -> Result<(Prefill, CachedBatch), ClientError> {
135140
// Create requests
@@ -145,6 +150,7 @@ async fn prefill(
145150
stop_sequences: vec![],
146151
ignore_eos_token: true, // Will not stop even if a eos token is generated
147152
}),
153+
top_n_tokens: top_n_tokens.unwrap_or(0),
148154
})
149155
.collect();
150156

benchmark/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ pub async fn run(
2222
batch_size: Vec<u32>,
2323
sequence_length: u32,
2424
decode_length: u32,
25+
top_n_tokens: Option<u32>,
2526
n_runs: usize,
2627
warmups: usize,
2728
temperature: Option<f32>,
@@ -70,6 +71,7 @@ pub async fn run(
7071
batch_size.clone(),
7172
sequence_length,
7273
decode_length,
74+
top_n_tokens,
7375
n_runs,
7476
warmups,
7577
parameters,
@@ -130,6 +132,7 @@ pub async fn run(
130132
tokenizer_name,
131133
sequence_length,
132134
decode_length,
135+
top_n_tokens,
133136
n_runs,
134137
warmups,
135138
temperature,

benchmark/src/main.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ struct Args {
9393
/// decoding strategies, for full doc refer to the `text-generation-server`
9494
#[clap(long, env)]
9595
do_sample: bool,
96+
97+
/// Generation parameter in case you want to specifically test/debug particular
98+
/// decoding strategies, for full doc refer to the `text-generation-server`
99+
#[clap(long, env)]
100+
top_n_tokens: Option<u32>,
96101
}
97102

98103
fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -117,6 +122,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
117122
watermark,
118123
do_sample,
119124
master_shard_uds_path,
125+
top_n_tokens,
120126
} = args;
121127

122128
let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);
@@ -173,6 +179,7 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
173179
batch_size,
174180
sequence_length,
175181
decode_length,
182+
top_n_tokens,
176183
runs,
177184
warmups,
178185
temperature,

benchmark/src/table.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ pub(crate) fn parameters_table(
77
tokenizer_name: String,
88
sequence_length: u32,
99
decode_length: u32,
10+
top_n_tokens: Option<u32>,
1011
n_runs: usize,
1112
warmups: usize,
1213
temperature: Option<f32>,
@@ -24,6 +25,7 @@ pub(crate) fn parameters_table(
2425
builder.push_record(["Model", &tokenizer_name]);
2526
builder.push_record(["Sequence Length", &sequence_length.to_string()]);
2627
builder.push_record(["Decode Length", &decode_length.to_string()]);
28+
builder.push_record(["Top N Tokens", &format!("{top_n_tokens:?}")]);
2729
builder.push_record(["N Runs", &n_runs.to_string()]);
2830
builder.push_record(["Warmups", &warmups.to_string()]);
2931
builder.push_record(["Temperature", &format!("{temperature:?}")]);

clients/python/text_generation/client.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def generate(
7575
typical_p: Optional[float] = None,
7676
watermark: bool = False,
7777
decoder_input_details: bool = False,
78+
top_n_tokens: Optional[int] = None,
7879
) -> Response:
7980
"""
8081
Given a prompt, generate the following text
@@ -113,6 +114,8 @@ def generate(
113114
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
114115
decoder_input_details (`bool`):
115116
Return the decoder input token logprobs and ids
117+
top_n_tokens (`int`):
118+
Return the `n` most likely tokens at each step
116119
117120
Returns:
118121
Response: generated response
@@ -134,6 +137,7 @@ def generate(
134137
typical_p=typical_p,
135138
watermark=watermark,
136139
decoder_input_details=decoder_input_details,
140+
top_n_tokens=top_n_tokens
137141
)
138142
request = Request(inputs=prompt, stream=False, parameters=parameters)
139143

@@ -164,6 +168,7 @@ def generate_stream(
164168
truncate: Optional[int] = None,
165169
typical_p: Optional[float] = None,
166170
watermark: bool = False,
171+
top_n_tokens: Optional[int] = None,
167172
) -> Iterator[StreamResponse]:
168173
"""
169174
Given a prompt, generate the following stream of tokens
@@ -198,6 +203,8 @@ def generate_stream(
198203
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
199204
watermark (`bool`):
200205
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
206+
top_n_tokens (`int`):
207+
Return the `n` most likely tokens at each step
201208
202209
Returns:
203210
Iterator[StreamResponse]: stream of generated tokens
@@ -219,6 +226,7 @@ def generate_stream(
219226
truncate=truncate,
220227
typical_p=typical_p,
221228
watermark=watermark,
229+
top_n_tokens=top_n_tokens,
222230
)
223231
request = Request(inputs=prompt, stream=True, parameters=parameters)
224232

@@ -317,6 +325,7 @@ async def generate(
317325
typical_p: Optional[float] = None,
318326
watermark: bool = False,
319327
decoder_input_details: bool = False,
328+
top_n_tokens: Optional[int] = None,
320329
) -> Response:
321330
"""
322331
Given a prompt, generate the following text asynchronously
@@ -355,6 +364,8 @@ async def generate(
355364
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
356365
decoder_input_details (`bool`):
357366
Return the decoder input token logprobs and ids
367+
top_n_tokens (`int`):
368+
Return the `n` most likely tokens at each step
358369
359370
Returns:
360371
Response: generated response
@@ -376,6 +387,7 @@ async def generate(
376387
truncate=truncate,
377388
typical_p=typical_p,
378389
watermark=watermark,
390+
top_n_tokens=top_n_tokens,
379391
)
380392
request = Request(inputs=prompt, stream=False, parameters=parameters)
381393

@@ -404,6 +416,7 @@ async def generate_stream(
404416
truncate: Optional[int] = None,
405417
typical_p: Optional[float] = None,
406418
watermark: bool = False,
419+
top_n_tokens: Optional[int] = None,
407420
) -> AsyncIterator[StreamResponse]:
408421
"""
409422
Given a prompt, generate the following stream of tokens asynchronously
@@ -438,6 +451,8 @@ async def generate_stream(
438451
See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information
439452
watermark (`bool`):
440453
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
454+
top_n_tokens (`int`):
455+
Return the `n` most likely tokens at each step
441456
442457
Returns:
443458
AsyncIterator[StreamResponse]: stream of generated tokens
@@ -459,6 +474,7 @@ async def generate_stream(
459474
truncate=truncate,
460475
typical_p=typical_p,
461476
watermark=watermark,
477+
top_n_tokens=top_n_tokens,
462478
)
463479
request = Request(inputs=prompt, stream=True, parameters=parameters)
464480

clients/python/text_generation/types.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class Parameters(BaseModel):
3939
details: bool = False
4040
# Get decoder input token logprobs and ids
4141
decoder_input_details: bool = False
42+
# Return the N most likely tokens at each step
43+
top_n_tokens: Optional[int]
4244

4345
@validator("best_of")
4446
def valid_best_of(cls, field_value, values):
@@ -101,6 +103,12 @@ def valid_typical_p(cls, v):
101103
raise ValidationError("`typical_p` must be > 0.0 and < 1.0")
102104
return v
103105

106+
@validator("top_n_tokens")
107+
def valid_top_n_tokens(cls, v):
108+
if v is not None and v <= 0:
109+
raise ValidationError("`top_n_tokens` must be strictly positive")
110+
return v
111+
104112

105113
class Request(BaseModel):
106114
# Prompt
@@ -125,9 +133,7 @@ def valid_best_of_stream(cls, field_value, values):
125133
and parameters.best_of > 1
126134
and field_value
127135
):
128-
raise ValidationError(
129-
"`best_of` != 1 is not supported when `stream` == True"
130-
)
136+
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
131137
return field_value
132138

133139

@@ -179,6 +185,8 @@ class BestOfSequence(BaseModel):
179185
prefill: List[InputToken]
180186
# Generated tokens
181187
tokens: List[Token]
188+
# Most likely tokens
189+
top_tokens: Optional[List[List[Token]]]
182190

183191

184192
# `generate` details
@@ -193,6 +201,8 @@ class Details(BaseModel):
193201
prefill: List[InputToken]
194202
# Generated tokens
195203
tokens: List[Token]
204+
# Most likely tokens
205+
top_tokens: Optional[List[List[Token]]]
196206
# Additional sequences when using the `best_of` parameter
197207
best_of_sequences: Optional[List[BestOfSequence]]
198208

@@ -219,6 +229,8 @@ class StreamDetails(BaseModel):
219229
class StreamResponse(BaseModel):
220230
# Generated token
221231
token: Token
232+
# Most likely tokens
233+
top_tokens: Optional[List[Token]]
222234
# Complete generated text
223235
# Only available when the generation is finished
224236
generated_text: Optional[str]

launcher/src/main.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,14 @@ struct Args {
159159
#[clap(default_value = "4", long, env)]
160160
max_stop_sequences: usize,
161161

162+
/// This is the maximum allowed value for clients to set `top_n_tokens`.
163+
/// `top_n_tokens is used to return information about the the `n` most likely
164+
/// tokens at each generation step, instead of just the sampled token. This
165+
/// information can be used for downstream tasks like for classification or
166+
/// ranking.
167+
#[clap(default_value = "5", long, env)]
168+
max_top_n_tokens: u32,
169+
162170
/// This is the maximum allowed input length (expressed in number of tokens)
163171
/// for users. The larger this value, the longer prompt users can send which
164172
/// can impact the overall memory required to handle the load.
@@ -929,6 +937,8 @@ fn spawn_webserver(
929937
args.max_best_of.to_string(),
930938
"--max-stop-sequences".to_string(),
931939
args.max_stop_sequences.to_string(),
940+
"--max-top-n-tokens".to_string(),
941+
args.max_top_n_tokens.to_string(),
932942
"--max-input-length".to_string(),
933943
args.max_input_length.to_string(),
934944
"--max-total-tokens".to_string(),

proto/generate.proto

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ message Request {
9191
StoppingCriteriaParameters stopping_parameters = 5;
9292
/// Return prefill logprobs
9393
bool prefill_logprobs = 6;
94+
/// Return most likely n tokens
95+
uint32 top_n_tokens = 7;
9496
}
9597

9698
message Batch {
@@ -141,6 +143,17 @@ message PrefillTokens {
141143
repeated string texts = 3;
142144
}
143145

146+
message TopTokens {
147+
/// Top Token IDs
148+
repeated uint32 ids = 1;
149+
/// Top Logprobs
150+
repeated float logprobs = 2;
151+
/// Top Token Texts
152+
repeated string texts = 3;
153+
/// If the tokens are special
154+
repeated bool is_special = 6;
155+
}
156+
144157
message Generation {
145158
/// Request ID
146159
uint64 request_id = 1;
@@ -156,6 +169,8 @@ message Generation {
156169
bool token_is_special = 6;
157170
/// Complete generated text
158171
optional GeneratedText generated_text = 7;
172+
/// Top tokens
173+
TopTokens top_tokens = 8;
159174
}
160175

161176
message FilterBatchRequest {

router/client/src/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ impl Client {
131131
ignore_eos_token: false,
132132
}),
133133
prefill_logprobs: true,
134+
top_n_tokens: 20,
134135
});
135136
n_tokens += max_input_length;
136137
}

router/src/health.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ impl Health {
5050
stop_sequences: vec![],
5151
ignore_eos_token: false,
5252
}),
53+
top_n_tokens: 0,
5354
};
5455
let batch = Batch {
5556
id: BATCH_ID,

0 commit comments

Comments
 (0)