Skip to content

Apply optimizer to model weights without data copy #222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions src/nf/nf_conv1d_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ module nf_conv1d_layer

procedure :: forward
procedure :: backward
procedure :: get_gradients
procedure :: get_gradients_ptr
procedure :: get_num_params
procedure :: get_params
procedure :: get_params_ptr
procedure :: init
procedure :: set_params

Expand Down Expand Up @@ -97,14 +98,25 @@ module function get_params(self) result(params)
!! Parameters to get
end function get_params

module function get_gradients(self) result(gradients)
!! Return the gradients of this layer.
!! The gradients are ordered as weights first, biases second.
module subroutine get_params_ptr(self, w_ptr, b_ptr)
!! Return pointers to the parameters (weights and biases) of this layer.
class(conv1d_layer), intent(in), target :: self
!! A `conv1d_layer` instance
real, allocatable :: gradients(:)
!! Gradients to get
end function get_gradients
real, pointer, intent(out) :: w_ptr(:)
!! Pointer to the kernel weights (flattened)
real, pointer, intent(out) :: b_ptr(:)
!! Pointer to the biases
end subroutine get_params_ptr

module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
!! Return pointers to the gradients of this layer.
class(conv1d_layer), intent(in), target :: self
!! A `conv1d_layer` instance
real, pointer, intent(out) :: dw_ptr(:)
!! Pointer to the kernel weight gradients (flattened)
real, pointer, intent(out) :: db_ptr(:)
!! Pointer to the bias gradients
end subroutine get_gradients_ptr

module subroutine set_params(self, params)
!! Set the parameters of the layer.
Expand Down
20 changes: 14 additions & 6 deletions src/nf/nf_conv1d_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,21 @@ module function get_params(self) result(params)
params = [ w_, self % biases]
end function get_params

module function get_gradients(self) result(gradients)
module subroutine get_params_ptr(self, w_ptr, b_ptr)
class(conv1d_layer), intent(in), target :: self
real, allocatable :: gradients(:)
real, pointer :: dw_(:) => null()
dw_(1:size(self % dw)) => self % dw
gradients = [ dw_, self % db ]
end function get_gradients
real, pointer, intent(out) :: w_ptr(:)
real, pointer, intent(out) :: b_ptr(:)
w_ptr(1:size(self % kernel)) => self % kernel
b_ptr => self % biases
end subroutine get_params_ptr

module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
class(conv1d_layer), intent(in), target :: self
real, pointer, intent(out) :: dw_ptr(:)
real, pointer, intent(out) :: db_ptr(:)
dw_ptr(1:size(self % dw)) => self % dw
db_ptr => self % db
end subroutine get_gradients_ptr

module subroutine set_params(self, params)
class(conv1d_layer), intent(in out) :: self
Expand Down
26 changes: 19 additions & 7 deletions src/nf/nf_conv2d_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ module nf_conv2d_layer

procedure :: forward
procedure :: backward
procedure :: get_gradients
procedure :: get_gradients_ptr
procedure :: get_num_params
procedure :: get_params
procedure :: get_params_ptr
procedure :: init
procedure :: set_params

Expand Down Expand Up @@ -98,14 +99,25 @@ module function get_params(self) result(params)
!! Parameters to get
end function get_params

module function get_gradients(self) result(gradients)
!! Return the gradients of this layer.
!! The gradients are ordered as weights first, biases second.
module subroutine get_params_ptr(self, w_ptr, b_ptr)
!! Return pointers to the parameters (weights and biases) of this layer.
class(conv2d_layer), intent(in), target :: self
!! A `conv2d_layer` instance
real, allocatable :: gradients(:)
!! Gradients to get
end function get_gradients
real, pointer, intent(out) :: w_ptr(:)
!! Pointer to the kernel weights (flattened)
real, pointer, intent(out) :: b_ptr(:)
!! Pointer to the biases
end subroutine get_params_ptr

module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
!! Return pointers to the gradients of this layer.
class(conv2d_layer), intent(in), target :: self
!! A `conv2d_layer` instance
real, pointer, intent(out) :: dw_ptr(:)
!! Pointer to the kernel weight gradients (flattened)
real, pointer, intent(out) :: db_ptr(:)
!! Pointer to the bias gradients
end subroutine get_gradients_ptr

