Skip to content

Commit 164b43c

Browse files
committed
Update defog with global pbar and data preprocessing
1 parent 28e2692 commit 164b43c

2 files changed

Lines changed: 72 additions & 24 deletions

File tree

tests/generator/defog.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from torch_molecule import DeFoGMolecularGenerator
66

7-
EPOCHS = 2
7+
EPOCHS = 10
88
BATCH_SIZE = 32
99

1010
def test_defog_generator():
@@ -20,32 +20,43 @@ def test_defog_generator():
2020
'CC1=CC=C(C=C1)C2=CC(=NN2C3=CC=C(C=C3)S(=O)(=O)N)C(F)(F)F'
2121
]
2222
smiles_list = smiles_list * 25 # Create 100 molecules for training
23-
properties = [0, 0, 1, 1] * 25 # Create 100 properties for training
23+
24+
# Multi-dimensional properties: each row is a molecule, each column is a property
25+
# Properties range from 0 to 1
26+
np.random.seed(42) # For reproducible results
27+
properties = np.random.rand(100, 3) # 100 molecules, 3 properties each
2428

25-
# 1. Conditional Model Testing
26-
print("\n=== Testing Conditional DeFoG Model ===")
29+
# 1. Multi-Conditional Model Testing
30+
print("\n=== Testing Multi-Conditional DeFoG Model ===")
2731
conditional_model = DeFoGMolecularGenerator(
28-
task_type=['regression'],
32+
task_type=['regression', 'regression', 'regression'], # 3 regression tasks
2933
epochs=EPOCHS,
3034
batch_size=BATCH_SIZE,
3135
learning_rate=5e-4,
3236
sample_steps=10, # Fewer steps for faster testing
3337
guidance_weight=0.2,
3438
verbose=True,
3539
)
36-
print("Conditional DeFoG Model initialized successfully.")
40+
print("Multi-Conditional DeFoG Model initialized successfully.")
3741
print(f"Input dim y: {conditional_model.input_dim_y}")
3842

39-
print("\n--- Fitting conditional model ---")
43+
print("\n--- Fitting multi-conditional model ---")
4044
conditional_model.fit(smiles_list, properties)
41-
print("Conditional DeFoG Model fitting completed.")
42-
43-
print("\n--- Testing conditional generation ---")
44-
target_properties = [[0], [0], [1], [1]]
45+
print("Multi-Conditional DeFoG Model fitting completed.")
46+
47+
print("\n--- Testing multi-conditional generation ---")
48+
# Generate molecules with specific multi-dimensional properties
49+
target_properties = [
50+
[0.1, 0.2, 0.3], # Low values for all properties
51+
[0.4, 0.5, 0.6], # Medium values for all properties
52+
[0.7, 0.8, 0.9], # High values for all properties
53+
[0.9, 0.1, 0.5] # Mixed values
54+
]
4555
generated_smiles = conditional_model.generate(labels=target_properties)
46-
print(f"Conditionally generated {len(generated_smiles)} molecules.")
56+
print(f"Multi-conditionally generated {len(generated_smiles)} molecules.")
4757
assert len(generated_smiles) == len(target_properties)
4858
print("Example SMILES:", generated_smiles[:2])
59+
print("Target properties for first molecule:", target_properties[0])
4960

5061
print("\n--- Testing model saving and loading ---")
5162
save_path = "conditional_defog_test_model.pt"
@@ -70,7 +81,29 @@ def test_defog_generator():
7081
os.remove(save_path)
7182
print(f"Cleaned up {save_path}")
7283

