Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions config_src/drivers/timing_tests/time_MOM_ANN.F90
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ program time_MOM_ANN
use MOM_ANN, only : ANN_CS
use MOM_ANN, only : ANN_allocate, ANN_apply, ANN_end
use MOM_ANN, only : ANN_apply_vector_orig, ANN_apply_vector_oi
use MOM_ANN, only : ANN_apply_array_sio
use MOM_ANN, only : ANN_apply_array_sio, ANN_apply_array_sio_r4
use MOM_ANN, only : ANN_random

implicit none
Expand Down Expand Up @@ -72,7 +72,10 @@ program time_MOM_ANN
2, "MOM_ANN:ANN_apply_vector_oi(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
12, "MOM_ANN:ANN_apply_array_sio(array)")
3, "MOM_ANN:ANN_apply_array_sio(array)")
write(*,"(',')")
call time_ANN(nlayers, nin, layer_width, nout, nsamp, nits, nxy, &
4, "MOM_ANN:ANN_apply_array_sio_r4(array)")
write(*,"()")

write(*,'(a)') "}"
Expand Down Expand Up @@ -101,9 +104,9 @@ subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
real :: x_s(nin) ! Inputs (just features) [nondim]
real :: y_s(nin) ! Outputs (just features) [nondim]
real :: x_fs(nin,nxy) ! Inputs (feature, space) [nondim]
real :: y_fs(nin,nxy) ! Outputs (feature, space) [nondim]
real :: x_sf(nin,nxy) ! Inputs (space, feature) [nondim]
real :: y_sf(nin,nxy) ! Outputs (space, feature) [nondim]
real :: y_fs(nout,nxy) ! Outputs (feature, space) [nondim]
real :: x_sf(nxy,nin) ! Inputs (space, feature) [nondim]
real :: y_sf(nxy,nout) ! Outputs (space, feature) [nondim]
integer :: iter, samp ! Loop counters
integer :: ij ! Horizontal loop index
real :: start, finish, timing ! CPU times [s]
Expand All @@ -117,6 +120,7 @@ subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
widths(nlayers) = nout

call ANN_random(ANN, nlayers, widths)
call random_number(x_s)
call random_number(x_fs)
call random_number(x_sf)

Expand All @@ -131,7 +135,6 @@ subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
do samp = 1, nsamp
select case (impl)
case (0)
aits = nits
call cpu_time(start)
do iter = 1, nits ! Make many passes to reduce sampling error
call ANN_apply(x_s, y_s, ANN)
Expand All @@ -153,13 +156,20 @@ subroutine time_ANN(nlayers, nin, width, nout, nsamp, nits, nxy, impl, label)
enddo
enddo
call cpu_time(finish)
case (12)
case (3)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
call ANN_apply_array_sio(nxy, x_sf(:,:), y_sf(:,:), ANN)
enddo
call cpu_time(finish)
asamp = nsamp * aits ! Account for working on whole arrays
case (4)
call cpu_time(start)
do iter = 1, aits ! Make many passes to reduce sampling error
call ANN_apply_array_sio_r4(nxy, x_sf(:,:), y_sf(:,:), ANN)
enddo
call cpu_time(finish)
asamp = nsamp * aits ! Account for working on whole arrays
end select

timing = ( finish - start ) / real(nits) ! Average time per call
Expand Down
104 changes: 104 additions & 0 deletions src/framework/MOM_ANN.F90
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ module MOM_ANN

public ANN_init, ANN_allocate, ANN_apply, ANN_end, ANN_unit_tests
public ANN_apply_vector_orig, ANN_apply_vector_oi, ANN_apply_array_sio
public ANN_apply_array_sio_r4
public set_layer, set_input_normalization, set_output_normalization
public ANN_random, randomize_layer

Expand All @@ -34,6 +35,8 @@ module MOM_ANN
real, allocatable :: A(:,:) !< Matrix in column-major order
!! of size A(output_width, input_width) [nondim]
real, allocatable :: b(:) !< bias vector of size output_width [nondim]
real(4), allocatable :: A_r4(:,:) !< Same as A(:,:) but in real(4) [nondim]
real(4), allocatable :: b_r4(:) !< Same as b(:) but in real(4) [nondim]
end type layer_type

!> Control structure/type for ANN
Expand Down Expand Up @@ -117,10 +120,12 @@ subroutine ANN_init(CS, NNfile)
fieldname = trim('A') // trim(layer_num_str)
call MOM_read_data(NNfile, fieldname, CS%layers(i)%A, &
(/1,1,1,1/),(/CS%layers(i)%output_width,CS%layers(i)%input_width,1,1/))
CS%layers(i)%A_r4(:,:) = real(CS%layers(i)%A(:,:), kind=4)

