@@ -33,10 +33,10 @@ class DeFoGMolecularGenerator(BaseMolecularGenerator):
3333 num_layer : int, default=6
3434 Number of transformer layers
3535 hidden_mlp_dims : Dict[str, int], default={'X': 256, 'E': 128, 'y': 128}
36- Hidden dimensions for MLP layers in X, E, and y components
36+ Hidden dimensions for MLP layers in X (node dim) , E (edge dim) , and y (property dim) components
3737 hidden_dims : Dict[str, Any], default={'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
3838 Hidden dimensions for transformer components including attention heads and feed-forward layers
39- Keys: 'dx', 'de', 'dy', 'n_head', 'dim_ffX', 'dim_ffE', 'dim_ffy'
39+ Keys: 'dx' (node dim) , 'de' (edge dim) , 'dy' (property dim) , 'n_head' (number of attention heads) , 'dim_ffX' (feed-forward dim for node features) , 'dim_ffE' (feed-forward dim for edge features) , 'dim_ffy' (feed-forward dim for property features)
4040 transition : str, default='marginal'
4141 Transition type for flow matching.
4242 Options: 'marginal', 'absorbing', 'uniform', 'absorbfirst', 'argmax', 'edge_marginal', 'node_marginal'
@@ -239,8 +239,6 @@ def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
239239
240240 return optimizer , scheduler
241241
242-
243-
244242 def _convert_to_pytorch_data (self , X , y = None ):
245243 """Convert numpy arrays to PyTorch Geometric data format."""
246244 if self .verbose :
@@ -256,7 +254,9 @@ def _convert_to_pytorch_data(self, X, y=None):
256254 g = Data ()
257255
258256 node_type = torch .from_numpy (graph ['node_feat' ][:, 0 ] - 1 )
259-
257+ if node_type .numel () <= 1 :
258+ continue
259+
260260 valid_mask = node_type >= 0
261261 if not valid_mask .all ():
262262 # Get valid nodes and adjust edge indices
@@ -398,24 +398,30 @@ def fit(self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]] = N
398398 train_dataset = self ._convert_to_pytorch_data (X_train , y_train )
399399 train_loader = DataLoader (train_dataset , batch_size = self .batch_size , shuffle = True )
400400
401+ # Calculate total steps for global progress bar
402+ total_steps = self .epochs * len (train_loader )
403+ global_progress = tqdm (total = total_steps , desc = "Training Progress" , leave = True ) if self .verbose else None
404+
401405 self .fitting_loss = []
402406 for epoch in range (self .epochs ):
403- train_losses = self ._train_epoch (train_loader , optimizer , epoch )
407+ train_losses = self ._train_epoch (train_loader , optimizer , epoch , global_progress )
404408 avg_loss = np .mean (train_losses )
405409 self .fitting_loss .append (avg_loss )
406410 if scheduler :
407411 scheduler .step (avg_loss )
408412
413+ if global_progress :
414+ global_progress .close ()
415+
409416 self .is_fitted_ = True
410417 return self
411418
412- def _train_epoch (self , train_loader , optimizer , epoch ):
419+ def _train_epoch (self , train_loader , optimizer , epoch , global_progress = None ):
413420 self .model .train ()
414421 losses = []
415- iterator = tqdm (train_loader , desc = f"Epoch { epoch } " , leave = False ) if self .verbose else train_loader
416422
417423 active_index = self .dataset_info ["active_index" ]
418- for batched_data in iterator :
424+ for batched_data in train_loader :
419425 batched_data = batched_data .to (self .device )
420426 optimizer .zero_grad ()
421427
@@ -467,8 +473,17 @@ def _train_epoch(self, train_loader, optimizer, epoch):
467473 optimizer .step ()
468474
469475 losses .append (loss .item ())
470- if self .verbose :
471- iterator .set_postfix ({"Loss" : f"{ loss .item ():.4f} " })
476+
477+ # Update global progress bar
478+ if global_progress :
479+ global_progress .set_postfix ({
480+ "Epoch" : f"{ epoch + 1 } " ,
481+ "Loss" : f"{ loss .item ():.4f} " ,
482+ "Loss_X" : f"{ masked_loss_X .item ():.4f} " ,
483+ "Loss_E" : f"{ masked_loss_E .item ():.4f} " ,
484+ "Loss_y" : f"{ loss_y .item ():.4f} "
485+ })
486+ global_progress .update (1 )
472487
473488 return losses
474489
0 commit comments