Skip to content

Commit 937c408

Browse files
authored
Update onnx.py
Fixed switched token_type_ids and attention_mask and made use of kwargs to make code less error prone in case something about the order of the arguments changes.
1 parent 4ebee43 commit 937c408

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/setfit/exporters/onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ def __init__(
4646
self.pooler = pooler
4747
self.model_head = model_head
4848

49-
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor):
50-
hidden_states = self.model_body(input_ids, attention_mask, token_type_ids)
49+
def forward(self, input_ids: torch.Tensor, token_type_ids: torch.Tensor, attention_mask: torch.Tensor):
50+
hidden_states = self.model_body(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
5151
hidden_states = {"token_embeddings": hidden_states[0], "attention_mask": attention_mask}
52-
52+
5353
embeddings = self.pooler(hidden_states)
5454

5555
# If the model_head is none we are using a sklearn head and only output

0 commit comments

Comments
 (0)