73-
# 2. Unconditional Model Testing
84+
# 2. Single-property conditional testing (backwards compatibility)
85+
print("\n=== Testing Single-Property Conditional DeFoG Model ===")
86+
single_properties = properties[:, 0:1] # Use only first property
87+
single_conditional_model = DeFoGMolecularGenerator(
88+
task_type=['regression'], # Single regression task
89+
epochs=EPOCHS,
90+
batch_size=BATCH_SIZE,
91+
learning_rate=5e-4,
92+
sample_steps=10,
93+
guidance_weight=0.2,
94+
verbose=True,
95+
)
96+
print("Single-Property Conditional DeFoG Model initialized successfully.")
97+
98+
single_conditional_model.fit(smiles_list, single_properties)
99+
print("Single-Property DeFoG Model fitting completed.")
100+
101+
single_target_properties = [[0.2], [0.5], [0.8], [0.1]]
102+
single_generated_smiles = single_conditional_model.generate(labels=single_target_properties)
103+
print(f"Single-conditionally generated {len(single_generated_smiles)} molecules.")
104+
assert len(single_generated_smiles) == len(single_target_properties)
105+
106+
# 3. Unconditional Model Testing
74107
print("\n=== Testing Unconditional DeFoG Model ===")
75108
unconditional_model = DeFoGMolecularGenerator(
76109
task_type=[], # Empty task_type for unconditional generation

torch_molecule/generator/defog/modeling_defog.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@ class DeFoGMolecularGenerator(BaseMolecularGenerator):
3333
num_layer : int, default=6
3434
Number of transformer layers
3535
hidden_mlp_dims : Dict[str, int], default={'X': 256, 'E': 128, 'y': 128}
36-
Hidden dimensions for MLP layers in X, E, and y components
36+
Hidden dimensions for MLP layers in X (node dim), E (edge dim), and y (property dim) components
3737
hidden_dims : Dict[str, Any], default={'dx': 256, 'de': 64, 'dy': 64, 'n_head': 8, 'dim_ffX': 256, 'dim_ffE': 128, 'dim_ffy': 128}
3838
Hidden dimensions for transformer components including attention heads and feed-forward layers
39-
Keys: 'dx', 'de', 'dy', 'n_head', 'dim_ffX', 'dim_ffE', 'dim_ffy'
39+
Keys: 'dx' (node dim), 'de' (edge dim), 'dy' (property dim), 'n_head' (number of attention heads), 'dim_ffX' (feed-forward dim for node features), 'dim_ffE' (feed-forward dim for edge features), 'dim_ffy' (feed-forward dim for property features)
4040
transition : str, default='marginal'
4141
Transition type for flow matching.
4242
Options: 'marginal', 'absorbing', 'uniform', 'absorbfirst', 'argmax', 'edge_marginal', 'node_marginal'
@@ -239,8 +239,6 @@ def _setup_optimizers(self) -> Tuple[torch.optim.Optimizer, Optional[Any]]:
239239

240240
return optimizer, scheduler
241241

242-
243-
244242
def _convert_to_pytorch_data(self, X, y=None):
245243
"""Convert numpy arrays to PyTorch Geometric data format."""
246244
if self.verbose:
@@ -256,7 +254,9 @@ def _convert_to_pytorch_data(self, X, y=None):
256254
g = Data()
257255

258256
node_type = torch.from_numpy(graph['node_feat'][:, 0] - 1)
259-
257+
if node_type.numel() <= 1:
258+
continue
259+
260260
valid_mask = node_type >= 0
261261
if not valid_mask.all():
262262
# Get valid nodes and adjust edge indices
@@ -398,24 +398,30 @@ def fit(self, X_train: List[str], y_train: Optional[Union[List, np.ndarray]] = N
398398
train_dataset = self._convert_to_pytorch_data(X_train, y_train)
399399
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
400400

401+
# Calculate total steps for global progress bar
402+
total_steps = self.epochs * len(train_loader)
403+
global_progress = tqdm(total=total_steps, desc="Training Progress", leave=True) if self.verbose else None
404+
401405
self.fitting_loss = []
402406
for epoch in range(self.epochs):
403-
train_losses = self._train_epoch(train_loader, optimizer, epoch)
407+
train_losses = self._train_epoch(train_loader, optimizer, epoch, global_progress)
404408
avg_loss = np.mean(train_losses)
405409
self.fitting_loss.append(avg_loss)
406410
if scheduler:
407411
scheduler.step(avg_loss)
408412

413+
if global_progress:
414+
global_progress.close()
415+
409416
self.is_fitted_ = True
410417
return self
411418

412-
def _train_epoch(self, train_loader, optimizer, epoch):
419+
def _train_epoch(self, train_loader, optimizer, epoch, global_progress=None):
413420
self.model.train()
414421
losses = []
415-
iterator = tqdm(train_loader, desc=f"Epoch {epoch}", leave=False) if self.verbose else train_loader
416422

417423
active_index = self.dataset_info["active_index"]
418-
for batched_data in iterator:
424+
for batched_data in train_loader:
419425
batched_data = batched_data.to(self.device)
420426
optimizer.zero_grad()
421427

@@ -467,8 +473,17 @@ def _train_epoch(self, train_loader, optimizer, epoch):
467473
optimizer.step()
468474

469475
losses.append(loss.item())
470-
if self.verbose:
471-
iterator.set_postfix({"Loss": f"{loss.item():.4f}"})
476+
477+
# Update global progress bar
478+
if global_progress:
479+
global_progress.set_postfix({
480+
"Epoch": f"{epoch+1}",
481+
"Loss": f"{loss.item():.4f}",
482+
"Loss_X": f"{masked_loss_X.item():.4f}",
483+
"Loss_E": f"{masked_loss_E.item():.4f}",
484+
"Loss_y": f"{loss_y.item():.4f}"
485+
})
486+
global_progress.update(1)
472487

473488
return losses
474489

0 commit comments

Comments
 (0)