diff --git a/jiant/proj/main/modeling/primary.py b/jiant/proj/main/modeling/primary.py index 7d714546f..9b85e8406 100644 --- a/jiant/proj/main/modeling/primary.py +++ b/jiant/proj/main/modeling/primary.py @@ -263,6 +263,64 @@ def get_mlm_weights_dict(self, weights_dict): mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} return mlm_weights_dict +@JiantTransformersModelFactory.register(ModelArchitectures.DISTILBERT) +class JiantBertModel(JiantTransformersModel): + def __init__(self, baseObject): + super().__init__(baseObject) + + @classmethod + def normalize_tokenizations(cls, tokenizer, space_tokenization, target_tokenization): + """See tokenization_normalization.py for details""" + if tokenizer.init_kwargs.get("do_lower_case", False): + space_tokenization = [token.lower() for token in space_tokenization] + modifed_space_tokenization = bow_tag_tokens(space_tokenization) + modifed_target_tokenization = process_wordpiece_tokens(target_tokenization) + + return modifed_space_tokenization, modifed_target_tokenization + + def get_feat_spec(self, max_seq_length): + return FeaturizationSpec( + max_seq_length=max_seq_length, + cls_token_at_end=False, + pad_on_left=False, + cls_token_segment_id=0, + pad_token_segment_id=0, + pad_token_id=0, + pad_token_mask_id=0, + sequence_a_segment_id=0, + sequence_b_segment_id=1, + sep_token_extra=False, + ) + + def get_mlm_weights_dict(self, weights_dict): + mlm_weights_map = { + "bias": "cls.predictions.bias", + "dense.weight": "cls.predictions.transform.dense.weight", + "dense.bias": "cls.predictions.transform.dense.bias", + "LayerNorm.weight": "cls.predictions.transform.LayerNorm.weight", + "LayerNorm.bias": "cls.predictions.transform.LayerNorm.bias", + "decoder.weight": "cls.predictions.decoder.weight", + "decoder.bias": "cls.predictions.bias", # <-- linked directly to bias + } + mlm_weights_dict = {new_k: weights_dict[old_k] for new_k, old_k in mlm_weights_map.items()} + return mlm_weights_dict + + def get_hidden_dropout_prob(self): + return 0.1 #Default config + + def encode(self, input_ids, segment_ids, input_mask, output_hidden_states=True): + output = self.forward( + input_ids=input_ids, + #token_type_ids=segment_ids, + attention_mask=input_mask, + output_hidden_states=output_hidden_states, + ) + return JiantModelOutput( + pooled=output.last_hidden_state[:, 0, :], + unpooled=output.last_hidden_state, + other=output.hidden_states, + ) + @JiantTransformersModelFactory.register(ModelArchitectures.ROBERTA) class JiantRobertaModel(JiantTransformersModel): diff --git a/jiant/shared/model_resolution.py b/jiant/shared/model_resolution.py index ef9e2ed0e..3b63d0650 100644 --- a/jiant/shared/model_resolution.py +++ b/jiant/shared/model_resolution.py @@ -15,6 +15,7 @@ class ModelArchitectures(Enum): MBART = "mbart" ELECTRA = "electra" DEBERTAV2 = "deberta-v2" + DISTILBERT = "distilbert" @classmethod def from_model_type(cls, model_type: str): @@ -38,6 +39,7 @@ def get_encoder_prefix(self): ModelArchitectures.MBART: transformers.MBartTokenizer, ModelArchitectures.ELECTRA: transformers.ElectraTokenizer, ModelArchitectures.DEBERTAV2: transformers.DebertaV2Tokenizer, + ModelArchitectures.DISTILBERT: transformers.DistilBertTokenizer, } )