Skip to content

Commit 04d27d4

Browse files
committed
added classifierbackbone
1 parent f565675 commit 04d27d4

File tree

1 file changed

+42
-36
lines changed

1 file changed

+42
-36
lines changed

Diff for: src/models/tabularmodel.jl

+42-36
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,25 @@
11
function emb_sz_rule(n_cat)
2-
min(600, round(1.6 * n_cat^0.56))
2+
min(600, round(1.6 * n_cat^0.56))
33
end
44

55
function _one_emb_sz(catdict, catcol::Symbol, sz_dict=nothing)
6-
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict
7-
n_cat = length(catdict[catcol])
8-
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat)
9-
Int64(n_cat), Int64(sz)
6+
sz_dict = isnothing(sz_dict) ? Dict() : sz_dict
7+
n_cat = length(catdict[catcol])
8+
sz = catcol in keys(sz_dict) ? sz_dict[catcol] : emb_sz_rule(n_cat)
9+
Int64(n_cat)+1, Int64(sz)
1010
end
1111

1212
function get_emb_sz(catdict, cols; sz_dict=nothing)
13-
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
13+
[_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
1414
end
1515

16-
# function get_emb_sz(td::TableDataset, sz_dict=nothing)
17-
# cols = Tables.columnaccess(td.table) ? Tables.columnnames(td.table) : Tables.columnnames(Tables.rows(td.table)[1])
18-
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
19-
# end
20-
2116
function sigmoidrange(x, low, high)
22-
@. Flux.sigmoid(x) * (high - low) + low
17+
@. Flux.sigmoid(x) * (high - low) + low
2318
end
2419

2520
function embeddingbackbone(embedding_sizes, dropoutprob=0.)
2621
embedslist = [Embedding(ni => nf) for (ni, nf) in embedding_sizes]
27-
emb_drop = Dropout(dropoutprob)
22+
emb_drop = dropoutprob==0. ? identity : Dropout(dropoutprob)
2823
Chain(
2924
x -> tuple(eachrow(x)...),
3025
Parallel(vcat, embedslist),
@@ -36,40 +31,51 @@ function continuousbackbone(n_cont)
3631
n_cont > 0 ? BatchNorm(n_cont) : identity
3732
end
3833

39-
function TabularModel(
40-
catbackbone,
41-
contbackbone,
42-
layers=[200, 100];
43-
n_cat,
44-
n_cont,
45-
out_sz,
34+
function classifierbackbone(
35+
layers;
4636
ps=0,
4737
use_bn=true,
4838
bn_final=false,
4939
act_cls=Flux.relu,
50-
lin_first=true,
51-
final_activation=identity
52-
)
53-
54-
tabularbackbone = Parallel(vcat, catbackbone, contbackbone)
55-
56-
catoutsize = first(Flux.outputsize(catbackbone, (n_cat, 1)))
40+
lin_first=true)
5741
ps = Iterators.cycle(ps)
5842
classifiers = []
5943

60-
first_ps, ps = Iterators.peel(ps)
61-
push!(classifiers, linbndrop(catoutsize+n_cont, first(layers); use_bn=use_bn, p=first_ps, lin_first=lin_first, act=act_cls))
62-
63-
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:(end)], ps)
44+
for (isize, osize, p) in zip(layers[1:(end-1)], layers[2:end], ps)
6445
layer = linbndrop(isize, osize; use_bn=use_bn, p=p, act=act_cls, lin_first=lin_first)
6546
push!(classifiers, layer)
6647
end
67-
68-
push!(classifiers, linbndrop(last(layers), out_sz; use_bn=bn_final, lin_first=lin_first))
69-
70-
layers = Chain(
48+
Chain(classifiers...)
49+
end
50+
51+
function TabularModel(
52+
catbackbone,
53+
contbackbone,
54+
classifierbackbone;
55+
final_activation=identity)
56+
tabularbackbone = Parallel(vcat, catbackbone, contbackbone)
57+
Chain(
7158
tabularbackbone,
72-
classifiers...,
59+
classifierbackbone,
7360
final_activation
7461
)
7562
end
63+
64+
function TabularModel(
65+
catcols,
66+
n_cont::Number,
67+
out_sz::Number,
68+
layers=[200, 100];
69+
catdict,
70+
embszs=nothing,
71+
ps=0.)
72+
embedszs = get_emb_sz(catdict, catcols, sz_dict=embszs)
73+
catback = embeddingbackbone(embedszs)
74+
contback = continuousbackbone(n_cont)
75+
76+
classifierin = mapreduce(layer -> size(layer.weight)[1], +, catback[2].layers, init = n_cont)
77+
layers = append!([classifierin], layers, [out_sz])
78+
classback = classifierbackbone(layers, ps=ps)
79+
80+
TabularModel(catback, contback, classback)
81+
end

0 commit comments

Comments
 (0)