diff --git a/src/formats.jl b/src/formats.jl index c3fa437..df6e4e4 100644 --- a/src/formats.jl +++ b/src/formats.jl @@ -71,6 +71,16 @@ struct Float64Format <: AbstractMsgPackFormat end Base.@pure magic_byte(::Type{Float32Format}) = 0xca Base.@pure magic_byte(::Type{Float64Format}) = 0xcb +##### +##### `Cfloat` family (unsupported by the conventional MsgPack spec) +##### + +struct ComplexF32Format <: AbstractMsgPackFormat end +struct ComplexF64Format <: AbstractMsgPackFormat end + +Base.@pure magic_byte(::Type{ComplexF32Format}) = 0x28 +Base.@pure magic_byte(::Type{ComplexF64Format}) = 0x29 + ##### ##### `str` family ##### diff --git a/src/pack.jl b/src/pack.jl index f11dce1..3920c55 100644 --- a/src/pack.jl +++ b/src/pack.jl @@ -159,6 +159,33 @@ function pack_format(io, ::Float64Format, x) write(io, hton(y)) end +##### +##### `CFloatType` +##### + +function pack_type(io, t::ComplexFType, x) + x = to_msgpack(t, x) + x isa ComplexF32 && return pack_format(io, ComplexF32Format(), x) + x isa ComplexF64 && return pack_format(io, ComplexF64Format(), x) + invalid_pack(io, t, x) +end + +function pack_format(io, ::ComplexF32Format, x) + y = Vector{Float32}(undef,2*length(x)) + y[1:2:end] .= Float32(real(x)) + y[2:2:end] .= Float32(imag(x)) + write(io, magic_byte(ComplexF32Format)) + write(io, hton(y)) +end + +function pack_format(io, ::ComplexF64Format, x) + y = Vector{Float64}(undef,2) + y[1] = Float64(real(x)) + y[2] = Float64(imag(x)) + write(io, magic_byte(ComplexF64Format)) + write(io, hton.(y)) +end + ##### ##### `StringType` ##### diff --git a/src/types.jl b/src/types.jl index cac8193..4345437 100644 --- a/src/types.jl +++ b/src/types.jl @@ -13,6 +13,7 @@ The subtypes of `AbstractMsgPackType` are: - [`NilType`](@ref) - [`BooleanType`](@ref) - [`FloatType`](@ref) +- [`ComplexFType`](@ref) - [`StringType`](@ref) - [`BinaryType`](@ref) - [`ArrayType`](@ref) @@ -88,6 +89,24 @@ where `S` may be one of the following types: """ struct FloatType <: AbstractMsgPackType end +""" + CFloatType <: AbstractMsgPackType + +A Julia type corresponding to the MessagePack Complex Float (extension) type. + +If `msgpack_type(T)` is defined to return `CFloatType()`, then `T` must support: + +- `to_msgpack(::CFloatType, ::T)::S` +- `from_msgpack(::Type{T}, ::S)::T` +- standard numeric comparators (`>`, `<`, `==`, etc.) against values of type `S` + +where `S` may be one of the following types: + +- `CFloat32` +- `CFloat64` +""" +struct ComplexFType <: AbstractMsgPackType end + """ StringType <: AbstractMsgPackType @@ -288,6 +307,10 @@ msgpack_type(::Type{Bool}) = BooleanType() msgpack_type(::Type{<:AbstractFloat}) = FloatType() +# Cfloat-y things + +msgpack_type(::Type{<:Complex{T}}) where T = ComplexFType() + # string-y things msgpack_type(::Type{<:AbstractString}) = StringType() diff --git a/src/unpack.jl b/src/unpack.jl index 5aff60e..fc7b695 100644 --- a/src/unpack.jl +++ b/src/unpack.jl @@ -75,6 +75,10 @@ function _unpack_any(io, byte, ::Type{T}; strict) where {T} return unpack_format(io, Float32Format(), T) elseif byte === magic_byte(Float64Format) return unpack_format(io, Float64Format(), T) + elseif byte === magic_byte(ComplexF32Format) + return unpack_format(io, ComplexF32Format(), T) + elseif byte === magic_byte(ComplexF64Format) + return unpack_format(io, ComplexF64Format(), T) elseif byte === magic_byte(Str8Format) return unpack_format(io, Str8Format(), T) elseif byte === magic_byte(Str16Format) @@ -318,6 +322,26 @@ unpack_format(io, ::Float32Format, ::Type{T}) where {T<:Skip} = (skip(io, 4); T( unpack_format(io, ::Float64Format, ::Type{T}) where {T} = from_msgpack(T, ntoh(read(io, Float64))) unpack_format(io, ::Float64Format, ::Type{T}) where {T<:Skip} = (skip(io, 8); T()) +##### +##### `CFloatType` +##### + +function unpack_type(io, byte, t::ComplexFType, ::Type{T}; strict) where {T} + if byte === magic_byte(ComplexF32Format) + return unpack_format(io, ComplexF32Format(), T) + elseif byte === magic_byte(ComplexF64Format) + return unpack_format(io, ComplexF64Format(), T) + else + invalid_unpack(io, byte, t, T) + end +end + +unpack_format(io, ::ComplexF32Format, ::Type{T}) where {T} = from_msgpack(T, ntoh(read(io, ComplexF32))) +unpack_format(io, ::ComplexF32Format, ::Type{T}) where {T<:Skip} = (skip(io, 4); T()) + +unpack_format(io, ::ComplexF64Format, ::Type{T}) where {T} = from_msgpack(T, ntoh(read(io, ComplexF64))) +unpack_format(io, ::ComplexF64Format, ::Type{T}) where {T<:Skip} = (skip(io, 8); T()) + ##### ##### `StringType` ##### @@ -404,7 +428,7 @@ unpack_format(io, f::ArrayFixFormat, ::Type{T}, strict) where {T} = _unpack_arra _eltype(T) = eltype(T) function _unpack_array(io, n, ::Type{T}, strict) where {T} - E = _eltype(T) +E = _eltype(T) e = msgpack_type(E) result = Vector{E}(undef, n) for i in 1:n