-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.py
104 lines (83 loc) · 3.65 KB
/
agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import numpy as np
import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from distances import mahalanobis_dist_from_vectors
class SimilarityMatrix:
def __init__(self, N):
self.state_matrix = T.zeros((N+1, N+1))
#self.kreciprocal_matrix = T.zeros((N+1, N+1))
self.max_positive_distance = 0
self.min_negative_distance = 1/T.zeros((1,)) #potential error source
#def k_reciprocal(self, k=15):
def find_max_min_distances_in_batch(self, query_feature, gk_feature):
#print(self.min_negative_distance)
#print(self.max_positive_distance)
dist = mahalanobis_dist_from_vectors(query_feature, gk_feature.reshape(1,-1))
#print(dist)
#print("****************************************")
self.max_positive_distance = max(self.max_positive_distance, dist).reshape(1)
self.min_negative_distance = min(self.min_negative_distance, dist).reshape(1)
def reset_distances(self):
self.max_positive_distance = 0
self.min_negative_distance = 1/T.zeros((1,)) #potential error source
class Agent(nn.Module): #architecture doubt: Ns=30->flattened state matrix:961, 3 fc layers with 256 in paper
def __init__(self, ALPHA, input_dims, fc1_dims, fc2_dims, fc3_dims, n_actions):
super(Agent, self).__init__()
self.input_dims = input_dims
self.fc1_dims = fc1_dims
self.fc2_dims = fc2_dims
self.fc3_dims = fc3_dims
self.n_actions = n_actions
self.fc1 = nn.Linear(self.input_dims, self.fc1_dims)
self.fc2 = nn.Linear(self.fc1_dims, self.fc2_dims)
self.fc3 = nn.Linear(self.fc2_dims, self.fc3_dims)
self.output = nn.Linear(self.fc3_dims, self.n_actions)
self.optimizer = optim.Adam(self.parameters(), lr=ALPHA)
self.device = T.device('cuda:0' if T.cuda.is_available() else 'cpu')
self.to(self.device)
self.sim_mat = SimilarityMatrix(self.n_actions)
def forward(self, state):
state = T.Tensor(state).to(self.device)
#print(state.shape) #torch.Size([961])
x = F.relu(self.fc1(state))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = F.relu(self.output(x)) #add activation
#print(x.shape) #torch.Size([30])
return x
def compute_reward(self, label, margin=0.2):
#print(self.sim_mat.max_positive_distance.shape)
#print(self.sim_mat.min_negative_distance.shape)
reward = margin + label*(self.sim_mat.max_positive_distance - self.sim_mat.min_negative_distance)
#print(f"Reward:{reward}")
return reward
def update_state(self, state, label, g_k, threshold=0.4):
#print(label)
#print("**************")
if label == True:
z = (state[:, 0] + state[:, g_k]) / 2
state[:, 0] = z
state[0, :] = z
state[0, g_k] = 1
state[g_k, 0] = 1
else:
z = state[:, g_k].detach().clone() #might be buggy
z[z<threshold] = 0
state[:, 0] = T.clamp(state[:, 0] - z, min=0)
state[0, g_k] = 0
state[g_k, 0] = 0
z = state[:, 0]
state[0, :] = z
state.fill_diagonal_(0) #look into this
#print(state)
#print("**************************************")
#assert(np.diagonal(state).any() == False) #sanity check, diagonal elements should be zero
return state
def take_unique_action(self, logits, action_buffer):
max_so_far = [-1/T.zeros((1,)), 0] #<-[logit, action]
for i in range(logits.shape[0]):
if i not in action_buffer and logits[i] > max_so_far[0]:
max_so_far[0], max_so_far[1] = logits[i], i
return max_so_far[1]