Skip to content
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
85437f6
break up the model into an additive term and the overall model
ilia-kats Dec 12, 2025
31e7a2f
make the mofaflex term responsible for its setup, also refactor priors
ilia-kats Dec 15, 2025
1b05d9b
minor simplifications and fixes
ilia-kats Dec 15, 2025
fb892c3
remove preprocessing pipeline, transform predictions instead
ilia-kats Dec 16, 2025
88ab2be
make R2 calculation work again
ilia-kats Dec 16, 2025
a34106b
minor fixes
ilia-kats Dec 17, 2025
6772b9b
terms/mofaflex: do all the initialization immediately prior to training
ilia-kats Dec 17, 2025
9555046
implement save/load for the entire model hierarchy
ilia-kats Dec 17, 2025
1a678ea
move likelihood initialization to the model
ilia-kats Dec 18, 2025
33ff948
properly handle likelihoods for nonnegative views
ilia-kats Dec 18, 2025
a104acb
simplify all the base classes
ilia-kats Dec 18, 2025
ea53634
generate API wrappers for terms and likelihoods
ilia-kats Dec 18, 2025
47eda11
remove dead code and reorganize
ilia-kats Dec 18, 2025
ea33955
minor simplification
ilia-kats Dec 18, 2025
cfa856f
new API for creating and training models
ilia-kats Dec 19, 2025
882f652
move the dynamic prior API to the mofaflex term
ilia-kats Dec 19, 2025
1e226e5
allow specifying a likelihood for multiple views
ilia-kats Dec 19, 2025
c3cbb00
type hint and API fixes
ilia-kats Dec 19, 2025
004ad26
Likelihoods: make known_likelihoods() a normal class method
ilia-kats Jan 7, 2026
bf5332b
add wrapper class for terms that only exposes the public API
ilia-kats Jan 7, 2026
398a77e
MofaFlex term: fix dynamic API for multiple priors of the same type
ilia-kats Jan 7, 2026
07211e9
GP prior: make warping a boolean argument
ilia-kats Jan 7, 2026
e3ccc71
MOFAFLEX: forward getattr to the term if there is only one term
ilia-kats Jan 7, 2026
554db74
generate API docs for priors and terms
ilia-kats Jan 7, 2026
0f8b345
simplify
ilia-kats Jan 8, 2026
0ef577b
datasets: implement filtering by groups/views in get_covariats
ilia-kats Jan 8, 2026
8791652
make priors completely agnostic to their axis
ilia-kats Jan 8, 2026
5f0e407
simplify spike-and-slab prior
ilia-kats Jan 8, 2026
d323d5d
minor simplification
ilia-kats Jan 8, 2026
5c3e784
add docstrings and generate docs for likelihoods
ilia-kats Jan 8, 2026
bcc6c01
make preprocessing and likelihood tests work again
ilia-kats Jan 9, 2026
47efeee
reorganize likelihood tests, also test Bernoulli and NegativeBinomial
ilia-kats Jan 9, 2026
b86bb4f
fix R2 calculation and integration tests
ilia-kats Jan 9, 2026
951580a
fix negative binomial likelihood
ilia-kats Jan 13, 2026
ceb07f7
make imputation work again
ilia-kats Jan 13, 2026
abde94b
fix saving
ilia-kats Jan 13, 2026
51b7384
fix loading
ilia-kats Jan 13, 2026
dd37a7d
fix R2 saving and API
ilia-kats Jan 13, 2026
192d3bb
fix Normal likelihood
ilia-kats Jan 13, 2026
7a7169c
fold save_load test into integration test and fix Term dynamic API for
ilia-kats Jan 14, 2026
2264874
move 0 R2 warning from likelihood to model
ilia-kats Jan 14, 2026
57bd366
fix tests for prior and pcgse
ilia-kats Jan 14, 2026
fa82ba3
make remaining tests pass
ilia-kats Jan 14, 2026
06ae1f0
update "mofaflex for mofa users" tutorial
ilia-kats Jan 14, 2026
67b4cee
fix R2 for non-Gaussian likelihoods
ilia-kats Jan 14, 2026
689ffc4
finalize plotting API + docstrings
ilia-kats Jan 15, 2026
7aa9b62
Normal likelihood: initialize variational dispersion to data scale
ilia-kats Jan 15, 2026
e334945
fix InformedHorseshoe and re-run Kang tutorial
ilia-kats Jan 15, 2026
88e5513
fix informed horseshoe prior when only a subset of its views have
ilia-kats Jan 16, 2026
b1b628a
update test plots
ilia-kats Jan 16, 2026
fbb91ba
fix negative binomial likelihood
ilia-kats Jan 16, 2026
eceffc6
fix MofaFlex getters with ordered=True and re-run Xenium tutorial
ilia-kats Jan 16, 2026
274aa92
GP: don't use buffers anymore
ilia-kats Jan 16, 2026
bc8af3e
some likelihood fixes for datasets with a single obs
ilia-kats Jan 16, 2026
1dbd4dc
cleanup: only define plate dimensions in the main model
ilia-kats Jan 16, 2026
0f67321
raise informative exceptions if the user tries to invoke methods on an
ilia-kats Jan 16, 2026
e0a9b55
make get_dispersion work again
ilia-kats Jan 16, 2026
ae6c837
add missing docstring
ilia-kats Jan 16, 2026
a613e04
bound Sphinx to <9
ilia-kats Jan 16, 2026
0e0c401
more detailed docstrings
ilia-kats Jan 19, 2026
07d7061
add changelog
ilia-kats Jan 19, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/_templates/autosummary/class.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Attributes

