Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit 63f0c58

Browse files
authored
Merge pull request #63 from yuehhua/chebyshev
Add Chebyshev transform
2 parents 8ad7a36 + 98192ab commit 63f0c58

File tree

6 files changed

+81
-7
lines changed

6 files changed

+81
-7
lines changed

src/Transform/Transform.jl

+1
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@ export
1717
abstract type AbstractTransform end
1818

1919
include("fourier_transform.jl")
20+
include("chebyshev_transform.jl")

src/Transform/chebyshev_transform.jl

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
export ChebyshevTransform
2+
3+
struct ChebyshevTransform{N, S} <: AbstractTransform
4+
modes::NTuple{N, S} # N == ndims(x)
5+
end
6+
7+
Base.ndims(::ChebyshevTransform{N}) where {N} = N
8+
9+
function transform(t::ChebyshevTransform{N}, 𝐱::AbstractArray) where {N}
10+
return FFTW.r2r(𝐱, FFTW.REDFT10, 1:N) # [size(x)..., in_chs, batch]
11+
end
12+
13+
function truncate_modes(t::ChebyshevTransform, 𝐱̂::AbstractArray)
14+
return view(𝐱̂, map(d -> 1:d, t.modes)..., :, :) # [t.modes..., in_chs, batch]
15+
end
16+
17+
function inverse(t::ChebyshevTransform{N}, 𝐱̂::AbstractArray) where {N}
18+
normalized_𝐱̂ = 𝐱̂ ./ (prod(2 .* (size(𝐱̂)[1:N] .- 1)))
19+
return FFTW.r2r(normalized_𝐱̂, FFTW.REDFT01, 1:N) # [size(x)..., in_chs, batch]
20+
end
21+
22+
function ChainRulesCore.rrule(::typeof(FFTW.r2r), x::AbstractArray, kind, dims)
23+
y = FFTW.r2r(x, kind, dims)
24+
r2r_pullback(Δ) = (NoTangent(), ∇r2r(unthunk(Δ), kind, dims), NoTangent(), NoTangent())
25+
return y, r2r_pullback
26+
end
27+
28+
function ∇r2r::AbstractArray{T}, kind, dims) where {T}
29+
# derivative of r2r turns out to be r2r
30+
Δx = FFTW.r2r(Δ, kind, dims)
31+
32+
# rank 4 correction: needs @bischtob to elaborate the reason using this.
33+
# (M,) = size(Δ)[dims]
34+
# a1 = fill!(similar(Δ, M), one(T))
35+
# CUDA.@allowscalar a1[1] = a1[end] = zero(T)
36+
37+
# a2 = fill!(similar(Δ, M), one(T))
38+
# a2[1:2:end] .= -one(T)
39+
# CUDA.@allowscalar a2[1] = a2[end] = zero(T)
40+
41+
# e1 = fill!(similar(Δ, M), zero(T))
42+
# CUDA.@allowscalar e1[1] = one(T)
43+
44+
# eN = fill!(similar(Δ, M), zero(T))
45+
# CUDA.@allowscalar eN[end] = one(T)
46+
47+
# Δx .+= @. a1' * sum(e1' .* Δ, dims=2) - a2' * sum(eN' .* Δ, dims=2)
48+
# Δx .+= @. eN' * sum(a2' .* Δ, dims=2) - e1' * sum(a1' .* Δ, dims=2)
49+
return Δx
50+
end

src/Transform/fourier_transform.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ function low_pass(ft::FourierTransform, 𝐱_fft::AbstractArray)
1414
return view(𝐱_fft, map(d -> 1:d, ft.modes)..., :, :) # [ft.modes..., in_chs, batch]
1515
end
1616

17-
truncate_modes(args...) = low_pass(args...)
17+
truncate_modes(ft::FourierTransform, 𝐱_fft::AbstractArray) = low_pass(ft, 𝐱_fft)
1818

1919
function inverse(ft::FourierTransform, 𝐱_fft::AbstractArray)
2020
return real(ifft(𝐱_fft, 1:ndims(ft))) # [size(x_fft)..., out_chs, batch]

test/Transform/Transform.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
@testset "Transform" begin include("fourier_transform.jl") end
1+
@testset "Transform" begin
2+
include("fourier_transform.jl")
3+
include("chebyshev_transform.jl")
4+
end

test/Transform/chebyshev_transform.jl

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
@testset "Chebyshev transform" begin
2+
ch = 6
3+
batch = 7
4+
𝐱 = rand(30, 40, 50, ch, batch)
5+
6+
t = ChebyshevTransform((3, 4, 5))
7+
8+
@test ndims(t) == 3
9+
@test size(transform(t, 𝐱)) == (30, 40, 50, ch, batch)
10+
@test size(truncate_modes(t, transform(t, 𝐱))) == (3, 4, 5, ch, batch)
11+
@test size(inverse(t, truncate_modes(t, transform(t, 𝐱)))) == (3, 4, 5, ch, batch)
12+
13+
g = gradient(x -> sum(inverse(t, truncate_modes(t, transform(t, x)))), 𝐱)
14+
@test size(g[1]) == (30, 40, 50, ch, batch)
15+
end

test/Transform/fourier_transform.jl

+10-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
1-
@testset "fourier transform" begin
2-
𝐱 = rand(30, 40, 50, 6, 7) # where ch == 6 and batch == 7
1+
@testset "Fourier transform" begin
2+
ch = 6
3+
batch = 7
4+
𝐱 = rand(30, 40, 50, ch, batch)
35

46
ft = FourierTransform((3, 4, 5))
57

6-
@test size(transform(ft, 𝐱)) == (30, 40, 50, 6, 7)
7-
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, 6, 7)
8-
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, 6, 7)
8+
@test size(transform(ft, 𝐱)) == (30, 40, 50, ch, batch)
9+
@test size(truncate_modes(ft, transform(ft, 𝐱))) == (3, 4, 5, ch, batch)
10+
@test size(inverse(ft, truncate_modes(ft, transform(ft, 𝐱)))) == (3, 4, 5, ch, batch)
11+
12+
g = Zygote.gradient(x -> sum(inverse(ft, truncate_modes(ft, transform(ft, x)))), 𝐱)
13+
@test size(g[1]) == (30, 40, 50, ch, batch)
914
end

0 commit comments

Comments
 (0)