diff --git a/README.md b/README.md index 65964786..e35673e6 100644 --- a/README.md +++ b/README.md @@ -33,11 +33,9 @@ Read the paper [here](https://arxiv.org/abs/1902.06714). | Embedding | `embedding` | n/a | 2 | ✅ | ✅ | | Dense (fully-connected) | `dense` | `input1d`, `dense`, `dropout`, `flatten` | 1 | ✅ | ✅ | | Dropout | `dropout` | `dense`, `flatten`, `input1d` | 1 | ✅ | ✅ | -| Locally connected (1-d) | `locally_connected1d` | `input2d`, `locally_connected1d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | -| Convolutional (1-d) | `conv1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | -| Convolutional (2-d) | `conv2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | -| Max-pooling (1-d) | `maxpool1d` | `input2d`, `conv1d`, `maxpool1d`, `reshape2d` | 2 | ✅ | ✅ | -| Max-pooling (2-d) | `maxpool2d` | `input3d`, `conv2d`, `maxpool2d`, `reshape` | 3 | ✅ | ✅ | +| Locally connected (1-d) | `locally_connected` | `input`, `locally_connected`, `conv`, `maxpool`, `reshape` | 2 | ✅ | ✅ | +| Convolutional (1-d and 2-d) | `conv` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 | ✅ | ✅ | +| Max-pooling (1-d and 2-d) | `maxpool` | `input`, `conv`, `maxpool`, `reshape` | 2, 3 | ✅ | ✅ | | Linear (2-d) | `linear2d` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Self-attention | `self_attention` | `input2d`, `layernorm`, `linear2d`, `self_attention` | 2 | ✅ | ✅ | | Layer Normalization | `layernorm` | `linear2d`, `self_attention` | 2 | ✅ | ✅ | diff --git a/example/cnn_mnist.f90 b/example/cnn_mnist.f90 index d2f61723..1ebe081c 100644 --- a/example/cnn_mnist.f90 +++ b/example/cnn_mnist.f90 @@ -1,7 +1,7 @@ program cnn_mnist use nf, only: network, sgd, & - input, conv2d, maxpool2d, flatten, dense, reshape, & + input, conv, maxpool, flatten, dense, reshape, & load_mnist, label_digits, softmax, relu implicit none @@ -21,10 +21,10 @@ program cnn_mnist net = network([ & input(784), & reshape(1, 28, 28), & - conv2d(filters=8, kernel_size=3, activation=relu()), & - maxpool2d(pool_size=2), & - conv2d(filters=16, kernel_size=3, activation=relu()), & - maxpool2d(pool_size=2), & + conv(filters=8, kernel_width=3, kernel_height=3, activation=relu()), & + maxpool(pool_width=2, pool_height=2, stride=2), & + conv(filters=16, kernel_width=3, kernel_height=3, activation=relu()), & + maxpool(pool_width=2, pool_height=2, stride=2), & dense(10, activation=softmax()) & ]) diff --git a/example/cnn_mnist_1d.f90 b/example/cnn_mnist_1d.f90 index b350a2f0..059d09c5 100644 --- a/example/cnn_mnist_1d.f90 +++ b/example/cnn_mnist_1d.f90 @@ -1,7 +1,7 @@ program cnn_mnist_1d use nf, only: network, sgd, & - input, conv1d, maxpool1d, flatten, dense, reshape, locally_connected1d, & + input, maxpool, flatten, dense, reshape, locally_connected, & load_mnist, label_digits, softmax, relu implicit none @@ -21,10 +21,10 @@ program cnn_mnist_1d net = network([ & input(784), & reshape(28, 28), & - locally_connected1d(filters=8, kernel_size=3, activation=relu()), & - maxpool1d(pool_size=2), & - locally_connected1d(filters=16, kernel_size=3, activation=relu()), & - maxpool1d(pool_size=2), & + locally_connected(filters=8, kernel_size=3, activation=relu()), & + maxpool(pool_width=2, stride=2), & + locally_connected(filters=16, kernel_size=3, activation=relu()), & + maxpool(pool_width=2, stride=2), & dense(10, activation=softmax()) & ]) diff --git a/fpm.toml b/fpm.toml index 1f2c2ac9..0d85b9dc 100644 --- a/fpm.toml +++ b/fpm.toml @@ -1,5 +1,5 @@ name = "neural-fortran" -version = "0.21.0" +version = "0.22.0" license = "MIT" author = "Milan Curcic" maintainer = "mcurcic@miami.edu" diff --git a/src/nf.f90 b/src/nf.f90 index f644826d..c7b21656 100644 --- a/src/nf.f90 +++ b/src/nf.f90 @@ -3,8 +3,7 @@ module nf use nf_datasets_mnist, only: label_digits, load_mnist use nf_layer, only: layer use nf_layer_constructors, only: & - conv1d, & - conv2d, & + conv, & dense, & dropout, & embedding, & @@ -12,9 +11,8 @@ module nf input, & layernorm, & linear2d, & - locally_connected1d, & - maxpool1d, & - maxpool2d, & + locally_connected, & + maxpool, & reshape, & self_attention use nf_loss, only: mse, quadratic diff --git a/src/nf/nf_layer_constructors.f90 b/src/nf/nf_layer_constructors.f90 index d3f06ca3..80860bdf 100644 --- a/src/nf/nf_layer_constructors.f90 +++ b/src/nf/nf_layer_constructors.f90 @@ -9,16 +9,14 @@ module nf_layer_constructors private public :: & - conv1d, & - conv2d, & + conv, & dense, & dropout, & flatten, & input, & linear2d, & - locally_connected1d, & - maxpool1d, & - maxpool2d, & + locally_connected, & + maxpool, & reshape, & self_attention, & embedding, & @@ -94,111 +92,28 @@ end function input3d end interface input - interface reshape - - module function reshape2d(dim1, dim2) result(res) - !! Rank-1 to rank-2 reshape layer constructor. - integer, intent(in) :: dim1, dim2 - !! Shape of the output - type(layer) :: res - !! Resulting layer instance - end function reshape2d - - module function reshape3d(dim1, dim2, dim3) result(res) - !! Rank-1 to rank-3 reshape layer constructor. - integer, intent(in) :: dim1, dim2, dim3 - !! Shape of the output - type(layer) :: res - !! Resulting layer instance - end function reshape3d - - end interface reshape - - - interface - - module function dense(layer_size, activation) result(res) - !! Dense (fully-connected) layer constructor. - !! - !! This layer is a building block for dense, fully-connected networks, - !! or for an output layer of a convolutional network. - !! A dense layer must not be the first layer in the network. - !! - !! Example: - !! - !! ``` - !! use nf, only :: dense, layer, relu - !! type(layer) :: dense_layer - !! dense_layer = dense(10) - !! dense_layer = dense(10, activation=relu()) - !! ``` - integer, intent(in) :: layer_size - !! The number of neurons in a dense layer - class(activation_function), intent(in), optional :: activation - !! Activation function instance (default sigmoid) - type(layer) :: res - !! Resulting layer instance - end function dense - - module function dropout(rate) result(res) - !! Create a dropout layer with a given dropout rate. - !! - !! This layer is for randomly disabling neurons during training. - !! - !! Example: - !! - !! ``` - !! use nf, only :: dropout, layer - !! type(layer) :: dropout_layer - !! dropout_layer = dropout(rate=0.5) - !! ``` - real, intent(in) :: rate - !! Dropout rate - fraction of neurons to randomly disable during training - type(layer) :: res - !! Resulting layer instance - end function dropout - - module function flatten() result(res) - !! Flatten (3-d -> 1-d) layer constructor. - !! - !! Use this layer to chain layers with 3-d outputs to layers with 1-d - !! inputs. For example, to chain a `conv2d` or a `maxpool2d` layer - !! with a `dense` layer for a CNN for classification, place a `flatten` - !! layer between them. - !! - !! A flatten layer must not be the first layer in the network. - !! - !! Example: - !! - !! ``` - !! use nf, only :: flatten, layer - !! type(layer) :: flatten_layer - !! flatten_layer = flatten() - !! ``` - type(layer) :: res - !! Resulting layer instance - end function flatten + interface conv - module function conv1d(filters, kernel_size, activation) result(res) + module function conv1d(filters, kernel_width, activation) result(res) !! 1-d convolutional layer constructor. !! !! This layer is for building 1-d convolutional network. !! Although the established convention is to call these layers 1-d, - !! the shape of the data is actually 2-d: image width - !! and the number of channels. + !! the shape of the data is actually 2-d: image width and the number of channels. !! A conv1d layer must not be the first layer in the network. !! + !! This specific function is available under a generic name `conv`. + !! !! Example: !! !! ``` - !! use nf, only :: conv1d, layer + !! use nf, only :: conv, layer !! type(layer) :: conv1d_layer - !! conv1d_layer = dense(filters=32, kernel_size=3) - !! conv1d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! conv1d_layer = conv(filters=32, kernel_size=3) !! ``` integer, intent(in) :: filters !! Number of filters in the output of the layer - integer, intent(in) :: kernel_size + integer, intent(in) :: kernel_width !! Width of the convolution window, commonly 3 or 5 class(activation_function), intent(in), optional :: activation !! Activation function (default sigmoid) @@ -206,39 +121,47 @@ module function conv1d(filters, kernel_size, activation) result(res) !! Resulting layer instance end function conv1d - module function conv2d(filters, kernel_size, activation) result(res) + module function conv2d(filters, kernel_width, kernel_height, activation) result(res) !! 2-d convolutional layer constructor. !! !! This layer is for building 2-d convolutional network. !! Although the established convention is to call these layers 2-d, - !! the shape of the data is actuall 3-d: image width, image height, - !! and the number of channels. + !! the shape of the data is actually 3-d: image width, image height, + !! and the number of channels. !! A conv2d layer must not be the first layer in the network. !! + !! This specific function is available under a generic name `conv`. + !! !! Example: !! !! ``` - !! use nf, only :: conv2d, layer - !! type(layer) :: conv2d_layer - !! conv2d_layer = dense(filters=32, kernel_size=3) - !! conv2d_layer = dense(filters=32, kernel_size=3, activation='relu') + !! use nf, only :: conv, layer + !! type(layer) :: conv2d_layer + !! conv2d_layer = conv(filters=32, kernel_width=3, kernel_height=3) !! ``` integer, intent(in) :: filters !! Number of filters in the output of the layer - integer, intent(in) :: kernel_size + integer, intent(in) :: kernel_width !! Width of the convolution window, commonly 3 or 5 + integer, intent(in) :: kernel_height + !! Height of the convolution window, commonly 3 or 5 class(activation_function), intent(in), optional :: activation !! Activation function (default sigmoid) type(layer) :: res !! Resulting layer instance end function conv2d + + end interface conv + + + interface locally_connected module function locally_connected1d(filters, kernel_size, activation) result(res) !! 1-d locally connected network constructor !! !! This layer is for building 1-d locally connected network. !! Although the established convention is to call these layers 1-d, - !! the shape of the data is actuall 2-d: image width, + !! the shape of the data is actually 2-d: image width, !! and the number of channels. !! A locally connected 1d layer must not be the first layer in the network. !! @@ -260,50 +183,145 @@ module function locally_connected1d(filters, kernel_size, activation) result(res !! Resulting layer instance end function locally_connected1d - module function maxpool1d(pool_size, stride) result(res) + end interface locally_connected + + + interface maxpool + + module function maxpool1d(pool_width, stride) result(res) !! 1-d maxpooling layer constructor. !! !! This layer is for downscaling other layers, typically `conv1d`. !! + !! This specific function is available under a generic name `maxpool`. + !! !! Example: !! !! ``` !! use nf, only :: maxpool1d, layer !! type(layer) :: maxpool1d_layer - !! maxpool1d_layer = maxpool1d(pool_size=2) - !! maxpool1d_layer = maxpool1d(pool_size=2, stride=3) + !! maxpool1d_layer = maxpool1d(pool_width=2, stride=2) !! ``` - integer, intent(in) :: pool_size + integer, intent(in) :: pool_width !! Width of the pooling window, commonly 2 - integer, intent(in), optional :: stride - !! Stride of the pooling window, commonly equal to `pool_size`; - !! Defaults to `pool_size` if omitted. + integer, intent(in) :: stride + !! Stride of the pooling window, commonly equal to `pool_width`; type(layer) :: res !! Resulting layer instance end function maxpool1d - module function maxpool2d(pool_size, stride) result(res) + module function maxpool2d(pool_width, pool_height, stride) result(res) !! 2-d maxpooling layer constructor. !! !! This layer is for downscaling other layers, typically `conv2d`. !! + !! This specific function is available under a generic name `maxpool`. + !! !! Example: !! !! ``` !! use nf, only :: maxpool2d, layer !! type(layer) :: maxpool2d_layer - !! maxpool2d_layer = maxpool2d(pool_size=2) - !! maxpool2d_layer = maxpool2d(pool_size=2, stride=3) + !! maxpool2d_layer = maxpool2d(pool_width=2, pool_height=2, stride=2) !! ``` - integer, intent(in) :: pool_size + integer, intent(in) :: pool_width !! Width of the pooling window, commonly 2 - integer, intent(in), optional :: stride - !! Stride of the pooling window, commonly equal to `pool_size`; - !! Defaults to `pool_size` if omitted. + integer, intent(in) :: pool_height + !! Height of the pooling window; currently must be equal to pool_width + integer, intent(in) :: stride + !! Stride of the pooling window, commonly equal to `pool_width`; type(layer) :: res !! Resulting layer instance end function maxpool2d + end interface maxpool + + + interface reshape + + module function reshape2d(dim1, dim2) result(res) + !! Rank-1 to rank-2 reshape layer constructor. + integer, intent(in) :: dim1, dim2 + !! Shape of the output + type(layer) :: res + !! Resulting layer instance + end function reshape2d + + module function reshape3d(dim1, dim2, dim3) result(res) + !! Rank-1 to rank-3 reshape layer constructor. + integer, intent(in) :: dim1, dim2, dim3 + !! Shape of the output + type(layer) :: res + !! Resulting layer instance + end function reshape3d + + end interface reshape + + + interface + + module function dense(layer_size, activation) result(res) + !! Dense (fully-connected) layer constructor. + !! + !! This layer is a building block for dense, fully-connected networks, + !! or for an output layer of a convolutional network. + !! A dense layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: dense, layer, relu + !! type(layer) :: dense_layer + !! dense_layer = dense(10) + !! dense_layer = dense(10, activation=relu()) + !! ``` + integer, intent(in) :: layer_size + !! The number of neurons in a dense layer + class(activation_function), intent(in), optional :: activation + !! Activation function instance (default sigmoid) + type(layer) :: res + !! Resulting layer instance + end function dense + + module function dropout(rate) result(res) + !! Create a dropout layer with a given dropout rate. + !! + !! This layer is for randomly disabling neurons during training. + !! + !! Example: + !! + !! ``` + !! use nf, only :: dropout, layer + !! type(layer) :: dropout_layer + !! dropout_layer = dropout(rate=0.5) + !! ``` + real, intent(in) :: rate + !! Dropout rate - fraction of neurons to randomly disable during training + type(layer) :: res + !! Resulting layer instance + end function dropout + + module function flatten() result(res) + !! Flatten (3-d -> 1-d) layer constructor. + !! + !! Use this layer to chain layers with 3-d outputs to layers with 1-d + !! inputs. For example, to chain a `conv2d` or a `maxpool2d` layer + !! with a `dense` layer for a CNN for classification, place a `flatten` + !! layer between them. + !! + !! A flatten layer must not be the first layer in the network. + !! + !! Example: + !! + !! ``` + !! use nf, only :: flatten, layer + !! type(layer) :: flatten_layer + !! flatten_layer = flatten() + !! ``` + type(layer) :: res + !! Resulting layer instance + end function flatten + module function linear2d(out_features) result(res) !! Rank-2 (sequence_length, out_features) linear layer constructor. !! sequence_length is determined at layer initialization, based on the diff --git a/src/nf/nf_layer_constructors_submodule.f90 b/src/nf/nf_layer_constructors_submodule.f90 index 1665d38a..7918ee1c 100644 --- a/src/nf/nf_layer_constructors_submodule.f90 +++ b/src/nf/nf_layer_constructors_submodule.f90 @@ -24,9 +24,9 @@ contains - module function conv1d(filters, kernel_size, activation) result(res) + module function conv1d(filters, kernel_width, activation) result(res) integer, intent(in) :: filters - integer, intent(in) :: kernel_size + integer, intent(in) :: kernel_width class(activation_function), intent(in), optional :: activation type(layer) :: res @@ -44,19 +44,26 @@ module function conv1d(filters, kernel_size, activation) result(res) allocate( & res % p, & - source=conv1d_layer(filters, kernel_size, activation_tmp) & + source=conv1d_layer(filters, kernel_width, activation_tmp) & ) end function conv1d - module function conv2d(filters, kernel_size, activation) result(res) + module function conv2d(filters, kernel_width, kernel_height, activation) result(res) integer, intent(in) :: filters - integer, intent(in) :: kernel_size + integer, intent(in) :: kernel_width + integer, intent(in) :: kernel_height class(activation_function), intent(in), optional :: activation type(layer) :: res class(activation_function), allocatable :: activation_tmp + ! Enforce kernel_width == kernel_height for now; + ! If non-square kernels show to be desired, we'll relax this constraint + ! and refactor conv2d_layer to work with non-square kernels. + if (kernel_width /= kernel_height) & + error stop 'kernel_width must equal kernel_height in a conv2d layer' + res % name = 'conv2d' if (present(activation)) then @@ -69,7 +76,7 @@ module function conv2d(filters, kernel_size, activation) result(res) allocate( & res % p, & - source=conv2d_layer(filters, kernel_size, activation_tmp) & + source=conv2d_layer(filters, kernel_width, activation_tmp) & ) end function conv2d @@ -172,58 +179,49 @@ module function input3d(dim1, dim2, dim3) result(res) res % initialized = .true. end function input3d - module function maxpool1d(pool_size, stride) result(res) - integer, intent(in) :: pool_size - integer, intent(in), optional :: stride - integer :: stride_ + module function maxpool1d(pool_width, stride) result(res) + integer, intent(in) :: pool_width + integer, intent(in) :: stride type(layer) :: res - if (pool_size < 2) & - error stop 'pool_size must be >= 2 in a maxpool1d layer' - - ! Stride defaults to pool_size if not provided - if (present(stride)) then - stride_ = stride - else - stride_ = pool_size - end if + if (pool_width < 2) & + error stop 'pool_width must be >= 2 in a maxpool1d layer' - if (stride_ < 1) & + if (stride < 1) & error stop 'stride must be >= 1 in a maxpool1d layer' res % name = 'maxpool1d' allocate( & res % p, & - source=maxpool1d_layer(pool_size, stride_) & + source=maxpool1d_layer(pool_width, stride) & ) end function maxpool1d - module function maxpool2d(pool_size, stride) result(res) - integer, intent(in) :: pool_size - integer, intent(in), optional :: stride - integer :: stride_ + module function maxpool2d(pool_width, pool_height, stride) result(res) + integer, intent(in) :: pool_width + integer, intent(in) :: pool_height + integer, intent(in) :: stride type(layer) :: res - if (pool_size < 2) & - error stop 'pool_size must be >= 2 in a maxpool2d layer' + if (pool_width < 2) & + error stop 'pool_width must be >= 2 in a maxpool2d layer' - ! Stride defaults to pool_size if not provided - if (present(stride)) then - stride_ = stride - else - stride_ = pool_size - end if + ! Enforce pool_width == pool_height for now; + ! If non-square poolings show to be desired, we'll relax this constraint + ! and refactor maxpool2d_layer to work with non-square kernels. + if (pool_width /= pool_height) & + error stop 'pool_width must equal pool_height in a maxpool2d layer' - if (stride_ < 1) & + if (stride < 1) & error stop 'stride must be >= 1 in a maxpool2d layer' res % name = 'maxpool2d' allocate( & res % p, & - source=maxpool2d_layer(pool_size, stride_) & + source=maxpool2d_layer(pool_width, stride) & ) end function maxpool2d diff --git a/src/nf/nf_network_submodule.f90 b/src/nf/nf_network_submodule.f90 index 449b5a5b..d8f5ff50 100644 --- a/src/nf/nf_network_submodule.f90 +++ b/src/nf/nf_network_submodule.f90 @@ -18,7 +18,7 @@ use nf_embedding_layer, only: embedding_layer use nf_layernorm_layer, only: layernorm_layer use nf_layer, only: layer - use nf_layer_constructors, only: conv1d, conv2d, dense, flatten, input, maxpool1d, maxpool2d, reshape + use nf_layer_constructors, only: flatten use nf_loss, only: quadratic use nf_optimizers, only: optimizer_base_type, sgd use nf_parallel, only: tile_indices diff --git a/test/test_conv1d_layer.f90 b/test/test_conv1d_layer.f90 index 81d03c1f..b80b520b 100644 --- a/test/test_conv1d_layer.f90 +++ b/test/test_conv1d_layer.f90 @@ -1,7 +1,7 @@ program test_conv1d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: conv1d, input, layer + use nf, only: conv, input, layer use nf_input2d_layer, only: input2d_layer implicit none @@ -12,7 +12,7 @@ program test_conv1d_layer real, parameter :: tolerance = 1e-7 logical :: ok = .true. - conv1d_layer = conv1d(filters, kernel_size) + conv1d_layer = conv(filters, kernel_size) if (.not. conv1d_layer % name == 'conv1d') then ok = .false. @@ -52,7 +52,7 @@ program test_conv1d_layer sample_input = 0 input_layer = input(1, 3) - conv1d_layer = conv1d(filters, kernel_size) + conv1d_layer = conv(filters, kernel_size) call conv1d_layer % init(input_layer) select type(this_layer => input_layer % p); type is(input2d_layer) diff --git a/test/test_conv1d_network.f90 b/test/test_conv1d_network.f90 index 5a353cf9..88289ab4 100644 --- a/test/test_conv1d_network.f90 +++ b/test/test_conv1d_network.f90 @@ -1,7 +1,7 @@ program test_conv1d_network use iso_fortran_env, only: stderr => error_unit - use nf, only: conv1d, input, network, dense, sgd, maxpool1d + use nf, only: conv, input, network, dense, sgd, maxpool implicit none @@ -12,8 +12,8 @@ program test_conv1d_network ! 3-layer convolutional network net = network([ & input(3, 32), & - conv1d(filters=16, kernel_size=3), & - conv1d(filters=32, kernel_size=3) & + conv(filters=16, kernel_width=3), & + conv(filters=32, kernel_width=3) & ]) if (.not. size(net % layers) == 3) then @@ -49,8 +49,8 @@ program test_conv1d_network cnn = network([ & input(1, 5), & - conv1d(filters=1, kernel_size=3), & - conv1d(filters=1, kernel_size=3), & + conv(filters=1, kernel_width=3), & + conv(filters=1, kernel_width=3), & dense(1) & ]) @@ -86,9 +86,9 @@ program test_conv1d_network cnn = network([ & input(1, 8), & - conv1d(filters=1, kernel_size=3), & - maxpool1d(pool_size=2), & - conv1d(filters=1, kernel_size=3), & + conv(filters=1, kernel_width=3), & + maxpool(pool_width=2, stride=2), & + conv(filters=1, kernel_width=3), & dense(1) & ]) @@ -121,9 +121,9 @@ program test_conv1d_network cnn = network([ & input(1, 12), & - conv1d(filters=1, kernel_size=3), & ! 1x12x12 input, 1x10x10 output - maxpool1d(pool_size=2), & ! 1x10x10 input, 1x5x5 output - conv1d(filters=1, kernel_size=3), & ! 1x5x5 input, 1x3x3 output + conv(filters=1, kernel_width=3), & ! 1x12x12 input, 1x10x10 output + maxpool(pool_width=2, stride=2), & ! 1x10x10 input, 1x5x5 output + conv(filters=1, kernel_width=3), & ! 1x5x5 input, 1x3x3 output dense(9) & ! 9 outputs ]) diff --git a/test/test_conv2d_layer.f90 b/test/test_conv2d_layer.f90 index 10a14c5e..2d5868b9 100644 --- a/test/test_conv2d_layer.f90 +++ b/test/test_conv2d_layer.f90 @@ -1,7 +1,7 @@ program test_conv2d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: conv2d, input, layer + use nf, only: conv, input, layer use nf_input3d_layer, only: input3d_layer implicit none @@ -12,7 +12,7 @@ program test_conv2d_layer real, parameter :: tolerance = 1e-7 logical :: ok = .true. - conv_layer = conv2d(filters, kernel_size) + conv_layer = conv(filters, kernel_size, kernel_size) if (.not. conv_layer % name == 'conv2d') then ok = .false. @@ -52,7 +52,7 @@ program test_conv2d_layer sample_input = 0 input_layer = input(1, 3, 3) - conv_layer = conv2d(filters, kernel_size) + conv_layer = conv(filters, kernel_size, kernel_size) call conv_layer % init(input_layer) select type(this_layer => input_layer % p); type is(input3d_layer) diff --git a/test/test_conv2d_network.f90 b/test/test_conv2d_network.f90 index 73c4595a..c293a1d2 100644 --- a/test/test_conv2d_network.f90 +++ b/test/test_conv2d_network.f90 @@ -1,7 +1,7 @@ program test_conv2d_network use iso_fortran_env, only: stderr => error_unit - use nf, only: conv2d, input, network, dense, sgd, maxpool2d + use nf, only: conv, input, network, dense, sgd, maxpool implicit none @@ -12,8 +12,8 @@ program test_conv2d_network ! 3-layer convolutional network net = network([ & input(3, 32, 32), & - conv2d(filters=16, kernel_size=3), & - conv2d(filters=32, kernel_size=3) & + conv(filters=16, kernel_width=3, kernel_height=3), & + conv(filters=32, kernel_width=3, kernel_height=3) & ]) if (.not. size(net % layers) == 3) then @@ -49,8 +49,8 @@ program test_conv2d_network cnn = network([ & input(1, 5, 5), & - conv2d(filters=1, kernel_size=3), & - conv2d(filters=1, kernel_size=3), & + conv(filters=1, kernel_width=3, kernel_height=3), & + conv(filters=1, kernel_width=3, kernel_height=3), & dense(1) & ]) @@ -86,9 +86,9 @@ program test_conv2d_network cnn = network([ & input(1, 8, 8), & - conv2d(filters=1, kernel_size=3), & - maxpool2d(pool_size=2), & - conv2d(filters=1, kernel_size=3), & + conv(filters=1, kernel_width=3, kernel_height=3), & + maxpool(pool_width=2, pool_height=2, stride=2), & + conv(filters=1, kernel_width=3, kernel_height=3), & dense(1) & ]) @@ -121,9 +121,9 @@ program test_conv2d_network cnn = network([ & input(1, 12, 12), & - conv2d(filters=1, kernel_size=3), & ! 1x12x12 input, 1x10x10 output - maxpool2d(pool_size=2), & ! 1x10x10 input, 1x5x5 output - conv2d(filters=1, kernel_size=3), & ! 1x5x5 input, 1x3x3 output + conv(filters=1, kernel_width=3, kernel_height=3), & ! 1x12x12 input, 1x10x10 output + maxpool(pool_width=2, pool_height=2, stride=2), & ! 1x10x10 input, 1x5x5 output + conv(filters=1, kernel_width=3, kernel_height=3), & ! 1x5x5 input, 1x3x3 output dense(9) & ! 9 outputs ]) diff --git a/test/test_get_set_network_params.f90 b/test/test_get_set_network_params.f90 index 71963a1c..f2a3b6a8 100644 --- a/test/test_get_set_network_params.f90 +++ b/test/test_get_set_network_params.f90 @@ -1,6 +1,6 @@ program test_get_set_network_params use iso_fortran_env, only: stderr => error_unit - use nf, only: conv2d, dense, flatten, input, maxpool2d, network + use nf, only: conv, dense, flatten, input, network implicit none type(network) :: net logical :: ok = .true. @@ -10,7 +10,7 @@ program test_get_set_network_params ! First test get_num_params() net = network([ & input(3, 5, 5), & ! 5 x 5 image with 3 channels - conv2d(filters=2, kernel_size=3), & ! kernel shape [2, 3, 3, 3], output shape [2, 3, 3], 56 parameters total + conv(filters=2, kernel_width=3, kernel_height=3), & ! kernel shape [2, 3, 3, 3], output shape [2, 3, 3], 56 parameters total flatten(), & dense(4) & ! weights shape [72], biases shape [4], 76 parameters total ]) @@ -46,7 +46,7 @@ program test_get_set_network_params ! Finally, test set_params() and get_params() for a conv2d layer net = network([ & input(1, 3, 3), & - conv2d(filters=1, kernel_size=3) & + conv(filters=1, kernel_width=3, kernel_height=3) & ]) call net % set_params(test_params_conv2d) diff --git a/test/test_insert_flatten.f90 b/test/test_insert_flatten.f90 index 18e41b81..3437b746 100644 --- a/test/test_insert_flatten.f90 +++ b/test/test_insert_flatten.f90 @@ -1,7 +1,7 @@ program test_insert_flatten use iso_fortran_env, only: stderr => error_unit - use nf, only: network, input, conv2d, maxpool2d, flatten, dense, reshape + use nf, only: network, input, conv, maxpool, flatten, dense, reshape implicit none @@ -20,7 +20,7 @@ program test_insert_flatten net = network([ & input(3, 32, 32), & - conv2d(filters=1, kernel_size=3), & + conv(filters=1, kernel_width=3, kernel_height=3), & dense(10) & ]) @@ -33,14 +33,14 @@ program test_insert_flatten net = network([ & input(3, 32, 32), & - conv2d(filters=1, kernel_size=3), & - maxpool2d(pool_size=2, stride=2), & + conv(filters=1, kernel_width=3, kernel_height=3), & + maxpool(pool_width=2, stride=2), & dense(10) & ]) if (.not. net % layers(4) % name == 'flatten') then ok = .false. - write(stderr, '(a)') 'flatten layer inserted after maxpool2d.. failed' + write(stderr, '(a)') 'flatten layer inserted after maxpool.. failed' end if net = network([ & diff --git a/test/test_locally_connected1d_layer.f90 b/test/test_locally_connected1d_layer.f90 index e8a30cfc..cde0a965 100644 --- a/test/test_locally_connected1d_layer.f90 +++ b/test/test_locally_connected1d_layer.f90 @@ -1,7 +1,7 @@ program test_locally_connected1d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: locally_connected1d, input, layer + use nf, only: locally_connected, input, layer use nf_input2d_layer, only: input2d_layer implicit none @@ -12,7 +12,7 @@ program test_locally_connected1d_layer real, parameter :: tolerance = 1e-7 logical :: ok = .true. - locally_connected_1d_layer = locally_connected1d(filters, kernel_size) + locally_connected_1d_layer = locally_connected(filters, kernel_size) if (.not. locally_connected_1d_layer % name == 'locally_connected1d') then ok = .false. @@ -52,7 +52,7 @@ program test_locally_connected1d_layer sample_input = 0 input_layer = input(1, 3) - locally_connected_1d_layer = locally_connected1d(filters, kernel_size) + locally_connected_1d_layer = locally_connected(filters, kernel_size) call locally_connected_1d_layer % init(input_layer) select type(this_layer => input_layer % p); type is(input2d_layer) @@ -62,7 +62,6 @@ program test_locally_connected1d_layer call locally_connected_1d_layer % forward(input_layer) call locally_connected_1d_layer % get_output(output) - if (.not. all(abs(output) < tolerance)) then ok = .false. write(stderr, '(a)') 'locally_connected1d layer with zero input and sigmoid function must forward to all 0.5.. failed' diff --git a/test/test_maxpool1d_layer.f90 b/test/test_maxpool1d_layer.f90 index 023a2c33..f3765686 100644 --- a/test/test_maxpool1d_layer.f90 +++ b/test/test_maxpool1d_layer.f90 @@ -1,7 +1,7 @@ program test_maxpool1d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: maxpool1d, input, layer + use nf, only: maxpool, input, layer use nf_input2d_layer, only: input2d_layer use nf_maxpool1d_layer, only: maxpool1d_layer @@ -16,7 +16,7 @@ program test_maxpool1d_layer integer :: i logical :: ok = .true., gradient_ok = .true. - maxpool_layer = maxpool1d(pool_size) + maxpool_layer = maxpool(pool_width=pool_size, stride=stride) if (.not. maxpool_layer % name == 'maxpool1d') then ok = .false. diff --git a/test/test_maxpool2d_layer.f90 b/test/test_maxpool2d_layer.f90 index 5983a217..29a56b57 100644 --- a/test/test_maxpool2d_layer.f90 +++ b/test/test_maxpool2d_layer.f90 @@ -1,7 +1,7 @@ program test_maxpool2d_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: maxpool2d, input, layer + use nf, only: maxpool, input, layer use nf_input3d_layer, only: input3d_layer use nf_maxpool2d_layer, only: maxpool2d_layer @@ -16,7 +16,7 @@ program test_maxpool2d_layer integer :: i, j logical :: ok = .true., gradient_ok = .true. - maxpool_layer = maxpool2d(pool_size) + maxpool_layer = maxpool(pool_width=pool_size, pool_height=pool_size, stride=stride) if (.not. maxpool_layer % name == 'maxpool2d') then ok = .false.