Skip to content

Commit 6469898

Browse files
author
Borovits
committed
Added channel and program title to the columns to be encoded as categorical.
Added channel and program title to be calculated in Jensen Shannon and KL divergence. Removed the calculation of CONTENT_ID for KL divergence For the selected columns in Jensen Shannon and KL divergence the same unique values as the orginal dataset needs to be generated (try/catch code)
1 parent fb3a55e commit 6469898

File tree

1 file changed

+156
-128
lines changed

1 file changed

+156
-128
lines changed

evaluation.py

Lines changed: 156 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, origdst, synthdst):
2424
def to_cat(dtr, dts):
2525

2626
target_cols = list(dtr.columns[11:-3])
27+
target_cols.insert(0, dtr.columns[1]) # channel
28+
target_cols.insert(0, dtr.columns[2]) # program_title
2729
target_cols.insert(0, dtr.columns[3]) # genre
2830

2931
# flag_same_demographic_column_values = True
@@ -118,17 +120,28 @@ def jensen_shannon(self):
118120
real_cat, synth_cat = self.to_cat(self.origdst, self.synthdst)
119121

120122
target_columns = list(self.origdst.columns[11:-3])
123+
target_columns.append(self.origdst.columns[1]) # channel
124+
target_columns.append(self.origdst.columns[2]) # program_title
121125
target_columns.append(self.origdst.columns[3]) # genre
122126

123127
js_dict = {}
124128

125129
for col in target_columns:
126-
col_counts_orig = real_cat[col].value_counts(normalize=True).sort_index(ascending=True)
127-
col_counts_synth = synth_cat[col].value_counts(normalize=True).sort_index(ascending=True)
128130

129-
js = distance.jensenshannon(asarray(col_counts_orig.tolist()), asarray(col_counts_synth.tolist()), base=2)
131+
try:
132+
col_counts_orig = real_cat[col].value_counts(normalize=True).sort_index(ascending=True)
133+
col_counts_synth = synth_cat[col].value_counts(normalize=True).sort_index(ascending=True)
130134

131-
js_dict[col] = js
135+
js = distance.jensenshannon(asarray(col_counts_orig.tolist()), asarray(col_counts_synth.tolist()),
136+
base=2)
137+
138+
js_dict[col] = js
139+
140+
except:
141+
142+
print('For the column ', col, ' you must generate the same unique values as the real dataset.')
143+
print('The number of unique values than you should generate for column ', col, 'is ',
144+
len(self.origdst[col].unique()))
132145

133146
return js_dict
134147

