-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathdraw.py
51 lines (46 loc) · 1.53 KB
/
draw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import torch
import matplotlib.pyplot as plt
import numpy as np
from rtc_env import GymEnv
from deep_rl.storage import Storage
from deep_rl.actor_critic import ActorCritic
def draw_state(record_action, record_state, path):
length = len(record_action)
plt.subplot(411)
plt.plot(range(length), record_action)
plt.xlabel('episode')
plt.ylabel('action')
ylabel = ['receiving rate', 'delay', 'packet loss']
record_state = [t.numpy() for t in record_state]
record_state = np.array(record_state)
for i in range(3):
plt.subplot(411+i+1)
plt.plot(range(length), record_state[:,i])
plt.xlabel('episode')
plt.ylabel(ylabel[i])
plt.tight_layout()
plt.savefig("{}test_result.jpg".format(path))
def draw_module(model, data_path, max_num_steps = 1000):
env = GymEnv()
record_reward = []
record_state = []
record_action = []
episode_reward = 0
time_step = 0
tmp = model.random_action
model.random_action = False
while time_step < max_num_steps:
done = False
state = torch.Tensor(env.reset())
while not done:
action, _, _ = model.forward(state)
state, reward, done, _ = env.step(action)
state = torch.Tensor(state)
record_state.append(state)
record_reward.append(reward)
record_action.append(action)
time_step += 1
model.random_action = True
draw_state(record_action, record_state, data_path)