Skip to content

Commit acb3700

Browse files
committed
classifier_wrapper -> classifier + readme improvements
1 parent d6466aa commit acb3700

File tree

8 files changed

+91
-224
lines changed

8 files changed

+91
-224
lines changed

README.md

Lines changed: 3 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22

33
A unified, extensible framework for text classification built on [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
44

5-
6-
75
## 🚀 Features
86

97
- **Unified API**: Consistent interface for different classifier wrappers
@@ -114,52 +112,6 @@ classifier.build(X_train, y_train)
114112
```
115113

116114

117-
## 🔧 Advanced Usage
118-
119-
### Custom Configuration
120-
121-
```python
122-
from torchTextClassifiers import torchTextClassifiers
123-
from torchTextClassifiers.classifiers.fasttext.config import FastTextConfig
124-
from torchTextClassifiers.classifiers.fasttext.wrapper import FastTextWrapper
125-
126-
# Create custom configuration
127-
config = FastTextConfig(
128-
embedding_dim=200,
129-
sparse=True,
130-
num_tokens=20000,
131-
min_count=3,
132-
min_n=2,
133-
max_n=8,
134-
len_word_ngrams=3,
135-
num_classes=5,
136-
direct_bagging=False, # Custom FastText parameter
137-
)
138-
139-
# Create classifier with custom config
140-
wrapper = FastTextWrapper(config)
141-
classifier = torchTextClassifiers(wrapper)
142-
```
143-
144-
### Using Pre-trained Tokenizers
145-
146-
```python
147-
from torchTextClassifiers import build_fasttext_from_tokenizer
148-
149-
# Assume you have a pre-trained tokenizer
150-
# my_tokenizer = ... (previously trained NGramTokenizer)
151-
152-
classifier = build_fasttext_from_tokenizer(
153-
tokenizer=my_tokenizer,
154-
embedding_dim=100,
155-
num_classes=3,
156-
sparse=False
157-
)
158-
159-
# Model and tokenizer are already built, ready for training
160-
classifier.train(X_train, y_train, X_val, y_val, ...)
161-
```
162-
163115
### Training Customization
164116

165117
```python
@@ -181,67 +133,6 @@ classifier.train(
181133
)
182134
```
183135

184-
## 📊 API Reference
185-
186-
### Main Classes
187-
188-
#### `torchTextClassifiers`
189-
The main classifier class providing a unified interface.
190-
191-
**Key Methods:**
192-
- `build(X_train, y_train)`: Build text preprocessing and model
193-
- `train(X_train, y_train, X_val, y_val, ...)`: Train the model
194-
- `predict(X)`: Make predictions
195-
- `validate(X, Y)`: Evaluate on test data
196-
- `to_json(filepath)`: Save configuration
197-
- `from_json(filepath)`: Load configuration
198-
199-
#### `BaseClassifierWrapper`
200-
Base class for all classifier wrappers. Each classifier implementation extends this class.
201-
202-
#### `FastTextWrapper`
203-
Wrapper for FastText classifier implementation with tokenization-based preprocessing.
204-
205-
### FastText Specific
206-
207-
#### `create_fasttext(**kwargs)`
208-
Convenience function to create FastText classifiers.
209-
210-
**Parameters:**
211-
- `embedding_dim`: Embedding dimension
212-
- `sparse`: Use sparse embeddings
213-
- `num_tokens`: Vocabulary size
214-
- `min_count`: Minimum token frequency
215-
- `min_n`, `max_n`: Character n-gram range
216-
- `len_word_ngrams`: Word n-gram length
217-
- `num_classes`: Number of output classes
218-
219-
#### `build_fasttext_from_tokenizer(tokenizer, **kwargs)`
220-
Create FastText classifier from existing tokenizer.
221-
222-
## 🏗️ Architecture
223-
224-
The framework follows a wrapper-based architecture:
225-
226-
```
227-
torchTextClassifiers/
228-
├── torchTextClassifiers.py # Main classifier interface
229-
├── classifiers/
230-
│ ├── base.py # Abstract base wrapper classes
231-
│ ├── fasttext/ # FastText implementation
232-
│ │ ├── config.py # Configuration
233-
│ │ ├── wrapper.py # FastText wrapper (tokenization)
234-
│ │ ├── factory.py # Convenience methods
235-
│ │ ├── tokenizer.py # N-gram tokenizer
236-
│ │ ├── pytorch_model.py # PyTorch model
237-
│ │ ├── lightning_module.py # Lightning module
238-
│ │ └── dataset.py # Dataset implementation
239-
│ └── simple_text_classifier.py # Example TF-IDF wrapper
240-
├── utilities/
241-
│ └── checkers.py # Input validation utilities
242-
└── factories.py # Convenience factory functions
243-
```
244-
245136
## 🔬 Testing
246137

247138
Run the test suite:
@@ -257,24 +148,6 @@ uv run pytest --cov=torchTextClassifiers
257148
uv run pytest tests/test_torchTextClassifiers.py -v
258149
```
259150

260-
## 🤝 Contributing
261-
262-
We welcome contributions! See our [Developer Guide](docs/developer_guide.md) for information on:
263-
264-
- Adding new classifier types
265-
- Code organization and patterns
266-
- Testing requirements
267-
- Documentation standards
268-
269-
## 📄 License
270-
271-
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
272-
273-
## 🙏 Acknowledgments
274-
275-
- Built with [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/)
276-
- Inspired by [FastText](https://fasttext.cc/) for efficient text classification
277-
- Uses [uv](https://github.com/astral-sh/uv) for dependency management
278151

279152
## 📚 Examples
280153

@@ -285,14 +158,8 @@ See the [examples/](examples/) directory for:
285158
- Custom classifier implementation
286159
- Advanced training configurations
287160

288-
## 🐛 Support
289161

290-
If you encounter any issues:
291162

292-
1. Check the [examples](examples/) for similar use cases
293-
2. Review the API documentation above
294-
3. Open an issue on GitHub with:
295-
- Python version
296-
- Package versions (`uv tree` or `pip list`)
297-
- Minimal reproduction code
298-
- Error messages/stack traces
163+
## 📄 License
164+
165+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.

examples/using_additional_features.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,15 @@ def train_and_evaluate_model(X, y, model_name, use_categorical=False, use_simple
107107
)
108108
wrapper = SimpleTextWrapper(simple_text_config)
109109
classifier = torchTextClassifiers(wrapper)
110-
print(f"Classifier type: {type(classifier.classifier_wrapper).__name__}")
111-
print(f"Uses tokenizer: {hasattr(classifier.classifier_wrapper, 'tokenizer')}")
112-
print(f"Uses vectorizer: {hasattr(classifier.classifier_wrapper, 'vectorizer')}")
110+
print(f"Classifier type: {type(classifier.classifier).__name__}")
111+
print(f"Uses tokenizer: {hasattr(classifier.classifier, 'tokenizer')}")
112+
print(f"Uses vectorizer: {hasattr(classifier.classifier, 'vectorizer')}")
113113

114114
# Build the model (this will use TF-IDF vectorization instead of tokenization)
115115
print("\n🔨 Building model with TF-IDF preprocessing...")
116116
classifier.build(X_train, y_train)
117117
print("✅ Model built successfully!")
118-
print(f"TF-IDF features: {len(classifier.classifier_wrapper.vectorizer.get_feature_names_out())}")
118+
print(f"TF-IDF features: {len(classifier.classifier.vectorizer.get_feature_names_out())}")
119119

120120
# Train the model
121121
print("\n🎯 Training model...")

notebooks/example.ipynb

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -900,7 +900,7 @@
900900
},
901901
{
902902
"cell_type": "code",
903-
"execution_count": 10,
903+
"execution_count": null,
904904
"id": "ebf5608b",
905905
"metadata": {},
906906
"outputs": [
@@ -916,7 +916,7 @@
916916
}
917917
],
918918
"source": [
919-
"type(model.classifier_wrapper)"
919+
"type(model.classifier)"
920920
]
921921
},
922922
{
@@ -1002,7 +1002,7 @@
10021002
},
10031003
{
10041004
"cell_type": "code",
1005-
"execution_count": 12,
1005+
"execution_count": null,
10061006
"id": "091024e6",
10071007
"metadata": {},
10081008
"outputs": [
@@ -1027,12 +1027,12 @@
10271027
}
10281028
],
10291029
"source": [
1030-
"model.classifier_wrapper.pytorch_model"
1030+
"model.classifier.pytorch_model"
10311031
]
10321032
},
10331033
{
10341034
"cell_type": "code",
1035-
"execution_count": 13,
1035+
"execution_count": null,
10361036
"id": "d983b113",
10371037
"metadata": {},
10381038
"outputs": [
@@ -1048,12 +1048,12 @@
10481048
}
10491049
],
10501050
"source": [
1051-
"model.classifier_wrapper.tokenizer"
1051+
"model.classifier.tokenizer"
10521052
]
10531053
},
10541054
{
10551055
"cell_type": "code",
1056-
"execution_count": 14,
1056+
"execution_count": null,
10571057
"id": "9b23f1ba",
10581058
"metadata": {},
10591059
"outputs": [
@@ -1082,7 +1082,7 @@
10821082
}
10831083
],
10841084
"source": [
1085-
"model.classifier_wrapper.lightning_module"
1085+
"model.classifier.lightning_module"
10861086
]
10871087
},
10881088
{
@@ -1097,7 +1097,7 @@
10971097
},
10981098
{
10991099
"cell_type": "code",
1100-
"execution_count": 15,
1100+
"execution_count": null,
11011101
"id": "00c077b0",
11021102
"metadata": {},
11031103
"outputs": [
@@ -1172,7 +1172,7 @@
11721172
"source": [
11731173
"from pprint import pprint \n",
11741174
"sentence = [\"lorem ipsum dolor sit amet\"]\n",
1175-
"pprint(model.classifier_wrapper.tokenizer.tokenize(sentence)[2][0])"
1175+
"pprint(model.classifier.tokenizer.tokenize(sentence)[2][0])"
11761176
]
11771177
},
11781178
{
@@ -1208,7 +1208,7 @@
12081208
"loaded_model = torchTextClassifiers.from_json('torchTextClassifiers_config.json')\n",
12091209
"\n",
12101210
"print(\"✅ Model loaded from JSON successfully!\")\n",
1211-
"print(f\"Loaded wrapper type: {type(loaded_model.classifier_wrapper).__name__}\")\n",
1211+
"print(f\"Loaded wrapper type: {type(loaded_model.classifier).__name__}\")\n",
12121212
"print(f\"Config parameters: embedding_dim={loaded_model.config.embedding_dim}, sparse={loaded_model.config.sparse}\")\n",
12131213
"\n",
12141214
"# The loaded model needs to be built before use\n",
@@ -1296,7 +1296,7 @@
12961296
},
12971297
{
12981298
"cell_type": "code",
1299-
"execution_count": 17,
1299+
"execution_count": null,
13001300
"id": "g0rmedya9eb",
13011301
"metadata": {},
13021302
"outputs": [
@@ -1350,7 +1350,7 @@
13501350
"direct_model.build(X_train, y_train, lightning=True, lr=parameters_train.get(\"lr\"))\n",
13511351
"\n",
13521352
"print(\"✅ Direct wrapper model created successfully!\")\n",
1353-
"print(f\"Model type: {type(direct_model.classifier_wrapper).__name__}\")\n",
1353+
"print(f\"Model type: {type(direct_model.classifier).__name__}\")\n",
13541354
"print(f\"Config type: {type(direct_model.config).__name__}\")"
13551355
]
13561356
},

tests/test_core_functionality.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_torchTextClassifiers_initialization_pattern():
7979
classifier = torchTextClassifiers(mock_wrapper)
8080

8181
# Verify initialization
82-
assert classifier.classifier_wrapper == mock_wrapper
82+
assert classifier.classifier == mock_wrapper
8383
assert classifier.config == mock_config
8484

8585

@@ -123,7 +123,7 @@ def test_create_fasttext_classmethod():
123123

124124
# Verify the result is a proper torchTextClassifiers instance
125125
assert isinstance(result, torchTextClassifiers)
126-
assert isinstance(result.classifier_wrapper, FastTextWrapper)
126+
assert isinstance(result.classifier, FastTextWrapper)
127127
assert result.config.embedding_dim == 50
128128
assert result.config.sparse == True
129129
assert result.config.num_tokens == 5000
@@ -135,17 +135,17 @@ def test_method_delegation_pattern():
135135

136136
# Create a mock instance
137137
classifier = Mock(spec=torchTextClassifiers)
138-
classifier.classifier_wrapper = Mock()
138+
classifier.classifier = Mock()
139139

140140
# Test predict delegation
141141
expected_result = np.array([1, 0, 1])
142-
classifier.classifier_wrapper.predict.return_value = expected_result
142+
classifier.classifier.predict.return_value = expected_result
143143

144144
# Apply the real predict method to our mock
145145
sample_X = np.array(["test1", "test2", "test3"])
146146
result = torchTextClassifiers.predict(classifier, sample_X)
147147

148-
classifier.classifier_wrapper.predict.assert_called_once_with(sample_X)
148+
classifier.classifier.predict.assert_called_once_with(sample_X)
149149
assert result is expected_result
150150

151151

0 commit comments

Comments
 (0)