Skip to content

Commit 30e0ea5

Browse files
committed
Updated code
1 parent 9ba2608 commit 30e0ea5

File tree

6 files changed

+36544
-17
lines changed

6 files changed

+36544
-17
lines changed

code/bootstrap_func.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import os
2+
import pandas as pd
3+
import numpy as np
4+
import scipy.stats as stats
5+
import ast
6+
from scipy.stats import norm, uniform
7+
from scipy.optimize import curve_fit
8+
from sklearn.metrics import r2_score
9+
from math import pi
10+
from itertools import combinations
11+
import re
12+
import multiprocessing as mp
13+
from multiprocessing import Pool
14+
import time
15+
from tqdm import tqdm
16+
PROCESSES = mp.cpu_count()
17+
18+
19+
def calculate_rel(r2, preds):
20+
rel = r2 * np.sum((preds - np.mean(preds))**2) / len(preds)
21+
return rel
22+
23+
def run_optimization(x_data, y_data, bounds, n_iterations=10):
24+
best_params = None
25+
best_r2 = -np.inf # Initialize to a very low value
26+
27+
for _ in range(n_iterations):
28+
# Random initial parameters within the bounds
29+
initial_params = [np.random.uniform(low, high) for low, high in zip(bounds[0], bounds[1])]
30+
31+
try:
32+
# Fit the model
33+
fit_params, _ = curve_fit(model, x_data, y_data, p0=initial_params, bounds=bounds, nan_policy='omit')
34+
preds = model(x_data, *fit_params)
35+
# Calculate R² value
36+
try:
37+
r2 = r2_score(y_data, preds)
38+
except:
39+
r2 = np.nan
40+
# print(np.round(r2, 2))
41+
# Calculate Rel value
42+
try:
43+
rel = calculate_rel(r2, preds)
44+
except:
45+
rel = np.nan
46+
47+
# Keep the best fit
48+
if not np.isnan(r2):
49+
if r2 > best_r2:
50+
best_r2 = r2
51+
best_params = fit_params
52+
best_rel = calculate_rel(r2, preds)
53+
else:
54+
best_r2 = r2
55+
best_params = fit_params
56+
best_rel = rel
57+
58+
except RuntimeError:
59+
# Handle the case where the fit doesn't converge
60+
continue
61+
62+
return best_params, best_r2, best_rel # Return the best R² and corresponding parameters
63+
64+
65+
# Define the model function based on the equation
66+
def model(x, # vector of time samples
67+
a_lf, # lf cosine wave amplitude in perfornance units
68+
b_lf, # lf sine wave amplitude in perfornance units
69+
omega_lf, # lf wave frequency
70+
a_hf, # hf cosine wave amplitude in perfornance units
71+
b_hf, # hf sine wave amplitude in perfornance units
72+
omega_hf, # hf wave frequency
73+
c, # constant in performance units
74+
):
75+
return (a_lf * np.cos(2*pi*omega_lf * x) + b_lf * np.sin(2*pi*omega_lf * x) +
76+
a_hf * np.cos(2*pi*omega_hf * x) + b_hf * np.sin(2*pi*omega_hf * x) + c)
77+
78+
def bootstrap_individual_stats(agg_dfs, n_bootstraps=10, n_iterations=10):
79+
bootstrap_results = []
80+
81+
for p_data in tqdm(agg_dfs):
82+
participant_results = []
83+
y_data = p_data[output_var].to_numpy()
84+
x_data = p_data['binned_jitter'].to_numpy()
85+
N = len(x_data)
86+
87+
for _ in tqdm(range(n_bootstraps)):
88+
# Generate a bootstrapped sample by resampling with replacement
89+
bootstrap_indices = np.random.choice(np.arange(N), size=N, replace=True)
90+
x_bootstrap = x_data[bootstrap_indices]
91+
y_bootstrap = y_data[bootstrap_indices]
92+
93+
# Run the optimization on the bootstrapped data
94+
best_fit_params, best_r2, best_rel = run_optimization(x_bootstrap, y_bootstrap, bounds, n_iterations=n_iterations)
95+
96+
# Store the result
97+
participant_results.append((best_fit_params, best_r2, best_rel))
98+
99+
bootstrap_results.append(participant_results)
100+
101+
return bootstrap_results
102+
103+
104+
if __name__ == "__main__":
105+
# first way, using multiprocessing
106+
output_var = "similarity"
107+
# Define bounds for the parameters
108+
if output_var == "similarity":
109+
bounds=([-1, -1, -2, -1, -1, 3, -2], # Lower bounds
110+
[1, 1, 2, 1, 1, 13, 2])
111+
elif output_var == "hr":
112+
bounds=([-0.5, -0.5, -2, -0.5, -0.5, 3, -1], # Lower bounds
113+
[0.5, 0.5, 2, 0.5, 0.5, 13, 1])
114+
agg_dfs = []
115+
for i in range(20):
116+
agg_dfs.append(pd.read_csv(f"/Users/fzaki001/Documents/DA/wme-face-jitter/{i}.csv"))
117+
118+
start_time = time.perf_counter()
119+
with Pool(processes=PROCESSES) as pool:
120+
result = bootstrap_individual_stats(agg_dfs)
121+
finish_time = time.perf_counter()
122+
print("Program finished in {} seconds - using multiprocessing".format(finish_time-start_time))
123+
print("---")
124+
# second way, serial computation
125+
start_time = time.perf_counter()
126+
result = []
127+
for x in agg_dfs:
128+
result.append(bootstrap_individual_stats([x]))
129+
finish_time = time.perf_counter()
130+
print("Program finished in {} seconds".format(finish_time-start_time))

0 commit comments

Comments
 (0)