Skip to content

Commit 7402f35

Browse files
committed
Added bootstrapping code
1 parent 1138917 commit 7402f35

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/bin/bash
2+
#SBATCH --job-name=wme-bootstrap # create a short name for your job
3+
#SBATCH --nodes=1 # node count
4+
#SBATCH --ntasks=5 # total number of tasks across all nodes
5+
#SBATCH --cpus-per-task=1
6+
#SBATCH --time=24:00:00 # total run time limit (HH:MM:SS)
7+
#SBATCH --mem=20G
8+
#SBATCH --partition=highmem1
9+
#SBATCH --qos=highmem1
10+
#SBATCH --account=iacc_gbuzzell
11+
#SBATCH --output=%x-%j.out
12+
#SBATCH --mail-type=end # send email when job ends
13+
14+
15+
16+
####conda init
17+
conda activate /home/fzaki001/.conda/envs/wme-env
18+
conda run -n wme-env python para_bootstr_curve.py
19+
20+
pwd; hostname; date
21+
echo "flurm cpus per task: $SLURM_CPUS_PER_TASK"
22+
printenv
23+
24+
25+
#errors=$(cat ${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out | grep "Error")
26+
errors=$(cat ${SLURM_JOB_NAME}-${SLURM_JOB_ID}.out | grep "Error")
27+
if [[ -z ${errors} ]]; then
28+
echo "Behavior processing complete."
29+
else
30+
echo "Behavior processing exited with errors: ${errors}"
31+
fi
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import numpy as np
2+
import pickle
3+
import re
4+
import os
5+
from tqdm import tqdm
6+
from scipy.optimize import curve_fit
7+
from sklearn.metrics import r2_score
8+
from math import pi
9+
import multiprocessing as mp
10+
from multiprocessing import Pool
11+
12+
def calculate_rel(r2, preds):
13+
rel = r2 * np.sum((preds - np.mean(preds))**2) / len(preds)
14+
return rel
15+
16+
def run_optimization(x_data, y_data, bounds, n_iterations=10):
17+
best_params = None
18+
best_r2 = -np.inf # Initialize to a very low value
19+
20+
for _ in range(n_iterations):
21+
# Random initial parameters within the bounds
22+
initial_params = [np.random.uniform(low, high) for low, high in zip(bounds[0], bounds[1])]
23+
24+
try:
25+
# Fit the model
26+
fit_params, _ = curve_fit(model, x_data, y_data, p0=initial_params, bounds=bounds, nan_policy='omit')
27+
preds = model(x_data, *fit_params)
28+
# Calculate R² value
29+
try:
30+
r2 = r2_score(y_data, preds)
31+
except:
32+
r2 = np.nan
33+
# print(np.round(r2, 2))
34+
# Calculate Rel value
35+
try:
36+
rel = calculate_rel(r2, preds)
37+
except:
38+
rel = np.nan
39+
40+
# Keep the best fit
41+
if not np.isnan(r2):
42+
if r2 > best_r2:
43+
best_r2 = r2
44+
best_params = fit_params
45+
best_rel = calculate_rel(r2, preds)
46+
else:
47+
best_r2 = r2
48+
best_params = fit_params
49+
best_rel = rel
50+
51+
except RuntimeError:
52+
# Handle the case where the fit doesn't converge
53+
continue
54+
55+
return best_params, best_r2, best_rel # Return the best R² and corresponding parameters
56+
57+
58+
# Define the model function based on the equation
59+
def model(x, # vector of time samples
60+
a_lf, # lf cosine wave amplitude in perfornance units
61+
b_lf, # lf sine wave amplitude in perfornance units
62+
omega_lf, # lf wave frequency
63+
a_hf, # hf cosine wave amplitude in perfornance units
64+
b_hf, # hf sine wave amplitude in perfornance units
65+
omega_hf, # hf wave frequency
66+
c, # constant in performance units
67+
):
68+
return (a_lf * np.cos(2*pi*omega_lf * x) + b_lf * np.sin(2*pi*omega_lf * x) +
69+
a_hf * np.cos(2*pi*omega_hf * x) + b_hf * np.sin(2*pi*omega_hf * x) + c)
70+
71+
def fit_curve_bootstr(subjects, n_iterations=10, output_var="similarity"):
72+
if output_var == "similarity":
73+
bounds=([-1, -1, -2, -1, -1, 3, -2], # Lower bounds
74+
[1, 1, 2, 1, 1, 13, 2])
75+
elif output_var == "hr":
76+
bounds=([-0.5, -0.5, -2, -0.5, -0.5, 3, -1], # Lower bounds
77+
[0.5, 0.5, 2, 0.5, 0.5, 13, 1])
78+
79+
for df_var in ["acc", "congruent"]:
80+
for condition in [0, 1]:
81+
participants_results = {"sub": [], "results": []}
82+
for sub in subjects:
83+
print(f"Fitting curves to bootstrapped data of sub-{sub}...")
84+
individual_results = []
85+
p_matrix = np.load(f"{npy_save_path}raw/sub-{sub}_{output_var}_{df_var}_{condition}.npy")
86+
for i in tqdm(range(p_matrix.shape[0])):
87+
x_data = p_matrix[i, :, 0]
88+
y_data = p_matrix[i, :, 1]
89+
best_fit_params, best_r2, best_rel = run_optimization(x_data, y_data, bounds, n_iterations=n_iterations)
90+
individual_results.append((best_fit_params, best_r2, best_rel))
91+
participants_results["sub"].append(sub)
92+
participants_results["results"].append(individual_results)
93+
94+
with open(f"{npy_save_path}res_curve/{output_var}_{df_var}_{condition}.pkl", "wb") as tf:
95+
pickle.dump(participants_results, tf)
96+
97+
npy_save_path = "/Users/fzaki001/Documents/working-memory-error-dataset/derivatives/face-jitter/behavior/bootstrap/"
98+
pattern = re.compile(r'sub-(\d+)')
99+
subjects = sorted(list(set([pattern.search(file).group(1) for file in os.listdir(npy_save_path+"raw/")])))
100+
101+
PROCESSES = mp.cpu_count()
102+
103+
if __name__ == "__main__":
104+
with Pool(processes=PROCESSES) as pool:
105+
pool.map(fit_curve_bootstr, [subjects])

0 commit comments

Comments
 (0)