! Reading bias b
fieldname = trim('b') // trim(layer_num_str)
call MOM_read_data(NNfile, fieldname, CS%layers(i)%b)
CS%layers(i)%b_r4(:) = real(CS%layers(i)%b(:), kind=4)
enddo

! No activation function for the last layer
Expand Down Expand Up @@ -170,6 +175,8 @@ subroutine ANN_allocate(CS, num_layers, layer_sizes)

allocate( CS%layers(l)%A(CS%layers(l)%output_width, CS%layers(l)%input_width) )
allocate( CS%layers(l)%b(CS%layers(l)%output_width) )
allocate( CS%layers(l)%A_r4(CS%layers(l)%output_width, CS%layers(l)%input_width) )
allocate( CS%layers(l)%b_r4(CS%layers(l)%output_width) )

CS%parameters = CS%parameters &
+ CS%layer_sizes(l) * CS%layer_sizes(l+1) & ! For weights
Expand Down Expand Up @@ -228,6 +235,8 @@ subroutine ANN_end(CS)
do i = 1, CS%num_layers-1
deallocate(CS%layers(i)%A)
deallocate(CS%layers(i)%b)
deallocate(CS%layers(i)%A_r4)
deallocate(CS%layers(i)%b_r4)
enddo
deallocate(CS%layers)

Expand All @@ -242,6 +251,15 @@ pure elemental function activation_fn(x) result (y)

end function activation_fn

!> The default activation function in real(4) precision
pure elemental function activation_fn_r4(x) result (y)
real(4), intent(in) :: x !< Scalar input value [nondim]
real(4) :: y !< Scalar output value [nondim]

y = max(x, 0.0_4) ! ReLU activation

end function activation_fn_r4

!> Single application of ANN inference using vector input and output
!!
!! This implementation is the simplest using allocation and de-allocation
Expand Down Expand Up @@ -440,6 +458,82 @@ subroutine layer_apply_sio(nij, x, y, layer)
end subroutine layer_apply_sio
end subroutine ANN_apply_array_sio

!> Same as ANN_apply_array_sio, but casts input and output
!! vectors to real(4) internally and performs ANN inference
!! in real(4) precision. On average, twice faster than original
!! ANN_apply_array_sio.
subroutine ANN_apply_array_sio_r4(nij, x, y, CS)
type(ANN_CS), intent(in) :: CS !< ANN control structure
integer, intent(in) :: nij !< Size of spatial dimension
real, intent(in) :: x(nij, CS%layer_sizes(1)) !< input [arbitrary]
real, intent(inout) :: y(nij, CS%layer_sizes(CS%num_layers)) !< output [arbitrary]
! Local variables
real(4), allocatable :: x_1(:,:), x_2(:,:) ! intermediate states [nondim]
integer :: l, i, o ! Layer, input, output index

allocate( x_1( nij, maxval( CS%layer_sizes(:) ) ) )
allocate( x_2( nij, maxval( CS%layer_sizes(:) ) ) )

! Normalize input
do i = 1, CS%layer_sizes(1)
x_1(:,i) = real(( x(:,i) - CS%input_means(i) ) * CS%input_norms(i), kind=4)
enddo

! Apply Linear layers
do l = 1, CS%num_layers-2, 2
call layer_apply_sio(nij, x_1, x_2, CS%layers(l))
call layer_apply_sio(nij, x_2, x_1, CS%layers(l+1))
enddo
if (mod(CS%num_layers,2)==0) then
call layer_apply_sio(nij, x_1, x_2, CS%layers(CS%num_layers-1))
! Un-normalize output
do o = 1, CS%layer_sizes(CS%num_layers)
y(:,o) = real(x_2(:,o) * CS%output_norms(o) + CS%output_means(o), kind=8)
enddo
else
! Un-normalize output
do o = 1, CS%layer_sizes(CS%num_layers)
y(:,o) = real(x_1(:,o) * CS%output_norms(o) + CS%output_means(o), kind=8)
enddo
endif

deallocate(x_1, x_2)

contains