module subroutine set_params(self, params)
!! Set the parameters of the layer.
Expand Down
26 changes: 14 additions & 12 deletions src/nf/nf_conv2d_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,23 @@ module function get_params(self) result(params)

end function get_params


module function get_gradients(self) result(gradients)
module subroutine get_params_ptr(self, w_ptr, b_ptr)
class(conv2d_layer), intent(in), target :: self
real, allocatable :: gradients(:)

real, pointer :: dw_(:) => null()
real, pointer, intent(out) :: w_ptr(:)
real, pointer, intent(out) :: b_ptr(:)
w_ptr(1:size(self % kernel)) => self % kernel
b_ptr => self % biases
end subroutine get_params_ptr

dw_(1:size(self % dw)) => self % dw

gradients = [ &
dw_, &
self % db &
]

end function get_gradients
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
class(conv2d_layer), intent(in), target :: self
real, pointer, intent(out) :: dw_ptr(:)
real, pointer, intent(out) :: db_ptr(:)
dw_ptr(1:size(self % dw)) => self % dw
db_ptr => self % db
end subroutine get_gradients_ptr


module subroutine set_params(self, params)
Expand Down
20 changes: 12 additions & 8 deletions src/nf/nf_dense_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ module nf_dense_layer

procedure :: backward
procedure :: forward
procedure :: get_gradients
procedure :: get_gradients_ptr
procedure :: get_num_params
procedure :: get_params
procedure :: get_params_ptr
procedure :: init
procedure :: set_params

Expand Down Expand Up @@ -96,14 +97,17 @@ module function get_params(self) result(params)
!! Parameters of this layer
end function get_params

module function get_gradients(self) result(gradients)
!! Return the gradients of this layer.
!! The gradients are ordered as weights first, biases second.
module subroutine get_params_ptr(self, w_ptr, b_ptr)
class(dense_layer), intent(in), target :: self
!! Dense layer instance
real, allocatable :: gradients(:)
!! Gradients of this layer
end function get_gradients
real, pointer, intent(out) :: w_ptr(:)
real, pointer, intent(out) :: b_ptr(:)
end subroutine get_params_ptr

module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
class(dense_layer), intent(in), target :: self
real, pointer, intent(out) :: dw_ptr(:)
real, pointer, intent(out) :: db_ptr(:)
end subroutine get_gradients_ptr

module subroutine set_params(self, params)
!! Set the parameters of this layer.
Expand Down
24 changes: 13 additions & 11 deletions src/nf/nf_dense_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,22 @@ module function get_params(self) result(params)
end function get_params


module function get_gradients(self) result(gradients)
module subroutine get_params_ptr(self, w_ptr, b_ptr)
class(dense_layer), intent(in), target :: self
real, allocatable :: gradients(:)
real, pointer, intent(out) :: w_ptr(:)
real, pointer, intent(out) :: b_ptr(:)
w_ptr(1:size(self % weights)) => self % weights
b_ptr => self % biases
end subroutine get_params_ptr

real, pointer :: dw_(:) => null()

dw_(1:size(self % dw)) => self % dw

gradients = [ &
dw_, &
self % db &
]

end function get_gradients
module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
class(dense_layer), intent(in), target :: self
real, pointer, intent(out) :: dw_ptr(:)
real, pointer, intent(out) :: db_ptr(:)
dw_ptr(1:size(self % dw)) => self % dw
db_ptr => self % db
end subroutine get_gradients_ptr


module subroutine set_params(self, params)
Expand Down
10 changes: 1 addition & 9 deletions src/nf/nf_layer.f90
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ module nf_layer
integer, allocatable :: layer_shape(:)
integer, allocatable :: input_layer_shape(:)
logical :: initialized = .false.
class(optimizer_base_type), allocatable :: optimizer

contains

procedure :: forward
procedure :: get_num_params
procedure :: get_params
procedure :: get_gradients
procedure :: set_params
procedure :: init
procedure :: print_info
Expand Down Expand Up @@ -160,14 +160,6 @@ module function get_params(self) result(params)
!! Parameters of this layer
end function get_params

