forked from ZiJianZhao/SeqGAN-PyTorch
-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathhelper.py
172 lines (135 loc) · 5.09 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import glob
from sklearn.model_selection import train_test_split
import numpy as np
import collections
import itertools
import torch
from torch import nn
from torch import optim
import sys
from nltk.translate.bleu_score import corpus_bleu
NUMBER_OF_SENTENCES = 100
Token = collections.namedtuple("Token", ["index", "word"])
SOS = Token(0, "<sos>")
EOS = Token(1, "<eos>")
PAD = Token(2, "<pad>")
# Helper for interactive demo
def pad_sentences(sentence, sentence_len):
words = sentence.split(" ")
if len(words) > sentence_len:
# keep only 10 words
words = words[:sentence_len]
else:
for i in range(sentence_len-len(words)):
words.append(PAD.word)
return words
# Convert the sentence to word ids
def get_ids(sentence, idx_to_word, word_to_idx, VOCAB_SIZE):
sentence_ids = []
for word in sentence:
if word != PAD.word:
flag=1
break
if flag == 1:
for word in sentence:
if word.lower() not in word_to_idx:
# PAD when unknown word found
sentence_ids.append(SOS.index)
elif word.lower():
sentence_ids.append(word_to_idx[word.lower()])
else:
sentence_ids.append(SOS.word)
for word in sentence[1:]:
sentence_ids.append(PAD.word)
return sentence_ids
def load_from_big_file(file):
s = []
with open(file) as f:
lines = f.readlines()
for line in lines[:NUMBER_OF_SENTENCES]:
line = line.strip()
line = line.rstrip(".")
words = line.split()
for i in range(len(words)):
words[i] = words[i].strip(',"')
if len(words) >= 10:
sent = " ".join(words[:10])
sent += " ."
else:
sent = " ".join(words)
sent += " ."
sent += (" "+PAD.word) * (10 - len(words))
s.append(sent)
s_train, s_test= train_test_split(s, shuffle = True, test_size=0.1, random_state=42)
return s_train, s_test[:2]
def fetch_vocab(DATA_GERMAN, DATA_ENGLISH, DATA_GERMAN2): # -> typing.Tuple[typing.List[str], typing.Dict[str, int]]:
"""Determines the vocabulary, and provides mappings from indices to words and vice versa.
Returns:
tuple: A pair of mappings, index-to-word and word-to-index.
"""
# gather all (lower-cased) words that appear in the data
all_words = set()
for sentence in itertools.chain(DATA_GERMAN, DATA_ENGLISH, DATA_GERMAN2):
all_words.update(word.lower() for word in sentence.split(" ") if word != PAD.word)
# create mapping from index to word
idx_to_word = [SOS.word, EOS.word, PAD.word] + list(sorted(all_words))
# create mapping from word to index
word_to_idx = {word: idx for idx, word in enumerate(idx_to_word)}
return idx_to_word, word_to_idx
def generate_sentence_from_id(idx_to_word, input_ids, file_name = None, header = ''):
sentence = []
if file_name:
out_file = open(file_name, 'a')
out_file.write(header + ':')
sep = ''
for id in input_ids:
sentence.append(idx_to_word[id])
if file_name:
out_file.write(sep + idx_to_word[id])
sep = ' '
if file_name:
out_file.write('\n')
out_file.close()
return sentence
def generate_file_from_sentence(sentences, out_file, word_to_idx, generated_num = 0):
if generated_num:
generated_index = np.random.choice(len(sentences), generated_num)
else:
generated_index = np.arange(0, len(sentences))
out_file = open(out_file, "w")
for i in generated_index:
sent = sentences[i].split(' ')
new_sent_id = []
sep = ''
for word in sent:
out_file.write(sep + str(word_to_idx[word.lower()]))
sep = ' '
out_file.write('\n')
def generate_real_data(input_file, batch_size, generated_num, idx_to_word, word_to_idx, train_file, test_file = None):
train_sen, test_sen = load_from_big_file(input_file)
generate_file_from_sentence(train_sen, train_file, word_to_idx, generated_num)
if test_file:
generate_file_from_sentence(test_sen, test_file, word_to_idx)
def save_vocab(checkpoint, idx_to_word, word_to_idx, vocab_size, g_emb_dim = None, g_hidden_dim = None, g_sequence_len = None):
"""
out_file = open(checkpoint+'idx_to_word.pkl', "wb")
pickle.dump(idx_to_word, out_file)
out_file.close()
out_file = open(checkpoint+'word_to_idx.pkl', "wb")
pickle.dump(word_to_idx, out_file)
out_file.close()
out_file = open(checkpoint+'vocab_size.pkl', "wb")
pickle.dump(vocab_size, out_file)
out_file.close()
"""
metadata = {}
metadata['idx_to_word'] = idx_to_word
metadata['word_to_idx'] = word_to_idx
metadata['vocab_size'] = vocab_size
metadata['g_emb_dim'] = g_emb_dim
metadata['g_hidden_dim'] = g_hidden_dim
metadata['g_sequence_len'] = g_sequence_len
torch.save(metadata, checkpoint)
def load_vocab(checkpoint):
metadata = torch.load(checkpoint)
return metadata