Skip to content

Commit 5797f87

Browse files
committed
update readme, add disjoint axioms, make confidence optional
1 parent 8f61ff7 commit 5797f87

File tree

6 files changed

+549
-31
lines changed

6 files changed

+549
-31
lines changed

README.md

Lines changed: 58 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# python-chebifier
2-
An AI ensemble model for predicting chemical classes.
2+
An AI ensemble model for predicting chemical classes in the ChEBI ontology.
33

44
## Installation
55

@@ -23,39 +23,18 @@ The package provides a command-line interface (CLI) for making predictions using
2323
python -m chebifier.cli --help
2424

2525
# Make predictions using a configuration file
26-
python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
26+
python -m chebifier.cli predict configs/example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
2727

2828
# Make predictions using SMILES from a file
29-
python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt
29+
python -m chebifier.cli predict configs/example_config.yml --smiles-file smiles.txt
3030
```
3131

3232
### Configuration File
3333

34-
The CLI requires a YAML configuration file that defines the ensemble model. Here's an example:
35-
36-
```yaml
37-
# Example configuration file for Chebifier ensemble model
38-
39-
# Each key in the top-level dictionary is a model name
40-
model1:
41-
# Required: type of model (must be one of the keys in MODEL_TYPES)
42-
type: electra
43-
# Required: name of the model
44-
model_name: electra_model1
45-
# Required: path to the checkpoint file
46-
ckpt_path: /path/to/checkpoint1.ckpt
47-
# Required: path to the target labels file
48-
target_labels_path: /path/to/target_labels1.txt
49-
# Optional: batch size for predictions (default is likely defined in the model)
50-
batch_size: 32
51-
52-
model2:
53-
type: electra
54-
model_name: electra_model2
55-
ckpt_path: /path/to/checkpoint2.ckpt
56-
target_labels_path: /path/to/target_labels2.txt
57-
batch_size: 64
58-
```
34+
The CLI requires a YAML configuration file that defines the ensemble model. An example can be found in `configs/example_config.yml`.
35+
36+
The models and other required files are trained / generated by our [chebai](https://github.com/ChEB-AI/python-chebai) package.
37+
Examples for models can be found on [kaggle](https://www.kaggle.com/datasets/sfluegel/chebai).
5938

6039
### Python API
6140

@@ -77,10 +56,59 @@ smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
7756
predictions = ensemble.predict_smiles_list(smiles_list)
7857

7958
# Print results
80-
for smile, prediction in zip(smiles_list, predictions):
81-
print(f"SMILES: {smile}")
59+
for smiles, prediction in zip(smiles_list, predictions):
60+
print(f"SMILES: {smiles}")
8261
if prediction:
8362
print(f"Predicted classes: {prediction}")
8463
else:
8564
print("No predictions")
8665
```
66+
67+
### The ensemble
68+
69+
Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows:
70+
1. Get predictions from each model $m_i$ for the sample.
71+
2. For each class $c$, aggregate predictions $p_c^{m_i}$ from all models that made a prediction for that class.
72+
The aggregation happens separately for all positive predictions (i.e., $p_c^{m_i} \geq 0.5$) and all negative predictions
73+
($p_c^{m_i} < 0.5$). If the aggregated value is larger for the positive predictions than for the negative predictions,
74+
the ensemble makes a positive prediction for class $c$:
75+
76+
$$
77+
\text{ensemble}(c) = \begin{cases}
78+
1 & \text{if } \sum_{i: p_c^{m_i} \geq 0.5} [\text{confidence}_c^{m_i} \cdot \text{model_weight}_{m_i} \cdot \text{trust}_c^{m_i}] > \sum_{i: p_c^{m_i} < 0.5} [\text{confidence}_c^{m_i} \cdot \text{model_weight}_{m_i} \cdot \text{trust}_c^{m_i}] \\
79+
0 & \text{otherwise}
80+
\end{cases}
81+
$$
82+
83+
Here, confidence is the model's (self-reported) confidence in its prediction, calculated as
84+
$$
85+
\text{confidence}_c^{m_i} = 2|p_c^{m_i} - 0.5|
86+
$$
87+
For example, if a model makes a positive prediction with $p_c^{m_i} = 0.55$, the confidence is $2|0.55 - 0.5| = 0.1$.
88+
One could say that the model is not very confident in its prediction and very close to switching to a negative prediction.
89+
If another model is very sure about its negative prediction with $p_c^{m_j} = 0.1$, the confidence is $2|0.1 - 0.5| = 0.8$.
90+
Therefore, if in doubt, we are more confident in the negative prediction.
91+
92+
Confidence can be disabled by the `use_confidence` parameter of the predict method (default: True).
93+
94+
The model_weight can be set for each model in the configuration file (default: 1). This is used to favor a certain
95+
model independently of a given class.
96+
Trust is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
97+
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score.
98+
If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models.
99+
100+
3. After a decision has been made for each class independently, the consistency of the predictions with regard to the ChEBI hierarchy
101+
and disjointness axioms is checked. This is
102+
done in 3 steps:
103+
- (1) First, the hierarchy is corrected. For each pair of classes $A$ and $B$ where $A$ is a subclass of $B$ (following
104+
the is-a relation in ChEBI), we set the ensemble prediction of $B$ to 1 if the prediction of $A$ is 1. Intuitively
105+
speaking, if we have determined that a molecule belongs to a specific class (e.g., aromatic primary alcohol), it also
106+
belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic alcohol, alcohol).
107+
- (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)).
108+
We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see
109+
`data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict
110+
both, we select one of them randomly and set the other to 0.
111+
- (3) Since the second step might have introduced new inconsistencies into the hierarchy, we repeat the first step, but
112+
with a small change. For a pair of classes $A \subseteq B$ with predictions $1$ and $0$, instead of setting $B$ to $1$,
113+
we now set $A$ to $0$. This has the advantage that we cannot introduce new disjointness-inconsistencies and don't have
114+
to repeat step 2.

chebifier/cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def cli():
2626
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
2727
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
2828
@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)")
29+
@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)")
2930
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version):
3031
"""Predict ChEBI classes for SMILES strings using an ensemble model.
3132

chebifier/ensemble/base_ensemble.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
7878
positive_mask = (predictions > self.positive_prediction_threshold) & valid_predictions
7979
negative_mask = (predictions < self.positive_prediction_threshold) & valid_predictions
8080

81-
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
81+
if "use_confidence" in kwargs and kwargs["use_confidence"]:
82+
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
83+
else:
84+
confidence = torch.ones_like(predictions)
8285

8386
# Extract positive and negative weights
8487
pos_weights = classwise_weights[0] # Shape: (num_classes, num_models)

configs/example_config.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
2+
chemlog_peptides:
3+
type: chemlog
4+
model_weight: 100 # if chemlog is available, it always gets chosen
5+
my_resgated:
6+
type: resgated
7+
ckpt_path: my_resgated.ckpt # checkpoint trained with chebai
8+
target_labels_path: ../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt # from the chebai dataset
9+
molecular_properties: # list of properties used during training
10+
- chebai_graph.preprocessing.properties.AtomType
11+
- chebai_graph.preprocessing.properties.NumAtomBonds
12+
- chebai_graph.preprocessing.properties.AtomCharge
13+
- chebai_graph.preprocessing.properties.AtomAromaticity
14+
- chebai_graph.preprocessing.properties.AtomHybridization
15+
- chebai_graph.preprocessing.properties.AtomNumHs
16+
- chebai_graph.preprocessing.properties.BondType
17+
- chebai_graph.preprocessing.properties.BondInRing
18+
- chebai_graph.preprocessing.properties.BondAromaticity
19+
- chebai_graph.preprocessing.properties.RDKit2DNormalized
20+
#classwise_weights_path: my_resgated_metrics.json # can be calculated with chebai.results.generate_class_properties
21+
22+
my_electra:
23+
type: electra
24+
ckpt_path: my_electra.ckpt
25+
target_labels_path: ../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt
26+
#classwise_weights_path: my_electra_metrics.json # can be calculated with chebai.results.generate_class_properties

data/disjoint_additional.csv

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
16670,60466
2+
16670,60194
3+
16670,60334
4+
60194,60466
5+
60334,60466
6+
60194,60334
7+
15841,25676
8+
46761,47923
9+
46761,48030
10+
46761,48545
11+
47923,48030
12+
47923,48545
13+
48030,48545
14+
90799,155837

0 commit comments

Comments
 (0)