@@ -6,10 +6,10 @@ function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing)
6
6
sz_dict = isnothing (sz_dict) ? Dict () : sz_dict
7
7
n_cat = length (catdict[catcol])
8
8
sz = catcol in keys (sz_dict) ? sz_dict[catcol] : emb_sz_rule (n_cat)
9
- n_cat, sz
9
+ Int64 ( n_cat), Int64 (sz)
10
10
end
11
11
12
- function get_emb_sz (catdict, cols, sz_dict= nothing )
12
+ function get_emb_sz (catdict, cols; sz_dict= nothing )
13
13
[_one_emb_sz (catdict, catcol, sz_dict) for catcol in cols]
14
14
end
15
15
18
18
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
19
19
# end
20
20
21
- struct TabularModel
22
- embeds
23
- emb_drop
24
- bn_cont
25
- n_emb
26
- n_cont
27
- layers
28
- end
29
-
30
21
function TabularModel (
31
22
layers;
32
23
emb_szs,
33
- n_cont,
24
+ n_cont:: Int64 ,
34
25
out_sz,
35
- ps:: Union{Tuple, Vector, Number, Nothing} = nothing ,
26
+ ps:: Union{Tuple, Vector, Number} = 0 ,
36
27
embed_p:: Float64 = 0. ,
37
28
y_range= nothing ,
38
29
use_bn:: Bool = true ,
@@ -41,39 +32,22 @@ function TabularModel(
41
32
act_cls= Flux. relu,
42
33
lin_first:: Bool = true )
43
34
44
- n_cont = Int64 (n_cont)
45
- if isnothing (ps)
46
- ps = zeros (length (layers))
47
- end
48
- if ps isa Number
49
- ps = fill (ps, length (layers))
50
- end
51
35
embedslist = [Embedding (ni, nf) for (ni, nf) in emb_szs]
52
36
emb_drop = Dropout (embed_p)
53
- bn_cont = bn_cont ? BatchNorm (n_cont) : false
37
+ embeds = Chain (x -> ntuple (i -> x[i, :], length (emb_szs)), Parallel (vcat, embedslist... ), emb_drop)
38
+
39
+ bn_cont = bn_cont ? BatchNorm (n_cont) : identity
40
+
54
41
n_emb = sum (size (embedlayer. weight)[1 ] for embedlayer in embedslist)
55
- sizes = append! (zeros (0 ), [n_emb+ n_cont], layers, [out_sz])
56
- actns = append! ([], [act_cls for i in 1 : (length (sizes)- 1 )], [nothing ])
57
- _layers = [linbndrop (Int64 (sizes[i]), Int64 (sizes[i+ 1 ]), use_bn= (use_bn && ((i!= (length (actns)- 1 )) || bn_final)), p= p, act= a, lin_first= lin_first) for (i, (p, a)) in enumerate (zip (push! (ps, 0. ), actns))]
58
- if ! isnothing (y_range)
59
- push! (_layers, Chain (@. x-> Flux. sigmoid (x) * (y_range[2 ] - y_range[1 ]) + y_range[1 ]))
60
- end
61
- layers = Chain (_layers... )
62
- TabularModel (embedslist, emb_drop, bn_cont, n_emb, n_cont, layers)
63
- end
42
+ sizes = append! (zeros (0 ), [n_emb+ n_cont], layers)
43
+ actns = append! ([], [act_cls for i in 1 : (length (sizes)- 1 )])
64
44
65
- function (tm:: TabularModel )(x)
66
- x_cat, x_cont = x
67
- if tm. n_emb != 0
68
- x = [e (x_cat[i, :]) for (i, e) in enumerate (tm. embeds)]
69
- x = vcat (x... )
70
- x = tm. emb_drop (x)
45
+ _layers = []
46
+ for (i, (p, a)) in enumerate (zip (Iterators. cycle (ps), actns))
47
+ layer = linbndrop (Int64 (sizes[i]), Int64 (sizes[i+ 1 ]), use_bn= use_bn, p= p, act= a, lin_first= lin_first)
48
+ push! (_layers, layer)
71
49
end
72
- if tm. n_cont != 0
73
- if (tm. bn_cont != false )
74
- x_cont = tm. bn_cont (x_cont)
75
- end
76
- x = tm. n_emb!= 0 ? vcat (x, x_cont) : x_cont
77
- end
78
- tm. layers (x)
50
+ push! (_layers, linbndrop (Int64 (last (sizes)), Int64 (out_sz), use_bn= bn_final, lin_first= lin_first))
51
+ layers = isnothing (y_range) ? Chain (Parallel (vcat, embeds, bn_cont), _layers... ) : Chain (Parallel (vcat, embeds, bn_cont), _layers... , @. x-> Flux. sigmoid (x) * (y_range[2 ] - y_range[1 ]) + y_range[1 ])
52
+ layers
79
53
end
0 commit comments