Skip to content

Commit

Permalink
Fixed communication and added ZMQ to UDP partly
Browse files Browse the repository at this point in the history
  • Loading branch information
flimdejong committed Jan 29, 2025
1 parent 20c9817 commit c40c92a
Show file tree
Hide file tree
Showing 10 changed files with 407 additions and 143 deletions.
20 changes: 20 additions & 0 deletions Dockerfile.zmq_server
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
FROM python:3.9-slim

# Install system dependencies
RUN apt-get update && \
apt-get install -y libzmq3-dev build-essential && \
apt-get clean && \
rm -rf /var/lib/apt/lists/*

# Install Python dependencies
RUN pip install pyzmq protobuf

# Create working directory
WORKDIR /app

# Copy the Python script and proto files
COPY roboteam_ai/src/rl/src/zmq_server.py /app/
COPY /roboteam_networking/ /app/roboteam_networking/

# Run the script
CMD ["python", "zmq_server.py"]
24 changes: 24 additions & 0 deletions docker/runner/ray-cluster-combined.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,27 @@ spec:
cpu: 100m
memory: 90Mi

# ZMQ UDP Bridge Server
- name: zmq-udp-bridge
image: roboteamtwente/roboteam:zmq-server
command: ["python"]
args: ["/app/zmq_server.py"]
ports:
- containerPort: 5557
protocol: TCP
- containerPort: 10300
protocol: UDP
env:
- name: LD_LIBRARY_PATH
value: /home/roboteam/build/release/lib
resources:
requests:
cpu: 60m
memory: 60Mi
limits:
cpu: 100m
memory: 90Mi

# # Multicast Receiver 1
# - name: multicast-receiver1
# image: python:3.9-slim
Expand Down Expand Up @@ -394,6 +415,9 @@ spec:
- name: zmq
port: 5559
targetPort: 5559
- name: zmq_server
port: 5557
targetPort: 5557
- name: ref
port: 10003
targetPort: 10003
71 changes: 28 additions & 43 deletions roboteam_ai/src/rl/env_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class RoboTeamEnv(gymnasium.Env):

def __init__(self, config=None):

print(subprocess.check_output(['ip', 'addr', 'show']).decode())
self.config = config or {} # Config placeholder
self.MAX_ROBOTS_US = 10

Expand Down Expand Up @@ -72,15 +71,14 @@ def __init__(self, config=None):
self.is_yellow_dribbling = False
self.is_blue_dribbling = False

# Set first_step flag
self.is_first_step = True

# Add previous dribbling state
self.previous_yellow_dribbling = False
self.step_taken = False

# Add previous ref command state
self.previous_ref_command = None

self.last_step_time = 0
self.min_step_interval = 1.5 # Minimum time between steps in seconds
self.min_step_interval = 2 # Minimum time between steps in seconds

def teleport_ball_with_check(self, x, y):
"""
Expand All @@ -99,7 +97,7 @@ def teleport_ball_with_check(self, x, y):
# Verify ball position
ball_pos, _ = get_ball_state()
if np.allclose(ball_pos, [x, y], atol=0.1): # Checks the real ball position from get_ball_state with the input of the function
print(f"Ball teleport successful on attempt {i+1}")
# print(f"Ball teleport successful on attempt {i+1}")
time.sleep(1)
return
except Exception as e:
Expand Down Expand Up @@ -129,10 +127,6 @@ def calculate_reward(self):
self.shaped_reward_given = True # Set it to true
shaped_reward = 0.1

# # If it gets a yellow card/ three times a foul, punish and reset
# if self.yellow_yellow_cards or self.blue_yellow_cards >= 1:
# yellow_card_punishment = 1

# Calculate final reward
reward = goal_scored_reward + shaped_reward

Expand Down Expand Up @@ -164,7 +158,7 @@ def step(self, action):
"""
The step function waits for either:
1. A true possession change (lost ball to opponent or gained from opponent)
2. Specific referee commands (8,9)
2. Specific referee commands (16,17)
With rate limiting to prevent too frequent steps
"""

Expand All @@ -175,13 +169,17 @@ def step(self, action):
# Get referee state
self.yellow_score, self.blue_score, self.stage, self.ref_command, self.x, self.y = get_referee_state()

# Check termination conditions first
truncated = self.is_truncated()
done = self.is_terminated()

if truncated or done:
reward = self.calculate_reward()
return observation_space, reward, done, truncated, {}

if self.ref_command in (16,17): # If there is ball placement
self.teleport_ball_with_check(self.x, self.y)

# Reset step_taken flag if ref_command is 2
if self.ref_command == 2:
self.step_taken = False

# Check for true possession change
possession_changed = (
(self.is_yellow_dribbling != self.previous_yellow_dribbling) and
Expand All @@ -196,35 +194,29 @@ def step(self, action):
should_take_step = (
can_take_step and (
possession_changed or
(self.ref_command in (8,9) and not self.step_taken)
self.ref_command in (16,17)
)
)

if should_take_step:
print(f"Taking step (time since last: {time_since_last_step:.2f}s)")
print(f"Taking step: {action} (time since last: {time_since_last_step:.2f}s)")
print("ref_command=", self.ref_command)
# Update last step time
self.last_step_time = current_time

# Execute action
send_num_attackers(action)

# Mark step as taken for this ref state
if self.ref_command in (8,9):
self.step_taken = True

# Update previous possession state
# Update previous states
self.previous_yellow_dribbling = self.is_yellow_dribbling
self.previous_ref_command = self.ref_command

reward = self.calculate_reward()

# Check for termination
truncated = self.is_truncated()
done = self.is_terminated()

return observation_space, reward, done, truncated, {}

# Update previous ref command even when not stepping
self.previous_ref_command = self.ref_command
time.sleep(0.1)


def is_terminated(self):
"""
Expand All @@ -245,6 +237,7 @@ def is_truncated(self):
is_truncated checks if game should end prematurely:
- On HALT command with no goals scored
- On STOP command
- On FORCE_START command
"""
if self.ref_command == 0: # HALT
if (self.yellow_score == 0 and self.blue_score == 0):
Expand All @@ -253,6 +246,9 @@ def is_truncated(self):
if self.ref_command == 1: # STOP
print("Game truncated, random STOP called")
return True
if self.ref_command == 4: # FORCE_START
print("Game truncated, FORCE_START called")
return True
return False

def reset(self, seed=None, options=None):
Expand All @@ -274,8 +270,8 @@ def reset(self, seed=None, options=None):
self.shaped_reward_given = False
self.is_yellow_dribbling = False
self.is_blue_dribbling = False
self.previous_yellow_dribbling = False # Add this if not already in init
self.step_taken = False # Reset step taken flag
self.previous_yellow_dribbling = False
self.previous_ref_command = None # Reset previous ref command

# Reset physical environment
print("Resetting physical environment...")
Expand Down Expand Up @@ -320,16 +316,5 @@ def reset(self, seed=None, options=None):
# Provide fallback observation if needed
observation = np.zeros(47, dtype=np.float64)

# Set first_step flag
self.is_first_step = True

return observation, {}









return observation, {}
2 changes: 1 addition & 1 deletion roboteam_ai/src/rl/src/changeGameState.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def start_game():
# stop() # Then stop to prepare for start
# time.sleep(10) # Regular sleep instead of asyncio.sleep
kickoff("BLUE") # Set up kickoff for Blue team
# normal_start() # Start the game normally
normal_start() # Start the game normally

if __name__ == "__main__":
print("Connecting to game controller...")
Expand Down
105 changes: 65 additions & 40 deletions roboteam_ai/src/rl/src/getState.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_zmq_address_2():
else:
host = "localhost"
#print("Running locally")
return f"tcp://{host}:5559"
return f"tcp://{host}:5551"

# Function to get the ball state
def get_ball_state():
Expand Down Expand Up @@ -75,7 +75,7 @@ def get_ball_state():
if world.HasField("ball"):
ball_position[0] = world.ball.pos.x
ball_position[1] = world.ball.pos.y

#print("x",ball_position[0])
#print("y",ball_position[1])

Expand Down Expand Up @@ -167,6 +167,11 @@ def get_referee_state_old():
if referee_state.HasField('designated_position'):
x = referee_state.designated_position.x
y = referee_state.designated_position.y

# print(f"Yellow score: {referee_state.yellow.score}")
# print(f"Blue score: {referee_state.blue.score}")
# print(f"Stage: {referee_state.stage}")
# print(f"Command: {referee_state.command}")

return (
referee_state.yellow.score,
Expand All @@ -187,45 +192,65 @@ def get_referee_state_old():
def get_referee_state():
"""
Returns tuple of (yellow_score, blue_score, stage, command, x, y)
Uses ZMQ-based implementation in Kubernetes, UDP multicast implementation otherwise
"""
context = zmq.Context()
socket_ref = context.socket(zmq.SUB)
socket_ref.setsockopt_string(zmq.SUBSCRIBE, "") # Match the working pattern

zmq_address = get_zmq_address_2()
print(f"Connecting to ZMQ address: {zmq_address}")
socket_ref.connect(zmq_address)

try:
# Just receive the message directly like in get_robot_state()
message = socket_ref.recv()
referee_state = Referee()
referee_state.ParseFromString(message)

x = y = 0
if referee_state.HasField('designated_position'):
x = referee_state.designated_position.x
y = referee_state.designated_position.y

print("get_referee_state no error")
return (
referee_state.yellow.score,
referee_state.blue.score,
referee_state.stage,
referee_state.command,
x/1000,
y/1000
)

except DecodeError as e:
print(f"Proto decode error: {e}")
return 0, 0, 0, 0, 0, 0
except zmq.ZMQError as e:
print(f"ZMQ Error: {e}")
return 0, 0, 0, 0, 0, 0
finally:
socket_ref.close()
context.term()
if is_kubernetes():
# Use the new ZMQ-based implementation for Kubernetes
context = zmq.Context()
socket_ref = context.socket(zmq.SUB)
socket_ref.setsockopt_string(zmq.SUBSCRIBE, "")

zmq_address = get_zmq_address_2()
# print(f"Connecting to ZMQ address: {zmq_address}")
socket_ref.connect(zmq_address)

try:
# print("Waiting to receive message...")
message = socket_ref.recv()
# print(f"Received message of length: {len(message)}")
# print(f"Message first 20 bytes: {message[:20]}")

referee_state = Referee()
# print("Created Referee object, attempting to parse...")
referee_state.ParseFromString(message)
# print("Successfully parsed protobuf message")

x = y = 0
if referee_state.HasField('designated_position'):
x = referee_state.designated_position.x
y = referee_state.designated_position.y
# print(f"Got designated position: ({x/1000}, {y/1000})")
else:
print("No designated position in message")

# print(f"Yellow score: {referee_state.yellow.score}")
# print(f"Blue score: {referee_state.blue.score}")
# print(f"Stage: {referee_state.stage}")
# print(f"Command: {referee_state.command}")

return (
referee_state.yellow.score,
referee_state.blue.score,
referee_state.stage,
referee_state.command,
x/1000,
y/1000
)

except DecodeError as e:
print(f"Proto decode error: {e}")
# print(f"Failed message content: {message.hex()}") # Print hex representation
return 0, 0, 0, 0, 0, 0
except zmq.ZMQError as e:
print(f"ZMQ Error: {e}")
print(f"Error details: {str(e)}")
return 0, 0, 0, 0, 0, 0
finally:
socket_ref.close()
context.term()
else:
# Use the old UDP multicast implementation for non-Kubernetes environments
return get_referee_state_old()

if __name__ == "__main__":
# Get robot state
Expand Down
4 changes: 2 additions & 2 deletions roboteam_ai/src/rl/src/multicast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
# Setup ZMQ publisher
context = zmq.Context()
zmq_socket = context.socket(zmq.PUB)
zmq_socket.bind("tcp://*:5559")
print("ZMQ publisher started on port 5559", file=sys.stderr)
zmq_socket.bind("tcp://*:5551")
print("ZMQ publisher started on port 5551", file=sys.stderr)

# Setup multicast receiver
multicast_group = '224.5.23.1'
Expand Down
Loading

0 comments on commit c40c92a

Please sign in to comment.