Skip to content

Commit efaa650

Browse files
committed
Rename wrap_model() -> init() (speechbrain#29)
1 parent 261c458 commit efaa650

30 files changed

+98
-80
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ This branch disentangles `adapter-transformers` from HF Transformers and adds Tr
2727
```
2828
- Built-in HF model classes can be adapted for usage with adapters via a wrapper method, e.g.:
2929
```python
30+
import adapters
3031
from transformers import BertModel
31-
from adapters import wrap_model
3232

3333
model = BertModel.from_pretrained("bert-base-uncased")
34-
model = wrap_model(model)
34+
adapters.init(model)
3535
```
3636

3737
### Model support

examples/pytorch/translation/run_translation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
import numpy as np
2929
from datasets import load_dataset
3030

31+
import adapters
3132
import evaluate
3233
import transformers
3334
from adapters import AdapterArguments, Seq2SeqAdapterTrainer, setup_adapter_training
34-
from adapters.wrappers.model import wrap_model
3535
from transformers import (
3636
AutoConfig,
3737
AutoModelForSeq2SeqLM,
@@ -389,7 +389,7 @@ def main():
389389
revision=model_args.model_revision,
390390
use_auth_token=True if model_args.use_auth_token else None,
391391
)
392-
model = wrap_model(model)
392+
adapters.init(model)
393393

394394
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
395395
# on a small vocab and want a smaller embedding size, remove this test.

src/adapters/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@
146146
"list_adapters",
147147
],
148148
"wrappers": [
149+
"init",
149150
"load_model",
150151
"wrap_config",
151-
"wrap_model",
152152
],
153153
}
154154

@@ -240,7 +240,7 @@
240240
get_adapter_info,
241241
list_adapters,
242242
)
243-
from .wrappers import load_model, wrap_config, wrap_model
243+
from .wrappers import init, load_model, wrap_config
244244

245245
else:
246246
import sys

src/adapters/models/albert/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
TaggingHead,
1818
)
1919
from ...model_mixin import EmbeddingAdaptersWrapperMixin
20-
from ...wrappers import wrap_model
20+
from ...wrappers import init
2121

2222

2323
@add_start_docstrings(
@@ -28,7 +28,8 @@ class AlbertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAd
2828
def __init__(self, config):
2929
super().__init__(config)
3030

31-
self.albert = wrap_model(AlbertModel(config))
31+
self.albert = AlbertModel(config)
32+
init(self.albert)
3233

3334
self._init_head_modules()
3435

src/adapters/models/bart/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
Seq2SeqLMHead,
2222
)
2323
from ...model_mixin import EmbeddingAdaptersWrapperMixin
24-
from ...wrappers import wrap_model
24+
from ...wrappers import init
2525

2626

2727
@add_start_docstrings(
@@ -32,7 +32,8 @@ class BartAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdap
3232

3333
def __init__(self, config: BartConfig, **kwargs):
3434
super().__init__(config, **kwargs)
35-
self.model = wrap_model(BartModel(config))
35+
self.model = BartModel(config)
36+
init(self.model)
3637

3738
self._init_head_modules()
3839

src/adapters/models/beit/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ...context import AdapterSetup
1414
from ...heads import ImageClassificationHead, ModelWithFlexibleHeadsAdaptersMixin
15-
from ...wrappers import wrap_model
15+
from ...wrappers import init
1616

1717

1818
@add_start_docstrings(
@@ -23,7 +23,8 @@ class BeitAdapterModel(ModelWithFlexibleHeadsAdaptersMixin, BeitPreTrainedModel)
2323
def __init__(self, config):
2424
super().__init__(config)
2525

26-
self.beit = wrap_model(BeitModel(config))
26+
self.beit = BeitModel(config)
27+
init(self.beit)
2728

2829
self._init_head_modules()
2930

src/adapters/models/bert/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
TaggingHead,
2222
)
2323
from ...model_mixin import EmbeddingAdaptersWrapperMixin
24-
from ...wrappers import wrap_model
24+
from ...wrappers import init
2525

2626

2727
@add_start_docstrings(
@@ -32,7 +32,8 @@ class BertAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdap
3232
def __init__(self, config):
3333
super().__init__(config)
3434

35-
self.bert = wrap_model(BertModel(config))
35+
self.bert = BertModel(config)
36+
init(self.bert)
3637

3738
self._init_head_modules()
3839

src/adapters/models/bert_generation/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ...context import AdapterSetup
1010
from ...heads import BertStyleMaskedLMHead, CausalLMHead, ModelWithFlexibleHeadsAdaptersMixin
1111
from ...model_mixin import EmbeddingAdaptersWrapperMixin
12-
from ...wrappers import wrap_model
12+
from ...wrappers import init
1313

1414

1515
@add_start_docstrings(
@@ -24,7 +24,8 @@ class BertGenerationAdapterModel(
2424
def __init__(self, config):
2525
super().__init__(config)
2626

27-
self.bert = wrap_model(BertGenerationEncoder(config))
27+
self.bert = BertGenerationEncoder(config)
28+
init(self.bert)
2829

2930
self._init_head_modules()
3031

src/adapters/models/clip/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,16 @@
1313
from ...context import AdapterSetup
1414
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
1515
from ...model_mixin import EmbeddingAdaptersWrapperMixin
16-
from ...wrappers import wrap_model
16+
from ...wrappers import init
1717

1818

1919
@add_start_docstrings(CLIP_START_DOCSTRING)
2020
class CLIPAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, CLIPPreTrainedModel):
2121
def __init__(self, config):
2222
super().__init__(config)
2323

24-
self.clip = wrap_model(CLIPModel(config))
24+
self.clip = CLIPModel(config)
25+
init(self.clip)
2526

2627
self._init_head_modules()
2728

src/adapters/models/deberta/adapter_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
TaggingHead,
1212
)
1313
from ...model_mixin import EmbeddingAdaptersWrapperMixin
14-
from ...wrappers import wrap_model
14+
from ...wrappers import init
1515

1616

1717
@add_start_docstrings(
@@ -23,7 +23,8 @@ class DebertaAdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsA
2323
def __init__(self, config):
2424
super().__init__(config)
2525

26-
self.deberta = wrap_model(DebertaModel(config))
26+
self.deberta = DebertaModel(config)
27+
init(self.deberta)
2728

2829
self._init_head_modules()
2930

0 commit comments

Comments
 (0)