Skip to content

Commit 8fcf54b

Browse files
committed
Put molecule plot in separate function
1 parent 41dd7a3 commit 8fcf54b

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

chebai/result/molplot.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from chebai.result.base import ResultProcessor
2121

2222

23-
class AttentionMolPlot(abc.ABC):
24-
def plot_attentions(self, smiles, attention, threshold, labels):
23+
class AttentionMolPlot:
24+
25+
def draw_attention_molecule(self, smiles, attention):
2526
pmol = self.read_smiles_with_index(smiles)
2627
rdmol = Chem.MolFromSmiles(smiles)
2728
if not rdmol:
@@ -34,26 +35,33 @@ def plot_attentions(self, smiles, attention, threshold, labels):
3435
}
3536
d = rdMolDraw2D.MolDraw2DCairo(500, 500)
3637
cmap = cm.ScalarMappable(cmap=cm.Greens)
37-
attention_colors = cmap.to_rgba(attention, norm=False)
38+
3839
aggr_attention_colors = cmap.to_rgba(
3940
np.max(attention[2:, :], axis=0), norm=False
4041
)
4142
cols = {
4243
token_to_node_map[token_index]: tuple(
4344
aggr_attention_colors[token_index].tolist()
4445
)
45-
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
46+
for node, token_index in
47+
nx.get_node_attributes(pmol, "token_index").items()
4648
}
4749
highlight_atoms = [
4850
token_to_node_map[token_index]
49-
for node, token_index in nx.get_node_attributes(pmol, "token_index").items()
51+
for node, token_index in
52+
nx.get_node_attributes(pmol, "token_index").items()
5053
]
5154
rdMolDraw2D.PrepareAndDrawMolecule(
5255
d, rdmol, highlightAtoms=highlight_atoms, highlightAtomColors=cols
5356
)
5457

5558
d.FinishDrawing()
59+
return d
5660

61+
def plot_attentions(self, smiles, attention, threshold, labels):
62+
d = self.draw_attention_molecule(smiles, attention)
63+
cmap = cm.ScalarMappable(cmap=cm.Greens)
64+
attention_colors = cmap.to_rgba(attention, norm=False)
5765
num_tokens = sum(1 for _ in _tokenize(smiles))
5866

5967
fig = plt.figure(figsize=(15, 15), facecolor="w")

0 commit comments

Comments
 (0)