3232
3333class PGBART (ArrayStepShared ):
3434 """
35- Particle Gibss BART sampling step
35+ Particle Gibss BART sampling step.
3636
3737 Parameters
3838 ----------
@@ -208,9 +208,7 @@ def astep(self, _):
208208 return self .sum_trees , [stats ]
209209
210210 def normalize (self , particles ):
211- """
212- Use logsumexp trick to get W_t and softmax to get normalized_weights
213- """
211+ """Use logsumexp trick to get W_t and softmax to get normalized_weights."""
214212 log_w = np .array ([p .log_weight for p in particles ])
215213 log_w_max = log_w .max ()
216214 log_w_ = log_w - log_w_max
@@ -224,9 +222,7 @@ def normalize(self, particles):
224222 return W_t , normalized_weights
225223
226224 def init_particles (self , tree_id : int ) -> np .ndarray :
227- """
228- Initialize particles
229- """
225+ """Initialize particles."""
230226 p = self .all_particles [tree_id ]
231227 particles = [p ]
232228 particles .append (copy (p ))
@@ -238,7 +234,7 @@ def init_particles(self, tree_id: int) -> np.ndarray:
238234
239235 def update_weight (self , particle , old = False ):
240236 """
241- Update the weight of a particle
237+ Update the weight of a particle.
242238
243239 Since the prior is used as the proposal,the weights are updated additively as the ratio of
244240 the new and old log-likelihoods.
@@ -253,19 +249,15 @@ def update_weight(self, particle, old=False):
253249
254250 @staticmethod
255251 def competence (var , has_grad ):
256- """
257- PGBART is only suitable for BART distributions
258- """
252+ """PGBART is only suitable for BART distributions."""
259253 dist = getattr (var .owner , "op" , None )
260254 if isinstance (dist , BARTRV ):
261255 return Competence .IDEAL
262256 return Competence .INCOMPATIBLE
263257
264258
265259class ParticleTree :
266- """
267- Particle tree
268- """
260+ """Particle tree."""
269261
270262 def __init__ (self , tree ):
271263 self .tree = tree .copy () # keeps the tree that we care at the moment
@@ -340,6 +332,7 @@ def rvs(self):
340332def compute_prior_probability (alpha ):
341333 """
342334 Calculate the probability of the node being a LeafNode (1 - p(being SplitNode)).
335+
343336 Taken from equation 19 in [Rockova2018].
344337
345338 Parameters
@@ -463,7 +456,7 @@ def get_new_idx_data_points(split_value, idx_data_points, selected_predictor, X)
463456
464457
465458def draw_leaf_value (Y_mu_pred , X_mu , mean , m , normal , mu_std ):
466- """Draw Gaussian distributed leaf values"""
459+ """Draw Gaussian distributed leaf values. """
467460 if Y_mu_pred .size == 0 :
468461 return 0
469462 else :
@@ -504,9 +497,7 @@ def discrete_uniform_sampler(upper_value):
504497
505498
506499class NormalSampler :
507- """
508- Cache samples from a standard normal distribution
509- """
500+ """Cache samples from a standard normal distribution."""
510501
511502 def __init__ (self ):
512503 self .size = 1000
0 commit comments