Skip to content

Commit d6466aa

Browse files
committed
refactor: remove classifier registry and implement wrapper-based architecture
1 parent 8123b70 commit d6466aa

26 files changed

+1085
-4953
lines changed

README.md

Lines changed: 49 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,24 @@
11
# torchTextClassifiers
22

3-
A unified, extensible framework for text classification using PyTorch and PyTorch Lightning.
3+
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
4+
5+
46

57
## 🚀 Features
68

7-
- **Unified API**: Consistent interface for different classifier types
8-
- **FastText Support**: Built-in FastText classifier implementation
9+
- **Unified API**: Consistent interface for different classifier wrappers
10+
- **Extensible**: Easy to add new classifier implementations through wrapper pattern
11+
- **FastText Support**: Built-in FastText classifier with n-gram tokenization
12+
- **Flexible Preprocessing**: Each classifier can implement its own text preprocessing approach
913
- **PyTorch Lightning**: Automated training with callbacks, early stopping, and logging
10-
- **Mixed Features**: Support for both text and categorical features
11-
- **Extensible**: Easy to add new classifier types
12-
- **Production Ready**: Model serialization, validation, and inference
14+
1315

1416
## 📦 Installation
1517

1618
```bash
1719
# Clone the repository
18-
git clone https://github.com/your-repo/torch-fastText.git
19-
cd torch-fastText
20+
git clone https://github.com/InseeFrLab/torchTextClassifiers.git
21+
cd torchtextClassifiers
2022

2123
# Install with uv (recommended)
2224
uv sync
@@ -82,47 +84,44 @@ accuracy = classifier.validate(X_test, np.array([1]))
8284
print(f"Accuracy: {accuracy:.3f}")
8385
```
8486

85-
### Working with Mixed Features (Text + Categorical)
87+
### Custom Classifier Implementation
8688

8789
```python
8890
import numpy as np
89-
from torchTextClassifiers import create_fasttext
91+
from torchTextClassifiers import torchTextClassifiers
92+
from torchTextClassifiers.classifiers.simple_text_classifier import SimpleTextWrapper, SimpleTextConfig
9093

91-
# Text data with categorical features
92-
X_train = np.column_stack([
93-
np.array(["Great product!", "Terrible service", "Love it!"]), # Text
94-
np.array([[1, 2], [2, 1], [1, 3]]) # Categorical features
95-
])
96-
y_train = np.array([1, 0, 1])
97-
98-
# Create classifier with categorical support
99-
classifier = create_fasttext(
100-
embedding_dim=50,
101-
sparse=False,
102-
num_tokens=5000,
103-
min_count=1,
104-
min_n=3,
105-
max_n=6,
106-
len_word_ngrams=2,
94+
# Example: TF-IDF based classifier (alternative to tokenization)
95+
config = SimpleTextConfig(
96+
hidden_dim=128,
10797
num_classes=2,
108-
categorical_vocabulary_sizes=[3, 4], # Vocab sizes for categorical features
109-
categorical_embedding_dims=[10, 10] # Embedding dims for categorical features
98+
max_features=5000,
99+
learning_rate=1e-3,
100+
dropout_rate=0.2
110101
)
111102

112-
# Build and train as usual
103+
# Create classifier with TF-IDF preprocessing
104+
wrapper = SimpleTextWrapper(config)
105+
classifier = torchTextClassifiers(wrapper)
106+
107+
# Text data
108+
X_train = np.array(["Great product!", "Terrible service", "Love it!"])
109+
y_train = np.array([1, 0, 1])
110+
111+
# Build and train
113112
classifier.build(X_train, y_train)
114113
# ... continue with training
115114
```
116115

117116

118-
119117
## 🔧 Advanced Usage
120118

121119
### Custom Configuration
122120

123121
```python
124-
from torchTextClassifiers import torchTextClassifiers, ClassifierType
122+
from torchTextClassifiers import torchTextClassifiers
125123
from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
124+
from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
126125

127126
# Create custom configuration
128127
config = FastTextConfig(
@@ -138,7 +137,8 @@ config = FastTextConfig(
138137
)
139138

140139
# Create classifier with custom config
141-
classifier = torchTextClassifiers(ClassifierType.FASTTEXT, config)
140+
wrapper = FastTextWrapper(config)
141+
classifier = torchTextClassifiers(wrapper)
142142
```
143143

144144
### Using Pre-trained Tokenizers
@@ -189,19 +189,18 @@ classifier.train(
189189
The main classifier class providing a unified interface.
190190

191191
**Key Methods:**
192-
- `build(X_train, y_train)`: Build tokenizer and model
192+
- `build(X_train, y_train)`: Build text preprocessing and model
193193
- `train(X_train, y_train, X_val, y_val, ...)`: Train the model
194194
- `predict(X)`: Make predictions
195195
- `validate(X, Y)`: Evaluate on test data
196196
- `to_json(filepath)`: Save configuration
197197
- `from_json(filepath)`: Load configuration
198198

199-
#### `ClassifierType`
200-
Enumeration of supported classifier types.
201-
- `FASTTEXT`: FastText classifier
199+
#### `BaseClassifierWrapper`
200+
Base class for all classifier wrappers. Each classifier implementation extends this class.
202201

203-
#### `ClassifierFactory`
204-
Factory for creating classifier instances.
202+
#### `FastTextWrapper`
203+
Wrapper for FastText classifier implementation with tokenization-based preprocessing.
205204

206205
### FastText Specific
207206

@@ -222,24 +221,25 @@ Create FastText classifier from existing tokenizer.
222221

223222
## 🏗️ Architecture
224223

225-
The framework follows a modular architecture:
224+
The framework follows a wrapper-based architecture:
226225

227226
```
228227
torchTextClassifiers/
229228
├── torchTextClassifiers.py # Main classifier interface
230229
├── classifiers/
231-
│ ├── base.py # Abstract base classes
232-
│ └── fasttext/ # FastText implementation
233-
│ ├── config.py # Configuration
234-
│ ├── wrapper.py # Classifier wrapper
235-
│ ├── factory.py # Convenience methods
236-
│ ├── tokenizer.py # N-gram tokenizer
237-
│ ├── pytorch_model.py # PyTorch model
238-
│ ├── lightning_module.py # Lightning module
239-
│ └── dataset.py # Dataset implementation
230+
│ ├── base.py # Abstract base wrapper classes
231+
│ ├── fasttext/ # FastText implementation
232+
│ │ ├── config.py # Configuration
233+
│ │ ├── wrapper.py # FastText wrapper (tokenization)
234+
│ │ ├── factory.py # Convenience methods
235+
│ │ ├── tokenizer.py # N-gram tokenizer
236+
│ │ ├── pytorch_model.py # PyTorch model
237+
│ │ ├── lightning_module.py # Lightning module
238+
│ │ └── dataset.py # Dataset implementation
239+
│ └── simple_text_classifier.py # Example TF-IDF wrapper
240240
├── utilities/
241241
│ └── checkers.py # Input validation utilities
242-
└── factories.py # Generic factory system
242+
└── factories.py # Convenience factory functions
243243
```
244244

245245
## 🔬 Testing

docs/README.md

Lines changed: 0 additions & 86 deletions
This file was deleted.

docs/source/_static/custom.css

Lines changed: 0 additions & 101 deletions
This file was deleted.

0 commit comments

Comments
 (0)