diff --git a/horizont/lda.py b/horizont/lda.py index 0c688f5..5ec04c6 100644 --- a/horizont/lda.py +++ b/horizont/lda.py @@ -104,7 +104,7 @@ def __init__(self, n_topics=None, n_iter=1000, alpha=0.1, eta=0.01, random_state self.n_iter = n_iter self.alpha = alpha self.eta = eta - self.random_state = random_state + self.random_state = sklearn.utils.check_random_state(random_state) rng = sklearn.utils.check_random_state(random_state) # random numbers that are reused self._rands = rng.rand(1000) @@ -138,6 +138,11 @@ def _fit(self, X): random_state = sklearn.utils.check_random_state(self.random_state) X = np.atleast_2d(sklearn.utils.as_float_array(X)) self._initialize(X, random_state) + self._run_fitting_iterations() + return self + + def _run_fitting_iterations(self): + random_state = sklearn.utils.check_random_state(self.random_state) for it in range(self.n_iter): if it % 10 == 0: self._print_status(it) @@ -151,11 +156,27 @@ def _fit(self, X): self.theta_ = self.ndz_ + self.alpha self.theta_ /= np.sum(self.theta_, axis=1, keepdims=True) - # delete attributes no longer needed after fitting - del self.WS - del self.DS - del self.ZS + def continue_fitting(self, n_iter): + """Continues fitting the model after calling `fit` or `fit_transform. + + Parameters + ---------- + n_iter: int, number of iterations to perform + """ + if not hasattr(self, 'WS'): + raise Exception('Must call `fit` or `fit_transform` first.') + self.n_iter = n_iter + self._run_fitting_iterations() return self + + def delete_temp_vars(self): + """Delete attributes no longer needed after fitting""" + if hasattr(self, 'WS'): + del self.WS + if hasattr(self, 'DS'): + del self.DS + if hasattr(self, 'ZS'): + del self.ZS def _print_status(self, iter): ll = self.loglikelihood()