Skip to content

Commit c6c5742

Browse files
author
Borovits
committed
Adopted changes from evaluation.py
1 parent 5ac547c commit c6c5742

File tree

1 file changed

+40
-13
lines changed

1 file changed

+40
-13
lines changed

Diff for: evaluation_in_prod.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __init__(self, origdst, synthdst):
2727
def to_cat(dtr, dts):
2828

2929
target_cols = list(dtr.columns[11:-3])
30+
target_cols.insert(0, dtr.columns[1]) # channel
31+
target_cols.insert(0, dtr.columns[2]) # program_title
3032
target_cols.insert(0, dtr.columns[3]) # genre
3133

3234
# flag_same_demographic_column_values = True
@@ -121,17 +123,28 @@ def jensen_shannon(self):
121123
real_cat, synth_cat = self.to_cat(self.origdst, self.synthdst)
122124

123125
target_columns = list(self.origdst.columns[11:-3])
124-
target_columns.append(self.origdst.columns[3]) # content_id
126+
target_columns.append(self.origdst.columns[1]) # channel
127+
target_columns.append(self.origdst.columns[2]) # program_title
128+
target_columns.append(self.origdst.columns[3]) # genre
125129

126130
js_dict = {}
127131

128132
for col in target_columns:
129-
col_counts_orig = real_cat[col].value_counts(normalize=True).sort_index(ascending=True)
130-
col_counts_synth = synth_cat[col].value_counts(normalize=True).sort_index(ascending=True)
131133

132-
js = distance.jensenshannon(asarray(col_counts_orig.tolist()), asarray(col_counts_synth.tolist()), base=2)
134+
try:
135+
col_counts_orig = real_cat[col].value_counts(normalize=True).sort_index(ascending=True)
136+
col_counts_synth = synth_cat[col].value_counts(normalize=True).sort_index(ascending=True)
133137

134-
js_dict[col] = js
138+
js = distance.jensenshannon(asarray(col_counts_orig.tolist()), asarray(col_counts_synth.tolist()),
139+
base=2)
140+
141+
js_dict[col] = js
142+
143+
except:
144+
145+
print('For the column ', col, ' you must generate the same unique values as the real dataset.')
146+
print('The number of unique values than you should generate for column ', col, 'is ',
147+
len(self.origdst[col].unique()))
135148

136149
return js_dict
137150

@@ -142,17 +155,28 @@ def kl_divergence(self):
142155
The threshold limit for this metric is a value below 2"""
143156

144157
target_columns = list(self.origdst.columns[11:-3])
145-
target_columns.append(self.origdst.columns[4]) # content_id
158+
target_columns.append(self.origdst.columns[1]) # channel
159+
target_columns.append(self.origdst.columns[2]) # program_title
160+
target_columns.append(self.origdst.columns[3]) # genre
146161

147162
kl_dict = {}
148163

149164
for col in target_columns:
150-
col_counts_orig = self.origdst[col].value_counts(normalize=True).sort_index(ascending=True)
151-
col_counts_synth = self.synthdst[col].value_counts(normalize=True).sort_index(ascending=True)
152165

153-
kl = sum(rel_entr(col_counts_orig.tolist(), col_counts_synth.tolist()))
166+
try:
167+
168+
col_counts_orig = self.origdst[col].value_counts(normalize=True).sort_index(ascending=True)
169+
col_counts_synth = self.synthdst[col].value_counts(normalize=True).sort_index(ascending=True)
154170

155-
kl_dict[col] = kl
171+
kl = sum(rel_entr(col_counts_orig.tolist(), col_counts_synth.tolist()))
172+
173+
kl_dict[col] = kl
174+
175+
except:
176+
177+
print('For the column ', col, ' you must generate the same unique values as the real dataset.')
178+
print('The number of unique values than you should generate for column ', col, 'is ',
179+
len(self.origdst[col].unique()))
156180

157181
return kl_dict
158182

@@ -275,10 +299,13 @@ def pairwise_correlation_difference(self):
275299
jsd = copy.deepcopy(dict_js)
276300

277301
for key in list(dict_js):
278-
if (dict_js[key] < 0.50) & (key != 'CONTENT_ID'):
302+
if (dict_js[key] < 0.50) & (key not in ['GENRE', 'PROGRAM_TITLE']):
279303
del dict_js[key]
280-
if key == 'CONTENT_ID':
281-
if (dict_js[key] < 0.75):
304+
if key == 'GENRE':
305+
if (dict_js[key] < 0.59):
306+
del dict_js[key]
307+
if key == 'PROGRAM_TITLE':
308+
if (dict_js[key] < 0.69):
282309
del dict_js[key]
283310

284311
if dict_js:

0 commit comments

Comments
 (0)