Skip to content

Commit bfa14fd

Browse files
advaitathreyaBlair Lyons
andauthored
Feature/draw md bonds (#183)
* Update md_converter.py to include option to draw bonds from MD trajectory * add draw_bonds option to MdConverter * formatting * formatting and move draw_bonds option to MdData * test for MD draw bonds --------- Co-authored-by: Blair Lyons <[email protected]>
1 parent 1494d22 commit bfa14fd

File tree

3 files changed

+154
-11
lines changed

3 files changed

+154
-11
lines changed

simulariumio/md/md_converter.py

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from ..trajectory_converter import TrajectoryConverter
1414
from ..data_objects import TrajectoryData, AgentData, DimensionData, DisplayData
15-
from ..constants import DISPLAY_TYPE, JMOL_COLORS
15+
from ..constants import DISPLAY_TYPE, JMOL_COLORS, VIZ_TYPE, SUBPOINT_VALUES_PER_ITEM
1616
from .md_data import MdData
1717

1818
###############################################################################
@@ -28,6 +28,7 @@ def __init__(
2828
input_data: MdData,
2929
progress_callback: Callable[[float], None] = None,
3030
callback_interval: float = 10,
31+
draw_bonds: bool = False,
3132
):
3233
"""
3334
This object reads simulation trajectory outputs
@@ -49,8 +50,11 @@ def __init__(
4950
If a progress_callback was provided, the period between updates
5051
to be sent to the callback, in seconds
5152
Default: 10
53+
draw_bonds: bool (optional)
54+
Default: False
5255
"""
5356
super().__init__(input_data, progress_callback, callback_interval)
57+
self.draw_bonds = draw_bonds
5458
self._data = self._read(input_data)
5559

5660
@staticmethod
@@ -166,6 +170,19 @@ def _read_universe(self, input_data: MdData) -> Tuple[AgentData, float]:
166170
Use a MD Universe to get AgentData
167171
"""
168172
dimensions = MdConverter._read_universe_dimensions(input_data)
173+
if input_data.draw_bonds:
174+
bond_indices = input_data.md_universe.bonds.indices
175+
n_bonds = bond_indices.shape[0]
176+
n_max_subpoints = 2 * SUBPOINT_VALUES_PER_ITEM(DISPLAY_TYPE.FIBER)
177+
else:
178+
bond_indices = np.array([])
179+
n_bonds = 0
180+
n_max_subpoints = 0
181+
dimensions = DimensionData(
182+
dimensions.total_steps,
183+
dimensions.max_agents + n_bonds,
184+
n_max_subpoints,
185+
)
169186
result = AgentData.from_dimensions(dimensions)
170187
get_type_name_func = np.frompyfunc(MdConverter._get_type_name, 2, 1)
171188
unique_raw_type_names = set([])
@@ -176,29 +193,68 @@ def _read_universe(self, input_data: MdData) -> Tuple[AgentData, float]:
176193
]:
177194
result.times[time_index] = input_data.md_universe.trajectory.time
178195
atom_positions = input_data.md_universe.atoms.positions
179-
result.n_agents[time_index] = atom_positions.shape[0]
180-
result.unique_ids[time_index] = np.arange(atom_positions.shape[0])
181-
unique_raw_type_names.update(list(input_data.md_universe.atoms.names))
182-
result.types[time_index] = get_type_name_func(
196+
if input_data.draw_bonds:
197+
bond_subpoints = np.array(
198+
[
199+
np.concatenate(
200+
[
201+
atom_positions[bond_indices[i][0]],
202+
atom_positions[bond_indices[i][1]],
203+
]
204+
)
205+
for i in range(bond_indices.shape[0])
206+
]
207+
)
208+
else:
209+
bond_subpoints = np.array([])
210+
n_agents = atom_positions.shape[0] + n_bonds
211+
result.n_agents[time_index] = n_agents
212+
result.unique_ids[time_index] = np.arange(n_agents)
213+
type_name_list = list(input_data.md_universe.atoms.names)
214+
if input_data.draw_bonds:
215+
type_name_list += ['bond']
216+
unique_raw_type_names.update(type_name_list)
217+
result.types[time_index][:atom_positions.shape[0]] = get_type_name_func(
183218
input_data.md_universe.atoms.names, input_data
184219
)
185-
result.positions[time_index] = atom_positions
186-
result.radii[time_index] = np.array(
187-
[
220+
if input_data.draw_bonds:
221+
result.types[time_index][atom_positions.shape[0]:] = ['bond'] * n_bonds
222+
result.positions[time_index][:atom_positions.shape[0]] = atom_positions
223+
radii_list = [
224+
MdConverter._get_radius(type_name, input_data)
225+
for type_name in input_data.md_universe.atoms.names
226+
]
227+
if input_data.draw_bonds:
228+
radii_list += [
188229
MdConverter._get_radius(type_name, input_data)
189-
for type_name in input_data.md_universe.atoms.names
230+
for type_name in ['bond'] * n_bonds
190231
]
191-
)
232+
result.radii[time_index] = np.array(radii_list)
233+
234+
if input_data.draw_bonds:
235+
result.n_subpoints[time_index] = np.array(
236+
[0] * atom_positions.shape[0]
237+
+ [2 * SUBPOINT_VALUES_PER_ITEM(DISPLAY_TYPE.FIBER)] * n_bonds
238+
)
239+
240+
result.subpoints[time_index][
241+
atom_positions.shape[0]:] = bond_subpoints
242+
result.viz_types[time_index][
243+
atom_positions.shape[0]:] = [VIZ_TYPE.FIBER] * n_bonds
244+
192245
time_index += 1
193246
self.check_report_progress(time_index / dimensions.total_steps)
194247

195248
result.n_timesteps = dimensions.total_steps
196249
result.display_data = MdConverter._get_display_data_mapping(
197250
unique_raw_type_names, input_data
198251
)
199-
return TrajectoryConverter.scale_agent_data(
252+
253+
result, scale_factor = TrajectoryConverter.scale_agent_data(
200254
result, input_data.meta_data.scale_factor
201255
)
256+
result = TrajectoryConverter.center_fiber_positions(result)
257+
return result, scale_factor
202258

203259
def _read(self, input_data: MdData) -> TrajectoryData:
204260
"""

simulariumio/md/md_data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class MdData:
2222
display_data: Dict[str, DisplayData]
2323
time_units: UnitData
2424
spatial_units: UnitData
25+
draw_bonds: bool
2526
plots: List[Dict[str, Any]]
2627

2728
def __init__(
@@ -32,6 +33,7 @@ def __init__(
3233
display_data: Dict[str, DisplayData] = None,
3334
time_units: UnitData = None,
3435
spatial_units: UnitData = None,
36+
draw_bonds: bool = False,
3537
plots: List[Dict[str, Any]] = None,
3638
):
3739
"""
@@ -66,6 +68,9 @@ def __init__(
6668
multiplier and unit name for spatial values
6769
(including positions, radii, and box size)
6870
Default: 1.0 meter
71+
draw_bonds: bool (optional)
72+
Draw bonds between atoms?
73+
Default: False
6974
plots : List[Dict[str, Any]] (optional)
7075
An object containing plot data already
7176
in Simularium format
@@ -78,4 +83,5 @@ def __init__(
7883
self.spatial_units = (
7984
spatial_units if spatial_units is not None else UnitData("m")
8085
)
86+
self.draw_bonds = draw_bonds
8187
self.plots = plots if plots is not None else []

simulariumio/tests/converters/test_md_converter.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,3 +398,84 @@ def test_callback_fn():
398398
assert call_value > last_call_val
399399
assert call_value <= 1.0
400400
last_call_val = call_value
401+
402+
403+
results_with_bonds = JsonWriter.format_trajectory_data(
404+
MdConverter(
405+
MdData(
406+
md_universe=Universe(
407+
"simulariumio/tests/data/md/example.xyz",
408+
guess_bonds=True,
409+
vdwradii={"T": 90.0},
410+
),
411+
draw_bonds=True,
412+
)
413+
)._data
414+
)
415+
416+
first_frame_data_with_bond = [
417+
VIZ_TYPE.DEFAULT, # first atom
418+
0.0,
419+
0.0,
420+
17.455860176109592, # x
421+
-9.282322881315764, # y
422+
25.164461444822095, # z
423+
0.0,
424+
0.0,
425+
0.0,
426+
0.410577680081599,
427+
0.0,
428+
VIZ_TYPE.DEFAULT, # second atom
429+
1.0,
430+
0.0,
431+
0.0, # x
432+
0.0, # y
433+
0.0, # z
434+
0.0,
435+
0.0,
436+
0.0,
437+
0.410577680081599,
438+
0.0,
439+
VIZ_TYPE.DEFAULT, # third atom
440+
2.0,
441+
1.0,
442+
-20.86782234438291, # x
443+
31.832387751607193, # y
444+
6.484507904672457, # z
445+
0.0,
446+
0.0,
447+
0.0,
448+
0.4516354480897589,
449+
0.0,
450+
VIZ_TYPE.FIBER, # bond agent
451+
3.0,
452+
2.0,
453+
8.727930088054796, # x
454+
-4.641161440657882, # y
455+
12.582230722411047, # z
456+
0.0,
457+
0.0,
458+
0.0,
459+
0.7883091457566701,
460+
6.0,
461+
8.727930088054796, # start x (first atom)
462+
-4.641161440657882, # start y
463+
12.582230722411047, # start z
464+
-8.727930088054796, # end x (second atom)
465+
4.641161440657882, # end y
466+
-12.582230722411047, # end z
467+
]
468+
469+
470+
# test draw bonds
471+
@pytest.mark.parametrize(
472+
"bundleData, expected_bundleData",
473+
[
474+
(
475+
results_with_bonds["spatialData"]["bundleData"][0],
476+
first_frame_data_with_bond,
477+
)
478+
],
479+
)
480+
def test_draw_bonds(bundleData, expected_bundleData):
481+
assert np.isclose(expected_bundleData, bundleData["data"]).all()

0 commit comments

Comments
 (0)