Skip to content

Commit

Permalink
Merge pull request #73 from megha-narayanan/prune
Browse files Browse the repository at this point in the history
Added method to prune bad data
  • Loading branch information
mrakitin authored Sep 26, 2024
2 parents 76ddd5a + 0b42d87 commit f8efcac
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 14 deletions.
18 changes: 9 additions & 9 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ jobs:
# mkdir -v -p ~/.config/databroker/
# wget https://raw.githubusercontent.com/NSLS-II/sirepo-bluesky/main/examples/local.yml -O ~/.config/databroker/local.yml

- name: Set up Python ${{ matrix.python-version }} with conda
uses: conda-incubator/setup-miniconda@v2
with:
activate-environment: ${{ env.REPOSITORY_NAME }}-py${{ matrix.python-version }}
auto-update-conda: true
miniconda-version: "latest"
python-version: ${{ matrix.python-version }}
mamba-version: "*"
channels: conda-forge
# - name: Set up Python ${{ matrix.python-version }} with conda
# uses: conda-incubator/setup-miniconda@v2
# with:
# activate-environment: ${{ env.REPOSITORY_NAME }}-py${{ matrix.python-version }}
# auto-update-conda: true
# miniconda-version: "latest"
# python-version: ${{ matrix.python-version }}
# mamba-version: "*"
# channels: conda-forge

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
matrix:
host-os: ["ubuntu-latest"]
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.9", "3.10", "3.11", "3.12"]
fail-fast: false

defaults:
Expand Down
6 changes: 4 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ blop
:target: https://github.com/NSLS-II/blop/actions/workflows/testing.yml


.. image:: https://img.shields.io/pypi/v/blop.svg
:target: https://pypi.python.org/pypi/blop
.. image:: https://img.shields.io/pypi/v/bloptools.svg
:target: https://pypi.python.org/pypi/bloptools

.. image:: https://img.shields.io/conda/vn/conda-forge/bloptools.svg
:target: https://anaconda.org/conda-forge/bloptools

Beamline Optimization Tools

Expand Down
34 changes: 33 additions & 1 deletion src/blop/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
digestion_kwargs: dict = {},
verbose: bool = False,
enforce_all_objectives_valid: bool = True,
exclude_pruned: bool = True,
model_inactive_objectives: bool = False,
tolerate_acquisition_errors: bool = False,
sample_center_on_init: bool = False,
Expand Down Expand Up @@ -142,6 +143,7 @@ def __init__(
self.model_inactive_objectives = model_inactive_objectives
self.tolerate_acquisition_errors = tolerate_acquisition_errors
self.enforce_all_objectives_valid = enforce_all_objectives_valid
self.exclude_pruned = exclude_pruned

self.train_every = train_every
self.trigger_delay = trigger_delay
Expand Down Expand Up @@ -687,7 +689,7 @@ def _construct_model(self, obj, skew_dims=None):
inputs_are_trusted = ~torch.isnan(train_inputs).any(axis=1)
targets_are_trusted = ~torch.isnan(train_targets).any(axis=1)

trusted = inputs_are_trusted & targets_are_trusted
trusted = inputs_are_trusted & targets_are_trusted & ~self.pruned_mask()

obj._model = construct_single_task_model(
X=train_inputs[trusted],
Expand Down Expand Up @@ -1055,3 +1057,33 @@ def latent_transforms(self):
def plot_pareto_front(self, **kwargs):
"""Plot the improvement of the agent over time."""
plotting._plot_pareto_front(self, **kwargs)

def prune(self, pruning_objs=[], thresholds=[]):
"""Prune low-fidelity datapoints from model fitting"""
# set the prune column to false
self._table = self._table.assign(prune=[False for i in range(self._table.shape[0])])
# make sure there are models trained for all the objectives we are pruning over
if not all(hasattr(obj, "model") for obj in pruning_objs):
raise ValueError("Not all pruning objectives have models.")
# make sure we have the same number of thresholds and objectives to prune over
if len(pruning_objs) != len(thresholds):
raise ValueError("Number of pruning objectives and thresholds should be the same")
for i in range(len(pruning_objs)):
obj = pruning_objs[i]
mll = gpytorch.mlls.ExactMarginalLogLikelihood(obj.model.likelihood, obj.model)
mlls = mll(obj.model(self.train_inputs()), self.train_targets()[obj.name].unsqueeze(-1)).detach()
mlls -= mlls.max()
mlls_wo_nans = [x for x in mlls if not np.isnan(x)]
# Q: SHOULD WE MAKE AN OPTION TO HAVE THIS BE >, IN CASE THEY ARE NOT NEGATED?
if len(mlls_wo_nans) > 0:
self._table["prune"] = torch.logical_or(
torch.tensor(self._table["prune"].values), mlls < thresholds[i] * np.quantile(mlls_wo_nans, q=0.25)
)
self.refresh()
# return self._table["prune"]

# @property
def pruned_mask(self):
if self.exclude_pruned and "prune" in self._table.columns:
return torch.tensor(self._table.prune.values.astype(bool))
return torch.zeros(len(self._table)).bool()
2 changes: 1 addition & 1 deletion src/blop/bayesian/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
if skew_dims:
self.skew_dims = [torch.arange(self.num_inputs)]
else:
self.skew_dims = [torch.arange(0)]
self.skew_dims = [(i) for i in torch.arange(num_inputs)]
elif hasattr(skew_dims, "__iter__"):
self.skew_dims = [torch.tensor(np.atleast_1d(skew_group)) for skew_group in skew_dims]
else:
Expand Down

0 comments on commit f8efcac

Please sign in to comment.