This repository has been archived by the owner on May 27, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2 from lalalilo/master
📦 Package the TrustScore and make it compatible with python 3
- Loading branch information
Showing
6 changed files
with
171 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
build/ | ||
dist/ | ||
*.egg-info/ | ||
|
||
**/__pycache__ | ||
|
||
.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import setuptools | ||
|
||
with open("README.rst", "r", encoding="utf-8") as fh: | ||
long_description = fh.read() | ||
|
||
setuptools.setup( | ||
name="trustscore", # Replace with your own username | ||
version="0.0.1", | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url="https://github.com/google/TrustScore", | ||
packages=setuptools.find_packages(), | ||
install_requires=[ | ||
"numpy", | ||
"sklearn", | ||
], | ||
extras_require={ | ||
"evaluation": ["matplotlib", "tensorflow"] | ||
}, | ||
python_requires='>=3.6', | ||
) |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from trustscore.trustscore import TrustScore | ||
from trustscore.trustscore import KNNConfidence |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright 2018 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import numpy as np | ||
from sklearn.neighbors import KDTree, KNeighborsClassifier | ||
|
||
|
||
class TrustScore: | ||
""" | ||
Trust Score: a measure of classifier uncertainty based on nearest neighbors. | ||
""" | ||
|
||
def __init__(self, k=10, alpha=0.0, filtering="none", min_dist=1e-12): | ||
""" | ||
k and alpha are the tuning parameters for the filtering, | ||
filtering: method of filtering. option are "none", "density", | ||
"uncertainty" | ||
min_dist: some small number to mitigate possible division by 0. | ||
""" | ||
self.k = k | ||
self.filtering = filtering | ||
self.alpha = alpha | ||
self.min_dist = min_dist | ||
|
||
def filter_by_density(self, X: np.array): | ||
"""Filter out points with low kNN density. | ||
Args: | ||
X: an array of sample points. | ||
Returns: | ||
A subset of the array without points in the bottom alpha-fraction of | ||
original points of kNN density. | ||
""" | ||
kdtree = KDTree(X) | ||
knn_radii = kdtree.query(X, k=self.k)[0][:, -1] | ||
eps = np.percentile(knn_radii, (1 - self.alpha) * 100) | ||
return X[np.where(knn_radii <= eps)[0], :] | ||
|
||
def filter_by_uncertainty(self, X: np.array, y: np.array): | ||
"""Filter out points with high label disagreement amongst its kNN neighbors. | ||
Args: | ||
X: an array of sample points. | ||
Returns: | ||
A subset of the array without points in the bottom alpha-fraction of | ||
samples with highest disagreement amongst its k nearest neighbors. | ||
""" | ||
neigh = KNeighborsClassifier(n_neighbors=self.k) | ||
neigh.fit(X, y) | ||
confidence = neigh.predict_proba(X) | ||
cutoff = np.percentile(confidence, self.alpha * 100) | ||
unfiltered_idxs = np.where(confidence >= cutoff)[0] | ||
return X[unfiltered_idxs, :], y[unfiltered_idxs] | ||
|
||
def fit(self, X: np.array, y: np.array): | ||
"""Initialize trust score precomputations with training data. | ||
WARNING: assumes that the labels are 0-indexed (i.e. | ||
0, 1,..., n_labels-1). | ||
Args: | ||
X: an array of sample points. | ||
y: corresponding labels. | ||
""" | ||
|
||
self.n_labels = np.max(y) + 1 | ||
self.kdtrees = [None] * self.n_labels | ||
if self.filtering == "uncertainty": | ||
X_filtered, y_filtered = self.filter_by_uncertainty(X, y) | ||
for label in range(self.n_labels): | ||
if self.filtering == "none": | ||
X_to_use = X[np.where(y == label)[0]] | ||
self.kdtrees[label] = KDTree(X_to_use) | ||
elif self.filtering == "density": | ||
X_to_use = self.filter_by_density(X[np.where(y == label)[0]]) | ||
self.kdtrees[label] = KDTree(X_to_use) | ||
elif self.filtering == "uncertainty": | ||
X_to_use = X_filtered[np.where(y_filtered == label)[0]] | ||
self.kdtrees[label] = KDTree(X_to_use) | ||
|
||
if len(X_to_use) == 0: | ||
print( | ||
"Filtered too much or missing examples from a label! Please lower " | ||
"alpha or check data." | ||
) | ||
|
||
def get_score(self, X: np.array, y_pred: np.array): | ||
"""Compute the trust scores. | ||
Given a set of points, determines the distance to each class. | ||
Args: | ||
X: an array of sample points. | ||
y_pred: The predicted labels for these points. | ||
Returns: | ||
The trust score, which is ratio of distance to closest class that was not | ||
the predicted class to the distance to the predicted class. | ||
""" | ||
d = np.tile(None, (X.shape[0], self.n_labels)) | ||
for label_idx in range(self.n_labels): | ||
d[:, label_idx] = self.kdtrees[label_idx].query(X, k=2)[0][:, -1] | ||
|
||
sorted_d = np.sort(d, axis=1) | ||
d_to_pred = d[range(d.shape[0]), y_pred] | ||
d_to_closest_not_pred = np.where( | ||
sorted_d[:, 0] != d_to_pred, sorted_d[:, 0], sorted_d[:, 1] | ||
) | ||
return d_to_closest_not_pred / (d_to_pred + self.min_dist) | ||
|
||
|
||
class KNNConfidence: | ||
"""Baseline which uses disagreement to kNN classifier. | ||
""" | ||
|
||
def __init__(self, k=10): | ||
self.k = k | ||
|
||
def fit(self, X, y): | ||
self.kdtree = KDTree(X) | ||
self.y = y | ||
|
||
def get_score(self, X, y_pred): | ||
knn_idxs = self.kdtree.query(X, k=self.k)[1] | ||
knn_outputs = self.y[knn_idxs] | ||
return np.mean( | ||
knn_outputs == np.transpose(np.tile(y_pred, (self.k, 1))), axis=1 | ||
) |
File renamed without changes.