Skip to content

Support multiple additive terms#163

Merged
arberqoku merged 61 commits intomainfrom
additive_terms
Jan 19, 2026
Merged

Support multiple additive terms#163
arberqoku merged 61 commits intomainfrom
additive_terms

Conversation

@ilia-kats
Copy link
Collaborator

This is a major refactor that completely modularizes the code base and adds support for multiple additive terms, i.e. models of the form $Y = Z_1 W_1 + Z_2 + W_2 + X$. Currently, only one type of term is implemented: The MofaFlex term, which takes the form $Y = Z W$, but additional term types can easily be added by subclassing the Term class.

This builds on prevous work modularizing the priors and introducing a dynamic API. Each term provides its own API, which can be accessed from the user-facing model object as e.g. model.terms[term_name].get_factors. To simplify the common special case of only a single additive term, in that situation the user-facing model objects forwards requests for any unknown attributes to its single term.

There are still some more unit tests needed and the Getting started tutorial needs a major overhaul, but that can all be done incrementally after this is merged: Since this touches every single part of the code base, it's blocking all other work, so it's time to merge and get on with it.

The additive term is responsible for predicting, the overall model
handles the likelihoods. Training works.

TODO:
- refactor priors
- have the terms and the model save their state to disk
- port R2 calculation
- API + docs
- plotting
- .tl namespace
There is now only one class per prior, without the split into Pyro
priors and wrapper classes. This counteracts the proliferation of
wrapper classes and allows us to use the same architecture for priors
and terms.
This makes the likelihoods stateful and responsible for calculating
whatever statistics they need and transforming the model prediction.
This fully decouples the likelihoods from the preprocessing and is a
prerequisite for a unified R2 calculation that is decoupled from
prediction calculation: Now also the NB and Bernoulli likelihoods
perform a shift of the prediction. Thus, a zero factor will now be
transformed to the null model and no longer produce negative R2 values.
The R2 calculation can thus operate directly on the full prediction
without any knowledge of how that prediction was obtained.
R2 is now being calculated for the entire model, for each additive term,
and for each component (e.g. factor) of each additive term.
this will make it possible to construct a term outside of the main
MOFAFLEX class
move the validation of subclasses and subclass registry handling to a
class decorator
return read-only mappings from the public API
apparently the ability to use class method properties was removed in
Python 3.13
With the new API, it is now possible to use different configurations for
the same prior in different groups/views. The dynamic API now handles
that case by merging the results of all priors of the same class.
Warping is now applied to all groups of the GP prior. If warping is not
required for some groups, a separate GaussianProcess instances with
warping turned off can be used.
all covariates are now handled through get_datasets in the priors, the
MofaFlex term class takes care of constructing CovariateDatasets for
factor priors and storing the covariates for weight priors. All
covariates are now passed as kwargs to model and guide.
@codecov
Copy link

codecov bot commented Jan 16, 2026

Codecov Report

❌ Patch coverage is 91.23638% with 185 lines in your changes missing coverage. Please review.
✅ Project coverage is 89.01%. Comparing base (7ff6adb) to head (07d7061).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
src/mofaflex/_core/terms/mofaflex.py 90.95% 50 Missing ⚠️
src/mofaflex/_core/terms/base.py 79.41% 21 Missing ⚠️
src/mofaflex/_core/model.py 93.33% 16 Missing ⚠️
src/mofaflex/_core/mofaflex.py 90.58% 16 Missing ⚠️
.../_core/priors/gaussian_process/gaussian_process.py 87.27% 14 Missing ⚠️
src/mofaflex/_core/utils.py 93.42% 10 Missing ⚠️
src/mofaflex/pl/_plotting.py 83.01% 9 Missing ⚠️
src/mofaflex/_core/priors/spike_slab.py 91.39% 8 Missing ⚠️
src/mofaflex/_core/api/_generate.py 88.88% 6 Missing ⚠️
src/mofaflex/_core/terms/api.py 73.68% 5 Missing ⚠️
... and 12 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #163      +/-   ##
==========================================
- Coverage   90.05%   89.01%   -1.05%     
==========================================
  Files          53       53              
  Lines        4758     5080     +322     
==========================================
+ Hits         4285     4522     +237     
- Misses        473      558      +85     
Files with missing lines Coverage Δ
src/mofaflex/__init__.py 100.00% <100.00%> (ø)
src/mofaflex/_core/__init__.py 100.00% <100.00%> (ø)
src/mofaflex/_core/api/__init__.py 100.00% <100.00%> (ø)
src/mofaflex/_core/api/likelihoods.py 100.00% <100.00%> (ø)
src/mofaflex/_core/api/priors.py 100.00% <100.00%> (+10.00%) ⬆️
src/mofaflex/_core/datasets/__init__.py 100.00% <100.00%> (ø)
src/mofaflex/_core/datasets/utils.py 81.60% <100.00%> (+0.65%) ⬆️
src/mofaflex/_core/dist.py 100.00% <ø> (ø)
src/mofaflex/_core/likelihoods/__init__.py 100.00% <100.00%> (ø)
src/mofaflex/_core/likelihoods/bernoulli.py 97.29% <100.00%> (+1.29%) ⬆️
... and 31 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

sphinx-tabs is not yet compatible with Sphinx 9, see
executablebooks/sphinx-tabs#209
for view_name, view in group.items():
if view.numel() == 0: # can occur in the last batch of an epoch if the batch is small
continue
prediction = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this might be simplified a bit, could not wrap my head around the multi-nested structure.

prediction = 0
has_prediction = False

for term in self._terms.values():
    try:
        term_prediction = term[group_name][view_name]
    except KeyError:
        continue

    prediction += term_prediction
    has_prediction = True

if has_prediction:
    ...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I don't really see how that's different from the existing code, except you're using an addtional boolean instead of the None sentinel. Starting with prediction=0 will result in an additional element-wise addition, which may be expensive.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just assumed the 0 instead of None would be safer wrt arithmetic errors.


return lr_func

def on_train_start(self, data: MofaFlexDataset):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if performing some sort of validation / error handling here would save some time for errors being raised later or being propagated during training. Maybe something like:

def _validate_terms(self, data):
    for term in self._terms.values():
        term.validate(data, self._likelihoods)

before on_train_start and then call self._validate_terms(data) inside on_train_start. Not sure what could go wrong but I remember the issue we had when we do all the processing which takes a bit of time but had issues with the model config during training (now term config).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Validation of arguments should be performed in the constructor as much as possible. But I don't think it makes sense to implement validation ourselves, it's probably better to use something like Pydantic for this. But that is out of scope for this PR IMHO.

@ilia-kats ilia-kats requested a review from arberqoku January 19, 2026 08:34
@arberqoku arberqoku merged commit 46e77c3 into main Jan 19, 2026
9 checks passed
@ilia-kats ilia-kats deleted the additive_terms branch January 19, 2026 09:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants