-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmodels.py
More file actions
122 lines (112 loc) · 4.58 KB
/
models.py
File metadata and controls
122 lines (112 loc) · 4.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.manifold import TSNE
class VCGAN_generator(nn.Module):
"""
The generator for vanilla GAN.
It takes as input a Gaussian noise z and a condition vector c (optional),
and produces a "fake" vector x.
To this end, it employs multiple fully-connected layers with batch normalization.
"""
def __init__(self, z_dim, hidden_dim, x_dim, num_layer, col_types, col_idxes, c_dim=0,
):
super(VCGAN_generator, self).__init__()
self.emb = nn.Embedding(c_dim, c_dim)
self.input = nn.Linear(z_dim + c_dim, hidden_dim)
self.inputbn = nn.BatchNorm1d(hidden_dim)
self.col_types = col_types
self.col_idxes = col_idxes
self.x_dim = x_dim # 生成数据的维度
self.num_layer = num_layer
for i in range(num_layer):
fc = nn.Linear(hidden_dim, hidden_dim)
setattr(self, "fc%d" % i, fc)
bn = nn.BatchNorm1d(hidden_dim)
setattr(self, "bn%d" % i, bn)
setattr(self, 'relu{}'.format(i), nn.LeakyReLU(0.2, inplace=True))
self.output = nn.Linear(hidden_dim, x_dim)
self.outputbn = nn.BatchNorm1d(x_dim)
def forward(self, z, c, generate=False):
z = torch.cat((z, self.emb(c)), -1)
z = self.input(z)
z = self.inputbn(z)
z = F.leaky_relu(z, negative_slope=0.2)
for i in range(self.num_layer):
z = getattr(self, 'fc{}'.format(i))(z)
z = getattr(self, 'bn{}'.format(i))(z)
z = getattr(self, 'relu{}'.format(i))(z)
x = self.output(z)
output = []
for i in range(len(self.col_types)):
start = self.col_idxes[i][0]
end = self.col_idxes[i][1]
if self.col_types[i] == 'binary':
temp = F.sigmoid(x[:, start:end + 1])
if generate:
temp = (temp > 1 / 2).int()
elif self.col_types[i] == 'normalize':
# 数据转化到了[-1,1]之间
temp = F.tanh(x[:, start:end + 1])
elif self.col_types[i] == 'one-hot':
temp = torch.softmax(x[:, start:end + 1], dim=1)
if generate:
max_idxes = temp.max(1)[1].unsqueeze(1)
temp = torch.zeros(temp.shape).scatter_(1, max_idxes, 1)
elif self.col_types[i] == 'gmm':
temp1 = torch.tanh(x[:, start:start + 1])
temp2 = torch.softmax(x[:, start + 1:end + 1], dim=1)
temp = torch.cat((temp1, temp2), dim=1)
else:
# self.col_type[i] == 'ordinal':
temp = torch.tanh(x[:, start:end + 1])
output.append(temp)
output = torch.cat(output, dim=1)
return output
def generate(self, z, c):
return self.forward(z, c, generate=True)
class VCGAN_discriminator(nn.Module):
"""
The discriminator for vanilla GAN.
It takes as input the real/fake data,
and uses an MLP to produce label (1: real; 0: fake)
判别器中不能加入barchnorm!!!, 加入的话无法生成有效数据
"""
def __init__(self, x_dim, hidden_dim, num_layer, c_dim=0, wgan=False, dropout=0.5):
super(VCGAN_discriminator, self).__init__()
self.num_layer = num_layer
self.dropout = dropout
self.emb = nn.Embedding(c_dim, c_dim)
self.input = nn.Linear(x_dim + c_dim, hidden_dim)
self.hidden = []
self.wgan = wgan
for i in range(num_layer):
fc = nn.Linear(hidden_dim, hidden_dim)
setattr(self, "fc%d" % i, fc)
self.hidden.append(fc)
self.output = nn.Linear(hidden_dim, 1)
def forward(self, z, c):
# if self.condition:
# assert c is not None
# z = torch.cat((z, c), dim=1)
# print(z.shape)
# print(c.shape)
z = torch.cat([z, self.emb(c)], dim=-1)
# print(z.shape)
z = self.input(z)
z = F.leaky_relu(z, ) # inplace=False)
# z = self.Dropout(z)
z = F.dropout(z, training=self.training, p=self.dropout)
for i in range(len(self.hidden)):
z = self.hidden[i](z)
z = F.dropout(z, training=self.training)
z = F.leaky_relu(z, ) # inplace=False)
z = self.output(z)
if self.wgan:
return z
else:
return torch.sigmoid(z)
def init_weights(m):
if type(m) == nn.Linear:
torch.nn.init.xavier_uniform(m.weight)
m.bias.data.fill_(0.01)