|
30 | 30 |
|
31 | 31 | from ... import initialization as init |
32 | 32 | from ...activations import ACT2FN |
33 | | -from ...cache_utils import Cache, DynamicCache |
| 33 | +from ...cache_utils import Cache |
34 | 34 | from ...generation import GenerationMixin |
35 | 35 | from ...image_processing_utils import select_best_resolution |
36 | 36 | from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func |
|
42 | 42 | from ...processing_utils import Unpack |
43 | 43 | from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check |
44 | 44 | from ...utils.generic import maybe_autocast, merge_with_config_defaults |
| 45 | +from ...utils.output_capturing import capture_outputs |
45 | 46 | from ..auto import AutoModel |
46 | 47 | from .configuration_granite4_vision import Granite4VisionConfig, Granite4VisionTextConfig |
47 | 48 |
|
@@ -80,7 +81,7 @@ class Granite4VisionCausalLMOutputWithPast(ModelOutput): |
80 | 81 |
|
81 | 82 |
|
82 | 83 | @dataclass |
83 | | -class Granite4VisionImageFeaturesOutput(ModelOutput): |
| 84 | +class Granite4VisionImageFeaturesOutput(BaseModelOutputWithPooling): |
84 | 85 | """ |
85 | 86 | Output of `Granite4VisionModel.get_image_features`. |
86 | 87 |
|
@@ -590,6 +591,7 @@ def __init__(self, config: Granite4VisionTextConfig): |
590 | 591 | # Initialize weights and apply final processing |
591 | 592 | self.post_init() |
592 | 593 |
|
| 594 | + @capture_outputs |
593 | 595 | @auto_docstring |
594 | 596 | def forward( |
595 | 597 | self, |
@@ -618,9 +620,6 @@ def forward( |
618 | 620 |
|
619 | 621 | inputs_embeds = inputs_embeds * self.embedding_multiplier |
620 | 622 |
|
621 | | - if use_cache and past_key_values is None: |
622 | | - past_key_values = DynamicCache(config=self.config) |
623 | | - |
624 | 623 | if position_ids is None: |
625 | 624 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
626 | 625 | position_ids = ( |
@@ -916,6 +915,9 @@ def get_image_features( |
916 | 915 | elif pixel_values.dim() != 4: |
917 | 916 | raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") |
918 | 917 |
|
| 918 | + output_hidden_states = kwargs.pop("output_hidden_states", None) |
| 919 | + if output_hidden_states is None: |
| 920 | + output_hidden_states = getattr(self.config, "output_hidden_states", False) |
919 | 921 | vision_outputs = self.vision_tower(pixel_values, output_hidden_states=True, **kwargs) |
920 | 922 |
|
921 | 923 | # Deepstack features: extract from multiple vision layers, downsample via interpolation |
@@ -958,7 +960,10 @@ def get_image_features( |
958 | 960 |
|
959 | 961 | all_features.append((llm_layer, packed_group)) |
960 | 962 |
|
961 | | - return Granite4VisionImageFeaturesOutput(deepstack_features=all_features) |
| 963 | + return Granite4VisionImageFeaturesOutput( |
| 964 | + deepstack_features=all_features, |
| 965 | + hidden_states=vision_outputs.hidden_states if output_hidden_states else None, |
| 966 | + ) |
962 | 967 |
|
963 | 968 | def get_placeholder_mask( |
964 | 969 | self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor |
|
0 commit comments