Skip to content

Commit

Permalink
refactor: deduplicate code
Browse files Browse the repository at this point in the history
  • Loading branch information
YaoYinYing committed Aug 19, 2024
1 parent d2feac5 commit b3153de
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
"""Functions for building the input features (reference ccd features) for the HelixFold model."""

import collections
from typing import Optional
import gzip
import os
import pickle
from typing import Any, Optional

from absl import logging
from immutabledict import immutabledict
from helixfold.common import residue_constants
import numpy as np

from helixfold.data.tools import utils


ALLOWED_LIGAND_BONDS_TYPE = {
"SING": 1,
"DOUB": 2,
Expand All @@ -13,6 +22,25 @@
"AROM": 12,
}

def load_ccd_dict(ccd_preprocessed_path: str) -> immutabledict[str, Any]:
if not os.path.exists(ccd_preprocessed_path):
raise FileNotFoundError(f'[CCD] ccd_preprocessed_path: {ccd_preprocessed_path} not exist.')

if not ccd_preprocessed_path.endswith('.pkl.gz') and not ccd_preprocessed_path.endswith('.pkl'):
raise ValueError(f'[CCD] ccd_preprocessed_path: {ccd_preprocessed_path} not endswith .pkl.gz and .pkl')

with utils.timing(f'Loading CCD dataset from {ccd_preprocessed_path}'):
if ccd_preprocessed_path.endswith('.pkl.gz'):
with gzip.open(ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = immutabledict(pickle.load(fp))
else:
with open(ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = immutabledict(pickle.load(fp))

logging.info(f'CCD dataset contains {len(ccd_preprocessed_dict)} entries.')

return ccd_preprocessed_dict

def element_map_with_x(atom_symbol):
# ## one-hot max shape == 128
return residue_constants.ATOM_ELEMENT.get(atom_symbol, 127)
Expand Down Expand Up @@ -107,8 +135,13 @@ def make_ccd_conf_features(all_chain_info, ccd_preprocessed_dict,
features[k] = np.concatenate(v, axis=0)
features['ref_atom_count'] = np.bincount(features['ref_token2atom_idx'])

assert np.max(features['ref_element']) < 128 # WTF?
assert np.max(features['ref_atom_name_chars']) < 64 # WTF?
if (_ref_element:=np.max(features['ref_element'])) >= 128:
raise ValueError(f'ref_element= {_ref_element}, which is larger then 128.\n{features["ref_element"]}\n{"-"*79}')


if (_ref_atom_name_chars:=np.max(features['ref_atom_name_chars'])) >= 64:
raise ValueError(f'ref_atom_name_chars= {_ref_atom_name_chars}, which is larger then 64.\n{features["ref_atom_name_chars"]}\n{"-"*79}')

assert len(set([len(v) for k, v in features.items() if k != 'ref_atom_count'])) == 1 ## To check same Atom-level features. # WTF?
return features

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
from absl import logging
from helixfold.common import residue_constants
from helixfold.data import parsers
from helixfold.data.tools import utils
from helixfold.data.pipeline_conf_bonds import load_ccd_dict
import numpy as np
import json
import gzip
import pickle

from rdkit import Chem

FeatureDict = MutableMapping[str, np.ndarray]
Expand Down Expand Up @@ -212,14 +211,7 @@ def process(self,
assembly_dict = unit_dict

if ccd_preprocessed_dict is None:
ccd_preprocessed_dict = {}

if not 'pkl.gz' in self.ccd_preprocessed_path:
raise ValueError(f'Invalid ccd_preprocessed_path: {self.ccd_preprocessed_path}')

with utils.timing('Loading CCD dataset'):
with gzip.open(self.ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = pickle.load(fp)
ccd_preprocessed_dict=load_ccd_dict(self.ccd_preprocessed_path)

if select_mmcif_chainID is not None:
select_mmcif_chainID = set(select_mmcif_chainID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import copy
import os
from pathlib import Path
import time, gzip, pickle
from typing import Optional, Tuple
import pickle
from typing import Optional, Tuple, Any

import numpy as np
from absl import logging

Expand All @@ -17,27 +18,15 @@
from helixfold.data import pipeline_conf_bonds, pipeline_token_feature, pipeline_hybrid
from helixfold.data import label_utils

from helixfold.data.tools import utils

from .preprocess import digit2alphabet


POLYMER_STANDARD_RESI_ATOMS = residue_constants.residue_atoms
STRING_FEATURES = ['all_chain_ids', 'all_ccd_ids','all_atom_ids',
'release_date','label_ccd_ids','label_atom_ids']

def load_ccd_dict(ccd_preprocessed_path):
assert os.path.exists(ccd_preprocessed_path),\
(f'[CCD] ccd_preprocessed_path: {ccd_preprocessed_path} not exist.')
st_1 = time.time()
if 'pkl.gz' in ccd_preprocessed_path:
with gzip.open(ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = pickle.load(fp)
elif '.pkl' in ccd_preprocessed_path:
with open(ccd_preprocessed_path, "rb") as fp:
ccd_preprocessed_dict = pickle.load(fp)
print(f'[CCD] load ccd dataset done. use {time.time()-st_1}s;'\
f'Has length of {len(ccd_preprocessed_dict)}')

return ccd_preprocessed_dict


def crop_msa(feat, max_msa_depth=16384):
Expand Down Expand Up @@ -385,7 +374,7 @@ def process_input_json(all_entitys, ccd_preprocessed_path,
no_msa_templ_feats=False):

## load ccd dict.
ccd_preprocessed_dict = load_ccd_dict(ccd_preprocessed_path)
ccd_preprocessed_dict = pipeline_conf_bonds.load_ccd_dict(ccd_preprocessed_path)
all_chain_features = {}
sequence_features = {}
num_chains = 0
Expand Down

0 comments on commit b3153de

Please sign in to comment.