@@ -139,17 +152,28 @@ def kl_divergence(self):
139152
The threshold limit for this metric is a value below 2"""
140153

141154
target_columns = list(self.origdst.columns[11:-3])
142-
target_columns.append(self.origdst.columns[4]) # content_id
155+
target_columns.append(self.origdst.columns[1]) # channel
156+
target_columns.append(self.origdst.columns[2]) # program_title
157+
target_columns.append(self.origdst.columns[3]) # genre
143158

144159
kl_dict = {}
145160

146161
for col in target_columns:
147-
col_counts_orig = self.origdst[col].value_counts(normalize=True).sort_index(ascending=True)
148-
col_counts_synth = self.synthdst[col].value_counts(normalize=True).sort_index(ascending=True)
149162

150-
kl = sum(rel_entr(col_counts_orig.tolist(), col_counts_synth.tolist()))
163+
try:
164+
165+
col_counts_orig = self.origdst[col].value_counts(normalize=True).sort_index(ascending=True)
166+
col_counts_synth = self.synthdst[col].value_counts(normalize=True).sort_index(ascending=True)
167+
168+
kl = sum(rel_entr(col_counts_orig.tolist(), col_counts_synth.tolist()))
151169

152-
kl_dict[col] = kl
170+
kl_dict[col] = kl
171+
172+
except:
173+
174+
print('For the column ', col, ' you must generate the same unique values as the real dataset.')
175+
print('The number of unique values than you should generate for column ', col, 'is ',
176+
len(self.origdst[col].unique()))
153177

154178
return kl_dict
155179

@@ -176,123 +200,127 @@ def pairwise_correlation_difference(self):
176200

177201
return prwcrdst, substract_m
178202

179-
if __name__ == "__main__":
180-
181-
logging.basicConfig(filename='evaluation.log',
182-
format='%(asctime)s %(message)s',
183-
filemode='w')
184-
185-
logger = logging.getLogger()
186-
logger.setLevel(logging.INFO)
187-
188-
ob = eval_metrics(r, ra)
189-
190-
# euclidean distance
191-
flag_eucl = False
192-
eucl, eumatr = ob.euclidean_dist()
193-
logger.info('Euclidean distance was calculated')
194-
print('The calculated euclidean distance is: ', eucl)
195-
print('The calculated euclidean distance matrix is:', eumatr)
196-
if eucl > 14:
197-
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
198-
less than 14. The current value is {eucl}')
199-
logger.info(f'The Euclidean distance matrix is \n {eumatr}')
200-
else:
201-
logger.info('The dataset satisfies the criteria for the euclidean distance.')
202-
logger.info(f'The calculated Euclidean distance value is \n {eucl}')
203-
logger.info(f'The Euclidean distance matrix is \n {eumatr}')
204-
flag_eucl = True
205-
logger.info('---------------------------------------------------------')
206-
207-
# 2 sample Kolmogorov-Smirnov test
208-
kst = ob.kolmogorov()
209-
210-
p_value = 0.05
211-
flag_klg = False
212-
logger.info('Kolmogorov-Smirnov test was performed')
213-
print('The results of the Kolmogorov-Smirnov test is:', kst)
214-
rejected = {}
215-
for col in kst:
216-
if kst[col]['p-value'] < p_value:
217-
rejected[col] = kst[col]
218-
if rejected:
219-
logger.info('The dataset did not pass the Kolmogorov-Smirnov test')
220-
logger.info(f'The columns that did not pass the test are \n {rejected}')
221-
logger.info(f'The overall performance for the test is \n {kst}')
222-
else:
223-
logger.info('The dataset passed the Kolmogorov-Smirnov test')
224-
logger.info(f'The overall performance for the test is \n {kst}')
225-
flag_klg = True
226-
logger.info('---------------------------------------------------------')
227-
228-
# Jensen-Shannon Divergence
229-
dict_js = ob.jensen_shannon()
230-
logger.info('Jensen-Shannon Divergence was calculated')
231-
print('The result of the Jensen-Shannon Divergence is:', dict_js)
232-
flag_js = False
233-
234-
jsd = deepcopy(dict_js)
235-
236-
for key in list(dict_js):
237-
if (dict_js[key] < 0.50) & (key != 'CONTENT_ID'):
203+
204+
if __name__ == "__main__":
205+
206+
logging.basicConfig(filename='evaluation.log',
207+
format='%(asctime)s %(message)s',
208+
filemode='w')
209+
210+
logger = logging.getLogger()
211+
logger.setLevel(logging.INFO)
212+
213+
ob = eval_metrics(r, ra)
214+
215+
# euclidean distance
216+
flag_eucl = False
217+
eucl, eumatr = ob.euclidean_dist()
218+
logger.info('Euclidean distance was calculated')
219+
print('The calculated euclidean distance is: ', eucl)
220+
print('The calculated euclidean distance matrix is:', eumatr)
221+
if eucl > 14:
222+
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
223+
less than 14. The current value is {eucl}')
224+
logger.info(f'The Euclidean distance matrix is \n {eumatr}')
225+
else:
226+
logger.info('The dataset satisfies the criteria for the euclidean distance.')
227+
logger.info(f'The calculated Euclidean distance value is \n {eucl}')
228+
logger.info(f'The Euclidean distance matrix is \n {eumatr}')
229+
flag_eucl = True
230+
logger.info('---------------------------------------------------------')
231+
232+
# 2 sample Kolmogorov-Smirnov test
233+
kst = ob.kolmogorov()
234+
235+
p_value = 0.05
236+
flag_klg = False
237+
logger.info('Kolmogorov-Smirnov test was performed')
238+
print('The results of the Kolmogorov-Smirnov test is:', kst)
239+
rejected = {}
240+
for col in kst:
241+
if kst[col]['p-value'] < p_value:
242+
rejected[col] = kst[col]
243+
if rejected:
244+
logger.info('The dataset did not pass the Kolmogorov-Smirnov test')
245+
logger.info(f'The columns that did not pass the test are \n {rejected}')
246+
logger.info(f'The overall performance for the test is \n {kst}')
247+
else:
248+
logger.info('The dataset passed the Kolmogorov-Smirnov test')
249+
logger.info(f'The overall performance for the test is \n {kst}')
250+
flag_klg = True
251+
logger.info('---------------------------------------------------------')
252+
253+
# Jensen-Shannon Divergence
254+
dict_js = ob.jensen_shannon()
255+
logger.info('Jensen-Shannon Divergence was calculated')
256+
print('The result of the Jensen-Shannon Divergence is:', dict_js)
257+
flag_js = False
258+
259+
jsd = deepcopy(dict_js)
260+
261+
for key in list(dict_js):
262+
if (dict_js[key] < 0.50) & (key not in ['GENRE', 'PROGRAM_TITLE']):
263+
del dict_js[key]
264+
if key == 'GENRE':
265+
if (dict_js[key] < 0.59):
238266
del dict_js[key]
239-
if key == 'CONTENT_ID':
240-
if (dict_js[key] < 0.75):
241-
del dict_js[key]
242-
243-
if dict_js:
244-
logger.info('The dataset did not pass the Jensen-Shannon Divergence test')
245-
for key in dict_js.keys():
246-
logger.info(f'The Jensen-Shannon Divergence value for the column {key} was {dict_js[key]}')
247-
logger.info(f'The overall performance for each column is summarized below: \n {jsd}')
248-
else:
249-
logger.info('The dataset passed the Jensen-Shannon Divergence test')
250-
logger.info(f'The overall performance for each column is summarized below: \n {jsd}')
251-
flag_js = True
252-
logger.info('---------------------------------------------------------')
253-
254-
# KL divergence
255-
dict_kl = ob.kl_divergence()
256-
logger.info('KL divergence was calculated')
257-
print('The result of the KL divergence is', dict_kl)
258-
flag_kl = False
259-
260-
kl = deepcopy(dict_kl)
261-
262-
for key in list(dict_kl):
263-
if dict_kl[key] < 2.20:
264-
del dict_kl[key]
265-
266-
if dict_kl:
267-
logger.info('The dataset did not pass the KL divergence evaluation test')
268-
for key in dict_kl.keys():
269-
logger.info(f'The KL divergence value for the column {key} was {dict_kl[key]}')
270-
logger.info(f'The overall for the KL divergence performance for each column is summarized below: \n {kl}')
271-
else:
272-
logger.info('The dataset passed the KL divergence evaluation test')
273-
logger.info(f'The overall performance for the KL divergence for each column is summarized below: \n {kl}')
274-
flag_kl = True
275-
logger.info('---------------------------------------------------------')
276-
277-
# pairwise correlation difference
278-
pair_corr_diff, pcd_matr = ob.pairwise_correlation_difference()
279-
logger.info('Pairwise correlation difference was calculated')
280-
print('The calculated Pairwise correlation difference was', pair_corr_diff)
281-
print('The calculated Pairwise correlation difference matrix was', pcd_matr)
282-
283-
flag_pcd = False
284-
if pair_corr_diff > 2.4:
285-
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
286-
less than 14. The current value is {pair_corr_diff}')
287-
logger.info(f'The Pairwise distance distance matrix is \n {pcd_matr}')
288-
else:
289-
logger.info('The dataaset satisfies the criteria for the Pairwise Correlation Difference.')
290-
logger.info(f'The Pairwise distance distance value is \n {pair_corr_diff}')
291-
logger.info(f'The Pairwise distance distance matrix is \n {pcd_matr}')
292-
flag_pcd = True
293-
294-
if (flag_eucl & flag_js & flag_klg & flag_kl & flag_pcd):
295-
logger.info('The dataaset satisfies the minimum evaluation criteria.')
296-
else:
297-
logger.info('The dataaset does not satisfy the minimum evaluation criteria.')
298-
logger.info('Plese check the previous log messages.')
267+
if key == 'PROGRAM_TITLE':
268+
if (dict_js[key] < 0.69):
269+
del dict_js[key]
270+
271+
if dict_js:
272+
logger.info('The dataset did not pass the Jensen-Shannon Divergence test')
273+
for key in dict_js.keys():
274+
logger.info(f'The Jensen-Shannon Divergence value for the column {key} was {dict_js[key]}')
275+
logger.info(f'The overall performance for each column is summarized below: \n {jsd}')
276+
else:
277+
logger.info('The dataset passed the Jensen-Shannon Divergence test')
278+
logger.info(f'The overall performance for each column is summarized below: \n {jsd}')
279+
flag_js = True
280+
logger.info('---------------------------------------------------------')
281+
282+
# KL divergence
283+
dict_kl = ob.kl_divergence()
284+
logger.info('KL divergence was calculated')
285+
print('The result of the KL divergence is', dict_kl)
286+
flag_kl = False
287+
288+
kl = deepcopy(dict_kl)
289+
290+
for key in list(dict_kl):
291+
if dict_kl[key] < 2.20:
292+
del dict_kl[key]
293+
294+
if dict_kl:
295+
logger.info('The dataset did not pass the KL divergence evaluation test')
296+
for key in dict_kl.keys():
297+
logger.info(f'The KL divergence value for the column {key} was {dict_kl[key]}')
298+
logger.info(f'The overall for the KL divergence performance for each column is summarized below: \n {kl}')
299+
else:
300+
logger.info('The dataset passed the KL divergence evaluation test')
301+
logger.info(f'The overall performance for the KL divergence for each column is summarized below: \n {kl}')
302+
flag_kl = True
303+
logger.info('---------------------------------------------------------')
304+
305+
# pairwise correlation difference
306+
pair_corr_diff, pcd_matr = ob.pairwise_correlation_difference()
307+
logger.info('Pairwise correlation difference was calculated')
308+
print('The calculated Pairwise correlation difference was', pair_corr_diff)
309+
print('The calculated Pairwise correlation difference matrix was', pcd_matr)
310+
311+
flag_pcd = False
312+
if pair_corr_diff > 2.4:
313+
logger.error(f'The calculated Euclidean distance value between the two correlation matrices is too high it should be \
314+
less than 14. The current value is {pair_corr_diff}')
315+
logger.info(f'The Pairwise distance distance matrix is \n {pcd_matr}')
316+
else:
317+
logger.info('The dataaset satisfies the criteria for the Pairwise Correlation Difference.')
318+
logger.info(f'The Pairwise distance distance value is \n {pair_corr_diff}')
319+
logger.info(f'The Pairwise distance distance matrix is \n {pcd_matr}')
320+
flag_pcd = True
321+
322+
if (flag_eucl & flag_js & flag_klg & flag_kl & flag_pcd):
323+
logger.info('The dataaset satisfies the minimum evaluation criteria.')
324+
else:
325+
logger.info('The dataaset does not satisfy the minimum evaluation criteria.')
326+
logger.info('Plese check the previous log messages.')

0 commit comments

Comments
 (0)