diff --git a/src/compression.jl b/src/compression.jl index 2492c85..4ea160a 100644 --- a/src/compression.jl +++ b/src/compression.jl @@ -37,7 +37,9 @@ function GzipCompressor(;level::Integer=Z_DEFAULT_COMPRESSION, elseif !(9 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 9..15")) end - return GzipCompressor(ZStream(), level, windowbits+16) + zstream = ZStream() + finalizer(compress_finalizer!, zstream) + return GzipCompressor(zstream, level, windowbits+16) end const GzipCompressorStream{S} = TranscodingStream{GzipCompressor,S} where S<:IO @@ -85,7 +87,9 @@ function ZlibCompressor(;level::Integer=Z_DEFAULT_COMPRESSION, elseif !(9 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 9..15")) end - return ZlibCompressor(ZStream(), level, windowbits) + zstream = ZStream() + finalizer(compress_finalizer!, zstream) + return ZlibCompressor(zstream, level, windowbits) end const ZlibCompressorStream{S} = TranscodingStream{ZlibCompressor,S} where S<:IO @@ -133,7 +137,9 @@ function DeflateCompressor(;level::Integer=Z_DEFAULT_COMPRESSION, elseif !(9 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 9..15")) end - return DeflateCompressor(ZStream(), level, -Int(windowbits)) + zstream = ZStream() + finalizer(compress_finalizer!, zstream) + return DeflateCompressor(zstream, level, -Int(windowbits)) end const DeflateCompressorStream{S} = TranscodingStream{DeflateCompressor,S} where S<:IO @@ -155,17 +161,6 @@ end # Methods # ------- -function TranscodingStreams.finalize(codec::CompressorCodec) - zstream = codec.zstream - if zstream.state != C_NULL - code = deflate_end!(zstream) - if code != Z_OK - zerror(zstream, code) - end - end - return -end - function TranscodingStreams.startproc(codec::CompressorCodec, state::Symbol, error_ref::Error) if codec.zstream.state == C_NULL code = deflate_init!(codec.zstream, codec.level, codec.windowbits) diff --git a/src/decompression.jl b/src/decompression.jl index 67bff25..65a0c9d 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -35,7 +35,9 @@ function GzipDecompressor(;windowbits::Integer=Z_DEFAULT_WINDOWBITS, gziponly::B if !(8 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 8..15")) end - return GzipDecompressor(ZStream(), windowbits+(gziponly ? 16 : 32)) + zstream = ZStream() + finalizer(decompress_finalizer!, zstream) + return GzipDecompressor(zstream, windowbits+(gziponly ? 16 : 32)) end const GzipDecompressorStream{S} = TranscodingStream{GzipDecompressor,S} where S<:IO @@ -78,7 +80,9 @@ function ZlibDecompressor(;windowbits::Integer=Z_DEFAULT_WINDOWBITS) if !(8 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 8..15")) end - return ZlibDecompressor(ZStream(), windowbits) + zstream = ZStream() + finalizer(decompress_finalizer!, zstream) + return ZlibDecompressor(zstream, windowbits) end const ZlibDecompressorStream{S} = TranscodingStream{ZlibDecompressor,S} where S<:IO @@ -121,7 +125,9 @@ function DeflateDecompressor(;windowbits::Integer=Z_DEFAULT_WINDOWBITS) if !(8 ≤ windowbits ≤ 15) throw(ArgumentError("windowbits must be within 8..15")) end - return DeflateDecompressor(ZStream(), -Int(windowbits)) + zstream = ZStream() + finalizer(decompress_finalizer!, zstream) + return DeflateDecompressor(zstream, -Int(windowbits)) end const DeflateDecompressorStream{S} = TranscodingStream{DeflateDecompressor,S} where S<:IO @@ -143,17 +149,6 @@ end # Methods # ------- -function TranscodingStreams.finalize(codec::DecompressorCodec) - zstream = codec.zstream - if zstream.state != C_NULL - code = inflate_end!(zstream) - if code != Z_OK - zerror(zstream, code) - end - end - return -end - function TranscodingStreams.startproc(codec::DecompressorCodec, ::Symbol, error_ref::Error) # indicate that no input data is being provided for future zlib compat codec.zstream.next_in = C_NULL diff --git a/src/libz.jl b/src/libz.jl index 507b7e1..ca62fb8 100644 --- a/src/libz.jl +++ b/src/libz.jl @@ -23,6 +23,18 @@ mutable struct ZStream reserved::Culong end +@assert typemax(Csize_t) ≥ typemax(Cuint) + +function zalloc(::Ptr{Cvoid}, items::Cuint, size::Cuint)::Ptr{Cvoid} + s, f = Base.Checked.mul_with_overflow(items, size) + if f + C_NULL + else + ccall(:jl_malloc, Ptr{Cvoid}, (Csize_t,), s%Csize_t) + end +end +zfree(::Ptr{Cvoid}, p::Ptr{Cvoid}) = ccall(:jl_free, Cvoid, (Ptr{Cvoid},), p) + function ZStream() ZStream( # input @@ -32,7 +44,9 @@ function ZStream() # message and state C_NULL, C_NULL, # memory allocation - C_NULL, C_NULL, C_NULL, + @cfunction(zalloc, Ptr{Cvoid}, (Ptr{Cvoid}, Cuint, Cuint)), + @cfunction(zfree, Cvoid, (Ptr{Cvoid}, Ptr{Cvoid})), + C_NULL, # data type, adler and reserved 0, 0, 0) end @@ -83,6 +97,11 @@ function deflate_end!(zstream::ZStream) return ccall((:deflateEnd, libz), Cint, (Ref{ZStream},), zstream) end +function compress_finalizer!(zstream::ZStream) + deflate_end!(zstream) + nothing +end + function deflate!(zstream::ZStream, flush::Integer) return ccall((:deflate, libz), Cint, (Ref{ZStream}, Cint), zstream, flush) end @@ -99,6 +118,11 @@ function inflate_end!(zstream::ZStream) return ccall((:inflateEnd, libz), Cint, (Ref{ZStream},), zstream) end +function decompress_finalizer!(zstream::ZStream) + inflate_end!(zstream) + nothing +end + function inflate!(zstream::ZStream, flush::Integer) return ccall((:inflate, libz), Cint, (Ref{ZStream}, Cint), zstream, flush) end diff --git a/test/big-mem-tests.jl b/test/big-mem-tests.jl index a8cd1b1..e728fa9 100644 --- a/test/big-mem-tests.jl +++ b/test/big-mem-tests.jl @@ -6,16 +6,21 @@ using Test using CodecZlib -# Enable this when https://github.com/JuliaIO/CodecZlib.jl/issues/88 is fixed. -# @testset "memory leak" begin -# function foo() -# for i in 1:1000000 -# c = transcode(GzipCompressor(), zeros(UInt8,16)) -# u = transcode(GzipDecompressor(), c) -# end -# end -# foo() -# end +@testset "memory leak" begin + function foo() + for (encode, decode) in [ + (GzipCompressor, GzipDecompressor), + (ZlibCompressor, ZlibDecompressor), + (DeflateCompressor, DeflateDecompressor), + ] + for i in 1:1000000 + c = transcode(encode(), zeros(UInt8,16)) + u = transcode(decode(), c) + end + end + end + foo() +end @testset "Big Memory Tests" begin Sys.WORD_SIZE == 64 || error("tests require 64 bit word size")