Skip to content

Conversation

RubensZimbres
Copy link
Contributor

This PR introduces a JAX implementation of the U-Net architecture, designed for image segmentation tasks. The implementation uses the modern flax.nnx API and follows the original U-Net paper's design. It includes the core model definition, a factory function for easy instantiation, and a complete example notebook that demonstrates its usage on a segmentation task.

@chapman20j
Copy link
Collaborator

Hi Rubens. Apologies for the late reply. I wanted to follow up on this PR. We're excited to add a Unet to the repo. Before we approve this PR:

  1. Could you please remove the other models from this PR (efficientnet and vae), since they're now in other PRs that have been accepted?
  2. Could you also remove the checkpoints for the trained Unet?
  3. What reference implementation you used for the Unet? We'd love to load from a checkpoint and create test cases for this implementation.

@RubensZimbres
Copy link
Contributor Author

Checkpoints, Efficientnet and VAE removed from PR. The implementation follows: https://arxiv.org/abs/1505.04597 (U-Net: Convolutional Networks for Biomedical Image Segmentation) with padded convolutions in its DoubleConv module, a modern implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants