Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ nosetests.xml
coverage.xml
*,cover
.hypothesis/
*.swp

# Translations
*.mo
Expand Down
65 changes: 65 additions & 0 deletions examples/feature_weighting/plot_tfigm_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# License: BSD 3 clause
#
# Authors: Roman Yurchak <[email protected]>

import pandas as pd

from sklearn.svm import LinearSVC
from sklearn.preprocessing import Normalizer, FunctionTransformer
from sklearn.pipeline import make_pipeline
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import cross_validate
from sklearn.metrics import f1_score

from sklearn_extra.feature_weighting import TfigmTransformer

if "CI" in os.environ:
# make this example run faster in CI
categories = ["sci.crypt", "comp.graphics", "comp.sys.mac.hardware"]
else:
categories = None

docs, y = fetch_20newsgroups(return_X_y=True, categories=categories)


vect = CountVectorizer(min_df=5, stop_words="english", ngram_range=(1, 1))
X = vect.fit_transform(docs)

res = []

for scaler_label, scaler in [
("TF", FunctionTransformer(lambda x: x)),
("TF-IDF(sublinear_tf=False)", TfidfTransformer()),
("TF-IDF(sublinear_tf=True)", TfidfTransformer(sublinear_tf=True)),
("TF-IGM(tf_scale=None)", TfigmTransformer()),
("TF-IGM(tf_scale='sqrt')", TfigmTransformer(tf_scale="sqrt"),),
("TF-IGM(tf_scale='log1p')", TfigmTransformer(tf_scale="log1p"),),
]:
pipe = make_pipeline(scaler, Normalizer())
X_tr = pipe.fit_transform(X, y)
est = LinearSVC()
scoring = {
"F1-macro": lambda est, X, y: f1_score(
y, est.predict(X), average="macro"
),
"balanced_accuracy": "balanced_accuracy",
}
scores = cross_validate(est, X_tr, y, scoring=scoring,)
for key, val in scores.items():
if not key.endswith("_time"):
res.append(
{
"metric": "_".join(key.split("_")[1:]),
"subset": key.split("_")[0],
"preprocessing": scaler_label,
"score": f"{val.mean():.3f}±{val.std():.3f}",
}
)
scores = (
pd.DataFrame(res)
.set_index(["preprocessing", "metric", "subset"])["score"]
.unstack(-1)
)
scores = scores["test"].unstack(-1).sort_values("F1-macro", ascending=False)
print(scores)
5 changes: 5 additions & 0 deletions sklearn_extra/feature_weighting/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# License: BSD 3 clause

from ._text import TfigmTransformer

__all__ = ["TfigmTransformer"]
190 changes: 190 additions & 0 deletions sklearn_extra/feature_weighting/_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# License: BSD 3 clause
#
# Authors: Roman Yurchak <[email protected]>

import numpy as np
import scipy.sparse as sp

from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_array, check_X_y
from sklearn.preprocessing import LabelEncoder


class TfigmTransformer(BaseEstimator, TransformerMixin):
"""TF-IGM feature weighting.

TF-IGM (Inverse Gravity Momentum) is a supervised
feature weighting scheme for classification tasks that measures
class distinguishing power.

See User Guide for mode details.

Parameters
----------
alpha : float, default=0.15
regularization parameter. Known good default values are 0.14 - 0.20.
tf_scale : {"sqrt", "log1p"}, default=None
if not None, scaling applied to term frequency. Possible scaling values are,
- "sqrt": square root scaling
- "log1p": ``log(1 + tf)`` scaling. This option corresponds to
``sublinear_tf=True`` parameter in
:class:`~sklearn.feature_extraction.text.TfidfTransformer`.

Attributes
----------
igm_ : array of shape (n_features,)
The Inverse Gravity Moment (IGM) weight.
coef_ : array of shape (n_features,)
Regularized IGM weight corresponding to ``alpha + igm_``

Examples
--------
>>> from sklearn.feature_extraction.text import CountVectorizer
>>> from sklearn.pipeline import Pipeline
>>> from sklearn_extra.feature_weighting import TfigmTransformer
>>> corpus = ['this is the first document',
... 'this document is the second document',
... 'and this is the third one',
... 'is this the first document']
>>> y = ['1', '2', '1', '2']
>>> pipe = Pipeline([('count', CountVectorizer()),
... ('tfigm', TfigmTransformer())]).fit(corpus, y)
>>> pipe['count'].transform(corpus).toarray()
array([[0, 1, 1, 1, 0, 0, 1, 0, 1],
[0, 2, 0, 1, 0, 1, 1, 0, 1],
[1, 0, 0, 1, 1, 0, 1, 1, 1],
[0, 1, 1, 1, 0, 0, 1, 0, 1]])
>>> pipe['tfigm'].igm_
array([1. , 0.25, 0. , 0. , 1. , 1. , 0. , 1. , 0. ])
>>> pipe['tfigm'].coef_
array([1.15, 0.4 , 0.15, 0.15, 1.15, 1.15, 0.15, 1.15, 0.15])
>>> pipe.transform(corpus).shape
(4, 9)

References
----------
Chen, Kewen, et al. "Turning from TF-IDF to TF-IGM for term weighting
in text classification." Expert Systems with Applications 66 (2016):
245-260.
"""

def __init__(self, alpha=0.15, tf_scale=None):
self.alpha = alpha
self.tf_scale = tf_scale

def _fit(self, X, y):
"""Learn the igm vector (global term weights)

Parameters
----------
X : {array-like, sparse matrix} of (n_samples, n_features)
a matrix of term/token counts
y : array-like of shape (n_samples,)
target classes
"""
self._le = LabelEncoder().fit(y)
n_class = len(self._le.classes_)
class_freq = np.zeros((n_class, X.shape[1]))

