Skip to content

Commit 195dc6f

Browse files
committed
Add standalone RMSD/Rg/RMSF scripts and fix analysis interval/output handling
1 parent 078af5e commit 195dc6f

5 files changed

Lines changed: 521 additions & 8 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python3
2+
"""
3+
openmm_rg.py
4+
5+
Compute radius of gyration and save:
6+
- rg.csv
7+
- rg_vs_time.png
8+
"""
9+
10+
import argparse
11+
import datetime
12+
import glob
13+
import os
14+
import re
15+
import sys
16+
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
import MDAnalysis as mda
20+
21+
22+
def parse_args():
23+
p = argparse.ArgumentParser(description="Compute radius of gyration")
24+
p.add_argument("simdir", help="Simulation directory (contains solvated.pdb, prod_full.dcd)")
25+
p.add_argument("-t", "--topology", default=None)
26+
p.add_argument("-x", "--trajectory", default=None)
27+
p.add_argument("-i", "--interval", type=float, default=None, help="Frame interval in ps")
28+
p.add_argument("--smooth-ps", type=float, default=25.0, help="Moving-average window in ps")
29+
p.add_argument("--plot-stride", type=int, default=5, help="Stride for raw line plotting")
30+
p.add_argument("-o", "--outdir", default=None)
31+
return p.parse_args()
32+
33+
34+
def _data_lines(path):
35+
with open(path, "r", encoding="utf-8", errors="ignore") as f:
36+
for line in f:
37+
s = line.strip()
38+
if s and s[0].isdigit():
39+
yield s
40+
41+
42+
def detect_interval_from_logs(simdir, n_frames):
43+
if n_frames > 1:
44+
merged = os.path.join(simdir, "prod_full.log")
45+
if os.path.isfile(merged):
46+
vals = []
47+
for s in _data_lines(merged):
48+
parts = s.split()
49+
if len(parts) >= 2:
50+
try:
51+
vals.append(float(parts[1]))
52+
except Exception:
53+
pass
54+
if len(vals) > 1:
55+
dt = (vals[-1] - vals[0]) / max(1, len(vals) - 1)
56+
if dt > 0:
57+
return dt, "prod_full.log"
58+
59+
logs = glob.glob(os.path.join(simdir, "prod_*to*ps.log"))
60+
rx = re.compile(r"prod_(\d+)to(\d+)ps\.log$")
61+
total_ps, total_frames = 0.0, 0
62+
for lp in logs:
63+
m = rx.search(os.path.basename(lp))
64+
if not m:
65+
continue
66+
total_ps += max(0.0, float(m.group(2)) - float(m.group(1)))
67+
total_frames += sum(1 for _ in _data_lines(lp))
68+
if total_ps > 0 and total_frames > 0:
69+
dt = total_ps / total_frames
70+
if dt > 0:
71+
return dt, "chunk logs/chunk names"
72+
return None, None
73+
74+
75+
def moving_average(y, win):
76+
if win <= 1 or win > len(y):
77+
return None
78+
kernel = np.ones(win) / win
79+
return np.convolve(y, kernel, mode="valid")
80+
81+
82+
def main():
83+
args = parse_args()
84+
simdir = os.path.abspath(args.simdir)
85+
label = os.path.basename(simdir.rstrip(os.sep)) or "sim"
86+
87+
topo = os.path.abspath(args.topology or os.path.join(simdir, "solvated.pdb"))
88+
traj = os.path.abspath(args.trajectory or os.path.join(simdir, "prod_full.dcd"))
89+
if not os.path.isfile(topo):
90+
sys.exit(f"Topology not found: {topo}")
91+
if not os.path.isfile(traj):
92+
sys.exit(f"Trajectory not found: {traj}")
93+
94+
outdir = args.outdir or f"analysis_{label}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
95+
os.makedirs(outdir, exist_ok=True)
96+
97+
u = mda.Universe(topo, traj)
98+
n_frames = len(u.trajectory)
99+
100+
interval = args.interval
101+
source = "user"
102+
if interval is None:
103+
interval, source = detect_interval_from_logs(simdir, n_frames)
104+
if interval is None:
105+
try:
106+
interval = float(u.trajectory.ts.dt)
107+
source = "DCD header"
108+
except Exception:
109+
interval = None
110+
if interval is None:
111+
sys.exit("Could not infer interval; pass --interval.")
112+
113+
print(f"Frames={n_frames}, interval={interval:.3f} ps ({source})")
114+
115+
heavy = u.select_atoms("not name H*")
116+
rg_vals = []
117+
for _ts in u.trajectory:
118+
coords = heavy.positions
119+
cog = coords.mean(axis=0)
120+
rg = np.sqrt(((coords - cog) ** 2).sum(axis=1).mean()) / 10.0
121+
rg_vals.append(rg)
122+
rg_vals = np.asarray(rg_vals)
123+
times = np.arange(n_frames) * interval
124+
125+
csv_path = os.path.join(outdir, "rg.csv")
126+
np.savetxt(csv_path, np.column_stack([times, rg_vals]), delimiter=",", header="time_ps,rg_nm", comments="")
127+
128+
stride = max(1, args.plot_stride)
129+
plt.figure(figsize=(10, 6))
130+
plt.plot(times[::stride], rg_vals[::stride], lw=0.9, alpha=0.4, label=f"Raw (stride={stride})")
131+
132+
win = max(1, int(round(args.smooth_ps / interval)))
133+
y_s = moving_average(rg_vals, win)
134+
if y_s is not None:
135+
t_s = times[win - 1 :]
136+
plt.plot(t_s, y_s, lw=2.0, label=f"Moving avg ({args.smooth_ps:g} ps)")
137+
138+
plt.xlabel("Time (ps)")
139+
plt.ylabel("Radius of Gyration (nm)")
140+
plt.title("Rg vs. Time")
141+
plt.grid(True, alpha=0.3)
142+
plt.legend()
143+
plt.tight_layout()
144+
fig_path = os.path.join(outdir, "rg_vs_time.png")
145+
plt.savefig(fig_path, dpi=300)
146+
plt.close()
147+
148+
print(f"Saved: {csv_path}")
149+
print(f"Saved: {fig_path}")
150+
151+
152+
if __name__ == "__main__":
153+
main()
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
#!/usr/bin/env python3
2+
"""
3+
openmm_rmsd.py
4+
5+
Compute backbone RMSD and save:
6+
- rmsd.csv
7+
- rmsd_vs_time.png
8+
"""
9+
10+
import argparse
11+
import datetime
12+
import glob
13+
import os
14+
import re
15+
import sys
16+
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
import MDAnalysis as mda
20+
from MDAnalysis.analysis import rms
21+
22+
23+
def parse_args():
24+
p = argparse.ArgumentParser(description="Compute backbone RMSD")
25+
p.add_argument("simdir", help="Simulation directory (contains solvated.pdb, prod_full.dcd)")
26+
p.add_argument("-t", "--topology", default=None)
27+
p.add_argument("-x", "--trajectory", default=None)
28+
p.add_argument("-i", "--interval", type=float, default=None, help="Frame interval in ps")
29+
p.add_argument("--smooth-ps", type=float, default=25.0, help="Moving-average window in ps")
30+
p.add_argument("--plot-stride", type=int, default=5, help="Stride for raw line plotting")
31+
p.add_argument("-o", "--outdir", default=None)
32+
return p.parse_args()
33+
34+
35+
def _data_lines(path):
36+
with open(path, "r", encoding="utf-8", errors="ignore") as f:
37+
for line in f:
38+
s = line.strip()
39+
if s and s[0].isdigit():
40+
yield s
41+
42+
43+
def detect_interval_from_logs(simdir, n_frames):
44+
if n_frames > 1:
45+
merged = os.path.join(simdir, "prod_full.log")
46+
if os.path.isfile(merged):
47+
vals = []
48+
for s in _data_lines(merged):
49+
parts = s.split()
50+
if len(parts) >= 2:
51+
try:
52+
vals.append(float(parts[1]))
53+
except Exception:
54+
pass
55+
if len(vals) > 1:
56+
dt = (vals[-1] - vals[0]) / max(1, len(vals) - 1)
57+
if dt > 0:
58+
return dt, "prod_full.log"
59+
60+
logs = glob.glob(os.path.join(simdir, "prod_*to*ps.log"))
61+
rx = re.compile(r"prod_(\d+)to(\d+)ps\.log$")
62+
total_ps, total_frames = 0.0, 0
63+
for lp in logs:
64+
m = rx.search(os.path.basename(lp))
65+
if not m:
66+
continue
67+
total_ps += max(0.0, float(m.group(2)) - float(m.group(1)))
68+
total_frames += sum(1 for _ in _data_lines(lp))
69+
if total_ps > 0 and total_frames > 0:
70+
dt = total_ps / total_frames
71+
if dt > 0:
72+
return dt, "chunk logs/chunk names"
73+
return None, None
74+
75+
76+
def moving_average(y, win):
77+
if win <= 1 or win > len(y):
78+
return None
79+
kernel = np.ones(win) / win
80+
return np.convolve(y, kernel, mode="valid")
81+
82+
83+
def main():
84+
args = parse_args()
85+
simdir = os.path.abspath(args.simdir)
86+
label = os.path.basename(simdir.rstrip(os.sep)) or "sim"
87+
88+
topo = os.path.abspath(args.topology or os.path.join(simdir, "solvated.pdb"))
89+
traj = os.path.abspath(args.trajectory or os.path.join(simdir, "prod_full.dcd"))
90+
if not os.path.isfile(topo):
91+
sys.exit(f"Topology not found: {topo}")
92+
if not os.path.isfile(traj):
93+
sys.exit(f"Trajectory not found: {traj}")
94+
95+
outdir = args.outdir or f"analysis_{label}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
96+
os.makedirs(outdir, exist_ok=True)
97+
98+
u = mda.Universe(topo, traj)
99+
n_frames = len(u.trajectory)
100+
101+
interval = args.interval
102+
source = "user"
103+
if interval is None:
104+
interval, source = detect_interval_from_logs(simdir, n_frames)
105+
if interval is None:
106+
try:
107+
interval = float(u.trajectory.ts.dt)
108+
source = "DCD header"
109+
except Exception:
110+
interval = None
111+
if interval is None:
112+
sys.exit("Could not infer interval; pass --interval.")
113+
114+
print(f"Frames={n_frames}, interval={interval:.3f} ps ({source})")
115+
116+
calc = rms.RMSD(u, u, select="backbone", ref_frame=0)
117+
calc.run()
118+
rmsd_nm = calc.results.rmsd[:, 2] / 10.0
119+
times = np.arange(n_frames) * interval
120+
121+
csv_path = os.path.join(outdir, "rmsd.csv")
122+
np.savetxt(csv_path, np.column_stack([times, rmsd_nm]), delimiter=",", header="time_ps,rmsd_nm", comments="")
123+
124+
stride = max(1, args.plot_stride)
125+
plt.figure(figsize=(10, 6))
126+
plt.plot(times[::stride], rmsd_nm[::stride], lw=0.9, alpha=0.4, label=f"Raw (stride={stride})")
127+
128+
win = max(1, int(round(args.smooth_ps / interval)))
129+
y_s = moving_average(rmsd_nm, win)
130+
if y_s is not None:
131+
t_s = times[win - 1 :]
132+
plt.plot(t_s, y_s, lw=2.0, label=f"Moving avg ({args.smooth_ps:g} ps)")
133+
134+
plt.xlabel("Time (ps)")
135+
plt.ylabel("RMSD (nm)")
136+
plt.title("Backbone RMSD vs. Time")
137+
plt.grid(True, alpha=0.3)
138+
plt.legend()
139+
plt.tight_layout()
140+
fig_path = os.path.join(outdir, "rmsd_vs_time.png")
141+
plt.savefig(fig_path, dpi=300)
142+
plt.close()
143+
144+
print(f"Saved: {csv_path}")
145+
print(f"Saved: {fig_path}")
146+
147+
148+
if __name__ == "__main__":
149+
main()
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
#!/usr/bin/env python3
2+
"""
3+
openmm_rmsf.py
4+
5+
Compute C-alpha RMSF and save:
6+
- rmsf.csv
7+
- rmsf_per_residue.png
8+
"""
9+
10+
import argparse
11+
import datetime
12+
import os
13+
import sys
14+
15+
import matplotlib.pyplot as plt
16+
import numpy as np
17+
import mdtraj as md
18+
19+
20+
def parse_args():
21+
p = argparse.ArgumentParser(description="Compute C-alpha RMSF")
22+
p.add_argument("simdir", help="Simulation directory (contains solvated.pdb, prod_full.dcd)")
23+
p.add_argument("-t", "--topology", default=None)
24+
p.add_argument("-x", "--trajectory", default=None)
25+
p.add_argument("-o", "--outdir", default=None)
26+
p.add_argument("--ylim", nargs=2, type=float, default=None, metavar=("YMIN", "YMAX"))
27+
p.add_argument("--resid-min", type=int, default=None, help="Optional lower residue bound")
28+
p.add_argument("--resid-max", type=int, default=None, help="Optional upper residue bound")
29+
return p.parse_args()
30+
31+
32+
def main():
33+
args = parse_args()
34+
simdir = os.path.abspath(args.simdir)
35+
label = os.path.basename(simdir.rstrip(os.sep)) or "sim"
36+
37+
topo = os.path.abspath(args.topology or os.path.join(simdir, "solvated.pdb"))
38+
traj = os.path.abspath(args.trajectory or os.path.join(simdir, "prod_full.dcd"))
39+
if not os.path.isfile(topo):
40+
sys.exit(f"Topology not found: {topo}")
41+
if not os.path.isfile(traj):
42+
sys.exit(f"Trajectory not found: {traj}")
43+
44+
outdir = args.outdir or f"analysis_{label}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
45+
os.makedirs(outdir, exist_ok=True)
46+
47+
t = md.load(traj, top=topo)
48+
t.superpose(t, 0)
49+
ca_idx = t.topology.select("name CA")
50+
rmsf_nm = md.rmsf(t, t, atom_indices=ca_idx) / 10.0
51+
resids = np.array([t.topology.atom(i).residue.resSeq for i in ca_idx], dtype=int)
52+
53+
mask = np.ones_like(resids, dtype=bool)
54+
if args.resid_min is not None:
55+
mask &= (resids >= args.resid_min)
56+
if args.resid_max is not None:
57+
mask &= (resids <= args.resid_max)
58+
59+
x = resids[mask]
60+
y = rmsf_nm[mask]
61+
62+
csv_path = os.path.join(outdir, "rmsf.csv")
63+
np.savetxt(csv_path, np.column_stack([x, y]), delimiter=",", header="residue,rmsf_nm", comments="")
64+
65+
plt.figure(figsize=(10, 6))
66+
plt.plot(x, y, lw=1.5)
67+
plt.xlabel("Residue")
68+
plt.ylabel("RMSF (nm)")
69+
plt.title("Backbone Cα RMSF per Residue")
70+
plt.grid(True, alpha=0.3)
71+
if args.ylim is not None:
72+
plt.ylim(args.ylim[0], args.ylim[1])
73+
plt.tight_layout()
74+
fig_path = os.path.join(outdir, "rmsf_per_residue.png")
75+
plt.savefig(fig_path, dpi=300)
76+
plt.close()
77+
78+
print(f"Saved: {csv_path}")
79+
print(f"Saved: {fig_path}")
80+
81+
82+
if __name__ == "__main__":
83+
main()

0 commit comments

Comments
 (0)