1
+ import numpy as np
2
+ import pickle
3
+ import re
4
+ import os
5
+ from tqdm import tqdm
6
+ from scipy .optimize import curve_fit
7
+ from sklearn .metrics import r2_score
8
+ from math import pi
9
+ import multiprocessing as mp
10
+ from multiprocessing import Pool
11
+
12
+ def calculate_rel (r2 , preds ):
13
+ rel = r2 * np .sum ((preds - np .mean (preds ))** 2 ) / len (preds )
14
+ return rel
15
+
16
+ def run_optimization (x_data , y_data , bounds , n_iterations = 10 ):
17
+ best_params = None
18
+ best_r2 = - np .inf # Initialize to a very low value
19
+
20
+ for _ in range (n_iterations ):
21
+ # Random initial parameters within the bounds
22
+ initial_params = [np .random .uniform (low , high ) for low , high in zip (bounds [0 ], bounds [1 ])]
23
+
24
+ try :
25
+ # Fit the model
26
+ fit_params , _ = curve_fit (model , x_data , y_data , p0 = initial_params , bounds = bounds , nan_policy = 'omit' )
27
+ preds = model (x_data , * fit_params )
28
+ # Calculate R² value
29
+ try :
30
+ r2 = r2_score (y_data , preds )
31
+ except :
32
+ r2 = np .nan
33
+ # print(np.round(r2, 2))
34
+ # Calculate Rel value
35
+ try :
36
+ rel = calculate_rel (r2 , preds )
37
+ except :
38
+ rel = np .nan
39
+
40
+ # Keep the best fit
41
+ if not np .isnan (r2 ):
42
+ if r2 > best_r2 :
43
+ best_r2 = r2
44
+ best_params = fit_params
45
+ best_rel = calculate_rel (r2 , preds )
46
+ else :
47
+ best_r2 = r2
48
+ best_params = fit_params
49
+ best_rel = rel
50
+
51
+ except RuntimeError :
52
+ # Handle the case where the fit doesn't converge
53
+ continue
54
+
55
+ return best_params , best_r2 , best_rel # Return the best R² and corresponding parameters
56
+
57
+
58
+ # Define the model function based on the equation
59
+ def model (x , # vector of time samples
60
+ a_lf , # lf cosine wave amplitude in perfornance units
61
+ b_lf , # lf sine wave amplitude in perfornance units
62
+ omega_lf , # lf wave frequency
63
+ a_hf , # hf cosine wave amplitude in perfornance units
64
+ b_hf , # hf sine wave amplitude in perfornance units
65
+ omega_hf , # hf wave frequency
66
+ c , # constant in performance units
67
+ ):
68
+ return (a_lf * np .cos (2 * pi * omega_lf * x ) + b_lf * np .sin (2 * pi * omega_lf * x ) +
69
+ a_hf * np .cos (2 * pi * omega_hf * x ) + b_hf * np .sin (2 * pi * omega_hf * x ) + c )
70
+
71
+ def fit_curve_bootstr (subjects , n_iterations = 10 , output_var = "similarity" ):
72
+ if output_var == "similarity" :
73
+ bounds = ([- 1 , - 1 , - 2 , - 1 , - 1 , 3 , - 2 ], # Lower bounds
74
+ [1 , 1 , 2 , 1 , 1 , 13 , 2 ])
75
+ elif output_var == "hr" :
76
+ bounds = ([- 0.5 , - 0.5 , - 2 , - 0.5 , - 0.5 , 3 , - 1 ], # Lower bounds
77
+ [0.5 , 0.5 , 2 , 0.5 , 0.5 , 13 , 1 ])
78
+
79
+ for df_var in ["acc" , "congruent" ]:
80
+ for condition in [0 , 1 ]:
81
+ participants_results = {"sub" : [], "results" : []}
82
+ for sub in subjects :
83
+ print (f"Fitting curves to bootstrapped data of sub-{ sub } ..." )
84
+ individual_results = []
85
+ p_matrix = np .load (f"{ npy_save_path } raw/sub-{ sub } _{ output_var } _{ df_var } _{ condition } .npy" )
86
+ for i in tqdm (range (p_matrix .shape [0 ])):
87
+ x_data = p_matrix [i , :, 0 ]
88
+ y_data = p_matrix [i , :, 1 ]
89
+ best_fit_params , best_r2 , best_rel = run_optimization (x_data , y_data , bounds , n_iterations = n_iterations )
90
+ individual_results .append ((best_fit_params , best_r2 , best_rel ))
91
+ participants_results ["sub" ].append (sub )
92
+ participants_results ["results" ].append (individual_results )
93
+
94
+ with open (f"{ npy_save_path } res_curve/{ output_var } _{ df_var } _{ condition } .pkl" , "wb" ) as tf :
95
+ pickle .dump (participants_results , tf )
96
+
97
+ npy_save_path = "/Users/fzaki001/Documents/working-memory-error-dataset/derivatives/face-jitter/behavior/bootstrap/"
98
+ pattern = re .compile (r'sub-(\d+)' )
99
+ subjects = sorted (list (set ([pattern .search (file ).group (1 ) for file in os .listdir (npy_save_path + "raw/" )])))
100
+
101
+ PROCESSES = mp .cpu_count ()
102
+
103
+ if __name__ == "__main__" :
104
+ with Pool (processes = PROCESSES ) as pool :
105
+ pool .map (fit_curve_bootstr , [subjects ])
0 commit comments