diff --git a/orbit/estimators/stan_estimator.py b/orbit/estimators/stan_estimator.py index c6b2d210..3d7c8d28 100644 --- a/orbit/estimators/stan_estimator.py +++ b/orbit/estimators/stan_estimator.py @@ -2,13 +2,13 @@ import multiprocessing from abc import abstractmethod from copy import copy -from sys import platform, version_info from ..exceptions import EstimatorException from ..utils.general import update_dict from ..utils.logger import get_logger from ..utils.set_cmdstan_path import set_cmdstan_path from ..utils.stan import get_compiled_stan_model +from ..utils.cmdstanpy_compat import patch_tqdm_progress_hook from .base_estimator import BaseEstimator logger = get_logger("orbit") @@ -16,6 +16,9 @@ # Make sure models are using the right cmdstan folder set_cmdstan_path() +# Apply cmdstanpy compatibility patches +patch_tqdm_progress_hook() + class StanEstimator(BaseEstimator): """Abstract StanEstimator with shared args for all StanEstimator child classes diff --git a/orbit/utils/cmdstanpy_compat.py b/orbit/utils/cmdstanpy_compat.py new file mode 100644 index 00000000..daeed783 --- /dev/null +++ b/orbit/utils/cmdstanpy_compat.py @@ -0,0 +1,116 @@ +""" +Compatibility utilities for cmdstanpy integration. + +This module contains patches and workarounds for cmdstanpy compatibility issues. +""" + +import os +from typing import Dict, List, Optional, Callable + +from .logger import get_logger + +logger = get_logger("orbit") + + +def patch_tqdm_progress_hook(): + """ + Patch cmdstanpy progress hook to handle TQDM_DISABLE safely. + + When TQDM_DISABLE=1 is set, tqdm creates disabled progress bar objects + that don't have the 'postfix' attribute. cmdstanpy assumes this attribute + exists and tries to access it, causing AttributeError. + + This patch adds safe access checks to prevent the crash. + + See: https://github.com/uber/orbit/issues/887 + """ + # Only patch if TQDM_DISABLE is set + if os.environ.get("TQDM_DISABLE") != "1": + return + + try: + import cmdstanpy.model + import re + + # Store reference to original method to avoid patching multiple times + if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"): + return + + original_hook = getattr( + cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None + ) + if original_hook is None: + return + + @staticmethod + def safe_wrap_sampler_progress_hook( + chain_ids: List[int], + total: int, + ) -> Optional[Callable[[str, int], None]]: + """Safe version that handles disabled tqdm progress bars.""" + try: + from tqdm import tqdm + + pat = re.compile(r"Chain \[(\d*)\] (Iteration.*)") + pbars: Dict[int, tqdm] = { + chain_id: tqdm( + total=total, + bar_format="{desc} |{bar}| {elapsed} {postfix[0][value]}", + postfix=[{"value": "Status"}], + desc=f"chain {chain_id}", + colour="yellow", + ) + for chain_id in chain_ids + } + + def progress_hook(line: str, idx: int) -> None: + if line == "Done": + for pbar in pbars.values(): + # safe postfix access + if hasattr(pbar, "postfix") and pbar.postfix: + try: + pbar.postfix[0]["value"] = "Sampling completed" + except (AttributeError, KeyError, IndexError): + pass + pbar.update(total - pbar.n) + pbar.close() + else: + match = pat.match(line) + if match: + idx = int(match.group(1)) + mline = match.group(2).strip() + elif line.startswith("Iteration"): + mline = line + idx = chain_ids[idx] + else: + return + + if idx in pbars: + if "Sampling" in mline and hasattr(pbars[idx], "colour"): + pbars[idx].colour = "blue" + pbars[idx].update(1) + + # safe postfix access + if hasattr(pbars[idx], "postfix") and pbars[idx].postfix: + try: + pbars[idx].postfix[0]["value"] = mline + except (AttributeError, KeyError, IndexError): + pass + + return progress_hook + + except Exception as e: + logger.warning( + f"Progress bar setup failed: {e}. Disabling progress bars." + ) + return None + + # apply the patch + cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = ( + safe_wrap_sampler_progress_hook + ) + cmdstanpy.model.CmdStanModel._orbit_tqdm_patched = True + logger.debug("cmdstanpy progress hook patched for TQDM_DISABLE compatibility") + + except Exception as e: + logger.warning(f"Failed to patch cmdstanpy progress hook: {e}") diff --git a/tests/orbit/utils/test_cmdstanpy_compat.py b/tests/orbit/utils/test_cmdstanpy_compat.py new file mode 100644 index 00000000..081bbb7c --- /dev/null +++ b/tests/orbit/utils/test_cmdstanpy_compat.py @@ -0,0 +1,103 @@ +import os +import pytest +from unittest.mock import patch, MagicMock + +from orbit.utils.cmdstanpy_compat import patch_tqdm_progress_hook + + +@pytest.mark.parametrize( + "env_value", + [None, "0", "false", "true", ""], +) +def test_patch_tqdm_progress_hook_no_patch_scenarios(env_value): + """Test that patch is not applied when TQDM_DISABLE is not '1'.""" + env_dict = {"TQDM_DISABLE": env_value} if env_value is not None else {} + + with patch.dict(os.environ, env_dict, clear=False): + if env_value is None and "TQDM_DISABLE" in os.environ: + del os.environ["TQDM_DISABLE"] + + # Should return early without doing anything + patch_tqdm_progress_hook() + # Test passes if no exception raised + + +def test_patch_tqdm_progress_hook_applies_patch(): + """Test that patch is applied when TQDM_DISABLE=1.""" + with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): + with patch("cmdstanpy.model") as mock_cmdstanpy_model: + mock_model = MagicMock() + mock_model._wrap_sampler_progress_hook = MagicMock() + # Ensure not already patched + del mock_model._orbit_tqdm_patched + mock_cmdstanpy_model.CmdStanModel = mock_model + + patch_tqdm_progress_hook() + + assert mock_model._orbit_tqdm_patched is True + + +def test_patch_tqdm_progress_hook_no_double_patch(): + """Test that patch is not applied multiple times.""" + with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): + with patch("cmdstanpy.model") as mock_cmdstanpy_model: + mock_model = MagicMock() + mock_model._orbit_tqdm_patched = True # Already patched + original_hook = MagicMock() + mock_model._wrap_sampler_progress_hook = original_hook + mock_cmdstanpy_model.CmdStanModel = mock_model + + patch_tqdm_progress_hook() + + # Original hook should remain unchanged + assert mock_model._wrap_sampler_progress_hook is original_hook + + +def test_patch_tqdm_progress_hook_handles_missing_method(): + """Test graceful handling when original method doesn't exist.""" + with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): + with patch("cmdstanpy.model") as mock_cmdstanpy_model: + # Simple object without the method + mock_model = type("MockModel", (), {})() + mock_cmdstanpy_model.CmdStanModel = mock_model + + # Should not raise exception + patch_tqdm_progress_hook() + + # Should not set patched flag + assert not hasattr(mock_model, "_orbit_tqdm_patched") + + +def test_patch_tqdm_progress_hook_handles_import_error(): + """Test graceful handling of import errors.""" + with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): + with patch("cmdstanpy.model", side_effect=ImportError("No module")): + # Should not raise exception + patch_tqdm_progress_hook() + + +def test_integration_with_actual_cmdstanpy(): + """Integration test with actual cmdstanpy if available.""" + cmdstanpy = pytest.importorskip("cmdstanpy") + + with patch.dict(os.environ, {"TQDM_DISABLE": "1"}): + # Store original state + original_method = getattr( + cmdstanpy.model.CmdStanModel, "_wrap_sampler_progress_hook", None + ) + + try: + patch_tqdm_progress_hook() + + # Verify patch was applied + assert hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched") + assert cmdstanpy.model.CmdStanModel._orbit_tqdm_patched is True + + finally: + # Clean up + if hasattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched"): + delattr(cmdstanpy.model.CmdStanModel, "_orbit_tqdm_patched") + if original_method is not None: + cmdstanpy.model.CmdStanModel._wrap_sampler_progress_hook = ( + original_method + )