1
+ import numpy as np
2
+ import math
3
+ # class MDP that takes in number of state, action and allow users to update
4
+ # transition functions through gradually inputting logic
5
+
6
+ class MDP ():
7
+ """
8
+ The following class defines a MDP
9
+ Running example MDP with 3 state variables: a, b, c
10
+ with a = [1, 2, 3]
11
+ b = [4, 5, 6, 7]
12
+ c = [10, 12, 13, 15]
13
+ Three different actions ["forward", "left", "right"]
14
+ """
15
+
16
+ def __init__ (self , states = None , actions = None , r = - 1.0 , method = "replace" ):
17
+ """
18
+ Define the value list of states and actions
19
+ State should be a list of list, with sub-lists are lists of state values
20
+
21
+ input: states=[a, b, c]
22
+
23
+ Action should be a list
24
+
25
+ input: actions = ["forward", "left", "right"]
26
+
27
+
28
+ method: replace or add: replacing the current element in P matrix if value is > 0 or add into. Default to replace.
29
+ """
30
+ self .s = states
31
+ self .a = actions
32
+ self .num_s_vars = len (self .s )
33
+ self .num_s = 1
34
+ self .s_range = []
35
+ for i in range (self .num_s_vars ):
36
+ self .s_range .append (len (self .s [i ]))
37
+ self .num_s = self .num_s * len (self .s [i ])
38
+ self .num_a = len (self .a )
39
+ # transition function
40
+ # P[new_state, current_state, action] = probability
41
+ self .P = np .zeros ((self .num_s , self .num_s , self .num_a ))
42
+
43
+ # reward function
44
+ self .R = r * np .ones ((self .num_s , self .num_a ))
45
+
46
+ self .method = method
47
+
48
+ def add_route (self , current_state , action , new_state , p = 1.0 ):
49
+ """
50
+ Add new transition route to MDP
51
+ Default probability is 1.0
52
+ current_state (list): [a_i, b_i, c_i]. E.g: [1, 4, 15]
53
+ new_state (list): [a_ii, b_ii, c_ii]. E.g: [1, 5, 10]
54
+ action: "forward"
55
+ """
56
+ # get correct index of MDP action from input action
57
+ action_index = self .a .index (action )
58
+
59
+ if self .P [self .get_index (new_state ), self .get_index (current_state ), action_index ] > 0.0 :
60
+ # print("Warning: Already have prop value of {}. Use {} to deal with".format(self.P[self.get_index(new_state), self.get_index(current_state), action_index], self.method))
61
+ if self .method == "replace" :
62
+ self .P [self .get_index (new_state ), self .get_index (current_state ), action_index ] = p
63
+ else :
64
+ self .P [self .get_index (new_state ), self .get_index (current_state ), action_index ] += p
65
+ else :
66
+ self .P [self .get_index (new_state ), self .get_index (current_state ), action_index ] = p
67
+
68
+ def add_reward (self , state , action , reward ):
69
+ self .R [self .get_index (state ), self .a .index (action )] = reward
70
+
71
+ def get_state (self , index ):
72
+ """
73
+ Get state index tuple from index
74
+ E.g.
75
+ Input: 0
76
+ Output: [0, 0, 0]
77
+ """
78
+ state_index = []
79
+ divisor = self .num_s
80
+ dividend = index
81
+ for i in range (self .num_s_vars - 1 , - 1 , - 1 ):
82
+ divisor = divisor / self .s_range [i ]
83
+ quotient = math .floor (dividend / divisor )
84
+ remainder = dividend % divisor
85
+ state_index .append (quotient )
86
+ dividend = remainder
87
+ return np .flip (state_index )
88
+
89
+ def get_real_state_value (self , index ):
90
+ """
91
+ Get state tuple from index
92
+ E.g.
93
+ Input: 0
94
+ Output: [1, 4, 10]
95
+ """
96
+ index = self .get_state (index )
97
+ return [self .s [i ][v ] for i , v in enumerate (index )]
98
+
99
+ def get_index (self , state ):
100
+ """
101
+ Get index from state tuple
102
+ E.g.
103
+ Input: [1, 4, 10]
104
+ Output: 0
105
+ """
106
+ # get state index tuple from state tuple
107
+ state_index = [self .s [i ].index (v ) for i , v in enumerate (state )]
108
+ cur_mul = self .num_s
109
+ index = 0
110
+ for i in range (self .num_s_vars - 1 , - 1 , - 1 ):
111
+ cur_mul = cur_mul / self .s_range [i ]
112
+ index = index + state_index [i ] * cur_mul
113
+ return int (index )
114
+
115
+ def get_mdp (self ):
116
+ return self .num_a , self .num_s , self .R , self .P
0 commit comments