Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Esmfold #114

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ install:
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
- conda create -q -n test-environment --channel=conda-forge mmtf-python numpy scipy pandas nose python=%PYTHON_VERSION%
- conda create -q -n test-environment --channel=conda-forge mmtf-python numpy scipy requests pandas nose python=%PYTHON_VERSION%
- activate test-environment

test_script:
Expand Down
163 changes: 111 additions & 52 deletions biopandas/mmcif/pandas_mmcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import pandas as pd
import requests

from ..pdb.engines import amino3to1dict
from ..pdb.pandas_pdb import PandasPdb
Expand Down Expand Up @@ -63,66 +64,89 @@ def read_mmcif(self, path):
self

"""
self.mmcif_path, self.pdb_text = self._read_mmcif(path=path)
self._df = self._construct_df(text=self.pdb_text)
self.mmcif_path, self.mmcif_text = self._read_mmcif(path=path)
self._df = self._construct_df(text=self.mmcif_text)
# self.header, self.code = self._parse_header_code() #TODO: implement
self.code = self.data["entry"]["id"][0].lower()
return self

def fetch_mmcif(self, pdb_code: Optional[str] = None, uniprot_id: Optional[str] = None, source: str = "pdb"):
"""Fetches mmCIF file contents from the Protein Databank at rcsb.org or AlphaFold database at https://alphafold.ebi.ac.uk/.
.
def fetch_mmcif(
self,
pdb_code: Optional[str] = None,
uniprot_id: Optional[str] = None,
sequence: Optional[str] = None,
source: str = "pdb",
):
"""
Fetches mmCIF file contents from the Protein Databank at rcsb.org,
AlphaFold database at https://alphafold.ebi.ac.uk/ or ESMFold database.
.

Parameters
----------
pdb_code : str, optional
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the PDB. Defaults to `None`.
Parameters
----------
pdb_code : str, optional
A 4-letter PDB code, e.g., `"3eiy"` to retrieve structures from the
PDB. Defaults to `None`.

uniprot_id : str, optional
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from the AF2 database. Defaults to `None`.
uniprot_id : str, optional
A UniProt Identifier, e.g., `"Q5VSL9"` to retrieve structures from
the AF2 database. Defaults to `None`.

source : str
The source to retrieve the structure from
(`"pdb"`, `"alphafold2-v1"`, `"alphafold2-v2"` or `"alphafold2-v3"`). Defaults to `"pdb"`.
sequence : str, optional
A protein sequence to retrieve a structure from the ESMFold
database. Defaults to `None`.

Returns
---------
self
source : str
The source to retrieve the structure from. Can be one of
(`"pdb"`, `"alphafold2-v1"`, `"alphafold2-v2"` or
`"alphafold2-v3"`). Defaults to `"pdb"`.

Returns
---------
self

"""
if sequence is not None:
self.esmfold(sequence)
return self
# Sanitize input
invalid_input_identifier_1 = pdb_code is None and uniprot_id is None
invalid_input_identifier_2 = pdb_code is not None and uniprot_id is not None
invalid_input_combination_1 = uniprot_id is not None and source == "pdb"
invalid_input_combination_2 = pdb_code is not None and source in {
"alphafold2-v1", "alphafold2-v2", "alphafold2-v3"}
"alphafold2-v1",
"alphafold2-v2",
"alphafold2-v3",
}

if invalid_input_identifier_1 or invalid_input_identifier_2:
raise ValueError(
"Please provide either a PDB code or a UniProt ID.")
raise ValueError("Please provide either a PDB code or a UniProt ID.")

if invalid_input_combination_1:
raise ValueError(
"Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'.")
"Please use a 'pdb_code' instead of 'uniprot_id' for source='pdb'."
)
elif invalid_input_combination_2:
raise ValueError(
f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}.")
f"Please use a 'uniprot_id' instead of 'pdb_code' for source={source}."
)

if source == "pdb":
self.mmcif_path, self.mmcif_text = self._fetch_mmcif(pdb_code)
elif source == "alphafold2-v1":
af2_version = 1
self.mmcif_path, self.mmcif_text = self._fetch_af2(
uniprot_id, af2_version)
self.mmcif_path, self.mmcif_text = self._fetch_af2(uniprot_id, af2_version)
elif source == "alphafold2-v2":
af2_version = 2
self.mmcif_path, self.mmcif_text = self._fetch_af2(uniprot_id, af2_version)
elif source == "alphafold2-v3":
af2_version = 3
self.mmcif_path, self.mmcif_text = self._fetch_af2(uniprot_id, af2_version)
else:
raise ValueError(f"Invalid source: {source}."
" Please use one of 'pdb', 'alphafold2-v1', 'alphafold2-v2' or 'alphafold2-v3.")
raise ValueError(
f"Invalid source: {source}."
" Please use one of 'pdb', 'alphafold2-v1', 'alphafold2-v2' or 'alphafold2-v3."
)

