diff --git a/socketcomms/comms.py b/socketcomms/comms.py index 24876f4..bc79bd0 100644 --- a/socketcomms/comms.py +++ b/socketcomms/comms.py @@ -9,6 +9,7 @@ from enum import Enum import ipaddress +import struct import time import datetime import socket,pickle,select @@ -21,11 +22,11 @@ # # ----------------------------------------------------------------------------- - class BaseCommPoint: """ Communication point. """ + _LEN_STRUCT = struct.Struct("!Q") # 8-byte unsigned length, network byte order class Kind(Enum): """ @@ -87,6 +88,28 @@ def setDebug(self,st:bool = True): Enable or disable debug messages. """ self._debug = st + + def _recv_exact(self, nbytes: int) -> bytes: + """Receive exactly nbytes from the socket. + + Args: + nbytes (int): Number of bytes to read. + + Returns: + bytes: The received bytes. + + Raises: + RuntimeError: If the connection is closed before receiving enough data. + """ + chunks = [] + remaining = nbytes + while remaining > 0: + chunk = self._sock.recv(min(self._datachunkmaxsize, remaining)) + if chunk == b"": + raise RuntimeError("Connection closed while receiving") + chunks.append(chunk) + remaining -= len(chunk) + return b"".join(chunks) def sendData(self, data: Dict) -> str: """ @@ -95,11 +118,19 @@ def sendData(self, data: Dict) -> str: """ if not self._begun: raise RuntimeError("Cannot send data in not-begun commpoint") - mydictser = pickle.dumps(data) + try: + payload = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + + # Prefix with payload length to delimit messages over TCP stream. + header = BaseCommPoint._LEN_STRUCT.pack(len(payload)) + if self._debug: - self._printInfo("Sending " + str(len(mydictser)) + " bytes...") - self._sock.send(mydictser) + self._printInfo(f"Sending framed msg: header={len(header)} bytes, payload={len(payload)} bytes...") + + # sendall guarantees full transmission. + self._sock.sendall(header) + self._sock.sendall(payload) if self._debug: self._printInfo("\tSent ok.") return "" @@ -108,28 +139,35 @@ def sendData(self, data: Dict) -> str: def readData(self, timeout: float = 2.0) -> Tuple[str, Dict]: """ - Read the data (blocking if timeout > 0.0) from the other side. - Return non-empty string if any error in the connection (connection closed, timeout in receiving, user interrupt, etc.) + Read one message (blocking if timeout > 0.0) from the other side. + Return non-empty string if any error in the connection. """ if not self._begun: raise RuntimeError("Cannot send data in not-begun commpoint") + if timeout <= 0.0: timeout = None - self._sock.settimeout(timeout) # after this, we assume the other side has shut down + self._sock.settimeout(timeout) + try: if self._debug: self._printInfo("Receiving...") - data = self._sock.recv(self._datachunkmaxsize) - if data == b'': - raise(RuntimeError("Connection closed while receiving")) - result = pickle.loads(data) + + # Read framed header (8 bytes) then payload. + header = self._recv_exact(BaseCommPoint._LEN_STRUCT.size) + (msg_len,) = BaseCommPoint._LEN_STRUCT.unpack(header) + + payload = self._recv_exact(msg_len) + result = pickle.loads(payload) + if self._debug: - self._printInfo("\tReceived " + str(len(data)) + " bytes.") + self._printInfo(f"\tReceived framed msg: payload={msg_len} bytes.") res = "" except Exception as e: result = None res = str(e) - self._sock.settimeout(None) # to deactivate timeout in other operations + + self._sock.settimeout(None) return res, result def checkDataToRead(self): @@ -200,6 +238,7 @@ def begin(self,timeoutaccept: float) -> str: self._basesock.settimeout(timeoutaccept) # after this, we assume the other side has shut down try: self._sock, _ = self._basesock.accept() # wait for calling us + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # disable Nagle's algorithm to reduce latency spikes self._begun = True self._basesock.settimeout(None) # to deactivate timeout in other operations return "" @@ -250,6 +289,7 @@ def begin(self) -> str: self.end() try: self._sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM) # 1st arg: ip4, 2nd: TCP + self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # disable Nagle's algorithm to reduce latency spikes self._sock.connect((self._ipv4,self._port)) # if bind-listen has been done on the other side but accept has not, ends immediately even when the server is not accpeting at the time (connection is kept pending), and data can be sent; if bind-listen has not been done on the other side, an error is raised self._begun = True return "" diff --git a/spindecoupler.py b/spindecoupler.py index f639872..f6547c6 100644 --- a/spindecoupler.py +++ b/spindecoupler.py @@ -66,12 +66,12 @@ def resetGetObs(self, timeout: float = 10.0): if len(res) > 0: raise RuntimeError("Error sending what to do to the agent. " + res) - res,obsato = self._rlcomm.readData(timeout) + res,obs = self._rlcomm.readData(timeout) if len(res) > 0: raise RuntimeError("Error reading after-reset observation from " "the agent. " + res) - return obsato["obs"], obsato["ato"] # return tuple + return obs["obs"], obs["ato"], obs["info"] # return observation + ato + extra info def stepSendActGetObs(self, action,timeout:float = 10.0): @@ -103,7 +103,7 @@ def stepSendActGetObs(self, action,timeout:float = 10.0): if len(res) > 0: raise RuntimeError("Error receiving step observation: " + res) - return lat["lat"], obsrewato["obs"], obsrewato["rew"], obsrewato["ato"] + return lat["lat_sim"], lat["lat_wall"], obsrewato["obs"], obsrewato["rew"], obsrewato["ato"] def stepExpFinished(self, timeout:float = 10.0): @@ -115,8 +115,8 @@ def stepExpFinished(self, timeout:float = 10.0): """ self._rlcomm.sendData(dict({"stepkind": "finish"})) - - + + #------------------------------------------------------------------------------- # # Base Class: AgentSide @@ -219,7 +219,7 @@ def readWhatToDo(self, timeout:float = 10.0): raise(ValueError("Unknown what-to-do indicator [" + ind["stepkind"] + "]")) - def stepSendLastActDur(self, lat:float): + def stepSendLastActDur(self, lat_sim:float, lat_wall:float): """ Call this method after receiving a REC_ACTION_SEND_OBS and starting the action, being LAT the actual time during which the action previous to @@ -228,7 +228,7 @@ def stepSendLastActDur(self, lat:float): This method can raise RuntimeError if any error occurs in comms. """ - res = self._rlcomm.sendData(dict({"lat": lat})) + res = self._rlcomm.sendData(dict({"lat_sim": lat_sim, "lat_wall":lat_wall})) if len(res) > 0: raise RuntimeError("Error sending lat to RL. " + res) @@ -249,7 +249,7 @@ def stepSendObs(self, obs, agenttime:float = 0.0, rew:float = 0.0): raise RuntimeError("Error sending observation/reward to RL. " + res) - def resetSendObs(self,obs,agenttime = 0.0): + def resetSendObs(self,obs,agenttime = 0.0, extra_info = {}): """ Call this method if readWhatToDo() returned RESET_SEND_OBS to send back the first observation (OBS, a dictionary) got after an episode reset, @@ -257,8 +257,175 @@ def resetSendObs(self,obs,agenttime = 0.0): This method can raise RuntimeError if any error occurs in comms. """ - res = self._rlcomm.sendData({"obs":obs,"ato":agenttime}) + res = self._rlcomm.sendData({"obs":obs,"ato":agenttime, "info": extra_info}) if len(res) > 0: raise RuntimeError("Error sending observation to RL. " + res) - - + + + +#------------------------------------------------------------------------------- +# +# Base Class: RLSideQuery +# +#------------------------------------------------------------------------------- + + +class RLSideQuery: + """ + Just answers queries from the agent. + + """ + def __init__(self, port: int, verbose: bool = False): + self._verbose = verbose + self._srv = ServerCommPoint(port) + if self._verbose: + print(f"[Query Server] Waiting for agent query connection on port {port}...") + res = self._srv.begin(timeoutaccept=60.0) + if len(res) > 0: + raise RuntimeError("No agent connection for query channel: " + res) + if self._verbose: + print("[Query Server] Agent connected for queries.") + + def __del__(self): + res = self._srv.end() + if len(res) > 0: + print("Error closing query channel (RL side): " + res) + if self._verbose: + print("Communications closed in the RL side.") + + def reconnect(self, timeoutaccept: float = 60.0): + """ + Closes the current connection and waits for a new agent connection. + """ + res = self._srv.end() + if len(res) > 0: + raise RuntimeError("Error closing query channel (RL side): " + res) + if self._verbose: + print("[Query Server] Waiting for agent query reconnection...") + res = self._srv.begin(timeoutaccept=timeoutaccept) + if len(res) > 0: + raise RuntimeError("No agent reconnection for query channel: " + res) + if self._verbose: + print("[Query Server] Agent reconnected for queries.") + + + def receive_query(self, timeout: float = -1.0): + """ + Blocks until receiving a new query or until timeout ends (if >=0). + Returns the received dict (e.g.: {"stepkind":"query","obs":{...}}). + """ + res, msg = self._srv.readData(timeout) + if len(res) > 0: + raise RuntimeError("Error reading query from agent: " + res) + return msg + + + def send_action(self, action_dict): + """ + Sends the predicted action to the agent as a dictionary. + """ + res = self._srv.sendData({"action": action_dict}) + if len(res) > 0: + raise RuntimeError("Error sending action to agent (query channel): " + res) + + + def wait_for_query(self, timeout: float = -1.0) -> bool: + """Block until a 'query' flag arrives (or timeout expires). + + Args: + timeout (float): Maximum time in seconds to wait for data. + A negative value means "block indefinitely". + + Returns: + bool: True if a valid query flag was received, False if no data + were available within the given timeout (when timeout >= 0). + + Raises: + RuntimeError: On communication errors. + ValueError: On unexpected message format or stepkind. + """ + # Non-blocking behavior when timeout >= 0 and no data pending: + if timeout >= 0.0 and not self._srv.checkDataToRead(): + return False + + res, msg = self._srv.readData(timeout) + if len(res) > 0: + raise RuntimeError("Error reading query from agent: " + res) + + if not isinstance(msg, dict): + raise ValueError(f"[Query Server] Unexpected message type: {type(msg)}") + + stepkind = msg.get("stepkind", None) + if stepkind != "query": + raise ValueError( + f"[Query Server] Unexpected stepkind while waiting for query: {stepkind}" + ) + + if self._verbose: + print("[Query Server] Query flag received.") + + return True + + +#------------------------------------------------------------------------------- +# +# Base Class: AgentSideQuery +# +#------------------------------------------------------------------------------- + + +class AgentSideQuery: + """ + Client for sending virtual observations to the RL side and received their + corresponding actions. + """ + + def __init__(self, ip_rl: str, port_rl: int, verbose: bool = False): + self._verbose = verbose + self._comm = ClientCommPoint(ip_rl, port_rl) + res = self._comm.begin() + if len(res) > 0: + raise RuntimeError("[Query client] Error connecting query channel to RL: " + res) + if self._verbose: + print("[Query client] Agent connected to RL query channel.") + + def __del__(self): + res = self._comm.end() + if len(res) > 0: + print("[Query client] Error closing query channel (agent side): " + res) + if self._verbose: + print("[Query client] Communications closed in the Agent side.") + + def query_action(self, obs_dict, timeout: float = 10.0): + """ + Sends the observations and receives an action in response. + """ + res = self._comm.sendData({"stepkind": "query", "obs": obs_dict}) + if len(res) > 0: + raise RuntimeError("[Query client] Error sending query obs to RL: " + res) + + res, msg = self._comm.readData(timeout) + if len(res) > 0: + raise RuntimeError("[Query client] Error receiving query action from RL: " + res) + + return msg["action"] + + + def send_query(self): + """Send a 'query' flag to the RL side. + + The message is a small dictionary with the single field:: + + {"stepkind": "query"} + + The RL side will interpret this as a request to perform exactly one + evaluation/test step using its own environment and observation. + """ + res = self._comm.sendData({"stepkind": "query"}) + if len(res) > 0: + raise RuntimeError( + "[Query client] Error sending query flag to RL: " + res + ) + + if self._verbose: + print("[Query client] Query flag sent to RL.")