diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 82b5fef..2abe06b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -27,7 +27,7 @@ on: - '.github/**' env: - PYTHON_VERSION: 3.11 + PYTHON_VERSION: 3.12 jobs: lint: @@ -45,10 +45,10 @@ jobs: fetch-depth: 0 submodules: true - - name: Setup Python 3.11 + - name: Setup Python 3.12 uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 with: - python-version: 3.11 + python-version: 3.12 cache: pip cache-dependency-path: | **/pyproject.toml diff --git a/src/instructlab/dolomite/hf_models/__init__.py b/src/instructlab/dolomite/hf_models/__init__.py index 86cc443..ef3ffb8 100644 --- a/src/instructlab/dolomite/hf_models/__init__.py +++ b/src/instructlab/dolomite/hf_models/__init__.py @@ -2,9 +2,9 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .models.gpt_dolomite.config import GPTDolomiteConfig from .model_conversion import export_to_huggingface, import_from_huggingface from .models import GPTDolomiteForCausalLM, GPTDolomiteModel +from .models.gpt_dolomite.config import GPTDolomiteConfig from .register_hf import register_model_classes register_model_classes() diff --git a/src/instructlab/dolomite/hf_models/config.py b/src/instructlab/dolomite/hf_models/config.py index 538dc34..49a2e9b 100644 --- a/src/instructlab/dolomite/hf_models/config.py +++ b/src/instructlab/dolomite/hf_models/config.py @@ -1,5 +1,7 @@ +# Third Party from transformers import PretrainedConfig +# Local from .enums import AttentionHeadType, InitMethod, PositionEmbeddingType @@ -98,7 +100,9 @@ def __init__( if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert self.num_key_value_heads == 1, "MultiQueryAttention should have 1 head for keys and values" + assert ( + self.num_key_value_heads == 1 + ), "MultiQueryAttention should have 1 head for keys and values" elif attention_head_type == AttentionHeadType.gqa: assert ( self.num_key_value_heads is not None @@ -108,4 +112,9 @@ def __init__( self.n_head % self.num_key_value_heads == 0 ), "GroupedQueryAttention should have more than 1 head for keys and values" - super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, pad_token_id=pad_token_id, **kwargs) + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + **kwargs, + ) diff --git a/src/instructlab/dolomite/hf_models/enums.py b/src/instructlab/dolomite/hf_models/enums.py index 5055bcf..0ac10fa 100644 --- a/src/instructlab/dolomite/hf_models/enums.py +++ b/src/instructlab/dolomite/hf_models/enums.py @@ -1,3 +1,4 @@ +# Standard from enum import Enum diff --git a/src/instructlab/dolomite/hf_models/mixins/__init__.py b/src/instructlab/dolomite/hf_models/mixins/__init__.py index c4f9102..e899ea2 100644 --- a/src/instructlab/dolomite/hf_models/mixins/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/__init__.py @@ -1,4 +1,7 @@ +# Local from .dense import BaseModelMixin, CausalLMModelMixin, PreTrainedModelMixin -#from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP + +# from .dense_TP import BaseModelMixin_TP, CausalLMModelMixin_TP, PreTrainedModelMixin_TP from .moe import BaseMoEModelMixin, CausalLMMoEModelMixin, PreTrainedMoEModelMixin -#from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP + +# from .moe_TP import BaseMoEModelMixin_TP, CausalLMMoEModelMixin_TP, PreTrainedMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py index 0ee5d10..b29b99f 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseModelMixin, PreTrainedModelMixin from .main import CausalLMModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/base.py b/src/instructlab/dolomite/hf_models/mixins/dense/base.py index 3298682..e133727 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/base.py @@ -1,14 +1,23 @@ +# Standard import warnings -import torch -import torch.nn as nn +# Third Party from transformers import DynamicCache, PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPast +import torch +import torch.nn as nn +# Local from ...config import CommonConfig from ...defaults import DEFAULT_NORMALIZATION_IMPLEMENTATION from ...enums import AttentionHeadType, PositionEmbeddingType -from ...modeling_utils import Alibi, ParameterizedEmbedding, RoPE, YaRNScaledRoPE, get_normalization_function +from ...modeling_utils import ( + Alibi, + ParameterizedEmbedding, + RoPE, + YaRNScaledRoPE, + get_normalization_function, +) from ...utils import convert_padding_free_lists_to_tensors, divide_if_divisible @@ -39,13 +48,19 @@ def __init__(self, config: CommonConfig, *args, **kwargs) -> None: self.attention_implementation = self.config._attn_implementation self._use_eager_attention = self.attention_implementation == "eager" self._use_sdpa = self.attention_implementation == "sdpa" - self._use_flash_attention_2 = self.attention_implementation == "flash_attention_2" - self._use_padding_free_transformer = kwargs.get("use_padding_free_transformer", False) + self._use_flash_attention_2 = ( + self.attention_implementation == "flash_attention_2" + ) + self._use_padding_free_transformer = kwargs.get( + "use_padding_free_transformer", False + ) self._tied_word_embeddings = config.tie_word_embeddings if self._use_padding_free_transformer: - assert self._use_flash_attention_2, "padding free transformer only works with flash attention" + assert ( + self._use_flash_attention_2 + ), "padding free transformer only works with flash attention" def _init_weights(self, module: nn.Module) -> None: if hasattr(module, "reset_parameters"): @@ -74,28 +89,43 @@ def prepare_inputs_for_model( ) assert cu_seqlens is None, error_message.format(variable="cu_seqlens") assert max_seqlen is None, error_message.format(variable="max_seqlen") - assert attention_mask is None, error_message.format(variable="attention_mask") - - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( - convert_padding_free_lists_to_tensors( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - device=torch.cuda.current_device(), - ) + assert attention_mask is None, error_message.format( + variable="attention_mask" + ) + + ( + input_ids, + position_ids, + token_type_ids, + labels, + cu_seqlens, + max_seqlen, + ) = convert_padding_free_lists_to_tensors( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + device=torch.cuda.current_device(), ) else: assert ( cu_seqlens is not None ), "cu_seqlens needs to be specified when using tensor inputs with padding_free transformer" - assert position_ids is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert max_seqlen is not None, "max_seqlen needs to be specified when specifying cu_seqlens" - assert attention_mask is None, "attention_mask should not be passed when specifying cu_seqlens" + assert ( + position_ids is not None + ), "max_seqlen needs to be specified when specifying cu_seqlens" + assert ( + max_seqlen is not None + ), "max_seqlen needs to be specified when specifying cu_seqlens" + assert ( + attention_mask is None + ), "attention_mask should not be passed when specifying cu_seqlens" if use_cache or past_key_values is not None: - raise NotImplementedError("KV caching is not supported with padding_free transformer") + raise NotImplementedError( + "KV caching is not supported with padding_free transformer" + ) assert not output_attentions @@ -128,9 +158,13 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: f"`embed_dim` ({self.embed_dim}) must be divisible by `num_heads` ({self.num_heads})", ) - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, self.embed_dim, std=self.initializer_range + ) - self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.drop = ( + nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + ) self.h = nn.ModuleList( [ self.layer_class( @@ -150,7 +184,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: normalization_implementation=self.normalization_implementation, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -206,7 +242,9 @@ def forward( # attention_mask -> (batch_size, 1, query_length, key_length) # ========================================================================================== - past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + past_key_values = ( + DynamicCache() if use_cache and past_key_values is None else past_key_values + ) all_hidden_states = () if output_hidden_states else None for block in self.h: if output_hidden_states: @@ -234,7 +272,12 @@ def forward( ) def _get_position_ids( - self, attention_mask: torch.Tensor, past_length: int, query_length: int, key_length: int, device: torch.device + self, + attention_mask: torch.Tensor, + past_length: int, + query_length: int, + key_length: int, + device: torch.device, ) -> torch.Tensor: if attention_mask is not None and len(attention_mask.shape) == 2: # create position_ids on the fly for batch generation @@ -243,7 +286,9 @@ def _get_position_ids( if past_length > 0: position_ids = position_ids[:, past_length:key_length:] else: - position_ids = torch.arange(past_length, key_length, dtype=torch.long, device=device) + position_ids = torch.arange( + past_length, key_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).view(-1, query_length) return position_ids @@ -277,7 +322,11 @@ def _get_alibi_bias( return alibi_bias def _get_rope_cos_sin( - self, key_length: int, position_ids: torch.Tensor, dtype: torch.dtype, device: torch.device + self, + key_length: int, + position_ids: torch.Tensor, + dtype: torch.dtype, + device: torch.device, ) -> torch.Tensor: if self.position_embedding_type == PositionEmbeddingType.rope: cos, sin = self.rope(key_length, dtype=dtype, device=device) @@ -301,7 +350,9 @@ def _prepare_causal_attention_mask( if query_length > 1: # (query_length, key_length) - causal_mask = torch.empty((query_length, key_length), dtype=torch.bool, device=device) + causal_mask = torch.empty( + (query_length, key_length), dtype=torch.bool, device=device + ) causal_mask[:, past_length:] = torch.tril( torch.ones(query_length, query_length, dtype=torch.bool, device=device) ) @@ -321,10 +372,18 @@ def _prepare_causal_attention_mask( else: if attention_mask is None: # (batch_size, query_length, key_length) - causal_mask = torch.ones(batch_size, query_length, key_length, dtype=torch.bool, device=device) + causal_mask = torch.ones( + batch_size, + query_length, + key_length, + dtype=torch.bool, + device=device, + ) else: # (batch_size, query_length, key_length) - causal_mask = attention_mask.unsqueeze(1).to(dtype=torch.bool, device=device) + causal_mask = attention_mask.unsqueeze(1).to( + dtype=torch.bool, device=device + ) # ========================================================================================== # attention_mask -> (batch_size, query_length, key_length) @@ -387,14 +446,20 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) if use_cache is None: - use_cache = False if self._use_padding_free_transformer else self.config.use_cache + use_cache = ( + False if self._use_padding_free_transformer else self.config.use_cache + ) if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) elif input_ids is not None: input_shape = input_ids.size() @@ -425,7 +490,10 @@ def _prepare_a_bunch_of_stuff( else: if self.position_embedding_type == PositionEmbeddingType.alibi: if position_ids is not None: - warnings.warn("`position_ids` have no functionality with Alibi.", FutureWarning) + warnings.warn( + "`position_ids` have no functionality with Alibi.", + FutureWarning, + ) if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) @@ -447,12 +515,16 @@ def _prepare_a_bunch_of_stuff( if self._use_padding_free_transformer: key_length = max_seqlen.item() else: - past_length = 0 if past_key_values is None else past_key_values.get_seq_length() + past_length = ( + 0 if past_key_values is None else past_key_values.get_seq_length() + ) query_length = input_shape[-1] key_length = past_length + query_length if position_ids is None: - position_ids = self._get_position_ids(attention_mask, past_length, query_length, key_length, device) + position_ids = self._get_position_ids( + attention_mask, past_length, query_length, key_length, device + ) # ========================================================================================== # padding_free: @@ -465,7 +537,9 @@ def _prepare_a_bunch_of_stuff( # position_ids -> (batch_size, query_length) # ========================================================================================== - hidden_states = self._get_initial_hidden_state(input_ids, inputs_embeds, position_ids, token_type_ids) + hidden_states = self._get_initial_hidden_state( + input_ids, inputs_embeds, position_ids, token_type_ids + ) # ========================================================================================== # padding_free: @@ -475,7 +549,12 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== alibi_bias = self._get_alibi_bias( - attention_mask, batch_size, query_length, key_length, device, hidden_states.dtype + attention_mask, + batch_size, + query_length, + key_length, + device, + hidden_states.dtype, ) # ========================================================================================== @@ -483,7 +562,10 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== rope_cos_sin = self._get_rope_cos_sin( - key_length, position_ids, dtype=hidden_states.dtype, device=hidden_states.device + key_length, + position_ids, + dtype=hidden_states.dtype, + device=hidden_states.device, ) # ========================================================================================== @@ -494,7 +576,13 @@ def _prepare_a_bunch_of_stuff( # ========================================================================================== attention_mask = self._get_maybe_causal_mask( - attention_mask, alibi_bias, batch_size, query_length, key_length, hidden_states.dtype, device + attention_mask, + alibi_bias, + batch_size, + query_length, + key_length, + hidden_states.dtype, + device, ) return ( @@ -511,9 +599,13 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings = self.config.max_position_embeddings if self.position_embedding_type == PositionEmbeddingType.learned_absolute: - self.wpe = ParameterizedEmbedding(max_position_embeddings, self.embed_dim, std=self.initializer_range) + self.wpe = ParameterizedEmbedding( + max_position_embeddings, self.embed_dim, std=self.initializer_range + ) elif self.position_embedding_type == PositionEmbeddingType.alibi: - assert not self._use_flash_attention_2, "alibi is not implemented with FlashAttention" + assert ( + not self._use_flash_attention_2 + ), "alibi is not implemented with FlashAttention" self.alibi = Alibi(self.num_heads) elif self.position_embedding_type == PositionEmbeddingType.rope: @@ -529,7 +621,9 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings=max_position_embeddings, base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], - original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], ) elif self.position_embedding_type == PositionEmbeddingType.nope: pass @@ -538,8 +632,14 @@ def _setup_positional_encoding(self) -> None: def _get_mask_value(self, device: torch.device, dtype: torch.dtype) -> torch.Tensor: # torch.where expects a tensor. We use a cache to avoid recreating it every time. - if self.mask_value is None or self.mask_value.dtype != dtype or self.mask_value.device != device: - self.mask_value = torch.full([], torch.finfo(dtype).min, dtype=dtype, device=device) + if ( + self.mask_value is None + or self.mask_value.dtype != dtype + or self.mask_value.device != device + ): + self.mask_value = torch.full( + [], torch.finfo(dtype).min, dtype=dtype, device=device + ) return self.mask_value def _get_maybe_causal_mask( @@ -568,7 +668,10 @@ def _get_maybe_causal_mask( # this is needed to prevent NaN since SDPA # see issue: https://github.com/pytorch/pytorch/issues/110213 attention_mask = attention_mask * ~torch.all( - attention_mask == self._get_mask_value(attention_mask.device, dtype), dim=-1, keepdim=True + attention_mask + == self._get_mask_value(attention_mask.device, dtype), + dim=-1, + keepdim=True, ) elif self._use_eager_attention: attention_mask = self._prepare_causal_attention_mask( diff --git a/src/instructlab/dolomite/hf_models/mixins/dense/main.py b/src/instructlab/dolomite/hf_models/mixins/dense/main.py index b03b9ed..603ca53 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense/main.py @@ -1,8 +1,13 @@ +# Third Party +from transformers import DynamicCache +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) import torch import torch.nn.functional as F -from transformers import DynamicCache -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +# Local from ...config import CommonConfig from ...modeling_utils import ParameterizedEmbedding, ParameterizedLinear from .base import PreTrainedModelMixin @@ -21,7 +26,10 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: if not self._tied_word_embeddings: self.lm_head = ParameterizedLinear( - config.n_embd, config.vocab_size, bias=False, std=config.initializer_range + config.n_embd, + config.vocab_size, + bias=False, + std=config.initializer_range, ) self.m_width = config.m_width @@ -112,18 +120,20 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> tuple | CausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) # ========================================================================================== @@ -155,7 +165,9 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) return CausalLMOutputWithPast( loss=loss, @@ -173,7 +185,10 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: ) def get_autoregressive_language_modeling_loss( - self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + self, + lm_logits: torch.Tensor, + labels: torch.Tensor | None, + cu_seqlens: torch.Tensor, ) -> torch.Tensor: if labels is None: return None @@ -193,6 +208,8 @@ def get_autoregressive_language_modeling_loss( if self.upcast_logits_for_loss: shift_logits = shift_logits.float() - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) return loss diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py index cbbb640..3adca7c 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseModelMixin_TP, PreTrainedModelMixin_TP from .main import CausalLMModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py index 801bd72..ca41143 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/base.py @@ -1,16 +1,25 @@ +# Third Party import torch.nn as nn +# Local from ....utils import ProcessGroupManager from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import RoPE, YaRNScaledRoPE -from ...modeling_utils_TP import Alibi_TP, Dropout_TP, Embedding_TP, get_normalization_function_TP +from ...modeling_utils_TP import ( + Alibi_TP, + Dropout_TP, + Embedding_TP, + get_normalization_function_TP, +) from ..dense import BaseModelMixin, PreTrainedModelMixin class PreTrainedModelMixin_TP(PreTrainedModelMixin): def __init__(self, config: CommonConfig, *args, **kwargs): - self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.tensor_parallel_word_embeddings = kwargs.get( + "tensor_parallel_word_embeddings", False + ) self.sequence_parallel = kwargs.get("sequence_parallel", False) super().__init__(config, *args, **kwargs) @@ -67,7 +76,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: sequence_parallel=self.sequence_parallel, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -90,7 +101,9 @@ def _setup_positional_encoding(self) -> None: elif self.position_embedding_type == PositionEmbeddingType.rope: if self.config.rope_scaling is None: self.rope = RoPE( - self.head_dim, max_position_embeddings=max_position_embeddings, base=self.config.rope_theta + self.head_dim, + max_position_embeddings=max_position_embeddings, + base=self.config.rope_theta, ) else: self.rope = YaRNScaledRoPE( @@ -98,7 +111,9 @@ def _setup_positional_encoding(self) -> None: max_position_embeddings=max_position_embeddings, base=self.config.rope_theta, scale=self.config.rope_scaling["factor"], - original_max_position_embeddings=self.config.rope_scaling["original_max_position_embeddings"], + original_max_position_embeddings=self.config.rope_scaling[ + "original_max_position_embeddings" + ], ) else: raise NotImplementedError() diff --git a/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py index 4505921..cc8d019 100644 --- a/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/dense_TP/main.py @@ -1,14 +1,21 @@ +# Future from __future__ import annotations +# Standard from contextlib import nullcontext -import torch -import torch.nn.functional as F +# Third Party from torch.distributed._tensor.placement_types import Replicate, Shard from torch.distributed.tensor.parallel import loss_parallel from transformers import DynamicCache -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +import torch +import torch.nn.functional as F +# Local from ....utils import ProcessGroupManager, SafeTensorsWeightsManager from ...config import CommonConfig from ...enums import PositionEmbeddingType @@ -58,18 +65,20 @@ def forward( cu_seqlens: torch.Tensor | None = None, max_seqlen: torch.Tensor | None = None, ) -> tuple | CausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) transformer_outputs: BaseModelOutputWithPast = self.transformer( @@ -90,15 +99,21 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) if output_parallel_lm_logits: assert self.tensor_parallel_word_embeddings else: if self.tensor_parallel_word_embeddings: # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + lm_logits = tensor_to_dtensor( + lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1) + ) + lm_logits = dtensor_to_tensor( + lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate() + ) return CausalLMOutputWithPast( loss=loss, @@ -123,7 +138,10 @@ def get_lm_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: ) def get_autoregressive_language_modeling_loss( - self, lm_logits: torch.Tensor, labels: torch.Tensor | None, cu_seqlens: torch.Tensor + self, + lm_logits: torch.Tensor, + labels: torch.Tensor | None, + cu_seqlens: torch.Tensor, ) -> torch.Tensor: if labels is None: return None @@ -143,16 +161,24 @@ def get_autoregressive_language_modeling_loss( shift_logits = tensor_to_dtensor( shift_logits, device_mesh=self.tp_mesh, - current_placement=Shard(-1) if self.tensor_parallel_word_embeddings else Replicate(), + current_placement=Shard(-1) + if self.tensor_parallel_word_embeddings + else Replicate(), + ) + shift_labels = tensor_to_dtensor( + shift_labels, device_mesh=self.tp_mesh, current_placement=Replicate() ) - shift_labels = tensor_to_dtensor(shift_labels, device_mesh=self.tp_mesh, current_placement=Replicate()) if self.upcast_logits_for_loss: shift_logits = shift_logits.float() - loss_context = loss_parallel if self.tensor_parallel_word_embeddings else nullcontext + loss_context = ( + loss_parallel if self.tensor_parallel_word_embeddings else nullcontext + ) with loss_context(): - loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) + loss = F.cross_entropy( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) return loss @@ -164,23 +190,35 @@ def from_pretrained( tensor_parallel_word_embeddings: bool = False, **kwargs, ) -> CausalLMModelMixin_TP: - config: CommonConfig = cls.config_class.from_pretrained(pretrained_model_name_or_path) + config: CommonConfig = cls.config_class.from_pretrained( + pretrained_model_name_or_path + ) # use dummy tensors to avoid initializing model here with torch.device("meta"): # try sharding vocab matrices if really struggling for memory - model = cls._from_config(config, tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, **kwargs) + model = cls._from_config( + config, + tensor_parallel_word_embeddings=tensor_parallel_word_embeddings, + **kwargs, + ) model = model.to(dtype=torch_dtype) # copy to device without copying storage model = model.to_empty(device=torch.cuda.current_device()) - model.load_from_safetensors_weights_manager(SafeTensorsWeightsManager(pretrained_model_name_or_path)) + model.load_from_safetensors_weights_manager( + SafeTensorsWeightsManager(pretrained_model_name_or_path) + ) return model - def load_from_safetensors_weights_manager(self, safetensors_weights_manager: SafeTensorsWeightsManager) -> None: + def load_from_safetensors_weights_manager( + self, safetensors_weights_manager: SafeTensorsWeightsManager + ) -> None: with torch.device(torch.cuda.current_device()): - position_embedding_type = PositionEmbeddingType(self.config.position_embedding_type) + position_embedding_type = PositionEmbeddingType( + self.config.position_embedding_type + ) if position_embedding_type == PositionEmbeddingType.alibi: self.transformer.alibi.reset_parameters() diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py index 12b6465..c247564 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/__init__.py @@ -1,2 +1,7 @@ -from .base import BaseMoEModelMixin, MoeModelOutputWithPastAndAuxLoss, PreTrainedMoEModelMixin +# Local +from .base import ( + BaseMoEModelMixin, + MoeModelOutputWithPastAndAuxLoss, + PreTrainedMoEModelMixin, +) from .main import CausalLMMoEModelMixin diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/base.py b/src/instructlab/dolomite/hf_models/mixins/moe/base.py index 54ed982..6f7166c 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/base.py @@ -1,10 +1,13 @@ +# Standard from dataclasses import dataclass -import torch -import torch.nn as nn +# Third Party from transformers import DynamicCache from transformers.modeling_outputs import MoeModelOutputWithPast +import torch +import torch.nn as nn +# Local from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType from ...modeling_utils import ParameterizedEmbedding, get_normalization_function @@ -39,9 +42,13 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: self.head_dim = self.embed_dim // self.num_heads - self.wte = ParameterizedEmbedding(config.vocab_size, self.embed_dim, std=self.initializer_range) + self.wte = ParameterizedEmbedding( + config.vocab_size, self.embed_dim, std=self.initializer_range + ) - self.drop = nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + self.drop = ( + nn.Identity() if config.embd_pdrop == 0 else nn.Dropout(config.embd_pdrop) + ) self.h = nn.ModuleList( [ self.layer_class( @@ -62,7 +69,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: normalization_implementation=self.normalization_implementation, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing @@ -116,7 +125,9 @@ def forward( # attention_mask -> (batch_size, 1, query_length, key_length) # ========================================================================================== - past_key_values = DynamicCache() if use_cache and past_key_values is None else past_key_values + past_key_values = ( + DynamicCache() if use_cache and past_key_values is None else past_key_values + ) all_hidden_states = () if output_hidden_states else None all_router_logits = () if output_router_logits else None total_aux_loss = 0 @@ -188,7 +199,9 @@ def _prepare_a_bunch_of_stuff( tuple[torch.Tensor], ]: output_router_logits = ( - output_router_logits if output_router_logits is not None else self.config.output_router_logits + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits ) return super()._prepare_a_bunch_of_stuff( diff --git a/src/instructlab/dolomite/hf_models/mixins/moe/main.py b/src/instructlab/dolomite/hf_models/mixins/moe/main.py index 89e9632..1138711 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe/main.py @@ -1,7 +1,9 @@ -import torch +# Third Party from transformers import DynamicCache from transformers.modeling_outputs import MoeCausalLMOutputWithPast +import torch +# Local from ...config import CommonConfig from ..dense import CausalLMModelMixin from .base import MoeModelOutputWithPastAndAuxLoss @@ -32,18 +34,20 @@ def forward( max_seqlen: torch.Tensor | None = None, output_router_logits: bool | None = None, ) -> tuple | MoeCausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) # ========================================================================================== @@ -76,7 +80,9 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + lm_loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) aux_loss = transformer_outputs.aux_loss if lm_loss is None: diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py index e4e90ab..1250111 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/__init__.py @@ -1,2 +1,3 @@ +# Local from .base import BaseMoEModelMixin_TP, PreTrainedMoEModelMixin_TP from .main import CausalLMMoEModelMixin_TP diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py index 55b09de..55749ff 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/base.py @@ -1,5 +1,7 @@ +# Third Party import torch.nn as nn +# Local from ....utils import ProcessGroupManager from ...config import CommonConfig from ...enums import AttentionHeadType, PositionEmbeddingType @@ -10,7 +12,9 @@ class PreTrainedMoEModelMixin_TP(PreTrainedMoEModelMixin, PreTrainedModelMixin_TP): def __init__(self, config: CommonConfig, *args, **kwargs): - self.tensor_parallel_word_embeddings = kwargs.get("tensor_parallel_word_embeddings", False) + self.tensor_parallel_word_embeddings = kwargs.get( + "tensor_parallel_word_embeddings", False + ) self.sequence_parallel = kwargs.get("sequence_parallel", False) super().__init__(config, *args, **kwargs) @@ -68,7 +72,9 @@ def _init_model(self, config: CommonConfig, **kwargs) -> None: sequence_parallel=self.sequence_parallel, ) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self._setup_positional_encoding() # Initialize weights and apply final processing diff --git a/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py index 8f5de69..f6fdb75 100644 --- a/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py +++ b/src/instructlab/dolomite/hf_models/mixins/moe_TP/main.py @@ -1,8 +1,10 @@ -import torch +# Third Party from torch.distributed._tensor.placement_types import Replicate, Shard from transformers import DynamicCache from transformers.modeling_outputs import MoeCausalLMOutputWithPast +import torch +# Local from ...modeling_utils_TP import dtensor_to_tensor, tensor_to_dtensor from ..dense_TP import CausalLMModelMixin_TP from ..moe import CausalLMMoEModelMixin, MoeModelOutputWithPastAndAuxLoss @@ -27,18 +29,20 @@ def forward( max_seqlen: torch.Tensor | None = None, output_router_logits: bool | None = None, ) -> tuple | MoeCausalLMOutputWithPast: - input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = self.prepare_inputs_for_model( - input_ids=input_ids, - inputs_embeds=inputs_embeds, - position_ids=position_ids, - token_type_ids=token_type_ids, - labels=labels, - cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, - past_key_values=past_key_values, - attention_mask=attention_mask, - use_cache=use_cache, - output_attentions=output_attentions, + input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen = ( + self.prepare_inputs_for_model( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + position_ids=position_ids, + token_type_ids=token_type_ids, + labels=labels, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + past_key_values=past_key_values, + attention_mask=attention_mask, + use_cache=use_cache, + output_attentions=output_attentions, + ) ) transformer_outputs: MoeModelOutputWithPastAndAuxLoss = self.transformer( @@ -60,9 +64,13 @@ def forward( if self.m_width is not None: lm_logits = lm_logits / self.m_width - lm_loss = self.get_autoregressive_language_modeling_loss(lm_logits, labels, cu_seqlens) + lm_loss = self.get_autoregressive_language_modeling_loss( + lm_logits, labels, cu_seqlens + ) aux_loss = tensor_to_dtensor( - transformer_outputs.aux_loss, device_mesh=self.tp_mesh, current_placement=Replicate() + transformer_outputs.aux_loss, + device_mesh=self.tp_mesh, + current_placement=Replicate(), ) if lm_loss is None: @@ -75,8 +83,12 @@ def forward( else: if self.tensor_parallel_word_embeddings: # all gather - lm_logits = tensor_to_dtensor(lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1)) - lm_logits = dtensor_to_tensor(lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate()) + lm_logits = tensor_to_dtensor( + lm_logits, device_mesh=self.tp_mesh, current_placement=Shard(-1) + ) + lm_logits = dtensor_to_tensor( + lm_logits, device_mesh=self.tp_mesh, desired_placement=Replicate() + ) return MoeCausalLMOutputWithPast( loss=loss, diff --git a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py index 0ddd148..bade858 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/__init__.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/__init__.py @@ -1,10 +1,11 @@ +# Third Party from transformers import AutoConfig +# Local from .bigcode import export_to_huggingface_bigcode, import_from_huggingface_bigcode from .granite import export_to_huggingface_granite, import_from_huggingface_granite from .llama import export_to_huggingface_llama, import_from_huggingface_llama - _MODEL_IMPORT_FUNCTIONS = { "gpt_bigcode": import_from_huggingface_bigcode, "granite": import_from_huggingface_granite, @@ -17,7 +18,9 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) model_type = config.model_type if model_type not in _MODEL_IMPORT_FUNCTIONS: - raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") + raise NotImplementedError( + f"the current model_type ({model_type}) is not yet supported" + ) import_function = _MODEL_IMPORT_FUNCTIONS[model_type] import_function(pretrained_model_name_or_path, save_path) @@ -30,9 +33,13 @@ def import_from_huggingface(pretrained_model_name_or_path: str, save_path: str) } -def export_to_huggingface(pretrained_model_name_or_path: str, save_path: str, model_type: str) -> None: +def export_to_huggingface( + pretrained_model_name_or_path: str, save_path: str, model_type: str +) -> None: if model_type not in _MODEL_EXPORT_FUNCTIONS: - raise NotImplementedError(f"the current model_type ({model_type}) is not yet supported") + raise NotImplementedError( + f"the current model_type ({model_type}) is not yet supported" + ) export_function = _MODEL_EXPORT_FUNCTIONS[model_type] export_function(pretrained_model_name_or_path, save_path) diff --git a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py index 9ee9339..5906aa4 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/bigcode.py @@ -1,12 +1,23 @@ +# Standard import shutil -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, GPTBigCodeConfig, GPTBigCodeForCausalLM +# Third Party +from transformers import ( + AutoConfig, + AutoTokenizer, + GenerationConfig, + GPTBigCodeConfig, + GPTBigCodeForCausalLM, +) +# Local from ..enums import AttentionHeadType, PositionEmbeddingType from ..models import GPTDolomiteConfig -def import_from_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: +def import_from_huggingface_bigcode( + pretrained_model_name_or_path: str, save_path: str +) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) original_config: GPTBigCodeConfig = AutoConfig.from_pretrained(save_path) @@ -23,7 +34,9 @@ def import_from_huggingface_bigcode(pretrained_model_name_or_path: str, save_pat pass -def _import_config_from_huggingface(original_config: GPTBigCodeConfig) -> GPTDolomiteConfig: +def _import_config_from_huggingface( + original_config: GPTBigCodeConfig, +) -> GPTDolomiteConfig: assert original_config.activation_function in ["gelu_pytorch_tanh", "gelu"] config = GPTDolomiteConfig( @@ -52,7 +65,9 @@ def _import_config_from_huggingface(original_config: GPTBigCodeConfig) -> GPTDol return config -def export_to_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: str) -> None: +def export_to_huggingface_bigcode( + pretrained_model_name_or_path: str, save_path: str +) -> None: shutil.copytree(pretrained_model_name_or_path, save_path) config: GPTDolomiteConfig = AutoConfig.from_pretrained(save_path) @@ -72,8 +87,14 @@ def export_to_huggingface_bigcode(pretrained_model_name_or_path: str, save_path: def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GPTBigCodeConfig: assert config.activation_function == "gelu_pytorch_tanh" assert config.normalization_function == "layernorm" - assert AttentionHeadType(config.attention_head_type) in [AttentionHeadType.mha, AttentionHeadType.mqa] - assert PositionEmbeddingType(config.position_embedding_type) == PositionEmbeddingType.learned_absolute + assert AttentionHeadType(config.attention_head_type) in [ + AttentionHeadType.mha, + AttentionHeadType.mqa, + ] + assert ( + PositionEmbeddingType(config.position_embedding_type) + == PositionEmbeddingType.learned_absolute + ) assert config.m_emb is None assert config.m_residual is None assert config.m_width is None diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granite.py b/src/instructlab/dolomite/hf_models/model_conversion/granite.py index c9af0d6..0d9fd63 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/granite.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/granite.py @@ -1,19 +1,28 @@ +# Third Party from transformers import AutoConfig, AutoTokenizer, GenerationConfig +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..models import GPTDolomiteConfig -from .llama import _export_state_dict_to_huggingface, _import_state_dict_from_huggingface - +from .llama import ( + _export_state_dict_to_huggingface, + _import_state_dict_from_huggingface, +) try: + # Third Party from transformers import GraniteConfig, GraniteForCausalLM except: GraniteConfig = None -def import_from_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_granite( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -36,7 +45,9 @@ def import_from_huggingface_granite(pretrained_model_name_or_path: str, save_pat tokenizer.save_pretrained(save_path, legacy_format=False) -def _import_config_from_huggingface(original_config: GraniteConfig) -> GPTDolomiteConfig: +def _import_config_from_huggingface( + original_config: GraniteConfig, +) -> GPTDolomiteConfig: assert original_config.hidden_act == "silu" if original_config.num_attention_heads == original_config.num_key_value_heads: @@ -71,20 +82,32 @@ def _import_config_from_huggingface(original_config: GraniteConfig) -> GPTDolomi bos_token_id=original_config.bos_token_id, eos_token_id=original_config.eos_token_id, pad_token_id=original_config.pad_token_id, - m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, - m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, - m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + m_emb=None + if original_config.embedding_multiplier == 1 + else original_config.embedding_multiplier, + m_residual=None + if original_config.residual_multiplier == 1 + else original_config.residual_multiplier, + m_width=None + if original_config.logits_scaling == 1 + else original_config.logits_scaling, attention_multiplier=original_config.attention_multiplier, ) return config -def export_to_huggingface_granite(pretrained_model_name_or_path: str, save_path: str) -> None: - config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_granite( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -119,7 +142,9 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> GraniteConfig: num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, diff --git a/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py index 478abac..5be1796 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/granitemoe.py @@ -1,6 +1,8 @@ -import torch +# Third Party from transformers import AutoConfig, AutoTokenizer, GenerationConfig +import torch +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..modeling_utils import ( @@ -9,15 +11,19 @@ ) from ..models import MoEDolomiteConfig - try: + # Third Party from transformers import GraniteMoeConfig, GraniteMoeForCausalLM except: GraniteMoeConfig = None -def import_from_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_granitemoe( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -41,7 +47,9 @@ def import_from_huggingface_granitemoe(pretrained_model_name_or_path: str, save_ tokenizer.save_pretrained(save_path, legacy_format=False) -def _import_config_from_huggingface(original_config: GraniteMoeConfig) -> MoEDolomiteConfig: +def _import_config_from_huggingface( + original_config: GraniteMoeConfig, +) -> MoEDolomiteConfig: assert original_config.hidden_act == "silu" if original_config.num_attention_heads == original_config.num_key_value_heads: @@ -80,9 +88,15 @@ def _import_config_from_huggingface(original_config: GraniteMoeConfig) -> MoEDol bos_token_id=original_config.bos_token_id, eos_token_id=original_config.eos_token_id, pad_token_id=original_config.pad_token_id, - m_emb=None if original_config.embedding_multiplier == 1 else original_config.embedding_multiplier, - m_residual=None if original_config.residual_multiplier == 1 else original_config.residual_multiplier, - m_width=None if original_config.logits_scaling == 1 else original_config.logits_scaling, + m_emb=None + if original_config.embedding_multiplier == 1 + else original_config.embedding_multiplier, + m_residual=None + if original_config.residual_multiplier == 1 + else original_config.residual_multiplier, + m_width=None + if original_config.logits_scaling == 1 + else original_config.logits_scaling, attention_multiplier=original_config.attention_multiplier, ) @@ -99,24 +113,36 @@ def _import_state_dict_from_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), + "transformer.wte.weight": safetensors_weights_manager.get_tensor( + "model.embed_tokens.weight" + ), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor( + "model.norm.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.moe.gate.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight" - ).T.contiguous() + state_dict[f"transformer.h.{layer_idx}.moe.gate.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight" + ).T.contiguous() + ) state_dict[f"transformer.h.{layer_idx}.moe.c_fc.weight"] = ( _split_and_reorder_for_glu( @@ -128,32 +154,50 @@ def _import_state_dict_from_huggingface( .contiguous() ) state_dict[f"transformer.h.{layer_idx}.moe.c_proj.weight"] = ( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight") + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight" + ) .transpose(0, 1) .contiguous() ) - state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = ( + interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ) ) return state_dict -def export_to_huggingface_granitemoe(pretrained_model_name_or_path: str, save_path: str) -> None: - config: MoEDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_granitemoe( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: MoEDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -190,7 +234,9 @@ def _export_config_to_huggingface(config: MoEDolomiteConfig) -> GraniteMoeConfig num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, @@ -227,45 +273,71 @@ def _export_state_dict_to_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor( + "transformer.wte.weight" + ), + "model.norm.weight": safetensors_weights_manager.get_tensor( + "transformer.ln_f.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" + ) ) state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_2.weight" + ) ) state_dict[f"model.layers.{layer_idx}.block_sparse_moe.router.layer.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.gate.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.gate.weight" + ) ).T.contiguous() - state_dict[f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight"] = _split_and_reorder_for_glu( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_fc.weight").transpose(0, 1) - ).contiguous() - state_dict[f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.moe.c_proj.weight").transpose(0, 1) + state_dict[f"model.layers.{layer_idx}.block_sparse_moe.input_linear.weight"] = ( + _split_and_reorder_for_glu( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.c_fc.weight" + ).transpose(0, 1) + ).contiguous() + ) + state_dict[ + f"model.layers.{layer_idx}.block_sparse_moe.output_linear.weight" + ] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.moe.c_proj.weight" + ).transpose(0, 1) ).contiguous() - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_weight, key_weight, value_weight = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" + ) ) return state_dict diff --git a/src/instructlab/dolomite/hf_models/model_conversion/llama.py b/src/instructlab/dolomite/hf_models/model_conversion/llama.py index dee5dd4..c94df33 100644 --- a/src/instructlab/dolomite/hf_models/model_conversion/llama.py +++ b/src/instructlab/dolomite/hf_models/model_conversion/llama.py @@ -1,5 +1,13 @@ -from transformers import AutoConfig, AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM +# Third Party +from transformers import ( + AutoConfig, + AutoTokenizer, + GenerationConfig, + LlamaConfig, + LlamaForCausalLM, +) +# Local from ...utils import SafeTensorsWeightsManager, download_repo from ..enums import AttentionHeadType from ..modeling_utils import ( @@ -7,11 +15,18 @@ split_query_key_value_tensor_for_attention, ) from ..models import GPTDolomiteConfig -from ..models.gpt_dolomite import interleave_up_gate_tensor_for_mlp, split_up_gate_tensor_for_mlp +from ..models.gpt_dolomite import ( + interleave_up_gate_tensor_for_mlp, + split_up_gate_tensor_for_mlp, +) -def import_from_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: - original_config, tokenizer, downloaded_model_path = download_repo(pretrained_model_name_or_path) +def import_from_huggingface_llama( + pretrained_model_name_or_path: str, save_path: str +) -> None: + original_config, tokenizer, downloaded_model_path = download_repo( + pretrained_model_name_or_path + ) config = _import_config_from_huggingface(original_config) safetensors_weights_manager = SafeTensorsWeightsManager(downloaded_model_path) @@ -83,54 +98,100 @@ def _import_state_dict_from_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "transformer.wte.weight": safetensors_weights_manager.get_tensor("model.embed_tokens.weight"), - "transformer.ln_f.weight": safetensors_weights_manager.get_tensor("model.norm.weight"), + "transformer.wte.weight": safetensors_weights_manager.get_tensor( + "model.embed_tokens.weight" + ), + "transformer.ln_f.weight": safetensors_weights_manager.get_tensor( + "model.norm.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.input_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_1.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.input_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.post_attention_layernorm.weight" + state_dict[f"transformer.h.{layer_idx}.ln_2.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.post_attention_layernorm.weight" + ) ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.weight"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.weight"), + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.weight"] = ( + interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.up_proj.weight" + ), + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.gate_proj.weight" + ), + ) ) if f"model.layers.{layer_idx}.mlp.up_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = interleave_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.up_proj.bias"), - safetensors_weights_manager.get_tensor(f"model.layers.{layer_idx}.mlp.gate_proj.bias"), + state_dict[f"transformer.h.{layer_idx}.mlp.c_fc.bias"] = ( + interleave_up_gate_tensor_for_mlp( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.up_proj.bias" + ), + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.gate_proj.bias" + ), + ) ) - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.weight" + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.weight" + ) ) - if f"model.layers.{layer_idx}.mlp.down_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.mlp.down_proj.bias" + if ( + f"model.layers.{layer_idx}.mlp.down_proj.bias" + in safetensors_weights_manager + ): + state_dict[f"transformer.h.{layer_idx}.mlp.c_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.mlp.down_proj.bias" + ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.weight"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + state_dict[f"transformer.h.{layer_idx}.attn.c_attn.weight"] = ( + interleave_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.weight" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) - if f"model.layers.{layer_idx}.self_attn.q_proj.bias" in safetensors_weights_manager: + if ( + f"model.layers.{layer_idx}.self_attn.q_proj.bias" + in safetensors_weights_manager + ): state_dict[f"transformer.h.{layer_idx}.attn.c_attn.bias"] = ( interleave_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.q_proj.bias"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.k_proj.bias"), - safetensors_weights_manager.get_slice(f"model.layers.{layer_idx}.self_attn.v_proj.bias"), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.q_proj.bias" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.k_proj.bias" + ), + safetensors_weights_manager.get_slice( + f"model.layers.{layer_idx}.self_attn.v_proj.bias" + ), num_heads, num_key_value_heads, head_dim, @@ -138,22 +199,35 @@ def _import_state_dict_from_huggingface( ) ) - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.weight" + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.weight" + ) ) - if f"model.layers.{layer_idx}.self_attn.o_proj.bias" in safetensors_weights_manager: - state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = safetensors_weights_manager.get_tensor( - f"model.layers.{layer_idx}.self_attn.o_proj.bias" + if ( + f"model.layers.{layer_idx}.self_attn.o_proj.bias" + in safetensors_weights_manager + ): + state_dict[f"transformer.h.{layer_idx}.attn.c_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"model.layers.{layer_idx}.self_attn.o_proj.bias" + ) ) return state_dict -def export_to_huggingface_llama(pretrained_model_name_or_path: str, save_path: str) -> None: - config: GPTDolomiteConfig = AutoConfig.from_pretrained(pretrained_model_name_or_path) +def export_to_huggingface_llama( + pretrained_model_name_or_path: str, save_path: str +) -> None: + config: GPTDolomiteConfig = AutoConfig.from_pretrained( + pretrained_model_name_or_path + ) original_config = _export_config_to_huggingface(config) - safetensors_weights_manager = SafeTensorsWeightsManager(pretrained_model_name_or_path) + safetensors_weights_manager = SafeTensorsWeightsManager( + pretrained_model_name_or_path + ) state_dict = _export_state_dict_to_huggingface( safetensors_weights_manager, config.n_layer, @@ -192,7 +266,9 @@ def _export_config_to_huggingface(config: GPTDolomiteConfig) -> LlamaConfig: num_hidden_layers=config.n_layer, num_attention_heads=config.n_head, num_key_value_heads=config.num_key_value_heads, - intermediate_size=4 * config.n_embd if config.n_inner is None else config.n_inner, + intermediate_size=4 * config.n_embd + if config.n_inner is None + else config.n_inner, hidden_act="silu", rms_norm_eps=config.layer_norm_epsilon, use_cache=config.use_cache, @@ -221,71 +297,101 @@ def _export_state_dict_to_huggingface( attention_head_type: AttentionHeadType, ) -> None: state_dict = { - "model.embed_tokens.weight": safetensors_weights_manager.get_tensor("transformer.wte.weight"), - "model.norm.weight": safetensors_weights_manager.get_tensor("transformer.ln_f.weight"), + "model.embed_tokens.weight": safetensors_weights_manager.get_tensor( + "transformer.wte.weight" + ), + "model.norm.weight": safetensors_weights_manager.get_tensor( + "transformer.ln_f.weight" + ), } if safetensors_weights_manager.has_tensor("lm_head.weight"): - state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor("lm_head.weight") + state_dict["lm_head.weight"] = safetensors_weights_manager.get_tensor( + "lm_head.weight" + ) for layer_idx in range(num_layers): - state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.ln_1.weight" + state_dict[f"model.layers.{layer_idx}.input_layernorm.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_1.weight" + ) ) state_dict[f"model.layers.{layer_idx}.post_attention_layernorm.weight"] = ( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.ln_2.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.ln_2.weight" + ) ) up_weight, gate_weight = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.weight") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_fc.weight" + ) ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.weight"] = up_weight state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.weight"] = gate_weight if f"transformer.h.{layer_idx}.mlp.c_fc.bias" in safetensors_weights_manager: up_bias, gate_bias = split_up_gate_tensor_for_mlp( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.mlp.c_fc.bias") + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_fc.bias" + ) ) state_dict[f"model.layers.{layer_idx}.mlp.up_proj.bias"] = up_bias state_dict[f"model.layers.{layer_idx}.mlp.gate_proj.bias"] = gate_bias - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.weight" + ) ) if f"transformer.h.{layer_idx}.mlp.c_proj.bias" in safetensors_weights_manager: - state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.mlp.c_proj.bias" + state_dict[f"model.layers.{layer_idx}.mlp.down_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.mlp.c_proj.bias" + ) ) - query_weight, key_weight, value_weight = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.weight"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_weight, key_weight, value_weight = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.weight" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"] = query_weight state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"] = key_weight state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"] = value_weight if f"transformer.h.{layer_idx}.attn.c_attn.bias" in safetensors_weights_manager: - query_bias, key_bias, value_bias = split_query_key_value_tensor_for_attention( - safetensors_weights_manager.get_tensor(f"transformer.h.{layer_idx}.attn.c_attn.bias"), - num_heads, - num_key_value_heads, - head_dim, - attention_head_type, + query_bias, key_bias, value_bias = ( + split_query_key_value_tensor_for_attention( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_attn.bias" + ), + num_heads, + num_key_value_heads, + head_dim, + attention_head_type, + ) ) state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.bias"] = query_bias state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.bias"] = key_bias state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.bias"] = value_bias - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.weight" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.weight" + ) ) if f"transformer.h.{layer_idx}.attn.c_proj.bias" in safetensors_weights_manager: - state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = safetensors_weights_manager.get_tensor( - f"transformer.h.{layer_idx}.attn.c_proj.bias" + state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.bias"] = ( + safetensors_weights_manager.get_tensor( + f"transformer.h.{layer_idx}.attn.c_proj.bias" + ) ) return state_dict diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py index 92aea83..f29f9cd 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/__init__.py @@ -1,3 +1,4 @@ +# Local from .activations import get_activation_function, is_glu from .attention import ( SDPA, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py index 478c5dd..8cf8873 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/__init__.py @@ -1,5 +1,7 @@ +# Third Party import torch.nn as nn +# Local from .base import get_base_activation from .glu import get_glu_activation, is_glu diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py index 3a8d155..f58cd32 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/base.py @@ -1,6 +1,6 @@ -import torch.nn as nn +# Third Party from transformers.activations import ACT2CLS, ClassInstantier - +import torch.nn as nn _BASE_ACTIVATIONS = { "celu": nn.modules.CELU, diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py index 1419488..a59cd0e 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/activations/glu.py @@ -1,9 +1,10 @@ +# Third Party import torch import torch.nn as nn +# Local from .base import get_base_activation - _GLU_BASE_MAPPING = { "ceglu": "celu", "eglu": "elu", diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py index c743985..46cf7b1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/__init__.py @@ -1,7 +1,10 @@ +# Standard import inspect +# Third Party import torch +# Local from ...config import CommonConfig from ...enums import AttentionHeadType from .base import Attention @@ -18,7 +21,6 @@ split_query_key_value_tensor_for_mqa, ) - _ATTENTION_MODULES = { "eager": Attention, "sdpa": SDPA, @@ -69,7 +71,9 @@ def interleave_query_key_value_tensor_for_attention( ) -> torch.Tensor: if attention_head_type.value in _INTERLEAVE_FUNCTIONS: interleave_function = _INTERLEAVE_FUNCTIONS[attention_head_type.value] - interleave_function_parameters = inspect.signature(interleave_function).parameters.keys() + interleave_function_parameters = inspect.signature( + interleave_function + ).parameters.keys() parameters_to_pass = {} this_function_parameters = locals() diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py index 51903f1..7edf15e 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/base.py @@ -1,10 +1,13 @@ +# Standard import math +# Third Party +from transformers import DynamicCache import torch import torch.nn as nn import torch.nn.functional as F -from transformers import DynamicCache +# Local from ...config import CommonConfig from ...enums import AttentionHeadType, InitMethod, PositionEmbeddingType from ...utils import divide_if_divisible @@ -14,7 +17,9 @@ class Attention(nn.Module): - def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = None) -> None: + def __init__( + self, config: CommonConfig, causal: bool, layer_idx: int | None = None + ) -> None: super().__init__() self.causal = causal @@ -36,7 +41,9 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N self.attention_head_type = AttentionHeadType(config.attention_head_type) - self.position_embedding_type = PositionEmbeddingType(config.position_embedding_type) + self.position_embedding_type = PositionEmbeddingType( + config.position_embedding_type + ) self.scale_attn_weights = config.scale_attn_weights self.attention_multiplier = config.attention_multiplier @@ -64,9 +71,13 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N if self.num_key_value_heads is None: self.num_key_value_heads = 1 - assert self.num_key_value_heads == 1, f"{self.__class__.__name__} should have 1 head for keys and values" + assert ( + self.num_key_value_heads == 1 + ), f"{self.__class__.__name__} should have 1 head for keys and values" else: - raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") + raise ValueError( + f"unexpected attention_head_type ({self.attention_head_type})" + ) # note that the actual layout is different for the output and depends on whether we are using MHA, MQA or GQA # (self.hidden_size + 2 * self.num_key_value_heads * self.head_dim) is just the actual number output features @@ -83,15 +94,23 @@ def __init__(self, config: CommonConfig, causal: bool, layer_idx: int | None = N std = initializer_range / math.sqrt(2 * n_layer) if init_method == InitMethod.mup: std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(self.hidden_size, self.hidden_size, bias=self.add_bias, std=std) + self.c_proj = ParameterizedLinear( + self.hidden_size, self.hidden_size, bias=self.add_bias, std=std + ) self.attn_pdrop = config.attn_pdrop self.resid_pdrop = config.resid_pdrop - self.attn_dropout = nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) - self.resid_dropout = nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) + self.attn_dropout = ( + nn.Identity() if self.attn_pdrop == 0 else nn.Dropout(self.attn_pdrop) + ) + self.resid_dropout = ( + nn.Identity() if self.resid_pdrop == 0 else nn.Dropout(self.resid_pdrop) + ) - def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def _prepare_qkv_for_forward( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # ========================================================================================== # hidden_states -> (batch_size, query_length, num_heads * head_dim) # ========================================================================================== @@ -111,7 +130,9 @@ def _prepare_qkv_for_forward(self, hidden_states: torch.Tensor) -> tuple[torch.T elif self.attention_head_type == AttentionHeadType.mqa: query, key, value = self._prepare_qkv_for_forward_mqa(hidden_states) else: - raise ValueError(f"unexpected attention_head_type ({self.attention_head_type})") + raise ValueError( + f"unexpected attention_head_type ({self.attention_head_type})" + ) # ========================================================================================== # query -> (batch_size, num_heads, query_length, head_dim) @@ -138,10 +159,17 @@ def _prepare_qkv_for_forward_gqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - hidden_states = hidden_states.view(batch_size, query_length, self.num_key_value_heads, -1) + hidden_states = hidden_states.view( + batch_size, query_length, self.num_key_value_heads, -1 + ) query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ( + (self.num_heads // self.num_key_value_heads) * self.head_dim, + self.head_dim, + self.head_dim, + ), + dim=-1, ) # this needs to be a reshape instead of view sadly @@ -158,7 +186,9 @@ def _prepare_qkv_for_forward_mqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: batch_size, query_length = hidden_states.shape[:-1] - query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) + query, key, value = hidden_states.split( + (self.hidden_size, self.head_dim, self.head_dim), dim=-1 + ) query = query.view(batch_size, query_length, self.num_heads, -1) @@ -233,16 +263,20 @@ def forward( if attention_mask is None: attn_weights = torch.empty( - (batch_size * self.num_heads, query_length, key_length), device=query.device, dtype=query.dtype + (batch_size * self.num_heads, query_length, key_length), + device=query.device, + dtype=query.dtype, ) beta = 0 else: - attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape(-1, query_length, key_length) + attn_weights = attention_mask.expand(-1, self.num_heads, -1, -1).reshape( + -1, query_length, key_length + ) beta = 1 - attn_weights = torch.baddbmm(attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False)).view( - batch_size, self.num_heads, query_length, key_length - ) + attn_weights = torch.baddbmm( + attn_weights, query, key, beta=beta, alpha=self._get_softmax_scale(False) + ).view(batch_size, self.num_heads, query_length, key_length) # ========================================================================================== # attn_weights -> (batch_size, num_heads, query_length, key_length) @@ -263,7 +297,9 @@ def forward( # ========================================================================================== attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) + attn_output = attn_output.reshape( + batch_size, -1, self.num_heads * self.head_dim + ) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py index 26bac53..9c44393 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/flash.py @@ -1,7 +1,9 @@ -import torch +# Third Party from transformers import DynamicCache from transformers.modeling_flash_attention_utils import _flash_attention_forward +import torch +# Local from ...enums import AttentionHeadType, PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py index 9b07a51..6338afc 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/padding_free.py @@ -1,13 +1,15 @@ -import torch +# Third Party from transformers import DynamicCache +import torch +# Local from ....utils import is_flash_attention_available from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention - if is_flash_attention_available(): + # Third Party from flash_attn.flash_attn_interface import flash_attn_varlen_func @@ -94,7 +96,12 @@ def _prepare_qkv_for_forward_gqa( hidden_states = hidden_states.view(total_q, self.num_key_value_heads, -1) query, key, value = hidden_states.split( - ((self.num_heads // self.num_key_value_heads) * self.head_dim, self.head_dim, self.head_dim), dim=-1 + ( + (self.num_heads // self.num_key_value_heads) * self.head_dim, + self.head_dim, + self.head_dim, + ), + dim=-1, ) # this needs to be a reshape instead of view sadly @@ -107,7 +114,9 @@ def _prepare_qkv_for_forward_mqa( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: total_q = hidden_states.shape[0] - query, key, value = hidden_states.split((self.hidden_size, self.head_dim, self.head_dim), dim=-1) + query, key, value = hidden_states.split( + (self.hidden_size, self.head_dim, self.head_dim), dim=-1 + ) query = query.view(total_q, self.num_heads, -1) key = key.unsqueeze(1) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py index ad3290e..8188ee1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/sdpa.py @@ -1,7 +1,9 @@ +# Third Party +from transformers import DynamicCache import torch import torch.nn.functional as F -from transformers import DynamicCache +# Local from ...enums import PositionEmbeddingType from ..position_embedding import apply_rotary_pos_emb from .base import Attention @@ -71,7 +73,9 @@ def forward( batch_size = attn_output.shape[0] attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, -1, self.num_heads * self.head_dim) + attn_output = attn_output.reshape( + batch_size, -1, self.num_heads * self.head_dim + ) # ========================================================================================== # attn_output -> (batch_size, query_length, num_heads * head_dim) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py index ca60ca9..275fe56 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/attention/utils.py @@ -1,3 +1,4 @@ +# Third Party import torch @@ -61,14 +62,21 @@ def interleave_query_key_value_tensor_for_gqa( def split_query_key_value_tensor_for_gqa( - query_key_value_weight: torch.Tensor, num_heads: int, num_key_value_heads: int, head_dim: int + query_key_value_weight: torch.Tensor, + num_heads: int, + num_key_value_heads: int, + head_dim: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: query_heads_per_group = num_heads // num_key_value_heads original_shape = query_key_value_weight.shape - query_key_value_weight = query_key_value_weight.view(num_key_value_heads, (query_heads_per_group + 2), -1) + query_key_value_weight = query_key_value_weight.view( + num_key_value_heads, (query_heads_per_group + 2), -1 + ) - query_weight, key_weight, value_weight = query_key_value_weight.split((query_heads_per_group, 1, 1), 1) + query_weight, key_weight, value_weight = query_key_value_weight.split( + (query_heads_per_group, 1, 1), 1 + ) query_weight = query_weight.reshape(-1, *original_shape[1:]) key_weight = key_weight.reshape(-1, *original_shape[1:]) @@ -92,7 +100,9 @@ def split_query_key_value_tensor_for_mqa( return query_key_value_weight.split((num_heads * head_dim, head_dim, head_dim)) -def repeat_key_value(x: torch.Tensor, num_heads: int, num_key_value_heads: int) -> torch.Tensor: +def repeat_key_value( + x: torch.Tensor, num_heads: int, num_key_value_heads: int +) -> torch.Tensor: num_groups = num_heads // num_key_value_heads if num_groups == 1: diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py index 3cff32e..806a588 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/embedding.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py index 560e100..524b9c7 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/linear.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/linear.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py index b4bf746..edb0856 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/__init__.py @@ -1,9 +1,10 @@ +# Third Party import torch.nn as nn +# Local from .layernorm import get_layernorm from .rmsnorm import get_rmsnorm - _NORMALIZATION_FUNCTIONS = { "layernorm": get_layernorm, "rmsnorm": get_rmsnorm, @@ -18,7 +19,11 @@ def get_normalization_function( ) -> nn.LayerNorm: if name in _NORMALIZATION_FUNCTIONS: return _NORMALIZATION_FUNCTIONS[name]( - normalized_shape, eps=eps, normalization_implementation=normalization_implementation + normalized_shape, + eps=eps, + normalization_implementation=normalization_implementation, ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py index 915c7ca..95ef207 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/__init__.py @@ -1,9 +1,10 @@ +# Third Party import torch.nn as nn +# Local from .apex import ApexLayerNorm from .apex_persistent import ApexPersistentLayerNorm - _LAYERNORM_MODULES = { "torch": nn.LayerNorm, "apex": ApexLayerNorm, @@ -17,6 +18,10 @@ def get_layernorm( normalization_implementation: str = "torch", ) -> nn.LayerNorm: if normalization_implementation in _LAYERNORM_MODULES: - return _LAYERNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + return _LAYERNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py index 763ad7f..5d60023 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex.py @@ -1,9 +1,11 @@ +# Third Party import torch import torch.nn as nn def is_apex_layernorm_available() -> bool: try: + # Third Party from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction return True @@ -12,14 +14,21 @@ def is_apex_layernorm_available() -> bool: if is_apex_layernorm_available(): + # Third Party from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction def apex_layernorm( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient: bool + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient: bool, ) -> torch.Tensor: normalized_shape = (input.shape[-1],) - return FusedLayerNormAffineFunction.apply(input, weight, bias, normalized_shape, eps, memory_efficient) + return FusedLayerNormAffineFunction.apply( + input, weight, bias, normalized_shape, eps, memory_efficient + ) class ApexLayerNorm(nn.LayerNorm): diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py index e3ac497..40ec27c 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/layernorm/apex_persistent.py @@ -1,9 +1,11 @@ +# Third Party import torch import torch.nn as nn def is_apex_persistent_layernorm_available() -> bool: try: + # Third Party from apex.contrib.layer_norm.layer_norm import FastLayerNormFN return True @@ -12,6 +14,7 @@ def is_apex_persistent_layernorm_available() -> bool: if is_apex_persistent_layernorm_available(): + # Third Party from apex.contrib.layer_norm.layer_norm import FastLayerNormFN @@ -44,7 +47,11 @@ def is_apex_persistent_layernorm_available() -> bool: def apex_persistent_layernorm( - input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float, memory_efficient + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + memory_efficient, ) -> torch.Tensor: return FastLayerNormFN.apply(input, weight, bias, eps, memory_efficient) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py index 42a64c3..a7c7dc1 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/__init__.py @@ -1,11 +1,17 @@ +# Third Party import torch.nn as nn +# Local from .apex import ApexRMSNorm from .base import RMSNorm -#from .torchtitan import TorchTitanRMSNorm + +# from .torchtitan import TorchTitanRMSNorm # Removing TorchTitanRMSNorm to avoid unecessary imports and checks -_RMSNORM_MODULES = {"torch": RMSNorm, "apex": ApexRMSNorm}#, "torchtitan": TorchTitanRMSNorm} +_RMSNORM_MODULES = { + "torch": RMSNorm, + "apex": ApexRMSNorm, +} # , "torchtitan": TorchTitanRMSNorm} def get_rmsnorm( @@ -14,6 +20,10 @@ def get_rmsnorm( normalization_implementation: str = "torch", ) -> nn.LayerNorm: if normalization_implementation in _RMSNORM_MODULES: - return _RMSNORM_MODULES[normalization_implementation](normalized_shape=normalized_shape, eps=eps) + return _RMSNORM_MODULES[normalization_implementation]( + normalized_shape=normalized_shape, eps=eps + ) - raise ValueError(f"unexpected `normalization_implementation` {normalization_implementation}") + raise ValueError( + f"unexpected `normalization_implementation` {normalization_implementation}" + ) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py index c91f4e7..2c2d646 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/apex.py @@ -1,10 +1,14 @@ +# Third Party import torch import torch.nn as nn def is_apex_rmsnorm_available() -> bool: try: - from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + # Third Party + from apex.normalization.fused_layer_norm import ( + FusedRMSNormAffineMixedDtypesFunction, + ) return True except ImportError: @@ -12,12 +16,19 @@ def is_apex_rmsnorm_available() -> bool: if is_apex_rmsnorm_available(): - from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + # Third Party + from apex.normalization.fused_layer_norm import ( + FusedRMSNormAffineMixedDtypesFunction, + ) -def apex_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float, memory_efficient: bool) -> torch.Tensor: +def apex_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float, memory_efficient: bool +) -> torch.Tensor: normalized_shape = (input.shape[-1],) - return FusedRMSNormAffineMixedDtypesFunction.apply(input, weight, normalized_shape, eps, memory_efficient) + return FusedRMSNormAffineMixedDtypesFunction.apply( + input, weight, normalized_shape, eps, memory_efficient + ) class ApexRMSNorm(nn.RMSNorm): diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py index 82dd4a2..8f45676 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/base.py @@ -1,3 +1,4 @@ +# Third Party import torch import torch.nn as nn diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py index c5fd754..38bd5af 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/normalization/rmsnorm/torchtitan.py @@ -10,16 +10,18 @@ """Code taken from torchtitan: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py""" - +# Standard import math +# Third Party import torch import torch.nn as nn +# Local from .....utils import is_triton_available - if is_triton_available(): + # Third Party import triton import triton.language as tl @@ -113,7 +115,9 @@ def _rms_norm_bwd_kernel_sm( for row in range(row_start, row_end): # Load input, output gradient, and reciprocal standard deviation x = tl.load(X + row * stride_x + cols, mask=mask, other=0.0).to(tl.float32) - dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to(tl.float32) + dy = tl.load(DY + row * stride_dy + cols, mask=mask, other=0.0).to( + tl.float32 + ) rstd = tl.load(Rstd + row) # Compute normalized input and gradients @@ -153,7 +157,9 @@ def forward(ctx, x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Ten raise ValueError(f"N {N} must be <= {block_N=}") grid = lambda meta: (M,) - _rms_norm_fwd_kernel[grid](x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N) + _rms_norm_fwd_kernel[grid]( + x, x.stride(0), y, y.stride(0), weight, rstd, eps, M, N, block_N + ) ctx.eps = eps ctx.save_for_backward(x, weight, rstd) @@ -189,14 +195,29 @@ def backward(ctx, dy: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, None]: grid = lambda meta: (sm_count,) _rms_norm_bwd_kernel_sm[grid]( - x, x.stride(0), weight, dy, dy.stride(0), dx, dx.stride(0), rstd, _dw, eps, M, N, rows_per_sm, block_N + x, + x.stride(0), + weight, + dy, + dy.stride(0), + dx, + dx.stride(0), + rstd, + _dw, + eps, + M, + N, + rows_per_sm, + block_N, ) dw = _dw.sum(0).to(weight.dtype) dx = dx.view(x_shape_start) return dx, dw, None -def torchtitan_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: +def torchtitan_rmsnorm( + input: torch.Tensor, weight: torch.Tensor, eps: float +) -> torch.Tensor: return _TorchTitanRMSNorm.apply(input, weight, eps) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py index e82f7cf..1e6fb4f 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/__init__.py @@ -1,2 +1,3 @@ +# Local from .alibi import Alibi from .rope import RoPE, YaRNScaledRoPE, apply_rotary_pos_emb diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py index 3f49177..585bf98 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/alibi.py @@ -1,5 +1,7 @@ +# Standard import math +# Third Party import torch import torch.nn as nn @@ -21,24 +23,40 @@ def forward( ) -> torch.Tensor: if attention_mask is None: arange_tensor = ( - torch.arange(key_length, device=device).unsqueeze(0).unsqueeze(0).expand(batch_size, -1, -1) + torch.arange(key_length, device=device) + .unsqueeze(0) + .unsqueeze(0) + .expand(batch_size, -1, -1) ) else: - arange_tensor = (attention_mask.cumsum(dim=-1) - 1).masked_fill_(attention_mask == 0, 0).unsqueeze(1) + arange_tensor = ( + (attention_mask.cumsum(dim=-1) - 1) + .masked_fill_(attention_mask == 0, 0) + .unsqueeze(1) + ) alibi = self.slopes.unsqueeze(1) * arange_tensor return alibi.to(dtype) def reset_parameters(self) -> None: closest_power_of_2 = 2 ** math.floor(math.log2(self.num_heads)) - base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32 + ) powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) slopes = torch.pow(base, powers) if closest_power_of_2 != self.num_heads: - extra_base = torch.tensor(2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32) - num_remaining_heads = min(closest_power_of_2, self.num_heads - closest_power_of_2) - extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32) + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min( + closest_power_of_2, self.num_heads - closest_power_of_2 + ) + extra_powers = torch.arange( + 1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32 + ) slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) self.register_buffer("slopes", slopes, persistent=False) diff --git a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py index 71c5916..5a3d0d3 100644 --- a/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py +++ b/src/instructlab/dolomite/hf_models/modeling_utils/position_embedding/rope.py @@ -1,7 +1,9 @@ """Logic is copied from transformers.models.llama.modeling_utils with slight modifications""" +# Standard import math +# Third Party import torch import torch.nn as nn @@ -22,7 +24,9 @@ def __init__( self.reset_parameters() - def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, seq_len: int, dtype: torch.dtype, device: torch.device + ) -> tuple[torch.Tensor, torch.Tensor]: if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) @@ -32,10 +36,14 @@ def forward(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> tup return cos, sin def reset_parameters(self) -> None: - self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=None, dtype=torch.float32) + self._set_cos_sin_cache( + seq_len=self.max_position_embeddings, device=None, dtype=torch.float32 + ) @torch.no_grad() - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: + def _set_cos_sin_cache( + self, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> None: self.max_seq_len_cached = seq_len inv_freq = self._get_inv_freq(device) @@ -46,12 +54,20 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dt # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False) - self.register_buffer("sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False) + self.register_buffer( + "cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False + ) def _get_inv_freq(self, device: torch.device) -> torch.Tensor: return 1.0 / ( - self.base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / self.head_dim) + self.base + ** ( + torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) + / self.head_dim + ) ) @@ -86,17 +102,27 @@ def __init__( self.reset_parameters() def _get_inv_freq(self, device: torch.device) -> torch.Tensor: - pos_freqs = self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim) + pos_freqs = self.base ** ( + torch.arange(0, self.head_dim, 2).float() / self.head_dim + ) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) low, high = _yarn_find_correction_range( - self.beta_fast, self.beta_slow, self.head_dim, self.base, self.original_max_position_embeddings + self.beta_fast, + self.beta_slow, + self.head_dim, + self.base, + self.original_max_position_embeddings, ) inv_freq_mask = ( - 1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float() - ) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation - inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + (1 - _yarn_linear_ramp_mask(low, high, self.head_dim // 2).float()) + * self.extrapolation_factor + ) # Get n-d rotational scaling corrected for extrapolation + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) return inv_freq @@ -118,15 +144,25 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor: def _yarn_find_correction_dim( num_rotations: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 ) -> float: - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) # Find dim range bounds based on rotations def _yarn_find_correction_range( - low_rot: int, high_rot: int, dim: int, base: int = 10000, max_position_embeddings: int = 2048 + low_rot: int, + high_rot: int, + dim: int, + base: int = 10000, + max_position_embeddings: int = 2048, ) -> int: - low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) - high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) return max(low, 0), min(high, dim - 1) # Clamp values just in case diff --git a/src/instructlab/dolomite/hf_models/models/__init__.py b/src/instructlab/dolomite/hf_models/models/__init__.py index 871910e..8eb2025 100644 --- a/src/instructlab/dolomite/hf_models/models/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/__init__.py @@ -2,4 +2,4 @@ # Extracted from https://github.com/ibm-granite/dolomite-engine # ---------------------------------------------------------------- # Local -from .gpt_dolomite import GPTDolomiteForCausalLM, GPTDolomiteModel, GPTDolomiteConfig +from .gpt_dolomite import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py index 347102e..07f56fb 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/__init__.py @@ -1,3 +1,4 @@ +# Local from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel from .config import GPTDolomiteConfig from .main import GPTDolomiteForCausalLM diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py index c9bee9d..5c9a4b6 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/base.py @@ -1,3 +1,4 @@ +# Local from ...mixins import BaseModelMixin, PreTrainedModelMixin from .config import GPTDolomiteConfig from .layer import GPTDolomiteBlock diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py index 8b83592..8015ae2 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/config.py @@ -1,3 +1,4 @@ +# Local from ...config import CommonConfig diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py index 5fc15a5..7f9c954 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/layer.py @@ -1,7 +1,9 @@ +# Third Party +from transformers import DynamicCache import torch import torch.nn as nn -from transformers import DynamicCache +# Local from ...enums import AttentionHeadType from ...modeling_utils import get_attention_module, get_normalization_function from .config import GPTDolomiteConfig @@ -36,7 +38,11 @@ def __init__( normalization_implementation=normalization_implementation, ) self.attn = get_attention_module( - config, True, attention_implementation, use_padding_free_transformer, layer_idx + config, + True, + attention_implementation, + use_padding_free_transformer, + layer_idx, ) self.ln_2 = get_normalization_function( config.normalization_function, diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py index cba1599..1e7527f 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/main.py @@ -1,3 +1,4 @@ +# Local from ...mixins import CausalLMModelMixin from .base import GPTDolomiteModel, GPTDolomitePreTrainedModel diff --git a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py index b94e41a..e08cbcc 100644 --- a/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py +++ b/src/instructlab/dolomite/hf_models/models/gpt_dolomite/mlp.py @@ -1,8 +1,11 @@ +# Standard import math +# Third Party import torch import torch.nn as nn +# Local from ...enums import InitMethod from ...modeling_utils import ParameterizedLinear, get_activation_function, is_glu from .config import GPTDolomiteConfig @@ -38,9 +41,13 @@ def __init__(self, config: GPTDolomiteConfig) -> None: std = initializer_range / math.sqrt(2 * n_layer) if init_method == InitMethod.mup: std /= math.sqrt(m_width) - self.c_proj = ParameterizedLinear(intermediate_size, hidden_size, bias=add_bias, std=std) + self.c_proj = ParameterizedLinear( + intermediate_size, hidden_size, bias=add_bias, std=std + ) - self.dropout = nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) + self.dropout = ( + nn.Identity() if residual_dropout == 0 else nn.Dropout(residual_dropout) + ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.c_fc(hidden_states) @@ -50,9 +57,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return hidden_states -def interleave_up_gate_tensor_for_mlp(up_weight: torch.Tensor, gate_weight: torch.Tensor) -> torch.Tensor: +def interleave_up_gate_tensor_for_mlp( + up_weight: torch.Tensor, gate_weight: torch.Tensor +) -> torch.Tensor: return torch.cat([up_weight, gate_weight]) -def split_up_gate_tensor_for_mlp(c_fc_weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +def split_up_gate_tensor_for_mlp( + c_fc_weight: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: return c_fc_weight.chunk(2) diff --git a/src/instructlab/dolomite/hf_models/register_hf.py b/src/instructlab/dolomite/hf_models/register_hf.py index e92e456..426cd3b 100644 --- a/src/instructlab/dolomite/hf_models/register_hf.py +++ b/src/instructlab/dolomite/hf_models/register_hf.py @@ -1,11 +1,13 @@ -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM - -from .models import ( - GPTDolomiteConfig, - GPTDolomiteForCausalLM, - GPTDolomiteModel, +# Third Party +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForSeq2SeqLM, ) +# Local +from .models import GPTDolomiteConfig, GPTDolomiteForCausalLM, GPTDolomiteModel # (AutoConfig, AutoModel, AutoModelForCausalLM) _CUSTOM_MODEL_REGISTRY = [ @@ -16,7 +18,11 @@ def register_model_classes() -> None: - for config_class, auto_model_class, auto_model_for_causal_lm_class in _CUSTOM_MODEL_REGISTRY: + for ( + config_class, + auto_model_class, + auto_model_for_causal_lm_class, + ) in _CUSTOM_MODEL_REGISTRY: model_type = config_class.model_type AutoConfig.register(model_type, config_class) @@ -27,5 +33,11 @@ def register_model_classes() -> None: _CUSTOM_MODEL_CLASSES.append(auto_model_for_causal_lm_class) -def is_custom_model(model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], model_type: str) -> bool: - return model_class.__name__ in _CUSTOM_MODEL_CLASSES or model_type in _CUSTOM_MODEL_TYPES +def is_custom_model( + model_class: type[AutoModelForCausalLM] | type[AutoModelForSeq2SeqLM], + model_type: str, +) -> bool: + return ( + model_class.__name__ in _CUSTOM_MODEL_CLASSES + or model_type in _CUSTOM_MODEL_TYPES + ) diff --git a/src/instructlab/dolomite/hf_models/utils.py b/src/instructlab/dolomite/hf_models/utils.py index d6ae749..e66cea0 100644 --- a/src/instructlab/dolomite/hf_models/utils.py +++ b/src/instructlab/dolomite/hf_models/utils.py @@ -1,3 +1,4 @@ +# Third Party import torch @@ -25,13 +26,18 @@ def convert_padding_free_lists_to_tensors( labels: list[list[int]] | None = None, device: torch.device = None, ) -> tuple[torch.Tensor]: - # check input types are correct error_message = "{variable} should be of type List[List[{dtype}]]" _check_list_type(input_ids, error_message.format(variable="input_ids", dtype="int")) - _check_list_type(inputs_embeds, error_message.format(variable="inputs_embeds", dtype="float")) - _check_list_type(position_ids, error_message.format(variable="position_ids", dtype="int")) - _check_list_type(token_type_ids, error_message.format(variable="token_type_ids", dtype="int")) + _check_list_type( + inputs_embeds, error_message.format(variable="inputs_embeds", dtype="float") + ) + _check_list_type( + position_ids, error_message.format(variable="position_ids", dtype="int") + ) + _check_list_type( + token_type_ids, error_message.format(variable="token_type_ids", dtype="int") + ) _check_list_type(labels, error_message.format(variable="labels", dtype="int")) # prepare inputs for the model @@ -57,7 +63,9 @@ def convert_padding_free_lists_to_tensors( return input_ids, position_ids, token_type_ids, labels, cu_seqlens, max_seqlen -def _check_list_type(list_of_list: list[list[int | float]] | None, error_message: str) -> None: +def _check_list_type( + list_of_list: list[list[int | float]] | None, error_message: str +) -> None: if list_of_list is None: return diff --git a/src/instructlab/dolomite/utils/hf_hub.py b/src/instructlab/dolomite/utils/hf_hub.py index 82d3431..d85ebd9 100644 --- a/src/instructlab/dolomite/utils/hf_hub.py +++ b/src/instructlab/dolomite/utils/hf_hub.py @@ -1,11 +1,15 @@ +# Standard import os +# Third Party from transformers import AutoConfig, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, cached_file from transformers.utils.hub import get_checkpoint_shard_files -def download_repo(repo_name_or_path: str) -> tuple[AutoConfig | None, AutoTokenizer | None, str]: +def download_repo( + repo_name_or_path: str, +) -> tuple[AutoConfig | None, AutoTokenizer | None, str]: config = _download_config(repo_name_or_path) tokenizer = _download_tokenizer(repo_name_or_path) model_path = None @@ -20,7 +24,9 @@ def download_repo(repo_name_or_path: str) -> tuple[AutoConfig | None, AutoTokeni except: # try downloading model weights if they are sharded try: - sharded_filename = cached_file(repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME) + sharded_filename = cached_file( + repo_name_or_path, SAFE_WEIGHTS_INDEX_NAME + ) get_checkpoint_shard_files(repo_name_or_path, sharded_filename) model_path = os.path.dirname(sharded_filename) except: diff --git a/src/instructlab/dolomite/utils/safetensors.py b/src/instructlab/dolomite/utils/safetensors.py index a9ffd0b..65a0a75 100644 --- a/src/instructlab/dolomite/utils/safetensors.py +++ b/src/instructlab/dolomite/utils/safetensors.py @@ -1,11 +1,13 @@ +# Standard import json import os -import torch +# Third Party from huggingface_hub import split_torch_state_dict_into_shards from safetensors import safe_open from safetensors.torch import save_file from transformers.modeling_utils import SAFE_WEIGHTS_INDEX_NAME +import torch class SafeTensorsWeightsManager: @@ -33,7 +35,10 @@ def get_slice(self, tensor_name: str): return f.get_slice(tensor_name) def get_tensor( - self, tensor_name: str, dtype: torch.dtype | None = None, device: torch.device | None = None + self, + tensor_name: str, + dtype: torch.dtype | None = None, + device: torch.device | None = None, ) -> torch.Tensor: filename = self.tensor_filenames[tensor_name] f = self.file_handles[filename] diff --git a/tox.ini b/tox.ini index 70c88a1..57ea095 100644 --- a/tox.ini +++ b/tox.ini @@ -2,7 +2,7 @@ [tox] # py3-unit runs unit tests with 'python3' -# py311-unit runs the same tests with 'python3.11' +# py312-unit runs the same tests with 'python3.12' envlist = ruff, lint, mypy, spellcheck minversion = 4.4