Skip to content

Commit d427277

Browse files
authored
Evaluation script initial commit
0 parents  commit d427277

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed

evaluation.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# !pip install dython
2+
from dython.nominal import associations
3+
from sklearn.metrics import pairwise_distances
4+
from sklearn.metrics.pairwise import euclidean_distances
5+
from numpy import linalg as LA
6+
from scipy.special import rel_entr
7+
from scipy.spatial import distance
8+
import logging
9+
import os
10+
from scipy.stats import ks_2samp
11+
12+
13+
class eval_metrics():
14+
"""The goal of the evaluation script is to measure how well the generated synthetic dataset preserves
15+
the characteristics that exist between the attributes in the original dataset. """
16+
17+
def __init__(self, origdst, synthdst):
18+
19+
self.origdst = origdst
20+
self.synthdst = synthdst
21+
22+
@staticmethod
23+
def to_cat(dtf):
24+
for col in list(dtf.columns[11:-3]):
25+
if type(dtf[col][0]) == str:
26+
dtf[col] = dtf[col].astype('category').cat.codes
27+
28+
return dtf
29+
30+
@staticmethod
31+
def get_demographics(df):
32+
33+
df = df[['CONTENT_ID', 'demographic_car_number_of_cars', 'demographic_age_of_the_eldest_child',
34+
'demographic_home_ownership', 'demographic_income',
35+
'demographic_education', 'demographic_household_composition',
36+
'demographic_number_of_people', 'demographic_age']]
37+
38+
return df
39+
40+
def euclidean_dist(self):
41+
42+
""" This metric measures the preservation of intrinsic patterns occurring between the attributes
43+
of the original dataset in the corresponding synthetic dataset. The lower the value is the better the data generation
44+
tool preserves the patterns.
45+
The threshold limit for this metric is a value below 14."""
46+
47+
real_cat = self.to_cat(self.origdst)
48+
synth_cat = self.to_cat(self.synthdst)
49+
50+
real_cat_dem = self.get_demographics(real_cat)
51+
synth_cat_dem = self.get_demographics(synth_cat)
52+
53+
corr_real_obj = associations(real_cat_dem, theil_u=True, bias_correction=False, plot=False)
54+
corr_synth_obj = associations(synth_cat_dem, theil_u=True, bias_correction=False, plot=False)
55+
56+
corr_real = corr_real_obj['corr']
57+
corr_rand = corr_synth_obj['corr']
58+
59+
eucl_matr = distance.cdist(corr_real, corr_rand, 'euclidean')
60+
61+
eucl = LA.norm(eucl_matr)
62+
63+
return eucl
64+
65+
def kolmogorov(self):
66+
67+
""" The two-sample Kolmogorov-Smirnov test is used to test whether two samples come from the same distribution.
68+
The level of significance a is set as a = 0.05. If the generated p-value from the test is lower than a then it is
69+
probable that the two distributions are different.
70+
The threshold limit for this function is a list containing less than 10 elements"""
71+
72+
real_cat = self.to_cat(self.origdst)
73+
synth_cat = self.to_cat(self.synthdst)
74+
75+
real_cat = real_cat[
76+
real_cat['iab_category_Family and Relationships'].notnull() & real_cat['iab_category_Travel'].notnull()]
77+
synth_cat = synth_cat[
78+
synth_cat['iab_category_Family and Relationships'].notnull() & synth_cat['iab_category_Travel'].notnull()]
79+
80+
target_cols = list(real_cat.columns[11:-1])
81+
82+
sample_real = real_cat[target_cols].reset_index(drop=True)
83+
sample_synth = synth_cat[target_cols].reset_index(drop=True)
84+
85+
p_value = 0.05
86+
rejected = []
87+
for col in range(10):
88+
test = ks_2samp(sample_real.iloc[:, col], sample_synth.iloc[:, col])
89+
if test[1] < p_value:
90+
rejected.append(target_cols[col])
91+
92+
return rejected
93+
94+
def kl_divergence(self):
95+
96+
""" This metric is also defined at the variable level and examines whether the distributions of the attributes are
97+
identical and measures the potential level of discrepancy between them.
98+
The threshold limit for this metric is a value below 2"""
99+
100+
target_columns = self.origdst.columns[11:-3]
101+
102+
kl_dict = {}
103+
104+
for col in target_columns:
105+
106+
col_counts_orig = self.origdst[col].value_counts()
107+
col_counts_synth = self.synthdst[col].value_counts()
108+
109+
for i, k in col_counts_orig.items():
110+
col_counts_orig[i] = k / col_counts_orig.sum()
111+
for i, k in col_counts_synth.items():
112+
col_counts_synth[i] = k / col_counts_synth.sum()
113+
114+
kl = sum(rel_entr(col_counts_orig.tolist(), col_counts_synth.tolist()))
115+
116+
kl_dict[col] = kl
117+
118+
for key in list(kl_dict):
119+
if kl_dict[key] < 2:
120+
del kl_dict[key]
121+
122+
return kl_dict
123+
124+
def pairwise_correlation_difference(self):
125+
126+
""" PCD measures the difference in terms of Frobenius norm of the correlation matrices computed from real and synthetic
127+
datasets. The smaller the PCD, the closer the synthetic data is to the real data in terms of linear correlations across
128+
the variables.
129+
The threshold limit for this metric is a value below 2.4 """
130+
131+
real_cat = self.to_cat(self.origdst)
132+
synth_cat = self.to_cat(self.synthdst)
133+
134+
real_cat_dem = self.get_demographics(real_cat)
135+
synth_cat_dem = self.get_demographics(synth_cat)
136+
137+
corr_real_obj = associations(real_cat_dem, theil_u=True, bias_correction=False, plot=False)
138+
corr_synth_obj = associations(synth_cat_dem, theil_u=True, bias_correction=False, plot=False)
139+
140+
corr_real = corr_real_obj['corr']
141+
corr_rand = corr_synth_obj['corr']
142+
143+
substract_m = np.subtract(corr_real, corr_rand)
144+
prwcrdst = LA.norm(substract_m)
145+
146+
return prwcrdst
147+
148+
149+
if __name__ == "__main__":
150+
151+
logging.basicConfig(filename='evaluation.log',
152+
format='%(asctime)s %(message)s',
153+
filemode='w')
154+
155+
logger = logging.getLogger()
156+
logger.setLevel(logging.INFO)
157+
158+
ob = eval_metrics(real, random)
159+
160+
# euclidean distance
161+
flag_eucl = False
162+
eucl = ob.euclidean_dist()
163+
print(eucl)
164+
logger.info('Euclidean distance calculated')
165+
if eucl > 14:
166+
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
167+
less than 14. The current value is {eucl}')
168+
else:
169+
logger.info('The dataaset satisfies the criteria for the euclidean distance.')
170+
flag_eucl = True
171+
logger.info('---------------------------------------------------------')
172+
173+
# 2 sample Kolmogorov-Smirnov test
174+
kst = ob.kolmogorov()
175+
flag_klg = False
176+
print(kst)
177+
logger.info('Kolmogorov-Smirnov test performed')
178+
if kst:
179+
logger.info('The dataset did not pass the Kolmogorov-Smirnov test')
180+
logger.info(f'The columns that did not pass the test are {kst}')
181+
else:
182+
logger.info('The dataset passed the Kolmogorov-Smirnov test')
183+
flag_klg = True
184+
logger.info('---------------------------------------------------------')
185+
186+
# KL divergence
187+
dict_kl = ob.kl_divergence()
188+
flag_kl = False
189+
print(dict_kl)
190+
logger.info('KL divergence calculated')
191+
if dict_kl:
192+
logger.info('The dataset did not pass the KL divergence evaluation test')
193+
for key in dict_kl.keys():
194+
logger.info(f'The KL divergence value for the column {key} was {dict_kl[key]}')
195+
else:
196+
logger.info('The dataset passed the KL divergence evaluation test')
197+
flag_kl = True
198+
logger.info('---------------------------------------------------------')
199+
200+
# pairwise correlation difference
201+
pair_corr_diff = ob.pairwise_correlation_difference()
202+
flag_pcd = False
203+
print(pair_corr_diff)
204+
logger.info('Pairwise correlation difference calculated')
205+
if pair_corr_diff > 2.4:
206+
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
207+
less than 14. The current value is {pair_corr_diff}')
208+
else:
209+
logger.info('The dataaset satisfies the criteria for the Pairwise Correlation Difference.')
210+
flag_pcd = True
211+
212+
if (flag_eucl & flag_klg & flag_kl & flag_pcd):
213+
logger.info('The dataaset satisfies the minimum evaluation criteria.')
214+
else:
215+
logger.info('The dataaset does not satisfy the minimum evaluation criteria.')
216+
logger.info('Plese check the previous log messages.')

0 commit comments

Comments
 (0)