Conversation
The additive term is responsible for predicting, the overall model handles the likelihoods. Training works. TODO: - refactor priors - have the terms and the model save their state to disk - port R2 calculation - API + docs - plotting - .tl namespace
There is now only one class per prior, without the split into Pyro priors and wrapper classes. This counteracts the proliferation of wrapper classes and allows us to use the same architecture for priors and terms.
This makes the likelihoods stateful and responsible for calculating whatever statistics they need and transforming the model prediction. This fully decouples the likelihoods from the preprocessing and is a prerequisite for a unified R2 calculation that is decoupled from prediction calculation: Now also the NB and Bernoulli likelihoods perform a shift of the prediction. Thus, a zero factor will now be transformed to the null model and no longer produce negative R2 values. The R2 calculation can thus operate directly on the full prediction without any knowledge of how that prediction was obtained.
R2 is now being calculated for the entire model, for each additive term, and for each component (e.g. factor) of each additive term.
this will make it possible to construct a term outside of the main MOFAFLEX class
move the validation of subclasses and subclass registry handling to a class decorator
return read-only mappings from the public API
apparently the ability to use class method properties was removed in Python 3.13
With the new API, it is now possible to use different configurations for the same prior in different groups/views. The dynamic API now handles that case by merging the results of all priors of the same class.
Warping is now applied to all groups of the GP prior. If warping is not required for some groups, a separate GaussianProcess instances with warping turned off can be used.
all covariates are now handled through get_datasets in the priors, the MofaFlex term class takes care of constructing CovariateDatasets for factor priors and storing the covariates for weight priors. All covariates are now passed as kwargs to model and guide.
This improves sparsity with the spike and slab prior as well as the quality of the results
annotations, re-run citeseq tutorial
initialization now happens in a device context
less warnings
pass them to term hooks
untrained model that only work on a trained model and vice versa
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #163 +/- ##
==========================================
- Coverage 90.05% 89.01% -1.05%
==========================================
Files 53 53
Lines 4758 5080 +322
==========================================
+ Hits 4285 4522 +237
- Misses 473 558 +85
🚀 New features to boost your workflow:
|
682f60b to
603e15d
Compare
sphinx-tabs is not yet compatible with Sphinx 9, see executablebooks/sphinx-tabs#209
603e15d to
a613e04
Compare
| for view_name, view in group.items(): | ||
| if view.numel() == 0: # can occur in the last batch of an epoch if the batch is small | ||
| continue | ||
| prediction = None |
There was a problem hiding this comment.
Think this might be simplified a bit, could not wrap my head around the multi-nested structure.
prediction = 0
has_prediction = False
for term in self._terms.values():
try:
term_prediction = term[group_name][view_name]
except KeyError:
continue
prediction += term_prediction
has_prediction = True
if has_prediction:
...
There was a problem hiding this comment.
I'm sorry, I don't really see how that's different from the existing code, except you're using an addtional boolean instead of the None sentinel. Starting with prediction=0 will result in an additional element-wise addition, which may be expensive.
There was a problem hiding this comment.
I just assumed the 0 instead of None would be safer wrt arithmetic errors.
|
|
||
| return lr_func | ||
|
|
||
| def on_train_start(self, data: MofaFlexDataset): |
There was a problem hiding this comment.
I wonder if performing some sort of validation / error handling here would save some time for errors being raised later or being propagated during training. Maybe something like:
def _validate_terms(self, data):
for term in self._terms.values():
term.validate(data, self._likelihoods)
before on_train_start and then call self._validate_terms(data) inside on_train_start. Not sure what could go wrong but I remember the issue we had when we do all the processing which takes a bit of time but had issues with the model config during training (now term config).
There was a problem hiding this comment.
Validation of arguments should be performed in the constructor as much as possible. But I don't think it makes sense to implement validation ourselves, it's probably better to use something like Pydantic for this. But that is out of scope for this PR IMHO.
This is a major refactor that completely modularizes the code base and adds support for multiple additive terms, i.e. models of the form$Y = Z_1 W_1 + Z_2 + W_2 + X$ . Currently, only one type of term is implemented: The MofaFlex term, which takes the form $Y = Z W$ , but additional term types can easily be added by subclassing the
Termclass.This builds on prevous work modularizing the priors and introducing a dynamic API. Each term provides its own API, which can be accessed from the user-facing model object as e.g.
model.terms[term_name].get_factors. To simplify the common special case of only a single additive term, in that situation the user-facing model objects forwards requests for any unknown attributes to its single term.There are still some more unit tests needed and the Getting started tutorial needs a major overhaul, but that can all be done incrementally after this is merged: Since this touches every single part of the code base, it's blocking all other work, so it's time to merge and get on with it.