Skip to content

Generic UNet implementation written in pure Julia, based on Flux.jl


Notifications You must be signed in to change notification settings


Repository files navigation


Actions Status

This pacakge provides a generic UNet implemented in Julia.

The package is built on top of Flux.jl, and therefore can be extended as needed

julia> u = Unet()
  ConvDown(64, 64)
  ConvDown(128, 128)
  ConvDown(256, 256)
  ConvDown(512, 512)

  UNetConvBlock(1, 3)
  UNetConvBlock(3, 64)
  UNetConvBlock(64, 128)
  UNetConvBlock(128, 256)
  UNetConvBlock(256, 512)
  UNetConvBlock(512, 1024)
  UNetConvBlock(1024, 1024)

  UNetUpBlock(1024, 512)
  UNetUpBlock(1024, 256)
  UNetUpBlock(512, 128)
  UNetUpBlock(256, 64)

The default input channel dimension is expected to be 1 ie. grayscale. To support different channel images, you can pass the channels to Unet.

julia> u = Unet(3) # for RGB images

The input size can be any power of two sized batch. Something like (256,256, channels, batch_size).

The default output channel dimension is the input channel dimension. So, 1 for a Unet() and e.g. 3 for a Unet(3). The output channel dimension can be set by supplying a second argument:

julia> u = Unet(3, 5) # 3 input channels, 5 output channels.

GPU Support

To train the model on UNet, it is as simple as calling gpu on the model.

julia> u = Unet();

julia> u = gpu(u);

julia> r = gpu(rand(Float32, 256, 256, 1, 1));

julia> size(u(r))
(256, 256, 1, 1)


Training UNet is a breeze too.

You can define your own loss function, or use Flux binary cross entropy implementation.

using UNet, Flux,  Base.Iterators
import Flux.Losses.binarycrossentropy

device = gpu #cpu

function loss(x, y)
    op = clamp.(u(x), 0.001f0, 1.f0)

u = Unet() |> device
w = rand(Float32, 256, 256, 1, 1) |> device
w′ = rand(Float32, 256, 256, 1, 1) |> device
rep = Iterators.repeated((w, w′), 10)

opt = ADAM()

Flux.train!(loss, Flux.params(u), rep, opt, cb = () -> @show(loss(w, w′)))

Further Reading

The package is an implementation of the paper, and all credits of the model itself go to the respective authors.


Generic UNet implementation written in pure Julia, based on Flux.jl







No packages published
