Skip to content

Commit 44ce098

Browse files
feat(server): pre-allocate max attention mask (#75)
1 parent 78063c0 commit 44ce098

File tree

7 files changed

+148
-114
lines changed

7 files changed

+148
-114
lines changed

server/tests/models/test_bloom.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def test_batch_from_pb(default_pb_batch, default_bloom_batch):
6565
assert batch.input_ids[0][-1] == 10264
6666
assert torch.all(batch.input_ids[0][:-1] == 3)
6767

68-
assert batch.attention_mask[0][-1] == 1
69-
assert torch.all(batch.attention_mask[0][:-1] == 0)
68+
assert batch.attention_mask[0][0] == 1
69+
assert torch.all(batch.attention_mask[0][1:] == 0)
7070

7171
assert batch.past_key_values is None
7272

@@ -98,16 +98,13 @@ def test_causal_lm_generate_token(default_bloom, default_bloom_batch):
9898
assert not next_batch.keys_head_dim_last
9999

100100
assert len(next_batch.all_input_ids) == next_batch.size
101-
assert (
102-
len(next_batch.all_input_ids[0])
103-
== len(next_batch.attention_mask[0])
104-
== sequence_length + 1
105-
)
101+
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
102+
assert len(next_batch.attention_mask[0]) == 11
106103
assert torch.all(next_batch.all_input_ids[0][-2:] == 10264)
107104
assert torch.all(next_batch.all_input_ids[0][:-2] == 3)
108105

109-
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
110-
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
106+
assert torch.all(next_batch.attention_mask[0][:2] == 1)
107+
assert torch.all(next_batch.attention_mask[0][2:] == 0)
111108

112109
assert next_batch.input_ids.shape == (next_batch.size, 1)
113110
assert next_batch.input_ids[0, 0] == 10264
@@ -213,9 +210,13 @@ def test_batch_concatenate(
213210
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
214211
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
215212

216-
assert torch.all(next_batch.attention_mask[0] == 1)
217-
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
218-
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
213+
assert torch.all(
214+
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
215+
)
216+
assert torch.all(
217+
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
218+
)
219+
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
219220

220221
assert next_batch.batch_id == 0
221222
assert torch.all(next_batch.input_ids == 10264)

server/tests/models/test_causal_lm.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
6262
assert batch.input_ids[0][-1] == 14402
6363
assert torch.all(batch.input_ids[0][:-1] == 50256)
6464

65-
assert batch.attention_mask[0][-1] == 1
66-
assert torch.all(batch.attention_mask[0][:-1] == 0)
65+
assert batch.attention_mask[0, 0] == 1
66+
assert torch.all(batch.attention_mask[0, 1:] == 0)
6767

6868
assert batch.past_key_values is None
6969

@@ -94,17 +94,14 @@ def test_causal_lm_generate_token(default_causal_lm, default_causal_lm_batch):
9494
assert isinstance(next_batch, CausalLMBatch)
9595

9696
assert len(next_batch.all_input_ids) == next_batch.size
97-
assert (
98-
len(next_batch.all_input_ids[0])
99-
== len(next_batch.attention_mask[0])
100-
== sequence_length + 1
101-
)
97+
assert len(next_batch.all_input_ids[0]) == sequence_length + 1
98+
assert len(next_batch.attention_mask[0]) == 11
10299
assert next_batch.all_input_ids[0][-1] == 13
103100
assert next_batch.all_input_ids[0][-2] == 14402
104101
assert torch.all(next_batch.all_input_ids[0][:-2] == 50256)
105102

106-
assert torch.all(next_batch.attention_mask[0][-2:] == 1)
107-
assert torch.all(next_batch.attention_mask[0][:-2] == 0)
103+
assert torch.all(next_batch.attention_mask[0][0:2] == 1)
104+
assert torch.all(next_batch.attention_mask[0][2:] == 0)
108105

109106
assert next_batch.input_ids.shape == (next_batch.size, 1)
110107
assert next_batch.input_ids[0, 0] == 13
@@ -210,9 +207,13 @@ def test_batch_concatenate(
210207
assert torch.equal(next_batch.all_input_ids[1], next_batch_1.all_input_ids[0])
211208
assert torch.equal(next_batch.all_input_ids[2], next_batch_1.all_input_ids[1])
212209

213-
assert torch.all(next_batch.attention_mask[0] == 1)
214-
assert torch.all(next_batch.attention_mask[1:, -2:] == 1)
215-
assert torch.all(next_batch.attention_mask[1:, :-2] == 0)
210+
assert torch.all(
211+
next_batch.attention_mask[0, : -next_batch.padding_right_offset] == 1
212+
)
213+
assert torch.all(
214+
next_batch.attention_mask[1:, 1 : -next_batch.padding_right_offset] == 1
215+
)
216+
assert torch.all(next_batch.attention_mask[1:, 3:] == 0)
216217

217218
assert next_batch.batch_id == 0
218219
assert next_batch.input_ids[0, 0] == 12355

server/tests/models/test_seq2seq_lm.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
106106
assert len(generations) == len(next_batch)
107107
assert isinstance(next_batch, Seq2SeqLMBatch)
108108

109-
assert torch.equal(next_batch.input_ids, default_seq2seq_lm_batch.input_ids)
109+
assert next_batch.input_ids is None
110110
assert torch.equal(
111111
next_batch.attention_mask, default_seq2seq_lm_batch.attention_mask
112112
)
@@ -220,11 +220,6 @@ def test_batch_concatenate(
220220

221221
assert next_batch.batch_id == 0
222222

223-
assert torch.all(next_batch.input_ids[:, 0] == 4268)
224-
assert torch.all(next_batch.input_ids[:, 1] == 1)
225-
226-
assert torch.all(next_batch.attention_mask == 1)
227-
228223
assert torch.equal(
229224
next_batch.decoder_input_ids[0], next_batch_0.decoder_input_ids[0]
230225
)
@@ -233,9 +228,10 @@ def test_batch_concatenate(
233228
next_batch.decoder_input_ids[1:, -2:], next_batch_1.decoder_input_ids
234229
)
235230

236-
assert torch.all(next_batch.decoder_attention_mask[0] == 1)
231+
assert torch.all(next_batch.decoder_attention_mask[0, :3] == 1)
232+
assert torch.all(next_batch.decoder_attention_mask[0, 3:] == 0)
237233
assert torch.all(next_batch.decoder_attention_mask[1:, 0] == 0)
238-
assert torch.all(next_batch.decoder_attention_mask[1:, -2:] == 1)
234+
assert torch.all(next_batch.decoder_attention_mask[1:, 1:3] == 1)
239235

240236
assert torch.equal(
241237
next_batch.encoder_last_hidden_state[0],

server/text_generation/models/causal_lm.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class CausalLMBatch(Batch):
3737
# Metadata used for padding
3838
size: int
3939
max_sequence_length: int
40+
padding_right_offset: int
4041

4142
# Past metadata
4243
keys_head_dim_last: bool = True
@@ -61,47 +62,67 @@ def from_pb(
6162
input_lengths = []
6263

6364
# Parse batch
65+
max_sequence_length = 0
66+
padding_right_offset = 0
6467
for r in pb.requests:
6568
inputs.append(r.inputs)
6669
input_lengths.append(r.input_length)
6770
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
68-
stopping_criterias.append(
69-
StoppingCriteria.from_pb(r.stopping_parameters, tokenizer)
71+
stopping_criteria = StoppingCriteria.from_pb(
72+
r.stopping_parameters, tokenizer
73+
)
74+
stopping_criterias.append(stopping_criteria)
75+
max_sequence_length = max(max_sequence_length, r.input_length)
76+
padding_right_offset = max(
77+
padding_right_offset, stopping_criteria.max_new_tokens
7078
)
7179

72-
pad_to_multiple_of = 8 if device.type == "cuda" else None
7380
tokenized_inputs = tokenizer(
7481
inputs,
7582
return_tensors="pt",
7683
padding=True,
77-
pad_to_multiple_of=pad_to_multiple_of,
7884
return_token_type_ids=False,
7985
).to(device)
86+
87+
input_ids = tokenized_inputs["input_ids"]
88+
# Allocate maximum attention_mask
89+
attention_mask = input_ids.new_zeros(
90+
(pb.size, max_sequence_length + padding_right_offset)
91+
)
92+
# Copy tokenizer attention_mask into fully allocated attention_mask
93+
attention_mask[:, :max_sequence_length] = tokenized_inputs["attention_mask"]
94+
8095
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
8196
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
8297
all_input_ids = tokenized_inputs["input_ids"].unsqueeze(-1)
8398

8499
return cls(
85100
batch_id=pb.id,
86101
requests=pb.requests,
87-
input_ids=tokenized_inputs["input_ids"],
88-
attention_mask=tokenized_inputs["attention_mask"],
102+
input_ids=input_ids,
103+
attention_mask=attention_mask,
89104
position_ids=position_ids,
90105
past_key_values=None,
91106
all_input_ids=all_input_ids,
92107
input_lengths=input_lengths,
93108
next_token_choosers=next_token_choosers,
94109
stopping_criterias=stopping_criterias,
95110
size=pb.size,
96-
max_sequence_length=max(input_lengths),
111+
max_sequence_length=max_sequence_length,
112+
padding_right_offset=padding_right_offset,
97113
)
98114

99115
@classmethod
100116
@tracer.start_as_current_span("concatenate")
101117
def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
102118
# Used for padding
103-
total_batch_size = sum(batch.size for batch in batches)
104-
max_sequence_length = max(batch.max_sequence_length for batch in batches)
119+
total_batch_size = 0
120+
max_sequence_length = 0
121+
padding_right_offset = 0
122+
for batch in batches:
123+
total_batch_size += batch.size
124+
max_sequence_length = max(max_sequence_length, batch.max_sequence_length)
125+
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
105126

106127
# Batch attributes
107128
requests = []
@@ -144,13 +165,22 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
144165
# Create padded tensor
145166
if attention_mask is None:
146167
attention_mask = batch.attention_mask.new_zeros(
147-
(total_batch_size, max_sequence_length),
168+
(total_batch_size, max_sequence_length + padding_right_offset),
148169
)
149170

150171
# We need to slice the attention mask to remove padding from previous steps
172+
# and to remove unused allocated space
173+
left_offset = max_sequence_length - batch.max_sequence_length
174+
batch_left_offset = (
175+
batch.attention_mask.shape[1] - batch.max_sequence_length - batch.padding_right_offset
176+
)
151177
attention_mask[
152-
start_index:end_index, -batch.max_sequence_length :
153-
] = batch.attention_mask[:, -batch.max_sequence_length :]
178+
start_index:end_index,
179+
left_offset:-padding_right_offset,
180+
] = batch.attention_mask[
181+
:,
182+
batch_left_offset : -batch.padding_right_offset,
183+
]
154184

155185
# Create empty tensor
156186
# position_ids is always of shape [batch_size, 1]
@@ -228,6 +258,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch":
228258
stopping_criterias=stopping_criterias,
229259
size=total_batch_size,
230260
max_sequence_length=max_sequence_length,
261+
padding_right_offset=padding_right_offset,
231262
keys_head_dim_last=batches[0].keys_head_dim_last,
232263
)
233264

@@ -294,9 +325,12 @@ def forward(
294325
def generate_token(
295326
self, batch: CausalLMBatch
296327
) -> Tuple[List[Generation], Optional[CausalLMBatch]]:
328+
# slice the attention mask to the correct shape
329+
attention_mask = batch.attention_mask[:, : -batch.padding_right_offset]
330+
297331
logits, past = self.forward(
298332
batch.input_ids,
299-
batch.attention_mask,
333+
attention_mask,
300334
batch.position_ids,
301335
batch.past_key_values,
302336
)
@@ -448,14 +482,8 @@ def generate_token(
448482
next_batch_next_token_choosers = batch.next_token_choosers
449483
next_batch_stopping_criterias = batch.stopping_criterias
450484

451-
# Update attention_mask with padding as we added a new token to input_ids
452-
next_batch_attention_mask = torch.cat(
453-
[
454-
next_batch_attention_mask,
455-
next_batch_attention_mask.new_ones(next_batch_size, 1),
456-
],
457-
dim=1,
458-
)
485+
# Update attention_mask as we added a new token to input_ids
486+
next_batch_attention_mask[:, -batch.padding_right_offset] = 1
459487

460488
# Update position_ids
461489
next_batch_position_ids = next_batch_position_ids[:, -1:] + 1
@@ -473,6 +501,7 @@ def generate_token(
473501
stopping_criterias=next_batch_stopping_criterias,
474502
size=next_batch_size,
475503
max_sequence_length=next_batch_max_sequence_length,
504+
padding_right_offset=batch.padding_right_offset - 1,
476505
keys_head_dim_last=batch.keys_head_dim_last,
477506
)
478507
return generations, next_batch

server/text_generation/models/galactica.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,10 @@ def from_pb(
106106
)
107107

108108
# Tokenize batch
109-
pad_to_multiple_of = 8 if device.type == "cuda" else None
110109
tokenized_inputs = tokenizer(
111110
inputs,
112111
return_tensors="pt",
113112
padding=True,
114-
pad_to_multiple_of=pad_to_multiple_of,
115113
return_token_type_ids=False,
116114
).to(device)
117115
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1

0 commit comments

Comments
 (0)