From 9f80553714177b5e2cb677a5f453070ebb93c39d Mon Sep 17 00:00:00 2001 From: Cheng Wang Date: Tue, 10 Nov 2015 03:02:13 +0100 Subject: [PATCH] rnn could run now, need check later --- data/get_tinyshakespeare.sh | 4 ++ example/rnn.cc | 18 ++++++ include/galois/dataset/chartxt.h | 103 +++++++++++++++++++++++++++++++ src/models/rnn.cc | 6 +- src/path.cc | 2 +- 5 files changed, 130 insertions(+), 3 deletions(-) create mode 100644 data/get_tinyshakespeare.sh create mode 100644 example/rnn.cc create mode 100644 include/galois/dataset/chartxt.h diff --git a/data/get_tinyshakespeare.sh b/data/get_tinyshakespeare.sh new file mode 100644 index 0000000..a14d2c1 --- /dev/null +++ b/data/get_tinyshakespeare.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env sh + +wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt +mv input.txt tinyshakespeare.txt diff --git a/example/rnn.cc b/example/rnn.cc new file mode 100644 index 0000000..13a2627 --- /dev/null +++ b/example/rnn.cc @@ -0,0 +1,18 @@ +#include "galois/models.h" + +using namespace std; +using namespace gs; + +int main() +{ + using T = double; + + int seq_length = 3; + int input_size = 100; + int output_size = 100; + auto hidden_sizes = {1000, 1000}; + int batch_size = 100; + int num_epoch = 10; + T learning_rate = 0.01; + RNN model(seq_length, input_size, output_size, hidden_sizes, batch_size, num_epoch, learning_rate, "sgd"); +} \ No newline at end of file diff --git a/include/galois/dataset/chartxt.h b/include/galois/dataset/chartxt.h new file mode 100644 index 0000000..74e462e --- /dev/null +++ b/include/galois/dataset/chartxt.h @@ -0,0 +1,103 @@ +#ifndef _GALOIS_CHARNN_READER_H_ +#define _GALOIS_CHARNN_READER_H_ + +#include "galois/narray.h" +#include "galois/utils.h" +#include +#include +#include + +using namespace std; + +namespace chartxt +{ + + template + class Article + { + private: + int num_diff_chars = 0; + map char2int = {}; + map int2char = {}; + + int sequence_length = 0; + gs::SP_NArray int_sequence = nullptr; + gs::SP_NArray vectorized_sequence = nullptr; + gs::SP_NArray target_sequence = nullptr; + + public: + explicit Article(const string &file_name) { + ifstream fin; + char ch; + int num_chars; + + fin.open(file_name); + CHECK(fin.is_open(), "can not open file"); + num_chars = 0; + while (fin >> noskipws >> ch) { + num_chars += 1; + if (char2int.count(ch) == 0) { + int idx = char2int.size(); + char2int[ch] = idx; + int2char[idx] = ch; + } + } + fin.close(); + + num_diff_chars = char2int.size(); + sequence_length = num_chars - 1; + int_sequence = make_shared>(num_chars); + vectorized_sequence = make_shared>(num_chars-1, num_diff_chars); + target_sequence = make_shared>(num_chars-1); + + cout << "size of chars: " << num_chars << endl; + cout << "size of different chars: " << char2int.size() << ", " << int2char.size() << endl; + fin.open(file_name); + CHECK(fin.is_open(), "can not open file"); + auto int_sequence_ptr = int_sequence->get_data(); + auto vectorized_sequence_ptr = vectorized_sequence->get_data(); + auto target_sequence_ptr = target_sequence->get_data(); + for (int i = 0; i < num_chars; i++) { + fin >> noskipws >> ch; + int idx = char2int[ch]; + + int_sequence_ptr[i] = idx; + if (i < num_chars - 1) { + for (int j = 0; j < num_diff_chars; j++) { + if (j == idx) { + vectorized_sequence_ptr[i*num_diff_chars + j] = 1; + } else { + vectorized_sequence_ptr[i*num_diff_chars + j] = 0; + } + } + } + if (i > 0) { + target_sequence_ptr[i-1] == idx; + } + } + fin.close(); + } + Article() = delete; + Article(const Article&) = delete; + Article& operator=(Article &) = delete; + + int get_num_diff_chars() { + return num_diff_chars; + } + +// int get_sequence_length { +// return sequence_length; +// } + + gs::SP_NArray get_vectorized_sequence() { + return vectorized_sequence; + } + + gs::SP_NArray get_target_sequence() { + return target_sequence; + } + }; + +} + +#endif diff --git a/src/models/rnn.cc b/src/models/rnn.cc index 77ff5ed..15b83a5 100644 --- a/src/models/rnn.cc +++ b/src/models/rnn.cc @@ -97,7 +97,7 @@ namespace gs auto t = this->net.links[idx]; auto in_id = get<0>(t)[0]; auto out_id = get<1>(t)[0]; - cout << out_id << " -> " << in_id << endl; + cout << in_id << " <- " << out_id << endl; } } @@ -118,9 +118,11 @@ namespace gs T RNN::fit_one_batch(const int start_from, bool update) { this->net.reopaque(); for (int i = 0; i < this->input_signals.size(); i++) { + this->input_signals[i]->reopaque(); this->input_signals[i]->get_data()->copy_from(start_from+i, this->batch_size, X); } for (int i = 0; i < this->output_signals.size(); i++) { + this->output_signals[i]->reopaque(); this->output_signals[i]->get_target()->copy_from(start_from+i, this->batch_size, Y); } @@ -148,7 +150,7 @@ namespace gs for (int i = 0; i < len; i++) { loss += fit_one_batch(i); } - loss /= T(seq_len - max_len + 1); + loss /= T(len); auto end = chrono::system_clock::now(); chrono::duration eplased_time = end - start; diff --git a/src/path.cc b/src/path.cc index 07175e5..6b2cfeb 100644 --- a/src/path.cc +++ b/src/path.cc @@ -24,7 +24,7 @@ namespace gs template SP_Filter Path::share() { auto res = make_shared>(); - for (auto const& filter : pfilters) { + for (auto const& filter : links) { auto copy_of_filter = filter->share(); res->add_filter(copy_of_filter); }