@@ -27,6 +27,8 @@ def __init__(self, origdst, synthdst):
27
27
def to_cat (dtr , dts ):
28
28
29
29
target_cols = list (dtr .columns [11 :- 3 ])
30
+ target_cols .insert (0 , dtr .columns [1 ]) # channel
31
+ target_cols .insert (0 , dtr .columns [2 ]) # program_title
30
32
target_cols .insert (0 , dtr .columns [3 ]) # genre
31
33
32
34
# flag_same_demographic_column_values = True
@@ -121,17 +123,28 @@ def jensen_shannon(self):
121
123
real_cat , synth_cat = self .to_cat (self .origdst , self .synthdst )
122
124
123
125
target_columns = list (self .origdst .columns [11 :- 3 ])
124
- target_columns .append (self .origdst .columns [3 ]) # content_id
126
+ target_columns .append (self .origdst .columns [1 ]) # channel
127
+ target_columns .append (self .origdst .columns [2 ]) # program_title
128
+ target_columns .append (self .origdst .columns [3 ]) # genre
125
129
126
130
js_dict = {}
127
131
128
132
for col in target_columns :
129
- col_counts_orig = real_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
130
- col_counts_synth = synth_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
131
133
132
- js = distance .jensenshannon (asarray (col_counts_orig .tolist ()), asarray (col_counts_synth .tolist ()), base = 2 )
134
+ try :
135
+ col_counts_orig = real_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
136
+ col_counts_synth = synth_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
133
137
134
- js_dict [col ] = js
138
+ js = distance .jensenshannon (asarray (col_counts_orig .tolist ()), asarray (col_counts_synth .tolist ()),
139
+ base = 2 )
140
+
141
+ js_dict [col ] = js
142
+
143
+ except :
144
+
145
+ print ('For the column ' , col , ' you must generate the same unique values as the real dataset.' )
146
+ print ('The number of unique values than you should generate for column ' , col , 'is ' ,
147
+ len (self .origdst [col ].unique ()))
135
148
136
149
return js_dict
137
150
@@ -142,17 +155,28 @@ def kl_divergence(self):
142
155
The threshold limit for this metric is a value below 2"""
143
156
144
157
target_columns = list (self .origdst .columns [11 :- 3 ])
145
- target_columns .append (self .origdst .columns [4 ]) # content_id
158
+ target_columns .append (self .origdst .columns [1 ]) # channel
159
+ target_columns .append (self .origdst .columns [2 ]) # program_title
160
+ target_columns .append (self .origdst .columns [3 ]) # genre
146
161
147
162
kl_dict = {}
148
163
149
164
for col in target_columns :
150
- col_counts_orig = self .origdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
151
- col_counts_synth = self .synthdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
152
165
153
- kl = sum (rel_entr (col_counts_orig .tolist (), col_counts_synth .tolist ()))
166
+ try :
167
+
168
+ col_counts_orig = self .origdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
169
+ col_counts_synth = self .synthdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
154
170
155
- kl_dict [col ] = kl
171
+ kl = sum (rel_entr (col_counts_orig .tolist (), col_counts_synth .tolist ()))
172
+
173
+ kl_dict [col ] = kl
174
+
175
+ except :
176
+
177
+ print ('For the column ' , col , ' you must generate the same unique values as the real dataset.' )
178
+ print ('The number of unique values than you should generate for column ' , col , 'is ' ,
179
+ len (self .origdst [col ].unique ()))
156
180
157
181
return kl_dict
158
182
@@ -275,10 +299,13 @@ def pairwise_correlation_difference(self):
275
299
jsd = copy .deepcopy (dict_js )
276
300
277
301
for key in list (dict_js ):
278
- if (dict_js [key ] < 0.50 ) & (key != 'CONTENT_ID' ):
302
+ if (dict_js [key ] < 0.50 ) & (key not in [ 'GENRE' , 'PROGRAM_TITLE' ] ):
279
303
del dict_js [key ]
280
- if key == 'CONTENT_ID' :
281
- if (dict_js [key ] < 0.75 ):
304
+ if key == 'GENRE' :
305
+ if (dict_js [key ] < 0.59 ):
306
+ del dict_js [key ]
307
+ if key == 'PROGRAM_TITLE' :
308
+ if (dict_js [key ] < 0.69 ):
282
309
del dict_js [key ]
283
310
284
311
if dict_js :
0 commit comments