Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmdstanpy/cmdstan_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def __init__(
tol_param: float | None = None,
history_size: int | None = None,
num_psis_draws: int | None = None,
num_paths: int | None = None,
num_paths: int = 4,
max_lbfgs_iters: int | None = None,
num_draws: int | None = None,
num_elbo_draws: int | None = None,
Expand Down
10 changes: 9 additions & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,7 +1328,7 @@ def pathfinder(
tol_rel_grad: float | None = None,
tol_param: float | None = None,
history_size: int | None = None,
num_paths: int | None = None,
num_paths: int = 4,
max_lbfgs_iters: int | None = None,
draws: int | None = None,
num_single_draws: int | None = None,
Expand All @@ -1352,6 +1352,7 @@ def pathfinder(
time_fmt: str = "%Y%m%d%H%M%S",
timeout: float | None = None,
num_threads: int | None = None,
save_single_paths: bool = False,
) -> CmdStanPathfinder:
"""
Run CmdStan's Pathfinder variational inference algorithm.
Expand Down Expand Up @@ -1458,6 +1459,12 @@ def pathfinder(
A number other than ``1`` requires the model to have been compiled
with STAN_THREADS=True.

:param save_single_paths: Save draws and ELBO evaluations from
individual Pathfinder runs. Draws are saved to CSV files and ELBO
evaluations are saved to JSON files. If ``True``, file paths can be
accessed via ``CmdStanPathfinder.runset.single_path_csv_files`` and
``CmdStanPathfinder.runset.single_path_json_files``.

:return: A :class:`CmdStanPathfinder` object

References
Expand Down Expand Up @@ -1506,6 +1513,7 @@ def pathfinder(
num_elbo_draws=num_elbo_draws,
psis_resample=psis_resample,
calculate_lp=calculate_lp,
save_single_paths=save_single_paths,
)

with temp_single_json(data) as _data, temp_inits(inits) as _inits:
Expand Down
39 changes: 38 additions & 1 deletion cmdstanpy/stanfit/runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from time import time

from cmdstanpy import _TMPDIR
from cmdstanpy.cmdstan_args import CmdStanArgs, Method
from cmdstanpy.cmdstan_args import CmdStanArgs, Method, PathfinderArgs
from cmdstanpy.utils import get_logger


Expand Down Expand Up @@ -57,6 +57,8 @@ def __init__(
self._stdout_files, self._profile_files = [], []
self._csv_files, self._diagnostic_files = [], []
self._config_files = []
self._single_path_csv_files: list[str] = []
self._single_path_json_files: list[str] = []

# per-process output files
if one_process_per_chain and chains > 1:
Expand Down Expand Up @@ -101,6 +103,9 @@ def __init__(
for id in self._chain_ids
]

if args.method == Method.PATHFINDER:
self.populate_pathfinder_single_path_files()

def __repr__(self) -> str:
lines = [
f"RunSet: chains={self._chains}, chain_ids={self._chain_ids}, "
Expand Down Expand Up @@ -222,6 +227,18 @@ def profile_files(self) -> list[str]:
"""List of paths to CmdStan profiler files."""
return self._profile_files

@property
def single_path_csv_files(self) -> list[str]:
"""List of paths to single-path Pathfinder output CSV files.
Only populated when method is Pathfinder and save_single_paths=True"""
return self._single_path_csv_files

@property
def single_path_json_files(self) -> list[str]:
"""List of paths to single-path Pathfinder output ELBO JSON files.
Only populated when method is Pathfinder and save_single_paths=True"""
return self._single_path_json_files

def gen_file_name(
self, suffix: str, *, extra: str = "", id: int | None = None
) -> str:
Expand Down Expand Up @@ -317,3 +334,23 @@ def raise_for_timeouts(self) -> None:
f"{sum(self._timeout_flags)} of {self.num_procs} "
"processes timed out"
)

def populate_pathfinder_single_path_files(self) -> None:
"""Properly assigns output files for Pathfinder's
save_single_paths=True option"""
if not isinstance(self._args.method_args, PathfinderArgs):
return
if self._args.method_args.save_single_paths:
num_paths = self._args.method_args.num_paths
if num_paths > 1:
self._single_path_csv_files = [
self.gen_file_name(".csv", extra="path", id=id)
for id in range(1, num_paths + 1)
]
self._single_path_json_files = [
self.gen_file_name(".json", extra="path", id=id)
for id in range(1, num_paths + 1)
]
else: # num_paths == 1
self._single_path_csv_files = [self.gen_file_name(".csv")]
self._single_path_json_files = [self.gen_file_name(".json")]
22 changes: 22 additions & 0 deletions test/test_pathfinder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import contextlib
import os
from io import StringIO
from pathlib import Path

Expand Down Expand Up @@ -193,3 +194,24 @@ def test_pathfinder_threads() -> None:
)
pathfinder = bern_model.pathfinder(data=jdata, num_threads=4)
assert pathfinder.draws().shape == (1000, 4)


def test_pathfinder_single_path_output() -> None:

stan = DATAFILES_PATH / 'bernoulli.stan'
bern_model = cmdstanpy.CmdStanModel(stan_file=stan)
jdata = str(DATAFILES_PATH / 'bernoulli.data.json')

fit = bern_model.pathfinder(data=jdata, num_paths=4, save_single_paths=True)
assert len(fit.runset.single_path_csv_files) == 4
assert len(fit.runset.single_path_json_files) == 4

assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files)
assert all(os.path.exists(f) for f in fit.runset.single_path_json_files)

fit = bern_model.pathfinder(data=jdata, num_paths=1, save_single_paths=True)
assert len(fit.runset.single_path_csv_files) == 1
assert len(fit.runset.single_path_json_files) == 1

assert all(os.path.exists(f) for f in fit.runset.single_path_csv_files)
assert all(os.path.exists(f) for f in fit.runset.single_path_json_files)
57 changes: 56 additions & 1 deletion test/test_runset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os

from cmdstanpy import _TMPDIR
from cmdstanpy.cmdstan_args import CmdStanArgs, SamplerArgs
from cmdstanpy.cmdstan_args import CmdStanArgs, PathfinderArgs, SamplerArgs
from cmdstanpy.stanfit import RunSet
from cmdstanpy.utils import EXTENSION

Expand Down Expand Up @@ -299,3 +299,58 @@ def test_chain_ids() -> None:
assert '_11.csv' in runset._csv_files[0]
assert 'id=14' in runset.cmd(3)
assert '_14.csv' in runset._csv_files[3]


def test_output_filenames_pathfinder_single_paths() -> None:
exe = os.path.join(DATAFILES_PATH, 'bernoulli' + EXTENSION)
jdata = os.path.join(DATAFILES_PATH, 'bernoulli.data.json')
sampler_args = PathfinderArgs(num_paths=4, save_single_paths=True)
chain_ids = [1]
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=chain_ids,
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args)
assert len(runset.single_path_csv_files) == 4
assert len(runset.single_path_json_files) == 4

assert all(
csv_file.endswith(f"_path_{id}.csv")
for id, csv_file in zip(range(1, 5), runset.single_path_csv_files)
)
assert all(
json_file.endswith(f"_path_{id}.json")
for id, json_file in zip(range(1, 5), runset.single_path_json_files)
)

sampler_args = PathfinderArgs(num_paths=1, save_single_paths=True)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=chain_ids,
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args)

assert len(runset.single_path_csv_files) == 1
assert len(runset.single_path_json_files) == 1

assert runset.single_path_csv_files[0].endswith(".csv")
assert runset.single_path_json_files[0].endswith(".json")

sampler_args = PathfinderArgs(num_paths=1, save_single_paths=False)
cmdstan_args = CmdStanArgs(
model_name='bernoulli',
model_exe=exe,
chain_ids=chain_ids,
data=jdata,
method_args=sampler_args,
)
runset = RunSet(args=cmdstan_args)

assert len(runset.single_path_csv_files) == 0
assert len(runset.single_path_json_files) == 0