From 457862bbc30af0939fe5621ec4408b88574be645 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 21 Oct 2024 21:17:52 -0400 Subject: [PATCH 01/30] added send thread, merged 2 classes --- src/utils/communication/mpi.py | 147 ++++++++++----------------------- 1 file changed, 43 insertions(+), 104 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index e9b2004..f631647 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -5,21 +5,13 @@ from utils.communication.interface import CommunicationInterface import threading import time -import random -import numpy as np - -if TYPE_CHECKING: - from algos.base_class import BaseNode class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]]): + def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() - self.num_users: int = int(config["num_users"]) # type: ignore - self.finished = False - # Ensure that we are using thread safe threading level self.required_threading_level = MPI.THREAD_MULTIPLE self.threading_level = MPI.Query_thread() @@ -28,124 +20,71 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): if self.required_threading_level > self.threading_level: raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + send_thread = threading.Thread(target=self.send, args=(data)) + send_thread.start() + self.send_event = threading.Event() # Ensures that the listener thread and send thread are not using self.request_source at the same time - self.lock = threading.Lock() + self.source_node_lock = threading.Lock() self.request_source: int | None = None - self.is_working = True - self.communication_cost_received: int = 0 - self.communication_cost_sent: int = 0 - - self.base_node: BaseNode | None = None - - self.listener_thread = threading.Thread(target=self.listener) - self.listener_thread.start() - def initialize(self): pass - def send_quorum(self) -> Any: - # return super().send_quorum(node_ids) - pass - - def register_self(self, obj: "BaseNode"): - self.base_node = obj - - def get_comm_cost(self): - with self.lock: - return self.communication_cost_received, self.communication_cost_sent - def listener(self): """ Runs on listener thread on each node to receive a send request - Once send request is received, the listener thread informs the send + Once send request is received, the listener thread informs the main thread to send the data to the requesting node. """ - while not self.finished: + while True: status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): - with self.lock: - # self.request_source = status.Get_source() - dest = status.Get_source() + with self.source_node_lock: + self.request_source = status.Get_source() - print(f"Node {self.rank} received request from {self.request_source}") - # receive_request = self.comm.irecv(source=self.request_source, tag=1) - # receive_request.wait() - self.comm.recv(source=dest, tag=1) - self.send(dest) - print(f"Node {self.rank} listener thread ended") + self.comm.irecv(source=self.request_source, tag=1) + self.send_event.set() + time.sleep(1) # Simulate waiting time - def get_model(self) -> List[OrderedDict[str, Tensor]] | None: - print(f"getting model from {self.rank}, {self.base_node}") - if not self.base_node: - raise Exception("Base node not registered") - with self.lock: - if self.is_working: - model = self.base_node.get_model_weights() - model = [model] - print(f"Model from {self.rank} acquired") - else: - assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." - model = None - return model - - def send(self, dest: int): + def send(self, data: Any): """ - Node will wait for a request to send data and then send the + Node will wait until request is received and then send data to requesting node. """ - if self.finished: - return - - data = self.get_model() - print(f"Node {self.rank} is sending data to {dest}") - # req = self.comm.Isend(data, dest=int(dest)) - # req.wait() - self.comm.send(data, dest=int(dest)) - - def receive(self, node_ids: List[int]) -> Any: + while True: + # Wait until the listener thread detects a request + self.send_event.wait() + with self.source_node_lock: + dest = self.request_source + + if dest is not None: + req = self.comm.isend(data, dest=int(dest)) + req.wait() + + with self.source_node_lock: + self.request_source = None + + self.send_event.clear() + + def receive(self, node_ids: str | int) -> Any: """ - Node will send a request for data and wait to receive data. + Node will send a request and wait to receive data. """ - max_tries = 10 - assert len(node_ids) == 1, "Too many node_ids to unpack" - node = node_ids[0] - while max_tries > 0: - try: - print(f"Node {self.rank} receiving from {node}") - self.comm.send("", dest=node, tag=1) - # recv_req = self.comm.Irecv([], source=node) - # received_data = recv_req.wait() - received_data = self.comm.recv(source=node) - print(f"Node {self.rank} received data from {node}: {bool(received_data)}") - if not received_data: - raise Exception("Received empty data") - return received_data - except MPI.Exception as e: - print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...") - import traceback - print(f"Traceback: {traceback.print_exc()}") - # sleep for a random time between 1 and 10 seconds - random_time = random.randint(1, 10) - time.sleep(random_time) - max_tries -= 1 - except Exception as e: - print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") - import traceback - print(f"Traceback: {traceback.print_exc()}") - # sleep for a random time between 1 and 10 seconds - random_time = random.randint(1, 10) - time.sleep(random_time) - max_tries -= 1 - print(f"Node {self.rank} received") + node_ids = int(node_ids) + send_req = self.comm.isend("", dest=node_ids, tag=1) + send_req.wait() + recv_req = self.comm.irecv(source=node_ids) + return recv_req.wait() - # deprecated broadcast function - def broadcast(self, data: Any): - for i in range(1, self.size): - if i != self.rank: - self.comm.send(data, dest=i) + # depreciated broadcast function + # def broadcast(self, data: Any): + # for i in range(1, self.size): + # if i != self.rank: + # self.send(i, data) def all_gather(self): """ From cf8dd0e415b10cad0512e443296ed7c02763b7ae Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 22 Oct 2024 10:30:21 -0400 Subject: [PATCH 02/30] improved comments --- src/utils/communication/mpi.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index f631647..409daf3 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -36,7 +36,7 @@ def initialize(self): def listener(self): """ Runs on listener thread on each node to receive a send request - Once send request is received, the listener thread informs the main + Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ while True: @@ -52,7 +52,7 @@ def listener(self): def send(self, data: Any): """ - Node will wait until request is received and then send + Node will wait for a request to send data and then send the data to requesting node. """ while True: @@ -72,7 +72,7 @@ def send(self, data: Any): def receive(self, node_ids: str | int) -> Any: """ - Node will send a request and wait to receive data. + Node will send a request for data and wait to receive data. """ node_ids = int(node_ids) send_req = self.comm.isend("", dest=node_ids, tag=1) @@ -80,7 +80,7 @@ def receive(self, node_ids: str | int) -> Any: recv_req = self.comm.irecv(source=node_ids) return recv_req.wait() - # depreciated broadcast function + # deprecated broadcast function # def broadcast(self, data: Any): # for i in range(1, self.size): # if i != self.rank: From ee49b220f3f236a4d824d021ee3a9fc2bd267ae6 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 28 Oct 2024 17:16:10 -0400 Subject: [PATCH 03/30] testing mpi, model weights not acquired --- src/configs/sys_config.py | 2 +- src/utils/communication/mpi.py | 93 ++++++++++++++++++++++++++-------- 2 files changed, 72 insertions(+), 23 deletions(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 2e7e043..e3a354c 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -175,7 +175,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 409daf3..ee32f4e 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,13 +1,17 @@ -from collections import OrderedDict from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI from torch import Tensor from utils.communication.interface import CommunicationInterface import threading import time +from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model +import random + +if TYPE_CHECKING: + from algos.base_class import BaseNode class MPICommUtils(CommunicationInterface): - def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): + def __init__(self, config: Dict[str, Dict[str, Any]]): self.comm = MPI.COMM_WORLD self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() @@ -20,19 +24,32 @@ def __init__(self, config: Dict[str, Dict[str, Any]], data: Any): if self.required_threading_level > self.threading_level: raise RuntimeError(f"Insufficient thread support. Required: {self.required_threading_level}, Current: {self.threading_level}") - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() - send_thread = threading.Thread(target=self.send, args=(data)) - send_thread.start() - self.send_event = threading.Event() # Ensures that the listener thread and send thread are not using self.request_source at the same time - self.source_node_lock = threading.Lock() + self.lock = threading.Lock() self.request_source: int | None = None + self.is_working = True + self.communication_cost_received: int = 0 + self.communication_cost_sent: int = 0 + + self.base_node: BaseNode | None = None + + listener_thread = threading.Thread(target=self.listener, daemon=True) + listener_thread.start() + def initialize(self): pass + def register_self(self, obj: "BaseNode"): + self.base_node = obj + send_thread = threading.Thread(target=self.send) + send_thread.start() + + def get_comm_cost(self): + with self.lock: + return self.communication_cost_received, self.communication_cost_sent + def listener(self): """ Runs on listener thread on each node to receive a send request @@ -43,14 +60,28 @@ def listener(self): status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): - with self.source_node_lock: + with self.lock: self.request_source = status.Get_source() self.comm.irecv(source=self.request_source, tag=1) self.send_event.set() time.sleep(1) # Simulate waiting time - def send(self, data: Any): + def get_model(self) -> bytes | None: + print(f"getting model from {self.rank}, {self.base_node}") + if not self.base_node: + raise Exception("Base node not registered") + with self.lock: + if self.is_working: + print("model is working") + model = serialize_model(self.base_node.get_model_weights()) + print(f"model data to be sent: {model}") + else: + assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." + model = None + return model + + def send(self): """ Node will wait for a request to send data and then send the data to requesting node. @@ -58,33 +89,46 @@ def send(self, data: Any): while True: # Wait until the listener thread detects a request self.send_event.wait() - with self.source_node_lock: + with self.lock: dest = self.request_source if dest is not None: + data = self.get_model() req = self.comm.isend(data, dest=int(dest)) req.wait() - with self.source_node_lock: + with self.lock: self.request_source = None self.send_event.clear() - def receive(self, node_ids: str | int) -> Any: + def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ - node_ids = int(node_ids) - send_req = self.comm.isend("", dest=node_ids, tag=1) - send_req.wait() - recv_req = self.comm.irecv(source=node_ids) - return recv_req.wait() + max_tries = 10 + for node in node_ids: + while max_tries > 0: + try: + self.comm.send("", dest=node, tag=1) + recv_req = self.comm.irecv(source=node) + received_data = recv_req.wait() + print(f"received data: {received_data}") + return deserialize_model(received_data) + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(traceback.print_exc()) + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 # deprecated broadcast function - # def broadcast(self, data: Any): - # for i in range(1, self.size): - # if i != self.rank: - # self.send(i, data) + def broadcast(self, data: Any): + for i in range(1, self.size): + if i != self.rank: + self.comm.send(data, dest=i) def all_gather(self): """ @@ -92,6 +136,7 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): + print(f"receiving this data: {self.receive(i)}") print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items @@ -141,3 +186,7 @@ def set_is_working(self, is_working: bool): with self.lock: self.is_working = is_working + def set_is_working(self, is_working: bool): + with self.lock: + self.is_working = is_working + From 4d1929f043be59a5216a6cb429ff05e4379ff03c Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 3 Nov 2024 14:12:09 -0500 Subject: [PATCH 04/30] mpi works, occassional deadlock issue --- src/utils/communication/mpi.py | 131 +++++++++++++++------------------ 1 file changed, 60 insertions(+), 71 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index ee32f4e..3cf1ff5 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -1,11 +1,12 @@ +from collections import OrderedDict from typing import Dict, Any, List, TYPE_CHECKING from mpi4py import MPI from torch import Tensor from utils.communication.interface import CommunicationInterface import threading import time -from utils.communication.grpc.grpc_utils import deserialize_model, serialize_model import random +import numpy as np if TYPE_CHECKING: from algos.base_class import BaseNode @@ -16,6 +17,9 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.rank = self.comm.Get_rank() self.size = self.comm.Get_size() + self.num_users: int = int(config["num_users"]) # type: ignore + self.finished = False + # Ensure that we are using thread safe threading level self.required_threading_level = MPI.THREAD_MULTIPLE self.threading_level = MPI.Query_thread() @@ -35,16 +39,17 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.base_node: BaseNode | None = None - listener_thread = threading.Thread(target=self.listener, daemon=True) - listener_thread.start() + self.listener_thread = threading.Thread(target=self.listener) + self.listener_thread.start() + + self.send_thread = threading.Thread(target=self.send) def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - send_thread = threading.Thread(target=self.send) - send_thread.start() + self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -56,26 +61,30 @@ def listener(self): Once send request is received, the listener thread informs the send thread to send the data to the requesting node. """ - while True: + while not self.finished: status = MPI.Status() # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: self.request_source = status.Get_source() - self.comm.irecv(source=self.request_source, tag=1) + print(f"Node {self.rank} received request from {self.request_source}") + # receive_request = self.comm.irecv(source=self.request_source, tag=1) + # receive_request.wait() + self.comm.recv(source=self.request_source, tag=1) self.send_event.set() - time.sleep(1) # Simulate waiting time + # time.sleep(1) + print(f"Node {self.rank} listener thread ended") - def get_model(self) -> bytes | None: + def get_model(self) -> List[OrderedDict[str, Tensor]] | None: print(f"getting model from {self.rank}, {self.base_node}") if not self.base_node: raise Exception("Base node not registered") with self.lock: if self.is_working: - print("model is working") - model = serialize_model(self.base_node.get_model_weights()) - print(f"model data to be sent: {model}") + model = self.base_node.get_model_weights() + model = [model] + print(f"Model from {self.rank} acquired") else: assert self.base_node.dropout.dropout_enabled, "Empty models are only supported when Dropout is enabled." model = None @@ -86,43 +95,62 @@ def send(self): Node will wait for a request to send data and then send the data to requesting node. """ - while True: + while not self.finished: # Wait until the listener thread detects a request self.send_event.wait() + if self.finished: + break with self.lock: dest = self.request_source if dest is not None: data = self.get_model() - req = self.comm.isend(data, dest=int(dest)) - req.wait() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) with self.lock: self.request_source = None self.send_event.clear() + print(f"Node {self.rank} send thread ended") def receive(self, node_ids: List[int]) -> Any: """ Node will send a request for data and wait to receive data. """ max_tries = 10 - for node in node_ids: - while max_tries > 0: - try: - self.comm.send("", dest=node, tag=1) - recv_req = self.comm.irecv(source=node) - received_data = recv_req.wait() - print(f"received data: {received_data}") - return deserialize_model(received_data) - except Exception as e: - print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") - import traceback - print(traceback.print_exc()) - # sleep for a random time between 1 and 10 seconds - random_time = random.randint(1, 10) - time.sleep(random_time) - max_tries -= 1 + assert len(node_ids) == 1, "Too many node_ids to unpack" + node = node_ids[0] + while max_tries > 0: + try: + print(f"Node {self.rank} receiving from {node}") + self.comm.send("", dest=node, tag=1) + # recv_req = self.comm.Irecv([], source=node) + # received_data = recv_req.wait() + received_data = self.comm.recv(source=node) + print(f"Node {self.rank} received data from {node}: {bool(received_data)}") + if not received_data: + raise Exception("Received empty data") + return received_data + except MPI.Exception as e: + print(f"MPI failed {10 - max_tries} times: MPI ERROR: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + except Exception as e: + print(f"MPI failed {10 - max_tries} times: {e}", "Retrying...") + import traceback + print(f"Traceback: {traceback.print_exc()}") + # sleep for a random time between 1 and 10 seconds + random_time = random.randint(1, 10) + time.sleep(random_time) + max_tries -= 1 + print(f"Node {self.rank} received") # deprecated broadcast function def broadcast(self, data: Any): @@ -145,46 +173,7 @@ def send_finished(self): self.comm.send("Finished", dest=0, tag=2) def finalize(self): - # 1. All nodes send finished to the super node - # 2. super node will wait for all nodes to send finished - # 3. super node will then send bye to all nodes - # 4. all nodes will wait for the bye and then exit - # this is to ensure that all nodes have finished - # and no one leaves early - if self.rank == 0: - quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished - num_finished: set[int] = set() - status = MPI.Status() - while len(num_finished) < quorum_threshold: - print( - f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" - ) - # get finished nodes - self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) - print(f"received finish message from {status.Get_source()}") - num_finished.add(status.Get_source()) - - else: - # send finished to the super node - print(f"Node {self.rank} sent finish message") - self.send_finished() - - message = self.comm.bcast("Done", root=0) - self.finished = True - self.send_event.set() - print(f"Node {self.rank} received {message}, finished") - self.comm.Barrier() - self.listener_thread.join() - print(f"Node {self.rank} listener thread done") - print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") - print(f"Node {self.rank} {threading.enumerate()}") - self.comm.Barrier() - print(f"Node {self.rank}: all nodes synchronized") - MPI.Finalize() - - def set_is_working(self, is_working: bool): - with self.lock: - self.is_working = is_working + pass def set_is_working(self, is_working: bool): with self.lock: From 1b5a4223cafbe1abf620f25f7249797c4988cefc Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 5 Nov 2024 22:46:26 -0500 Subject: [PATCH 05/30] merged send and listener threads --- src/utils/communication/mpi.py | 80 +++++++++++++++++++++------------- 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 3cf1ff5..022f69e 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -42,14 +42,11 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): self.listener_thread = threading.Thread(target=self.listener) self.listener_thread.start() - self.send_thread = threading.Thread(target=self.send) - def initialize(self): pass def register_self(self, obj: "BaseNode"): self.base_node = obj - self.send_thread.start() def get_comm_cost(self): with self.lock: @@ -66,14 +63,14 @@ def listener(self): # look for message with tag 1 (represents send request) if self.comm.Iprobe(source=MPI.ANY_SOURCE, tag=1, status=status): with self.lock: - self.request_source = status.Get_source() + # self.request_source = status.Get_source() + dest = status.Get_source() print(f"Node {self.rank} received request from {self.request_source}") # receive_request = self.comm.irecv(source=self.request_source, tag=1) # receive_request.wait() - self.comm.recv(source=self.request_source, tag=1) - self.send_event.set() - # time.sleep(1) + self.comm.recv(source=dest, tag=1) + self.send(dest) print(f"Node {self.rank} listener thread ended") def get_model(self) -> List[OrderedDict[str, Tensor]] | None: @@ -90,31 +87,19 @@ def get_model(self) -> List[OrderedDict[str, Tensor]] | None: model = None return model - def send(self): + def send(self, dest: int): """ Node will wait for a request to send data and then send the data to requesting node. """ - while not self.finished: - # Wait until the listener thread detects a request - self.send_event.wait() - if self.finished: - break - with self.lock: - dest = self.request_source - - if dest is not None: - data = self.get_model() - print(f"Node {self.rank} is sending data to {dest}") - # req = self.comm.Isend(data, dest=int(dest)) - # req.wait() - self.comm.send(data, dest=int(dest)) - - with self.lock: - self.request_source = None - - self.send_event.clear() - print(f"Node {self.rank} send thread ended") + if self.finished: + return + + data = self.get_model() + print(f"Node {self.rank} is sending data to {dest}") + # req = self.comm.Isend(data, dest=int(dest)) + # req.wait() + self.comm.send(data, dest=int(dest)) def receive(self, node_ids: List[int]) -> Any: """ @@ -173,8 +158,43 @@ def send_finished(self): self.comm.send("Finished", dest=0, tag=2) def finalize(self): - pass - + # 1. All nodes send finished to the super node + # 2. super node will wait for all nodes to send finished + # 3. super node will then send bye to all nodes + # 4. all nodes will wait for the bye and then exit + # this is to ensure that all nodes have finished + # and no one leaves early + if self.rank == 0: + quorum_threshold = self.num_users - 1 # No +1 for the super node because it doesn't send finished + num_finished: set[int] = set() + status = MPI.Status() + while len(num_finished) < quorum_threshold: + print( + f"Waiting for {quorum_threshold} users to finish, {num_finished} have finished so far" + ) + # get finished nodes + self.comm.recv(source=MPI.ANY_SOURCE, tag=2, status=status) + print(f"received finish message from {status.Get_source()}") + num_finished.add(status.Get_source()) + + else: + # send finished to the super node + print(f"Node {self.rank} sent finish message") + self.send_finished() + + message = self.comm.bcast("Done", root=0) + self.finished = True + self.send_event.set() + print(f"Node {self.rank} received {message}, finished") + self.comm.Barrier() + self.listener_thread.join() + print(f"Node {self.rank} listener thread done") + print(f"Node {self.rank} listener thread is {self.listener_thread.is_alive()}") + print(f"Node {self.rank} {threading.enumerate()}") + self.comm.Barrier() + print(f"Node {self.rank}: all nodes synchronized") + MPI.Finalize() + def set_is_working(self, is_working: bool): with self.lock: self.is_working = is_working From 778dabc53cd0666e508cbda1a5089dea0cb9d7df Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Fri, 15 Nov 2024 13:12:34 -0500 Subject: [PATCH 06/30] first draft of test --- src/configs/sys_config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index e3a354c..91c6962 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -170,12 +170,13 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, - "seed": 32, + # "seed": 32, + "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): # "device_ids": {"node_0": [0], "node_1": [0], "node_2": [0], "node_3": [0]}, - "device_ids": get_device_ids(num_users=4, gpus_available=[1, 2]), + "device_ids": get_device_ids(num_users=3, gpus_available=[1, 2]), # use this when the list needs to be imported from the algo_config # "algo": get_algo_configs(num_users=3, algo_configs=algo_configs_list), "algos": get_algo_configs( From b1407f56cf4c703b95ffc82e8eea225bb9c09ae6 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 18 Nov 2024 17:44:25 -0500 Subject: [PATCH 07/30] testing workflow --- src/configs/algo_config_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index 2f2c7fc..264b904 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python3 + from utils.types import ConfigType # fedstatic: ConfigType = { From d99076d47b744e7bb1f7803a053b5f99c2682bd5 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 24 Nov 2024 20:55:06 -0500 Subject: [PATCH 08/30] predict next move ish --- src/scheduler.py | 1 - src/utils/communication/grpc/comm.proto | 5 +++ src/utils/communication/grpc/comm_pb2.py | 44 ++++++++++--------- src/utils/communication/grpc/comm_pb2_grpc.py | 43 ++++++++++++++++++ src/utils/communication/grpc/main.py | 3 ++ 5 files changed, 74 insertions(+), 22 deletions(-) diff --git a/src/scheduler.py b/src/scheduler.py index 55da449..23cc327 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -129,7 +129,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: rank=self.communication.get_rank(), comm_utils=self.communication, ) - self.communication.send_quorum() def run_job(self) -> None: self.node.run_protocol() diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto index 8f689c3..c69ade5 100644 --- a/src/utils/communication/grpc/comm.proto +++ b/src/utils/communication/grpc/comm.proto @@ -3,6 +3,7 @@ syntax = "proto3"; service CommunicationServer { + rpc send_status(Empty) returns (Status) {} rpc send_data (Data) returns (Empty) {} rpc send_model (Model) returns (Empty) {} rpc get_rank (Empty) returns (Rank) {} @@ -16,6 +17,10 @@ service CommunicationServer { message Empty {} +message Status{ + string message = 1; +} + message Model { bytes buffer = 1; } diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py index a9b03ce..4393839 100644 --- a/src/utils/communication/grpc/comm_pb2.py +++ b/src/utils/communication/grpc/comm_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xc1\x02\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x19\n\x06Status\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xe3\x02\n\x13\x43ommunicationServer\x12 \n\x0bsend_status\x12\x06.Empty\x1a\x07.Status\"\x00\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,24 +25,26 @@ _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' _globals['_EMPTY']._serialized_start=14 _globals['_EMPTY']._serialized_end=21 - _globals['_MODEL']._serialized_start=23 - _globals['_MODEL']._serialized_end=46 - _globals['_DATA']._serialized_start=48 - _globals['_DATA']._serialized_end=89 - _globals['_RANK']._serialized_start=91 - _globals['_RANK']._serialized_end=111 - _globals['_ROUND']._serialized_start=113 - _globals['_ROUND']._serialized_end=135 - _globals['_PORT']._serialized_start=137 - _globals['_PORT']._serialized_end=157 - _globals['_PEERID']._serialized_start=159 - _globals['_PEERID']._serialized_end=221 - _globals['_PEERIDS']._serialized_start=223 - _globals['_PEERIDS']._serialized_end=330 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=275 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=330 - _globals['_QUORUM']._serialized_start=332 - _globals['_QUORUM']._serialized_end=356 - _globals['_COMMUNICATIONSERVER']._serialized_start=359 - _globals['_COMMUNICATIONSERVER']._serialized_end=680 + _globals['_STATUS']._serialized_start=23 + _globals['_STATUS']._serialized_end=48 + _globals['_MODEL']._serialized_start=50 + _globals['_MODEL']._serialized_end=73 + _globals['_DATA']._serialized_start=75 + _globals['_DATA']._serialized_end=116 + _globals['_RANK']._serialized_start=118 + _globals['_RANK']._serialized_end=138 + _globals['_ROUND']._serialized_start=140 + _globals['_ROUND']._serialized_end=162 + _globals['_PORT']._serialized_start=164 + _globals['_PORT']._serialized_end=184 + _globals['_PEERID']._serialized_start=186 + _globals['_PEERID']._serialized_end=248 + _globals['_PEERIDS']._serialized_start=250 + _globals['_PEERIDS']._serialized_end=357 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=302 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=357 + _globals['_QUORUM']._serialized_start=359 + _globals['_QUORUM']._serialized_end=383 + _globals['_COMMUNICATIONSERVER']._serialized_start=386 + _globals['_COMMUNICATIONSERVER']._serialized_end=741 # @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py index e0258f7..ea45534 100644 --- a/src/utils/communication/grpc/comm_pb2_grpc.py +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -39,6 +39,11 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ + self.send_status = channel.unary_unary( + '/CommunicationServer/send_status', + request_serializer=comm__pb2.Empty.SerializeToString, + response_deserializer=comm__pb2.Status.FromString, + _registered_method=True) self.send_data = channel.unary_unary( '/CommunicationServer/send_data', request_serializer=comm__pb2.Data.SerializeToString, @@ -89,6 +94,12 @@ def __init__(self, channel): class CommunicationServerServicer(object): """Missing associated documentation comment in .proto file.""" + def send_status(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def send_data(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -146,6 +157,11 @@ def send_finished(self, request, context): def add_CommunicationServerServicer_to_server(servicer, server): rpc_method_handlers = { + 'send_status': grpc.unary_unary_rpc_method_handler( + servicer.send_status, + request_deserializer=comm__pb2.Empty.FromString, + response_serializer=comm__pb2.Status.SerializeToString, + ), 'send_data': grpc.unary_unary_rpc_method_handler( servicer.send_data, request_deserializer=comm__pb2.Data.FromString, @@ -202,6 +218,33 @@ def add_CommunicationServerServicer_to_server(servicer, server): class CommunicationServer(object): """Missing associated documentation comment in .proto file.""" + @staticmethod + def send_status(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/CommunicationServer/send_status', + comm__pb2.Empty.SerializeToString, + comm__pb2.Status.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def send_data(request, target, diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index b850f0e..730728e 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -175,6 +175,9 @@ def update_port( self.peer_ids[request.rank.rank]["ip"] = request.ip # type: ignore self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore return comm_pb2.Empty() # type: ignore + + def send_status(self, request, context) -> comm_pb2.Status: + return comm_pb2.Status(message="Ready") # type: ignore def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # type: ignore """ From 1a9df18378982de2bcd53fa20cbe8719ddc9567b Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Tue, 26 Nov 2024 14:15:53 -0500 Subject: [PATCH 09/30] moved quorum send --- .github/workflows/train.yml | 28 +++++++++++++++++++++++----- src/configs/algo_config_test.py | 2 -- src/scheduler.py | 1 + src/utils/communication/grpc/main.py | 3 --- src/utils/communication/mpi.py | 4 ++++ 5 files changed, 28 insertions(+), 10 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 3d7e86f..37e16f8 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -22,6 +22,20 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 + # - name: check directories + # run: | + # cd src + # DIR="../../../../../../../home" + # if [ -d "$DIR" ]; then + # ### Take action if $DIR exists ### + # echo "Installing config files in ${DIR}" + # exit 1 + # else + # ### Control will jump here if $DIR does NOT exists ### + # echo "Error: ${DIR} not found. Can not continue." + # exit 1 + # fi + # Step 2: Set up Python - name: Set up Python uses: actions/setup-python@v4 @@ -49,8 +63,12 @@ jobs: python main.py -super true -s "./configs/sys_config_test.py" echo "done" - # further checks: - # only 5 rounds - # gRPC only? or also MPI? - # num of samples - # num users and nodes + - name: Clean up + run: | + rm -rf ./sonar_experiments/ + + # further checks: + # only 5 rounds + # gRPC only? or also MPI? + # num of samples + # num users and nodes diff --git a/src/configs/algo_config_test.py b/src/configs/algo_config_test.py index 264b904..2f2c7fc 100644 --- a/src/configs/algo_config_test.py +++ b/src/configs/algo_config_test.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 - from utils.types import ConfigType # fedstatic: ConfigType = { diff --git a/src/scheduler.py b/src/scheduler.py index 23cc327..6094921 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -109,6 +109,7 @@ def initialize(self, copy_souce_code: bool = True) -> None: numpy.random.seed(seed) self.merge_configs() if self.communication.get_rank() == 0: + print("initializing super node") if copy_souce_code: copy_source_code(self.config) else: diff --git a/src/utils/communication/grpc/main.py b/src/utils/communication/grpc/main.py index 730728e..b850f0e 100644 --- a/src/utils/communication/grpc/main.py +++ b/src/utils/communication/grpc/main.py @@ -175,9 +175,6 @@ def update_port( self.peer_ids[request.rank.rank]["ip"] = request.ip # type: ignore self.peer_ids[request.rank.rank]["port"] = request.port.port # type: ignore return comm_pb2.Empty() # type: ignore - - def send_status(self, request, context) -> comm_pb2.Status: - return comm_pb2.Status(message="Ready") # type: ignore def send_peer_ids(self, request: comm_pb2.PeerIds, context) -> comm_pb2.Empty: # type: ignore """ diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index 022f69e..bd565f8 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -45,6 +45,10 @@ def __init__(self, config: Dict[str, Dict[str, Any]]): def initialize(self): pass + def send_quorum(self) -> Any: + # return super().send_quorum(node_ids) + pass + def register_self(self, obj: "BaseNode"): self.base_node = obj From aefb94f57f9385db704eb54f0517982752065e7a Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:23:03 -0500 Subject: [PATCH 10/30] using traditional fl algo --- src/configs/algo_config.py | 3 ++- src/scheduler.py | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/configs/algo_config.py b/src/configs/algo_config.py index f0f4976..a2242c6 100644 --- a/src/configs/algo_config.py +++ b/src/configs/algo_config.py @@ -31,7 +31,8 @@ def get_malicious_types(malicious_config_list: List[ConfigType]) -> Dict[str, st traditional_fl: ConfigType = { # Collaboration setup "algo": "fedavg", - "rounds": 5, + "rounds": 2, + # Model parameters "model": "resnet10", "model_lr": 3e-4, diff --git a/src/scheduler.py b/src/scheduler.py index 6094921..23cc327 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -109,7 +109,6 @@ def initialize(self, copy_souce_code: bool = True) -> None: numpy.random.seed(seed) self.merge_configs() if self.communication.get_rank() == 0: - print("initializing super node") if copy_souce_code: copy_source_code(self.config) else: From af0c113949c83ce7f511e996d972d047e69ecbd7 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:25:41 -0500 Subject: [PATCH 11/30] run test only during push to main --- .github/workflows/train.yml | 32 +++++++------------------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 37e16f8..a9e67ec 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -4,8 +4,8 @@ on: workflow_dispatch: push: branches: - # - main - - "*" + - main + # - "*" pull_request: branches: - main @@ -22,20 +22,6 @@ jobs: - name: Checkout repository uses: actions/checkout@v3 - # - name: check directories - # run: | - # cd src - # DIR="../../../../../../../home" - # if [ -d "$DIR" ]; then - # ### Take action if $DIR exists ### - # echo "Installing config files in ${DIR}" - # exit 1 - # else - # ### Control will jump here if $DIR does NOT exists ### - # echo "Error: ${DIR} not found. Can not continue." - # exit 1 - # fi - # Step 2: Set up Python - name: Set up Python uses: actions/setup-python@v4 @@ -63,12 +49,8 @@ jobs: python main.py -super true -s "./configs/sys_config_test.py" echo "done" - - name: Clean up - run: | - rm -rf ./sonar_experiments/ - - # further checks: - # only 5 rounds - # gRPC only? or also MPI? - # num of samples - # num users and nodes + # further checks: + # only 5 rounds + # gRPC only? or also MPI? + # num of samples + # num users and nodes From fa1cd06ed7fdeebc9b6507faefa8387fd6ba9e7c Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Wed, 27 Nov 2024 13:30:31 -0500 Subject: [PATCH 12/30] new dump_dir --- .github/workflows/train.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index a9e67ec..3d7e86f 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -4,8 +4,8 @@ on: workflow_dispatch: push: branches: - - main - # - "*" + # - main + - "*" pull_request: branches: - main From 6963564ba5e153afc87e4ce7ac6e4d1a6718dad4 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Fri, 29 Nov 2024 20:15:49 -0500 Subject: [PATCH 13/30] remove send_status from proto --- src/utils/communication/grpc/comm.proto | 5 --- src/utils/communication/grpc/comm_pb2.py | 44 +++++++++---------- src/utils/communication/grpc/comm_pb2_grpc.py | 43 ------------------ 3 files changed, 21 insertions(+), 71 deletions(-) diff --git a/src/utils/communication/grpc/comm.proto b/src/utils/communication/grpc/comm.proto index c69ade5..8f689c3 100644 --- a/src/utils/communication/grpc/comm.proto +++ b/src/utils/communication/grpc/comm.proto @@ -3,7 +3,6 @@ syntax = "proto3"; service CommunicationServer { - rpc send_status(Empty) returns (Status) {} rpc send_data (Data) returns (Empty) {} rpc send_model (Model) returns (Empty) {} rpc get_rank (Empty) returns (Rank) {} @@ -17,10 +16,6 @@ service CommunicationServer { message Empty {} -message Status{ - string message = 1; -} - message Model { bytes buffer = 1; } diff --git a/src/utils/communication/grpc/comm_pb2.py b/src/utils/communication/grpc/comm_pb2.py index 4393839..a9b03ce 100644 --- a/src/utils/communication/grpc/comm_pb2.py +++ b/src/utils/communication/grpc/comm_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x19\n\x06Status\x12\x0f\n\x07message\x18\x01 \x01(\t\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xe3\x02\n\x13\x43ommunicationServer\x12 \n\x0bsend_status\x12\x06.Empty\x1a\x07.Status\"\x00\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\ncomm.proto\"\x07\n\x05\x45mpty\"\x17\n\x05Model\x12\x0e\n\x06\x62uffer\x18\x01 \x01(\x0c\")\n\x04\x44\x61ta\x12\n\n\x02id\x18\x01 \x01(\t\x12\x15\n\x05model\x18\x02 \x01(\x0b\x32\x06.Model\"\x14\n\x04Rank\x12\x0c\n\x04rank\x18\x01 \x01(\x05\"\x16\n\x05Round\x12\r\n\x05round\x18\x01 \x01(\x05\"\x14\n\x04Port\x12\x0c\n\x04port\x18\x01 \x01(\x05\">\n\x06PeerId\x12\x13\n\x04rank\x18\x01 \x01(\x0b\x32\x05.Rank\x12\x13\n\x04port\x18\x02 \x01(\x0b\x32\x05.Port\x12\n\n\x02ip\x18\x03 \x01(\t\"k\n\x07PeerIds\x12\'\n\x08peer_ids\x18\x01 \x03(\x0b\x32\x15.PeerIds.PeerIdsEntry\x1a\x37\n\x0cPeerIdsEntry\x12\x0b\n\x03key\x18\x01 \x01(\x05\x12\x16\n\x05value\x18\x02 \x01(\x0b\x32\x07.PeerId:\x02\x38\x01\"\x18\n\x06Quorum\x12\x0e\n\x06quorum\x18\x01 \x01(\x08\x32\xc1\x02\n\x13\x43ommunicationServer\x12\x1c\n\tsend_data\x12\x05.Data\x1a\x06.Empty\"\x00\x12\x1e\n\nsend_model\x12\x06.Model\x1a\x06.Empty\"\x00\x12\x1b\n\x08get_rank\x12\x06.Empty\x1a\x05.Rank\"\x00\x12\x1d\n\tget_model\x12\x06.Empty\x1a\x06.Model\"\x00\x12%\n\x11get_current_round\x12\x06.Empty\x1a\x06.Round\"\x00\x12 \n\x0bupdate_port\x12\x07.PeerId\x1a\x06.Empty\"\x00\x12#\n\rsend_peer_ids\x12\x08.PeerIds\x1a\x06.Empty\"\x00\x12 \n\x0bsend_quorum\x12\x07.Quorum\x1a\x06.Empty\"\x00\x12 \n\rsend_finished\x12\x05.Rank\x1a\x06.Empty\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,26 +25,24 @@ _globals['_PEERIDS_PEERIDSENTRY']._serialized_options = b'8\001' _globals['_EMPTY']._serialized_start=14 _globals['_EMPTY']._serialized_end=21 - _globals['_STATUS']._serialized_start=23 - _globals['_STATUS']._serialized_end=48 - _globals['_MODEL']._serialized_start=50 - _globals['_MODEL']._serialized_end=73 - _globals['_DATA']._serialized_start=75 - _globals['_DATA']._serialized_end=116 - _globals['_RANK']._serialized_start=118 - _globals['_RANK']._serialized_end=138 - _globals['_ROUND']._serialized_start=140 - _globals['_ROUND']._serialized_end=162 - _globals['_PORT']._serialized_start=164 - _globals['_PORT']._serialized_end=184 - _globals['_PEERID']._serialized_start=186 - _globals['_PEERID']._serialized_end=248 - _globals['_PEERIDS']._serialized_start=250 - _globals['_PEERIDS']._serialized_end=357 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=302 - _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=357 - _globals['_QUORUM']._serialized_start=359 - _globals['_QUORUM']._serialized_end=383 - _globals['_COMMUNICATIONSERVER']._serialized_start=386 - _globals['_COMMUNICATIONSERVER']._serialized_end=741 + _globals['_MODEL']._serialized_start=23 + _globals['_MODEL']._serialized_end=46 + _globals['_DATA']._serialized_start=48 + _globals['_DATA']._serialized_end=89 + _globals['_RANK']._serialized_start=91 + _globals['_RANK']._serialized_end=111 + _globals['_ROUND']._serialized_start=113 + _globals['_ROUND']._serialized_end=135 + _globals['_PORT']._serialized_start=137 + _globals['_PORT']._serialized_end=157 + _globals['_PEERID']._serialized_start=159 + _globals['_PEERID']._serialized_end=221 + _globals['_PEERIDS']._serialized_start=223 + _globals['_PEERIDS']._serialized_end=330 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_start=275 + _globals['_PEERIDS_PEERIDSENTRY']._serialized_end=330 + _globals['_QUORUM']._serialized_start=332 + _globals['_QUORUM']._serialized_end=356 + _globals['_COMMUNICATIONSERVER']._serialized_start=359 + _globals['_COMMUNICATIONSERVER']._serialized_end=680 # @@protoc_insertion_point(module_scope) diff --git a/src/utils/communication/grpc/comm_pb2_grpc.py b/src/utils/communication/grpc/comm_pb2_grpc.py index ea45534..e0258f7 100644 --- a/src/utils/communication/grpc/comm_pb2_grpc.py +++ b/src/utils/communication/grpc/comm_pb2_grpc.py @@ -39,11 +39,6 @@ def __init__(self, channel): Args: channel: A grpc.Channel. """ - self.send_status = channel.unary_unary( - '/CommunicationServer/send_status', - request_serializer=comm__pb2.Empty.SerializeToString, - response_deserializer=comm__pb2.Status.FromString, - _registered_method=True) self.send_data = channel.unary_unary( '/CommunicationServer/send_data', request_serializer=comm__pb2.Data.SerializeToString, @@ -94,12 +89,6 @@ def __init__(self, channel): class CommunicationServerServicer(object): """Missing associated documentation comment in .proto file.""" - def send_status(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details('Method not implemented!') - raise NotImplementedError('Method not implemented!') - def send_data(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -157,11 +146,6 @@ def send_finished(self, request, context): def add_CommunicationServerServicer_to_server(servicer, server): rpc_method_handlers = { - 'send_status': grpc.unary_unary_rpc_method_handler( - servicer.send_status, - request_deserializer=comm__pb2.Empty.FromString, - response_serializer=comm__pb2.Status.SerializeToString, - ), 'send_data': grpc.unary_unary_rpc_method_handler( servicer.send_data, request_deserializer=comm__pb2.Data.FromString, @@ -218,33 +202,6 @@ def add_CommunicationServerServicer_to_server(servicer, server): class CommunicationServer(object): """Missing associated documentation comment in .proto file.""" - @staticmethod - def send_status(request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None): - return grpc.experimental.unary_unary( - request, - target, - '/CommunicationServer/send_status', - comm__pb2.Empty.SerializeToString, - comm__pb2.Status.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - _registered_method=True) - @staticmethod def send_data(request, target, From 2377619cdb3b6218f69e4080f33fcbc3b85d6b85 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 11:03:33 -0500 Subject: [PATCH 14/30] changed dump_dir --- src/configs/sys_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index 91c6962..c0b4031 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -170,7 +170,7 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dset": CIFAR10_DSET, "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, - # "seed": 32, + "seed": 32, "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. From c776735712852cfdba24ea0a874cc8adeb9fe8bb Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 11:06:40 -0500 Subject: [PATCH 15/30] small changes --- src/configs/sys_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/configs/sys_config.py b/src/configs/sys_config.py index c0b4031..2e7e043 100644 --- a/src/configs/sys_config.py +++ b/src/configs/sys_config.py @@ -171,7 +171,6 @@ def get_digit_five_support(num_users: int, domains: List[str] = DIGIT_FIVE): "dump_dir": DUMP_DIR, "dpath": CIAR10_DPATH, "seed": 32, - "seed": 2, # node_0 is a server currently # The device_ids dictionary depicts the GPUs on which the nodes reside. # For a single-GPU environment, the config will look as follows (as it follows a 0-based indexing): From 1dc26100e50cc3009212488f09f03f400e0287fc Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 13:33:19 -0500 Subject: [PATCH 16/30] inversefed data unhidden --- .gitignore | 1 + src/inversefed/data/README.md | 3 ++ src/inversefed/data/__init__.py | 6 +++ src/inversefed/data/data.py | 96 +++++++++++++++++++++++++++++++++ src/inversefed/data/datasets.py | 62 +++++++++++++++++++++ 5 files changed, 168 insertions(+) create mode 100644 src/inversefed/data/README.md diff --git a/.gitignore b/.gitignore index e03029f..5124753 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ di_test/ imgs/ pascal/ data/ +!src/inversefed/data/ notes.txt removeme*.png diff --git a/src/inversefed/data/README.md b/src/inversefed/data/README.md new file mode 100644 index 0000000..0cf811e --- /dev/null +++ b/src/inversefed/data/README.md @@ -0,0 +1,3 @@ +# Data Processing + +This module implements ```construct_dataloaders```. \ No newline at end of file diff --git a/src/inversefed/data/__init__.py b/src/inversefed/data/__init__.py index e69de29..be87c57 100644 --- a/src/inversefed/data/__init__.py +++ b/src/inversefed/data/__init__.py @@ -0,0 +1,6 @@ +"""Data stuff that I usually don't want to see.""" + +from .data_processing import construct_dataloaders + + +__all__ = ['construct_dataloaders'] diff --git a/src/inversefed/data/data.py b/src/inversefed/data/data.py index e69de29..d2d7cd6 100644 --- a/src/inversefed/data/data.py +++ b/src/inversefed/data/data.py @@ -0,0 +1,96 @@ +"""This is data.py from pytorch-examples. + +Refer to +https://github.com/pytorch/examples/blob/master/super_resolution/data.py. +""" + +from os.path import exists, join, basename +from os import makedirs, remove +from six.moves import urllib +import tarfile +from torchvision.transforms import Compose, CenterCrop, ToTensor, Resize, RandomCrop + + +from .datasets import DatasetFromFolder + +def _build_bsds_sr(data_path, augmentations=True, normalize=True, upscale_factor=3, RGB=True): + root_dir = _download_bsd300(dest=data_path) + train_dir = join(root_dir, "train") + crop_size = _calculate_valid_crop_size(256, upscale_factor) + print(f'Crop size is {crop_size}. Upscaling factor is {upscale_factor} in mode {RGB}.') + + trainset = DatasetFromFolder(train_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), RGB=RGB) + + test_dir = join(root_dir, "test") + validset = DatasetFromFolder(test_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), RGB=RGB) + return trainset, validset + +def _build_bsds_dn(data_path, augmentations=True, normalize=True, upscale_factor=1, noise_level=25 / 255, RGB=True): + root_dir = _download_bsd300(dest=data_path) + train_dir = join(root_dir, "train") + + crop_size = _calculate_valid_crop_size(256, upscale_factor) + patch_size = 64 + print(f'Crop size is {crop_size} for patches of size {patch_size}. ' + f'Upscaling factor is {upscale_factor} in mode RGB={RGB}.') + + trainset = DatasetFromFolder(train_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor, patch_size=patch_size), + target_transform=_target_transform(crop_size, patch_size=patch_size), + noise_level=noise_level, RGB=RGB) + + test_dir = join(root_dir, "test") + validset = DatasetFromFolder(test_dir, replicate=200, + input_transform=_input_transform(crop_size, upscale_factor), + target_transform=_target_transform(crop_size), + noise_level=noise_level, RGB=RGB) + return trainset, validset + + +def _download_bsd300(dest="dataset"): + output_image_dir = join(dest, "BSDS300/images") + + if not exists(output_image_dir): + makedirs(dest, exist_ok=True) + url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" + print("downloading url ", url) + + data = urllib.request.urlopen(url) + + file_path = join(dest, basename(url)) + with open(file_path, 'wb') as f: + f.write(data.read()) + + print("Extracting data") + with tarfile.open(file_path) as tar: + for item in tar: + tar.extract(item, dest) + + remove(file_path) + + return output_image_dir + + +def _calculate_valid_crop_size(crop_size, upscale_factor): + return crop_size - (crop_size % upscale_factor) + + +def _input_transform(crop_size, upscale_factor, patch_size=None): + return Compose([ + CenterCrop(crop_size), + Resize(crop_size // upscale_factor), + RandomCrop(patch_size if patch_size is not None else crop_size // upscale_factor), + ToTensor(), + ]) + + +def _target_transform(crop_size, patch_size=None): + return Compose([ + CenterCrop(crop_size), + RandomCrop(patch_size if patch_size is not None else crop_size), + ToTensor(), + ]) diff --git a/src/inversefed/data/datasets.py b/src/inversefed/data/datasets.py index e69de29..dd28cbf 100644 --- a/src/inversefed/data/datasets.py +++ b/src/inversefed/data/datasets.py @@ -0,0 +1,62 @@ +"""This is dataset.py from pytorch-examples. + +Refer to + +https://github.com/pytorch/examples/blob/master/super_resolution/dataset.py. +""" +import torch +import torch.utils.data as data + +from os import listdir +from os.path import join +from PIL import Image + + +def _is_image_file(filename): + return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"]) + + +def _load_img(filepath, RGB=True): + img = Image.open(filepath) + if RGB: + pass + else: + img = img.convert('YCbCr') + img, _, _ = img.split() + return img + + +class DatasetFromFolder(data.Dataset): + """Generate an image-to-image dataset from images from the given folder.""" + + def __init__(self, image_dir, replicate=1, input_transform=None, target_transform=None, RGB=True, noise_level=0.0): + """Init with directory, transforms and RGB switch.""" + super(DatasetFromFolder, self).__init__() + self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if _is_image_file(x)] + + self.input_transform = input_transform + self.target_transform = target_transform + + self.replicate = replicate + self.classes = [None] + self.RGB = RGB + self.noise_level = noise_level + + def __getitem__(self, index): + """Index into dataset.""" + input = _load_img(self.image_filenames[index % len(self.image_filenames)], RGB=self.RGB) + target = input.copy() + if self.input_transform: + input = self.input_transform(input) + if self.target_transform: + target = self.target_transform(target) + + if self.noise_level > 0: + # Add noise + input += self.noise_level * torch.randn_like(input) + + return input, target + + def __len__(self): + """Length is amount of files found.""" + return len(self.image_filenames) * self.replicate From 859c8c4299ede4be2741634ae10ef7e719643e26 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 15:43:40 -0500 Subject: [PATCH 17/30] cpu requirements --- .github/workflows/train.yml | 4 +- requirements_cpu.txt | 161 +++++++++++++++++++++++++++++++++ src/configs/sys_config_test.py | 10 +- 3 files changed, 173 insertions(+), 2 deletions(-) create mode 100644 requirements_cpu.txt diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 3d7e86f..e176347 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -16,6 +16,8 @@ env: jobs: train-check: runs-on: ubuntu-latest + env: + DEVICE: cpu steps: # Step 1: Checkout the code @@ -35,7 +37,7 @@ jobs: sudo apt install -y libopenmpi-dev openmpi-bin sudo apt-get install -y libgl1 libglib2.0-0 - pip install -r requirements.txt + pip install -r requirements_cpu.txt # Step 4: Run gRPC server and client - name: Run test diff --git a/requirements_cpu.txt b/requirements_cpu.txt new file mode 100644 index 0000000..db98dd9 --- /dev/null +++ b/requirements_cpu.txt @@ -0,0 +1,161 @@ +anyio==4.3.0 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.1 +async-lru==2.0.4 +attrs==23.2.0 +Babel==2.15.0 +beautifulsoup4==4.12.3 +bleach==6.1.0 +certifi==2024.2.2 +cffi==1.16.0 +charset-normalizer==3.3.2 +click==8.1.7 +colorama==0.4.6 +comm==0.2.2 +contourpy==1.2.1 +cycler==0.12.1 +debugpy==1.8.1 +decorator==5.1.1 +defusedxml==0.7.1 +exceptiongroup==1.2.1 +executing==2.0.1 +fastjsonschema==2.19.1 +filelock==3.14.0 +fire==0.6.0 +fonttools==4.52.1 +fqdn==1.5.1 +fsspec==2024.5.0 +ghp-import==2.1.0 +grpcio==1.64.0 +grpcio-tools==1.64.0 +h11==0.14.0 +httpcore==1.0.5 +httpx==0.27.0 +idna==3.7 +imageio==2.34.1 +ipykernel==6.29.4 +ipython==8.24.0 +isoduration==20.11.0 +jedi==0.19.1 +Jinja2==3.1.4 +jmespath==1.0.1 +joblib==1.4.2 +json5==0.9.25 +jsonpointer==2.4 +jsonschema==4.22.0 +jsonschema-specifications==2023.12.1 +jupyter-events==0.10.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.2 +jupyter_core==5.7.2 +jupyter_server==2.14.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.2.5 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.2 +kiwisolver==1.4.5 +lazy_loader==0.4 +littleutils==0.2.2 +Markdown==3.7 +MarkupSafe==2.1.5 +matplotlib==3.9.0 +matplotlib-inline==0.1.7 +medmnist==3.0.1 +mergedeep==1.3.4 +mistune==3.0.2 +mkdocs==1.6.0 +mkdocs-get-deps==0.2.0 +mkdocs-material==9.5.31 +mkdocs-material-extensions==1.3.1 +mpi4py==3.1.6 +mpmath==1.3.0 +nbclient==0.10.0 +nbconvert==7.16.4 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.3 +notebook_shim==0.2.4 +numpy +# nvidia-cublas-cu12==12.1.3.1 +# nvidia-cuda-cupti-cu12==12.1.105 +# nvidia-cuda-nvrtc-cu12==12.1.105 +# nvidia-cuda-runtime-cu12==12.1.105 +# nvidia-cudnn-cu12==8.9.2.26 +# nvidia-cufft-cu12==11.0.2.54 +# nvidia-curand-cu12==10.3.2.106 +# nvidia-cusolver-cu12==11.4.5.107 +# nvidia-cusparse-cu12==12.1.0.106 +# nvidia-nccl-cu12==2.20.5 +# nvidia-nvjitlink-cu12==12.5.40 +# nvidia-nvtx-cu12==12.1.105 +ogb==1.3.6 +opencv-python==4.10.0.84 +outdated==0.2.2 +overrides==7.7.0 +packaging==24.0 +paginate==0.5.6 +pandas==2.2.2 +pandocfilters==1.5.1 +parso==0.8.4 +pathspec==0.12.1 +pexpect==4.9.0 +pillow==10.3.0 +platformdirs==4.2.2 +prometheus_client==0.20.0 +prompt-toolkit==3.0.43 +protobuf==5.26.1 +psutil==5.9.8 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pycparser==2.22 +Pygments==2.18.0 +pymdown-extensions==10.9 +pyparsing==3.1.2 +python-dateutil==2.9.0.post0 +python-json-logger==2.0.7 +pytz==2024.1 +PyYAML==6.0.1 +pyyaml_env_tag==0.1 +pyzmq==26.0.3 +referencing==0.35.1 +regex==2024.7.24 +requests==2.32.2 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rpds-py==0.18.1 +scikit-image==0.23.2 +scikit-learn==1.5.0 +scipy==1.13.1 +Send2Trash==1.8.3 +six==1.16.0 +sniffio==1.3.1 +soupsieve==2.5 +stack-data==0.6.3 +sympy==1.12 +tensorboardX==2.6.2.2 +termcolor==2.4.0 +terminado==0.18.1 +threadpoolctl==3.5.0 +tifffile==2024.5.22 +tinycss2==1.3.0 +tomli==2.0.1 +torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.3.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl#sha256=896e8a82f441ff8ae5c8acd2fddd48d9ae1f2738fb8c083b912debbbe09c553b +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.18.0%2Bcpu-cp310-cp310-linux_x86_64.whl#sha256=6f285024f8a93598a67fbd2c86b534cd01236e8f6bda8704cd3459cc63d665a3 +tornado==6.4 +tqdm==4.66.4 +traitlets==5.14.3 +#triton +types-python-dateutil==2.9.0.20240316 +typing_extensions==4.12.0 +tzdata==2024.1 +uri-template==1.3.0 +urllib3==2.2.1 +Wand==0.6.13 +watchdog==4.0.2 +wcwidth==0.2.13 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.8.0 +wilds==2.0.0 diff --git a/src/configs/sys_config_test.py b/src/configs/sys_config_test.py index f357541..4737d16 100644 --- a/src/configs/sys_config_test.py +++ b/src/configs/sys_config_test.py @@ -121,6 +121,14 @@ def get_algo_configs( "exp_keys": [], "dropout_dicts": dropout_dicts, "test_samples_per_user": 200, + "log_memory": True, + # "streaming_aggregation": True, # Make it true for fedstatic + "assign_based_on_host": True, + "hostname_to_device_ids": { + "matlaber1": [2, 3, 4, 5, 6, 7], + "matlaber12": [0, 1, 2, 3], + "matlaber3": [0, 1, 2, 3], + "matlaber4": [0, 2, 3, 4, 5, 6, 7], + } } - current_config = grpc_system_config \ No newline at end of file From c8bc6c63199fd40a66f0c04a076d5b408731bf2b Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 15:49:25 -0500 Subject: [PATCH 18/30] test cpu torch --- .github/workflows/train.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index e176347..3e95519 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -38,6 +38,8 @@ jobs: sudo apt-get install -y libgl1 libglib2.0-0 pip install -r requirements_cpu.txt + python -c "import torch" + python -c "import torchvision" # Step 4: Run gRPC server and client - name: Run test From 40e48e0b41137730a919db75ad93d3128b4b6ca3 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 15:57:39 -0500 Subject: [PATCH 19/30] test cpu torch --- requirements_cpu.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements_cpu.txt b/requirements_cpu.txt index db98dd9..977d993 100644 --- a/requirements_cpu.txt +++ b/requirements_cpu.txt @@ -141,8 +141,8 @@ threadpoolctl==3.5.0 tifffile==2024.5.22 tinycss2==1.3.0 tomli==2.0.1 -torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.3.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl#sha256=896e8a82f441ff8ae5c8acd2fddd48d9ae1f2738fb8c083b912debbbe09c553b -torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.18.0%2Bcpu-cp310-cp310-linux_x86_64.whl#sha256=6f285024f8a93598a67fbd2c86b534cd01236e8f6bda8704cd3459cc63d665a3 +torch == 2.3.0 --index-url https://download.pytorch.org/whl/cpu +torchvision == 0.18.0 --index-url https://download.pytorch.org/whl/cpu tornado==6.4 tqdm==4.66.4 traitlets==5.14.3 From 32f75fcdd51b8f4e00fed478e735b0e751126455 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:01:41 -0500 Subject: [PATCH 20/30] test cpu torch --- requirements_cpu.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements_cpu.txt b/requirements_cpu.txt index 977d993..e289174 100644 --- a/requirements_cpu.txt +++ b/requirements_cpu.txt @@ -141,8 +141,8 @@ threadpoolctl==3.5.0 tifffile==2024.5.22 tinycss2==1.3.0 tomli==2.0.1 -torch == 2.3.0 --index-url https://download.pytorch.org/whl/cpu -torchvision == 0.18.0 --index-url https://download.pytorch.org/whl/cpu +torch == 2.3.0+cpu +torchvision == 0.18.0+cpu tornado==6.4 tqdm==4.66.4 traitlets==5.14.3 From ff268abb58edd579804cdb50f4a6a50d77e732d7 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:02:53 -0500 Subject: [PATCH 21/30] test cpu torch --- requirements_cpu.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements_cpu.txt b/requirements_cpu.txt index e289174..faa0c5c 100644 --- a/requirements_cpu.txt +++ b/requirements_cpu.txt @@ -141,8 +141,8 @@ threadpoolctl==3.5.0 tifffile==2024.5.22 tinycss2==1.3.0 tomli==2.0.1 -torch == 2.3.0+cpu -torchvision == 0.18.0+cpu +torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.3.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl +torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.18.0%2Bcpu-cp310-cp310-linux_x86_64.whl tornado==6.4 tqdm==4.66.4 traitlets==5.14.3 From 95105ae616982aa8b85bc9121395d35e5b3ed0e6 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:07:23 -0500 Subject: [PATCH 22/30] test cpu torch --- .github/workflows/train.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 3e95519..6d27dac 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -38,8 +38,8 @@ jobs: sudo apt-get install -y libgl1 libglib2.0-0 pip install -r requirements_cpu.txt - python -c "import torch" - python -c "import torchvision" + python -c "import torch; print(torch.__version__)" + python -c "import torchvision; print(torchvision.__version__)" # Step 4: Run gRPC server and client - name: Run test From 3f2a673c2e6b994cb7654260c201f30667c21e91 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:42:14 -0500 Subject: [PATCH 23/30] test cpu torch --- requirements_cpu.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_cpu.txt b/requirements_cpu.txt index faa0c5c..928c7de 100644 --- a/requirements_cpu.txt +++ b/requirements_cpu.txt @@ -141,7 +141,7 @@ threadpoolctl==3.5.0 tifffile==2024.5.22 tinycss2==1.3.0 tomli==2.0.1 -torch @ https://download.pytorch.org/whl/cpu-cxx11-abi/torch-2.3.0%2Bcpu.cxx11.abi-cp310-cp310-linux_x86_64.whl +torch @ https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp310-cp310-linux_x86_64.whl torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.18.0%2Bcpu-cp310-cp310-linux_x86_64.whl tornado==6.4 tqdm==4.66.4 From 8e575aaa1298a989771bbcccf115719d381439ef Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:46:32 -0500 Subject: [PATCH 24/30] test cpu torch --- .github/workflows/train.yml | 2 -- src/algos/fl.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index 6d27dac..e176347 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -38,8 +38,6 @@ jobs: sudo apt-get install -y libgl1 libglib2.0-0 pip install -r requirements_cpu.txt - python -c "import torch; print(torch.__version__)" - python -c "import torchvision; print(torchvision.__version__)" # Step 4: Run gRPC server and client - name: Run test diff --git a/src/algos/fl.py b/src/algos/fl.py index 721912f..3cf86ba 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -243,7 +243,7 @@ def run_protocol(self): self.round_init() self.local_round_done() - self.single_round() + self.single_round(round) self.test() self.round_finalize() From d55dc125cf2c89ad8a86d7f2237bdedf72a29747 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:51:40 -0500 Subject: [PATCH 25/30] test cpu torch --- src/algos/fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 3cf86ba..18c0bab 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -216,7 +216,7 @@ def receive_and_aggregate(self): avg_wts = self.aggregate(reprs) self.set_representation(avg_wts) - def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1): + def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1, dump_file_name=""): """ Runs the whole training procedure. From ab2fcc6ec6d9564073d5d0e0e0d0d754e616f679 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Sun, 1 Dec 2024 16:57:33 -0500 Subject: [PATCH 26/30] requirements-cpu works --- src/algos/fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 18c0bab..3cf86ba 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -216,7 +216,7 @@ def receive_and_aggregate(self): avg_wts = self.aggregate(reprs) self.set_representation(avg_wts) - def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1, dump_file_name=""): + def single_round(self, round: int, attack_start_round: int = 0, attack_end_round: int = 1): """ Runs the whole training procedure. From c2ad00e37c148444e413d164c4e8569e43bb5075 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 17:17:04 -0500 Subject: [PATCH 27/30] changed check for attacks --- src/algos/fl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algos/fl.py b/src/algos/fl.py index 3cf86ba..08dafb2 100644 --- a/src/algos/fl.py +++ b/src/algos/fl.py @@ -227,7 +227,7 @@ def single_round(self, round: int, attack_start_round: int = 0, attack_end_round """ # Determine if the attack should be performed - attack_in_progress = self.gia_attacker and attack_start_round <= round <= attack_end_round + attack_in_progress = hasattr(FedAvgServer, "gia_attacker") and attack_start_round <= round <= attack_end_round if attack_in_progress: self.receive_attack_and_aggregate(round, attack_start_round, attack_end_round) From 8821d782d3efb49bcd617125aea482166c2e61e5 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 17:38:40 -0500 Subject: [PATCH 28/30] quorum being sent --- src/scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/scheduler.py b/src/scheduler.py index 23cc327..36aa1ae 100644 --- a/src/scheduler.py +++ b/src/scheduler.py @@ -130,6 +130,8 @@ def initialize(self, copy_souce_code: bool = True) -> None: comm_utils=self.communication, ) + self.communication.send_quorum() + def run_job(self) -> None: self.node.run_protocol() self.communication.finalize() From 2f0871a2299a2d05b1957782157540c006ed3909 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 17:53:45 -0500 Subject: [PATCH 29/30] removed debugging in git workflow --- .github/workflows/train.yml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/train.yml b/.github/workflows/train.yml index e176347..706211e 100644 --- a/.github/workflows/train.yml +++ b/.github/workflows/train.yml @@ -1,11 +1,13 @@ name: Test Training Code with gRPC on: - workflow_dispatch: + # used for debugging purposes + # workflow_dispatch: push: branches: - # - main - - "*" + # run test on push to main only + - main + # - "*" pull_request: branches: - main From 4bc63e53541c384b1d89866761040e30d6550a61 Mon Sep 17 00:00:00 2001 From: Kathryn Le Date: Mon, 2 Dec 2024 17:57:25 -0500 Subject: [PATCH 30/30] removed extra print statements --- src/utils/communication/mpi.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/utils/communication/mpi.py b/src/utils/communication/mpi.py index bd565f8..0a9216c 100644 --- a/src/utils/communication/mpi.py +++ b/src/utils/communication/mpi.py @@ -153,8 +153,6 @@ def all_gather(self): """ items: List[Any] = [] for i in range(1, self.size): - print(f"receiving this data: {self.receive(i)}") - print(f"receiving this data: {self.receive(i)}") items.append(self.receive(i)) return items