Skip to content

Commit

Permalink
update rnn
Browse files Browse the repository at this point in the history
  • Loading branch information
polarker committed Nov 3, 2015
1 parent 83ccce9 commit 2158dbd
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 19 deletions.
8 changes: 4 additions & 4 deletions data/get_mnist.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/usr/bin/env sh

wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget --no-check-certificate http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
11 changes: 9 additions & 2 deletions include/galois/models/rnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ namespace gs
class RNN : protected Model<T>
{
protected:
int seq_length;
int max_len; // length of rnn
int input_size;
int output_size;
vector<int> hidden_sizes;

int seq_len = 0; // length of dataset
SP_NArray<T> X = nullptr;
SP_NArray<T> Y = nullptr;
public:
RNN(int seq_length,
RNN(int max_len,
int input_size,
int output_size,
initializer_list<int> hidden_sizes,
Expand All @@ -30,6 +33,10 @@ namespace gs
string optimizer_name);
RNN(const RNN& other) = delete;
RNN& operator=(const RNN&) = delete;

void add_train_dataset(const SP_NArray<T> data, const SP_NArray<T> target);
T fit_one_batch(const int start_from, const bool update=true);
void fit();
};

}
Expand Down
3 changes: 2 additions & 1 deletion include/galois/narray.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ namespace gs
void reopaque() { data_opaque = true; }
void setclear() { data_opaque = false; }

void copy_from(const vector<int> &, const T*);
// void copy_from(const vector<int> &, const T*);
void copy_from(const vector<int> &, const SP_NArray<T>);
void copy_from(const int start_from, const int copy_size, const SP_NArray<T>);
void uniform(T lower, T upper) {
// future : move random generator to a single file
uniform_real_distribution<T> distribution(lower, upper);
Expand Down
65 changes: 61 additions & 4 deletions src/models/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace gs
}

template<typename T>
RNN<T>::RNN(int _seq_length,
RNN<T>::RNN(int _max_len,
int _input_size,
int _output_size,
initializer_list<int> _hidden_sizes,
Expand All @@ -39,7 +39,7 @@ namespace gs
T _learning_rate,
string optimizer_name)
: Model<T>(_batch_size, _num_epoch, _learning_rate, optimizer_name)
, seq_length(_seq_length)
, max_len(_max_len)
, input_size(_input_size)
, output_size(_output_size)
, hidden_sizes(_hidden_sizes) {
Expand All @@ -56,7 +56,7 @@ namespace gs
}
}
auto h2y = linear_entropy<T>(hidden_sizes[hidden_sizes.size()-1], output_size);
for (int i = 0; i < seq_length; i++) {
for (int i = 0; i < max_len; i++) {
for (int j = 0; j < hidden_sizes.size(); j++) {
string h = generate_id("h", i, j);
string left_h = generate_id("h", i-1, j);
Expand All @@ -78,7 +78,7 @@ namespace gs

auto x_ids = vector<string>();
auto y_ids = vector<string>();
for (int i = 0; i < seq_length; i++) {
for (int i = 0; i < max_len; i++) {
x_ids.push_back(generate_id("x", i));
y_ids.push_back(generate_id("y", i));
}
Expand All @@ -100,6 +100,63 @@ namespace gs
cout << out_id << " -> " << in_id << endl;
}
}

template<typename T>
void RNN<T>::add_train_dataset(const SP_NArray<T> data, const SP_NArray<T> target) {
auto data_dims = data->get_dims();
auto target_dims = target->get_dims();
CHECK(data_dims[0] == target_dims[0], "length of data and target must match");
CHECK(data_dims.size() == 2 && target_dims.size() == 1 && data_dims[1] == input_size, "sizes must match");

CHECK(X==nullptr && Y==nullptr, "dataset should not be set before");
seq_len = data_dims[0];
X = data;
Y = target;
}

template<typename T>
T RNN<T>::fit_one_batch(const int start_from, bool update) {
this->net.reopaque();
for (int i = 0; i < this->input_signals.size(); i++) {
this->input_signals[i]->get_data()->copy_from(start_from+i, this->batch_size, X);
}
for (int i = 0; i < this->output_signals.size(); i++) {
this->output_signals[i]->get_target()->copy_from(start_from+i, this->batch_size, Y);
}

this->net.forward();
this->net.backward();
if (update) {
this->optimizer->update();
}

T loss = 0;
for (auto output_signal : this->output_signals) {
loss += *output_signal->get_loss();
}
return loss;
}

template<typename T>
void RNN<T>::fit() {
for (int k = 1; k < this->num_epoch+1; k++) {
printf("Epoch: %2d", k);
auto start = chrono::system_clock::now();
T loss = 0;

int len = seq_len - max_len + 1 - this->batch_size + 1;
for (int i = 0; i < len; i++) {
loss += fit_one_batch(i);
}
loss /= T(seq_len - max_len + 1);

auto end = chrono::system_clock::now();
chrono::duration<double> eplased_time = end - start;
printf(", time: %.2fs", eplased_time.count());
printf(", loss: %.6f", loss);
printf("\n");
}
}

template class RNN<float>;
template class RNN<double>;
Expand Down
38 changes: 30 additions & 8 deletions src/narray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ namespace gs
}
}

template<typename T>
void NArray<T>::copy_from(const vector<int> &dims, const T* data) {
CHECK(dims == this->dims, "the dimension should be equal");
for (int i = 0; i < this->get_size(); i++) {
this->data[i] = data[i];
}
setclear();
}
// template<typename T>
// void NArray<T>::copy_from(const vector<int> &dims, const T* data) {
// CHECK(dims == this->dims, "the dimension should be equal");
// for (int i = 0; i < this->get_size(); i++) {
// this->data[i] = data[i];
// }
// setclear();
// }

template<typename T>
void NArray<T>::copy_from(const vector<int> &idxs, const SP_NArray<T> dataset) {
Expand All @@ -81,6 +81,28 @@ namespace gs
}
setclear();
}

template<typename T>
void NArray<T>::copy_from(const int start_from, const int copy_size, const SP_NArray<T> dataset) {
// copy a batch from dataset
auto dataset_dims = dataset->get_dims();
CHECK(copy_size == this->dims[0], "the size of copy should be equal to batch size");
CHECK(dataset_dims.size() == this->dims.size(), "number of dimensions should be equal");
for (int i = 1; i < this->dims.size(); i++) {
CHECK(dataset_dims[i] == this->dims[i], "dimensions should be equal");
}
CHECK(start_from >= 0 && start_from+copy_size-1 < dataset_dims[0], "offset is not valid");

int batch_size = copy_size;
int stride = this->get_size() / batch_size;
auto dataset_ptr = dataset->get_data();
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < stride; j++) {
this->data[i*stride + j] = dataset_ptr[(start_from+i)*stride + j];
}
}
setclear();
}


template<typename T>
Expand Down

0 comments on commit 2158dbd

Please sign in to comment.