self._df = self._construct_df(text=self.mmcif_text)
return self
Expand All @@ -132,9 +156,9 @@ def _construct_df(self, text: str):
data = data[list(data.keys())[0]]
self.data = data
df: Dict[str, pd.DataFrame] = {}
full_df = pd.DataFrame.from_dict(
data["atom_site"], orient="index").transpose()
full_df = full_df.astype(mmcif_col_types, errors="ignore")
full_df = pd.DataFrame.from_dict(data["atom_site"], orient="index").transpose()
types = {k: v for k, v in mmcif_col_types.items() if k in full_df.columns}
full_df = full_df.astype(types, errors="ignore")
df["ATOM"] = pd.DataFrame(full_df[full_df.group_PDB == "ATOM"])
df["HETATM"] = pd.DataFrame(full_df[full_df.group_PDB == "HETATM"])
try:
Expand All @@ -152,8 +176,7 @@ def _fetch_mmcif(pdb_code):
response = urlopen(url)
txt = response.read()
txt = (
txt.decode(
"utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
)
except HTTPError as e:
print(f"HTTP Error {e.code}")
Expand All @@ -170,14 +193,13 @@ def _fetch_af2(uniprot_id: str, af2_version: int = 3):
try:
response = urlopen(url)
txt = response.read()
if sys.version_info[0] >= 3:
txt = txt.decode('utf-8')
else:
txt = txt.encode('ascii')
txt = (
txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
)
except HTTPError as e:
print('HTTP Error %s' % e.code)
print(f"HTTP Error {e.code}")
except URLError as e:
print('URL Error %s' % e.args)
print(f"URL Error {e.args}")
return url, txt

@staticmethod
Expand All @@ -190,8 +212,7 @@ def _read_mmcif(path):
r_mode = "rb"
openf = gzip.open
else:
allowed_formats = ", ".join(
(".cif", ".cif.gz", ".mmcif", ".mmcif.gz"))
allowed_formats = ", ".join((".cif", ".cif.gz", ".mmcif", ".mmcif.gz"))
raise ValueError(
f"Wrong file format; allowed file formats are {allowed_formats}"
)
Expand All @@ -201,11 +222,48 @@ def _read_mmcif(path):

if path.endswith(".gz"):
txt = (
txt.decode(
"utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
txt.decode("utf-8") if sys.version_info[0] >= 3 else txt.encode("ascii")
)
return path, txt

def esmfold(self, sequence: str, out_path: Optional[str] = None, version: int = 1):
"""Fold a protein sequence using the ESMFold model from the ESMFold server at
https://api.esmatlas.com/foldSequence/v1/pdb/.


Parameters
----------
sequence : str
A protein sequence in one-letter code.
out_path : str, optional
Path to save the PDB file to. If `None`, the file is not saved.
Defaults to `None`.

version : int, optional
The version of the ESMFold model to use. Defaults to `1`.


Returns
--------
self
"""
URL = f"https://api.esmatlas.com/foldSequence/v{version}/cif/"
headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
}

cif = requests.post(URL, data=sequence, headers=headers).text

# append header
header = "\n".join([f"data_{sequence}", "#", f"_entry.id\t{sequence}", "#\n"])
cif = header + cif
if out_path is not None:
with open(out_path, "w") as f:
f.write(cif)

self._df = self._construct_df(text=cif)
return self

def get(self, s, df=None, invert=False, records=("ATOM", "HETATM")):
"""Filter PDB DataFrames by properties

Expand Down Expand Up @@ -278,8 +336,7 @@ def _get_mainchain(
def _get_hydrogen(df, invert):
"""Return only hydrogen atom entries from a DataFrame"""
return (
df[(df["type_symbol"] != "H")] if invert else df[(
df["type_symbol"] == "H")]
df[(df["type_symbol"] != "H")] if invert else df[(df["type_symbol"] == "H")]
)

@staticmethod
Expand Down Expand Up @@ -346,8 +403,7 @@ def amino3to1(
indices.append(ind)
cmp = num

transl = tmp.iloc[indices][residue_col].map(
amino3to1dict).fillna(fillna)
transl = tmp.iloc[indices][residue_col].map(amino3to1dict).fillna(fillna)

return pd.concat((tmp.iloc[indices][chain_col], transl), axis=1)

Expand Down Expand Up @@ -473,7 +529,7 @@ def _init_get_dict():
"heavy": PandasMmcif._get_heavy,
}

def read_mmcif_from_list(self, mmcif_lines):
def read_mmcif_from_list(self, mmcif_lines: List[str]):
"""Reads mmCIF file from a list into DataFrames

Attributes
Expand All @@ -486,8 +542,8 @@ def read_mmcif_from_list(self, mmcif_lines):
self

"""
self.pdb_text = "".join(mmcif_lines)
self._df = self._construct_df(mmcif_lines)
self.mmcif_text = "\n".join(mmcif_lines)
self._df = self._construct_df("\n".join(mmcif_lines))
# self.header, self.code = self._parse_header_code()
self.code = self.data["entry"]["id"][0].lower()
return self
Expand Down Expand Up @@ -532,10 +588,13 @@ def convert_to_pandas_pdb(self, offset_chains: bool = True, records: List[str] =

# Update atom numbers
if offset_chains:
offsets = pandaspdb.df["ATOM"]["chain_id"].astype(
"category").cat.codes
pandaspdb.df["ATOM"]["atom_number"] = pandaspdb.df["ATOM"]["atom_number"] + offsets
offsets = pandaspdb.df["ATOM"]["chain_id"].astype("category").cat.codes
pandaspdb.df["ATOM"]["atom_number"] = (
pandaspdb.df["ATOM"]["atom_number"] + offsets
)
hetatom_offset = offsets.max() + 1
pandaspdb.df["HETATM"]["atom_number"] = pandaspdb.df["HETATM"]["atom_number"] + hetatom_offset
pandaspdb.df["HETATM"]["atom_number"] = (
pandaspdb.df["HETATM"]["atom_number"] + hetatom_offset
)

return pandaspdb
26 changes: 20 additions & 6 deletions biopandas/mmcif/tests/test_read_mmcif.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@


import os
from typing import Set
from urllib.error import HTTPError
from urllib.request import urlopen

import numpy as np
import pandas as pd
from nose.tools import raises
from pandas.testing import assert_frame_equal

from biopandas.mmcif import PandasMmcif
from biopandas.pdb import PandasPdb
from biopandas.testutils import assert_raises
from nose.tools import raises
from pandas.testing import assert_frame_equal

TESTDATA_FILENAME = os.path.join(os.path.dirname(__file__), "data", "3eiy.cif")

Expand Down Expand Up @@ -91,7 +93,6 @@
af2_test_struct_v3 = f.read()



def test__read_pdb():
"""Test private _read_pdb"""
ppdb = PandasMmcif()
Expand Down Expand Up @@ -135,6 +136,19 @@ def test_fetch_pdb():
assert ppdb.mmcif_path == "https://files.rcsb.org/download/3eiy.cif"


def test_read_pdb_esmfold():
"""Test retrieving a structure from ESMFold."""
sequence = "MTYGLY"
res_ids: Set[str] = {"A:MET:1", "A:THR:2", "A:TYR:3", "A:GLY:4", "A:LEU:5", "A:TYR:6"}
ppdb = PandasMmcif().fetch_mmcif(sequence=sequence)

df = ppdb.df["ATOM"]

folded_struct_residue_ids = set(list(df.label_asym_id + ":" + df.label_comp_id + ":" + df.label_seq_id.astype(str)))

assert folded_struct_residue_ids == res_ids, "Residue IDs do not match"


def test_fetch_af2():
"""Test fetch_af2"""
# Test latest release
Expand Down Expand Up @@ -245,7 +259,7 @@ def test_read_pdb():
"""Test public read_pdb"""
ppdb = PandasMmcif()
ppdb.read_mmcif(TESTDATA_FILENAME)
assert ppdb.pdb_text == three_eiy
assert ppdb.mmcif_text == three_eiy
assert ppdb.code == "3eiy", ppdb.code
assert ppdb.mmcif_path == TESTDATA_FILENAME

Expand All @@ -255,8 +269,8 @@ def test_read_pdb_from_list():

for pdb_text, code in zip([three_eiy, four_eiy], ["3eiy", "4eiy"]):
ppdb = PandasMmcif()
ppdb.read_mmcif_from_list(pdb_text)
assert ppdb.pdb_text == pdb_text
ppdb.read_mmcif_from_list(pdb_text.split("\n"))
assert ppdb.mmcif_text == pdb_text
assert ppdb.code == code
assert ppdb.mmcif_path == ""

Expand Down
Loading