diff --git a/pytest_pyvista/pytest_pyvista.py b/pytest_pyvista/pytest_pyvista.py index 9b71bf66..eb1af2f6 100644 --- a/pytest_pyvista/pytest_pyvista.py +++ b/pytest_pyvista/pytest_pyvista.py @@ -491,7 +491,15 @@ def func_show(*args, **kwargs) -> None: # noqa: ANN002, ANN003 if user_callback is None: # special case encountered when using the `plot` property of pyvista objects user_callback = lambda *a: ... # noqa: ARG005, E731 - kwargs[key] = _ChainedCallbacks(user_callback, verify_image_cache) + # Set kwargs to None in order to get the callback from the + # global theme one which is patched by the current callback. + # This is done to make sure that the weak ref `_before_close_callback` is not dead + # when using `auto_close=False` on the plotter + # See https://github.com/pyvista/pytest-pyvista/issues/172 + callback = _ChainedCallbacks(user_callback, verify_image_cache) + kwargs[key] = None + + monkeypatch.setattr(pyvista.global_theme, "before_close_callback", callback) return old_show(*args, **kwargs) diff --git a/tests/test_pyvista.py b/tests/test_pyvista.py index 63070c85..1e31e08d 100644 --- a/tests/test_pyvista.py +++ b/tests/test_pyvista.py @@ -855,3 +855,29 @@ def test_imcache(verify_image_cache, tmp_path: Path, pytestconfig: pytest.Config args = ["--image_cache_dir", new_dir, "--failed_image_dir", "failed"] result = pytester.runpytest(*args) result.assert_outcomes(passed=1) + + +def test_auto_close(pytester: pytest.Pytester) -> None: + """Test when using auto_close=False.""" + pytester.makepyfile( + """ + import pyvista as pv + pv.OFF_SCREEN = True + def test_auto_close(verify_image_cache): + verify_image_cache.skip = True + pl = pv.Plotter() + pl.show(auto_close=False) + assert pl._before_close_callback is not None + pl.close() + + def test_auto_close_2(verify_image_cache, mocker): + verify_image_cache.skip = True + plotter = pv.Plotter() + plotter.show(auto_close=False, before_close_callback=(m:=mocker.MagicMock())) + plotter.close() + m.assert_called_once_with(plotter) + """ + ) + + result = pytester.runpytest() + result.assert_outcomes(passed=2)