11#ifndef CAFFE_DATA_LAYERS_HPP_
22#define CAFFE_DATA_LAYERS_HPP_
33
4+ #include < map>
45#include < string>
56#include < utility>
67#include < vector>
78
8- #include " boost/scoped_ptr.hpp"
9+ #include " boost/random/mersenne_twister.hpp"
10+ #include " boost/random/uniform_real.hpp"
11+ #include " boost/random/variate_generator.hpp"
12+ #include " boost/weak_ptr.hpp"
913#include " hdf5.h"
1014
1115#include " caffe/blob.hpp"
1620#include " caffe/layer.hpp"
1721#include " caffe/net.hpp"
1822#include " caffe/proto/caffe.pb.h"
23+ #include " caffe/util/blocking_queue.hpp"
1924#include " caffe/util/db.hpp"
2025
2126namespace caffe {
2227
28+ using boost::weak_ptr;
29+ using boost::mt19937;
30+ using boost::uniform_real;
31+ using boost::variate_generator;
32+
2333/* *
2434 * @brief Provides base for data layers that feed blobs to the Net.
2535 *
@@ -52,12 +62,17 @@ class BaseDataLayer : public Layer<Dtype> {
5262 bool output_labels_;
5363};
5464
65+ template <typename Dtype>
66+ class Batch {
67+ public:
68+ Blob<Dtype> data_, label_;
69+ };
70+
5571template <typename Dtype>
5672class BasePrefetchingDataLayer :
5773 public BaseDataLayer<Dtype>, public InternalThread {
5874 public:
59- explicit BasePrefetchingDataLayer (const LayerParameter& param)
60- : BaseDataLayer<Dtype>(param) {}
75+ explicit BasePrefetchingDataLayer (const LayerParameter& param);
6176 virtual ~BasePrefetchingDataLayer () {}
6277 // LayerSetUp: implements common data layer setup functionality, and calls
6378 // DataLayerSetUp to do special data layer setup for individual layer types.
@@ -70,22 +85,63 @@ class BasePrefetchingDataLayer :
7085 virtual void Forward_gpu (const vector<Blob<Dtype>*>& bottom,
7186 const vector<Blob<Dtype>*>& top);
7287
73- virtual void CreatePrefetchThread ();
74- virtual void JoinPrefetchThread ();
75- // The thread's function
76- virtual void InternalThreadEntry () {}
88+ // Prefetches batches (asynchronously if to GPU memory)
89+ static const int PREFETCH_COUNT = 3 ;
7790
7891 protected:
79- Blob<Dtype> prefetch_data_;
80- Blob<Dtype> prefetch_label_;
92+ virtual void InternalThreadEntry ();
93+ virtual void load_batch (Batch<Dtype>* batch) = 0;
94+
95+ Batch<Dtype> prefetch_[PREFETCH_COUNT];
96+ blocking_queue<Batch<Dtype>*> prefetch_free_;
97+ blocking_queue<Batch<Dtype>*> prefetch_full_;
98+ int device_;
99+
81100 Blob<Dtype> transformed_data_;
82101};
83102
103+ // Prefetches datums to host memory that can be read by multiple data layers.
104+ class DataLoader {
105+ public:
106+ DataLoader (const DataParameter& param, int index);
107+ ~DataLoader ();
108+
109+ inline blocking_queue<Datum*>& free () {
110+ return body_.get ()->free_ ;
111+ }
112+ inline blocking_queue<Datum*>& full () {
113+ return body_.get ()->full_ ;
114+ }
115+
116+ protected:
117+ class Body : public InternalThread {
118+ public:
119+ Body (const DataParameter& param, int index);
120+ ~Body ();
121+
122+ void InternalThreadEntry ();
123+
124+ shared_ptr<db::DB> db_;
125+ shared_ptr<db::Cursor> cursor_;
126+
127+ blocking_queue<Datum*> free_;
128+ blocking_queue<Datum*> full_;
129+
130+ DISABLE_COPY_AND_ASSIGN (Body);
131+ };
132+
133+ static map<string, weak_ptr<Body> > instances_;
134+
135+ const string source_;
136+ shared_ptr<Body> body_;
137+
138+ DISABLE_COPY_AND_ASSIGN (DataLoader);
139+ };
140+
84141template <typename Dtype>
85- class DataLayer : public BasePrefetchingDataLayer <Dtype> {
142+ class DataLayer : public BasePrefetchingDataLayer <Dtype> {
86143 public:
87- explicit DataLayer (const LayerParameter& param)
88- : BasePrefetchingDataLayer<Dtype>(param) {}
144+ explicit DataLayer (const LayerParameter& param);
89145 virtual ~DataLayer ();
90146 virtual void DataLayerSetUp (const vector<Blob<Dtype>*>& bottom,
91147 const vector<Blob<Dtype>*>& top);
@@ -96,10 +152,12 @@ class DataLayer : public BasePrefetchingDataLayer<Dtype> {
96152 virtual inline int MaxTopBlobs () const { return 2 ; }
97153
98154 protected:
99- virtual void InternalThreadEntry ();
155+ virtual void load_batch (Batch<Dtype>* batch);
156+ DataLoader* next_loader ();
100157
101- shared_ptr<db::DB> db_;
102- shared_ptr<db::Cursor> cursor_;
158+ vector<shared_ptr<DataLoader> > loaders_;
159+ mt19937 rand_engine_;
160+ uniform_real<float > rand_;
103161};
104162
105163/* *
@@ -236,7 +294,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {
236294 protected:
237295 shared_ptr<Caffe::RNG> prefetch_rng_;
238296 virtual void ShuffleImages ();
239- virtual void InternalThreadEntry ( );
297+ virtual void load_batch (Batch<Dtype>* batch );
240298
241299 vector<std::pair<std::string, int > > lines_;
242300 int lines_id_;
@@ -308,7 +366,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {
308366
309367 protected:
310368 virtual unsigned int PrefetchRand ();
311- virtual void InternalThreadEntry ( );
369+ virtual void load_batch (Batch<Dtype>* batch );
312370
313371 shared_ptr<Caffe::RNG> prefetch_rng_;
314372 vector<std::pair<std::string, vector<int > > > image_database_;
0 commit comments