Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ForwardDiff for lattice strain DFPT response #1054

Merged
merged 13 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/terms/Hamiltonian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,11 @@ function LinearAlgebra.mul!(Hψ, H::Hamiltonian, ψ)
end
end
# need `deepcopy` here to copy the elements of the array of arrays ψ (not just pointers)
Base.:*(H::Hamiltonian, ψ) = mul!(deepcopy(ψ), H, ψ)
function Base.:*(H::Hamiltonian, ψ)
# This allocates new memory for the result of promoted eltype
result = one(eltype(H.basis)) * ψ
mul!(result, H, ψ)
end

# Loop through bands, IFFT to get ψ in real space, loop through terms, FFT and accumulate into Hψ
# For the common DftHamiltonianBlock there is an optimized version below
Expand Down
27 changes: 14 additions & 13 deletions src/workarounds/forwarddiff_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,35 +157,32 @@
basis_primal = construct_value(basis_dual)
scfres = self_consistent_field(basis_primal; kwargs...)

## Compute external perturbation (contained in ham_dual) and from matvec with bands
# Compute explicit density perturbation (including strain) due to normalization
ρ_basis = compute_density(basis_dual, scfres.ψ, scfres.occupation)

Check warning on line 161 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L161

Added line #L161 was not covered by tests

# Compute external perturbation (contained in ham_dual)
Hψ_dual = let
occupation_dual = [T.(occk) for occk in scfres.occupation]
ψ_dual = [Complex.(T.(real(ψk)), T.(imag(ψk))) for ψk in scfres.ψ]
ρ_dual = compute_density(basis_dual, ψ_dual, occupation_dual)
εF_dual = T(scfres.εF) # Only needed for entropy term
eigenvalues_dual = [T.(εk) for εk in scfres.eigenvalues]
ham_dual = energy_hamiltonian(basis_dual, ψ_dual, occupation_dual;
ρ=ρ_dual, eigenvalues=eigenvalues_dual,
εF=εF_dual).ham
ham_dual * ψ_dual
ham_dual = energy_hamiltonian(basis_dual, scfres.ψ, scfres.occupation;

Check warning on line 165 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L165

Added line #L165 was not covered by tests
ρ=ρ_basis, scfres.eigenvalues,
scfres.εF).ham
ham_dual * scfres.ψ

Check warning on line 168 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L168

Added line #L168 was not covered by tests
end

## Implicit differentiation
# Implicit differentiation
response.verbose && println("Solving response problem")
δresults = ntuple(ForwardDiff.npartials(T)) do α
δHextψ = [ForwardDiff.partials.(δHextψk, α) for δHextψk in Hψ_dual]
solve_ΩplusK_split(scfres, -δHextψ; tol=last(scfres.history_Δρ), response.verbose)
end

## Convert and combine
# Convert and combine
DT = Dual{ForwardDiff.tagtype(T)}
ψ = map(scfres.ψ, getfield.(δresults, :δψ)...) do ψk, δψk...
map(ψk, δψk...) do ψnk, δψnk...
Complex(DT(real(ψnk), real.(δψnk)),
DT(imag(ψnk), imag.(δψnk)))
end
end
ρ = map((ρi, δρi...) -> DT(ρi, δρi), scfres.ρ, getfield.(δresults, :δρ)...)
eigenvalues = map(scfres.eigenvalues, getfield.(δresults, :δeigenvalues)...) do εk, δεk...
map((εnk, δεnk...) -> DT(εnk, δεnk), εk, δεk...)
end
Expand All @@ -194,6 +191,10 @@
end
εF = DT(scfres.εF, getfield.(δresults, :δεF)...)

# For strain, basis_dual contributes an explicit lattice contribution which
# is not contained in δresults, so we need to recompute ρ here
ρ = compute_density(basis_dual, ψ, occupation)

Check warning on line 196 in src/workarounds/forwarddiff_rules.jl

View check run for this annotation

Codecov / codecov/patch

src/workarounds/forwarddiff_rules.jl#L196

Added line #L196 was not covered by tests

# TODO Could add δresults[α].δVind the dual part of the total local potential in ham_dual
# and in this way return a ham that represents also the total change in Hamiltonian

Expand Down
48 changes: 48 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,54 @@
end
end

@testitem "Anisotropic strain sensitivity using ForwardDiff" #=
=# tags=[:dont_test_mpi] setup=[TestCases] begin
using DFTK
using ForwardDiff
using LinearAlgebra
using ComponentArrays
using PseudoPotentialData
aluminium = TestCases.aluminium
Ecut = 5
kgrid = [2, 2, 2]
model = model_DFT(aluminium.lattice, aluminium.atoms, aluminium.positions;
functionals=LDA(), temperature=1e-2, smearing=Smearing.Gaussian(),
kinetic_blowup=BlowupCHV())
basis = PlaneWaveBasis(model; Ecut, kgrid)
nbandsalg = FixedBands(; n_bands_converge=10)

function compute_properties(lattice)
model_strained = Model(model; lattice)
basis = PlaneWaveBasis(model_strained; Ecut, kgrid)
scfres = self_consistent_field(basis; tol=1e-10, nbandsalg)
ComponentArray(
eigenvalues=stack([ev[1:10] for ev in scfres.eigenvalues]),
ρ=scfres.ρ,
energies=collect(values(scfres.energies)),
εF=scfres.εF,
occupation=reduce(vcat, scfres.occupation),
)
end

strain_isotropic(ε) = (1 + ε) * model.lattice
strain_anisotropic(ε) = DFTK.voigt_strain_to_full([ε, 0., 0., 0., 0., 0.]) * model.lattice

@testset "$strain_fn" for strain_fn in [strain_isotropic, strain_anisotropic]
f(ε) = compute_properties(strain_fn(ε))
dx = ForwardDiff.derivative(f, 0.)

h = 1e-4
x1 = f(-h)
x2 = f(+h)
dx_findiff = (x2 - x1) / 2h
@test norm(dx.ρ - dx_findiff.ρ) * sqrt(basis.dvol) < 1e-6
@test maximum(abs, dx.eigenvalues - dx_findiff.eigenvalues) < 1e-6
@test maximum(abs, dx.energies - dx_findiff.energies) < 1e-5
@test dx.εF - dx_findiff.εF < 1e-6
@test maximum(abs, dx.occupation - dx_findiff.occupation) < 1e-4
end
end

@testitem "scfres PSP sensitivity using ForwardDiff" #=
=# tags=[:dont_test_mpi] setup=[TestCases] begin
using DFTK
Expand Down
Loading