1
1
# torchTextClassifiers
2
2
3
- A unified, extensible framework for text classification using PyTorch and PyTorch Lightning.
3
+ A unified, extensible framework for text classification built on [ PyTorch] ( https://pytorch.org/ ) and [ PyTorch Lightning] ( https://lightning.ai/docs/pytorch/stable/ ) .
4
+
5
+
4
6
5
7
## 🚀 Features
6
8
7
- - ** Unified API** : Consistent interface for different classifier types
8
- - ** FastText Support** : Built-in FastText classifier implementation
9
+ - ** Unified API** : Consistent interface for different classifier wrappers
10
+ - ** Extensible** : Easy to add new classifier implementations through wrapper pattern
11
+ - ** FastText Support** : Built-in FastText classifier with n-gram tokenization
12
+ - ** Flexible Preprocessing** : Each classifier can implement its own text preprocessing approach
9
13
- ** PyTorch Lightning** : Automated training with callbacks, early stopping, and logging
10
- - ** Mixed Features** : Support for both text and categorical features
11
- - ** Extensible** : Easy to add new classifier types
12
- - ** Production Ready** : Model serialization, validation, and inference
14
+
13
15
14
16
## 📦 Installation
15
17
16
18
``` bash
17
19
# Clone the repository
18
- git clone https://github.com/your-repo/torch-fastText .git
19
- cd torch-fastText
20
+ git clone https://github.com/InseeFrLab/torchTextClassifiers .git
21
+ cd torchtextClassifiers
20
22
21
23
# Install with uv (recommended)
22
24
uv sync
@@ -82,47 +84,44 @@ accuracy = classifier.validate(X_test, np.array([1]))
82
84
print (f " Accuracy: { accuracy:.3f } " )
83
85
```
84
86
85
- ### Working with Mixed Features (Text + Categorical)
87
+ ### Custom Classifier Implementation
86
88
87
89
``` python
88
90
import numpy as np
89
- from torchTextClassifiers import create_fasttext
91
+ from torchTextClassifiers import torchTextClassifiers
92
+ from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
90
93
91
- # Text data with categorical features
92
- X_train = np.column_stack([
93
- np.array([" Great product!" , " Terrible service" , " Love it!" ]), # Text
94
- np.array([[1 , 2 ], [2 , 1 ], [1 , 3 ]]) # Categorical features
95
- ])
96
- y_train = np.array([1 , 0 , 1 ])
97
-
98
- # Create classifier with categorical support
99
- classifier = create_fasttext(
100
- embedding_dim = 50 ,
101
- sparse = False ,
102
- num_tokens = 5000 ,
103
- min_count = 1 ,
104
- min_n = 3 ,
105
- max_n = 6 ,
106
- len_word_ngrams = 2 ,
94
+ # Example: TF-IDF based classifier (alternative to tokenization)
95
+ config = SimpleTextConfig(
96
+ hidden_dim = 128 ,
107
97
num_classes = 2 ,
108
- categorical_vocabulary_sizes = [3 , 4 ], # Vocab sizes for categorical features
109
- categorical_embedding_dims = [10 , 10 ] # Embedding dims for categorical features
98
+ max_features = 5000 ,
99
+ learning_rate = 1e-3 ,
100
+ dropout_rate = 0.2
110
101
)
111
102
112
- # Build and train as usual
103
+ # Create classifier with TF-IDF preprocessing
104
+ wrapper = SimpleTextWrapper(config)
105
+ classifier = torchTextClassifiers(wrapper)
106
+
107
+ # Text data
108
+ X_train = np.array([" Great product!" , " Terrible service" , " Love it!" ])
109
+ y_train = np.array([1 , 0 , 1 ])
110
+
111
+ # Build and train
113
112
classifier.build(X_train, y_train)
114
113
# ... continue with training
115
114
```
116
115
117
116
118
-
119
117
## 🔧 Advanced Usage
120
118
121
119
### Custom Configuration
122
120
123
121
``` python
124
- from torchTextClassifiers import torchTextClassifiers, ClassifierType
122
+ from torchTextClassifiers import torchTextClassifiers
125
123
from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
124
+ from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
126
125
127
126
# Create custom configuration
128
127
config = FastTextConfig(
@@ -138,7 +137,8 @@ config = FastTextConfig(
138
137
)
139
138
140
139
# Create classifier with custom config
141
- classifier = torchTextClassifiers(ClassifierType.FASTTEXT , config)
140
+ wrapper = FastTextWrapper(config)
141
+ classifier = torchTextClassifiers(wrapper)
142
142
```
143
143
144
144
### Using Pre-trained Tokenizers
@@ -189,19 +189,18 @@ classifier.train(
189
189
The main classifier class providing a unified interface.
190
190
191
191
** Key Methods:**
192
- - ` build(X_train, y_train) ` : Build tokenizer and model
192
+ - ` build(X_train, y_train) ` : Build text preprocessing and model
193
193
- ` train(X_train, y_train, X_val, y_val, ...) ` : Train the model
194
194
- ` predict(X) ` : Make predictions
195
195
- ` validate(X, Y) ` : Evaluate on test data
196
196
- ` to_json(filepath) ` : Save configuration
197
197
- ` from_json(filepath) ` : Load configuration
198
198
199
- #### ` ClassifierType `
200
- Enumeration of supported classifier types.
201
- - ` FASTTEXT ` : FastText classifier
199
+ #### ` BaseClassifierWrapper `
200
+ Base class for all classifier wrappers. Each classifier implementation extends this class.
202
201
203
- #### ` ClassifierFactory `
204
- Factory for creating classifier instances .
202
+ #### ` FastTextWrapper `
203
+ Wrapper for FastText classifier implementation with tokenization-based preprocessing .
205
204
206
205
### FastText Specific
207
206
@@ -222,24 +221,25 @@ Create FastText classifier from existing tokenizer.
222
221
223
222
## 🏗️ Architecture
224
223
225
- The framework follows a modular architecture:
224
+ The framework follows a wrapper-based architecture:
226
225
227
226
```
228
227
torchTextClassifiers/
229
228
├── torchTextClassifiers.py # Main classifier interface
230
229
├── classifiers/
231
- │ ├── base.py # Abstract base classes
232
- │ └── fasttext/ # FastText implementation
233
- │ ├── config.py # Configuration
234
- │ ├── wrapper.py # Classifier wrapper
235
- │ ├── factory.py # Convenience methods
236
- │ ├── tokenizer.py # N-gram tokenizer
237
- │ ├── pytorch_model.py # PyTorch model
238
- │ ├── lightning_module.py # Lightning module
239
- │ └── dataset.py # Dataset implementation
230
+ │ ├── base.py # Abstract base wrapper classes
231
+ │ ├── fasttext/ # FastText implementation
232
+ │ │ ├── config.py # Configuration
233
+ │ │ ├── wrapper.py # FastText wrapper (tokenization)
234
+ │ │ ├── factory.py # Convenience methods
235
+ │ │ ├── tokenizer.py # N-gram tokenizer
236
+ │ │ ├── pytorch_model.py # PyTorch model
237
+ │ │ ├── lightning_module.py # Lightning module
238
+ │ │ └── dataset.py # Dataset implementation
239
+ │ └── simple_text_classifier.py # Example TF-IDF wrapper
240
240
├── utilities/
241
241
│ └── checkers.py # Input validation utilities
242
- └── factories.py # Generic factory system
242
+ └── factories.py # Convenience factory functions
243
243
```
244
244
245
245
## 🔬 Testing
0 commit comments