@@ -27,6 +27,8 @@ def __init__(self, origdst, synthdst):
2727 def to_cat (dtr , dts ):
2828
2929 target_cols = list (dtr .columns [11 :- 3 ])
30+ target_cols .insert (0 , dtr .columns [1 ]) # channel
31+ target_cols .insert (0 , dtr .columns [2 ]) # program_title
3032 target_cols .insert (0 , dtr .columns [3 ]) # genre
3133
3234 # flag_same_demographic_column_values = True
@@ -121,17 +123,28 @@ def jensen_shannon(self):
121123 real_cat , synth_cat = self .to_cat (self .origdst , self .synthdst )
122124
123125 target_columns = list (self .origdst .columns [11 :- 3 ])
124- target_columns .append (self .origdst .columns [3 ]) # content_id
126+ target_columns .append (self .origdst .columns [1 ]) # channel
127+ target_columns .append (self .origdst .columns [2 ]) # program_title
128+ target_columns .append (self .origdst .columns [3 ]) # genre
125129
126130 js_dict = {}
127131
128132 for col in target_columns :
129- col_counts_orig = real_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
130- col_counts_synth = synth_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
131133
132- js = distance .jensenshannon (asarray (col_counts_orig .tolist ()), asarray (col_counts_synth .tolist ()), base = 2 )
134+ try :
135+ col_counts_orig = real_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
136+ col_counts_synth = synth_cat [col ].value_counts (normalize = True ).sort_index (ascending = True )
133137
134- js_dict [col ] = js
138+ js = distance .jensenshannon (asarray (col_counts_orig .tolist ()), asarray (col_counts_synth .tolist ()),
139+ base = 2 )
140+
141+ js_dict [col ] = js
142+
143+ except :
144+
145+ print ('For the column ' , col , ' you must generate the same unique values as the real dataset.' )
146+ print ('The number of unique values than you should generate for column ' , col , 'is ' ,
147+ len (self .origdst [col ].unique ()))
135148
136149 return js_dict
137150
@@ -142,17 +155,28 @@ def kl_divergence(self):
142155 The threshold limit for this metric is a value below 2"""
143156
144157 target_columns = list (self .origdst .columns [11 :- 3 ])
145- target_columns .append (self .origdst .columns [4 ]) # content_id
158+ target_columns .append (self .origdst .columns [1 ]) # channel
159+ target_columns .append (self .origdst .columns [2 ]) # program_title
160+ target_columns .append (self .origdst .columns [3 ]) # genre
146161
147162 kl_dict = {}
148163
149164 for col in target_columns :
150- col_counts_orig = self .origdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
151- col_counts_synth = self .synthdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
152165
153- kl = sum (rel_entr (col_counts_orig .tolist (), col_counts_synth .tolist ()))
166+ try :
167+
168+ col_counts_orig = self .origdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
169+ col_counts_synth = self .synthdst [col ].value_counts (normalize = True ).sort_index (ascending = True )
154170
155- kl_dict [col ] = kl
171+ kl = sum (rel_entr (col_counts_orig .tolist (), col_counts_synth .tolist ()))
172+
173+ kl_dict [col ] = kl
174+
175+ except :
176+
177+ print ('For the column ' , col , ' you must generate the same unique values as the real dataset.' )
178+ print ('The number of unique values than you should generate for column ' , col , 'is ' ,
179+ len (self .origdst [col ].unique ()))
156180
157181 return kl_dict
158182
@@ -275,10 +299,13 @@ def pairwise_correlation_difference(self):
275299 jsd = copy .deepcopy (dict_js )
276300
277301 for key in list (dict_js ):
278- if (dict_js [key ] < 0.50 ) & (key != 'CONTENT_ID' ):
302+ if (dict_js [key ] < 0.50 ) & (key not in [ 'GENRE' , 'PROGRAM_TITLE' ] ):
279303 del dict_js [key ]
280- if key == 'CONTENT_ID' :
281- if (dict_js [key ] < 0.75 ):
304+ if key == 'GENRE' :
305+ if (dict_js [key ] < 0.59 ):
306+ del dict_js [key ]
307+ if key == 'PROGRAM_TITLE' :
308+ if (dict_js [key ] < 0.69 ):
282309 del dict_js [key ]
283310
284311 if dict_js :
0 commit comments