Skip to content

Commit

Permalink
ICA for training IVE methods
Browse files Browse the repository at this point in the history
  • Loading branch information
PietroMelzi committed Jul 27, 2022
1 parent fb1ae18 commit 63996e4
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 11 deletions.
1 change: 1 addition & 0 deletions evaluate_IVE.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,4 @@ def execute_evaluation(db, classifiers, use_pca, seed, blocked_pca_features=0):


execute_evaluation('diveface', ['mlp'], True, 0, 3)
# ['mlp', 'svm_lin', 'svm_rbf', 'rf', 'gb', 'nb', 'et', 'log_reg']
28 changes: 17 additions & 11 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xgboost import XGBClassifier
from sklearn.preprocessing import StandardScaler
import numpy as np
from sklearn.decomposition import PCA
from sklearn.decomposition import PCA, FastICA
import pickle as pk
import tqdm

Expand All @@ -20,16 +20,16 @@ def zeros_pca(x_total, mask):
return x_total


def execute_all(model_feat_importance, use_pca, seed, blocked_pca_features=0):
def execute_all(model_feat_importance, transform, seed, blocked_features=0):
# create
if not os.path.isfile('data/feret_df.csv'):
feret_df = manage_data.get_feret_df(r'data/myselection_embeddings', seed, save_files=True)
else:
feret_df = pd.read_csv('data/feret_df.csv')

folder = str(seed) + ('_pca' if use_pca else '_NO_pca')
if use_pca:
folder = folder + '_k=' + str(blocked_pca_features)
folder = str(seed) + ('_' + transform)
if transform == 'pca' or transform == 'ica':
folder = folder + '_k=' + str(blocked_features)
os.makedirs("results", exist_ok=True)
os.makedirs(os.path.join('results', folder), exist_ok=True)

Expand All @@ -41,7 +41,6 @@ def execute_all(model_feat_importance, use_pca, seed, blocked_pca_features=0):
num_steps = 1
num_eliminations = 3
num_epochs = 170
blocked_features = blocked_pca_features

# define classifier
n_estimators = 30
Expand All @@ -68,11 +67,16 @@ def execute_all(model_feat_importance, use_pca, seed, blocked_pca_features=0):
# get labels
y = manage_data.get_y_ready(feret_df, labels)

if use_pca:
if transform == 'pca':
# transform from original to PCA domain
pca = PCA(n_components=length_embeddings)
pca = PCA(n_components=length_embeddings, random_state=seed)
x_total = pca.fit_transform(x_total)
pk.dump(pca, open(os.path.join(os.path.join('results', folder), 'pca.pkl'), "wb"))
if transform == 'ica':
# transform from original to ICA domain
ica = FastICA(n_components=length_embeddings, random_state=seed, whiten='unit-variance', max_iter=1000)
x_total = ica.fit_transform(x_total)
pk.dump(ica, open(os.path.join(os.path.join('results', folder), 'ica.pkl'), "wb"))

x_first = np.copy(x_total)
x_second = np.copy(x_total)
Expand Down Expand Up @@ -108,7 +112,9 @@ def execute_all(model_feat_importance, use_pca, seed, blocked_pca_features=0):


for seed in range(10):
execute_all('rf', False, seed)
execute_all('rf', 'NO_pca', seed)
for k in [0, 3, 5]:
execute_all('rf', True, seed, k)
# execute_all('rf', ['mlp', 'svm_lin', 'svm_rbf', 'rf', 'gb', 'nb', 'et', 'log_reg'], True, 0, 3)
execute_all('rf', 'pca', seed, k)

# for seed in range(10):
# execute_all('rf', 'ica', seed)

0 comments on commit 63996e4

Please sign in to comment.