X_nz = X != 0
if sp.issparse(X_nz):
X_nz = X_nz.asformat("csr", copy=False)

for idx, class_label in enumerate(self._le.classes_):
y_mask = y == class_label
n_samples = y_mask.sum()
class_freq[idx, :] = X_nz[y_mask].sum(axis=0) / n_samples

self._class_freq = class_freq
class_freq_sort = np.sort(self._class_freq, axis=0)
f1 = class_freq_sort[-1, :]

fk = (class_freq_sort * np.arange(n_class, 0, -1)[:, None]).sum(axis=0)
# avoid division by zero
igm = np.divide(f1, fk, out=np.zeros_like(f1), where=(fk != 0))
if n_class > 1:
# scale weights to [0, 1]
self.igm_ = ((1 + n_class) * n_class * igm - 2) / (
(1 + n_class) * n_class - 2
)
else:
self.igm_ = igm
self.coef_ = self.alpha + self.igm_
return self

def fit(self, X, y):
"""Learn the igm vector (global term weights)

Parameters
----------
X : {array-like, sparse matrix} of (n_samples, n_features)
a matrix of term/token counts
y : array-like of shape (n_samples,)
target classes
"""
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
self._fit(X, y)
return self

def _transform(self, X):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
if self.tf_scale is None:
pass
elif self.tf_scale == "sqrt":
X = np.sqrt(X)
elif self.tf_scale == "log1p":
X = np.log1p(X)
else:
raise ValueError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A friendlier message here?


if sp.issparse(X):
X_tr = X @ sp.diags(self.coef_)
else:
X_tr = X * self.coef_[None, :]
return X_tr

def transform(self, X):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
X = check_array(X, accept_sparse=["csr", "csc"])
X_tr = self._transform(X)
return X_tr

def fit_transform(self, X, y):
"""Transform a count matrix to a TF-IGM representation

Parameters
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
a matrix of term/token counts

Returns
-------
vectors : {ndarray, sparse matrix} of shape (n_samples, n_features)
transformed matrix
"""
X, y = check_X_y(X, y, accept_sparse=["csr", "csc"])
self._fit(X, y)
X_tr = self._transform(X)
return X_tr
83 changes: 83 additions & 0 deletions sklearn_extra/feature_weighting/tests/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# License: BSD 3 clause

import numpy as np
from numpy.testing import assert_allclose, assert_array_less
import scipy.sparse as sp

import pytest

from sklearn_extra.feature_weighting import TfigmTransformer
from sklearn.datasets import make_classification


@pytest.mark.parametrize("array_format", ["dense", "csr", "csc", "coo"])
def test_tfigm_transform(array_format):
X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
if array_format != "dense":
X = sp.csr_matrix(X).asformat(array_format)
y = np.array(["a", "b", "a", "c"])

alpha = 0.2
est = TfigmTransformer(alpha=alpha)
X_tr = est.fit_transform(X, y)

assert_allclose(est.igm_, [0.20, 0.40, 0.0])
assert_allclose(est.igm_ + alpha, est.coef_)

assert X_tr.shape == X.shape
assert sp.issparse(X_tr) is (array_format != "dense")

if array_format == "dense":
assert_allclose(X * est.coef_[None, :], X_tr)
else:
assert_allclose(X.A * est.coef_[None, :], X_tr.A)


def test_tfigm_synthetic():
X, y = make_classification(
n_samples=100,
n_features=10,
n_informative=5,
n_redundant=0,
random_state=0,
n_classes=5,
shuffle=False,
)
X = (X > 0).astype(np.float)

est = TfigmTransformer()
est.fit(X, y)
# informative features have higher IGM weights than noisy ones.
# (athough here we lose a lot of information due to thresholding of X).
assert est.igm_[:5].mean() / est.igm_[5:].mean() > 3


@pytest.mark.parametrize("n_class", [2, 5])
def test_tfigm_random_distribution(n_class):
rng = np.random.RandomState(0)
n_samples, n_features = 500, 4
X = rng.randint(2, size=(n_samples, n_features))
y = rng.randint(n_class, size=(n_samples,))

est = TfigmTransformer()
X_tr = est.fit_transform(X, y)

# all weighs are strictly positive
assert_array_less(0, est.igm_)
# and close to zero, since none of the features are discriminant
assert_array_less(est.igm_, 0.05)


def test_tfigm_valid_target():
X = np.array([[0, 1, 1], [1, 0, 1], [0, 0, 1], [1, 1, 1]])
y = None

est = TfigmTransformer()
with pytest.raises(ValueError, match="y cannot be None"):
est.fit(X, y)

# check asymptotic behaviour for 1 class
y = [1, 1, 1, 1]
est = TfigmTransformer()
est.fit(X, y)
assert_allclose(est.igm_[0], np.ones(3))
9 changes: 8 additions & 1 deletion sklearn_extra/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@
from sklearn_extra.kernel_approximation import Fastfood
from sklearn_extra.kernel_methods import EigenProClassifier, EigenProRegressor
from sklearn_extra.cluster import KMedoids
from sklearn_extra.feature_weighting import TfigmTransformer

ALL_ESTIMATORS = [Fastfood, KMedoids, EigenProClassifier, EigenProRegressor]
ALL_ESTIMATORS = [
Fastfood,
KMedoids,
EigenProClassifier,
EigenProRegressor,
TfigmTransformer,
]

if hasattr(estimator_checks, "parametrize_with_checks"):
# Common tests are only run on scikit-learn 0.22+
Expand Down