Backend for creating Flow Matching (FM) models. With Schedulers, Paths and Integrators available to make creating new Flow Matching models easier.
Flow matching models are a type of generative model that models densities. At its core it is a physics based model that transporst points from t=0 to t=1, points at t=0 are sampled from some known distribution, while points at t=1 are sampled from the target distribution that the model learns.
The points are moved from t=0 to t=1 along paths, since the origin and the enpoint of this path is a distribution, at each time t along the path there is another distribution, hence the path is called the probability path, but other than specific cases along this probability path, the probabilities aren't tangible, unless a nontrivial integral is solved.
The actual neural network part of the model is the vector field that pushes points from t=0 to t=1 along the path at each timestep.
This section is for explaining the building blocks of the FM model. All of this is basically made by following the Flow Matching Guide and Code book/paper from Meta.
A probability path is the name of the path a point takes while travelling from t=0 to t=1. This path is ideally affine (and this is how it is implemented), because straight simple paths are easy to integrate with low error. The affine path is defined by a function that decides where a point should be, this function is defined as:
Where t=1, usually sampled from the actual dataset, while t=0 which is usually sampled from Gaussian noise
Since we're learning vector fields, we're also interested in velocities along these probability paths. And since velocity is just the
The
The simplest scheduler comes from the optimal transport theorem, and is therefore named OTScheduler, in which Polynomial Scheduler which is just the OT one except, t is exponentiated by a parameter, meaning:
The actual neural network part of the FM model is the velocity vector field. It learns to assign each point at some time a velocity along all of its dimensions. Therefore the input to this type of network is some
There are two constraints the model has to meet, one is simple enough, and required for sampling and likelihood calculation later on, the other is a bit more troubling.
The first constraint is that the overall network must in the end be differentiable w.r.t the input which is more or less any neural network.
The other constraint is that the network needs to be aware of time as a concept in some way. Usually the time dimension is added as another dimension to the input vector, but this is not always true (i.e. time dependent CNNs). So keep this in mind when making networks.
Keep in mind, Network outputs are just speeds not the actual points in space. To get actual points you need to integrate!
The loss is super simple, the network needs to learn a vector field, by assigning a speed to all the components of the input tensor.
Which is basically just MSE on the predicted speed
This is a super simple convex loss objective, so the network converges nicely. Don't be too scared when losses aren't going down for a long time, this loss tells the network to almost exactly match the speeds which is near impossible for too many points and dimensions. But a nice thing is that the good ol' rule applies here, more iterations means better representation.
Since FM is a generative model at its core, sampling points is obviously very important. Like we said in the velocity fields section, the network produces speeds, so the integral of the speed from t=0 to t=1 is the actual position of the point in space. An intuitive way to look at the integral is through Verlet integration (which is in reality just Euler's method for function approximation):
So for small midpoint method is really good.
The framework has an Integrator class that takes as a parameter the network that produces the velocity field. The integrator solves the integral with given parameters and produces points sampled along the path at anchors given through the t parameter.
But FM doesn't only generate new data, it can also calculate the likelihoods along the path, but in reality the only time the likelihood can be calculated at is at t=0 since we know the exact likelihood function only at that time. So the likelihood computation consists of taking a point at t=1 integrating in reverse from t=1 to t=0, taking the produced point's position at t=0 and calculating the likelihood for it there.
In the backend, the likelihood is governed by a dynamics function that follows both the divergence and the position of the point in time. The divergence is super important because of the Mass Conservation Law (MCL) from physics which states that the change of mass should always be
In our context the Mass part of MCL is changed with probability, and now it's called the Continuity Law, which says probability is neither created nor destroyed over time.
Since