From 9e69444e4e583b6715adeb98b62ace048c23c171 Mon Sep 17 00:00:00 2001 From: tot Date: Thu, 10 Aug 2023 17:33:54 +0200 Subject: [PATCH] fix delayed frames --- .gitignore | 1 + __init__.py | 0 capture.py | 218 ++++++++---------------------- experiments/__init__.py | 0 experiments/delayed_frames.py | 84 ++++++++++++ experiments/list_webcams.py | 5 +- experiments/pygame_webcam.py | 37 +++++ main.py | 53 +++++--- modules/dataset.py | 20 +++ modules/eye_position_predictor.py | 10 +- modules/list_webcams.py | 32 +++++ modules/spiral.py | 41 ++++++ prepare.py | 14 +- train.py | 74 +++++----- 14 files changed, 358 insertions(+), 231 deletions(-) create mode 100644 __init__.py create mode 100644 experiments/__init__.py create mode 100644 experiments/delayed_frames.py create mode 100644 experiments/pygame_webcam.py create mode 100644 modules/dataset.py create mode 100644 modules/list_webcams.py create mode 100644 modules/spiral.py diff --git a/.gitignore b/.gitignore index 82f0c3a..e036a95 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /data/ +/archive/ diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/capture.py b/capture.py index ed464b1..7b4bdba 100644 --- a/capture.py +++ b/capture.py @@ -9,87 +9,8 @@ import datetime import os import numpy as np - - -def list_webcams(): - from collections import defaultdict - import re - import subprocess - - # Command as a list of strings - - completed_process = subprocess.run( - 'v4l2-ctl --list-devices 2>/dev/null', - shell=True, stdout=subprocess.PIPE, text=True - ) - - stdout_output = completed_process.stdout - # print("Stdout Output:") - # print(stdout_output) - - device_info = defaultdict(list) - current_device = "" - - for line in stdout_output.splitlines(): - line = line.strip() - if line: - if re.match(r"^\w+.*:", line): - current_device = line - else: - device_info[current_device].append(line) - - parsed_dict = dict(device_info) - - # print(parsed_dict) - return parsed_dict - - -def track(): - # cam = cv2.VideoCapture(0) - face_mesh = mp.solutions.face_mesh.FaceMesh(refine_landmarks=True) - screen_w, screen_h = pyautogui.size() - - while True: - _, frame = cam2.read() - frame = cv2.flip(frame, 1) - rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) - output = face_mesh.process(rgb_frame) - landmark_points = output.multi_face_landmarks - frame_h, frame_w, _ = frame.shape - if landmark_points: - landmarks = landmark_points[0].landmark - for id, landmark in enumerate(landmarks[474:478]): - x = int(landmark.x * frame_w) - y = int(landmark.y * frame_h) - cv2.circle(frame, (x, y), 3, (0, 255, 0)) - if id == 1: - screen_x = screen_w * landmark.x - screen_y = screen_h * landmark.y - # pyautogui.moveTo(screen_x, screen_y) - left = [landmarks[145], landmarks[159]] - for landmark in left: - x = int(landmark.x * frame_w) - y = int(landmark.y * frame_h) - cv2.circle(frame, (x, y), 3, (0, 255, 255)) - # if (left[0].y - left[1].y) < 0.004: - # pyautogui.click() - - # pyautogui.sleep(1) - cv2.imshow('Eye Controlled Mouse', frame) - cv2.waitKey(1) - - -def on_click(x, y, button, pressed): - should_handle = pressed and y < 1439 - # if not should_handle: - # return - - print("Mouse clicked", x, y, button, pressed) - # _, frame1 = cam1.read() - # _, frame2 = cam2.read() - # print(frame1.empty()) - # cv2.imwrite(f'./data/{time.time()} re {x} {y}.jpg', frame1) - # cv2.imwrite(f'./data/{time.time()} ir {x} {y}.jpg', frame2) +from modules.list_webcams import list_webcams +from modules.spiral import spiral def cams_init(): @@ -100,10 +21,11 @@ def cams_init(): print('cam1') cam1 = cv2.VideoCapture(briocams[0]) - cam1.set(cv2.CAP_PROP_FRAME_WIDTH, 1920) - cam1.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080) - cam1.set(cv2.CAP_PROP_FPS, 30) + cam1.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) + cam1.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) + cam1.set(cv2.CAP_PROP_FPS, 60) camsdict['brio'] = cam1 + # having multiple cams enabled slows down and delays capture # print('cam2') # cam2 = cv2.VideoCapture(briocams[2]) @@ -112,12 +34,12 @@ def cams_init(): # cam2.set(cv2.CAP_PROP_FPS, 30) # camsdict['brioBW'] = cam2 # regular BRIO cam hangs when BW cam is in use, same behavior in guvcview - print('cam3') - cam3 = cv2.VideoCapture(intcams[0]) - cam3.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) - cam3.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) - cam3.set(cv2.CAP_PROP_FPS, 30) - camsdict['integrated'] = cam3 + # print('cam3') + # cam3 = cv2.VideoCapture(intcams[0]) + # cam3.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) + # cam3.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) + # cam3.set(cv2.CAP_PROP_FPS, 30) + # camsdict['integrated'] = cam3 # print('cam4') # cam4 = cv2.VideoCapture(intcams[2]) @@ -129,14 +51,6 @@ def cams_init(): return camsdict -def cams_deinit(camsdict): - for name, cam in camsdict.items(): - cam.release() - - -frames = [] - - def cams_capture(cams, iso_date, pos): t0_iter = time.time() t = time.time() @@ -147,77 +61,53 @@ def cams_capture(cams, iso_date, pos): t0 = time.time() ret, frame = cam.read() dt = time.time() - t0 - filename = f'./data/{iso_date}/{camname} {(t * 1000):.0f}-{i+1} [{x} {y}] {dt*1000:.0f}.jpeg' + filename = f'./data/{iso_date}/{camname} {(t * 1000):.0f} {c}-{i+1} [{x} {y}].jpeg' cv2.imwrite(filename, frame) - if i == 0 and camname == 'brio': + if i == 1 and camname == 'brio': fr = frame dt_iter = time.time() - t0_iter print(dt_iter) return fr -def spiral(xmin, ymin, xmax, ymax, xsteps, ysteps): - points_list = [] - x, y = 0, 0 - num_points = (xsteps + 1) * (ysteps + 1) - dir = 'right' - x0, y0, x1, y1 = 0, 0, xsteps, ysteps - - while len(points_list) < num_points: - points_list.append([x, y]) - if dir == 'right': - x += 1 - if x == x1: - y0 += 1 - dir = 'down' - continue - if dir == 'down': - y += 1 - if y == y1: - x1 -= 1 - dir = 'left' - continue - if dir == 'left': - x -= 1 - if x == x0: - y1 -= 1 - dir = 'up' - continue - if dir == 'up': - y -= 1 - if y == y0: - x0 += 1 - dir = 'right' - continue - - dx = (xmax - xmin) / xsteps - dy = (ymax - ymin) / ysteps - points = np.array([xmin, ymin]) + np.array(points_list) * np.array([dx, dy]) - return points - - -def main(): - print('hello') - - iso_date = datetime.datetime.now().isoformat() - os.mkdir(f'./data/{iso_date}') - - cams = cams_init() - - edge_offset = 5 - points = spiral(edge_offset, edge_offset, 2560-edge_offset, 1440-edge_offset, 16, 10) - i = 0 - - while True: - step = points[i % len(points)] +def on_press(key): + global cams, iso_date, i, pos + if key == pynput.keyboard.Key.enter: + frames = {} + x, y = pyautogui.position() + t0 = time.time() + for camname, cam in cams.items(): + for j in range(3): + ret, frame = cam.read() + filename = f'./data/{iso_date}/{camname} {i}-{j} [{x} {y}].jpeg' + frames[filename] = frame + dt = time.time() - t0 + print(f'{dt*1000:.0f}') i += 1 - pyautogui.moveTo(*step) - input() - fr = cams_capture(cams, iso_date=iso_date, pos=pyautogui.position()) - cv2.imshow('cam', fr) - cv2.waitKey(1) - - cams_deinit(cams) - - -main() + pyautogui.moveTo(*points[i % len(points)]) + for filename, frame in frames.items(): + cv2.imwrite(filename, frame) + print('save', filename) + + +kb_listener = pynput.keyboard.Listener(on_press=on_press) +kb_listener.start() + +cams = cams_init() +iso_date = datetime.datetime.now().isoformat() +os.mkdir(f'./data/{iso_date}') + +edge_offset = 10 +points = spiral(edge_offset, edge_offset, 2560-edge_offset, 1440-edge_offset, 8, 5) +i = 0 +pyautogui.moveTo(*points[i % len(points)]) + +while True: + for camname, cam in cams.items(): + t0 = time.time() + ret, frame = cam.read() + dt = time.time() - t0 + if camname == 'brio': + cv2.imshow('cam', frame) + # print(dt) + cv2.waitKey(1) diff --git a/experiments/__init__.py b/experiments/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/delayed_frames.py b/experiments/delayed_frames.py new file mode 100644 index 0000000..d662ad2 --- /dev/null +++ b/experiments/delayed_frames.py @@ -0,0 +1,84 @@ +import cv2 +import mediapipe as mp +import pyautogui +from pynput.mouse import Listener +from pynput import keyboard, mouse +import pynput +import uuid +import time +import datetime +import os +import numpy as np +import sys + +sys.path.append("../modules") +from list_webcams import list_webcams # noqa + + +def cams_init(): + webcams = list_webcams() + intcams = webcams[[cam for cam in webcams.keys() if 'Integrated' in cam][0]] + briocams = webcams[[cam for cam in webcams.keys() if 'BRIO' in cam][0]] + camsdict = {} + + print('cam1') + cam1 = cv2.VideoCapture(briocams[0]) + cam1.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) + cam1.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) + cam1.set(cv2.CAP_PROP_FPS, 60) + camsdict['brio'] = cam1 + + # print('cam2') + # cam2 = cv2.VideoCapture(briocams[2]) + # cam2.set(cv2.CAP_PROP_FRAME_WIDTH, 340) + # cam2.set(cv2.CAP_PROP_FRAME_HEIGHT, 340) + # cam2.set(cv2.CAP_PROP_FPS, 30) + # camsdict['brioBW'] = cam2 # regular BRIO cam hangs when BW cam is in use, same behavior in guvcview + + # print('cam3') + # cam3 = cv2.VideoCapture(intcams[0]) + # cam3.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) + # cam3.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) + # cam3.set(cv2.CAP_PROP_FPS, 30) + # camsdict['integrated'] = cam3 + + return camsdict + + +def on_press(key): + if key == pynput.keyboard.Key.enter: + frames = {} + t0 = time.time() + for camname, cam in cams.items(): + for i in range(3): + ret, frame = cam.read() + filename = f'./data/{iso_date}/{camname} {t0*1000:.0f} {i}.jpeg' + frames[filename] = frame + dt = time.time() - t0 + print(dt) + for filename, frame in frames.items(): + cv2.imwrite(filename, frame) + print('save', filename) + + +kb_listener = pynput.keyboard.Listener(on_press=on_press) +kb_listener.start() + + +cams = cams_init() +cam = cams['brio'] +iso_date = datetime.datetime.now().isoformat() +os.mkdir(f'./data/{iso_date}') +i = 0 +while True: + if i == 1: + print('ready') + for camname, cam in cams.items(): + t0 = time.time() + ret, frame = cam.read() + dt = time.time() - t0 + if camname == 'brio': + cv2.imshow('cam', frame) + print(dt) + cv2.waitKey(1) + i += 1 diff --git a/experiments/list_webcams.py b/experiments/list_webcams.py index d2038ac..19ccab0 100644 --- a/experiments/list_webcams.py +++ b/experiments/list_webcams.py @@ -5,7 +5,10 @@ def list_webcams(): # Command as a list of strings - completed_process = subprocess.run('v4l2-ctl --list-devices', shell=True, stdout=subprocess.PIPE, text=True) + completed_process = subprocess.run( + 'v4l2-ctl --list-devices 2>/dev/null', + shell=True, stdout=subprocess.PIPE, text=True + ) stdout_output = completed_process.stdout # print("Stdout Output:") diff --git a/experiments/pygame_webcam.py b/experiments/pygame_webcam.py new file mode 100644 index 0000000..00e941c --- /dev/null +++ b/experiments/pygame_webcam.py @@ -0,0 +1,37 @@ +import time +import numpy as np +import pygame +import pygame.camera +import cv2 +import sys +sys.path.append("../modules") +from list_webcams import list_webcams # noqa + + +webcams = list_webcams() +intcams = webcams[[cam for cam in webcams.keys() if 'Integrated' in cam][0]] +briocams = webcams[[cam for cam in webcams.keys() if 'BRIO' in cam][0]] + +pygame.camera.init() +print(briocams) +cam1 = pygame.camera.Camera(briocams[0], (1280, 720)) +cam1.start() + +# cam2 = pygame.camera.Camera(intcams[0], (1280, 720)) +# cam2.start() + +while True: + # img = cam2.get_image() + t0 = time.time() + img = cam1.get_image() + pygame_image_string = pygame.image.tostring(img, 'RGB') + cv2_image_array = np.frombuffer(pygame_image_string, dtype=np.uint8) + cv2_image_array = cv2_image_array.reshape((img.get_height(), img.get_width(), 3)) + bgr = cv2.cvtColor(cv2_image_array, cv2.COLOR_RGB2BGR) + dt = time.time() - t0 + print(dt) + cv2.imshow('cam', bgr) + cv2.waitKey(1) + + +pygame.image.save(img, "filename.jpg") diff --git a/main.py b/main.py index 1c06f50..35022c3 100644 --- a/main.py +++ b/main.py @@ -23,6 +23,9 @@ import pickle from modules.eye_position_predictor import EyePositionPredictor +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f'{device=}') + def list_webcams(): from collections import defaultdict @@ -107,16 +110,6 @@ def cams_capture(cams): return frames -face_mesh = mp.solutions.face_mesh.FaceMesh( - # static_image_mode=True, - max_num_faces=1, - refine_landmarks=True, - min_detection_confidence=0.5, - min_tracking_confidence=0.5, -) -model = EyePositionPredictor.load_from_file('./data/model.pickle') - - def predict(frame): rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) output = face_mesh.process(rgb) @@ -131,7 +124,6 @@ def predict(frame): y = model(X) monsize = np.array([2560, 1440]) - print(y) cursor = (y[0].numpy() + 1) / 2 * monsize cursor = cursor.clip([0, 0], [2560, 1440]) return cursor, faces @@ -147,29 +139,50 @@ def get_paths(globs): photo_globs = [ # '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-08T15:57:06.820873-continuous-ok/brio *.jpeg', # '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-08T16:33:38.163179-3-ok/brio *-1 *.jpeg', - '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-09T15:37:18.761700-first-spiral-ok/brio *.jpeg', + # '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-09T15:37:18.761700-first-spiral-ok/brio *-1 *.jpeg', + '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/*/brio *.jpeg', ] photo_paths = get_paths(photo_globs) +pyautogui.FAILSAFE = False + +face_mesh = mp.solutions.face_mesh.FaceMesh( + # static_image_mode=True, + max_num_faces=1, + refine_landmarks=True, + min_detection_confidence=0.5, + min_tracking_confidence=0.5, +) +model = EyePositionPredictor.load_from_file(sys.argv[1]) + + +numavg = 3 +avgs = np.zeros(shape=(numavg, 2)) + + def main(): + global avgs, numavg print('hello') cams = cams_init() - # while True: - # frames = cams_capture(cams) - # frame = frames['brio'] + while True: + frames = cams_capture(cams) + frame = frames['brio'] - pyautogui.FAILSAFE = False - for filepath in photo_paths: - frame = cv2.imread(filepath) + # for filepath in photo_paths: + # frame = cv2.imread(filepath) cursor, faces = predict(frame) print(cursor) if cursor is not None: - pyautogui.moveTo(*cursor) - draw_landmarks(frame, faces) + avgs = np.roll(avgs, -1, axis=0) + avgs[-1] = cursor + avg = avgs.mean(axis=0) + print(avg) + pyautogui.moveTo(*avg) + # draw_landmarks(frame, faces) # input() cams_deinit(cams) diff --git a/modules/dataset.py b/modules/dataset.py new file mode 100644 index 0000000..7351ce8 --- /dev/null +++ b/modules/dataset.py @@ -0,0 +1,20 @@ +import pickle +import numpy as np + +dataset_filepath = './data/prepared.pickle' + + +class Dataset(): + def read_dataset(filepath=dataset_filepath): + with open(filepath, 'rb') as file: + dataset_list = pickle.load(file) + X = np.array([dp['landmarks'].ravel() for dp in dataset_list]) + y = np.array([dp['cursor'] for dp in dataset_list]) + monsize = np.array([2560, 1440]) + y = y / monsize * 2 - 1 + return X, y + + def save_dataset(dataset, filepath=dataset_filepath): + with open(dataset_filepath, 'wb') as file: + pickle.dump(dataset, file) + print('saved to file', dataset_filepath) diff --git a/modules/eye_position_predictor.py b/modules/eye_position_predictor.py index d1b64bb..0d532e5 100644 --- a/modules/eye_position_predictor.py +++ b/modules/eye_position_predictor.py @@ -5,15 +5,19 @@ class EyePositionPredictor(nn.Module): - def __init__(self, input_size, hidden_size, output_size): + def __init__(self, input_size, output_size): super(EyePositionPredictor, self).__init__() - self.fc1 = nn.Linear(input_size, hidden_size) + self.fc1 = nn.Linear(input_size, 128) self.relu = nn.ReLU() - self.fc2 = nn.Linear(hidden_size, output_size) + self.hidden1 = nn.Linear(128, 16) + self.relu2 = nn.ReLU() + self.fc2 = nn.Linear(16, output_size) def forward(self, x): x = self.fc1(x) x = self.relu(x) + x = self.hidden1(x) + x = self.relu2(x) x = self.fc2(x) return x diff --git a/modules/list_webcams.py b/modules/list_webcams.py new file mode 100644 index 0000000..b56bbc9 --- /dev/null +++ b/modules/list_webcams.py @@ -0,0 +1,32 @@ + +def list_webcams(): + from collections import defaultdict + import re + import subprocess + + # Command as a list of strings + + completed_process = subprocess.run( + 'v4l2-ctl --list-devices 2>/dev/null', + shell=True, stdout=subprocess.PIPE, text=True + ) + + stdout_output = completed_process.stdout + # print("Stdout Output:") + # print(stdout_output) + + device_info = defaultdict(list) + current_device = "" + + for line in stdout_output.splitlines(): + line = line.strip() + if line: + if re.match(r"^\w+.*:", line): + current_device = line + else: + device_info[current_device].append(line) + + parsed_dict = dict(device_info) + + # print(parsed_dict) + return parsed_dict diff --git a/modules/spiral.py b/modules/spiral.py new file mode 100644 index 0000000..3788e83 --- /dev/null +++ b/modules/spiral.py @@ -0,0 +1,41 @@ +import numpy as np + + +def spiral(xmin, ymin, xmax, ymax, xsteps, ysteps): + points_list = [] + x, y = 0, 0 + num_points = (xsteps + 1) * (ysteps + 1) + dir = 'right' + x0, y0, x1, y1 = 0, 0, xsteps, ysteps + + while len(points_list) < num_points: + points_list.append([x, y]) + if dir == 'right': + x += 1 + if x == x1: + y0 += 1 + dir = 'down' + continue + if dir == 'down': + y += 1 + if y == y1: + x1 -= 1 + dir = 'left' + continue + if dir == 'left': + x -= 1 + if x == x0: + y1 -= 1 + dir = 'up' + continue + if dir == 'up': + y -= 1 + if y == y0: + x0 += 1 + dir = 'right' + continue + + dx = (xmax - xmin) / xsteps + dy = (ymax - ymin) / ysteps + points = np.array([xmin, ymin]) + np.array(points_list) * np.array([dx, dy]) + return points diff --git a/prepare.py b/prepare.py index bccafcf..61b93ae 100644 --- a/prepare.py +++ b/prepare.py @@ -11,6 +11,7 @@ import pickle from modules.mp_landmarks_to_points import mp_landmarks_to_points from modules.draw_landmarks import draw_landmarks +from modules.dataset import Dataset def get_xy_from_filename(filename): @@ -24,12 +25,12 @@ def mp_detect_faces(rgb, num_warmup=3, num_avg=3): for i in range(num_warmup): output = face_mesh.process(rgb) output = face_mesh.process(rgb) - if output is None: + if output is None or output.multi_face_landmarks is None: return None faces = mp_landmarks_to_points(output.multi_face_landmarks) for i in range(num_avg - 1): output = face_mesh.process(rgb) - if output is None: + if output is None or output.multi_face_landmarks is None: return None faces += mp_landmarks_to_points(output.multi_face_landmarks) faces /= num_avg @@ -46,8 +47,8 @@ def get_paths(globs): photo_globs = [ # '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-08T15:57:06.820873-continuous-ok/brio *.jpeg', # '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-08T16:33:38.163179-3-ok/brio *-1 *.jpeg', - '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-09T15:37:18.761700-first-spiral-ok/brio *-1 *.jpeg', - '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/2023-08-09T19:53:51.113869-1-spiral-ok/brio *-1 *.jpeg', + '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/*/brio *-2 *.jpeg', + '/home/anatoly/_tot/proj/ml/eye_controlled_mouse/data/*/brio *-3 *.jpeg', ] photo_paths = get_paths(photo_globs) @@ -76,9 +77,6 @@ def get_paths(globs): dataset.append(datapoint) draw_landmarks(img, faces) -dataset_filepath = './data/prepared.pickle' -with open(dataset_filepath, 'wb') as file: - pickle.dump(dataset, file) -print('saved to file', dataset_filepath) +Dataset.save_dataset(dataset) cv2.destroyAllWindows() diff --git a/train.py b/train.py index 3c6aa8e..91458d8 100644 --- a/train.py +++ b/train.py @@ -7,27 +7,10 @@ import numpy as np import pickle from modules.eye_position_predictor import EyePositionPredictor +from modules.dataset import Dataset -dataset_filepath = './data/prepared.pickle' - - -def read_dataset(filepath): - with open(filepath, 'rb') as file: - dataset_list = pickle.load(file) - X = np.array([dp['landmarks'].ravel() for dp in dataset_list]) - y = np.array([dp['cursor'] for dp in dataset_list]) - monsize = np.array([2560, 1440]) - y = y / monsize * 2 - 1 - return X, y - - -# Generate random data for demonstration (replace with your data) -# num_samples = 1000 -# num_landmarks = 68 -# X = np.random.randn(num_samples, num_landmarks) -# y = np.random.rand(num_samples, 2) * 2 - 1 # Scaling to -1 to 1 range -X, y = read_dataset(dataset_filepath) +X, y = Dataset.read_dataset() num_landmarks = X.shape[1] print('shapes', X.shape, y.shape) print('mean xy', y.mean(axis=0)) @@ -48,34 +31,55 @@ def read_dataset(filepath): # Initialize the model input_size = num_landmarks -hidden_size = 128 -output_size = 2 -model = EyePositionPredictor(input_size, hidden_size, output_size) +output_size = y.shape[1] +model = EyePositionPredictor(input_size, output_size) +if len(sys.argv) >= 3: + model = EyePositionPredictor.load_from_file(sys.argv[2]) + + +class CustomLoss(nn.Module): + def __init__(self, penalty_factor): + super(CustomLoss, self).__init__() + self.penalty_factor = penalty_factor + + def forward(self, y_pred, y_true): + squared_errors = (y_pred - y_true) ** 2 + weighted_errors = squared_errors * (1 + y_true ** 2) * self.penalty_factor + loss = torch.mean(weighted_errors) + return loss + # Define loss function and optimizer -criterion = nn.MSELoss() -optimizer = optim.Adam(model.parameters(), lr=0.001) +# criterion = nn.MSELoss() +criterion = CustomLoss(1.5) +optimizer = optim.Adam(model.parameters(), lr=float(sys.argv[1])) + + +def evaluate(): + model.eval() + with torch.no_grad(): + test_outputs = model(X_test_tensor) + test_loss = criterion(test_outputs, y_test_tensor) + print(f'Test Loss: {test_loss.item():.4f}') + model.train() + # Training loop -num_epochs = 10000 +num_epochs = int(100e3) for epoch in range(num_epochs): optimizer.zero_grad() outputs = model(X_train_tensor) loss = criterion(outputs, y_train_tensor) loss.backward() optimizer.step() - if (epoch + 1) % 100 == 0: + if (epoch + 1) % 1000 == 0: print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}') + if (epoch + 1) % 5000 == 0: + evaluate() + model.save_to_file(f'./data/model-{epoch+1}.pickle') -# Evaluation -model.eval() -with torch.no_grad(): - test_outputs = model(X_test_tensor) - test_loss = criterion(test_outputs, y_test_tensor) - print(f'Test Loss: {test_loss.item():.4f}') +evaluate() model_filepath = './data/model.pickle' -with open(model_filepath, 'wb') as model_file: - pickle.dump(model, model_file) -print('saved model to file', model_filepath) +model.save_to_file(model_filepath)