diff --git a/Makefile b/Makefile index d2e5e5720ed..69dd8ddd687 100644 --- a/Makefile +++ b/Makefile @@ -300,6 +300,14 @@ ifeq ($(CPU_ONLY), 1) COMMON_FLAGS += -DCPU_ONLY endif +# Benchmarks +ifeq ($(BENCHMARK_DATA), 1) + COMMON_FLAGS += -DBENCHMARK_DATA +endif +ifeq ($(BENCHMARK_SOLVER), 1) + COMMON_FLAGS += -DBENCHMARK_SOLVER +endif + # Python layer support ifeq ($(WITH_PYTHON_LAYER), 1) COMMON_FLAGS += -DWITH_PYTHON_LAYER diff --git a/Makefile.config.example b/Makefile.config.example index 7a8aafd7c9f..f7f86967153 100644 --- a/Makefile.config.example +++ b/Makefile.config.example @@ -79,3 +79,7 @@ TEST_GPUID := 0 # enable pretty build (comment to see full commands) Q ?= @ + +# Adds timing info in logs +# BENCHMARK_DATA := 1 +# BENCHMARK_SOLVER := 1 diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 3c829f2f9b0..68a5e1d1d1a 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -10,6 +10,7 @@ #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/benchmark.hpp" diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 5f86bc2625b..1df6b9a14fb 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -98,12 +98,12 @@ void GlobalInit(int* pargc, char*** pargv); class Caffe { public: ~Caffe(); - inline static Caffe& Get() { - if (!singleton_.get()) { - singleton_.reset(new Caffe()); - } - return *singleton_; - } + + // Thread local context for Caffe. Moved to common.cpp instead of + // including boost/thread.hpp to avoid a boost/NVCC issues (#1009, #1010) + // on OSX. Also fails on Linux with CUDA 7.0.18. + static Caffe& Get(); + enum Brew { CPU, GPU }; // This random number generator facade hides boost and CUDA rng @@ -149,6 +149,11 @@ class Caffe { static void SetDevice(const int device_id); // Prints the current GPU status. static void DeviceQuery(); + // Parallel training info + inline static int solver_count() { return Get().solver_count_; } + inline static void set_solver_count(int val) { Get().solver_count_ = val; } + inline static bool root_solver() { return Get().root_solver_; } + inline static void set_root_solver(bool val) { Get().root_solver_ = val; } protected: #ifndef CPU_ONLY @@ -158,7 +163,8 @@ class Caffe { shared_ptr random_generator_; Brew mode_; - static shared_ptr singleton_; + int solver_count_; + bool root_solver_; private: // The private constructor to avoid duplicate instantiation. diff --git a/include/caffe/data_layers.hpp b/include/caffe/data_layers.hpp index 3958cb7ecb0..12e6c366620 100644 --- a/include/caffe/data_layers.hpp +++ b/include/caffe/data_layers.hpp @@ -5,16 +5,17 @@ #include #include -#include "boost/scoped_ptr.hpp" #include "hdf5.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" +#include "caffe/data_reader.hpp" #include "caffe/data_transformer.hpp" #include "caffe/filler.hpp" #include "caffe/internal_thread.hpp" #include "caffe/layer.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/util/blocking_queue.hpp" #include "caffe/util/db.hpp" namespace caffe { @@ -50,12 +51,17 @@ class BaseDataLayer : public Layer { bool output_labels_; }; +template +class Batch { + public: + Blob data_, label_; +}; + template class BasePrefetchingDataLayer : public BaseDataLayer, public InternalThread { public: - explicit BasePrefetchingDataLayer(const LayerParameter& param) - : BaseDataLayer(param) {} + explicit BasePrefetchingDataLayer(const LayerParameter& param); // LayerSetUp: implements common data layer setup functionality, and calls // DataLayerSetUp to do special data layer setup for individual layer types. // This method may not be overridden. @@ -67,22 +73,24 @@ class BasePrefetchingDataLayer : virtual void Forward_gpu(const vector*>& bottom, const vector*>& top); - virtual void CreatePrefetchThread(); - virtual void JoinPrefetchThread(); - // The thread's function - virtual void InternalThreadEntry() {} + // Prefetches batches (asynchronously if to GPU memory) + static const int PREFETCH_COUNT = 3; protected: - Blob prefetch_data_; - Blob prefetch_label_; + virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch) = 0; + + Batch prefetch_[PREFETCH_COUNT]; + BlockingQueue*> prefetch_free_; + BlockingQueue*> prefetch_full_; + Blob transformed_data_; }; template class DataLayer : public BasePrefetchingDataLayer { public: - explicit DataLayer(const LayerParameter& param) - : BasePrefetchingDataLayer(param) {} + explicit DataLayer(const LayerParameter& param); virtual ~DataLayer(); virtual void DataLayerSetUp(const vector*>& bottom, const vector*>& top); @@ -93,10 +101,9 @@ class DataLayer : public BasePrefetchingDataLayer { virtual inline int MaxTopBlobs() const { return 2; } protected: - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); - shared_ptr db_; - shared_ptr cursor_; + DataReader reader_; }; /** @@ -235,7 +242,7 @@ class ImageDataLayer : public BasePrefetchingDataLayer { protected: shared_ptr prefetch_rng_; virtual void ShuffleImages(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); vector > lines_; int lines_id_; @@ -307,7 +314,7 @@ class WindowDataLayer : public BasePrefetchingDataLayer { protected: virtual unsigned int PrefetchRand(); - virtual void InternalThreadEntry(); + virtual void load_batch(Batch* batch); shared_ptr prefetch_rng_; vector > > image_database_; diff --git a/include/caffe/data_reader.hpp b/include/caffe/data_reader.hpp new file mode 100644 index 00000000000..8ed5542cb8d --- /dev/null +++ b/include/caffe/data_reader.hpp @@ -0,0 +1,82 @@ +#ifndef CAFFE_DATA_READER_HPP_ +#define CAFFE_DATA_READER_HPP_ + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/util/blocking_queue.hpp" +#include "caffe/util/db.hpp" + +namespace caffe { + +/** + * @brief Reads data from a source to queues available to data layers. + * A single reading thread is created per source, even if multiple solvers + * are running in parallel, e.g. for multi-GPU training. This makes sure + * databases are read sequentially, and that each solver accesses a different + * subset of the database. Data is distributed to solvers in a round-robin + * way to keep parallel training deterministic. + */ +class DataReader { + public: + explicit DataReader(const LayerParameter& param); + ~DataReader(); + + inline BlockingQueue& free() const { + return queue_pair_->free_; + } + inline BlockingQueue& full() const { + return queue_pair_->full_; + } + + protected: + // Queue pairs are shared between a body and its readers + class QueuePair { + public: + explicit QueuePair(int size); + ~QueuePair(); + + BlockingQueue free_; + BlockingQueue full_; + + DISABLE_COPY_AND_ASSIGN(QueuePair); + }; + + // A single body is created per source + class Body : public InternalThread { + public: + explicit Body(const LayerParameter& param); + virtual ~Body(); + + protected: + void InternalThreadEntry(); + void read_one(db::Cursor* cursor, QueuePair* qp); + + const LayerParameter param_; + BlockingQueue > new_queue_pairs_; + + friend class DataReader; + + DISABLE_COPY_AND_ASSIGN(Body); + }; + + // A source is uniquely identified by its layer name + path, in case + // the same database is read from two different locations in the net. + static inline string source_key(const LayerParameter& param) { + return param.name() + ":" + param.data_param().source(); + } + + const shared_ptr queue_pair_; + shared_ptr body_; + + static map > bodies_; + +DISABLE_COPY_AND_ASSIGN(DataReader); +}; + +} // namespace caffe + +#endif // CAFFE_DATA_READER_HPP_ diff --git a/include/caffe/internal_thread.hpp b/include/caffe/internal_thread.hpp index 815ca54605e..3c32a1d13b3 100644 --- a/include/caffe/internal_thread.hpp +++ b/include/caffe/internal_thread.hpp @@ -14,18 +14,22 @@ namespace caffe { /** * Virtual class encapsulate boost::thread for use in base class * The child class will acquire the ability to run a single thread, - * by reimplementing the virutal function InternalThreadEntry. + * by reimplementing the virtual function InternalThreadEntry. */ class InternalThread { public: - InternalThread() : thread_() {} + InternalThread(); virtual ~InternalThread(); - /** Returns true if the thread was successfully started. **/ - bool StartInternalThread(); + /** + * Caffe's thread local state will be initialized using the current + * thread values, e.g. device id, solver index etc. The random seed + * is initialized using caffe_rng_rand. + */ + void StartInternalThread(); /** Will not return until the internal thread has exited. */ - bool WaitForInternalThreadToExit(); + void StopInternalThread(); bool is_started() const; @@ -34,7 +38,18 @@ class InternalThread { with the code you want your thread to run. */ virtual void InternalThreadEntry() {} + /* Should be tested when running loops to exit when requested. */ + bool must_stop(); + + private: + void entry(); + shared_ptr thread_; + int device_; + Caffe::Brew mode_; + int rand_seed_; + int solver_count_; + bool root_solver_; }; } // namespace caffe diff --git a/include/caffe/layer_factory.hpp b/include/caffe/layer_factory.hpp index 2fcd93869a0..32e849de0d2 100644 --- a/include/caffe/layer_factory.hpp +++ b/include/caffe/layer_factory.hpp @@ -71,7 +71,9 @@ class LayerRegistry { // Get a layer using a LayerParameter. static shared_ptr > CreateLayer(const LayerParameter& param) { - LOG(INFO) << "Creating layer " << param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating layer " << param.name(); + } const string& type = param.type(); CreatorRegistry& registry = Registry(); CHECK_EQ(registry.count(type), 1) << "Unknown layer type: " << type diff --git a/include/caffe/parallel.hpp b/include/caffe/parallel.hpp new file mode 100644 index 00000000000..2b99069dc2d --- /dev/null +++ b/include/caffe/parallel.hpp @@ -0,0 +1,121 @@ +#ifndef CAFFE_PARALLEL_HPP_ +#define CAFFE_PARALLEL_HPP_ + +#include + +#include + +#include "caffe/blob.hpp" +#include "caffe/common.hpp" +#include "caffe/internal_thread.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" +#include "caffe/solver.hpp" +#include "caffe/syncedmem.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +// Represents a net parameters. Once a net is created, its parameter buffers can +// be replaced by ones from Params, to allow parallelization. Params ensures +// parameters are allocated in one consecutive array. +template +class Params { + public: + explicit Params(shared_ptr > root_solver); + virtual ~Params() { + } + + inline size_t size() const { + return size_; + } + inline Dtype* data() const { + return data_; + } + inline Dtype* diff() const { + return diff_; + } + + protected: + const size_t size_; // Size of buffers + Dtype* data_; // Network parameters + Dtype* diff_; // Gradient + +DISABLE_COPY_AND_ASSIGN(Params); +}; + +// Params stored in GPU memory. +template +class GPUParams : public Params { + public: + GPUParams(shared_ptr > root_solver, int device); + virtual ~GPUParams(); + + void configure(Solver* solver) const; + + protected: + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +class DevicePair { + public: + DevicePair(int parent, int device) + : parent_(parent), + device_(device) { + } + inline int parent() { + return parent_; + } + inline int device() { + return device_; + } + + // Group GPUs in pairs, by proximity depending on machine's topology + static void compute(const vector devices, vector* pairs); + + protected: + int parent_; + int device_; +}; + +// Synchronous data parallelism using map-reduce between local GPUs. +template +class P2PSync : public GPUParams, public Solver::Callback, + public InternalThread { + public: + explicit P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param); + virtual ~P2PSync(); + + inline const shared_ptr >& solver() const { + return solver_; + } + + static void run(shared_ptr > root, const vector& gpus); + + // Divide the batch size by the number of solvers + static void divide_batch_size(NetParameter* net); + + protected: + void on_start(Timer* timer, ostringstream* timing); + void on_gradients_ready(Timer* timer, ostringstream* timing); + + void InternalThreadEntry(); + + P2PSync* parent_; + vector*> children_; + BlockingQueue*> queue_; + const int initial_iter_; + Dtype* parent_grads_; + shared_ptr > solver_; + + using Params::size_; + using Params::data_; + using Params::diff_; +}; + +} // namespace caffe + +#endif diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 4dcdc3dc20b..b292ba25ae6 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -5,6 +5,7 @@ #include #include "caffe/net.hpp" +#include "caffe/util/benchmark.hpp" namespace caffe { @@ -32,15 +33,30 @@ class Solver { // function that restores the state from a SolverState protocol buffer. void Restore(const char* resume_file); virtual ~Solver() {} + inline const SolverParameter& param() const { return param_; } inline shared_ptr > net() { return net_; } inline const vector > >& test_nets() { return test_nets_; } int iter() { return iter_; } + // Invoked at specific points during an iteration + class Callback { + protected: + virtual void on_start(Timer* timer, ostringstream* timing) = 0; + virtual void on_gradients_ready(Timer* timer, ostringstream* timing) = 0; + + template + friend class Solver; + }; + const vector& callbacks() const { return callbacks_; } + void add_callback(Callback* value) { + callbacks_.push_back(value); + } + protected: - // Get the update value for the current iteration. - virtual void ComputeUpdateValue() = 0; + // Get and apply the update value for the current iteration. + virtual void Iteration() {} // The Solver::Snapshot function implements the basic snapshotting utility // that stores the learned net. You should implement the SnapshotSolverState() // function that produces a SolverState protocol buffer that needs to be @@ -49,8 +65,12 @@ class Solver { // The test routine void TestAll(); void Test(const int test_net_id = 0); - virtual void SnapshotSolverState(SolverState* state) = 0; - virtual void RestoreSolverState(const SolverState& state) = 0; + virtual void SnapshotSolverState(SolverState* state) { + CHECK(false) << "Should be overriden"; + } + virtual void RestoreSolverState(const SolverState& state) { + CHECK(false) << "Should be overriden"; + } void DisplayOutputBlobs(const int net_id); SolverParameter param_; @@ -58,6 +78,10 @@ class Solver { int current_step_; shared_ptr > net_; vector > > test_nets_; + vector callbacks_; + + Timer iteration_timer_; + float iterations_last_; DISABLE_COPY_AND_ASSIGN(Solver); }; @@ -80,7 +104,9 @@ class SGDSolver : public Solver { protected: void PreSolve(); Dtype GetLearningRate(); - virtual void ComputeUpdateValue(); + virtual void Iteration(); + virtual void Regularize(int param_id); + virtual void ComputeUpdateValue(int param_id, Dtype rate); virtual void ClipGradients(); virtual void SnapshotSolverState(SolverState * state); virtual void RestoreSolverState(const SolverState& state); @@ -90,6 +116,9 @@ class SGDSolver : public Solver { // of gradients/updates and is not needed in snapshots vector > > history_, update_, temp_; + using Solver::iteration_timer_; + using Solver::iterations_last_; + DISABLE_COPY_AND_ASSIGN(SGDSolver); }; @@ -102,7 +131,7 @@ class NesterovSolver : public SGDSolver { : SGDSolver(param_file) {} protected: - virtual void ComputeUpdateValue(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); DISABLE_COPY_AND_ASSIGN(NesterovSolver); }; @@ -116,7 +145,7 @@ class AdaGradSolver : public SGDSolver { : SGDSolver(param_file) { constructor_sanity_check(); } protected: - virtual void ComputeUpdateValue(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); void constructor_sanity_check() { CHECK_EQ(0, this->param_.momentum()) << "Momentum cannot be used with AdaGrad."; diff --git a/include/caffe/syncedmem.hpp b/include/caffe/syncedmem.hpp index 1b726de9564..6bdf84ab660 100644 --- a/include/caffe/syncedmem.hpp +++ b/include/caffe/syncedmem.hpp @@ -8,26 +8,29 @@ namespace caffe { -// Theoretically, CaffeMallocHost and CaffeFreeHost should simply call the -// cudaMallocHost and cudaFree functions in order to create pinned memory. -// However, those codes rely on the existence of a cuda GPU (I don't know -// why that is a must since allocating memory should not be accessing the -// GPU resource, but it just creates an error as of Cuda 5.0) and will cause -// problem when running on a machine without GPU. Thus, we simply define -// these two functions for safety and possible future change if the problem -// of calling cuda functions disappears in a future version. -// -// In practice, although we are creating unpinned memory here, as long as we -// are constantly accessing them the memory pages almost always stays in -// the physical memory (assuming we have large enough memory installed), and -// does not seem to create a memory bottleneck here. - +// If CUDA is available and in GPU mode, host memory will be allocated pinned, +// using cudaMallocHost. It avoids dynamic pinning for transfers (DMA). +// The improvement in performance seems negligible in the single GPU case, +// but might be more significant for parallel training. Most importantly, +// it improved stability for large models on many GPUs. inline void CaffeMallocHost(void** ptr, size_t size) { +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaMallocHost(ptr, size)); + return; + } +#endif *ptr = malloc(size); CHECK(*ptr) << "host allocation of size " << size << " failed"; } inline void CaffeFreeHost(void* ptr) { +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + CUDA_CHECK(cudaFreeHost(ptr)); + return; + } +#endif free(ptr); } @@ -42,20 +45,25 @@ class SyncedMemory { public: SyncedMemory() : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(0), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} explicit SyncedMemory(size_t size) : cpu_ptr_(NULL), gpu_ptr_(NULL), size_(size), head_(UNINITIALIZED), - own_cpu_data_(false) {} + own_cpu_data_(false), own_gpu_data_(false) {} ~SyncedMemory(); const void* cpu_data(); void set_cpu_data(void* data); const void* gpu_data(); + void set_gpu_data(void* data); void* mutable_cpu_data(); void* mutable_gpu_data(); enum SyncedHead { UNINITIALIZED, HEAD_AT_CPU, HEAD_AT_GPU, SYNCED }; SyncedHead head() { return head_; } size_t size() { return size_; } +#ifndef CPU_ONLY + void async_gpu_push(const cudaStream_t& stream); +#endif + private: void to_cpu(); void to_gpu(); @@ -64,6 +72,7 @@ class SyncedMemory { size_t size_; SyncedHead head_; bool own_cpu_data_; + bool own_gpu_data_; DISABLE_COPY_AND_ASSIGN(SyncedMemory); }; // class SyncedMemory diff --git a/include/caffe/util/blocking_queue.hpp b/include/caffe/util/blocking_queue.hpp new file mode 100644 index 00000000000..955e12cc567 --- /dev/null +++ b/include/caffe/util/blocking_queue.hpp @@ -0,0 +1,47 @@ +#ifndef CAFFE_UTIL_BLOCKING_QUEUE_HPP_ +#define CAFFE_UTIL_BLOCKING_QUEUE_HPP_ + +#include +#include + +#include "caffe/common.hpp" + +namespace caffe { + +template +class BlockingQueue { + public: + explicit BlockingQueue(); + + void push(const T& t); + + bool try_pop(T* t); + + // This logs a message if the threads needs to be blocked + // useful for detecting e.g. when data feeding is too slow + T pop(const string& log_on_wait = ""); + + bool try_peek(T* t); + + // Return element without removing it + T peek(); + + size_t size() const; + + protected: + /** + Move synchronization fields out instead of including boost/thread.hpp + to avoid a boost/NVCC issues (#1009, #1010) on OSX. Also fails on + Linux CUDA 7.0.18. + */ + class sync; + + std::queue queue_; + shared_ptr sync_; + +DISABLE_COPY_AND_ASSIGN(BlockingQueue); +}; + +} // namespace caffe + +#endif diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index af96cac40aa..7077f3789d7 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -7,7 +8,15 @@ namespace caffe { -shared_ptr Caffe::singleton_; +// Make sure each thread can have different values. +static boost::thread_specific_ptr thread_instance_; + +Caffe& Caffe::Get() { + if (!thread_instance_.get()) { + thread_instance_.reset(new Caffe()); + } + return *(thread_instance_.get()); +} // random seeding int64_t cluster_seedgen(void) { @@ -42,7 +51,8 @@ void GlobalInit(int* pargc, char*** pargv) { #ifdef CPU_ONLY // CPU-only Caffe. Caffe::Caffe() - : random_generator_(), mode_(Caffe::CPU) { } + : random_generator_(), mode_(Caffe::CPU), + solver_count_(1), root_solver_(true) { } Caffe::~Caffe() { } @@ -86,7 +96,7 @@ void* Caffe::RNG::generator() { Caffe::Caffe() : cublas_handle_(NULL), curand_generator_(NULL), random_generator_(), - mode_(Caffe::CPU) { + mode_(Caffe::CPU), solver_count_(1), root_solver_(true) { // Try to create a cublas handler, and report an error if failed (but we will // keep the program running as one might just want to run CPU code). if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) { diff --git a/src/caffe/data_reader.cpp b/src/caffe/data_reader.cpp new file mode 100644 index 00000000000..16378203a88 --- /dev/null +++ b/src/caffe/data_reader.cpp @@ -0,0 +1,119 @@ +#include +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +using boost::weak_ptr; + +map > DataReader::bodies_; +static boost::mutex bodies_mutex_; + +DataReader::DataReader(const LayerParameter& param) + : queue_pair_(new QueuePair( // + param.data_param().prefetch() * param.data_param().batch_size())) { + // Get or create a body + boost::mutex::scoped_lock lock(bodies_mutex_); + string key = source_key(param); + weak_ptr& weak = bodies_[key]; + body_ = weak.lock(); + if (!body_) { + body_.reset(new Body(param)); + bodies_[key] = weak_ptr(body_); + } + body_->new_queue_pairs_.push(queue_pair_); +} + +DataReader::~DataReader() { + string key = source_key(body_->param_); + body_.reset(); + boost::mutex::scoped_lock lock(bodies_mutex_); + if (bodies_[key].expired()) { + bodies_.erase(key); + } +} + +// + +DataReader::QueuePair::QueuePair(int size) { + // Initialize the free queue with requested number of datums + for (int i = 0; i < size; ++i) { + free_.push(new Datum()); + } +} + +DataReader::QueuePair::~QueuePair() { + Datum* datum; + while (free_.try_pop(&datum)) { + delete datum; + } + while (full_.try_pop(&datum)) { + delete datum; + } +} + +// + +DataReader::Body::Body(const LayerParameter& param) + : param_(param), + new_queue_pairs_() { + StartInternalThread(); +} + +DataReader::Body::~Body() { + StopInternalThread(); +} + +void DataReader::Body::InternalThreadEntry() { + shared_ptr db(db::GetDB(param_.data_param().backend())); + db->Open(param_.data_param().source(), db::READ); + shared_ptr cursor(db->NewCursor()); + vector > qps; + try { + int solver_count = param_.phase() == TRAIN ? Caffe::solver_count() : 1; + + // To ensure deterministic runs, only start running once all solvers + // are ready. But solvers need to peek on one item during initialization, + // so read one item, then wait for the next solver. + for (int i = 0; i < solver_count; ++i) { + shared_ptr qp(new_queue_pairs_.pop()); + read_one(cursor.get(), qp.get()); + qps.push_back(qp); + } + // Main loop + while (!must_stop()) { + for (int i = 0; i < solver_count; ++i) { + read_one(cursor.get(), qps[i].get()); + } + // Check no additional readers have been created. This can happen if + // more than one net is trained at a time per process, whether single + // or multi solver. It might also happen if two data layers have same + // name and same source. + CHECK_EQ(new_queue_pairs_.size(), 0); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } +} + +void DataReader::Body::read_one(db::Cursor* cursor, QueuePair* qp) { + Datum* datum = qp->free_.pop(); + // TODO deserialize in-place instead of copy? + datum->ParseFromString(cursor->value()); + qp->full_.push(datum); + + // go to the next iter + cursor->Next(); + if (!cursor->valid()) { + DLOG(INFO) << "Restarting data prefetching from start."; + cursor->SeekToFirst(); + } +} + +} // namespace caffe diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index b0b98e478c1..482b8c09d24 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -19,7 +19,9 @@ DataTransformer::DataTransformer(const TransformationParameter& param, CHECK_EQ(param_.mean_value_size(), 0) << "Cannot specify mean_file and mean_value at the same time"; const string& mean_file = param.mean_file(); - LOG(INFO) << "Loading mean file from: " << mean_file; + if (Caffe::root_solver()) { + LOG(INFO) << "Loading mean file from: " << mean_file; + } BlobProto blob_proto; ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto); data_mean_.FromProto(blob_proto); diff --git a/src/caffe/internal_thread.cpp b/src/caffe/internal_thread.cpp index c2d19d433b4..2402a192e7e 100644 --- a/src/caffe/internal_thread.cpp +++ b/src/caffe/internal_thread.cpp @@ -1,40 +1,75 @@ #include +#include + #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" namespace caffe { +InternalThread::InternalThread() + : thread_(), + device_(), + mode_(), + rand_seed_(), + solver_count_(), + root_solver_() { +} + InternalThread::~InternalThread() { - WaitForInternalThreadToExit(); + StopInternalThread(); } bool InternalThread::is_started() const { - return thread_.get() != NULL && thread_->joinable(); + return thread_ && thread_->joinable(); } +bool InternalThread::must_stop() { + return thread_ && thread_->interruption_requested(); +} + +void InternalThread::StartInternalThread() { + // TODO switch to failing once Caffe prefetch thread is persistent. + // Threads should not be started and stopped repeatedly. + // CHECK(!is_started()); + StopInternalThread(); + +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDevice(&device_)); +#endif + mode_ = Caffe::mode(); + rand_seed_ = caffe_rng_rand(); + solver_count_ = Caffe::solver_count(); + root_solver_ = Caffe::root_solver(); -bool InternalThread::StartInternalThread() { - if (!WaitForInternalThreadToExit()) { - return false; - } try { - thread_.reset( - new boost::thread(&InternalThread::InternalThreadEntry, this)); - } catch (...) { - return false; + thread_.reset(new boost::thread(&InternalThread::entry, this)); + } catch (std::exception& e) { + CHECK(false) << e.what(); } - return true; } -/** Will not return until the internal thread has exited. */ -bool InternalThread::WaitForInternalThreadToExit() { +void InternalThread::entry() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaSetDevice(device_)); +#endif + Caffe::set_mode(mode_); + Caffe::set_random_seed(rand_seed_); + Caffe::set_solver_count(solver_count_); + Caffe::set_root_solver(root_solver_); + + InternalThreadEntry(); +} + +void InternalThread::StopInternalThread() { if (is_started()) { + thread_->interrupt(); try { thread_->join(); - } catch (...) { - return false; + } catch (boost::thread_interrupted&) { + } catch (std::exception& e) { + CHECK(false) << e.what(); } } - return true; } } // namespace caffe diff --git a/src/caffe/layers/base_data_layer.cpp b/src/caffe/layers/base_data_layer.cpp index 931e4a9c0ab..71504c11bfd 100644 --- a/src/caffe/layers/base_data_layer.cpp +++ b/src/caffe/layers/base_data_layer.cpp @@ -1,7 +1,9 @@ +#include #include #include #include "caffe/data_layers.hpp" +#include "caffe/net.hpp" #include "caffe/util/io.hpp" namespace caffe { @@ -27,54 +29,91 @@ void BaseDataLayer::LayerSetUp(const vector*>& bottom, data_transformer_->InitRand(); } +template +BasePrefetchingDataLayer::BasePrefetchingDataLayer( + const LayerParameter& param) + : BaseDataLayer(param), + prefetch_free_(), prefetch_full_() { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_free_.push(&prefetch_[i]); + } +} + template void BasePrefetchingDataLayer::LayerSetUp( const vector*>& bottom, const vector*>& top) { BaseDataLayer::LayerSetUp(bottom, top); - // Now, start the prefetch thread. Before calling prefetch, we make two - // cpu_data calls so that the prefetch thread does not accidentally make - // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. - this->prefetch_data_.mutable_cpu_data(); - if (this->output_labels_) { - this->prefetch_label_.mutable_cpu_data(); + + // Before starting the prefetch thread, we make cpu_data and gpu_data + // calls so that the prefetch thread does not accidentally make simultaneous + // cudaMalloc calls when the main thread is running. In some GPUs this + // seems to cause failures if we do not so. + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_cpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_cpu_data(); + } + } +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + for (int i = 0; i < PREFETCH_COUNT; ++i) { + prefetch_[i].data_.mutable_gpu_data(); + if (this->output_labels_) { + prefetch_[i].label_.mutable_gpu_data(); + } + } } +#endif + DLOG(INFO) << "Initializing prefetch"; - this->CreatePrefetchThread(); + this->data_transformer_->InitRand(); + StartInternalThread(); DLOG(INFO) << "Prefetch initialized."; } template -void BasePrefetchingDataLayer::CreatePrefetchThread() { - this->data_transformer_->InitRand(); - CHECK(StartInternalThread()) << "Thread execution failed"; -} +void BasePrefetchingDataLayer::InternalThreadEntry() { +#ifndef CPU_ONLY + cudaStream_t stream; + if (Caffe::mode() == Caffe::GPU) { + cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking); + } +#endif -template -void BasePrefetchingDataLayer::JoinPrefetchThread() { - CHECK(WaitForInternalThreadToExit()) << "Thread joining failed"; + try { + while (!must_stop()) { + Batch* batch = prefetch_free_.pop(); + load_batch(batch); +#ifndef CPU_ONLY + if (Caffe::mode() == Caffe::GPU) { + batch->data_.data().get()->async_gpu_push(stream); + cudaStreamSynchronize(stream); + } +#endif + prefetch_full_.push(batch); + } + } catch (boost::thread_interrupted&) { + // Interrupted exception is expected on shutdown + } } template void BasePrefetchingDataLayer::Forward_cpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); - DLOG(INFO) << "Thread joined"; + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(), - this->prefetch_data_.height(), this->prefetch_data_.width()); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.cpu_data(), top[0]->mutable_cpu_data()); DLOG(INFO) << "Prefetch copied"; if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), - top[1]->mutable_cpu_data()); + caffe_copy(batch->label_.count(), batch->label_.cpu_data(), + top[1]->mutable_cpu_data()); } - // Start a new prefetch thread - DLOG(INFO) << "CreatePrefetchThread"; - CreatePrefetchThread(); + + prefetch_free_.push(batch); } #ifdef CPU_ONLY diff --git a/src/caffe/layers/base_data_layer.cu b/src/caffe/layers/base_data_layer.cu index 775f6c47f7e..52085d007a7 100644 --- a/src/caffe/layers/base_data_layer.cu +++ b/src/caffe/layers/base_data_layer.cu @@ -7,20 +7,19 @@ namespace caffe { template void BasePrefetchingDataLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - // First, join the thread - JoinPrefetchThread(); + Batch* batch = prefetch_full_.pop("Data layer prefetch queue empty"); // Reshape to loaded data. - top[0]->Reshape(this->prefetch_data_.num(), this->prefetch_data_.channels(), - this->prefetch_data_.height(), this->prefetch_data_.width()); + top[0]->Reshape(batch->data_.num(), batch->data_.channels(), + batch->data_.height(), batch->data_.width()); // Copy the data - caffe_copy(prefetch_data_.count(), prefetch_data_.cpu_data(), + caffe_copy(batch->data_.count(), batch->data_.gpu_data(), top[0]->mutable_gpu_data()); if (this->output_labels_) { - caffe_copy(prefetch_label_.count(), prefetch_label_.cpu_data(), + caffe_copy(batch->label_.count(), batch->label_.gpu_data(), top[1]->mutable_gpu_data()); } - // Start a new prefetch thread - CreatePrefetchThread(); + + prefetch_free_.push(batch); } INSTANTIATE_LAYER_GPU_FORWARD(BasePrefetchingDataLayer); diff --git a/src/caffe/layers/data_layer.cpp b/src/caffe/layers/data_layer.cpp index 0f2d66776a9..9c23ba0cd29 100644 --- a/src/caffe/layers/data_layer.cpp +++ b/src/caffe/layers/data_layer.cpp @@ -11,36 +11,25 @@ #include "caffe/proto/caffe.pb.h" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" -#include "caffe/util/math_functions.hpp" -#include "caffe/util/rng.hpp" namespace caffe { template -DataLayer::~DataLayer() { - this->JoinPrefetchThread(); +DataLayer::DataLayer(const LayerParameter& param) + : BasePrefetchingDataLayer(param), + reader_(param) { +} + +template +DataLayer::~DataLayer() { + this->StopInternalThread(); } template void DataLayer::DataLayerSetUp(const vector*>& bottom, const vector*>& top) { - // Initialize DB - db_.reset(db::GetDB(this->layer_param_.data_param().backend())); - db_->Open(this->layer_param_.data_param().source(), db::READ); - cursor_.reset(db_->NewCursor()); - - // Check if we should randomly skip a few data points - if (this->layer_param_.data_param().rand_skip()) { - unsigned int skip = caffe_rng_rand() % - this->layer_param_.data_param().rand_skip(); - LOG(INFO) << "Skipping first " << skip << " data points."; - while (skip-- > 0) { - cursor_->Next(); - } - } // Read a data point, and use it to initialize the top blob. - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().peek()); bool force_color = this->layer_param_.data_param().force_encoded_color(); if ((force_color && DecodeDatum(&datum, true)) || @@ -48,42 +37,49 @@ void DataLayer::DataLayerSetUp(const vector*>& bottom, LOG(INFO) << "Decoding Datum"; } // image - int crop_size = this->layer_param_.transform_param().crop_size(); + const int crop_size = this->layer_param_.transform_param().crop_size(); + const int batch_size = this->layer_param_.data_param().batch_size(); if (crop_size > 0) { - top[0]->Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), crop_size, crop_size); - this->transformed_data_.Reshape(1, datum.channels(), crop_size, crop_size); + top[0]->Reshape(batch_size, datum.channels(), crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum.channels(), + crop_size, crop_size); + } + this->transformed_data_.Reshape(1, datum.channels(), + crop_size, crop_size); } else { - top[0]->Reshape( - this->layer_param_.data_param().batch_size(), datum.channels(), + top[0]->Reshape(batch_size, datum.channels(), datum.height(), datum.width()); - this->prefetch_data_.Reshape(this->layer_param_.data_param().batch_size(), - datum.channels(), datum.height(), datum.width()); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].data_.Reshape(batch_size, datum.channels(), + datum.height(), datum.width()); + } this->transformed_data_.Reshape(1, datum.channels(), - datum.height(), datum.width()); + datum.height(), datum.width()); } LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," << top[0]->width(); // label if (this->output_labels_) { - vector label_shape(1, this->layer_param_.data_param().batch_size()); + vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } } -// This function is used to create a thread that prefetches the data. -template -void DataLayer::InternalThreadEntry() { +// This function is called on prefetch thread +template +void DataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); - double read_time = 0; + double deque_time = 0; + double decod_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); // Reshape on single input batches for inputs of varying dimension. @@ -91,8 +87,7 @@ void DataLayer::InternalThreadEntry() { const int crop_size = this->layer_param_.transform_param().crop_size(); bool force_color = this->layer_param_.data_param().force_encoded_color(); if (batch_size == 1 && crop_size == 0) { - Datum datum; - datum.ParseFromString(cursor_->value()); + Datum& datum = *(reader_.full().peek()); if (datum.encoded()) { if (force_color) { DecodeDatum(&datum, true); @@ -100,24 +95,25 @@ void DataLayer::InternalThreadEntry() { DecodeDatumNative(&datum); } } - this->prefetch_data_.Reshape(1, datum.channels(), + batch->data_.Reshape(1, datum.channels(), datum.height(), datum.width()); this->transformed_data_.Reshape(1, datum.channels(), datum.height(), datum.width()); } - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); Dtype* top_label = NULL; // suppress warnings about uninitialized variables if (this->output_labels_) { - top_label = this->prefetch_label_.mutable_cpu_data(); + top_label = batch->label_.mutable_cpu_data(); } for (int item_id = 0; item_id < batch_size; ++item_id) { - timer.Start(); // get a blob - Datum datum; - datum.ParseFromString(cursor_->value()); + timer.Start(); + Datum& datum = *(reader_.full().pop("Waiting for data")); + deque_time += timer.MicroSeconds(); + timer.Start(); cv::Mat cv_img; if (datum.encoded()) { if (force_color) { @@ -132,11 +128,11 @@ void DataLayer::InternalThreadEntry() { << "convert_imageset."; } } - read_time += timer.MicroSeconds(); - timer.Start(); + decod_time += timer.MicroSeconds(); // Apply data transformations (mirror, scale, crop...) - int offset = this->prefetch_data_.offset(item_id); + timer.Start(); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(top_data + offset); if (datum.encoded()) { this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); @@ -147,17 +143,17 @@ void DataLayer::InternalThreadEntry() { top_label[item_id] = datum.label(); } trans_time += timer.MicroSeconds(); - // go to the next iter - cursor_->Next(); - if (!cursor_->valid()) { - DLOG(INFO) << "Restarting data prefetching from start."; - cursor_->SeekToFirst(); - } + + reader_.free().push(const_cast(&datum)); } batch_timer.Stop(); - DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; - DLOG(INFO) << " Read time: " << read_time / 1000 << " ms."; - DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; + +#ifdef BENCHMARK_DATA + LOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms."; + LOG(INFO) << " Dequeue time: " << deque_time / 1000 << " ms."; + LOG(INFO) << " Decode time: " << decod_time / 1000 << " ms."; + LOG(INFO) << "Transform time: " << trans_time / 1000 << " ms."; +#endif } INSTANTIATE_CLASS(DataLayer); diff --git a/src/caffe/layers/image_data_layer.cpp b/src/caffe/layers/image_data_layer.cpp index 38ebbd5ec14..50187bbe5ce 100644 --- a/src/caffe/layers/image_data_layer.cpp +++ b/src/caffe/layers/image_data_layer.cpp @@ -17,7 +17,7 @@ namespace caffe { template ImageDataLayer::~ImageDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -70,11 +70,14 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, const int batch_size = this->layer_param_.image_data_param().batch_size(); if (crop_size > 0) { top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, + crop_size, crop_size); this->transformed_data_.Reshape(1, channels, crop_size, crop_size); } else { top[0]->Reshape(batch_size, channels, height, width); - this->prefetch_data_.Reshape(batch_size, channels, height, width); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape(batch_size, channels, height, width); this->transformed_data_.Reshape(1, channels, height, width); } LOG(INFO) << "output data size: " << top[0]->num() << "," @@ -83,7 +86,9 @@ void ImageDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } } template @@ -93,15 +98,15 @@ void ImageDataLayer::ShuffleImages() { shuffle(lines_.begin(), lines_.end(), prefetch_rng); } -// This function is used to create a thread that prefetches the data. +// This function is called on prefetch thread template -void ImageDataLayer::InternalThreadEntry() { +void ImageDataLayer::load_batch(Batch* batch) { CPUTimer batch_timer; batch_timer.Start(); double read_time = 0; double trans_time = 0; CPUTimer timer; - CHECK(this->prefetch_data_.count()); + CHECK(batch->data_.count()); CHECK(this->transformed_data_.count()); ImageDataParameter image_data_param = this->layer_param_.image_data_param(); const int batch_size = image_data_param.batch_size(); @@ -115,14 +120,14 @@ void ImageDataLayer::InternalThreadEntry() { if (batch_size == 1 && crop_size == 0 && new_height == 0 && new_width == 0) { cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first, 0, 0, is_color); - this->prefetch_data_.Reshape(1, cv_img.channels(), + batch->data_.Reshape(1, cv_img.channels(), cv_img.rows, cv_img.cols); this->transformed_data_.Reshape(1, cv_img.channels(), cv_img.rows, cv_img.cols); } - Dtype* prefetch_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* prefetch_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* prefetch_data = batch->data_.mutable_cpu_data(); + Dtype* prefetch_label = batch->label_.mutable_cpu_data(); // datum scales const int lines_size = lines_.size(); @@ -136,7 +141,7 @@ void ImageDataLayer::InternalThreadEntry() { read_time += timer.MicroSeconds(); timer.Start(); // Apply transformations (mirror, crop...) to the image - int offset = this->prefetch_data_.offset(item_id); + int offset = batch->data_.offset(item_id); this->transformed_data_.set_cpu_data(prefetch_data + offset); this->data_transformer_->Transform(cv_img, &(this->transformed_data_)); trans_time += timer.MicroSeconds(); diff --git a/src/caffe/layers/window_data_layer.cpp b/src/caffe/layers/window_data_layer.cpp index c127d56bc46..f637f2ec6d4 100644 --- a/src/caffe/layers/window_data_layer.cpp +++ b/src/caffe/layers/window_data_layer.cpp @@ -27,7 +27,7 @@ namespace caffe { template WindowDataLayer::~WindowDataLayer() { - this->JoinPrefetchThread(); + this->StopInternalThread(); } template @@ -171,7 +171,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, CHECK_GT(crop_size, 0); const int batch_size = this->layer_param_.window_data_param().batch_size(); top[0]->Reshape(batch_size, channels, crop_size, crop_size); - this->prefetch_data_.Reshape(batch_size, channels, crop_size, crop_size); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) + this->prefetch_[i].data_.Reshape( + batch_size, channels, crop_size, crop_size); LOG(INFO) << "output data size: " << top[0]->num() << "," << top[0]->channels() << "," << top[0]->height() << "," @@ -179,7 +181,9 @@ void WindowDataLayer::DataLayerSetUp(const vector*>& bottom, // label vector label_shape(1, batch_size); top[1]->Reshape(label_shape); - this->prefetch_label_.Reshape(label_shape); + for (int i = 0; i < this->PREFETCH_COUNT; ++i) { + this->prefetch_[i].label_.Reshape(label_shape); + } // data mean has_mean_file_ = this->transform_param_.has_mean_file(); @@ -217,9 +221,9 @@ unsigned int WindowDataLayer::PrefetchRand() { return (*prefetch_rng)(); } -// Thread fetching the data +// This function is called on prefetch thread template -void WindowDataLayer::InternalThreadEntry() { +void WindowDataLayer::load_batch(Batch* batch) { // At each iteration, sample N windows where N*p are foreground (object) // windows and N*(1-p) are background (non-object) windows CPUTimer batch_timer; @@ -227,8 +231,8 @@ void WindowDataLayer::InternalThreadEntry() { double read_time = 0; double trans_time = 0; CPUTimer timer; - Dtype* top_data = this->prefetch_data_.mutable_cpu_data(); - Dtype* top_label = this->prefetch_label_.mutable_cpu_data(); + Dtype* top_data = batch->data_.mutable_cpu_data(); + Dtype* top_label = batch->label_.mutable_cpu_data(); const Dtype scale = this->layer_param_.window_data_param().scale(); const int batch_size = this->layer_param_.window_data_param().batch_size(); const int context_pad = this->layer_param_.window_data_param().context_pad(); @@ -252,7 +256,7 @@ void WindowDataLayer::InternalThreadEntry() { bool use_square = (crop_mode == "square") ? true : false; // zero out batch - caffe_set(this->prefetch_data_.count(), Dtype(0), top_data); + caffe_set(batch->data_.count(), Dtype(0), top_data); const int num_fg = static_cast(static_cast(batch_size) * fg_fraction); diff --git a/src/caffe/net.cpp b/src/caffe/net.cpp index a18ee63818e..f1579b85a27 100644 --- a/src/caffe/net.cpp +++ b/src/caffe/net.cpp @@ -8,6 +8,7 @@ #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/net.hpp" +#include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/insert_splits.hpp" #include "caffe/util/io.hpp" @@ -39,8 +40,13 @@ void Net::Init(const NetParameter& in_param) { // the current NetState. NetParameter filtered_param; FilterNet(in_param, &filtered_param); - LOG(INFO) << "Initializing net from parameters: " << std::endl - << filtered_param.DebugString(); + if (phase_ == TRAIN) { + caffe::P2PSync::divide_batch_size(&filtered_param); + } + if (Caffe::root_solver()) { + LOG(INFO) << "Initializing net from parameters: " << std::endl + << filtered_param.DebugString(); + } // Create a copy of filtered_param with splits added where necessary. NetParameter param; InsertSplits(filtered_param, ¶m); @@ -64,7 +70,9 @@ void Net::Init(const NetParameter& in_param) { const int layer_id = -1; // inputs have fake layer ID -1 AppendTop(param, layer_id, input_id, &available_blobs, &blob_name_to_idx); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + } // For each layer, set up its input and output bottom_vecs_.resize(param.layer_size()); top_vecs_.resize(param.layer_size()); @@ -87,7 +95,9 @@ void Net::Init(const NetParameter& in_param) { } layers_.push_back(LayerRegistry::CreateLayer(layer_param)); layer_names_.push_back(layer_param.name()); - LOG(INFO) << "Creating Layer " << layer_param.name(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating Layer " << layer_param.name(); + } bool need_backward = false; // Figure out this layer's input and output @@ -117,20 +127,30 @@ void Net::Init(const NetParameter& in_param) { } } // After this layer is connected, set it up. - LOG(INFO) << "Setting up " << layer_names_[layer_id]; + if (Caffe::root_solver()) { + LOG(INFO) << "Setting up " << layer_names_[layer_id]; + } layers_[layer_id]->SetUp(bottom_vecs_[layer_id], top_vecs_[layer_id]); for (int top_id = 0; top_id < top_vecs_[layer_id].size(); ++top_id) { if (blob_loss_weights_.size() <= top_id_vecs_[layer_id][top_id]) { blob_loss_weights_.resize(top_id_vecs_[layer_id][top_id] + 1, Dtype(0)); } blob_loss_weights_[top_id_vecs_[layer_id][top_id]] = layer->loss(top_id); - LOG(INFO) << "Top shape: " << top_vecs_[layer_id][top_id]->shape_string(); + if (Caffe::root_solver()) { + LOG(INFO) << "Top shape: " + << top_vecs_[layer_id][top_id]->shape_string(); + } if (layer->loss(top_id)) { - LOG(INFO) << " with loss weight " << layer->loss(top_id); + if (Caffe::root_solver()) { + LOG(INFO) << " with loss weight " << layer->loss(top_id); + } } memory_used_ += top_vecs_[layer_id][top_id]->count(); } - DLOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + DLOG(INFO) << "Memory required for data: " + << memory_used_ * sizeof(Dtype); + } const int param_size = layer_param.param_size(); const int num_param_blobs = layers_[layer_id]->blobs().size(); CHECK_LE(param_size, num_param_blobs) @@ -189,10 +209,14 @@ void Net::Init(const NetParameter& in_param) { } if (!layer_contributes_loss) { layer_need_backward_[layer_id] = false; } if (layer_need_backward_[layer_id]) { - LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " needs backward computation."; + } } else { - LOG(INFO) << layer_names_[layer_id] - << " does not need backward computation."; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] + << " does not need backward computation."; + } } for (int bottom_id = 0; bottom_id < bottom_vecs_[layer_id].size(); ++bottom_id) { @@ -232,7 +256,9 @@ void Net::Init(const NetParameter& in_param) { // In the end, all remaining blobs are considered output blobs. for (set::iterator it = available_blobs.begin(); it != available_blobs.end(); ++it) { - LOG(INFO) << "This network produces output " << *it; + if (Caffe::root_solver()) { + LOG(INFO) << "This network produces output " << *it; + } net_output_blobs_.push_back(blobs_[blob_name_to_idx[*it]].get()); net_output_blob_indices_.push_back(blob_name_to_idx[*it]); } @@ -244,8 +270,10 @@ void Net::Init(const NetParameter& in_param) { } GetLearningRateAndWeightDecay(); debug_info_ = param.debug_info(); - LOG(INFO) << "Network initialization done."; - LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + if (Caffe::root_solver()) { + LOG(INFO) << "Network initialization done."; + LOG(INFO) << "Memory required for data: " << memory_used_ * sizeof(Dtype); + } } template @@ -284,27 +312,33 @@ bool Net::StateMeetsRule(const NetState& state, // Check whether the rule is broken due to phase. if (rule.has_phase()) { if (rule.phase() != state.phase()) { - LOG(INFO) << "The NetState phase (" << state.phase() - << ") differed from the phase (" << rule.phase() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState phase (" << state.phase() + << ") differed from the phase (" << rule.phase() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to min level. if (rule.has_min_level()) { if (state.level() < rule.min_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the min_level (" << rule.min_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the min_level (" << rule.min_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } // Check whether the rule is broken due to max level. if (rule.has_max_level()) { if (state.level() > rule.max_level()) { - LOG(INFO) << "The NetState level (" << state.level() - << ") is above the max_level (" << rule.max_level() - << ") specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState level (" << state.level() + << ") is above the max_level (" << rule.max_level() + << ") specified by a rule in layer " << layer_name; + } return false; } } @@ -317,8 +351,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.stage(i) == state.stage(j)) { has_stage = true; } } if (!has_stage) { - LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState did not contain stage '" << rule.stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -331,8 +367,10 @@ bool Net::StateMeetsRule(const NetState& state, if (rule.not_stage(i) == state.stage(j)) { has_stage = true; } } if (has_stage) { - LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) - << "' specified by a rule in layer " << layer_name; + if (Caffe::root_solver()) { + LOG(INFO) << "The NetState contained a not_stage '" << rule.not_stage(i) + << "' specified by a rule in layer " << layer_name; + } return false; } } @@ -354,7 +392,9 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, if (blob_name_to_idx && layer_param && layer_param->bottom_size() > top_id && blob_name == layer_param->bottom(top_id)) { // In-place computation - LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + if (Caffe::root_solver()) { + LOG(INFO) << layer_param->name() << " -> " << blob_name << " (in-place)"; + } top_vecs_[layer_id].push_back(blobs_[(*blob_name_to_idx)[blob_name]].get()); top_id_vecs_[layer_id].push_back((*blob_name_to_idx)[blob_name]); } else if (blob_name_to_idx && @@ -364,10 +404,12 @@ void Net::AppendTop(const NetParameter& param, const int layer_id, LOG(FATAL) << "Duplicate blobs produced by multiple sources."; } else { // Normal output. - if (layer_param) { - LOG(INFO) << layer_param->name() << " -> " << blob_name; - } else { - LOG(INFO) << "Input " << top_id << " -> " << blob_name; + if (Caffe::root_solver()) { + if (layer_param) { + LOG(INFO) << layer_param->name() << " -> " << blob_name; + } else { + LOG(INFO) << "Input " << top_id << " -> " << blob_name; + } } shared_ptr > blob_pointer(new Blob()); const int blob_id = blobs_.size(); @@ -407,7 +449,9 @@ int Net::AppendBottom(const NetParameter& param, const int layer_id, << " (at index " << bottom_id << ") to layer " << layer_id; } const int blob_id = (*blob_name_to_idx)[blob_name]; - LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + if (Caffe::root_solver()) { + LOG(INFO) << layer_names_[layer_id] << " <- " << blob_name; + } bottom_vecs_[layer_id].push_back(blobs_[blob_id].get()); bottom_id_vecs_[layer_id].push_back(blob_id); available_blobs->erase(blob_name); @@ -456,9 +500,11 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, param_layer_indices_[owner_net_param_id]; const int owner_layer_id = owner_index.first; const int owner_param_id = owner_index.second; - LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " - << "layer '" << layer_names_[owner_layer_id] << "', param " - << "index " << owner_param_id; + if (Caffe::root_solver()) { + LOG(INFO) << "Sharing parameters '" << param_name << "' owned by " + << "layer '" << layer_names_[owner_layer_id] << "', param " + << "index " << owner_param_id; + } Blob* this_blob = layers_[layer_id]->blobs()[param_id].get(); Blob* owner_blob = layers_[owner_layer_id]->blobs()[owner_param_id].get(); @@ -479,7 +525,9 @@ void Net::AppendParam(const NetParameter& param, const int layer_id, template void Net::GetLearningRateAndWeightDecay() { - LOG(INFO) << "Collecting Learning Rate and Weight Decay."; + if (Caffe::root_solver()) { + LOG(INFO) << "Collecting Learning Rate and Weight Decay."; + } ParamSpec default_param_spec; for (int i = 0; i < layers_.size(); ++i) { vector > >& layer_blobs = layers_[i]->blobs(); @@ -581,8 +629,10 @@ void Net::InputDebugInfo(const int input_id) { const Blob& blob = *net_input_blobs_[input_id]; const string& blob_name = blob_names_[net_input_blob_indices_[input_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Input " << blob_name << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Input " << blob_name << " data: " << data_abs_val_mean; + } } template @@ -591,9 +641,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const Blob& blob = *top_vecs_[layer_id][top_id]; const string& blob_name = blob_names_[top_id_vecs_[layer_id][top_id]]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", top blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", top blob " << blob_name + << " data: " << data_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { @@ -601,9 +654,12 @@ void Net::ForwardDebugInfo(const int layer_id) { const int net_param_id = param_id_vecs_[layer_id][param_id]; const string& blob_name = param_display_names_[net_param_id]; const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Forward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << blob_name - << " data: " << data_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Forward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << blob_name + << " data: " << data_abs_val_mean; + } } } @@ -615,18 +671,24 @@ void Net::BackwardDebugInfo(const int layer_id) { const Blob& blob = *bottom_vec[bottom_id]; const string& blob_name = blob_names_[bottom_id_vecs_[layer_id][bottom_id]]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", bottom blob " << blob_name - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", bottom blob " << blob_name + << " diff: " << diff_abs_val_mean; + } } for (int param_id = 0; param_id < layers_[layer_id]->blobs().size(); ++param_id) { if (!layers_[layer_id]->param_propagate_down(param_id)) { continue; } const Blob& blob = *layers_[layer_id]->blobs()[param_id]; const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); - LOG(INFO) << " [Backward] " - << "Layer " << layer_names_[layer_id] << ", param blob " << param_id - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Backward] " + << "Layer " << layer_names_[layer_id] + << ", param blob " << param_id + << " diff: " << diff_abs_val_mean; + } } } @@ -639,17 +701,22 @@ void Net::UpdateDebugInfo(const int param_id) { const Dtype diff_abs_val_mean = blob.asum_diff() / blob.count(); if (param_owner < 0) { const Dtype data_abs_val_mean = blob.asum_data() / blob.count(); - LOG(INFO) << " [Update] Layer " << layer_name - << ", param " << param_display_name - << " data: " << data_abs_val_mean << "; diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param " << param_display_name + << " data: " << data_abs_val_mean + << "; diff: " << diff_abs_val_mean; + } } else { const string& owner_layer_name = layer_names_[param_layer_indices_[param_owner].first]; - LOG(INFO) << " [Update] Layer " << layer_name - << ", param blob " << param_display_name - << " (owned by layer " << owner_layer_name << ", " - << "param " << param_display_names_[param_owners_[param_id]] << ")" - << " diff: " << diff_abs_val_mean; + if (Caffe::root_solver()) { + LOG(INFO) << " [Update] Layer " << layer_name + << ", param blob " << param_display_name + << " (owned by layer " << owner_layer_name << ", " << "param " + << param_display_names_[param_owners_[param_id]] << ")" + << " diff: " << diff_abs_val_mean; + } } } @@ -706,8 +773,8 @@ void Net::Backward() { const Dtype l2norm_data = std::sqrt(sumsq_data); const Dtype l2norm_diff = std::sqrt(sumsq_diff); LOG(ERROR) << " [Backward] All net params (data, diff): " - << "L1 norm = (" << asum_data << ", " << asum_diff << "); " - << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; + << "L1 norm = (" << asum_data << ", " << asum_diff << "); " + << "L2 norm = (" << l2norm_data << ", " << l2norm_diff << ")"; } } diff --git a/src/caffe/parallel.cpp b/src/caffe/parallel.cpp new file mode 100644 index 00000000000..a2c1ca32dfb --- /dev/null +++ b/src/caffe/parallel.cpp @@ -0,0 +1,510 @@ +#ifndef CPU_ONLY +#include +#endif +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "boost/thread.hpp" +#include "caffe/caffe.hpp" +#include "caffe/parallel.hpp" + +namespace caffe { + +enum Op { + copy, + replace_cpu, + replace_gpu, + replace_cpu_diff, + replace_gpu_diff +}; + +template +static void apply_buffers(const vector > >& blobs, + Dtype* buffer, size_t total_size, Op op) { + Dtype* ptr = buffer; + for (int i = 0; i < blobs.size(); ++i) { + int size = blobs[i]->count(); + switch (op) { + case copy: { + // Init buffer to current values of blobs + caffe_copy(size, + reinterpret_cast(blobs[i]->data()->cpu_data()), + ptr); + break; + } + case replace_cpu: + blobs[i]->data()->set_cpu_data(ptr); + break; + case replace_gpu: + blobs[i]->data()->set_gpu_data(ptr); + break; + case replace_cpu_diff: + blobs[i]->diff()->set_cpu_data(ptr); + break; + case replace_gpu_diff: + blobs[i]->diff()->set_gpu_data(ptr); + break; + } + ptr += size; + } + CHECK_EQ(total_size, ptr - buffer); +} + +// Buffer size necessary to store given blobs +template +static size_t total_size(const vector > >& params) { + size_t size = 0; + for (int i = 0; i < params.size(); ++i) + size += params[i]->count(); + return size; +} + +template +Params::Params(shared_ptr > root_solver) + : size_(total_size(root_solver->net()->params())), + data_(), + diff_() { +} + +template +GPUParams::GPUParams(shared_ptr > root_solver, int device) + : Params(root_solver) { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + + // Allocate device buffers + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaMalloc(&data_, size_ * sizeof(Dtype))); + + // Copy blob values + const vector > >& net = root_solver->net()->params(); + apply_buffers(net, data_, size_, copy); + + CUDA_CHECK(cudaMalloc(&diff_, size_ * sizeof(Dtype))); + caffe_gpu_set(size_, Dtype(0), diff_); + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +GPUParams::~GPUParams() { +#ifndef CPU_ONLY + CUDA_CHECK(cudaFree(data_)); + CUDA_CHECK(cudaFree(diff_)); +#endif +} + +template +void GPUParams::configure(Solver* solver) const { + const vector > >& net = solver->net()->params(); + apply_buffers(net, data_, size_, replace_gpu); + apply_buffers(net, diff_, size_, replace_gpu_diff); +} + +// + +void DevicePair::compute(const vector devices, vector* pairs) { +#ifndef CPU_ONLY + vector remaining(devices); + + // Group GPUs by board - some boards can have more than 2 ASICs + for (int d = 0; d < remaining.size(); ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + cudaDeviceProp a, b; + CUDA_CHECK(cudaGetDeviceProperties(&a, remaining[i])); + CUDA_CHECK(cudaGetDeviceProperties(&b, remaining[j])); + if (a.isMultiGpuBoard && b.isMultiGpuBoard) { + if (a.multiGpuBoardGroupID == b.multiGpuBoardGroupID) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "GPU board: " << remaining[i] + << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + } + ostringstream s; + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by boards, remaining: " << s.str(); + + // Group by P2P accessibility - P2P group can be larger than 4 boards + for (int d = 0; d < remaining.size(); ++d) { + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, + remaining[i], + remaining[j])); + if (access) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "P2P pair: " << remaining[i] + << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + } + s.str(""); + for (int i = 0; i < remaining.size(); ++i) { + s << (i ? ", " : "") << remaining[i]; + } + DLOG(INFO) << "GPUs paired by P2P access, remaining: " << s.str(); + + // Group remaining + for (int d = 0; d < remaining.size(); ++d) { // try to pair everyone + for (int i = 0; i < remaining.size(); ++i) { + for (int j = i + 1; j < remaining.size(); ++j) { + pairs->push_back(DevicePair(remaining[i], remaining[j])); + DLOG(INFO) << "Remaining pair: " << remaining[i] + << ":" << remaining[j]; + remaining.erase(remaining.begin() + j); + break; + } + } + } + CHECK_EQ(remaining.size(), 1); + pairs->insert(pairs->begin(), DevicePair(-1, remaining[0])); + + CHECK(pairs->size() == devices.size()); + for (int i = 0; i < pairs->size(); ++i) { + CHECK((*pairs)[i].parent() != (*pairs)[i].device()); + for (int j = i + 1; j < pairs->size(); ++j) { + CHECK((*pairs)[i].device() != (*pairs)[j].device()); + } + } +#else + NO_GPU; +#endif +} + +// + +template +P2PSync::P2PSync(shared_ptr > root_solver, + P2PSync* parent, const SolverParameter& param) + : GPUParams(root_solver, param.device_id()), + parent_(parent), + children_(), + queue_(), + initial_iter_(root_solver->iter()), + solver_() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = param.device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent == NULL) { + solver_ = root_solver; + } else { + Caffe::set_root_solver(false); + solver_.reset(new Solver(param)); + Caffe::set_root_solver(true); + } + this->configure(solver_.get()); + solver_->add_callback(this); + + if (parent) { + // Enable p2p access between devices + const int peer = parent->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(peer, 0)); + } else { + LOG(INFO)<< "GPU " << self << " does not have p2p access to GPU " << peer; + } + // Allocate receiving buffer on parent + CUDA_CHECK(cudaSetDevice(peer)); + CUDA_CHECK(cudaMalloc(&parent_grads_, size_ * sizeof(Dtype))); + CUDA_CHECK(cudaSetDevice(self)); + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#else + NO_GPU; +#endif +} + +template +P2PSync::~P2PSync() { +#ifndef CPU_ONLY + int initial_device; + CUDA_CHECK(cudaGetDevice(&initial_device)); + const int self = solver_->param().device_id(); + CUDA_CHECK(cudaSetDevice(self)); + + if (parent_) { + CUDA_CHECK(cudaFree(parent_grads_)); + const int peer = parent_->solver_->param().device_id(); + int access; + CUDA_CHECK(cudaDeviceCanAccessPeer(&access, self, peer)); + if (access) { + CUDA_CHECK(cudaDeviceDisablePeerAccess(peer)); + } + } + + CUDA_CHECK(cudaSetDevice(initial_device)); +#endif +} + +template +void P2PSync::InternalThreadEntry() { + Caffe::SetDevice(solver_->param().device_id()); + CHECK(Caffe::root_solver()); + Caffe::set_root_solver(false); + // See if there is a defined seed and reset random state if so + if (solver_->param().random_seed() >= 0) { + // Fetch random seed and modulate by device ID to make sure + // everyone doesn't have the same seed. We seem to have some + // solver instability if we have everyone with the same seed + Caffe::set_random_seed( + solver_->param().random_seed() + solver_->param().device_id()); + } + solver_->Step(solver_->param().max_iter() - initial_iter_); +} + +template +void P2PSync::on_start(Timer* timer, ostringstream* timing) { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#else +// CHECK(false); +#endif + + // Wait for update from parent + if (parent_) { + timer->Start(); + P2PSync *parent = queue_.pop(); + CHECK(parent == parent_); + *timing << " recv_param: " << timer->MilliSeconds(); + } + + // Update children + if (children_.size()) { + timer->Start(); + } + for (int i = 0; i < children_.size(); ++i) { + Dtype* src = data_; + Dtype* dst = children_[i]->data_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == children_[i]->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + } + if (children_.size()) { + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + } + for (int i = 0; i < children_.size(); ++i) { + children_[i]->queue_.push(this); + } + if (children_.size()) { + *timing << " send_param: " << timer->MilliSeconds(); + } +#endif +} + +template +void P2PSync::on_gradients_ready(Timer* timer, ostringstream* timing) { +#ifndef CPU_ONLY +#ifdef DEBUG + int device; + CUDA_CHECK(cudaGetDevice(&device)); + CHECK(device == solver_->param().device_id()); +#endif + + // Sum children gradients as they appear in the queue + for (int i = 0; i < children_.size(); ++i) { + timer->Start(); + P2PSync *child = queue_.pop(); + Dtype* src = child->parent_grads_; + Dtype* dst = diff_; + +#ifdef DEBUG + bool ok = false; + for (int j = 0; j < children_.size(); ++j) { + if (child == children_[j]) { + ok = true; + } + } + CHECK(ok); + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == device); +#endif + + caffe_gpu_add(size_, src, dst, dst); + *timing << " add_grad: " << timer->MilliSeconds(); + } + + // Send gradients to parent + if (parent_) { + timer->Start(); + Dtype* src = diff_; + Dtype* dst = parent_grads_; + +#ifdef DEBUG + cudaPointerAttributes attributes; + CUDA_CHECK(cudaPointerGetAttributes(&attributes, src)); + CHECK(attributes.device == device); + CUDA_CHECK(cudaPointerGetAttributes(&attributes, dst)); + CHECK(attributes.device == parent_->solver_->param().device_id()); +#endif + + CUDA_CHECK(cudaMemcpyAsync(dst, src, size_ * sizeof(Dtype), // + cudaMemcpyDeviceToDevice, cudaStreamDefault)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamDefault)); + parent_->queue_.push(this); + *timing << " send_grad: " << timer->MilliSeconds(); + } else { + // Loss functions divide gradients by the batch size, so to compensate + // for split batch, the root solver divides by number of solvers. + caffe_gpu_scal(size_, Dtype(1.0 / Caffe::solver_count()), diff_); + } +#endif +} + +template +void P2PSync::run(shared_ptr > root, + const vector& gpus) { + // Pair devices for map-reduce synchronization + vector pairs; + DevicePair::compute(gpus, &pairs); + ostringstream s; + for (int i = 1; i < pairs.size(); ++i) { + s << (i == 1 ? "" : ", ") << pairs[i].parent() << ":" << pairs[i].device(); + } + LOG(INFO)<< "GPUs pairs " << s.str(); + + SolverParameter param(root->param()); + vector > > syncs(gpus.size()); + syncs[0].reset(new P2PSync(root, NULL, param)); + + // Build the GPU tree by finding the parent for each solver + for (int attempts = 0; attempts < pairs.size(); ++attempts) { + for (int i = 1; i < pairs.size(); ++i) { + if (!syncs[i].get()) { + P2PSync* parent = NULL; + for (int j = 0; j < syncs.size(); ++j) { + if (syncs[j]) { + const SolverParameter& p = syncs[j]->solver()->param(); + if (p.device_id() == pairs[i].parent()) { + parent = (P2PSync*) syncs[j].get(); + } + } + } + if (parent) { + param.set_device_id(pairs[i].device()); + syncs[i].reset(new P2PSync(root, parent, param)); + parent->children_.push_back((P2PSync*) syncs[i].get()); + } + } + } + } + + LOG(INFO)<< "Starting Optimization"; + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StartInternalThread(); + } + + // Run root solver on current thread + syncs[0]->solver_->Solve(); + + for (int i = 1; i < syncs.size(); ++i) { + syncs[i]->StopInternalThread(); + } +} + +template +void P2PSync::divide_batch_size(NetParameter* net) { + int solver_count = Caffe::solver_count(); + for (int i = 0; i < net->layer_size(); ++i) { + string m = "Batch size must be divisible by the number of solvers (GPUs)"; + if (net->layer(i).has_data_param()) { + if (net->layer(i).data_param().has_batch_size()) { + uint32_t total = net->layer(i).data_param().batch_size(); + uint32_t batch = total / solver_count; + CHECK(batch * solver_count == total) << m; + net->mutable_layer(i)->mutable_data_param()->set_batch_size(batch); + + // Also adjust the prefetch count, as it is shared by all solvers + uint32_t prefetch = net->layer(i).data_param().prefetch(); + net->mutable_layer(i)->mutable_data_param()->set_prefetch( + prefetch * solver_count); + } + } + if (net->layer(i).has_hdf5_data_param()) { + if (net->layer(i).hdf5_data_param().has_batch_size()) { + uint32_t total = net->layer(i).hdf5_data_param().batch_size(); + uint32_t batch = total / solver_count; + CHECK(batch * solver_count == total) << m; + net->mutable_layer(i)->mutable_hdf5_data_param()->set_batch_size(batch); + } + } + if (net->layer(i).has_image_data_param()) { + if (net->layer(i).image_data_param().has_batch_size()) { + uint32_t total = net->layer(i).image_data_param().batch_size(); + uint32_t batch = total / solver_count; + CHECK(batch * solver_count == total) << m; + net->mutable_layer(i)->mutable_image_data_param()->set_batch_size( + batch); + } + } + if (net->layer(i).has_memory_data_param()) { + if (net->layer(i).memory_data_param().has_batch_size()) { + uint32_t total = net->layer(i).memory_data_param().batch_size(); + uint32_t batch = total / solver_count; + CHECK(batch * solver_count == total) << m; + net->mutable_layer(i)->mutable_memory_data_param()->set_batch_size( + batch); + } + } + if (net->layer(i).has_window_data_param()) { + if (net->layer(i).window_data_param().has_batch_size()) { + uint32_t total = net->layer(i).window_data_param().batch_size(); + uint32_t batch = total / solver_count; + CHECK(batch * solver_count == total) << m; + net->mutable_layer(i)->mutable_window_data_param()->set_batch_size( + batch); + } + } + } +} + +INSTANTIATE_CLASS(Params); +INSTANTIATE_CLASS(GPUParams); +INSTANTIATE_CLASS(P2PSync); + +} // namespace caffe + diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 307015f42c9..ec5c6947095 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -455,6 +455,7 @@ message DataParameter { // to avoid all asynchronous sgd clients to start at the same point. The skip // point would be set as rand_skip * rand(0,1). Note that rand_skip should not // be larger than the number of keys in the database. + // DEPRECATED. Each solver accesses a different subset of the database. optional uint32 rand_skip = 7 [default = 0]; optional DB backend = 8 [default = LEVELDB]; // DEPRECATED. See TransformationParameter. For data pre-processing, we can do @@ -470,6 +471,9 @@ message DataParameter { optional bool mirror = 6 [default = false]; // Force the encoded image to have 3 color channels optional bool force_encoded_color = 9 [default = false]; + // Prefetch queue (Number of batches to prefetch to host memory, increase if + // data access bandwidth varies). + optional uint32 prefetch = 10 [default = 4]; } message DropoutParameter { diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 877b19b86f8..926490365ff 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -15,13 +15,13 @@ namespace caffe { template Solver::Solver(const SolverParameter& param) - : net_() { + : net_(), callbacks_(), iteration_timer_(), iterations_last_() { Init(param); } template Solver::Solver(const string& param_file) - : net_() { + : net_(), callbacks_(), iteration_timer_(), iterations_last_() { SolverParameter param; ReadProtoFromTextFileOrDie(param_file, ¶m); Init(param); @@ -29,17 +29,21 @@ Solver::Solver(const string& param_file) template void Solver::Init(const SolverParameter& param) { - LOG(INFO) << "Initializing solver from parameters: " << std::endl - << param.DebugString(); + if (Caffe::root_solver()) { + LOG(INFO) << "Initializing solver from parameters: " << std::endl + << param.DebugString(); + } param_ = param; CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative."; - if (param_.random_seed() >= 0) { + if (Caffe::root_solver() && param_.random_seed() >= 0) { Caffe::set_random_seed(param_.random_seed()); } // Scaffolding code InitTrainNet(); - InitTestNets(); - LOG(INFO) << "Solver scaffolding done."; + if (Caffe::root_solver()) { + InitTestNets(); + LOG(INFO) << "Solver scaffolding done."; + } iter_ = 0; current_step_ = 0; } @@ -55,19 +59,27 @@ void Solver::InitTrainNet() { << "one of these fields specifying a train_net: " << field_names; NetParameter net_param; if (param_.has_train_net_param()) { - LOG(INFO) << "Creating training net specified in train_net_param."; + if (Caffe::root_solver()) { + LOG(INFO) << "Creating training net specified in train_net_param."; + } net_param.CopyFrom(param_.train_net_param()); } else if (param_.has_train_net()) { - LOG(INFO) << "Creating training net from train_net file: " - << param_.train_net(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating training net from train_net file: " + << param_.train_net(); + } ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param); } if (param_.has_net_param()) { - LOG(INFO) << "Creating training net specified in net_param."; + if (Caffe::root_solver()) { + LOG(INFO) << "Creating training net specified in net_param."; + } net_param.CopyFrom(param_.net_param()); } if (param_.has_net()) { - LOG(INFO) << "Creating training net from net file: " << param_.net(); + if (Caffe::root_solver()) { + LOG(INFO) << "Creating training net from net file: " << param_.net(); + } ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); } // Set the correct NetState. We start with the solver defaults (lowest @@ -84,6 +96,7 @@ void Solver::InitTrainNet() { template void Solver::InitTestNets() { + CHECK(Caffe::root_solver()); const bool has_net_param = param_.has_net_param(); const bool has_net_file = param_.has_net(); const int num_generic_nets = has_net_param + has_net_file; @@ -167,12 +180,26 @@ void Solver::Step(int iters) { vector losses; Dtype smoothed_loss = 0; + iteration_timer_.Start(); + Timer timer; + ostringstream timing; + while (iter_ < stop_iter) { if (param_.test_interval() && iter_ % param_.test_interval() == 0 - && (iter_ > 0 || param_.test_initialization())) { + && (iter_ > 0 || param_.test_initialization()) + && Caffe::root_solver()) { TestAll(); } + timer.Start(); + timing.str(""); + timing << "Timing "; + if (param().solver_mode() == SolverParameter_SolverMode_GPU) { + timing << "(device " << param().device_id() << ") "; + } + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_start(&timer, &timing); + } const bool display = param_.display() && iter_ % param_.display() == 0; net_->set_debug_info(display && param_.debug_info()); Dtype loss = net_->ForwardBackward(bottom_vec); @@ -186,7 +213,9 @@ void Solver::Step(int iters) { losses[idx] = loss; } if (display) { - LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + if (Caffe::root_solver()) { + LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss; + } const vector*>& result = net_->output_blobs(); int score_index = 0; for (int j = 0; j < result.size(); ++j) { @@ -201,21 +230,34 @@ void Solver::Step(int iters) { loss_msg_stream << " (* " << loss_weight << " = " << loss_weight * result_vec[k] << " loss)"; } - LOG(INFO) << " Train net output #" - << score_index++ << ": " << output_name << " = " - << result_vec[k] << loss_msg_stream.str(); + if (Caffe::root_solver()) { + LOG(INFO) << " Train net output #" + << score_index++ << ": " << output_name << " = " + << result_vec[k] << loss_msg_stream.str(); + } } } } - ComputeUpdateValue(); - net_->Update(); + timing << " grads: " << timer.MilliSeconds(); + for (int i = 0; i < callbacks_.size(); ++i) { + callbacks_[i]->on_gradients_ready(&timer, &timing); + } + timer.Start(); + Iteration(); + timing << " apply: " << timer.MilliSeconds(); + +#ifdef BENCHMARK_SOLVER + LOG(INFO)<< timing.str(); +#endif // Increment the internal iter_ counter -- its value should always indicate // the number of times the weights have been updated. ++iter_; // Save a snapshot if needed. - if (param_.snapshot() && iter_ % param_.snapshot() == 0) { + if (param_.snapshot() + && iter_ % param_.snapshot() == 0 + && Caffe::root_solver()) { Snapshot(); } } @@ -223,6 +265,7 @@ void Solver::Step(int iters) { template void Solver::Solve(const char* resume_file) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Solving " << net_->name(); LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy(); @@ -267,6 +310,7 @@ void Solver::TestAll() { template void Solver::Test(const int test_net_id) { + CHECK(Caffe::root_solver()); LOG(INFO) << "Iteration " << iter_ << ", Testing net (#" << test_net_id << ")"; CHECK_NOTNULL(test_nets_[test_net_id].get())-> @@ -317,13 +361,14 @@ void Solver::Test(const int test_net_id) { << " = " << loss_weight * mean_score << " loss)"; } LOG(INFO) << " Test net output #" << i << ": " << output_name << " = " - << mean_score << loss_msg_stream.str(); + << mean_score << loss_msg_stream.str(); } } template void Solver::Snapshot() { + CHECK(Caffe::root_solver()); NetParameter net_param; // For intermediate results, we will also dump the gradient values. net_->ToProto(&net_param, param_.snapshot_diff()); @@ -348,6 +393,7 @@ void Solver::Snapshot() { template void Solver::Restore(const char* state_file) { + CHECK(Caffe::root_solver()); SolverState state; NetParameter net_param; ReadProtoFromBinaryFile(state_file, &state); @@ -456,95 +502,124 @@ void SGDSolver::ClipGradients() { } template -void SGDSolver::ComputeUpdateValue() { - const vector > >& net_params = this->net_->params(); - const vector& net_params_lr = this->net_->params_lr(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - // get the learning rate +void SGDSolver::Iteration() { + CHECK(Caffe::root_solver()); Dtype rate = GetLearningRate(); if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + float lapse = iteration_timer_.Seconds(); + float per_s = (this->iter_ - iterations_last_) / (lapse ? lapse : 1); + LOG(INFO) << "Iteration " << this->iter_ << " (" << per_s << "/s), " + << "lr = " << rate; + iteration_timer_.Start(); + iterations_last_ = this->iter_; } ClipGradients(); - Dtype momentum = this->param_.momentum(); + for (int param_id = 0; param_id < this->net_->params().size(); ++param_id) { + Regularize(param_id); + ComputeUpdateValue(param_id, rate); + } + this->net_->Update(); +} + +template +void SGDSolver::Regularize(int param_id) { + const vector > >& net_params = this->net_->params(); + const vector& net_params_weight_decay = + this->net_->params_weight_decay(); Dtype weight_decay = this->param_.weight_decay(); string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { - case Caffe::CPU: - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - // Compute the value to history, and then copy them to the blob's diff. - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } + case Caffe::CPU: { + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; } - - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - history_[param_id]->mutable_cpu_data()); - // copy - caffe_copy(net_params[param_id]->count(), - history_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); } break; - case Caffe::GPU: + } + case Caffe::GPU: { #ifndef CPU_ONLY - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - // Compute the value to history, and then copy them to the blob's diff. - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; } - - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - history_[param_id]->mutable_gpu_data()); - // copy - caffe_copy(net_params[param_id]->count(), - history_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); } #else NO_GPU; #endif break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector > >& net_params = this->net_->params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype momentum = this->param_.momentum(); + switch (Caffe::mode()) { + case Caffe::CPU: { + // Compute the value to history, and then copy them to the blob's diff. + Dtype local_rate = rate * net_params_lr[param_id]; + + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + history_[param_id]->mutable_cpu_data()); + // copy + caffe_copy(net_params[param_id]->count(), + history_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // Compute the value to history, and then copy them to the blob's diff. + Dtype local_rate = rate * net_params_lr[param_id]; + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + history_[param_id]->mutable_gpu_data()); + // copy + caffe_copy(net_params[param_id]->count(), + history_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } @@ -562,6 +637,7 @@ void SGDSolver::SnapshotSolverState(SolverState* state) { template void SGDSolver::RestoreSolverState(const SolverState& state) { + CHECK(Caffe::root_solver()); CHECK_EQ(state.history_size(), history_.size()) << "Incorrect length of history blobs."; LOG(INFO) << "SGDSolver: restoring history"; @@ -571,252 +647,146 @@ void SGDSolver::RestoreSolverState(const SolverState& state) { } template -void NesterovSolver::ComputeUpdateValue() { +void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector > >& net_params = this->net_->params(); const vector& net_params_lr = this->net_->params_lr(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - // get the learning rate - Dtype rate = this->GetLearningRate(); - if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; - } - SGDSolver::ClipGradients(); Dtype momentum = this->param_.momentum(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { - case Caffe::CPU: - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - - // update history - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // compute udpate: step back then over step - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->cpu_data(), -momentum, - this->update_[param_id]->mutable_cpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } + case Caffe::CPU: { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + + // update history + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // compute update: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); break; - case Caffe::GPU: + } + case Caffe::GPU: { #ifndef CPU_ONLY - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - - // update history - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // compute udpate: step back then over step - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->gpu_data(), -momentum, - this->update_[param_id]->mutable_gpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + Dtype local_rate = rate * net_params_lr[param_id]; + + // update history + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // compute update: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); #else NO_GPU; #endif break; + } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } } template -void AdaGradSolver::ComputeUpdateValue() { +void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); const vector > >& net_params = this->net_->params(); const vector& net_params_lr = this->net_->params_lr(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - // get the learning rate - Dtype rate = this->GetLearningRate(); Dtype delta = this->param_.delta(); - if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; - } - SGDSolver::ClipGradients(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); switch (Caffe::mode()) { - case Caffe::CPU: - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_add(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); - - // prepare update - caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_cpu_data()); - - caffe_div(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // scale and copy - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->cpu_data(), Dtype(0), - net_params[param_id]->mutable_cpu_diff()); - } + case Caffe::CPU: { + Dtype local_rate = rate * net_params_lr[param_id]; + + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); break; - case Caffe::GPU: + } + case Caffe::GPU: { #ifndef CPU_ONLY - for (int param_id = 0; param_id < net_params.size(); ++param_id) { - Dtype local_rate = rate * net_params_lr[param_id]; - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - this->temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_add(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); - - // prepare update - caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_div(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // scale and copy - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->gpu_data(), Dtype(0), - net_params[param_id]->mutable_gpu_diff()); - } + Dtype local_rate = rate * net_params_lr[param_id]; + + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); #else NO_GPU; #endif break; + } default: LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); } diff --git a/src/caffe/syncedmem.cpp b/src/caffe/syncedmem.cpp index 7617ccfb27f..029d53a5cb6 100644 --- a/src/caffe/syncedmem.cpp +++ b/src/caffe/syncedmem.cpp @@ -12,7 +12,7 @@ SyncedMemory::~SyncedMemory() { } #ifndef CPU_ONLY - if (gpu_ptr_) { + if (gpu_ptr_ && own_gpu_data_) { CUDA_CHECK(cudaFree(gpu_ptr_)); } #endif // CPU_ONLY @@ -51,10 +51,12 @@ inline void SyncedMemory::to_gpu() { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); caffe_gpu_memset(size_, 0, gpu_ptr_); head_ = HEAD_AT_GPU; + own_gpu_data_ = true; break; case HEAD_AT_CPU: if (gpu_ptr_ == NULL) { CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; } caffe_gpu_memcpy(size_, cpu_ptr_, gpu_ptr_); head_ = SYNCED; @@ -92,6 +94,20 @@ const void* SyncedMemory::gpu_data() { #endif } +void SyncedMemory::set_gpu_data(void* data) { +#ifndef CPU_ONLY + CHECK(data); + if (own_gpu_data_) { + CUDA_CHECK(cudaFree(gpu_ptr_)); + } + gpu_ptr_ = data; + head_ = HEAD_AT_GPU; + own_gpu_data_ = false; +#else + NO_GPU; +#endif +} + void* SyncedMemory::mutable_cpu_data() { to_cpu(); head_ = HEAD_AT_CPU; @@ -108,6 +124,19 @@ void* SyncedMemory::mutable_gpu_data() { #endif } +#ifndef CPU_ONLY +void SyncedMemory::async_gpu_push(const cudaStream_t& stream) { + CHECK(head_ == HEAD_AT_CPU); + if (gpu_ptr_ == NULL) { + CUDA_CHECK(cudaMalloc(&gpu_ptr_, size_)); + own_gpu_data_ = true; + } + const cudaMemcpyKind put = cudaMemcpyHostToDevice; + CUDA_CHECK(cudaMemcpyAsync(gpu_ptr_, cpu_ptr_, size_, put, stream)); + // Assume caller will synchronize on the stream before use + head_ = SYNCED; +} +#endif } // namespace caffe diff --git a/src/caffe/test/test_internal_thread.cpp b/src/caffe/test/test_internal_thread.cpp index 31882b6db1d..93f1cc541cd 100644 --- a/src/caffe/test/test_internal_thread.cpp +++ b/src/caffe/test/test_internal_thread.cpp @@ -2,6 +2,7 @@ #include "gtest/gtest.h" #include "caffe/internal_thread.hpp" +#include "caffe/util/math_functions.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -13,11 +14,40 @@ class InternalThreadTest : public ::testing::Test {}; TEST_F(InternalThreadTest, TestStartAndExit) { InternalThread thread; EXPECT_FALSE(thread.is_started()); - EXPECT_TRUE(thread.StartInternalThread()); + thread.StartInternalThread(); EXPECT_TRUE(thread.is_started()); - EXPECT_TRUE(thread.WaitForInternalThreadToExit()); + thread.StopInternalThread(); EXPECT_FALSE(thread.is_started()); } +class TestThreadA : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(4244559767, caffe_rng_rand()); + } +}; + +class TestThreadB : public InternalThread { + void InternalThreadEntry() { + EXPECT_EQ(1726478280, caffe_rng_rand()); + } +}; + +TEST_F(InternalThreadTest, TestRandomSeed) { + TestThreadA t1; + Caffe::set_random_seed(9658361); + t1.StartInternalThread(); + t1.StopInternalThread(); + + TestThreadA t2; + Caffe::set_random_seed(9658361); + t2.StartInternalThread(); + t2.StopInternalThread(); + + TestThreadB t3; + Caffe::set_random_seed(3435563); + t3.StartInternalThread(); + t3.StopInternalThread(); +} + } // namespace caffe diff --git a/src/caffe/test/test_layer_factory.cpp b/src/caffe/test/test_layer_factory.cpp index efb1b37ac42..c86fafd000c 100644 --- a/src/caffe/test/test_layer_factory.cpp +++ b/src/caffe/test/test_layer_factory.cpp @@ -1,11 +1,14 @@ #include #include +#include "boost/scoped_ptr.hpp" #include "gtest/gtest.h" #include "caffe/common.hpp" #include "caffe/layer.hpp" #include "caffe/layer_factory.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -21,11 +24,20 @@ TYPED_TEST(LayerFactoryTest, TestCreateLayer) { typename LayerRegistry::CreatorRegistry& registry = LayerRegistry::Registry(); shared_ptr > layer; - LayerParameter layer_param; for (typename LayerRegistry::CreatorRegistry::iterator iter = registry.begin(); iter != registry.end(); ++iter) { // Special case: PythonLayer is checked by pytest if (iter->first == "Python") { continue; } + LayerParameter layer_param; + // Data layers expect a DB + if (iter->first == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer_param.set_type(iter->first); layer = LayerRegistry::CreateLayer(layer_param); EXPECT_EQ(iter->first, layer->type()); diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index eec627656ef..006720231a5 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2,12 +2,15 @@ #include #include +#include "boost/scoped_ptr.hpp" #include "google/protobuf/text_format.h" #include "gtest/gtest.h" #include "caffe/blob.hpp" #include "caffe/common.hpp" #include "caffe/layer.hpp" +#include "caffe/util/db.hpp" +#include "caffe/util/io.hpp" #include "caffe/util/upgrade_proto.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -2901,6 +2904,15 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { continue; // Empty string isn't actually a valid layer type. } layer_param.set_type(v2_layer_type); + // Data layers expect a DB + if (v2_layer_type == "Data") { + string tmp; + MakeTempDir(&tmp); + boost::scoped_ptr db(db::GetDB(DataParameter_DB_LEVELDB)); + db->Open(tmp, db::NEW); + db->Close(); + layer_param.mutable_data_param()->set_source(tmp); + } layer = LayerRegistry::CreateLayer(layer_param); EXPECT_EQ(v2_layer_type, layer->type()); } diff --git a/src/caffe/util/blocking_queue.cpp b/src/caffe/util/blocking_queue.cpp new file mode 100644 index 00000000000..8a0e9306f18 --- /dev/null +++ b/src/caffe/util/blocking_queue.cpp @@ -0,0 +1,96 @@ +#include +#include + +#include "caffe/data_layers.hpp" +#include "caffe/data_reader.hpp" +#include "caffe/parallel.hpp" +#include "caffe/util/blocking_queue.hpp" + +namespace caffe { + +template +class BlockingQueue::sync { + public: + mutable boost::mutex mutex_; + boost::condition_variable condition_; +}; + +template +BlockingQueue::BlockingQueue() + : sync_(new sync()) { +} + +template +void BlockingQueue::push(const T& t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + queue_.push(t); + lock.unlock(); + sync_->condition_.notify_one(); +} + +template +bool BlockingQueue::try_pop(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + queue_.pop(); + return true; +} + +template +T BlockingQueue::pop(const string& log_on_wait) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + if (!log_on_wait.empty()) { + LOG(INFO)<< log_on_wait; + } + sync_->condition_.wait(lock); + } + + T t = queue_.front(); + queue_.pop(); + return t; +} + +template +bool BlockingQueue::try_peek(T* t) { + boost::mutex::scoped_lock lock(sync_->mutex_); + + if (queue_.empty()) { + return false; + } + + *t = queue_.front(); + return true; +} + +template +T BlockingQueue::peek() { + boost::mutex::scoped_lock lock(sync_->mutex_); + + while (queue_.empty()) { + sync_->condition_.wait(lock); + } + + return queue_.front(); +} + +template +size_t BlockingQueue::size() const { + boost::mutex::scoped_lock lock(sync_->mutex_); + return queue_.size(); +} + +template class BlockingQueue*>; +template class BlockingQueue*>; +template class BlockingQueue; +template class BlockingQueue >; +template class BlockingQueue*>; +template class BlockingQueue*>; + +} // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index 70b15f890f7..3a55236e65b 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -12,13 +12,18 @@ using caffe::Blob; using caffe::Caffe; using caffe::Net; using caffe::Layer; +using caffe::Solver; using caffe::shared_ptr; +using caffe::string; using caffe::Timer; using caffe::vector; - +using std::ostringstream; DEFINE_int32(gpu, -1, - "Run in GPU mode on given device ID."); + "Run in GPU mode on given device ID (Legacy switch, use -gpus)."); +DEFINE_string(gpus, "", + "Run in GPU mode on given device IDs separated by ','." + "Use '-gpus all' to run on all available GPUs."); DEFINE_string(solver, "", "The solver definition protocol buffer text file."); DEFINE_string(model, "", @@ -26,8 +31,8 @@ DEFINE_string(model, "", DEFINE_string(snapshot, "", "Optional; the snapshot solver state to resume training."); DEFINE_string(weights, "", - "Optional; the pretrained weights to initialize finetuning. " - "Cannot be set simultaneously with snapshot."); + "Optional; the pretrained weights to initialize finetuning, " + "separated by ','. Cannot be set simultaneously with snapshot."); DEFINE_int32(iterations, 50, "The number of iterations to run."); @@ -61,6 +66,32 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { } } +// Parse GPU ids or use all available devices +static void get_gpus(vector* gpus) { + if (FLAGS_gpu >= 0) { + FLAGS_gpus = "" + boost::lexical_cast(FLAGS_gpu); + } + if (FLAGS_gpus == "all") { + int count = 0; +#ifndef CPU_ONLY + CUDA_CHECK(cudaGetDeviceCount(&count)); +#else + NO_GPU; +#endif + for (int i = 0; i < count; ++i) { + gpus->push_back(i); + } + } else if (FLAGS_gpus.size()) { + vector strings; + boost::split(strings, FLAGS_gpus, boost::is_any_of(",")); + for (int i = 0; i < strings.size(); ++i) { + gpus->push_back(boost::lexical_cast(strings[i])); + } + } else { + CHECK_EQ(gpus->size(), 0); + } +} + // caffe commands to call by // caffe // @@ -69,10 +100,13 @@ static BrewFunction GetBrewFunction(const caffe::string& name) { // Device Query: show diagnostic information for a GPU device. int device_query() { - CHECK_GT(FLAGS_gpu, -1) << "Need a device ID to query."; - LOG(INFO) << "Querying device ID = " << FLAGS_gpu; - caffe::Caffe::SetDevice(FLAGS_gpu); - caffe::Caffe::DeviceQuery(); + LOG(INFO) << "Querying GPUs " << FLAGS_gpus; + vector gpus; + get_gpus(&gpus); + for (int i = 0; i < gpus.size(); ++i) { + caffe::Caffe::SetDevice(gpus[i]); + caffe::Caffe::DeviceQuery(); + } return 0; } RegisterBrewFunction(device_query); @@ -101,34 +135,44 @@ int train() { caffe::SolverParameter solver_param; caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); - // If the gpu flag is not provided, allow the mode and device to be set + // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. - if (FLAGS_gpu < 0 - && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) { - FLAGS_gpu = solver_param.device_id(); + if (FLAGS_gpu < 0 && FLAGS_gpus.size() == 0 + && solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU + && solver_param.has_device_id()) { + FLAGS_gpus = "" + boost::lexical_cast(solver_param.device_id()); } - // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); - Caffe::set_mode(Caffe::GPU); - } else { - LOG(INFO) << "Use CPU."; + vector gpus; + get_gpus(&gpus); + if (gpus.size() == 0) { Caffe::set_mode(Caffe::CPU); + } else { + ostringstream s; + for (int i = 0; i < gpus.size(); ++i) { + s << (i ? ", " : "") << gpus[i]; + } + LOG(INFO) << "Using GPUs " << s.str(); + + solver_param.set_device_id(gpus[0]); + Caffe::SetDevice(gpus[0]); + Caffe::set_mode(Caffe::GPU); + Caffe::set_solver_count(gpus.size()); } - LOG(INFO) << "Starting Optimization"; - shared_ptr > - solver(caffe::GetSolver(solver_param)); + shared_ptr > solver(caffe::GetSolver(solver_param)); if (FLAGS_snapshot.size()) { LOG(INFO) << "Resuming from " << FLAGS_snapshot; - solver->Solve(FLAGS_snapshot); + solver->Restore(FLAGS_snapshot.c_str()); } else if (FLAGS_weights.size()) { - CopyLayers(&*solver, FLAGS_weights); - solver->Solve(); + CopyLayers(solver.get(), FLAGS_weights); + } + + if (gpus.size() > 1) { + caffe::P2PSync::run(solver, gpus); } else { + LOG(INFO) << "Starting Optimization"; solver->Solve(); } LOG(INFO) << "Optimization Done."; @@ -143,9 +187,11 @@ int test() { CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU."; @@ -208,9 +254,11 @@ int time() { CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to time."; // Set device id and mode - if (FLAGS_gpu >= 0) { - LOG(INFO) << "Use GPU with device ID " << FLAGS_gpu; - Caffe::SetDevice(FLAGS_gpu); + vector gpus; + get_gpus(&gpus); + if (gpus.size() != 0) { + LOG(INFO) << "Use GPU with device ID " << gpus[0]; + Caffe::SetDevice(gpus[0]); Caffe::set_mode(Caffe::GPU); } else { LOG(INFO) << "Use CPU.";