From 27e0fe42bdca00c1cd3084bc0086f850d4844791 Mon Sep 17 00:00:00 2001 From: "P. L. Lim" <2090236+pllim@users.noreply.github.com> Date: Fri, 5 Apr 2024 13:23:36 -0400 Subject: [PATCH] export plugin: warning before overwriting files (#2783) * Overwrite prompt as overlay * Address review comment and fix tests * Add test * TST: TestExportPluginPlots get out of _jail free and collect 200 dollars * Throw exception by default from API * Fix PEP 8 * Clear warning when change filename --- .../configs/default/plugins/export/export.py | 69 +++++++++++++++++-- .../configs/default/plugins/export/export.vue | 31 +++++++++ .../plugins/export/tests/test_export.py | 41 ++++++++--- jdaviz/conftest.py | 13 ++++ 4 files changed, 137 insertions(+), 17 deletions(-) diff --git a/jdaviz/configs/default/plugins/export/export.py b/jdaviz/configs/default/plugins/export/export.py index 16dca09c47..075b3820c1 100644 --- a/jdaviz/configs/default/plugins/export/export.py +++ b/jdaviz/configs/default/plugins/export/export.py @@ -92,6 +92,8 @@ class Export(PluginTemplateMixin, ViewerSelectMixin, SubsetSelectMixin, movie_recording = Bool(False).tag(sync=True) movie_interrupt = Bool(False).tag(sync=True) + overwrite_warn = Bool(False).tag(sync=True) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -231,6 +233,11 @@ def _disable_viewer_format_combo(self, event): msg = "" self.viewer_invalid_msg = msg + @observe('filename') + def _is_filename_changed(self, event): + # Clear overwrite warning when user changes filename + self.overwrite_warn = False + def _set_subset_not_supported_msg(self, msg=None): """ Check if selected subset is spectral or composite, and warn and @@ -265,7 +272,7 @@ def _set_dataset_not_supported_msg(self, msg=None): else: self.data_invalid_msg = '' - def _normalize_filename(self, filename=None, filetype=None): + def _normalize_filename(self, filename=None, filetype=None, overwrite=False): # Make sure filename is valid and file does not end up in weird places in standalone mode. if not filename: raise ValueError("Invalid filename") @@ -276,15 +283,21 @@ def _normalize_filename(self, filename=None, filetype=None): filename = Path(filename).expanduser() filepath = filename.parent - if filepath and not filepath.exists(): + if filepath and not filepath.is_dir(): raise ValueError(f"Invalid path={filepath}") elif ((not filepath or str(filepath).startswith(".")) and os.environ.get("JDAVIZ_START_DIR", "")): # noqa: E501 # pragma: no cover filename = os.environ["JDAVIZ_START_DIR"] / filename + if filename.exists() and not overwrite: + self.overwrite_warn = True + else: + self.overwrite_warn = False + return str(filename) @with_spinner() - def export(self, filename=None, show_dialog=None): + def export(self, filename=None, show_dialog=None, overwrite=False, + raise_error_for_overwrite=True): """ Export selected item(s) @@ -292,6 +305,18 @@ def export(self, filename=None, show_dialog=None): ---------- filename : str, optional If not provided, plugin value will be used. + + show_dialog : bool or `None` + If `True`, prompts dialog to save PNG/SVG from browser. + + overwrite : bool + If `True`, silently overwrite an existing file. + + raise_error_for_overwrite : bool + If `True`, raise exception when ``overwrite=False`` but + output file already exists. Otherwise, a message will be sent + to application snackbar instead. + """ if self.multiselect: @@ -309,7 +334,12 @@ def export(self, filename=None, show_dialog=None): viewer = self.viewer.selected_obj filetype = self.viewer_format.selected - filename = self._normalize_filename(filename, filetype) + filename = self._normalize_filename(filename, filetype, overwrite=overwrite) + + if self.overwrite_warn and not overwrite: + if raise_error_for_overwrite: + raise FileExistsError(f"{filename} exists but overwrite=False") + return if filetype == "mp4": self.save_movie(viewer, filename, filetype) @@ -336,21 +366,33 @@ def export(self, filename=None, show_dialog=None): elif len(self.plugin_table.selected): filetype = self.plugin_table_format.selected filename = self._normalize_filename(filename, filetype) + if self.overwrite_warn and not overwrite: + if raise_error_for_overwrite: + raise FileExistsError(f"{filename} exists but overwrite=False") + return self.plugin_table.selected_obj.export_table(filename, overwrite=True) elif len(self.subset.selected): selected_subset_label = self.subset.selected filetype = self.subset_format.selected - filename = self._normalize_filename(filename, filetype) + filename = self._normalize_filename(filename, filetype, overwrite=overwrite) if self.subset_invalid_msg != '': raise NotImplementedError(f'Subset can not be exported - {self.subset_invalid_msg}') + if self.overwrite_warn and not overwrite: + if raise_error_for_overwrite: + raise FileExistsError(f"{filename} exists but overwrite=False") + return self.save_subset_as_region(selected_subset_label, filename) elif len(self.dataset.selected): filetype = self.dataset_format.selected - filename = self._normalize_filename(filename, filetype) + filename = self._normalize_filename(filename, filetype, overwrite=overwrite) if self.data_invalid_msg != "": raise NotImplementedError(f"Data can not be exported - {self.data_invalid_msg}") + if self.overwrite_warn and not overwrite: + if raise_error_for_overwrite: + raise FileExistsError(f"{filename} exists but overwrite=False") + return self.dataset.selected_obj.write(Path(filename), overwrite=True) else: raise ValueError("nothing selected for export") @@ -359,7 +401,7 @@ def export(self, filename=None, show_dialog=None): def vue_export_from_ui(self, *args, **kwargs): try: - filename = self.export(show_dialog=True) + filename = self.export(show_dialog=True, raise_error_for_overwrite=False) except Exception as e: self.hub.broadcast(SnackbarMessage( f"Export failed with: {e}", sender=self, color="error")) @@ -368,6 +410,19 @@ def vue_export_from_ui(self, *args, **kwargs): self.hub.broadcast(SnackbarMessage( f"Exported to {filename}", sender=self, color="success")) + def vue_overwrite_from_ui(self, *args, **kwargs): + """Attempt to force writing the output if the user confirms the desire to overwrite.""" + try: + filename = self.export(show_dialog=True, overwrite=True, + raise_error_for_overwrite=False) + except Exception as e: + self.hub.broadcast(SnackbarMessage( + f"Export with overwrite failed with: {e}", sender=self, color="error")) + else: + if filename is not None: + self.hub.broadcast(SnackbarMessage( + f"Exported to {filename} (overwrite)", sender=self, color="success")) + def save_figure(self, viewer, filename=None, filetype="png", show_dialog=False): if filetype == "png": diff --git a/jdaviz/configs/default/plugins/export/export.vue b/jdaviz/configs/default/plugins/export/export.vue index 75637f210f..722b3ae600 100644 --- a/jdaviz/configs/default/plugins/export/export.vue +++ b/jdaviz/configs/default/plugins/export/export.vue @@ -220,6 +220,8 @@ +
+
Export +
+ + + + + +
+ A file with this name is already on disk. Overwrite? +
+
+ + + + Cancel + Overwrite + + +
+ +
+
+ diff --git a/jdaviz/configs/default/plugins/export/tests/test_export.py b/jdaviz/configs/default/plugins/export/tests/test_export.py index ffc217b1fb..a321befa7c 100644 --- a/jdaviz/configs/default/plugins/export/tests/test_export.py +++ b/jdaviz/configs/default/plugins/export/tests/test_export.py @@ -1,24 +1,25 @@ -import numpy as np import os -import pytest import re +import numpy as np +import pytest +from astropy import units as u from astropy.io import fits from astropy.nddata import NDData -import astropy.units as u from glue.core.edit_subset_mode import AndMode, NewMode from glue.core.roi import CircularROI, XRangeROI from regions import Regions, CircleSkyRegion from specutils import Spectrum1D -class TestExportSubsets(): +@pytest.mark.usefixtures('_jail') +class TestExportSubsets: """ Tests for exporting subsets. Currently limited to non-composite spatial subsets. """ - def test_basic_export_subsets_imviz(self, tmp_path, imviz_helper): + def test_basic_export_subsets_imviz(self, imviz_helper): data = NDData(np.ones((500, 500)) * u.nJy) @@ -94,7 +95,7 @@ def test_not_implemented(self, cubeviz_helper, spectral_cube_wcs): cubeviz_helper.app.get_viewer("spectrum-viewer").apply_roi(XRangeROI(5, 15.5)) assert 'Subset 2' not in export_plugin.subset.choices - def test_export_subsets_wcs(self, tmp_path, imviz_helper, spectral_cube_wcs): + def test_export_subsets_wcs(self, imviz_helper, spectral_cube_wcs): # using cube WCS instead of 2d imaging wcs for consistancy with # cubeviz test. accessing just the spatial part of this. @@ -125,7 +126,7 @@ def test_export_subsets_wcs(self, tmp_path, imviz_helper, spectral_cube_wcs): assert isinstance(Regions.read('sky_region.reg')[0], CircleSkyRegion) - def test_basic_export_subsets_cubeviz(self, tmp_path, cubeviz_helper, spectral_cube_wcs): + def test_basic_export_subsets_cubeviz(self, cubeviz_helper, spectral_cube_wcs): data = Spectrum1D(flux=np.ones((128, 128, 256)) * u.nJy, wcs=spectral_cube_wcs) @@ -168,6 +169,25 @@ def test_basic_export_subsets_cubeviz(self, tmp_path, cubeviz_helper, spectral_c export_plugin.export() assert os.path.isfile('test.reg') + # Overwrite not enable, so no-op with warning. + export_plugin.export(raise_error_for_overwrite=False) + assert export_plugin.overwrite_warn + + # Changing filename should clear warning. + old_filename = export_plugin.filename + export_plugin.filename = "foo" + assert not export_plugin.overwrite_warn + export_plugin.filename = old_filename + + # Overwrite not enable, but with exception from API by default. + with pytest.raises(FileExistsError, match=".* exists but overwrite=False"): + export_plugin.export() + assert export_plugin.overwrite_warn + + # User forces overwrite. + export_plugin.export(overwrite=True) + assert not export_plugin.overwrite_warn + # test that invalid file extension raises an error with pytest.raises(ValueError, match=re.escape("x not one of ['fits', 'reg'], reverting selection to reg")): # noqa @@ -183,6 +203,7 @@ def test_basic_export_subsets_cubeviz(self, tmp_path, cubeviz_helper, spectral_c export_plugin.export() +@pytest.mark.usefixtures('_jail') def test_export_data(cubeviz_helper, spectrum1d_cube): cubeviz_helper.load_data(spectrum1d_cube, data_label='test') mm = cubeviz_helper.plugins["Moment Maps"] @@ -224,9 +245,9 @@ def test_disable_export_for_unsupported_units(specviz2d_helper): assert ep.data_invalid_msg == "Export Disabled: The unit DN / s could not be saved in native FITS format." # noqa -class TestExportPluginPlots(): +class TestExportPluginPlots: - def test_basic_export_plugin_plots(tmp_path, imviz_helper): + def test_basic_export_plugin_plots(self, imviz_helper): """ Test basic funcionality of exporting plugin plots from the export plugin. Tests on the 'Plot Options: stretch_hist' @@ -259,7 +280,7 @@ def test_basic_export_plugin_plots(tmp_path, imviz_helper): assert len(available_plots) == 1 assert available_plots[0] == 'Plot Options: stretch_hist' - def test_ap_phot_plot_export(tmp_path, imviz_helper): + def test_ap_phot_plot_export(self, imviz_helper): """ Test export functionality for plot from the aperture photometry diff --git a/jdaviz/conftest.py b/jdaviz/conftest.py index 56f7a73d0a..d4be86ca84 100644 --- a/jdaviz/conftest.py +++ b/jdaviz/conftest.py @@ -3,6 +3,7 @@ # get picked up when running the tests inside an interpreter using # packagename.test +import os import warnings import numpy as np @@ -334,6 +335,18 @@ def roman_imagemodel(): return create_wfi_image_model((20, 10)) +# Copied over from https://github.com/spacetelescope/ci_watson +@pytest.fixture(scope='function') +def _jail(tmp_path): + """Perform test in a pristine temporary working directory.""" + old_dir = os.getcwd() + os.chdir(tmp_path) + try: + yield str(tmp_path) + finally: + os.chdir(old_dir) + + try: from pytest_astropy_header.display import PYTEST_HEADER_MODULES, TESTED_VERSIONS except ImportError: