Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 53 additions & 13 deletions socketcomms/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from enum import Enum
import ipaddress
import struct
import time
import datetime
import socket,pickle,select
Expand All @@ -21,11 +22,11 @@
#
# -----------------------------------------------------------------------------


class BaseCommPoint:
"""
Communication point.
"""
_LEN_STRUCT = struct.Struct("!Q") # 8-byte unsigned length, network byte order

class Kind(Enum):
"""
Expand Down Expand Up @@ -87,6 +88,28 @@ def setDebug(self,st:bool = True):
Enable or disable debug messages.
"""
self._debug = st

def _recv_exact(self, nbytes: int) -> bytes:
"""Receive exactly nbytes from the socket.

Args:
nbytes (int): Number of bytes to read.

Returns:
bytes: The received bytes.

Raises:
RuntimeError: If the connection is closed before receiving enough data.
"""
chunks = []
remaining = nbytes
while remaining > 0:
chunk = self._sock.recv(min(self._datachunkmaxsize, remaining))
if chunk == b"":
raise RuntimeError("Connection closed while receiving")
chunks.append(chunk)
remaining -= len(chunk)
return b"".join(chunks)

def sendData(self, data: Dict) -> str:
"""
Expand All @@ -95,11 +118,19 @@ def sendData(self, data: Dict) -> str:
"""
if not self._begun:
raise RuntimeError("Cannot send data in not-begun commpoint")
mydictser = pickle.dumps(data)

try:
payload = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)

# Prefix with payload length to delimit messages over TCP stream.
header = BaseCommPoint._LEN_STRUCT.pack(len(payload))

if self._debug:
self._printInfo("Sending " + str(len(mydictser)) + " bytes...")
self._sock.send(mydictser)
self._printInfo(f"Sending framed msg: header={len(header)} bytes, payload={len(payload)} bytes...")

# sendall guarantees full transmission.
self._sock.sendall(header)
self._sock.sendall(payload)
if self._debug:
self._printInfo("\tSent ok.")
return ""
Expand All @@ -108,28 +139,35 @@ def sendData(self, data: Dict) -> str:

def readData(self, timeout: float = 2.0) -> Tuple[str, Dict]:
"""
Read the data (blocking if timeout > 0.0) from the other side.
Return non-empty string if any error in the connection (connection closed, timeout in receiving, user interrupt, etc.)
Read one message (blocking if timeout > 0.0) from the other side.
Return non-empty string if any error in the connection.
"""
if not self._begun:
raise RuntimeError("Cannot send data in not-begun commpoint")

if timeout <= 0.0:
timeout = None
self._sock.settimeout(timeout) # after this, we assume the other side has shut down
self._sock.settimeout(timeout)

try:
if self._debug:
self._printInfo("Receiving...")
data = self._sock.recv(self._datachunkmaxsize)
if data == b'':
raise(RuntimeError("Connection closed while receiving"))
result = pickle.loads(data)

# Read framed header (8 bytes) then payload.
header = self._recv_exact(BaseCommPoint._LEN_STRUCT.size)
(msg_len,) = BaseCommPoint._LEN_STRUCT.unpack(header)

payload = self._recv_exact(msg_len)
result = pickle.loads(payload)

if self._debug:
self._printInfo("\tReceived " + str(len(data)) + " bytes.")
self._printInfo(f"\tReceived framed msg: payload={msg_len} bytes.")
res = ""
except Exception as e:
result = None
res = str(e)
self._sock.settimeout(None) # to deactivate timeout in other operations

self._sock.settimeout(None)
return res, result

def checkDataToRead(self):
Expand Down Expand Up @@ -200,6 +238,7 @@ def begin(self,timeoutaccept: float) -> str:
self._basesock.settimeout(timeoutaccept) # after this, we assume the other side has shut down
try:
self._sock, _ = self._basesock.accept() # wait for calling us
self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # disable Nagle's algorithm to reduce latency spikes
self._begun = True
self._basesock.settimeout(None) # to deactivate timeout in other operations
return ""
Expand Down Expand Up @@ -250,6 +289,7 @@ def begin(self) -> str:
self.end()
try:
self._sock = socket.socket(socket.AF_INET,socket.SOCK_STREAM) # 1st arg: ip4, 2nd: TCP
self._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # disable Nagle's algorithm to reduce latency spikes
self._sock.connect((self._ipv4,self._port)) # if bind-listen has been done on the other side but accept has not, ends immediately even when the server is not accpeting at the time (connection is kept pending), and data can be sent; if bind-listen has not been done on the other side, an error is raised
self._begun = True
return ""
Expand Down
189 changes: 178 additions & 11 deletions spindecoupler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,12 @@ def resetGetObs(self, timeout: float = 10.0):
if len(res) > 0:
raise RuntimeError("Error sending what to do to the agent. " + res)

res,obsato = self._rlcomm.readData(timeout)
res,obs = self._rlcomm.readData(timeout)
if len(res) > 0:
raise RuntimeError("Error reading after-reset observation from "
"the agent. " + res)

return obsato["obs"], obsato["ato"] # return tuple
return obs["obs"], obs["ato"], obs["info"] # return observation + ato + extra info


def stepSendActGetObs(self, action,timeout:float = 10.0):
Expand Down Expand Up @@ -103,7 +103,7 @@ def stepSendActGetObs(self, action,timeout:float = 10.0):
if len(res) > 0:
raise RuntimeError("Error receiving step observation: " + res)

return lat["lat"], obsrewato["obs"], obsrewato["rew"], obsrewato["ato"]
return lat["lat_sim"], lat["lat_wall"], obsrewato["obs"], obsrewato["rew"], obsrewato["ato"]


def stepExpFinished(self, timeout:float = 10.0):
Expand All @@ -115,8 +115,8 @@ def stepExpFinished(self, timeout:float = 10.0):
"""

self._rlcomm.sendData(dict({"stepkind": "finish"}))


#-------------------------------------------------------------------------------
#
# Base Class: AgentSide
Expand Down Expand Up @@ -219,7 +219,7 @@ def readWhatToDo(self, timeout:float = 10.0):
raise(ValueError("Unknown what-to-do indicator [" +
ind["stepkind"] + "]"))

def stepSendLastActDur(self, lat:float):
def stepSendLastActDur(self, lat_sim:float, lat_wall:float):
"""
Call this method after receiving a REC_ACTION_SEND_OBS and starting the
action, being LAT the actual time during which the action previous to
Expand All @@ -228,7 +228,7 @@ def stepSendLastActDur(self, lat:float):
This method can raise RuntimeError if any error occurs in comms.
"""

res = self._rlcomm.sendData(dict({"lat": lat}))
res = self._rlcomm.sendData(dict({"lat_sim": lat_sim, "lat_wall":lat_wall}))
if len(res) > 0:
raise RuntimeError("Error sending lat to RL. " + res)

Expand All @@ -249,16 +249,183 @@ def stepSendObs(self, obs, agenttime:float = 0.0, rew:float = 0.0):
raise RuntimeError("Error sending observation/reward to RL. " + res)


def resetSendObs(self,obs,agenttime = 0.0):
def resetSendObs(self,obs,agenttime = 0.0, extra_info = {}):
"""
Call this method if readWhatToDo() returned RESET_SEND_OBS to send back
the first observation (OBS, a dictionary) got after an episode reset,
along with the time (of the agent) when that observation was gathered.
This method can raise RuntimeError if any error occurs in comms.
"""

res = self._rlcomm.sendData({"obs":obs,"ato":agenttime})
res = self._rlcomm.sendData({"obs":obs,"ato":agenttime, "info": extra_info})
if len(res) > 0:
raise RuntimeError("Error sending observation to RL. " + res)





#-------------------------------------------------------------------------------
#
# Base Class: RLSideQuery
#
#-------------------------------------------------------------------------------


class RLSideQuery:
"""
Just answers queries from the agent.

