2020from chebai .result .base import ResultProcessor
2121
2222
23- class AttentionMolPlot (abc .ABC ):
24- def plot_attentions (self , smiles , attention , threshold , labels ):
23+ class AttentionMolPlot :
24+
25+ def draw_attention_molecule (self , smiles , attention ):
2526 pmol = self .read_smiles_with_index (smiles )
2627 rdmol = Chem .MolFromSmiles (smiles )
2728 if not rdmol :
@@ -34,26 +35,33 @@ def plot_attentions(self, smiles, attention, threshold, labels):
3435 }
3536 d = rdMolDraw2D .MolDraw2DCairo (500 , 500 )
3637 cmap = cm .ScalarMappable (cmap = cm .Greens )
37- attention_colors = cmap . to_rgba ( attention , norm = False )
38+
3839 aggr_attention_colors = cmap .to_rgba (
3940 np .max (attention [2 :, :], axis = 0 ), norm = False
4041 )
4142 cols = {
4243 token_to_node_map [token_index ]: tuple (
4344 aggr_attention_colors [token_index ].tolist ()
4445 )
45- for node , token_index in nx .get_node_attributes (pmol , "token_index" ).items ()
46+ for node , token_index in
47+ nx .get_node_attributes (pmol , "token_index" ).items ()
4648 }
4749 highlight_atoms = [
4850 token_to_node_map [token_index ]
49- for node , token_index in nx .get_node_attributes (pmol , "token_index" ).items ()
51+ for node , token_index in
52+ nx .get_node_attributes (pmol , "token_index" ).items ()
5053 ]
5154 rdMolDraw2D .PrepareAndDrawMolecule (
5255 d , rdmol , highlightAtoms = highlight_atoms , highlightAtomColors = cols
5356 )
5457
5558 d .FinishDrawing ()
59+ return d
5660
61+ def plot_attentions (self , smiles , attention , threshold , labels ):
62+ d = self .draw_attention_molecule (smiles , attention )
63+ cmap = cm .ScalarMappable (cmap = cm .Greens )
64+ attention_colors = cmap .to_rgba (attention , norm = False )
5765 num_tokens = sum (1 for _ in _tokenize (smiles ))
5866
5967 fig = plt .figure (figsize = (15 , 15 ), facecolor = "w" )
0 commit comments