!> Applies linear layer to input data x and stores the result in y with
!! y = A*x + b with optional application of the activation function so the
!! overall operations is ReLU(A*x + b)
subroutine layer_apply_sio(nij, x, y, layer)
type(layer_type), intent(in) :: layer !< Linear layer
integer, intent(in) :: nij !< Size of spatial dimension
real(4), intent(in) :: x(nij, layer%input_width) !< Input vector [nondim]
real(4), intent(inout) :: y(nij, layer%output_width) !< Output vector [nondim]
! Local variables
integer :: i, o ! Input, output indices
! We introduce rescaling which gives bitwise the same answer if there is no underflow
! We assume that overflow is unlikely because x is always on the order of one
real(4), parameter :: boost = 2.**33 ! Shifts exponent by ~ 1e+10
real(4), parameter :: inv_boost = 2.**(-33) ! Shifts exponent back

do o = 1, layer%output_width
! Add bias
y(:,o) = layer%b_r4(o) * boost
! Multiply by kernel
do i = 1, layer%input_width
y(:,o) = y(:,o) + (x(:,i) * boost) * layer%A_r4(o, i)
enddo
! Apply activation function
if (layer%activation) then
y(:,o) = activation_fn_r4(y(:,o) * inv_boost)
else
y(:,o) = y(:,o) * inv_boost
endif
enddo

end subroutine layer_apply_sio
end subroutine ANN_apply_array_sio_r4

!> Sets weights and bias for a single layer
subroutine set_layer(ANN, layer, weights, biases, activation)
type(ANN_CS), intent(inout) :: ANN !< ANN control structure
Expand All @@ -456,12 +550,14 @@ subroutine set_layer(ANN, layer, weights, biases, activation)
if ( size(biases) /= size(ANN%layers(layer)%b) ) &
call MOM_error(FATAL, "MOM_ANN, set_layer: mismatch in size of biases")
ANN%layers(layer)%b(:) = biases(:)
ANN%layers(layer)%b_r4(:) = real(biases(:), kind=4)

if ( size(weights,1) /= size(ANN%layers(layer)%A,1) ) &
call MOM_error(FATAL, "MOM_ANN, set_layer: mismatch in size of weights (first dim)")
if ( size(weights,2) /= size(ANN%layers(layer)%A,2) ) &
call MOM_error(FATAL, "MOM_ANN, set_layer: mismatch in size of weights (second dim)")
ANN%layers(layer)%A(:,:) = weights(:,:)
ANN%layers(layer)%A_r4(:,:) = real(weights(:,:), kind=4)

ANN%layers(layer)%activation = activation
end subroutine set_layer
Expand Down Expand Up @@ -669,6 +765,9 @@ logical function ANN_unit_tests(verbose)
! as above with v5 of ANN_apply applied to 2d inputs, x(space,feature)
call ANN_apply_array_sio(2, reshape([0.,1.,2.,3.,4.,5.,6.,7.],[2,4]), y2, ANN)
call test%real_arr(2, y2, [2.,5.], 'Rectifier+summation+bias+norms 4-layer array v2')

call ANN_apply_array_sio_r4(2, reshape([0.,1.,2.,3.,4.,5.,6.,7.],[2,4]), y2, ANN)
call test%real_arr(2, y2, [2.,5.], 'Rectifier+summation+bias+norms 4-layer array v2 real(4)')
deallocate( y2 )

call ANN_end(ANN)
Expand All @@ -685,6 +784,8 @@ logical function ANN_unit_tests(verbose)
deallocate( y )
call ANN_random(ANN, nlay, widths)
allocate( x(widths(1)), y(widths(nlay)), y_good(widths(nlay)) )
call random_number(x)
x(:) = 2. * x(:) - 1.
call ANN_apply_vector_orig(x, y_good, ANN)
call ANN_apply_vector_oi(x, y, ANN)
rand_res = rand_res .or. maxval( abs( y(:) - y_good(:) ) ) > 0. ! Check results from v2 = v1
Expand All @@ -695,6 +796,9 @@ logical function ANN_unit_tests(verbose)
call ANN_apply_array_sio(20, x2, y2, ANN)
rand_res = rand_res .or. maxval( abs( maxval(y2(:,:),1) - y_good(:) ) ) > 0. ! Check results from array v2 = v1
rand_res = rand_res .or. maxval( abs( minval(y2(:,:),1) - y_good(:) ) ) > 0. ! Check results from array v2 = v1
call ANN_apply_array_sio_r4(20, x2, y2, ANN)
rand_res = rand_res .or. maxval( abs( maxval(y2(:,:),1) - y_good(:) ) ) > 1.e-5 ! Lower Real(4) precision
rand_res = rand_res .or. maxval( abs( minval(y2(:,:),1) - y_good(:) ) ) > 1.e-5 ! Lower Real(4) precision
deallocate( x, y, y_good, x2, y2 )
call ANN_end(ANN)
enddo
Expand Down
Loading