"""
def __init__(self, port: int, verbose: bool = False):
self._verbose = verbose
self._srv = ServerCommPoint(port)
if self._verbose:
print(f"[Query Server] Waiting for agent query connection on port {port}...")
res = self._srv.begin(timeoutaccept=60.0)
if len(res) > 0:
raise RuntimeError("No agent connection for query channel: " + res)
if self._verbose:
print("[Query Server] Agent connected for queries.")

def __del__(self):
res = self._srv.end()
if len(res) > 0:
print("Error closing query channel (RL side): " + res)
if self._verbose:
print("Communications closed in the RL side.")

def reconnect(self, timeoutaccept: float = 60.0):
"""
Closes the current connection and waits for a new agent connection.
"""
res = self._srv.end()
if len(res) > 0:
raise RuntimeError("Error closing query channel (RL side): " + res)
if self._verbose:
print("[Query Server] Waiting for agent query reconnection...")
res = self._srv.begin(timeoutaccept=timeoutaccept)
if len(res) > 0:
raise RuntimeError("No agent reconnection for query channel: " + res)
if self._verbose:
print("[Query Server] Agent reconnected for queries.")


def receive_query(self, timeout: float = -1.0):
"""
Blocks until receiving a new query or until timeout ends (if >=0).
Returns the received dict (e.g.: {"stepkind":"query","obs":{...}}).
"""
res, msg = self._srv.readData(timeout)
if len(res) > 0:
raise RuntimeError("Error reading query from agent: " + res)
return msg


def send_action(self, action_dict):
"""
Sends the predicted action to the agent as a dictionary.
"""
res = self._srv.sendData({"action": action_dict})
if len(res) > 0:
raise RuntimeError("Error sending action to agent (query channel): " + res)


def wait_for_query(self, timeout: float = -1.0) -> bool:
"""Block until a 'query' flag arrives (or timeout expires).

Args:
timeout (float): Maximum time in seconds to wait for data.
A negative value means "block indefinitely".

Returns:
bool: True if a valid query flag was received, False if no data
were available within the given timeout (when timeout >= 0).

Raises:
RuntimeError: On communication errors.
ValueError: On unexpected message format or stepkind.
"""
# Non-blocking behavior when timeout >= 0 and no data pending:
if timeout >= 0.0 and not self._srv.checkDataToRead():
return False

res, msg = self._srv.readData(timeout)
if len(res) > 0:
raise RuntimeError("Error reading query from agent: " + res)

if not isinstance(msg, dict):
raise ValueError(f"[Query Server] Unexpected message type: {type(msg)}")

stepkind = msg.get("stepkind", None)
if stepkind != "query":
raise ValueError(
f"[Query Server] Unexpected stepkind while waiting for query: {stepkind}"
)

if self._verbose:
print("[Query Server] Query flag received.")

return True


#-------------------------------------------------------------------------------
#
# Base Class: AgentSideQuery
#
#-------------------------------------------------------------------------------


class AgentSideQuery:
"""
Client for sending virtual observations to the RL side and received their
corresponding actions.
"""

def __init__(self, ip_rl: str, port_rl: int, verbose: bool = False):
self._verbose = verbose
self._comm = ClientCommPoint(ip_rl, port_rl)
res = self._comm.begin()
if len(res) > 0:
raise RuntimeError("[Query client] Error connecting query channel to RL: " + res)
if self._verbose:
print("[Query client] Agent connected to RL query channel.")

def __del__(self):
res = self._comm.end()
if len(res) > 0:
print("[Query client] Error closing query channel (agent side): " + res)
if self._verbose:
print("[Query client] Communications closed in the Agent side.")

def query_action(self, obs_dict, timeout: float = 10.0):
"""
Sends the observations and receives an action in response.
"""
res = self._comm.sendData({"stepkind": "query", "obs": obs_dict})
if len(res) > 0:
raise RuntimeError("[Query client] Error sending query obs to RL: " + res)

res, msg = self._comm.readData(timeout)
if len(res) > 0:
raise RuntimeError("[Query client] Error receiving query action from RL: " + res)

return msg["action"]


def send_query(self):
"""Send a 'query' flag to the RL side.

The message is a small dictionary with the single field::

{"stepkind": "query"}

The RL side will interpret this as a request to perform exactly one
evaluation/test step using its own environment and observation.
"""
res = self._comm.sendData({"stepkind": "query"})
if len(res) > 0:
raise RuntimeError(
"[Query client] Error sending query flag to RL: " + res
)

if self._verbose:
print("[Query client] Query flag sent to RL.")