1
1
function emb_sz_rule (n_cat)
2
- min (600 , round (1.6 * n_cat^ 0.56 ))
2
+ min (600 , round (1.6 * n_cat^ 0.56 ))
3
3
end
4
4
5
5
function _one_emb_sz (catdict, catcol:: Symbol , sz_dict= nothing )
6
- sz_dict = isnothing (sz_dict) ? Dict () : sz_dict
7
- n_cat = length (catdict[catcol])
8
- sz = catcol in keys (sz_dict) ? sz_dict[catcol] : emb_sz_rule (n_cat)
9
- Int64 (n_cat), Int64 (sz)
6
+ sz_dict = isnothing (sz_dict) ? Dict () : sz_dict
7
+ n_cat = length (catdict[catcol])
8
+ sz = catcol in keys (sz_dict) ? sz_dict[catcol] : emb_sz_rule (n_cat)
9
+ Int64 (n_cat)+ 1 , Int64 (sz)
10
10
end
11
11
12
12
function get_emb_sz (catdict, cols; sz_dict= nothing )
13
- [_one_emb_sz (catdict, catcol, sz_dict) for catcol in cols]
13
+ [_one_emb_sz (catdict, catcol, sz_dict) for catcol in cols]
14
14
end
15
15
16
- # function get_emb_sz(td::TableDataset, sz_dict=nothing)
17
- # cols = Tables.columnaccess(td.table) ? Tables.columnnames(td.table) : Tables.columnnames(Tables.rows(td.table)[1])
18
- # [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
19
- # end
20
-
21
16
function sigmoidrange (x, low, high)
22
- @. Flux. sigmoid (x) * (high - low) + low
17
+ @. Flux. sigmoid (x) * (high - low) + low
23
18
end
24
19
25
20
function embeddingbackbone (embedding_sizes, dropoutprob= 0. )
26
21
embedslist = [Embedding (ni => nf) for (ni, nf) in embedding_sizes]
27
- emb_drop = Dropout (dropoutprob)
22
+ emb_drop = dropoutprob == 0. ? identity : Dropout (dropoutprob)
28
23
Chain (
29
24
x -> tuple (eachrow (x)... ),
30
25
Parallel (vcat, embedslist),
@@ -36,40 +31,51 @@ function continuousbackbone(n_cont)
36
31
n_cont > 0 ? BatchNorm (n_cont) : identity
37
32
end
38
33
39
- function TabularModel (
40
- catbackbone,
41
- contbackbone,
42
- layers= [200 , 100 ];
43
- n_cat,
44
- n_cont,
45
- out_sz,
34
+ function classifierbackbone (
35
+ layers;
46
36
ps= 0 ,
47
37
use_bn= true ,
48
38
bn_final= false ,
49
39
act_cls= Flux. relu,
50
- lin_first= true ,
51
- final_activation= identity
52
- )
53
-
54
- tabularbackbone = Parallel (vcat, catbackbone, contbackbone)
55
-
56
- catoutsize = first (Flux. outputsize (catbackbone, (n_cat, 1 )))
40
+ lin_first= true )
57
41
ps = Iterators. cycle (ps)
58
42
classifiers = []
59
43
60
- first_ps, ps = Iterators. peel (ps)
61
- push! (classifiers, linbndrop (catoutsize+ n_cont, first (layers); use_bn= use_bn, p= first_ps, lin_first= lin_first, act= act_cls))
62
-
63
- for (isize, osize, p) in zip (layers[1 : (end - 1 )], layers[2 : (end )], ps)
44
+ for (isize, osize, p) in zip (layers[1 : (end - 1 )], layers[2 : end ], ps)
64
45
layer = linbndrop (isize, osize; use_bn= use_bn, p= p, act= act_cls, lin_first= lin_first)
65
46
push! (classifiers, layer)
66
47
end
67
-
68
- push! (classifiers, linbndrop (last (layers), out_sz; use_bn= bn_final, lin_first= lin_first))
69
-
70
- layers = Chain (
48
+ Chain (classifiers... )
49
+ end
50
+
51
+ function TabularModel (
52
+ catbackbone,
53
+ contbackbone,
54
+ classifierbackbone;
55
+ final_activation= identity)
56
+ tabularbackbone = Parallel (vcat, catbackbone, contbackbone)
57
+ Chain (
71
58
tabularbackbone,
72
- classifiers ... ,
59
+ classifierbackbone ,
73
60
final_activation
74
61
)
75
62
end
63
+
64
+ function TabularModel (
65
+ catcols,
66
+ n_cont:: Number ,
67
+ out_sz:: Number ,
68
+ layers= [200 , 100 ];
69
+ catdict,
70
+ embszs= nothing ,
71
+ ps= 0. )
72
+ embedszs = get_emb_sz (catdict, catcols, sz_dict= embszs)
73
+ catback = embeddingbackbone (embedszs)
74
+ contback = continuousbackbone (n_cont)
75
+
76
+ classifierin = mapreduce (layer -> size (layer. weight)[1 ], + , catback[2 ]. layers, init = n_cont)
77
+ layers = append! ([classifierin], layers, [out_sz])
78
+ classback = classifierbackbone (layers, ps= ps)
79
+
80
+ TabularModel (catback, contback, classback)
81
+ end
0 commit comments