-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmymodels.py
More file actions
62 lines (47 loc) · 2.1 KB
/
mymodels.py
File metadata and controls
62 lines (47 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from tensorflow.keras import datasets, layers, models
INIT = 'glorot_uniform'
class gatedattention(layers.Layer):
def __init__(self, channels=64, **kwargs):
super(gatedattention, self).__init__(**kwargs)
self.channels = channels
self.V = tfa.layers.WeightNormalization(layers.Dense(channels,use_bias=False, kernel_initializer=INIT))
self.U = tfa.layers.WeightNormalization(layers.Dense(channels,use_bias=False, kernel_initializer=INIT))
self.Wa = layers.Dense(1,kernel_regularizer=tf.keras.regularizers.l2(1e-5),use_bias=False, kernel_initializer=INIT)
self.softmax = layers.Softmax(axis=1)
self.dot = layers.Dot(axes=1)
def call(self, x):
x = x[0]
V = tf.keras.activations.tanh(self.V(x))
U = tf.keras.activations.sigmoid(self.U(x))
energy = tf.math.multiply(V,U)
x = tf.expand_dims(x,0)
att = tf.expand_dims(self.Wa(energy),0)
att = self.softmax(att)
hs = self.dot([att,x]) # 1,vector_size
hs = tf.squeeze(hs,1)
return att, hs
def get_config(self):
config = super(gatedattention, self).get_config()
config.update({'channels':self.channels})
return config
class AttMILbinary(models.Model):
def __init__(self):
super(AttMILbinary, self).__init__()
def build(self, inputshape):
self.inputshape = inputshape
self.gatedattention = gatedattention(inputshape[-1]//2, name='attention')
# self.dot = layers.Dot(axes=1)
self.WC = layers.Dense(1,activation='sigmoid',kernel_regularizer=tf.keras.regularizers.l2(0.00001), kernel_initializer=INIT)
super(AttMILbinary,self).build(inputshape)
def call(self, x):
att, hs = self.gatedattention(x)
hs = layers.Dropout(rate=0.1)(hs)
s = self.WC(hs)
return s
def get_config(self):
config = super(AttMILbinary, self).get_config()
config.update({'inputshape':self.inputshape})
return config