Skip to content

Commit f565675

Browse files
committed
updated tabular model, and added tests
1 parent a081616 commit f565675

File tree

5 files changed

+66
-4
lines changed

5 files changed

+66
-4
lines changed

src/models/Models.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ include("tabularmodel.jl")
1818

1919

2020
export xresnet18, xresnet50, UNetDynamic,
21-
TabularModel, get_emb_sz, embeddingbackbone, continuousbackbone
21+
TabularModel, get_emb_sz, embeddingbackbone, continuousbackbone, sigmoidrange
2222

2323

2424
end

src/models/tabularmodel.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,12 @@ end
1818
# [_one_emb_sz(catdict, catcol, sz_dict) for catcol in cols]
1919
# end
2020

21-
function embeddingbackbone(embedding_sizes, dropoutprob)
22-
embedslist = [Embedding(ni, nf) for (ni, nf) in embedding_sizes]
21+
function sigmoidrange(x, low, high)
22+
@. Flux.sigmoid(x) * (high - low) + low
23+
end
24+
25+
function embeddingbackbone(embedding_sizes, dropoutprob=0.)
26+
embedslist = [Embedding(ni => nf) for (ni, nf) in embedding_sizes]
2327
emb_drop = Dropout(dropoutprob)
2428
Chain(
2529
x -> tuple(eachrow(x)...),
@@ -35,7 +39,7 @@ end
3539
function TabularModel(
3640
catbackbone,
3741
contbackbone,
38-
layers;
42+
layers=[200, 100];
3943
n_cat,
4044
n_cont,
4145
out_sz,

test/imports.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using FastAI: Image, Keypoints, Mask, testencoding, Label, OneHot, ProjectiveTra
66
encodedblock, decodedblock, encode, decode, mockblock
77
using FilePathsBase
88
using FastAI.Datasets
9+
using FastAI.Models
910
using DLPipelines
1011
import DataAugmentation
1112
import DataAugmentation: getbounds

test/models/tabularmodel.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
include("../imports.jl")
2+
3+
@testset ExtendedTestSet "TabularModel Components" begin
4+
@testset ExtendedTestSet "embeddingbackbone" begin
5+
embed_szs = [(5, 10), (100, 30), (2, 30)]
6+
embeds = embeddingbackbone(embed_szs, 0.)
7+
x = [rand(1:n) for (n, _) in embed_szs]
8+
9+
@test size(embeds(x)) == (70, 1)
10+
end
11+
12+
@testset ExtendedTestSet "continuousbackbone" begin
13+
n = 5
14+
contback = continuousbackbone(n)
15+
x = rand(5, 1)
16+
@test size(contback(x)) == (5, 1)
17+
end
18+
19+
@testset ExtendedTestSet "TabularModel" begin
20+
n = 5
21+
embed_szs = [(5, 10), (100, 30), (2, 30)]
22+
23+
embeds = embeddingbackbone(embed_szs, 0.)
24+
contback = continuousbackbone(n)
25+
26+
tm = TabularModel(
27+
embeds,
28+
contback,
29+
[200, 100],
30+
n_cat=3,
31+
n_cont=5,
32+
out_sz=4
33+
)
34+
x = ([rand(1:n) for (n, _) in embed_szs], rand(5, 1))
35+
@test size(tm(x)) == (4, 1)
36+
37+
tm2 = TabularModel(
38+
embeds,
39+
contback,
40+
[200, 100],
41+
n_cat=3,
42+
n_cont=5,
43+
out_sz=4,
44+
final_activation=x->FastAI.sigmoidrange(x, 2, 5)
45+
)
46+
y2 = tm2(x)
47+
@test all(y2.> 2) && all(y2.<5)
48+
end
49+
end
50+
51+

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,10 @@ include("imports.jl")
5555
end
5656
# TODO: test learning rate finder
5757
end
58+
59+
@testset ExtendedTestSet "models/" begin
60+
@testset ExtendedTestSet "tabularmodel.jl" begin
61+
include("models/tabularmodel.jl")
62+
end
63+
end
5864
end

0 commit comments

Comments
 (0)