Skip to content

Commit

Permalink
feat: support if vlm models
Browse files Browse the repository at this point in the history
  • Loading branch information
drbh committed Jun 7, 2024
1 parent 7c67d80 commit 1ba461d
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 6 deletions.
2 changes: 2 additions & 0 deletions server/text_generation_server/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def forward_layer_type(
start_idx: int,
end_idx: int,
) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = (
data.get(LORA) if data is not None else None
Expand Down
16 changes: 13 additions & 3 deletions server/text_generation_server/models/flash_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,17 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):

# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model

for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
Expand Down Expand Up @@ -139,7 +149,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer.mlp.down_proj,
)

layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights

@property
Expand All @@ -151,7 +161,7 @@ def default_traced_adapter_layers(self) -> List[str]:
return ["q_proj", "v_proj"]

def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
return 1 if layer_type == "lm_head" else len(self.model.model.layers)

def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
14 changes: 12 additions & 2 deletions server/text_generation_server/models/flash_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,17 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}

prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):

# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model

for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
Expand Down Expand Up @@ -150,7 +160,7 @@ def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer.mlp.down_proj,
)

layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights

@property
Expand Down
1 change: 1 addition & 0 deletions server/text_generation_server/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ def __init__(
model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
Expand Down
4 changes: 3 additions & 1 deletion server/text_generation_server/models/vlm_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,9 @@ def batch_type(self) -> Type[VlmCausalLMBatch]:
return VlmCausalLMBatch

def forward(
self, batch: VlmCausalLMBatch
self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward
if batch.speculative_ids is not None:
Expand Down

0 comments on commit 1ba461d

Please sign in to comment.