diff --git a/Cargo.lock b/Cargo.lock index cb949045..d140e941 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1773,7 +1773,7 @@ version = "4.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2 1.0.92", "quote 1.0.37", "syn", @@ -2412,7 +2412,7 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2 1.0.92", "quote 1.0.37", "syn", @@ -2544,7 +2544,9 @@ dependencies = [ "ferritin-test-data", "itertools 0.13.0", "lazy_static", + "ndarray", "pdbtbx", + "strum 0.25.0", ] [[package]] @@ -2596,7 +2598,7 @@ dependencies = [ "pdbtbx", "rand", "serde", - "strum", + "strum 0.26.3", "tempfile", "tokenizers 0.21.0", ] @@ -3319,6 +3321,12 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + [[package]] name = "heck" version = "0.5.0" @@ -5856,13 +5864,35 @@ version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" +[[package]] +name = "strum" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros 0.25.3", +] + [[package]] name = "strum" version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" dependencies = [ - "strum_macros", + "strum_macros 0.26.4", +] + +[[package]] +name = "strum_macros" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23dc1fa9ac9c169a78ba62f0b841814b7abae11bdd047b9c58f893439e309ea0" +dependencies = [ + "heck 0.4.1", + "proc-macro2 1.0.92", + "quote 1.0.37", + "rustversion", + "syn", ] [[package]] @@ -5871,7 +5901,7 @@ version = "0.26.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" dependencies = [ - "heck", + "heck 0.5.0", "proc-macro2 1.0.92", "quote 1.0.37", "rustversion", diff --git a/ferritin-core/Cargo.toml b/ferritin-core/Cargo.toml index 417e0478..463fca1a 100644 --- a/ferritin-core/Cargo.toml +++ b/ferritin-core/Cargo.toml @@ -9,7 +9,9 @@ description.workspace = true [dependencies] pdbtbx.workspace = true itertools.workspace = true +ndarray = { version = "0.16" } lazy_static = "1.5.0" +strum = { version = "0.25", features = ["derive"] } [dev-dependencies] ferritin-test-data = { path = "../ferritin-test-data" } diff --git a/ferritin-core/examples/ccd.rs b/ferritin-core/examples/ccd.rs deleted file mode 100644 index 345373a4..00000000 --- a/ferritin-core/examples/ccd.rs +++ /dev/null @@ -1,264 +0,0 @@ -// //! BinaryCIF implementation following the official specification -// //! Based on https://github.com/dsehnal/BinaryCIF -// use serde::{Deserialize, Serialize}; -// use std::fs::File; -// use std::io::BufReader; -// use std::io::Read; -// use std::path::Path; - -// pub const VERSION: &str = "0.3.0"; - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct BinaryCifFile { -// #[serde(rename = "dataBlocks")] -// pub data_blocks: Vec, -// // Make these optional since they might not be in the file -// #[serde(skip_serializing_if = "Option::is_none")] -// pub version: Option, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub encoder: Option, -// } - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct DataBlock { -// pub categories: Vec, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub header: Option, -// } - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct Category { -// #[serde(rename = "rowCount")] -// pub row_count: u32, -// pub columns: Vec, -// #[serde(skip_serializing_if = "Option::is_none")] -// pub name: Option, -// } - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct Column { -// pub name: String, -// pub data: EncodedData, -// /// The mask represents presence/absence of values: -// /// 0 = Value is present -// /// 1 = . = value not specified -// /// 2 = ? = value unknown -// pub mask: Option, -// } - -// #[derive(Debug, Serialize, Deserialize)] -// pub struct EncodedData { -// pub encoding: Vec, -// #[serde(with = "serde_bytes")] -// pub data: Vec, -// } - -// #[derive(Debug, Serialize, Deserialize)] -// #[serde(tag = "kind")] -// pub enum Encoding { -// #[serde(rename = "ByteArray")] -// ByteArray { -// #[serde(rename = "type")] -// data_type: DataType, -// }, -// #[serde(rename = "FixedPoint")] -// FixedPoint { -// factor: f64, -// #[serde(rename = "srcType")] -// src_type: FloatDataType, -// }, -// #[serde(rename = "RunLength")] -// RunLength { -// #[serde(rename = "srcType")] -// src_type: IntDataType, -// #[serde(rename = "srcSize")] -// src_size: u32, -// }, -// #[serde(rename = "Delta")] -// Delta { -// origin: i32, -// #[serde(rename = "srcType")] -// src_type: IntDataType, -// }, -// #[serde(rename = "IntervalQuantization")] -// IntervalQuantization { -// min: f64, -// max: f64, -// #[serde(rename = "numSteps")] -// num_steps: u32, -// #[serde(rename = "srcType")] -// src_type: FloatDataType, -// }, -// #[serde(rename = "IntegerPacking")] -// IntegerPacking { -// #[serde(rename = "byteCount")] -// byte_count: u32, -// #[serde(rename = "isUnsigned")] -// is_unsigned: bool, -// #[serde(rename = "srcSize")] -// src_size: u32, -// }, -// #[serde(rename = "StringArray")] -// StringArray { -// #[serde(rename = "dataEncoding")] -// data_encoding: Vec, -// #[serde(rename = "stringData")] -// string_data: String, -// #[serde(rename = "offsetEncoding")] -// offset_encoding: Vec, -// offsets: Vec, -// }, -// } - -// /// Integer data types used in encodings -// #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -// #[repr(u8)] -// pub enum IntDataType { -// Int8 = 1, -// Int16 = 2, -// Int32 = 3, -// Uint8 = 4, -// Uint16 = 5, -// Uint32 = 6, -// } - -// /// Floating point data types used in encodings -// #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -// #[repr(u8)] -// pub enum FloatDataType { -// Float32 = 32, -// Float64 = 33, -// } - -// /// Combined data type enum that can be either integer or float -// #[derive(Debug, Clone, Copy, Serialize, Deserialize)] -// #[serde(untagged)] -// pub enum DataType { -// Int(IntDataType), -// Float(FloatDataType), -// } - -// /// CIF value types representing presence/absence of data -// #[derive(Debug, Clone, Copy, PartialEq, Eq)] -// pub enum CifValueStatus { -// Present = 0, // Value exists -// NotSpecified = 1, // "." in CIF -// Unknown = 2, // "?" in CIF -// } - -// /// Types of encoding methods available -// #[derive(Debug, Clone, Copy, PartialEq, Eq)] -// pub enum EncodingType { -// ByteArray, -// FixedPoint, -// RunLength, -// Delta, -// IntervalQuantization, -// IntegerPacking, -// StringArray, -// } - -// impl IntDataType { -// /// Get the size in bytes for this integer type -// pub fn size_in_bytes(&self) -> usize { -// match self { -// IntDataType::Int8 | IntDataType::Uint8 => 1, -// IntDataType::Int16 | IntDataType::Uint16 => 2, -// IntDataType::Int32 | IntDataType::Uint32 => 4, -// } -// } - -// /// Check if this is an unsigned type -// pub fn is_unsigned(&self) -> bool { -// match self { -// IntDataType::Uint8 | IntDataType::Uint16 | IntDataType::Uint32 => true, -// _ => false, -// } -// } -// } - -// impl FloatDataType { -// /// Get the size in bytes for this float type -// pub fn size_in_bytes(&self) -> usize { -// match self { -// FloatDataType::Float32 => 4, -// FloatDataType::Float64 => 8, -// } -// } -// } - -// impl DataType { -// /// Get the size in bytes for this data type -// pub fn size_in_bytes(&self) -> usize { -// match self { -// DataType::Int(i) => i.size_in_bytes(), -// DataType::Float(f) => f.size_in_bytes(), -// } -// } - -// /// Check if this is an integer type -// pub fn is_integer(&self) -> bool { -// matches!(self, DataType::Int(_)) -// } - -// /// Check if this is a float type -// pub fn is_float(&self) -> bool { -// matches!(self, DataType::Float(_)) -// } -// } - -// /// Errors that can occur during encoding/decoding -// #[derive(Debug, thiserror::Error)] -// pub enum BinaryCifError { -// #[error("Invalid data type")] -// InvalidDataType, -// #[error("Invalid encoding")] -// InvalidEncoding, -// #[error("Invalid mask value")] -// InvalidMask, -// #[error("IO error: {0}")] -// Io(#[from] std::io::Error), -// #[error("MessagePack error: {0}")] -// MessagePack(#[from] rmp_serde::decode::Error), -// } - -// type BinaryCifResult = std::result::Result; // Changed from Result - -// impl BinaryCifFile { -// pub fn from_reader(reader: R) -> BinaryCifResult { -// let file: BinaryCifFile = rmp_serde::from_read(reader)?; -// Ok(file) -// } - -// pub fn from_bytes(bytes: &[u8]) -> BinaryCifResult { -// let file: BinaryCifFile = rmp_serde::from_slice(bytes)?; -// Ok(file) -// } -// } -// fn main() -> Result<(), Box> { -// // Load and parse file -// // let bytes = std::fs::read("data/ccd/components.bcif")?; -// let bytes = std::fs::read("/Users/zcpowers/Documents/Projects/ferritin/data/biotite-1.0.1/src/biotite/structure/info/ccd/components.bcif")?; -// let bcif = BinaryCifFile::from_bytes(&bytes)?; -// // Print structure overview -// println!("File version: {:?}", bcif.version); -// println!("Number of data blocks: {:?}", bcif.data_blocks.len()); - -// for block in &bcif.data_blocks { -// println!("\nBlock: {:?}", block.header); -// for category in &block.categories { -// println!( -// " Category {:?} ({:?} rows, {:?} columns)", -// category.name, -// category.row_count, -// category.columns.len() -// ); -// } -// } - -// Ok(()) -// } -// // -// -// -fn main() {} diff --git a/ferritin-core/src/featurize/mod.rs b/ferritin-core/src/featurize/mod.rs new file mode 100644 index 00000000..f81b7574 --- /dev/null +++ b/ferritin-core/src/featurize/mod.rs @@ -0,0 +1,12 @@ +//! Protein Featurizer for ProteinMPNN/LignadMPNN +//! +//! Extract protein features for ligandmpnn +//! +//! Returns a set of features calculated from protein structure +//! including: +//! - Residue-level features like amino acid type, secondary structure +//! - Geometric features like distances, angles +//! - Chemical features like hydrophobicity, charge +//! - Evolutionary features from MSA profiles +mod ndarray_impl; +mod utilities; diff --git a/ferritin-core/src/featurize/ndarray_impl.rs b/ferritin-core/src/featurize/ndarray_impl.rs new file mode 100644 index 00000000..1cc1fe9f --- /dev/null +++ b/ferritin-core/src/featurize/ndarray_impl.rs @@ -0,0 +1,500 @@ +use super::utilities::{aa1to_int, aa3to1, AAAtom}; +use crate::AtomCollection; +use itertools::MultiUnzip; +use ndarray::{Array, Array1, Array2, Array4}; +use pdbtbx::Element; +use std::collections::{HashMap, HashSet}; +use strum::IntoEnumIterator; + +fn is_heavy_atom(element: &Element) -> bool { + !matches!(element, Element::H | Element::He) +} + +pub trait StructureFeatures { + type Error; + + /// Convert amino acid sequence to numeric representation + fn encode_amino_acids(&self) -> Result, Self::Error>; + + /// Convert structure into complete feature set + fn featurize(&self) -> Result; + + /// Get residue indices + fn get_res_index(&self) -> Array1; + + /// Extract backbone atom coordinates (N, CA, C, O) + fn to_numeric_backbone_atoms(&self) -> Result, Self::Error>; + + /// Extract all atom coordinates in standard ordering + fn to_numeric_atom37(&self) -> Result, Self::Error>; + + /// Extract ligand atom coordinates and properties + fn to_numeric_ligand_atoms( + &self, + ) -> Result<(Array2, Array1, Array2), Self::Error>; + + /// Convert to PDB format + fn to_pdb(&self); +} + +impl StructureFeatures for AtomCollection { + type Error = ndarray::ShapeError; + + fn to_pdb(&self) { + todo!() + } + fn featurize(&self) -> Result { + todo!() + } + fn encode_amino_acids(&self) -> Result, Self::Error> { + let n = self.iter_residues_aminoacid().count(); + let sequence: Vec = self + .iter_residues_aminoacid() + .map(|res| res.res_name) + .map(|res| aa3to1(&res)) + .map(|res| aa1to_int(res) as f32) + .collect(); + + Array::from_shape_vec((1, n), sequence) + } + + fn to_numeric_backbone_atoms(&self) -> Result, Self::Error> { + let res_count = self.iter_residues_aminoacid().count(); + let mut backbone_data = vec![0f32; res_count * 4 * 3]; + + for residue in self.iter_residues_aminoacid() { + let resid = residue.res_id as usize; + let backbone_atoms = [ + residue.find_atom_by_name("N"), + residue.find_atom_by_name("CA"), + residue.find_atom_by_name("C"), + residue.find_atom_by_name("O"), + ]; + + for (atom_idx, maybe_atom) in backbone_atoms.iter().enumerate() { + if let Some(atom) = maybe_atom { + let [x, y, z] = atom.coords; + let base_idx = (resid * 4 + atom_idx) * 3; + backbone_data[base_idx] = *x; + backbone_data[base_idx + 1] = *y; + backbone_data[base_idx + 2] = *z; + } + } + } + + Array::from_shape_vec((1, res_count, 4, 3), backbone_data) + } + + fn to_numeric_atom37(&self) -> Result, Self::Error> { + let res_count = self.iter_residues_aminoacid().count(); + let mut atom37_data = vec![0f32; res_count * 37 * 3]; + + for (idx, residue) in self.iter_residues_aminoacid().enumerate() { + for atom_type in AAAtom::iter().filter(|&a| a != AAAtom::Unknown) { + if let Some(atom) = residue.find_atom_by_name(&atom_type.to_string()) { + let [x, y, z] = atom.coords; + let base_idx = (idx * 37 + atom_type as usize) * 3; + atom37_data[base_idx] = *x; + atom37_data[base_idx + 1] = *y; + atom37_data[base_idx + 2] = *z; + } + } + } + + Array::from_shape_vec((1, res_count, 37, 3), atom37_data) + } + + fn to_numeric_ligand_atoms( + &self, + ) -> Result<(Array2, Array1, Array2), Self::Error> { + let (coords, elements): (Vec<[f32; 3]>, Vec) = self + .iter_residues_all() + .filter(|residue| { + let res_name = &residue.res_name; + !residue.is_amino_acid() && res_name != "HOH" && res_name != "WAT" + }) + .flat_map(|residue| { + residue + .iter_atoms() + .filter(|atom| is_heavy_atom(&atom.element)) + .map(|atom| (*atom.coords, atom.element.clone())) + .collect::>() + }) + .multiunzip(); + + let n_atoms = coords.len(); + let coords_flat: Vec = coords.into_iter().flat_map(|[x, y, z]| [x, y, z]).collect(); + let coords_array = Array::from_shape_vec((n_atoms, 3), coords_flat)?; + + let elements_array = + Array1::from_vec(elements.iter().map(|e| e.atomic_number() as f32).collect()); + + let mask_array = Array::ones((n_atoms, 3)); + + Ok((coords_array, elements_array, mask_array)) + } + + fn get_res_index(&self) -> Array1 { + self.iter_residues_aminoacid() + .map(|res| res.res_id as u32) + .collect() + } +} + +pub struct ProteinFeatures { + /// protein amino acids sequences as 1D Array of u32 + pub sequence: Array2, + /// protein coords by residue [batch, seqlength, 37, 3] + pub coordinates: Array4, + /// protein mask by residue + pub mask: Option>, + /// ligand coords + pub ligand_coords: Array2, + /// encoded ligand atom names + pub ligand_types: Array1, + /// ligand mask + pub ligand_mask: Option>, + /// residue indices + pub residue_index: Array1, + /// chain labels + pub chain_labels: Option>, + /// chain letters + pub chain_letters: Vec, + /// chain mask + pub chain_mask: Option>, + /// list of chains + pub chain_list: Vec, +} + +impl ProteinFeatures { + pub fn get_coords(&self) -> &Array4 { + &self.coordinates + } + + pub fn get_sequence(&self) -> &Array2 { + &self.sequence + } + + pub fn get_sequence_mask(&self) -> Option<&Array4> { + self.mask.as_ref() + } + + pub fn get_residue_index(&self) -> &Array1 { + &self.residue_index + } + + pub fn get_encoded( + &self, + ) -> Result<(Vec, HashMap, HashMap), ndarray::ShapeError> + { + let r_idx_list = self + .residue_index + .iter() + .map(|&x| x as u32) + .collect::>(); + let chain_letters_list = &self.chain_letters; + + let encoded_residues: Vec = r_idx_list + .iter() + .enumerate() + .map(|(i, r_idx)| format!("{}{}", chain_letters_list[i], r_idx)) + .collect(); + + let encoded_residue_dict: HashMap = encoded_residues + .iter() + .enumerate() + .map(|(i, s)| (s.clone(), i)) + .collect(); + + let encoded_residue_dict_rev: HashMap = encoded_residues + .iter() + .enumerate() + .map(|(i, s)| (i, s.clone())) + .collect(); + + Ok(( + encoded_residues, + encoded_residue_dict, + encoded_residue_dict_rev, + )) + } + + pub fn get_encoded_array( + &self, + fixed_residues: &str, + ) -> Result, ndarray::ShapeError> { + let res_set: HashSet = fixed_residues.split(' ').map(String::from).collect(); + let (encoded_res, _, _) = self.get_encoded()?; + + Ok(Array1::from_vec( + encoded_res + .iter() + .map(|item| if res_set.contains(item) { 0.0 } else { 1.0 }) + .collect(), + )) + } + + pub fn get_chain_mask_array( + &self, + chains_to_design: &[String], + ) -> Result, ndarray::ShapeError> { + Ok(Array1::from_vec( + self.chain_letters + .iter() + .map(|chain| { + if chains_to_design.contains(chain) { + 1.0 + } else { + 0.0 + } + }) + .collect(), + )) + } + + pub fn update_mask(&mut self, array: Array4) -> Result<(), ndarray::ShapeError> { + if let Some(ref mask) = self.mask { + self.mask = Some(mask * &array); + } else { + self.mask = Some(array); + } + Ok(()) + } + + pub fn create_bias_array( + &self, + bias_aa: Option<&str>, + ) -> Result, ndarray::ShapeError> { + let mut bias_values = vec![0.0f32; 21]; + + if let Some(bias_str) = bias_aa { + for pair in bias_str.split(',') { + if let Some((aa, value_str)) = pair.split_once(':') { + if let Ok(value) = value_str.parse::() { + if let Some(aa_char) = aa.chars().next() { + let idx = aa1to_int(aa_char) as usize; + bias_values[idx] = value; + } + } + } + } + } + + Ok(Array1::from_vec(bias_values)) + } +} + +// pub struct ProteinFeatures { +// /// Protein amino acid sequence encoding +// pub sequence: Array2, +// /// Protein atom coordinates +// pub coordinates: Array3, +// /// Protein atom mask +// pub atom_mask: Option>, +// /// Ligand coordinates +// pub ligand_coords: Array2, +// /// Ligand atom types +// pub ligand_types: Array2, +// /// Ligand mask +// pub ligand_mask: Option>, +// /// Residue indices +// pub residue_indices: Array1, +// /// Chain labels +// pub chain_labels: Option>, +// /// Chain letters +// pub chain_letters: Vec, +// /// Chain mask +// pub chain_mask: Option>, +// /// List of chains +// pub chain_list: Vec, +// } +// impl ProteinFeatures { +// pub fn get_coords(&self) -> &Array3 { +// &self.coordinates +// } + +// pub fn get_sequence(&self) -> &Array2 { +// &self.sequence +// } + +// pub fn get_sequence_mask(&self) -> Option<&Array2> { +// self.atom_mask.as_ref() +// } + +// pub fn get_residue_index(&self) -> &Array1 { +// &self.residue_indices +// } + +// pub fn get_encoded( +// &self, +// ) -> Result< +// (Vec, HashMap, HashMap), +// Box, +// > { +// // Implementation +// todo!() +// } +// } +// + +#[cfg(test)] +mod tests { + use super::*; + use ferritin_test_data::TestFile; + use ndarray::{s, Array4}; + + #[test] + fn test_atom_backbone_tensor() { + let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap(); + let (pdb, _) = pdbtbx::open(pdb_file).unwrap(); + let ac = AtomCollection::from(&pdb); + let ac_backbone_tensor: Array4 = ac.to_numeric_backbone_atoms().expect("REASON"); + // Check my residue coords in the Tensor + // ATOM 1 N N . MET A 1 1 ? 24.277 8.374 -9.854 1.00 38.41 ? 0 MET A N 1 + // ATOM 2 C CA . MET A 1 1 ? 24.404 9.859 -9.939 1.00 37.90 ? 0 MET A CA 1 + // ATOM 3 C C . MET A 1 1 ? 25.814 10.249 -10.359 1.00 36.65 ? 0 MET A C 1 + // ATOM 4 O O . MET A 1 1 ? 26.748 9.469 -10.197 1.00 37.13 ? 0 MET A O 1 + let backbone_coords = [ + // Methionine - AA00 + ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]), + ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]), + ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]), + ("O", (0, 0, 3, ..), vec![26.748, 9.469, -10.197]), + // Valine - AA01 + ("N", (0, 1, 0, ..), vec![25.964, 11.453, -10.903]), + ("CA", (0, 1, 1, ..), vec![27.263, 11.924, -11.359]), + ("C", (0, 1, 2, ..), vec![27.392, 13.428, -11.115]), + ("O", (0, 1, 3, ..), vec![26.443, 14.184, -11.327]), + // Glycing - AAlast + ("N", (0, 153, 0, ..), vec![23.474, -3.227, 5.994]), + ("CA", (0, 153, 1, ..), vec![22.818, -2.798, 7.211]), + ("C", (0, 153, 2, ..), vec![22.695, -1.282, 7.219]), + ("O", (0, 153, 3, ..), vec![21.870, -0.745, 7.992]), + ]; + + for (atom_name, (b, i, j, k), expected) in backbone_coords { + let actual: Vec = ac_backbone_tensor.slice(s![b, i, j, k]).to_vec(); + assert_eq!(actual, expected, "Mismatch for atom {}", atom_name); + } + } + + #[test] + fn test_all_atom37_tensor() { + let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap(); + let (pdb, _) = pdbtbx::open(pdb_file).unwrap(); + let ac = AtomCollection::from(&pdb); + let ac_backbone_tensor: Array4 = ac.to_numeric_atom37().expect("REASON"); + // batch size of 1154 residues; all atoms; positions + assert_eq!(ac_backbone_tensor.dim(), (1, 154, 37, 3)); + + // Check my residue coords in the Tensor + // ATOM 1 N N . MET A 1 1 ? 24.277 8.374 -9.854 1.00 38.41 ? 0 MET A N 1 + // ATOM 2 C CA . MET A 1 1 ? 24.404 9.859 -9.939 1.00 37.90 ? 0 MET A CA 1 + // ATOM 3 C C . MET A 1 1 ? 25.814 10.249 -10.359 1.00 36.65 ? 0 MET A C 1 + // ATOM 4 O O . MET A 1 1 ? 26.748 9.469 -10.197 1.00 37.13 ? 0 MET A O 1 + // ATOM 5 C CB . MET A 1 1 ? 24.070 10.495 -8.596 1.00 39.58 ? 0 MET A CB 1 + // ATOM 6 C CG . MET A 1 1 ? 24.880 9.939 -7.442 1.00 41.49 ? 0 MET A CG 1 + // ATOM 7 S SD . MET A 1 1 ? 24.262 10.555 -5.873 1.00 44.70 ? 0 MET A SD 1 + // ATOM 8 C CE . MET A 1 1 ? 24.822 12.266 -5.967 1.00 41.59 ? 0 MET A CE 1 + // + // pub enum AAAtom { + // N = 0, CA = 1, C = 2, CB = 3, O = 4, + // CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9, + // SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14, + // ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19, + // CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24, + // NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29, + // NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34, + // NZ = 35, OXT = 36, + // Unknown = -1, + // } + let allatom_coords = [ + // Methionine - AA00 + // We iterate through these positions. Not all AA's have each + ("N", (0, 0, 0, ..), vec![24.277, 8.374, -9.854]), + ("CA", (0, 0, 1, ..), vec![24.404, 9.859, -9.939]), + ("C", (0, 0, 2, ..), vec![25.814, 10.249, -10.359]), + ("CB", (0, 0, 3, ..), vec![24.070, 10.495, -8.596]), + ("O", (0, 0, 4, ..), vec![26.748, 9.469, -10.197]), + ("CG", (0, 0, 5, ..), vec![24.880, 9.939, -7.442]), + ("CG1", (0, 0, 6, ..), vec![0.0, 0.0, 0.0]), + ("CG2", (0, 0, 7, ..), vec![0.0, 0.0, 0.0]), + ("OG", (0, 0, 8, ..), vec![0.0, 0.0, 0.0]), + ("OG1", (0, 0, 9, ..), vec![0.0, 0.0, 0.0]), + ("SG", (0, 0, 10, ..), vec![0.0, 0.0, 0.0]), + ("CD", (0, 0, 11, ..), vec![0.0, 0.0, 0.0]), + ("CD1", (0, 0, 12, ..), vec![0.0, 0.0, 0.0]), + ("CD2", (0, 0, 13, ..), vec![0.0, 0.0, 0.0]), + ("ND1", (0, 0, 14, ..), vec![0.0, 0.0, 0.0]), + ("ND2", (0, 0, 15, ..), vec![0.0, 0.0, 0.0]), + ("OD1", (0, 0, 16, ..), vec![0.0, 0.0, 0.0]), + ("OD2", (0, 0, 17, ..), vec![0.0, 0.0, 0.0]), + ("SD", (0, 0, 18, ..), vec![24.262, 10.555, -5.873]), + ("CE", (0, 0, 19, ..), vec![24.822, 12.266, -5.967]), + ("CE1", (0, 0, 20, ..), vec![0.0, 0.0, 0.0]), + ("CE2", (0, 0, 21, ..), vec![0.0, 0.0, 0.0]), + ("CE3", (0, 0, 22, ..), vec![0.0, 0.0, 0.0]), + ("NE", (0, 0, 23, ..), vec![0.0, 0.0, 0.0]), + ("NE1", (0, 0, 24, ..), vec![0.0, 0.0, 0.0]), + ("NE2", (0, 0, 25, ..), vec![0.0, 0.0, 0.0]), + ("OE1", (0, 0, 26, ..), vec![0.0, 0.0, 0.0]), + ("OE2", (0, 0, 27, ..), vec![0.0, 0.0, 0.0]), + ("CH2", (0, 0, 28, ..), vec![0.0, 0.0, 0.0]), + ("NH1", (0, 0, 29, ..), vec![0.0, 0.0, 0.0]), + ("NH2", (0, 0, 30, ..), vec![0.0, 0.0, 0.0]), + ("OH", (0, 0, 31, ..), vec![0.0, 0.0, 0.0]), + ("CZ", (0, 0, 32, ..), vec![0.0, 0.0, 0.0]), + ("CZ2", (0, 0, 33, ..), vec![0.0, 0.0, 0.0]), + ("CZ3", (0, 0, 34, ..), vec![0.0, 0.0, 0.0]), + ("NZ", (0, 0, 35, ..), vec![0.0, 0.0, 0.0]), + ("OXT", (0, 0, 36, ..), vec![0.0, 0.0, 0.0]), + ]; + for (atom_name, (b, i, j, k), expected) in allatom_coords { + let actual: Vec = ac_backbone_tensor.slice(s![b, i, j, k]).to_vec(); + assert_eq!(actual, expected, "Mismatch for atom {}", atom_name); + } + } + + #[test] + fn test_ligand_tensor() { + let (pdb_file, _temp) = TestFile::protein_01().create_temp().unwrap(); + let (pdb, _) = pdbtbx::open(pdb_file).unwrap(); + let ac = AtomCollection::from(&pdb); + let (ligand_coords, ligand_elements, _) = ac.to_numeric_ligand_atoms().expect("REASON"); + // 54 residues; N/CA/C/O; positions + assert_eq!(ligand_coords.dim(), (54, 3)); + + // Check my residue coords in the Tensor + // + // HETATM 1222 S S . SO4 B 2 . ? 30.746 18.706 28.896 1.00 47.98 ? 157 SO4 A S 1 + // HETATM 1223 O O1 . SO4 B 2 . ? 30.697 20.077 28.620 1.00 48.06 ? 157 SO4 A O1 1 + // HETATM 1224 O O2 . SO4 B 2 . ? 31.104 18.021 27.725 1.00 47.52 ? 157 SO4 A O2 1 + // HETATM 1225 O O3 . SO4 B 2 . ? 29.468 18.179 29.331 1.00 47.79 ? 157 SO4 A O3 1 + // HETATM 1226 O O4 . SO4 B 2 . ? 31.722 18.578 29.881 1.00 47.85 ? 157 SO4 A O4 1 + let allatom_coords = [ + ("S", (0, ..), vec![30.746, 18.706, 28.896]), + ("O1", (1, ..), vec![30.697, 20.077, 28.620]), + ("O2", (2, ..), vec![31.104, 18.021, 27.725]), + ("O3", (3, ..), vec![29.468, 18.179, 29.331]), + ("O4", (4, ..), vec![31.722, 18.578, 29.881]), + ]; + + for (atom_name, (i, j), expected) in allatom_coords { + let actual: Vec = ligand_coords.slice(s![i, j]).to_vec(); + assert_eq!(actual, expected, "Mismatch for atom {}", atom_name); + } + + // Now check the elements + // + let elements: Vec<&str> = ligand_elements + .to_vec() + .into_iter() + .map(|elem| Element::new(elem as usize).unwrap().symbol()) + .collect(); + + assert_eq!(elements[0], "S"); + assert_eq!(elements[1], "O"); + assert_eq!(elements[2], "O"); + assert_eq!(elements[3], "O"); + } +} diff --git a/ferritin-core/src/featurize/utilities.rs b/ferritin-core/src/featurize/utilities.rs new file mode 100644 index 00000000..17e4005d --- /dev/null +++ b/ferritin-core/src/featurize/utilities.rs @@ -0,0 +1,179 @@ +use strum::{Display, EnumIter, EnumString}; + +#[rustfmt::skip] +// todo: better utility library +pub fn aa3to1(aa: &str) -> char { + match aa { + "ALA" => 'A', "CYS" => 'C', "ASP" => 'D', + "GLU" => 'E', "PHE" => 'F', "GLY" => 'G', + "HIS" => 'H', "ILE" => 'I', "LYS" => 'K', + "LEU" => 'L', "MET" => 'M', "ASN" => 'N', + "PRO" => 'P', "GLN" => 'Q', "ARG" => 'R', + "SER" => 'S', "THR" => 'T', "VAL" => 'V', + "TRP" => 'W', "TYR" => 'Y', _ => 'X', + } +} + +#[rustfmt::skip] +// todo: better utility library +pub fn aa1to_int(aa: char) -> u32 { + match aa { + 'A' => 0, 'C' => 1, 'D' => 2, + 'E' => 3, 'F' => 4, 'G' => 5, + 'H' => 6, 'I' => 7, 'K' => 8, + 'L' => 9, 'M' => 10, 'N' => 11, + 'P' => 12, 'Q' => 13, 'R' => 14, + 'S' => 15, 'T' => 16, 'V' => 17, + 'W' => 18, 'Y' => 19, _ => 20, + } +} + +#[rustfmt::skip] +pub fn int_to_aa1(aa_int: u32) -> char { + match aa_int { + 0 => 'A', 1 => 'C', 2 => 'D', + 3 => 'E', 4 => 'F', 5 => 'G', + 6 => 'H', 7 => 'I', 8 => 'K', + 9 => 'L', 10 => 'M', 11 => 'N', + 12 => 'P', 13 => 'Q', 14 => 'R', + 15 => 'S', 16 => 'T', 17 => 'V', + 18 => 'W', 19 => 'Y', 20 => 'X', + _ => 'X' + + } +} + +const ALPHABET: [char; 21] = [ + 'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', + 'Y', 'X', +]; + +const ELEMENT_LIST: [&str; 118] = [ + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", + "Ar", "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", + "Se", "Br", "Kr", "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd", "In", + "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", + "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", + "Pb", "Bi", "Po", "At", "Rn", "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk", + "Cf", "Es", "Fm", "Md", "No", "Lr", "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", + "Fl", "Mc", "Lv", "Ts", "Og", +]; + +#[rustfmt::skip] +#[derive(Debug, Clone, Copy, PartialEq, Display, EnumString, EnumIter)] +pub enum AAAtom { + N = 0, CA = 1, C = 2, CB = 3, O = 4, + CG = 5, CG1 = 6, CG2 = 7, OG = 8, OG1 = 9, + SG = 10, CD = 11, CD1 = 12, CD2 = 13, ND1 = 14, + ND2 = 15, OD1 = 16, OD2 = 17, SD = 18, CE = 19, + CE1 = 20, CE2 = 21, CE3 = 22, NE = 23, NE1 = 24, + NE2 = 25, OE1 = 26, OE2 = 27, CH2 = 28, NH1 = 29, + NH2 = 30, OH = 31, CZ = 32, CZ2 = 33, CZ3 = 34, + NZ = 35, OXT = 36, + Unknown = -1, +} +impl AAAtom { + // Get numeric value (might still be useful in some contexts) + pub fn to_index(&self) -> usize { + *self as usize + } +} + +macro_rules! define_residues { + ($($name:ident: $code3:expr, $code1:expr, $idx:expr, $features:expr, $atoms14:expr),* $(,)?) => { + #[derive(Debug, Copy, Clone)] + pub enum Residue { + $($name),* + } + + impl Residue { + pub const fn code3(&self) -> &'static str { + match self { + $(Self::$name => $code3),* + } + } + pub const fn code1(&self) -> char { + match self { + $(Self::$name => $code1),* + } + } + pub const fn atoms14(&self) -> [AAAtom; 14] { + match self { + $(Self::$name => $atoms14),* + } + } + pub fn from_int(value: i32) -> Self { + match value { + $($idx => Self::$name,)* + _ => Self::UNK + } + } + pub fn to_int(&self) -> i32 { + match self { + $(Self::$name => $idx),* + } + } + } + } +} + +define_residues! { + ALA: "ALA", 'A', 0, [1.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + CYS: "CYS", 'C', 1, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::SG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + ASP: "ASP", 'D', 2, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::OD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + GLU: "GLU", 'E', 3, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::OE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + PHE: "PHE", 'F', 4, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + GLY: "GLY", 'G', 5, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + HIS: "HIS", 'H', 6, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::ND1, AAAtom::CD2, AAAtom::CE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + ILE: "ILE", 'I', 7, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::CD1, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + LYS: "LYS", 'K', 8, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::CE, AAAtom::NZ, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + LEU: "LEU", 'L', 9, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + MET: "MET", 'M', 10, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::SD, AAAtom::CE, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + ASN: "ASN", 'N', 11, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::OD1, AAAtom::ND2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + PRO: "PRO", 'P', 12, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + GLN: "GLN", 'Q', 13, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::OE1, AAAtom::NE2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + ARG: "ARG", 'R', 14, [0.0, 1.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD, AAAtom::NE, AAAtom::CZ, AAAtom::NH1, AAAtom::NH2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + SER: "SER", 'S', 15, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + THR: "THR", 'T', 16, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::OG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + VAL: "VAL", 'V', 17, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG1, AAAtom::CG2, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], + TRP: "TRP", 'W', 18, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE2, AAAtom::CE3, AAAtom::NE1, AAAtom::CZ2, AAAtom::CZ3, AAAtom::CH2], + TYR: "TYR", 'Y', 19, [0.0, 0.0], [AAAtom::N, AAAtom::CA, AAAtom::C, AAAtom::O, AAAtom::CB, AAAtom::CG, AAAtom::CD1, AAAtom::CD2, AAAtom::CE1, AAAtom::CE2, AAAtom::CZ, AAAtom::OH, AAAtom::Unknown, AAAtom::Unknown], + UNK: "UNK", 'X', 20, [0.0, 0.0], [AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown, AAAtom::Unknown], +} + +#[cfg(test)] +mod tests { + use super::*; + // use crate::ligandmpnn::proteinfeatures::LMPNNFeatures; + use crate::AtomCollection; + use ferritin_test_data::TestFile; + use pdbtbx; + use pdbtbx::Element; + + #[test] + fn test_residue_codes() { + let ala = Residue::ALA; + assert_eq!(ala.code3(), "ALA"); + assert_eq!(ala.code1(), 'A'); + assert_eq!(ala.to_int(), 0); + } + + #[test] + fn test_residue_from_int() { + assert!(matches!(Residue::from_int(0), Residue::ALA)); + assert!(matches!(Residue::from_int(1), Residue::CYS)); + assert!(matches!(Residue::from_int(999), Residue::UNK)); + } + + #[test] + fn test_residue_atoms() { + let trp = Residue::TRP; + let atoms = trp.atoms14(); + assert_eq!(atoms[0], AAAtom::N); + assert_eq!(atoms[13], AAAtom::CH2); + + let gly = Residue::GLY; + let atoms = gly.atoms14(); + assert_eq!(atoms[4], AAAtom::Unknown); + } +} diff --git a/ferritin-core/src/lib.rs b/ferritin-core/src/lib.rs index a351901e..b6991abf 100644 --- a/ferritin-core/src/lib.rs +++ b/ferritin-core/src/lib.rs @@ -14,6 +14,7 @@ mod atomcollection; mod bonds; mod conversions; +mod featurize; mod info; mod residue; mod selection; diff --git a/ferritin-plms/src/ligandmpnn/proteinfeatures.rs b/ferritin-plms/src/ligandmpnn/proteinfeatures.rs index 07033a5f..df6e3bf4 100644 --- a/ferritin-plms/src/ligandmpnn/proteinfeatures.rs +++ b/ferritin-plms/src/ligandmpnn/proteinfeatures.rs @@ -28,6 +28,7 @@ pub trait LMPNNFeatures { fn featurize(&self, device: &Device) -> Result; // need more control over this featurization process fn get_res_index(&self) -> Vec; fn to_numeric_backbone_atoms(&self, device: &Device) -> Result; // [residues, N/CA/C/O, xyz] + fn to_numeric_atom37(&self, device: &Device) -> Result; // [residues, N/CA/C/O....37, xyz] fn to_numeric_ligand_atoms(&self, device: &Device) -> Result<(Tensor, Tensor, Tensor)>; // ( positions , elements, mask ) fn to_pdb(&self); //