diff --git a/contextualized/regression/__init__.py b/contextualized/regression/__init__.py index e626eaa6..e2a0d26d 100644 --- a/contextualized/regression/__init__.py +++ b/contextualized/regression/__init__.py @@ -1,14 +1,13 @@ """ Contextualized Regression models. """ -from contextualized.regression.datasets import ( - MultivariateDataset, - UnivariateDataset, - MultitaskMultivariateDataset, - MultitaskUnivariateDataset, -) +from contextualized.regression.datasets import DATASETS from contextualized.regression.losses import MSE, BCELoss from contextualized.regression.regularizers import REGULARIZERS +from contextualized.regression.trainers import RegressionTrainer, TRAINERS +from contextualized.regression.losses import LOSSES +from contextualized.regression.metamodels import METAMODELS + from contextualized.regression.lightning_modules import ( NaiveContextualizedRegression, ContextualizedRegression, @@ -16,16 +15,12 @@ TasksplitContextualizedRegression, ContextualizedUnivariateRegression, TasksplitContextualizedUnivariateRegression, + MODELS, ) -from contextualized.regression.trainers import RegressionTrainer -DATASETS = { - "multivariate": MultivariateDataset, - "univariate": UnivariateDataset, - "multitask_multivariate": MultitaskMultivariateDataset, - "multitask_univariate": MultitaskUnivariateDataset, -} -LOSSES = {"mse": MSE, "bceloss": BCELoss} -MODELS = ["multivariate", "univariate"] -METAMODELS = ["simple", "subtype", "multitask", "tasksplit"] -TRAINERS = {"regression_trainer": RegressionTrainer} +from contextualized.regression.datasets import ( + MultivariateDataset, + UnivariateDataset, + MultitaskMultivariateDataset, + MultitaskUnivariateDataset, +) diff --git a/contextualized/regression/datamodules.py b/contextualized/regression/datamodules.py new file mode 100644 index 00000000..d518ec49 --- /dev/null +++ b/contextualized/regression/datamodules.py @@ -0,0 +1,152 @@ +from abc import abstractmethod +import numpy as np +import torch +import pytorch_lightning as pl +from contextualized.regression.datasets import DATASETS + +from torch.utils.data import DataLoader + + +class RegressionDataModule(pl.LightningDataModule): + """ + Torch Datamodule used for contextualized.regression modules + """ + + def __init__( + self, + c, + x, + y, + dataset="multivariate", + num_workers=0, + batch_size=32, + correlation=False, + markov=False, + pct_test=0.2, + pct_val=0.2, + **kwargs, + ): + + """Initialize the Regression Datamodule + + Args: + c (ndarray): 2D array containing contextual features per each sample. + x (ndarray): 2D array containing features per each sample. + w (ndarray): 3D array containing known 2D network per each sample. + dataset (str): Which dataset to use. Choose between ["multivariate", "univariate", "multitask_multivariate", "multitask_univariate]. + n (int): Number of data samples to use. Defaults to 0 (full dataset will be used). + correlation (bool): Whether datamodule will be used for correlation regression module. + markov (bool): Whether datamodule will be used for markov regression module. (Currently unused) + num_workers (int): Number of workers used in dataloaders. + batch_size (int): Size of batches used in dataloaders. + pct_test (float): Pct of full dataset to be used as test dataset + pct_test (float): Pct of test set to be used as val dataset + """ + + super().__init__() + + self.dataset = DATASETS[dataset] + self.num_workers = 0 + self.batch_size = 32 + + # NOTE: batch size ~ dummy params => each dataset + # NOTE: batch size is either too small or + + self.C = torch.tensor(c) + self.X = torch.tensor(x) + self.Y = torch.tensor(y) + + self.n_samples = self.C.shape[0] + + if correlation or markov: + self.Y = self.X + + # partition data + train_idx, test_idx, val_idx = self._create_idx(pct_test=0.2, pct_val=0.2) + + self.full_dataset = self.dataset(self.C, self.X, self.Y) + self.train_dataset = self.dataset( + self.C[train_idx], self.X[train_idx], self.Y[train_idx] + ) + self.test_dataset = self.dataset( + self.C[test_idx], self.X[test_idx], self.Y[test_idx] + ) + self.val_dataset = self.dataset( + self.C[val_idx], self.X[val_idx], self.Y[val_idx] + ) + self.pred_dataset = self.test_dataset # default to test dataset + + def setup(self, stage: str, pred_dl_type=None): + # Assign full/test/train/val datasets for use in dataloaders + + if stage == "predict": + + pred_dl_to_dataset = { + "full": self.full_dataset, + "train": self.train_dataset, + "test": self.test_dataset, + "val": self.val_dataset, + } + + assert pred_dl_type in [None] + list( + pred_dl_to_dataset.keys() + ), "Error, invalid dataset type for predict dataloader not specified. Choose from 'test', 'train', 'val', 'full'." + + if pred_dl_type: + self.pred_dataset = pred_dl_to_dataset[pred_dl_type] + + def full_dataloader(self): + return DataLoader( + self.full_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def val_dataloader(self): + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def predict_dataloader(self): + return DataLoader( + self.pred_dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers, + ) + + def _create_idx(self, pct_test=0.2, pct_val=0.2): + # create idx for test, train, val + + test_idx = np.random.choice( + range(self.n_samples), int(pct_test * self.n_samples), replace=False + ) + non_test_idx = list(set(range(self.n_samples)) - set(test_idx)) + + val_idx = np.random.choice( + non_test_idx, int(pct_val * len(non_test_idx)), replace=False + ) + train_idx = list(set(non_test_idx) - set(val_idx)) + np.random.shuffle(train_idx) + + return train_idx, test_idx, val_idx diff --git a/contextualized/regression/datasets.py b/contextualized/regression/datasets.py index 56554d56..29b9c729 100644 --- a/contextualized/regression/datasets.py +++ b/contextualized/regression/datasets.py @@ -1,154 +1,195 @@ """ -Data generators used for Contextualized regression training. +Datasets used for Contextualized regression training. """ from abc import abstractmethod import torch -from torch.utils.data import IterableDataset -class Dataset: - """Superclass for datastreams (iterators) used to train contextualized.regression models""" +class RegressionDatasetBase: + """ + Superclass for map-based datasets used to train contextualized.regression models + """ def __init__(self, C, X, Y, dtype=torch.float): + + self.dtype = dtype self.C = torch.tensor(C, dtype=dtype) self.X = torch.tensor(X, dtype=dtype) self.Y = torch.tensor(Y, dtype=dtype) - self.n_i = 0 - self.x_i = 0 - self.y_i = 0 - self.n = C.shape[0] + self.c_dim = C.shape[-1] self.x_dim = X.shape[-1] self.y_dim = Y.shape[-1] - self.dtype = dtype + self.n = self.C.shape[0] + + assert len(set([self.C.shape[0], self.X.shape[0], self.Y.shape[0]])) == 1 - def __iter__(self): self.n_i = 0 - self.x_i = 0 self.y_i = 0 - return self + self.x_i = 0 + self.sample_ids = [] + + for i in range(len(self)): + self.sample_ids.append(self._get_next_id(i)) + + @abstractmethod + def _total_len(self): + pass @abstractmethod - def __next__(self): + def _get_next_id(self): pass @abstractmethod def __len__(self): pass + @abstractmethod + def __getitem__(self): + pass -class MultivariateDataset(Dataset): + +class UnivariateDataset(RegressionDatasetBase): """ - Simple multivariate dataset with context, predictors, and outcomes. + Simple univariate dataset with context, predictors, and one outcome. """ - def __next__(self): + def _get_next_id(self, i): if self.n_i >= self.n: self.n_i = 0 raise StopIteration - ret = ( - self.C[self.n_i], - self.X[self.n_i].expand(self.y_dim, -1), - self.Y[self.n_i].unsqueeze(-1), - self.n_i, - ) + ret = self.n_i self.n_i += 1 return ret def __len__(self): return self.n + def __getitem__(self, idx): + n_i = self.sample_ids[idx] + ret = ( + self.C[n_i], + self.X[n_i].expand(self.y_dim, -1).unsqueeze(-1), + self.Y[n_i].expand(self.x_dim, -1).T.unsqueeze(-1), + n_i, + ) + return ret + -class UnivariateDataset(Dataset): +class MultivariateDataset(RegressionDatasetBase): """ - Simple univariate dataset with context, predictors, and one outcome. + Simple multivariate dataset with context, predictors, and outcomes. """ - def __next__(self): + def _get_next_id(self, i): if self.n_i >= self.n: self.n_i = 0 raise StopIteration - ret = ( - self.C[self.n_i], - self.X[self.n_i].expand(self.y_dim, -1).unsqueeze(-1), - self.Y[self.n_i].expand(self.x_dim, -1).T.unsqueeze(-1), - self.n_i, - ) + ret = self.n_i self.n_i += 1 return ret def __len__(self): return self.n + def __getitem__(self, idx): + # one vs multiple + n_i = self.sample_ids[idx] + ret = ( + self.C[n_i], + self.X[n_i].expand(self.y_dim, -1), + self.Y[n_i].unsqueeze(-1), + n_i, + ) + return ret + -class MultitaskMultivariateDataset(Dataset): +class MultitaskUnivariateDataset(RegressionDatasetBase): """ - Multi-task Multivariate Dataset. + Multitask Univariate Dataset """ - def __next__(self): + def _get_next_id(self, i): if self.y_i >= self.y_dim: - self.n_i += 1 + self.x_i += 1 self.y_i = 0 + if self.x_i >= self.x_dim: + self.n_i += 1 + self.x_i = 0 if self.n_i >= self.n: self.n_i = 0 raise StopIteration - t = torch.zeros(self.y_dim) - t[self.y_i] = 1 + + t = [0] * (self.x_dim + self.y_dim) + t[self.x_i] = 1 + t[self.x_dim + self.y_i] = 1 + ret = ( - self.C[self.n_i], t, - self.X[self.n_i], - self.Y[self.n_i, self.y_i].unsqueeze(0), self.n_i, + self.x_i, self.y_i, ) + self.y_i += 1 return ret def __len__(self): - return self.n * self.y_dim + return self.n * self.y_dim * self.x_dim + + def __getitem__(self, idx): + t, n_i, x_i, y_i = self.sample_ids[idx] + t = torch.zeros(self.x_dim + self.y_dim) + ret = ( + self.C[n_i], + torch.tensor(t, dtype=torch.float), + self.X[n_i, x_i].unsqueeze(0), + self.Y[n_i, y_i].unsqueeze(0), + n_i, + x_i, + y_i, + ) + return ret -class MultitaskUnivariateDataset(Dataset): +class MultitaskMultivariateDataset(RegressionDatasetBase): """ - Multitask Univariate Dataset + Multi-task Multivariate Dataset. """ - def __next__(self): + def _get_next_id(self, i): if self.y_i >= self.y_dim: - self.x_i += 1 - self.y_i = 0 - if self.x_i >= self.x_dim: self.n_i += 1 - self.x_i = 0 + self.y_i = 0 if self.n_i >= self.n: self.n_i = 0 raise StopIteration - t = torch.zeros(self.x_dim + self.y_dim) - t[self.x_i] = 1 - t[self.x_dim + self.y_i] = 1 - ret = ( - self.C[self.n_i], - t, - self.X[self.n_i, self.x_i].unsqueeze(0), - self.Y[self.n_i, self.y_i].unsqueeze(0), - self.n_i, - self.x_i, - self.y_i, - ) + t = [0] * (self.y_dim) + t[self.y_i] = 1 + + ret = (t, self.n_i, self.y_i) self.y_i += 1 return ret def __len__(self): - return self.n * self.x_dim * self.y_dim - + return self.n * self.y_dim -class DataIterable(IterableDataset): - """Dataset wrapper, required by PyTorch""" + def __getitem__(self, idx): + t, n_i, y_i = self.sample_ids[idx] + ret = ( + self.C[n_i], + torch.tensor(t, dtype=torch.float), + self.X[n_i], + self.Y[n_i, y_i].unsqueeze(0), + n_i, + y_i, + ) + return ret - def __init__(self, dataset): - self.dataset = dataset - def __iter__(self): - return iter(self.dataset) +DATASETS = { + "multivariate": MultivariateDataset, + "univariate": UnivariateDataset, + "multitask_multivariate": MultitaskMultivariateDataset, + "multitask_univariate": MultitaskUnivariateDataset, +} diff --git a/contextualized/regression/lightning_modules.py b/contextualized/regression/lightning_modules.py index 4aa420b2..38906850 100644 --- a/contextualized/regression/lightning_modules.py +++ b/contextualized/regression/lightning_modules.py @@ -11,13 +11,13 @@ Implemented with PyTorch Lightning """ - from abc import abstractmethod import numpy as np import torch from torch.utils.data import DataLoader import pytorch_lightning as pl +from contextualized.regression.datamodules import RegressionDataModule from contextualized.regression.regularizers import REGULARIZERS from contextualized.regression.losses import MSE from contextualized.functions import LINK_FUNCTIONS @@ -28,8 +28,8 @@ MultitaskMetamodel, TasksplitMetamodel, ) + from contextualized.regression.datasets import ( - DataIterable, MultivariateDataset, UnivariateDataset, MultitaskMultivariateDataset, @@ -72,6 +72,18 @@ def _build_metamodel(self, *args, **kwargs): """ # builds the metamodel + @abstractmethod + def datamodule(self, C, X, Y, batch_size=32): + """ + + :param C: + :param X: + :param Y: + :param batch_size: (Default value = 32) + + """ + # returns the datamodule for this class + @abstractmethod def dataloader(self, C, X, Y, batch_size=32): """ @@ -106,7 +118,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): # returns predicted params on the given batch @abstractmethod - def _params_reshape(self, beta_preds, mu_preds, dataloader): + def _params_reshape(self, beta_preds, mu_preds, dataclass): """ :param beta_preds: @@ -117,7 +129,7 @@ def _params_reshape(self, beta_preds, mu_preds, dataloader): # reshapes the batch parameter predictions into beta (y_dim, x_dim) @abstractmethod - def _y_reshape(self, y_preds, dataloader): + def _y_reshape(self, y_preds, dataclass): """ :param y_preds: @@ -132,7 +144,9 @@ def forward(self, *args, **kwargs): :param *args: """ + beta, mu = self.metamodel(*args) + if self.base_param_predictor is not None: base_beta, base_mu = self.base_param_predictor.predict_params(*args) beta = beta + base_beta @@ -187,6 +201,7 @@ def _predict_from_models(self, X, beta_hat, mu_hat): :param mu_hat: """ + return self.link_fn((beta_hat * X).sum(axis=-1).unsqueeze(-1) + mu_hat) def _predict_y(self, C, X, beta_hat, mu_hat): @@ -203,19 +218,60 @@ def _predict_y(self, C, X, beta_hat, mu_hat): Y = Y + self.base_y_predictor.predict_y(C, X) return Y - def _dataloader(self, C, X, Y, dataset_constructor, **kwargs): + def _dataclass_to_dataloader(self, dataclass, dataloader_type="pred"): + """ + Input an ambigous dataloader/datamodule and return a dataloader (type specifiable). + + :param dc: the dataclass to convert to dataloader + :param dl_type: if a datamodule, which of its dataloaders to use + """ + dataloader = None + + if type(dataclass) == DataLoader: + dataloader = dataclass + elif type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + dl_type_to_dl = { + "full": lambda x: x.full_dataloader(), + "train": lambda x: x.train_dataloader(), + "test": lambda x: x.test_dataloader(), + "val": lambda x: x.val_dataloader(), + "pred": lambda x: x.predict_dataloader(), + } + dataloader = dl_type_to_dl[dataloader_type](dataclass) + else: + print("Error, dataloader nor datamodule are specified.") + + return dataloader + + def _datamodule(self, C, X, Y, dataset, **kwargs): """ :param C: :param X: :param Y: - :param dataset_constructor: + :param correlation: :param **kwargs: """ kwargs["num_workers"] = kwargs.get("num_workers", 0) kwargs["batch_size"] = kwargs.get("batch_size", 32) - return DataLoader(dataset=DataIterable(dataset_constructor(C, X, Y)), **kwargs) + kwargs["correlation"] = kwargs.get("correlation", False) + kwargs["markov"] = kwargs.get("markov", False) + + return RegressionDataModule(C, X, Y, dataset=dataset, **kwargs) + + def _dataloader(self, C, X, Y, dataset, **kwargs): + """ + :param C: + :param X: + :param Y: + :param dataset_constructor: + :param **kwargs: + """ + + kwargs["num_workers"] = kwargs.get("num_workers", 0) + kwargs["batch_size"] = kwargs.get("batch_size", 32) + return DataLoader(dataset=dataset(C, X, Y), **kwargs) class NaiveContextualizedRegression(ContextualizedRegressionBase): @@ -255,14 +311,17 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(C) return beta_hat, mu_hat - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -271,21 +330,24 @@ def _params_reshape(self, preds, dataloader): mus[n_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, X, _, n_idx = data ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, **kwargs): """ :param C: @@ -293,6 +355,15 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule(C, X, Y, "multivariate", **kwargs) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) @@ -317,12 +388,8 @@ def _batch_loss(self, batch, batch_idx): :param batch_idx: """ - ( - C, - X, - Y, - _, - ) = batch + (C, X, Y, _,) = batch + beta_hat, mu_hat = self.predict_step(batch, batch_idx) pred_loss = self.loss_fn(Y, self._predict_y(C, X, beta_hat, mu_hat)) reg_loss = self.model_regularizer(beta_hat, mu_hat) @@ -339,14 +406,16 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(C) return beta_hat, mu_hat - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -355,21 +424,22 @@ def _params_reshape(self, preds, dataloader): mus[n_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, X, _, n_idx = data ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, **kwargs): """ :param C: @@ -377,6 +447,15 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule(C, X, Y, "multivariate", **kwargs) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, MultivariateDataset, **kwargs) @@ -418,14 +497,16 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(C, T) return beta_hat, mu_hat - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -434,21 +515,24 @@ def _params_reshape(self, preds, dataloader): mus[n_idx, y_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, _, X, _, n_idx, y_idx = data ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, **kwargs): """ :param C: @@ -456,6 +540,15 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule(C, X, Y, "multitask_multivariate", **kwargs) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) @@ -497,14 +590,17 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(C, T) return beta_hat, mu_hat - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -513,21 +609,23 @@ def _params_reshape(self, preds, dataloader): mus[n_idx, y_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, _, X, _, n_idx, y_idx = data ys[n_idx, y_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, **kwargs): """ :param C: @@ -535,6 +633,15 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule(C, X, Y, "multitask_multivariate", **kwargs) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, MultitaskMultivariateDataset, **kwargs) @@ -552,14 +659,16 @@ def _build_metamodel(self, *args, **kwargs): kwargs["univariate"] = True self.metamodel = SubtypeMetamodel(*args, **kwargs) - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = np.zeros((ds.n, ds.y_dim, ds.x_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -568,21 +677,23 @@ def _params_reshape(self, preds, dataloader): mus[n_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, X, _, n_idx = data ys[n_idx] = self._predict_y(C, X, beta_hats, mu_hats).squeeze(-1) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, **kwargs): """ :param C: @@ -590,6 +701,15 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule(C, X, Y, "univariate", **kwargs) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, UnivariateDataset, **kwargs) @@ -631,14 +751,16 @@ def predict_step(self, batch, batch_idx): beta_hat, mu_hat = self(C, T) return beta_hat, mu_hat - def _params_reshape(self, preds, dataloader): + def _params_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset betas = np.zeros((ds.n, ds.y_dim, ds.x_dim)) mus = betas.copy() for (beta_hats, mu_hats), data in zip(preds, dataloader): @@ -647,14 +769,16 @@ def _params_reshape(self, preds, dataloader): mus[n_idx, y_idx, x_idx] = mu_hats.squeeze(-1) return betas, mus - def _y_reshape(self, preds, dataloader): + def _y_reshape(self, preds, dataclass): """ :param preds: :param dataloader: """ - ds = dataloader.dataset.dataset + dataloader = self._dataclass_to_dataloader(dataclass, dataloader_type="pred") + + ds = dataloader.dataset ys = np.zeros((ds.n, ds.y_dim, ds.x_dim)) for (beta_hats, mu_hats), data in zip(preds, dataloader): C, _, X, _, n_idx, x_idx, y_idx = data @@ -663,7 +787,7 @@ def _y_reshape(self, preds, dataloader): ) return ys - def dataloader(self, C, X, Y, **kwargs): + def datamodule(self, C, X, Y, correlation=False, **kwargs): """ :param C: @@ -671,6 +795,17 @@ def dataloader(self, C, X, Y, **kwargs): :param Y: :param **kwargs: + """ + return self._datamodule( + C, X, Y, "multitask_univariate", correlation=correlation, **kwargs + ) + + def dataloader(self, C, X, Y, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ return self._dataloader(C, X, Y, MultitaskUnivariateDataset, **kwargs) @@ -687,7 +822,7 @@ def __init__(self, context_dim, x_dim, **kwargs): del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) - def dataloader(self, C, X, Y=None, **kwargs): + def datamodule(self, C, X, Y=None, **kwargs): """ :param C: @@ -695,6 +830,19 @@ def dataloader(self, C, X, Y=None, **kwargs): :param Y: :param **kwargs: + """ + if Y is not None: + print( + "Passed a Y, but this is self-correlation between X featuers. Ignoring Y." + ) + return self._datamodule(C, X, X, "univariate", correlation=True, **kwargs) + + def dataloader(self, C, X, Y=None, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: """ if Y is not None: print( @@ -715,7 +863,7 @@ def __init__(self, context_dim, x_dim, **kwargs): del kwargs["y_dim"] super().__init__(context_dim, x_dim, x_dim, **kwargs) - def dataloader(self, C, X, Y=None, **kwargs): + def datamodule(self, C, X, Y=None, **kwargs): """ :param C: @@ -726,7 +874,20 @@ def dataloader(self, C, X, Y=None, **kwargs): """ if Y is not None: print( - "Passed a Y, but this is self-correlation between X featuers. Ignoring Y." + "Passed a Y, but this is self-correlation between X features. Ignoring Y." + ) + return super().datamodule(C, X, X, correlation=True, **kwargs) + + def dataloader(self, C, X, Y=None, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: + """ + if Y is not None: + print( + "Passed a Y, but this is self-correlation between X features. Ignoring Y." ) return super().dataloader(C, X, X, **kwargs) @@ -760,7 +921,7 @@ def predict_step(self, batch, batch_idx): beta_hat = beta_hat * self.diag_mask.expand(beta_hat.shape[0], -1, -1) return beta_hat, mu_hat - def dataloader(self, C, X, Y=None, **kwargs): + def datamodule(self, C, X, Y=None, **kwargs): """ :param C: @@ -772,6 +933,33 @@ def dataloader(self, C, X, Y=None, **kwargs): if Y is not None: print( - "Passed a Y, but this is a Markov Graph between X featuers. Ignoring Y." + "Passed a Y, but this is a Markov Graph between X features. Ignoring Y." + ) + return self._datamodule(C, X, X, "multivariate", **kwargs) + + def dataloader(self, C, X, Y=None, **kwargs): + """ + :param C: + :param X: + :param Y: + :param **kwargs: + """ + + if Y is not None: + print( + "Passed a Y, but this is a Markov Graph between X features. Ignoring Y." ) return super().dataloader(C, X, X, **kwargs) + + +MODELS = { + "naive": NaiveContextualizedRegression, + "multivariate": ContextualizedRegression, + "multitask": MultitaskContextualizedRegression, + "tasksplit": TasksplitContextualizedRegression, + "univariate": ContextualizedUnivariateRegression, + "tasksplit_univariate": TasksplitContextualizedUnivariateRegression, + "correlation": ContextualizedCorrelation, + "tasksplit_correlation": TasksplitContextualizedCorrelation, + "markov": ContextualizedMarkovGraph, +} diff --git a/contextualized/regression/losses.py b/contextualized/regression/losses.py index 9fb67c89..5ec31370 100644 --- a/contextualized/regression/losses.py +++ b/contextualized/regression/losses.py @@ -29,3 +29,6 @@ def BCELoss(Y_true, Y_pred): Y_true * torch.log(Y_pred + 1e-8) + (1 - Y_true) * torch.log(1 - Y_pred + 1e-8) ) return loss.mean() + + +LOSSES = {"mse": MSE, "bceloss": BCELoss} diff --git a/contextualized/regression/metamodels.py b/contextualized/regression/metamodels.py index d9ba6d90..9df90d63 100644 --- a/contextualized/regression/metamodels.py +++ b/contextualized/regression/metamodels.py @@ -7,12 +7,12 @@ from contextualized.modules import ENCODERS, Explainer, SoftSelect from contextualized.functions import LINK_FUNCTIONS +METAMODELS = ["simple", "subtype", "multitask", "tasksplit"] + class NaiveMetamodel(nn.Module): """Probabilistic assumptions as a graphical model (observed) {unobserved}: (C) --> {beta, mu} --> (X, Y) - - """ def __init__( diff --git a/contextualized/regression/tests.py b/contextualized/regression/tests.py index 994c7871..b4f2034d 100644 --- a/contextualized/regression/tests.py +++ b/contextualized/regression/tests.py @@ -1,13 +1,16 @@ """ Unit tests for Contextualized Regression. """ + import unittest import numpy as np import torch + # from contextualized.modules import NGAM, MLP, SoftSelect, Explainer from contextualized.regression.lightning_modules import * from contextualized.regression.trainers import * +from contextualized.regression.datamodules import RegressionDataModule from contextualized.functions import LINK_FUNCTIONS from contextualized.utils import DummyParamPredictor, DummyYPredictor @@ -47,46 +50,114 @@ def setUp(self): self.c_dim, self.x_dim, self.y_dim = c_dim, x_dim, y_dim self.C, self.X, self.Y = C.numpy(), X.numpy(), Y.numpy() - def _quicktest(self, model, univariate=False, correlation=False, markov=False): + def _quicktest( + self, + model, + univariate=False, + correlation=False, + markov=False, + dataclass_type="dataloader", + ): """ - :param model: :param univariate: (Default value = False) :param correlation: (Default value = False) :param markov: (Default value = False) + :param dm_pred_dl_type: Dataset to use for predict (Default value = 'Full', choose from 'test', 'train', 'val', 'full') """ + print(f"\n{type(model)} quicktest") + + get_dataclass = { + "datamodule": lambda x, **kwargs: x.datamodule(**kwargs), + "dataloader": lambda x, **kwargs: x.dataloader(**kwargs), + } + + # get dataclass & trainer if correlation: - dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size) + dataclass = get_dataclass[dataclass_type]( + model, + C=self.C, + X=np.hstack((self.X, self.Y)), + batch_size=self.batch_size, + ) trainer = CorrelationTrainer(max_epochs=self.epochs) - y_true = np.tile(self.X[:, :, np.newaxis], (1, 1, self.X.shape[-1])) + elif markov: - dataloader = model.dataloader(self.C, self.X, batch_size=self.batch_size) + # use concatenated X,Y for X -> Y prediction + dataclass = get_dataclass[dataclass_type]( + model, + C=self.C, + X=np.hstack((self.X, self.Y)), + batch_size=self.batch_size, + ) trainer = MarkovTrainer(max_epochs=self.epochs) - y_true = self.X else: - dataloader = model.dataloader( - self.C, self.X, self.Y, batch_size=self.batch_size + dataclass = get_dataclass[dataclass_type]( + model, C=self.C, X=self.X, Y=self.Y, batch_size=self.batch_size ) - trainer = RegressionTrainer(max_epochs=self.epochs) - if univariate: - y_true = np.tile(self.Y[:, :, np.newaxis], (1, 1, self.X.shape[-1])) - else: - y_true = self.Y - y_preds = trainer.predict_y(model, dataloader) - err_init = ((y_true - y_preds) ** 2).mean() - trainer.fit(model, dataloader) - trainer.validate(model, dataloader) - trainer.test(model, dataloader) - beta_preds, mu_preds = trainer.predict_params(model, dataloader) - if correlation: - rhos = trainer.predict_correlation(model, dataloader) - if markov: - omegas = trainer.predict_precision(model, dataloader) - y_preds = trainer.predict_y(model, dataloader) - err_trained = ((y_true - y_preds) ** 2).mean() - assert err_trained < err_init, "Model failed to converge" + trainer = RegressionTrainer(max_epochs=self.epochs, univariate=univariate) + + # train / eval models + if type(dataclass_type) in ( + pl.LightningDataModule, + RegressionDataModule, + ): # datamodule + err_init = {} + err_trained = {} + # pre-train mse/preds + for p_type in ["train", "test", "val", "full"]: + y_preds = trainer.predict_y(model, dataclass, dm_pred_type=p_type) + + # train + trainer.fit(model, dataclass) + trainer.validate(model, dataclass) + trainer.test(model, dataclass) + + # post-train predictions + for p_type in ["train", "test", "val", "full"]: + y_preds = trainer.predict_y(model, dataclass, p_type) + + beta_preds, mu_preds = trainer.predict_params(model, dataclass, p_type) + + if correlation: + rhos = trainer.predict_correlation(model, dataclass, p_type) + if markov: + omegas = trainer.predict_precision(model, dataclass, p_type) + + err_trained[p_type] = ((y_true - y_preds) ** 2).mean() + + assert ( + err_trained[p_type] < err_init[p_type] + ), "Model failed to converge" + + else: # dataloader + # pre-train mse/preds + y_preds = trainer.predict_y(model=model, dataclass=dataclass) + mse_pre = trainer.measure_mses( + model, dataclass, dm_pred_type="test", individual_preds=False + ) + + # train + trainer.fit(model, dataclass) + trainer.validate(model, dataclass) + trainer.test(model, dataclass) + + # post-train mse/preds + y_preds = trainer.predict_y(model, dataclass) + mse_post = trainer.measure_mses( + model, dataclass, dm_pred_type="test", individual_preds=False + ) + + beta_preds, mu_preds = trainer.predict_params(model, dataclass) + + if correlation: + rhos = trainer.predict_correlation(model, dataclass) + if markov: + omegas = trainer.predict_precision(model, dataclass) + + assert mse_post < mse_pre, "Model failed to converge" def test_naive(self): """ @@ -104,7 +175,8 @@ def test_naive(self): }, link_fn=LINK_FUNCTIONS["identity"], ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") model = NaiveContextualizedRegression( self.c_dim, @@ -118,7 +190,8 @@ def test_naive(self): }, link_fn=LINK_FUNCTIONS["identity"], ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") model = NaiveContextualizedRegression( self.c_dim, @@ -131,7 +204,8 @@ def test_naive(self): }, link_fn=LINK_FUNCTIONS["identity"], ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") model = NaiveContextualizedRegression( self.c_dim, @@ -144,7 +218,8 @@ def test_naive(self): }, link_fn=LINK_FUNCTIONS["logistic"], ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") model = NaiveContextualizedRegression( self.c_dim, @@ -157,7 +232,8 @@ def test_naive(self): }, link_fn=LINK_FUNCTIONS["logistic"], ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") parambase = DummyParamPredictor((self.y_dim, self.x_dim), (self.y_dim, 1)) model = NaiveContextualizedRegression( @@ -172,7 +248,8 @@ def test_naive(self): link_fn=LINK_FUNCTIONS["logistic"], base_param_predictor=parambase, ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") ybase = DummyYPredictor((self.y_dim, 1)) model = NaiveContextualizedRegression( @@ -187,7 +264,8 @@ def test_naive(self): link_fn=LINK_FUNCTIONS["logistic"], base_y_predictor=ybase, ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") def test_subtype(self): """ @@ -203,7 +281,10 @@ def test_subtype(self): base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model) + # self._quicktest(model, dataclass="dataloader") + + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") def test_multitask(self): """ @@ -219,7 +300,8 @@ def test_multitask(self): base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model) + # self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") def test_tasksplit(self): """ @@ -235,7 +317,8 @@ def test_tasksplit(self): base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model) + self._quicktest(model, dataclass_type="dataloader") + self._quicktest(model, dataclass_type="datamodule") def test_univariate_subtype(self): """ @@ -253,7 +336,8 @@ def test_univariate_subtype(self): base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model, univariate=True) + self._quicktest(model, univariate=True, dataclass_type="dataloader") + self._quicktest(model, univariate=True, dataclass_type="datamodule") def test_univariate_tasksplit(self): """ @@ -269,7 +353,8 @@ def test_univariate_tasksplit(self): base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model, univariate=True) + self._quicktest(model, univariate=True, dataclass_type="dataloader") + self._quicktest(model, univariate=True, dataclass_type="datamodule") def test_correlation_subtype(self): """ @@ -277,16 +362,18 @@ def test_correlation_subtype(self): """ # Correlation parambase = DummyParamPredictor( - (self.x_dim, self.x_dim, 1), (self.x_dim, self.x_dim, 1) + (self.x_dim + self.y_dim, self.x_dim + self.y_dim, 1), + (self.x_dim + self.y_dim, self.x_dim + self.y_dim, 1), ) - ybase = DummyYPredictor((self.x_dim, self.x_dim, 1)) + ybase = DummyYPredictor((self.x_dim + self.y_dim, self.x_dim + self.y_dim, 1)) model = ContextualizedCorrelation( self.c_dim, - self.x_dim, + self.x_dim + self.y_dim, base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model, correlation=True) + self._quicktest(model, correlation=True, dataclass_type="dataloader") + self._quicktest(model, correlation=True, dataclass_type="datamodule") def test_correlation_tasksplit(self): """ @@ -297,26 +384,30 @@ def test_correlation_tasksplit(self): ybase = DummyYPredictor((1,)) model = TasksplitContextualizedCorrelation( self.c_dim, - self.x_dim, + self.x_dim + self.y_dim, base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model, correlation=True) + self._quicktest(model, correlation=True, dataclass_type="dataloader") + self._quicktest(model, correlation=True, dataclass_type="datamodule") def test_markov_subtype(self): """ Test Markov Graphs. """ # Markov Graph - parambase = DummyParamPredictor((self.x_dim, self.x_dim), (self.x_dim, 1)) - ybase = DummyYPredictor((self.x_dim, 1)) + parambase = DummyParamPredictor( + (self.y_dim + self.x_dim, 1), (self.y_dim + self.x_dim, 1) + ) + ybase = DummyYPredictor((self.y_dim + self.x_dim, 1)) model = ContextualizedMarkovGraph( self.c_dim, - self.x_dim, + self.x_dim + self.y_dim, base_param_predictor=parambase, base_y_predictor=ybase, ) - self._quicktest(model, markov=True) + self._quicktest(model, markov=True, dataclass_type="dataloader") + self._quicktest(model, markov=True, dataclass_type="datamodule") if __name__ == "__main__": diff --git a/contextualized/regression/trainers.py b/contextualized/regression/trainers.py index 3e29bb7f..df41a3d5 100644 --- a/contextualized/regression/trainers.py +++ b/contextualized/regression/trainers.py @@ -3,6 +3,7 @@ """ import numpy as np import pytorch_lightning as pl +from contextualized.regression.datamodules import RegressionDataModule class RegressionTrainer(pl.Trainer): @@ -10,22 +11,88 @@ class RegressionTrainer(pl.Trainer): Trains the contextualized.regression lightning_modules """ - def predict_params(self, model, dataloader): + def __init__(self, univariate=False, **kwargs): + super().__init__(**kwargs) + self.univariate = univariate + + def predict_params(self, model, dataclass, dm_pred_type="test"): """ + :param model: Model to use for predicting + :param dataclass: Dataloader or datamodule class to predict on + :param dm_pred_type: If dataclass is a datamodule, choose what dataset the predict_dataloader will use: 'test', 'train', 'val', 'full' + Returns context-specific regression models - beta (numpy.ndarray): (n, y_dim, x_dim) - mu (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) """ - preds = super().predict(model, dataloader) - return model._params_reshape(preds, dataloader) - def predict_y(self, model, dataloader): + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + dataclass.setup("predict", dm_pred_type) + + preds = super().predict(model, dataclass) + return model._params_reshape(preds, dataclass) + + def predict_y(self, model, dataclass, dm_pred_type="test"): """ + :param model: Model to use for predicting + :param dataclass: Dataloader or datamodule class to predict on + :param dm_pred_type: If dataclass is a datamodule, choose what dataset the predict_dataloader will use: 'test', 'train', 'val', 'full' + Returns context-specific predictions of the response Y - y_hat (numpy.ndarray): (n, y_dim, [1 if normal regression, x_dim if univariate]) """ - preds = super().predict(model, dataloader) - return model._y_reshape(preds, dataloader) + + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + dataclass.setup("predict", dm_pred_type) + + preds = super().predict(model, dataclass) + return model._y_reshape(preds, dataclass) + + def measure_mses( + self, model, dataclass, dm_pred_type="test", individual_preds=False + ): + """ + Measure mean-squared errors. + """ + + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + # datamodule + ds = { + "full": dataclass.full_dataset, + "test": dataclass.test_dataset, + "train": dataclass.train_dataset, + "val": dataclass.val_dataset, + }[dm_pred_type] + + else: # dataloader + ds = dataclass.dataset + + X = ds.X.numpy() + Y = ds.Y.numpy() + + betas, mus = self.predict_params(model, dataclass, dm_pred_type=dm_pred_type) + mses = np.zeros((len(X))) # n_samples + + if self.univariate: + for i in range(Y.shape[-1]): + for j in range(X.shape[-1]): + tiled_xi = X[:, i] + tiled_xj = X[:, j] + residuals = tiled_xi - betas[:, i, j] * tiled_xj + mus[:, i, j] + mses += residuals ** 2 / (X.shape[-1] ** 2) + + if not individual_preds: + mses = np.mean(mses) + else: + for i in range(Y.shape[-1]): + for j in range(X.shape[-1]): + tiled_xi = X[:, i] + tiled_xj = X[:, j] + residuals = tiled_xi - betas[:, i, j] * tiled_xj + mus[:, i] + mses += residuals ** 2 / (X.shape[-1] ** 2) + if not individual_preds: + mses = np.mean(mses) + return mses class CorrelationTrainer(RegressionTrainer): @@ -33,12 +100,20 @@ class CorrelationTrainer(RegressionTrainer): Trains the contextualized.regression correlation lightning_modules """ - def predict_correlation(self, model, dataloader): + def predict_correlation(self, model, dataclass, dm_pred_type="test"): """ + :param model: Model to use for predicting + :param dataclass: Dataloader or datamodule class to predict on + :param dm_pred_type: If dataclass is a datamodule, choose what dataset the predict_dataloader will use: 'test', 'train', 'val', 'full' + Returns context-specific correlation networks containing Pearson's correlation coefficient - - correlation (numpy.ndarray): (n, x_dim, x_dim) + - correlation (numpy.ndarray): (n, x_dim, x_dim """ - betas, _ = super().predict_params(model, dataloader) + + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + dataclass.setup("predict", dm_pred_type) + + betas, _ = super().predict_params(model, dataclass) signs = np.sign(betas) signs[ signs != np.transpose(signs, (0, 2, 1)) @@ -46,14 +121,52 @@ def predict_correlation(self, model, dataloader): correlations = signs * np.sqrt(np.abs(betas * np.transpose(betas, (0, 2, 1)))) return correlations + def measure_mses( + self, model, dataclass, dm_pred_type="test", individual_preds=False + ): + """ + Measure mean-squared errors. + """ + + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + # datamodule + ds = { + "full": dataclass.full_dataset, + "test": dataclass.test_dataset, + "train": dataclass.train_dataset, + "val": dataclass.val_dataset, + }[dm_pred_type] + + else: # dataloader + ds = dataclass.dataset + + X = ds.X.numpy() + + betas, mus = self.predict_params(model, dataclass, dm_pred_type=dm_pred_type) + mses = np.zeros((len(X))) # n_samples + + for i in range(X.shape[-1]): + for j in range(X.shape[-1]): + tiled_xi = X[:, i] + tiled_xj = X[:, j] + residuals = tiled_xi - betas[:, i, j] * tiled_xj + mus[:, i, j] + mses += residuals ** 2 / (X.shape[-1] ** 2) + if not individual_preds: + mses = np.mean(mses) + return mses + class MarkovTrainer(CorrelationTrainer): """ Trains the contextualized.regression markov graph lightning_modules """ - def predict_precision(self, model, dataloader): + def predict_precision(self, model, dataclass, dm_pred_type="test"): """ + :param model: Model to use for predicting + :param dataclass: Dataloader or datamodule class to predict on + :param dm_pred_type: If dataclass is a datamodule, choose what dataset the predict_dataloader will use: 'test', 'train', 'val', 'full' + Returns context-specific precision matrix under a Gaussian graphical model Assuming all diagonal precisions are equal and constant over context, this is equivalent to the negative of the multivariate regression coefficient. @@ -61,4 +174,48 @@ def predict_precision(self, model, dataloader): """ # A trick in the markov lightning_module predict_step makes makes the predict_correlation # output equivalent to negative precision values here. - return -super().predict_correlation(model, dataloader) + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + dataclass.setup("predict", dm_pred_type) + + return -super().predict_correlation(model, dataclass) + + def measure_mses( + self, model, dataclass, dm_pred_type="test", individual_preds=False + ): + """ + Measure mean-squared errors. + """ + + if type(dataclass) in (pl.LightningDataModule, RegressionDataModule): + # datamodule + ds = { + "full": dataclass.full_dataset, + "test": dataclass.test_dataset, + "train": dataclass.train_dataset, + "val": dataclass.val_dataset, + }[dm_pred_type] + + else: # dataloader + ds = dataclass.dataset + + X = ds.X.numpy() + + betas, mus = self.predict_params(model, dataclass, dm_pred_type=dm_pred_type) + mses = np.zeros((len(X))) + for i in range(X.shape[-1]): + preds = np.array( + [X[j].dot(betas[j, i, :]) + mus[j, i] for j in range(len(X))] + ) + residuals = X[:, i] - preds + mses += residuals ** 2 / (X.shape[-1]) + + if not individual_preds: + mses = np.mean(mses) + return mses + + +TRAINERS = { + "regression_trainer": RegressionTrainer, + "correlation_trainer": CorrelationTrainer, + "markov_trainer": MarkovTrainer, +}