Skip to content

Commit

Permalink
fix: update all models forwards to include adapter_data
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 7, 2024
1 parent 460ee55 commit 5b54835
Show file tree
Hide file tree
Showing 9 changed files with 9 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
hidden_states = self.model(
input_ids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
input_embeds = self.embed_tokens(input_ids)
hidden_states = self.model(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
token_embeds = self.embed_tokens(input_ids)
position_embeds = self.embed_positions(position_ids)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def forward(
# Unused here
pixel_attention_mask: Optional[torch.BoolTensor] = None,
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.text_model.embed_tokens(input_ids)
# TODO This is odd but apparently pali gemma position ids start at 1.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,7 @@ def forward(
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def forward(
pixel_attention_mask: Optional[torch.BoolTensor] = None,
# Unused here
image_sizes: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.text_model.embed_tokens(input_ids)
if pixel_values is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def forward(
# Unused for this model
pixel_attention_mask=None,
image_sizes: Optional[torch.LongTensor] = None,
adapter_data: Optional[torch.Tensor] = None,
):
inputs_embeds = self.language_model.embed_tokens(input_ids)
if pixel_values is not None and len(pixel_values) > 0:
Expand Down

0 comments on commit 5b54835

Please sign in to comment.