{% for item in attributes %}

.. autoattribute:: {{ [objname, item] | join(".") }}
.. autoattribute:: {{ [fullname, item] | join(".") }}
{%- endfor %}

{% endif %}
Expand All @@ -52,7 +52,7 @@ Methods

{% for item in methods %}

.. automethod:: {{ [objname, item] | join(".") }}
.. automethod:: {{ [fullname, item] | join(".") }}
{%- endfor %}

{% endif %}
Expand Down
22 changes: 18 additions & 4 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
.. autosummary::
:toctree: generated

DataOptions
ModelOptions
TrainingOptions
SmoothOptions
MOFAFLEX
FeatureSet
FeatureSets
```

### Terms
```{eval-rst}
.. autosummary::
:toctree: generated
:recursive:

terms
```

### Priors
```{eval-rst}
.. autosummary::
Expand All @@ -28,6 +33,15 @@
priors
```

### Likelihoods
```{eval-rst}
.. autosummary::
:toctree: generated
:recursive:

likelihoods
```

### Settings

An instance of the [](#_core.settings.Settings) class is available as `mofaflex.settings` and allows configuring MOFA-FLEX.
Expand Down
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

# -- Path setup --------------------------------------------------------------
import sys
import os
from datetime import datetime
from importlib.metadata import metadata
from pathlib import Path

os.environ["MOFAFLEX_DOCS"] = "1"

HERE = Path(__file__).parent
sys.path.insert(0, str(HERE / "extensions"))

Expand Down Expand Up @@ -90,7 +93,7 @@
"numpy": ("https://numpy.org/doc/stable/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
"plotnine": ("https://plotnine.org/", None),
"pytorch": ("https://pytorch.org/docs/stable/", None),
"pytorch": ("https://docs.pytorch.org/docs/stable/", None),
"muon-tutorials": ("https://muon-tutorials.readthedocs.io/en/latest", None),
}

Expand Down
152 changes: 82 additions & 70 deletions docs/notebooks/kang_analysis.ipynb

Large diffs are not rendered by default.

137 changes: 67 additions & 70 deletions docs/notebooks/mofaflex_for_mofa_users_cll.ipynb

Large diffs are not rendered by default.

244 changes: 134 additions & 110 deletions docs/notebooks/mouse_citeseq_informed.ipynb

Large diffs are not rendered by default.

115 changes: 59 additions & 56 deletions docs/notebooks/xenium_chromium_informed.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ optional-dependencies.doc = [
"myst-nb>=1.1",
"pandas",
"setuptools", # Until pybtex >0.24.0 releases: https://bitbucket.org/pybtex-devs/pybtex/issues/169/
"sphinx>=8.1",
"sphinx>=8.1,<9", # https://github.com/executablebooks/sphinx-tabs/issues/209
"sphinx-autodoc-typehints",
"sphinx-automodapi",
"sphinx-book-theme>=1",
Expand Down
4 changes: 2 additions & 2 deletions src/mofaflex/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging

from . import pl, tl
from ._core import MOFAFLEX, DataOptions, FeatureSet, FeatureSets, ModelOptions, TrainingOptions, settings
from ._core.api import priors
from ._core import MOFAFLEX, FeatureSet, FeatureSets, settings
from ._core.api import likelihoods, priors, terms
from ._version import __version__, __version_tuple__

_logger = logging.getLogger(__name__)
Expand Down
2 changes: 1 addition & 1 deletion src/mofaflex/_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .datasets import MofaFlexDataset
from .feature_sets import FeatureSet, FeatureSets
from .mofaflex import MOFAFLEX, DataOptions, ModelOptions, TrainingOptions
from .mofaflex import MOFAFLEX
from .pcgse import pcgse_test
from .settings import settings
2 changes: 1 addition & 1 deletion src/mofaflex/_core/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import priors
from . import likelihoods, priors, terms
77 changes: 77 additions & 0 deletions src/mofaflex/_core/api/_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import sys
from abc import ABC, abstractmethod
from collections.abc import Mapping
from inspect import Parameter, signature


class APIWrapper(ABC):
@abstractmethod
def __call__(self, axis, names):
pass

def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = kwargs

def __eq__(self, other):
if not isinstance(other, __class__):
return NotImplemented
elif other.__class__ != self.__class__:
return False
else:
return self._args == other._args and self._kwargs == other._kwargs

def __hash__(self):
return hash((self.__class__, self._args, tuple(sorted(self._kwargs.items()))))


def init_api(module: str, basecls: type, subclss: Mapping[str, type]):
mod = sys.modules[module]
coreinit = basecls.__dict__.get("__init__", None)
if coreinit is not None:
coresig = signature(coreinit)

all_ = []

basewrapper = type(basecls.__name__, (APIWrapper,), {"__module__": module})

for subname, subcls in subclss.items():
sig = signature(subcls.__init__)
annots = subcls.__init__.__annotations__
if coreinit is not None:
params = [
param
for i, param in enumerate(sig.parameters.values())
if i == 0 or param.name not in coresig.parameters
]
sig = sig.replace(parameters=params)
annots = {param.name: param.annotation for param in params if param.annotation is not Parameter.empty}

def init(self, *args, **kwargs):
self.__init__.__signature__.bind(self, *args, **kwargs) # check for argument compatibility
super(self.__class__, self).__init__(*args, **kwargs)

def call(self, *args, **kwargs):
return self._cls(*args, *self._args, **kwargs, **self._kwargs)

if coreinit is not None:
call.__signature__ = coresig
call.__annotations__ = coreinit.__annotations__

init.__signature__ = sig
init.__annotations__ = annots
init.__name__ = "__init__"
init.__qualname__ = f"{subname}.__init__"
call.__name__ = "__call__"
call.__qualname__ = f"{subname}.__call__"
apicls = type(
subname, (basewrapper,), {"_cls": subcls, "__init__": init, "__call__": call, "__module__": module}
)
apicls.__doc__ = subcls.__doc__

setattr(mod, subname, apicls)
all_.append(subname)

setattr(mod, basewrapper.__name__, basewrapper)
mod.__all__ = all_
mod.__dir__ = lambda: all_
4 changes: 4 additions & 0 deletions src/mofaflex/_core/api/likelihoods.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from ..likelihoods import Likelihood
from ._generate import init_api

init_api(__name__, Likelihood, Likelihood.known_likelihoods())
78 changes: 3 additions & 75 deletions src/mofaflex/_core/api/priors.py
Original file line number Diff line number Diff line change
@@ -1,76 +1,4 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from inspect import Parameter, signature
from types import MappingProxyType
from typing import Literal
from ..priors import Prior
from ._generate import init_api

from ..priors import Prior as PriorCore


class Prior(ABC):
@abstractmethod
def __call__(self, axis, names):
pass

def __init__(self, *args, **kwargs):
self._args = args
self._kwargs = MappingProxyType(kwargs)

def __eq__(self, other):
if not isinstance(other, __class__):
return NotImplemented
elif other.__class__ != self.__class__:
return False
else:
return self._args == other._args and self._kwargs == other._kwargs

def __hash__(self):
return hash((self.__class__, self._args, tuple(sorted(self._kwargs.items()))))


__all__ = []


def _init_priors():
for priorname in PriorCore.known_priors():
priorcls = PriorCore.class_(priorname)
sig = signature(priorcls.__init__)
params = [param for param in sig.parameters.values() if param.name not in ("axis", "names")]
sig = sig.replace(parameters=params)

def init(self, *args, **kwargs):
self.__init__.__signature__.bind(self, *args, **kwargs) # check for argument compatibility
super(self.__class__, self).__init__(*args, **kwargs)

if priorcls is not PriorCore:

def call(self, axis: Literal[0, 1, "samples", "features"], names: str | Sequence[str]):
return self._cls(axis, names, *self._args, **self._kwargs)
else:

def call(self, axis: Literal[0, 1, "samples", "features"], names: str | Sequence[str]):
return PriorCore(self.__class__.__name__, axis, names, *self._args, **self._kwargs)

init.__signature__ = sig
init.__annotations__ = {
param.name: param.annotation for param in params if param.annotation is not Parameter.empty
}
init.__name__ = "__init__"
init.__qualname__ = f"{priorname}.__init__"
call.__name__ = "__call__"
call.__qualname__ = f"{priorname}.__call__"
apicls = type(
priorname, (Prior,), {"_cls": priorcls, "__init__": init, "__call__": call, "__module__": __name__}
)
if priorcls is not PriorCore:
apicls.__doc__ = priorcls.__doc__

globals()[priorname] = apicls
__all__.append(priorname)


_init_priors()


def __dir__():
return __all__
init_api(__name__, Prior, Prior.known_priors())
47 changes: 47 additions & 0 deletions src/mofaflex/_core/api/terms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from inspect import Parameter, Signature, signature
from typing import TYPE_CHECKING

from ..terms import Term
from ..utils import building_docs

__all__ = []

if TYPE_CHECKING:
pass


def _init_api():
def make_wrapper(term: Term): # required due to Python's late-binding closures
def wrapper(name="_", /, **kwargs):
from ..mofaflex import MOFAFLEX

return MOFAFLEX(**{name: term(**kwargs)})

return wrapper

for termname, term in Term.known_terms().items():
if not building_docs():
wrapper = make_wrapper(term)
sig = signature(term.__init__)
params = [signature(wrapper).parameters["name"]] + [
Parameter(param.name, Parameter.KEYWORD_ONLY, default=param.default, annotation=param.annotation)
for param in sig.parameters.values()
]
wrapper.__signature__ = Signature(params)
wrapper.__annotations__ = term.__init__.__annotations__ | {"name": str, "return": "MOFAFLEX"}
wrapper.__doc__ = term.__doc__
else:
wrapper = type(termname, (), {"__module__": __name__, "__doc__": term.__doc__})
wrapper.__init__ = term.__init__
for api in term.api():
setattr(wrapper, api, getattr(term, api))

globals()[termname] = wrapper
__all__.append(termname)


def __dir__():
return __all__


_init_api()
10 changes: 10 additions & 0 deletions src/mofaflex/_core/api/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ..terms import TermWrapper
from ..utils import building_docs
from . import terms

if building_docs():
for term in dir(terms):
globals()[term] = getattr(terms, term)
else:
for term in dir(terms):
globals()[term] = TermWrapper
2 changes: 1 addition & 1 deletion src/mofaflex/_core/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .anndatadictdataset import AnnDataDictDataset
from .base import MofaFlexDataset, Preprocessor
from .misc import CovariatesDataset, GuidingVarsDataset, MofaFlexBatchSampler, StackDataset
from .misc import CovariatesDataset, MofaFlexBatchSampler, StackDataset, merge_covariates
from .mudatadataset import MuDataDataset
Loading
Loading