Skip to content

Commit

Permalink
rnn could run now, need check later
Browse files Browse the repository at this point in the history
  • Loading branch information
polarker committed Nov 10, 2015
1 parent 2158dbd commit 9f80553
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 3 deletions.
4 changes: 4 additions & 0 deletions data/get_tinyshakespeare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env sh

wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
mv input.txt tinyshakespeare.txt
18 changes: 18 additions & 0 deletions example/rnn.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "galois/models.h"

using namespace std;
using namespace gs;

int main()
{
using T = double;

int seq_length = 3;
int input_size = 100;
int output_size = 100;
auto hidden_sizes = {1000, 1000};
int batch_size = 100;
int num_epoch = 10;
T learning_rate = 0.01;
RNN<T> model(seq_length, input_size, output_size, hidden_sizes, batch_size, num_epoch, learning_rate, "sgd");
}
103 changes: 103 additions & 0 deletions include/galois/dataset/chartxt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#ifndef _GALOIS_CHARNN_READER_H_
#define _GALOIS_CHARNN_READER_H_

#include "galois/narray.h"
#include "galois/utils.h"
#include <fstream>
#include <string>
#include <map>

using namespace std;

namespace chartxt
{

template<typename T>
class Article
{
private:
int num_diff_chars = 0;
map<char, int> char2int = {};
map<int, char> int2char = {};

int sequence_length = 0;
gs::SP_NArray<T> int_sequence = nullptr;
gs::SP_NArray<T> vectorized_sequence = nullptr;
gs::SP_NArray<T> target_sequence = nullptr;

public:
explicit Article(const string &file_name) {
ifstream fin;
char ch;
int num_chars;

fin.open(file_name);
CHECK(fin.is_open(), "can not open file");
num_chars = 0;
while (fin >> noskipws >> ch) {
num_chars += 1;
if (char2int.count(ch) == 0) {
int idx = char2int.size();
char2int[ch] = idx;
int2char[idx] = ch;
}
}
fin.close();

num_diff_chars = char2int.size();
sequence_length = num_chars - 1;
int_sequence = make_shared<gs::NArray<T>>(num_chars);
vectorized_sequence = make_shared<gs::NArray<T>>(num_chars-1, num_diff_chars);
target_sequence = make_shared<gs::NArray<T>>(num_chars-1);

cout << "size of chars: " << num_chars << endl;
cout << "size of different chars: " << char2int.size() << ", " << int2char.size() << endl;
fin.open(file_name);
CHECK(fin.is_open(), "can not open file");
auto int_sequence_ptr = int_sequence->get_data();
auto vectorized_sequence_ptr = vectorized_sequence->get_data();
auto target_sequence_ptr = target_sequence->get_data();
for (int i = 0; i < num_chars; i++) {
fin >> noskipws >> ch;
int idx = char2int[ch];

int_sequence_ptr[i] = idx;
if (i < num_chars - 1) {
for (int j = 0; j < num_diff_chars; j++) {
if (j == idx) {
vectorized_sequence_ptr[i*num_diff_chars + j] = 1;
} else {
vectorized_sequence_ptr[i*num_diff_chars + j] = 0;
}
}
}
if (i > 0) {
target_sequence_ptr[i-1] == idx;
}
}
fin.close();
}
Article() = delete;
Article(const Article&) = delete;
Article& operator=(Article &) = delete;

int get_num_diff_chars() {
return num_diff_chars;
}

// int get_sequence_length {
// return sequence_length;
// }

gs::SP_NArray<T> get_vectorized_sequence() {
return vectorized_sequence;
}

gs::SP_NArray<T> get_target_sequence() {
return target_sequence;
}
};

}

#endif
6 changes: 4 additions & 2 deletions src/models/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace gs
auto t = this->net.links[idx];
auto in_id = get<0>(t)[0];
auto out_id = get<1>(t)[0];
cout << out_id << " -> " << in_id << endl;
cout << in_id << " <- " << out_id << endl;
}
}

Expand All @@ -118,9 +118,11 @@ namespace gs
T RNN<T>::fit_one_batch(const int start_from, bool update) {
this->net.reopaque();
for (int i = 0; i < this->input_signals.size(); i++) {
this->input_signals[i]->reopaque();
this->input_signals[i]->get_data()->copy_from(start_from+i, this->batch_size, X);
}
for (int i = 0; i < this->output_signals.size(); i++) {
this->output_signals[i]->reopaque();
this->output_signals[i]->get_target()->copy_from(start_from+i, this->batch_size, Y);
}

Expand Down Expand Up @@ -148,7 +150,7 @@ namespace gs
for (int i = 0; i < len; i++) {
loss += fit_one_batch(i);
}
loss /= T(seq_len - max_len + 1);
loss /= T(len);

auto end = chrono::system_clock::now();
chrono::duration<double> eplased_time = end - start;
Expand Down
2 changes: 1 addition & 1 deletion src/path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace gs
template<typename T>
SP_Filter<T> Path<T>::share() {
auto res = make_shared<Path<T>>();
for (auto const& filter : pfilters) {
for (auto const& filter : links) {
auto copy_of_filter = filter->share();
res->add_filter(copy_of_filter);
}
Expand Down

0 comments on commit 9f80553

Please sign in to comment.