module function get_gradients(self) result(gradients)
!! Returns the gradients of this layer.
class(layer), intent(in) :: self
!! Layer instance
real, allocatable :: gradients(:)
!! Gradients of this layer
end function get_gradients

module subroutine set_params(self, params)
!! Returns the parameters of this layer.
class(layer), intent(in out) :: self
Expand Down
44 changes: 0 additions & 44 deletions src/nf/nf_layer_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -682,50 +682,6 @@ module function get_params(self) result(params)

end function get_params

module function get_gradients(self) result(gradients)
class(layer), intent(in) :: self
real, allocatable :: gradients(:)

select type (this_layer => self % p)
type is (input1d_layer)
! No gradients to get.
type is (input2d_layer)
! No gradients to get.
type is (input3d_layer)
! No gradients to get.
type is (dense_layer)
gradients = this_layer % get_gradients()
type is (dropout_layer)
! No gradients to get.
type is (conv1d_layer)
gradients = this_layer % get_gradients()
type is (conv2d_layer)
gradients = this_layer % get_gradients()
type is (locally_connected1d_layer)
gradients = this_layer % get_gradients()
type is (maxpool1d_layer)
! No gradients to get.
type is (maxpool2d_layer)
! No gradients to get.
type is (flatten_layer)
! No gradients to get.
type is (reshape2d_layer)
! No parameters to get.
type is (reshape3d_layer)
! No gradients to get.
type is (linear2d_layer)
gradients = this_layer % get_gradients()
type is (self_attention_layer)
gradients = this_layer % get_gradients()
type is (embedding_layer)
gradients = this_layer % get_gradients()
type is (layernorm_layer)
gradients = this_layer % get_gradients()
class default
error stop 'Unknown layer type.'
end select

end function get_gradients

module subroutine set_params(self, params)
class(layer), intent(in out) :: self
Expand Down
14 changes: 14 additions & 0 deletions src/nf/nf_layernorm.f90
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ module nf_layernorm_layer
procedure :: init
procedure :: get_num_params
procedure :: get_params
procedure :: get_params_ptr
procedure :: get_gradients
procedure :: get_gradients_ptr
procedure :: set_params
end type layernorm_layer

Expand Down Expand Up @@ -78,12 +80,24 @@ module function get_params(self) result(params)
end function get_params


module subroutine get_params_ptr(self, g_ptr, b_ptr)
class(layernorm_layer), intent(in), target :: self
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
end subroutine get_params_ptr


module function get_gradients(self) result(gradients)
class(layernorm_layer), intent(in), target :: self
real, allocatable :: gradients(:)
end function get_gradients


module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
class(layernorm_layer), intent(in), target :: self
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
end subroutine get_gradients_ptr


module subroutine set_params(self, params)
class(layernorm_layer), intent(in out) :: self
real, intent(in), target :: params(:)
Expand Down
26 changes: 16 additions & 10 deletions src/nf/nf_layernorm_submodule.f90
Original file line number Diff line number Diff line change
Expand Up @@ -112,25 +112,31 @@ end function get_num_params
module function get_params(self) result(params)
class(layernorm_layer), intent(in), target :: self
real, allocatable :: params(:)
params = [self % gamma, self % beta]
end function get_params

params = [ &
self % gamma, &
self % beta &
]

end function get_params
module subroutine get_params_ptr(self, g_ptr, b_ptr)
class(layernorm_layer), intent(in), target :: self
real, pointer, intent(out) :: g_ptr(:), b_ptr(:)
g_ptr => self % gamma
b_ptr => self % beta
end subroutine get_params_ptr


module function get_gradients(self) result(gradients)
class(layernorm_layer), intent(in), target :: self
real, allocatable :: gradients(:)
gradients = [self % d_gamma, self % d_beta]
end function get_gradients

gradients = [ &
self % d_gamma, &
self % d_beta &
]

end function get_gradients
module subroutine get_gradients_ptr(self, dg_ptr, db_ptr)
class(layernorm_layer), intent(in), target :: self
real, pointer, intent(out) :: dg_ptr(:), db_ptr(:)
dg_ptr => self % d_gamma
db_ptr => self % d_beta
end subroutine get_gradients_ptr


module subroutine set_params(self, params)
Expand Down
Loading