Skip to content

Commit a871b2b

Browse files
add materials for lab 3
1 parent 28cfb16 commit a871b2b

File tree

4 files changed

+1024
-0
lines changed

4 files changed

+1024
-0
lines changed

ROS_Core/src/Labs/Lab3/mdp.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import numpy as np
2+
import math
3+
# class MDP that takes in number of state, action and allow users to update
4+
# transition functions through gradually inputting logic
5+
6+
class MDP():
7+
"""
8+
The following class defines a MDP
9+
Running example MDP with 3 state variables: a, b, c
10+
with a = [1, 2, 3]
11+
b = [4, 5, 6, 7]
12+
c = [10, 12, 13, 15]
13+
Three different actions ["forward", "left", "right"]
14+
"""
15+
16+
def __init__ (self, states=None, actions=None, r=-1.0, method="replace"):
17+
"""
18+
Define the value list of states and actions
19+
State should be a list of list, with sub-lists are lists of state values
20+
21+
input: states=[a, b, c]
22+
23+
Action should be a list
24+
25+
input: actions = ["forward", "left", "right"]
26+
27+
28+
method: replace or add: replacing the current element in P matrix if value is > 0 or add into. Default to replace.
29+
"""
30+
self.s = states
31+
self.a = actions
32+
self.num_s_vars = len(self.s)
33+
self.num_s = 1
34+
self.s_range = []
35+
for i in range(self.num_s_vars):
36+
self.s_range.append(len(self.s[i]))
37+
self.num_s = self.num_s * len(self.s[i])
38+
self.num_a = len(self.a)
39+
# transition function
40+
# P[new_state, current_state, action] = probability
41+
self.P = np.zeros((self.num_s, self.num_s, self.num_a))
42+
43+
# reward function
44+
self.R = r * np.ones((self.num_s, self.num_a))
45+
46+
self.method = method
47+
48+
def add_route(self, current_state, action, new_state, p=1.0):
49+
"""
50+
Add new transition route to MDP
51+
Default probability is 1.0
52+
current_state (list): [a_i, b_i, c_i]. E.g: [1, 4, 15]
53+
new_state (list): [a_ii, b_ii, c_ii]. E.g: [1, 5, 10]
54+
action: "forward"
55+
"""
56+
# get correct index of MDP action from input action
57+
action_index = self.a.index(action)
58+
59+
if self.P[self.get_index(new_state), self.get_index(current_state), action_index] > 0.0:
60+
# print("Warning: Already have prop value of {}. Use {} to deal with".format(self.P[self.get_index(new_state), self.get_index(current_state), action_index], self.method))
61+
if self.method == "replace":
62+
self.P[self.get_index(new_state), self.get_index(current_state), action_index] = p
63+
else:
64+
self.P[self.get_index(new_state), self.get_index(current_state), action_index] += p
65+
else:
66+
self.P[self.get_index(new_state), self.get_index(current_state), action_index] = p
67+
68+
def add_reward(self, state, action, reward):
69+
self.R[self.get_index(state), self.a.index(action)] = reward
70+
71+
def get_state(self, index):
72+
"""
73+
Get state index tuple from index
74+
E.g.
75+
Input: 0
76+
Output: [0, 0, 0]
77+
"""
78+
state_index = []
79+
divisor = self.num_s
80+
dividend = index
81+
for i in range(self.num_s_vars-1, -1, -1):
82+
divisor = divisor / self.s_range[i]
83+
quotient = math.floor(dividend / divisor)
84+
remainder = dividend % divisor
85+
state_index.append(quotient)
86+
dividend = remainder
87+
return np.flip(state_index)
88+
89+
def get_real_state_value(self, index):
90+
"""
91+
Get state tuple from index
92+
E.g.
93+
Input: 0
94+
Output: [1, 4, 10]
95+
"""
96+
index = self.get_state(index)
97+
return [self.s[i][v] for i, v in enumerate(index)]
98+
99+
def get_index(self, state):
100+
"""
101+
Get index from state tuple
102+
E.g.
103+
Input: [1, 4, 10]
104+
Output: 0
105+
"""
106+
# get state index tuple from state tuple
107+
state_index = [self.s[i].index(v) for i, v in enumerate(state)]
108+
cur_mul = self.num_s
109+
index = 0
110+
for i in range(self.num_s_vars-1, -1, -1):
111+
cur_mul = cur_mul / self.s_range[i]
112+
index = index + state_index[i] * cur_mul
113+
return int(index)
114+
115+
def get_mdp(self):
116+
return self.num_a, self.num_s, self.R, self.P

0 commit comments

Comments
 (0)