1212
1313from ..trajectory_converter import TrajectoryConverter
1414from ..data_objects import TrajectoryData , AgentData , DimensionData , DisplayData
15- from ..constants import DISPLAY_TYPE , JMOL_COLORS
15+ from ..constants import DISPLAY_TYPE , JMOL_COLORS , VIZ_TYPE , SUBPOINT_VALUES_PER_ITEM
1616from .md_data import MdData
1717
1818###############################################################################
@@ -28,6 +28,7 @@ def __init__(
2828 input_data : MdData ,
2929 progress_callback : Callable [[float ], None ] = None ,
3030 callback_interval : float = 10 ,
31+ draw_bonds : bool = False ,
3132 ):
3233 """
3334 This object reads simulation trajectory outputs
@@ -49,8 +50,11 @@ def __init__(
4950 If a progress_callback was provided, the period between updates
5051 to be sent to the callback, in seconds
5152 Default: 10
53+ draw_bonds: bool (optional)
54+ Default: False
5255 """
5356 super ().__init__ (input_data , progress_callback , callback_interval )
57+ self .draw_bonds = draw_bonds
5458 self ._data = self ._read (input_data )
5559
5660 @staticmethod
@@ -166,6 +170,19 @@ def _read_universe(self, input_data: MdData) -> Tuple[AgentData, float]:
166170 Use a MD Universe to get AgentData
167171 """
168172 dimensions = MdConverter ._read_universe_dimensions (input_data )
173+ if input_data .draw_bonds :
174+ bond_indices = input_data .md_universe .bonds .indices
175+ n_bonds = bond_indices .shape [0 ]
176+ n_max_subpoints = 2 * SUBPOINT_VALUES_PER_ITEM (DISPLAY_TYPE .FIBER )
177+ else :
178+ bond_indices = np .array ([])
179+ n_bonds = 0
180+ n_max_subpoints = 0
181+ dimensions = DimensionData (
182+ dimensions .total_steps ,
183+ dimensions .max_agents + n_bonds ,
184+ n_max_subpoints ,
185+ )
169186 result = AgentData .from_dimensions (dimensions )
170187 get_type_name_func = np .frompyfunc (MdConverter ._get_type_name , 2 , 1 )
171188 unique_raw_type_names = set ([])
@@ -176,29 +193,68 @@ def _read_universe(self, input_data: MdData) -> Tuple[AgentData, float]:
176193 ]:
177194 result .times [time_index ] = input_data .md_universe .trajectory .time
178195 atom_positions = input_data .md_universe .atoms .positions
179- result .n_agents [time_index ] = atom_positions .shape [0 ]
180- result .unique_ids [time_index ] = np .arange (atom_positions .shape [0 ])
181- unique_raw_type_names .update (list (input_data .md_universe .atoms .names ))
182- result .types [time_index ] = get_type_name_func (
196+ if input_data .draw_bonds :
197+ bond_subpoints = np .array (
198+ [
199+ np .concatenate (
200+ [
201+ atom_positions [bond_indices [i ][0 ]],
202+ atom_positions [bond_indices [i ][1 ]],
203+ ]
204+ )
205+ for i in range (bond_indices .shape [0 ])
206+ ]
207+ )
208+ else :
209+ bond_subpoints = np .array ([])
210+ n_agents = atom_positions .shape [0 ] + n_bonds
211+ result .n_agents [time_index ] = n_agents
212+ result .unique_ids [time_index ] = np .arange (n_agents )
213+ type_name_list = list (input_data .md_universe .atoms .names )
214+ if input_data .draw_bonds :
215+ type_name_list += ['bond' ]
216+ unique_raw_type_names .update (type_name_list )
217+ result .types [time_index ][:atom_positions .shape [0 ]] = get_type_name_func (
183218 input_data .md_universe .atoms .names , input_data
184219 )
185- result .positions [time_index ] = atom_positions
186- result .radii [time_index ] = np .array (
187- [
220+ if input_data .draw_bonds :
221+ result .types [time_index ][atom_positions .shape [0 ]:] = ['bond' ] * n_bonds
222+ result .positions [time_index ][:atom_positions .shape [0 ]] = atom_positions
223+ radii_list = [
224+ MdConverter ._get_radius (type_name , input_data )
225+ for type_name in input_data .md_universe .atoms .names
226+ ]
227+ if input_data .draw_bonds :
228+ radii_list += [
188229 MdConverter ._get_radius (type_name , input_data )
189- for type_name in input_data . md_universe . atoms . names
230+ for type_name in [ 'bond' ] * n_bonds
190231 ]
191- )
232+ result .radii [time_index ] = np .array (radii_list )
233+
234+ if input_data .draw_bonds :
235+ result .n_subpoints [time_index ] = np .array (
236+ [0 ] * atom_positions .shape [0 ]
237+ + [2 * SUBPOINT_VALUES_PER_ITEM (DISPLAY_TYPE .FIBER )] * n_bonds
238+ )
239+
240+ result .subpoints [time_index ][
241+ atom_positions .shape [0 ]:] = bond_subpoints
242+ result .viz_types [time_index ][
243+ atom_positions .shape [0 ]:] = [VIZ_TYPE .FIBER ] * n_bonds
244+
192245 time_index += 1
193246 self .check_report_progress (time_index / dimensions .total_steps )
194247
195248 result .n_timesteps = dimensions .total_steps
196249 result .display_data = MdConverter ._get_display_data_mapping (
197250 unique_raw_type_names , input_data
198251 )
199- return TrajectoryConverter .scale_agent_data (
252+
253+ result , scale_factor = TrajectoryConverter .scale_agent_data (
200254 result , input_data .meta_data .scale_factor
201255 )
256+ result = TrajectoryConverter .center_fiber_positions (result )
257+ return result , scale_factor
202258
203259 def _read (self , input_data : MdData ) -> TrajectoryData :
204260 """
0 commit comments