Skip to content

Commit fa5c563

Browse files
committed
updated tabular model
1 parent 4852dcf commit fa5c563

File tree

1 file changed

+17
-43
lines changed

1 file changed

+17
-43
lines changed

src/models/tabularmodel.jl

Lines changed: 17 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@ function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing)
66
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict
77
n_cat = length(catdict[catcol])
88
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat)
9-
n_cat, sz
9+
Int64(n_cat), Int64(sz)
1010
end
1111

12-
function get_emb_sz(catdict, cols, sz_dict=nothing)
12+
function get_emb_sz(catdict, cols; sz_dict=nothing)
1313
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
1414
end
1515

@@ -18,21 +18,12 @@ end
1818
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
1919
# end
2020

21-
struct TabularModel
22-
embeds
23-
emb_drop
24-
bn_cont
25-
n_emb
26-
n_cont
27-
layers
28-
end
29-
3021
function TabularModel(
3122
layers;
3223
emb_szs,
33-
n_cont,
24+
n_cont::Int64,
3425
out_sz,
35-
ps::Union{Tuple, Vector, Number, Nothing}=nothing,
26+
ps::Union{Tuple, Vector, Number}=0,
3627
embed_p::Float64=0.,
3728
y_range=nothing,
3829
use_bn::Bool=true,
@@ -41,39 +32,22 @@ function TabularModel(
4132
act_cls=Flux.relu,
4233
lin_first::Bool=true)
4334

44-
n_cont = Int64(n_cont)
45-
if isnothing(ps)
46-
ps = zeros(length(layers))
47-
end
48-
if ps isa Number
49-
ps = fill(ps, length(layers))
50-
end
5135
embedslist = [Embedding(ni, nf) for (ni, nf) in emb_szs]
5236
emb_drop = Dropout(embed_p)
53-
bn_cont = bn_cont ? BatchNorm(n_cont) : false
37+
embeds = Chain(x -> ntuple(i -> x[i, :], length(emb_szs)), Parallel(vcat, embedslist...), emb_drop)
38+
39+
bn_cont = bn_cont ? BatchNorm(n_cont) : identity
40+
5441
n_emb = sum(size(embedlayer.weight)[1] for embedlayer in embedslist)
55-
sizes = append!(zeros(0), [n_emb+n_cont], layers, [out_sz])
56-
actns = append!([], [act_cls for i in 1:(length(sizes)-1)], [nothing])
57-
_layers = [linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=(use_bn && ((i!=(length(actns)-1)) || bn_final)), p=p, act=a, lin_first=lin_first) for (i, (p, a)) in enumerate(zip(push!(ps, 0.), actns))]
58-
if !isnothing(y_range)
59-
push!(_layers, Chain(@. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1]))
60-
end
61-
layers = Chain(_layers...)
62-
TabularModel(embedslist, emb_drop, bn_cont, n_emb, n_cont, layers)
63-
end
42+
sizes = append!(zeros(0), [n_emb+n_cont], layers)
43+
actns = append!([], [act_cls for i in 1:(length(sizes)-1)])
6444

65-
function (tm::TabularModel)(x)
66-
x_cat, x_cont = x
67-
if tm.n_emb != 0
68-
x = [e(x_cat[i, :]) for (i, e) in enumerate(tm.embeds)]
69-
x = vcat(x...)
70-
x = tm.emb_drop(x)
45+
_layers = []
46+
for (i, (p, a)) in enumerate(zip(Iterators.cycle(ps), actns))
47+
layer = linbndrop(Int64(sizes[i]), Int64(sizes[i+1]), use_bn=use_bn, p=p, act=a, lin_first=lin_first)
48+
push!(_layers, layer)
7149
end
72-
if tm.n_cont != 0
73-
if (tm.bn_cont != false)
74-
x_cont = tm.bn_cont(x_cont)
75-
end
76-
x = tm.n_emb!=0 ? vcat(x, x_cont) : x_cont
77-
end
78-
tm.layers(x)
50+
push!(_layers, linbndrop(Int64(last(sizes)), Int64(out_sz), use_bn=bn_final, lin_first=lin_first))
51+
layers = isnothing(y_range) ? Chain(Parallel(vcat, embeds, bn_cont), _layers...) : Chain(Parallel(vcat, embeds, bn_cont), _layers..., @. x->Flux.sigmoid(x) * (y_range[2] - y_range[1]) + y_range[1])
52+
layers
7953
end

0 commit comments

Comments
 (0)