Skip to content

Commit

Permalink
fixed extracting missing samples
Browse files Browse the repository at this point in the history
  • Loading branch information
racinmat committed Feb 8, 2021
1 parent edf35be commit d9cc748
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JsonGrinder"
uuid = "d201646e-a9c0-11e8-1063-23b139159713"
authors = ["pevnak <[email protected]>", "Matej Racinsky <[email protected]>"]
version = "2.1.3"
version = "2.1.4"

[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
Expand Down
29 changes: 16 additions & 13 deletions benchmarks/2_1_3/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,18 @@ using Flux, MLDataPattern, Mill, JsonGrinder, JSON, StatsBase, HierarchicalUtils
# downloaded from https://www.kaggle.com/allen-institute-for-ai/CORD-19-research-challenge
# document_parses/pdf_json

samples = [open(JSON.parse, x) for x in readdir("examples/documents", join=true)]
# samples = [open(JSON.parse, x) for x in readdir("examples/documents", join=true)]
# samples = [open(JSON.parse, x) for x in readdir("examples2/documents_100", join=true)]
# samples = [open(JSON.parse, x) for x in readdir("examples2/documents", join=true)]
samples = [open(JSON.parse, x) for x in readdir("examples2/documents", join=true)]
# samples = samples[1:800]
# sch = JsonGrinder.schema(samples)
sch = JsonGrinder.schema(samples)

# generate_html("documents_schema.html", sch)
# delete!(sch.childs,:paper_id)
delete!(sch.childs,:paper_id)
# extractor = suggestextractor(sch, (; key_as_field=13)) # for small dataset
# extractor = suggestextractor(sch, (; key_as_field=300))
JLD2.@load "model_etc.jld2" model sch extractor

extractor = suggestextractor(sch, (; key_as_field=300))
# JLD2.@load "model_etc.jld2" model sch extractor
# JLD2.@load "mb.jld2" mbx mby
s = samples[1]

function author_cite_themself(s)
Expand All @@ -32,11 +32,13 @@ targets = author_cite_themself.(samples)
countmap(targets)
labelnames = unique(targets)

# model = reflectinmodel(sch, extractor,
# k -> Dense(k,20,relu),
# d -> SegmentedMeanMax(d),
# fsm = Dict("" => k -> Dense(k, 2)),
# )
extractor(JsonGrinder.sample_synthetic(sch, empty_dict_vals=true))

model = reflectinmodel(sch, extractor,
k -> Dense(k,20,relu),
d -> SegmentedMeanMax(d),
fsm = Dict("" => k -> Dense(k, 2)),
)

opt = Flux.Optimise.ADAM()
loss(x,y) = Flux.logitcrossentropy(model(x).data, y)
Expand All @@ -63,6 +65,7 @@ ps = Flux.params(model)
# JLD2.@save "model_etc.jld2" model sch extractor
@info "testing the gradient and catobsing"
mbx, mby = minibatch()
# JLD2.@save "mb.jld2" mbx mby
@btime minibatch()
# 178.186 ms (87016 allocations: 4.67 MiB)
# 50.702 ms (48667 allocations: 20.73 MiB)
Expand All @@ -71,7 +74,7 @@ mbx = reduce(catobs, mbx_vec)
@btime reduce(catobs, mbx_vec)
# 5.835 ms (24124 allocations: 1.51 MiB)
loss(mbx, mby)

@info "testing gradient"
gs = gradient(() -> loss(mbx, mby), ps)
@btime gradient(() -> loss(mbx, mby), ps)
Flux.Optimise.update!(opt, ps, gs)
Expand Down
5 changes: 3 additions & 2 deletions src/extractors/extract_keyasfield.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ struct ExtractKeyAsField{S,V} <: AbstractExtractor
end

function (e::ExtractKeyAsField)(v::V) where {V<:Union{Missing,Nothing}}
BagNode(ProductNode((key = e.key(missing), item = e.item(missing)))[1:0], [0:-1])
Mill._emptyismissing[] && return BagNode(missing, [0:-1])
BagNode(ProductNode((key = e.key(extractempty), item = e.item(extractempty)))[1:0], [0:-1])
end

function (e::ExtractKeyAsField)(v::ExtractEmpty)
BagNode(ProductNode((key = e.key(v), item = e.item(v))), Mill.AlignedBags(Array{UnitRange{Int64},1}()))
end

function (e::ExtractKeyAsField)(vs::Dict)
isempty(vs) && return(e(missing))
isempty(vs) && return e(missing)
items = map(collect(vs)) do (k,v)
ProductNode((key = e.key(k), item = e.item(v)))
end
Expand Down
2 changes: 1 addition & 1 deletion src/extractors/extractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct ExtractArray{T} <: AbstractExtractor
end

function (s::ExtractArray)(v::V) where {V<:Union{Missing, Nothing}}
Mill._emptyismissing[] && return(BagNode(missing, [0:-1]))
Mill._emptyismissing[] && return BagNode(missing, [0:-1])
BagNode(s.item(extractempty), [0:-1])
end

Expand Down
46 changes: 46 additions & 0 deletions test/extractors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -320,18 +320,34 @@ end
js = [Dict(randstring(5) => rand()) for _ in 1:1000]
sch = JsonGrinder.schema(js)
ext = JsonGrinder.suggestextractor(sch, (;key_as_field = 500))
@test ext isa JsonGrinder.ExtractKeyAsField

b = ext(js[1])
k = only(keys(js[1]))
@test b.data[:item].data[1] (js[1][k] - ext.item.c) * ext.item.s
@test b.data[:key].data.s[1] == k

orig_emptyismissing = Mill.emptyismissing()

Mill.emptyismissing!(true)
b = ext(nothing)
@test nobs(b) == 1
@test ismissing(b.data)
b = ext(Dict())
@test nobs(b) == 1
@test ismissing(b.data)

Mill.emptyismissing!(false)
b = ext(nothing)
@test nobs(b) == 1
@test nobs(b.data) == 0
@test b.data[:key].data isa Mill.NGramMatrix{String,Int64}
b = ext(Dict())
@test nobs(b) == 1
@test nobs(b.data) == 0
@test b.data[:key].data isa Mill.NGramMatrix{String,Int64}

Mill.emptyismissing!(orig_emptyismissing)

b = ext(extractempty)
@test nobs(b) == 0
Expand Down Expand Up @@ -830,3 +846,33 @@ end
@testset "default_scalar_extractor" begin

end

# @testset "type stability of extraction missings" begin
# j1 = JSON.parse("""{"a": 1}""")
# j2 = JSON.parse("""{"b": "a"}""")
# j3 = JSON.parse("""{"a": 3, "b": "b"}""")
# j4 = JSON.parse("""{"a": 4, "b": "c"}""")
# j5 = JSON.parse("""{"a": 5, "b": "d"}""")
#
# sch = schema([j1,j2,j3,j4,j5])
# ext = suggestextractor(sch, testing_settings)
# m = reflectinmodel(sch, ext)
#
# e1 = ext(j1)
# e2 = ext(j2)
# e3 = ext(j3)
# e12 = catobs(e1, e2)
# e23 = catobs(e2, e3)
# e13 = catobs(e1, e3)
# typeof(e1)
# typeof(e2)
# typeof(e3)
# typeof(e12)
# typeof(e23)
# typeof(e13)
# @test typeof(e1) == typeof(e2)
# @test typeof(e1) == typeof(e3)
# @test typeof(e1) == typeof(e12)
# @test typeof(e1) == typeof(e23)
# @test typeof(e1) == typeof(e13)
# end

2 comments on commit d9cc748

@racinmat
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/29634

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v2.1.4 -m "<description of version>" d9cc74800a40b1e7e9fd19ec37820fc16fe2d31a
git push origin v2.1.4

Please sign in to comment.