include/caffe/solver_factory.hpp中的solverRegistry类和SolverRegisterer 类
/**
* @brief A solver factory that allows one to register solvers, similar to
* layer factory. During runtime, registered solvers could be called by passing
* a SolverParameter protobuffer to the CreateSolver function:
*
* SolverRegistry<Dtype>::CreateSolver(param);
*
* There are two ways to register a solver. Assuming that we have a solver like:
*
* template <typename Dtype>
* class MyAwesomeSolver : public Solver<Dtype> {
* // your implementations
* };
*
* and its type is its C++ class name, but without the "Solver" at the end
* ("MyAwesomeSolver" -> "MyAwesome").
*
* If the solver is going to be created simply by its constructor, in your C++
* file, add the following line:
*
* REGISTER_SOLVER_CLASS(MyAwesome);
*
* Or, if the solver is going to be created by another creator function, in the
* format of:
*
* template <typename Dtype>
* Solver<Dtype*> GetMyAwesomeSolver(const SolverParameter& param) {
* // your implementation
* }
*
* then you can register the creator function instead, like
*
* REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver)
*
* Note that each solver type should only be registered once.
*/
#ifndef CAFFE_SOLVER_FACTORY_H_
#define CAFFE_SOLVER_FACTORY_H_
#include <map>
#include <string>
#include <vector>
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
template <typename Dtype>
class Solver;
template <typename Dtype>
class SolverRegistry {
public:
typedef Solver<Dtype>* (*Creator)(const SolverParameter&);
typedef std::map<string, Creator> CreatorRegistry;
//所有成员函数都是静态的,通过类名调用
static CreatorRegistry& Registry() {
static CreatorRegistry* g_registry_ = new CreatorRegistry();//g_registry是指向CreatorRegistry
这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,而且在其他地方修改这个map里的内容,是存储在这个map中的。事实上各个Solver的register的过程正是往g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。
return *g_registry_;
}
// 添加一个creator指针
static void AddCreator(const string& type, Creator creator) {
CreatorRegistry& registry = Registry();
CHECK_EQ(registry.count(type), 0)
<< "Solver type " << type << " already registered.";
registry[type] = creator;//如果没有注册就添加到registor静态指针指向的map中
}
// Get a solver using a SolverParameter.
static Solver<Dtype>* CreateSolver(const SolverParameter& param) {
const string& type = param.type();//先定义了一个string类型的变量表示Solver的类型
CreatorRegistry& registry = Registry();//通过调用Registry()函数,Registry()中创建CreatorRegistry类的对象,定义了一个key类型为string,value类型为Creator
的map:registry.其中Creator
是一个solver函数指针类型,指向的函数的参数为SolverParameter
类型
CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type//果是一个已经register过的Solver类型,那么registry.count(type)
应该为1
<< " (known types: " << SolverTypeListString() << ")";
return registry[type](param);//返回registry中type对应的creator对象,并调用这个creator函数,将creator返回的Solver<Dtype>*
返回
}
static vector<string> SolverTypeList() {
CreatorRegistry& registry = Registry();
vector<string> solver_types;
for (typename CreatorRegistry::iterator iter = registry.begin();
iter != registry.end(); ++iter) {
solver_types.push_back(iter->first);
}
return solver_types;
}
private:
// Solver registry should never be instantiated - everything is done with its
// static variables.
SolverRegistry() {} //构造函数时私有,所以没办法创造该类的变量.直接用类名调用
static string SolverTypeListString() {
vector<string> solver_types = SolverTypeList();
string solver_types_str;
for (vector<string>::iterator iter = solver_types.begin();
iter != solver_types.end(); ++iter) {
if (iter != solver_types.begin()) {
solver_types_str += ", ";
}
solver_types_str += *iter;
}
return solver_types_str;
}
};
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
在sgd_solver.cpp(SGD Solver对应的cpp文件)末尾使用了REGISTER_SOLVER_CLASS
这个宏,这个宏会定义一个名为Creator_SGDSolver
的函数,这个函数即为Creator
类型的指针指向的函数,在这个函数中调用了SGDSolver
的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了REGISTER_SOLVER_CREATOR
这个宏,这里分别定义了SolverRegisterer
这个模板类的float和double类型的static变量,这会去调用各自的构造函数,而在SolverRegisterer
的构造函数中调用了之前提到的SolverRegistry
类的AddCreator
函数,这个函数就是将刚才定义的Creator_SGDSolver
这个函数的指针存到g_registry指向的map里面。类似地,所有的Solver对应的cpp文件的末尾都调用了这个宏来完成注册,在所有的Solver都注册之后,我们就可以通过之前描述的方式,通过g_registry得到对应的Creator函数的指针,并通过调用这个Creator函数来构造对应的Solver。
template <typename Dtype>
class SolverRegisterer {
public:
SolverRegisterer(const string& type,
Solver<Dtype>* (*creator)(const SolverParameter&)) {
// LOG(INFO) << "Registering solver type: " << type;
SolverRegistry<Dtype>::AddCreator(type, creator);
}
};
//
分别定义了SolverRegisterer这个模板类的float和double类型的static变量,这会去调用各自的构造函数,而在SolverRegisterer的构造函数中调用了之前提到的SolverRegistry类的
AddCreator函数,这个函数就是将刚才定义的Creator_SGDSolver这个函数的指针存到g_registry指向的map里面。
#define REGISTER_SOLVER_CREATOR(type, creator) \
static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>); \
static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>) \
这个宏会定义一个名为Creator_×××Solver的函数,这个函数即为Creator类型的指针指向的函数,在这个函数中调用了×××Solver的构造函数,并将构造的这个变量得到的指针返回,这也就是Creator
类型函数的作用:构造一个对应类型的Solver对象,将其指针返回。然后在这个宏里又调用了REGISTER_SOLVER_CREATOR这个宏
#define REGISTER_SOLVER_CLASS(type) \
template <typename Dtype> \
Solver<Dtype>* Creator_##type##Solver( \
const SolverParameter& param) \
{ \
return new type##Solver<Dtype>(param); \
} \
REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)
} // namespace caffe
#endif // CAFFE_SOLVER_FACTORY_H_
include/caffe/solver.hpp
#ifndef CAFFE_SOLVER_HPP_
#define CAFFE_SOLVER_HPP_
#include <boost/function.hpp>
#include <string>
#include <vector>
#include "caffe/net.hpp"
#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
/* |
* (1)solver_factory的register和create不同类型Solver的机制, |
* (2)通过signal_handler来获取系统信号,并根据用户或默认的设置进行相应的处理, |
* (3)Solver::Solve函数的具体实现的分析, |
* (4)SGDSolver::ApplyUpdate函数的具体实现。前面三个部分都属于基类的, |
* 最后一个是SGDSolver这个子类的,如果用户想要实现自己的Solver类, |
* 也应该类似地去继承基类,并实现自己的ApplyUpdate函数,在代码的末尾通过 |
* register宏完成注册,便可以被成功的调用。 |
namespace caffe {
/**
按Ctrl-C时,会保存当前训练时的模型
如果还在训练终端不小心被关闭时,可以接着上次继续训练
*/
namespace SolverAction {
enum Enum {
NONE = 0, // Take no special action.
STOP = 1, // Stop training. snapshot_after_train controls whether a
// snapshot is created.
SNAPSHOT = 2 // Take a snapshot, and keep training.
};
}
/**
* @brief Type of a function that returns a Solver Action enumeration.
*/
typedef boost::function<SolverAction::Enum()> ActionCallback;
/**
* @brief An interface for classes that perform optimization on Net%s.
*
* Requires implementation of ApplyUpdate to compute a parameter update
* given the current state of the Net parameters.
*/
template <typename Dtype>
class Solver {
public:
explicit Solver(const SolverParameter& param);
explicit Solver(const string& param_file);
void Init(const SolverParameter& param);
void InitTrainNet();
void InitTestNets();
// Client of the Solver optionally may call this in order to set the function
// that the solver uses to see what action it should take (e.g. snapshot or
// exit training early).
void SetActionFunction(ActionCallback func);
SolverAction::Enum GetRequestedAction();
// The main entry of the solver function. In default, iter will be zero. Pass
// in a non-zero iter number to resume training for a pre-trained net.
virtual void Solve(const char* resume_file = NULL);
inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
void Step(int iters);
// The Restore method simply dispatches to one of the
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
int iter() const { return iter_; }
// Invoked at specific points during an iteration
class Callback {
protected:
virtual void on_start() = 0;
virtual void on_gradients_ready() = 0;
template <typename T>
friend class Solver;
};
const vector<Callback*>& callbacks() const { return callbacks_; }
void add_callback(Callback* value) {
callbacks_.push_back(value);
}
void CheckSnapshotWritePermissions();
/**
* @brief Returns the solver type.
*/
virtual inline const char* type() const { return ""; }
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
// The test routine
void TestAll();
void Test(const int test_net_id = 0);
virtual void SnapshotSolverState(const string& model_filename) = 0;
virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
void DisplayOutputBlobs(const int net_id);
void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
SolverParameter param_;
int iter_;
int current_step_;
shared_ptr<Net<Dtype> > net_;
vector<shared_ptr<Net<Dtype> > > test_nets_;
vector<Callback*> callbacks_;
vector<Dtype> losses_;
Dtype smoothed_loss_;
// A function that can be set by a client of the Solver to provide indication
// that it wants a snapshot saved and/or to exit early.
ActionCallback action_request_function_;
// True iff a request to stop early was received.
bool requested_early_exit_;
// Timing information, handy to tune e.g. nbr of GPUs
Timer iteration_timer_;
float iterations_last_;
DISABLE_COPY_AND_ASSIGN(Solver);
};
} // namespace caffe
#endif // CAFFE_SOLVER_HPP_
src/caffe/solver.cpp
#include <cstdio>
#include <string>
#include <vector>
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
//确定solver层的适用方式
void Solver<Dtype>::SetActionFunction(ActionCallback func) {
action_request_function_ = func;
}
template<typename Dtype>
SolverAction::Enum Solver<Dtype>::GetRequestedAction() {
if (action_request_function_) {
// If the external request function has been set, call it.
return action_request_function_();
}
return SolverAction::NONE;
}
//设计好需要优化的对象,以及用于学习的训练网络和用于评估的测试网络
//构造函数:初始化net,调用init(),有两种调用参数的方式
//1.使用SolverParamter类型的param
template <typename Dtype>Solver<Dtype>::Solver(const SolverParameter& param)
: net_(), callbacks_(), requested_early_exit_(false) {
Init(param);
}
//2.使用string类型的param_file
template <typename Dtype>
Solver<Dtype>::Solver(const string& param_file)
: net_(), callbacks_(), requested_early_exit_(false) {
SolverParameter param;
ReadSolverParamsFromTextFileOrDie(param_file, ¶m);
Init(param);
}
template <typename Dtype>
void Solver<Dtype>::Init(const SolverParameter& param) {
LOG_IF(INFO, Caffe::root_solver()) << "Initializing solver from parameters: "
<< std::endl << param.DebugString();
param_ = param;
CHECK_GE(param_.average_loss(), 1) << "average_loss should be non-negative.";
CheckSnapshotWritePermissions();
if (param_.random_seed() >= 0) { //设置随机种子
Caffe::set_random_seed(param_.random_seed() + Caffe::solver_rank());
}
// Scaffolding code
InitTrainNet();/ /初始化训练网络,net指向这个空间
if (Caffe::root_solver()) {
InitTestNets(); //初始化测试网络,net指向这个空间
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
current_step_ = 0;
}
//初始化训练网络
template <typename Dtype>void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
param_.has_train_net() + param_.has_train_net_param(); //训练网络数量
const string& field_names = "net, net_param, train_net, train_net_param"; //区域名字
CHECK_GE(num_train_nets, 1) << "SolverParameter must specify a train net "
<< "using one of these fields: " << field_names;
CHECK_LE(num_train_nets, 1) << "SolverParameter must not contain more than "
<< "one of these fields specifying a train_net: " << field_names;/ /训练网络数量超过,报错
NetParameter net_param;// 网络参数
if (param_.has_train_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in train_net_param.";
net_param.CopyFrom(param_.train_net_param());/ /从训练网络中复制参数
} else if (param_.has_train_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from train_net file: " << param_.train_net();
ReadNetParamsFromTextFileOrDie(param_.train_net(), &net_param);//从训练网络中读取参数
}
if (param_.has_net_param()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net specified in net_param.";
net_param.CopyFrom(param_.net_param());/ /从测试网络中复制参数
}
if (param_.has_net()) {
LOG_IF(INFO, Caffe::root_solver())
<< "Creating training net from net file: " << param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_param); //从测试网络中读取参数
}
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param itself;
// finally, merge in any NetState specified by the train_state (highest
// precedence).
NetState net_state;
net_state.set_phase(TRAIN);
net_state.MergeFrom(net_param.state());
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);/ /设置solver的初始化参数,混合网络参数本身的网络状态
net_.reset(new Net<Dtype>(net_param)); //调用模板类的构造函数,进行net初始化
}
//初始化测试网络
template <typename Dtype>void Solver<Dtype>::InitTestNets() {
CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file ;//同类网络数量
CHECK_LE(num_generic_nets, 1)
<< "Both net_param and net_file may not be specified.";
const int num_test_net_params = param_.test_net_param_size();
const int num_test_net_files = param_.test_net_size();
const int num_test_nets = num_test_net_params + num_test_net_files; //测试网络数量
if (num_generic_nets) {
CHECK_GE(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
} else {
CHECK_EQ(param_.test_iter_size(), num_test_nets)
<< "test_iter must be specified for each test network.";
}
// If we have a generic net (specified by net or net_param, rather than
// test_net or test_net_param), we may have an unlimited number of actual
// test networks -- the actual number is given by the number of remaining
// test_iters after any test nets specified by test_net_param and/or test_net
// are evaluated.
const int num_generic_net_instances = param_.test_iter_size() - num_test_nets;
const int num_test_net_instances = num_test_nets + num_generic_net_instances;
if (param_.test_state_size()) {
CHECK_EQ(param_.test_state_size(), num_test_net_instances)
<< "test_state must be unspecified or specified once per test net.";
}
if (num_test_net_instances) {
CHECK_GT(param_.test_interval(), 0);
}
int test_net_id = 0;
vector<string> sources(num_test_net_instances);
vector<NetParameter> net_params(num_test_net_instances);
for (int i = 0; i < num_test_net_params; ++i, ++test_net_id) {
sources[test_net_id] = "test_net_param"; //对网络参数进行标记
net_params[test_net_id].CopyFrom(param_.test_net_param(i)); //复制网络参数
}
for (int i = 0; i < num_test_net_files; ++i, ++test_net_id) {
sources[test_net_id] = "test_net file: " + param_.test_net(i);// 对网络参数进行标记
ReadNetParamsFromTextFileOrDie(param_.test_net(i) ,//复制网络参数
&net_params[test_net_id]);
}
const int remaining_test_nets = param_.test_iter_size() - test_net_id;
if (has_net_param) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net_param";net_params[test_net_id].CopyFrom(param_.net_param());
}
}
if (has_net_file) {
for (int i = 0; i < remaining_test_nets; ++i, ++test_net_id) {
sources[test_net_id] = "net file: " + param_.net();
ReadNetParamsFromTextFileOrDie(param_.net(), &net_params[test_net_id]);
}
}
test_nets_.resize(num_test_net_instances);
for (int i = 0; i < num_test_net_instances; ++i) {
// Set the correct NetState. We start with the solver defaults (lowest
// precedence); then, merge in any NetState specified by the net_param
// itself; finally, merge in any NetState specified by the test_state
// (highest precedence).
NetState net_state;
net_state.set_phase(TEST);
net_state.MergeFrom(net_params[i].state());
if (param_.test_state_size()) {
net_state.MergeFrom(param_.test_state(i));
}
net_params[i].mutable_state()->CopyFrom(net_state);
LOG(INFO)
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
}
}
//step()函数
template <typename Dtype>void Solver<Dtype>::Step(int iters) {
const int start_iter = iter_; //设置开始的迭代次数(如果是从之前的snapshot恢复的,那么开始的迭代次数是snapshot结束时的迭代次数)
const int stop_iter = iter_ + iters ;//设置结束的迭代次数
int average_loss = this->param_.average_loss(); //输出的loss是以前的average_loss次的loss平均值,在solver.prototxt里设置,默认为1
losses_.clear();
smoothed_loss_ = 0;
iteration_timer_.Start();
//迭代
while (iter_ < stop_iter) {
// 清空上一次所有参数的梯度
net_->ClearParamDiffs();
if (param_.test_interval() && iter_ % param_.test_interval() == 0
&& (iter_ > 0 || param_.test_initialization())) {/ /判断是否需要测试
if (Caffe::root_solver()) {
TestAll();
}
if (requested_early_exit_) {
// 判断是都需要提前结束
break;
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_start();
}
const bool display = param_.display() && iter_ % param_.display() == 0;
net_->set_debug_info(display && param_.debug_info()); //输出信息
// accumulate the loss and gradient
Dtype loss = 0;
for (int i = 0; i < param_.iter_size(); ++i) {
loss += net_->ForwardBackward();
}
loss /= param_.iter_size();//每次迭代loss求均值
//计算要输出的smooth_loss
UpdateSmoothedLoss(loss, start_iter, average_loss);
float lapse = iteration_timer_.Seconds();
float per_s = (iter_ - iterations_last_) / (lapse ? lapse : 1);
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << iter_
<< " (" << per_s << " iter/s, " << lapse << "s/"
<< param_.display() << " iters), loss = " << smoothed_loss_;
iteration_timer_.Start();
iterations_last_ = iter_;
const vector<Blob<Dtype>*>& result = net_->output_blobs();/ /输出blob付给result
int score_index = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
const string& output_name =
net_->blob_names()[net_->output_blob_indices()[j]]; //输出名字
const Dtype loss_weight =
net_->blob_loss_weights()[net_->output_blob_indices()[j]];/ /输出loss
for (int k = 0; k < result[j]->count(); ++k) {
ostringstream loss_msg_stream;
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * result_vec[k] << " loss)";
}
LOG_IF(INFO, Caffe::root_solver()) << " Train net output #"
<< score_index++ << ": " << output_name << " = "
<< result_vec[k] << loss_msg_stream.str();
}
}
}
for (int i = 0; i < callbacks_.size(); ++i) {
callbacks_[i]->on_gradients_ready();
}
// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;/ /迭代次数加1
SolverAction::Enum request = GetRequestedAction();
// Save a snapshot if needed.
if ((param_.snapshot()
&& iter_ % param_.snapshot() == 0
&& Caffe::root_solver()) ||
(request == SolverAction::SNAPSHOT)) {
Snapshot() ;//存储snapshot
}
if (SolverAction::STOP == request) {
requested_early_exit_ = true;
// Break out of training loop.
break;
}
}
}
//solve函数
template <typename Dtype>void Solver<Dtype>::Solve(const char* resume_file) {
CHECK(Caffe::root_solver());
//检查当前是否是root_solver(多gpu模式下,只有root_soler才运行这一部分代码)
LOG(INFO) << "Solving " << net_->name();LOG(INFO) << "Learning Rate Policy: " << param_.lr_policy();
//输出学习率更新策略
// Initialize to false every time we start solving.
requested_early_exit_ = false;
//初始化为FALSE,表示没有要求在优化结束前退出
if (resume_file) {
LOG(INFO) << "Restoring previous solver status from " << resume_file;
Restore(resume_file);
}
//如果resume_file指针不为空,则需要从存储的路径里读取之前的训练状态
// For a network that is trained by the solver, no bottom or top vecs
// should be given, and we will just provide dummy vecs.
int start_iter = iter_;
Step(param_.max_iter() - iter_);
//调用step函数,执行实际的逐步迭代过程
// If we haven't already, save a snapshot after optimization, unless// overridden by setting snapshot_after_train := false
if (param_.snapshot_after_train()
&& (!param_.snapshot() || iter_ % param_.snapshot() != 0)) {
Snapshot();
}
//迭代结束或者遇到系统信号提前结束后,判断是否需要在训练结束后snapshot,这个可以在solver.prototxt里设置
LOG(INFO) << "Optimization stopped early.";
return;
}
//如果在step函数中遇到了提前结束的信号,且我们的处理方式是stop,那么requested_early_exit_会被修改为TRUE,所以进入函数内部迭代提前结束,输出信息
//优化完成以后,运行一个额外的训练和测试过程展示训练测试的loss或者输出if (param_.display() && iter_ % param_.display() == 0) {
int average_loss = this->param_.average_loss();
Dtype loss;
net_->Forward(&loss);
UpdateSmoothedLoss(loss, start_iter, average_loss);
//判断是否需要输出最后的loss
LOG(INFO) << "Iteration " << iter_ << ", loss = " << smoothed_loss_;
}
if (param_.test_interval() && iter_ % param_.test_interval() == 0) {
TestAll();
}
//判断是否需要最后的测试
LOG(INFO) << "Optimization Done.";}
template <typename Dtype>
void Solver<Dtype>::TestAll() { //对test_net全部进行测试
for (int test_net_id = 0;
test_net_id < test_nets_.size() && !requested_early_exit_;
++test_net_id) {
Test(test_net_id);
}
}
template <typename Dtype>
void Solver<Dtype>::Test(const int test_net_id) {
CHECK(Caffe::root_solver());
LOG(INFO) << "Iteration " << iter_
<< ", Testing net (#" << test_net_id << ")";
CHECK_NOTNULL(test_nets_[test_net_id].get())->
ShareTrainedLayersWith(net_.get());
vector<Dtype> test_score;
vector<int> test_score_output_id;
const shared_ptr<Net<Dtype> >& test_net = test_nets_[test_net_id];
Dtype loss = 0;
for (int i = 0; i < param_.test_iter(test_net_id); ++i) {
SolverAction::Enum request = GetRequestedAction();
// Check to see if stoppage of testing/training has been requested.
//对于网路不断检测请求状态,如果在训练或测试终端请求发出后,随时执行保存快照
while (request != SolverAction::NONE) {if (SolverAction::SNAPSHOT == request) {
Snapshot();
} else if (SolverAction::STOP == request) {
requested_early_exit_ = true;
}
request = GetRequestedAction();
}
if (requested_early_exit_) {
// break out of test loop.
break;
}
Dtype iter_loss;
const vector<Blob<Dtype>*>& result =
test_net->Forward(&iter_loss);
if (param_.test_compute_loss()) {
loss += iter_loss;
}
if (i == 0) {
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
//第一次测试时,取每一个输出层的blob result_vec = result[j]->cpu_data();把每一个blob的数据(降为一维)存入一个vector test_score
for (int k = 0; k < result[j]->count(); ++k) {test_score.push_back(result_vec[k]);
test_score_output_id.push_back(j);
}
}
} else {
int idx = 0;
for (int j = 0; j < result.size(); ++j) {
const Dtype* result_vec = result[j]->cpu_data();
for (int k = 0; k < result[j]->count(); ++k) {
//不是第一次测试,把输出层的对应位置的blob值累加
test_score[idx++] += result_vec[k];}
}
}
}
if (requested_early_exit_) {
LOG(INFO) << "Test interrupted.";
return;
}
if (param_.test_compute_loss()) {
loss /= param_.test_iter(test_net_id); //求出平均loss值并输出
LOG(INFO) << "Test loss: " << loss;
}
for (int i = 0; i < test_score.size(); ++i) {
const int output_blob_index =
test_net->output_blob_indices()[test_score_output_id[i]];
const string& output_name = test_net->blob_names()[output_blob_index];
const Dtype loss_weight = test_net->blob_loss_weights()[output_blob_index];
ostringstream loss_msg_stream;
const Dtype mean_score = test_score[i] / param_.test_iter(test_net_id);
if (loss_weight) {
loss_msg_stream << " (* " << loss_weight
<< " = " << loss_weight * mean_score << " loss)";
}
LOG(INFO) << " Test net output #" << i << ": " << output_name << " = "
<< mean_score << loss_msg_stream.str();
}
}
template <typename Dtype>
void Solver<Dtype>::Snapshot() { //选择合适的方式保存快照
CHECK(Caffe::root_solver());
string model_filename;
switch (param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
model_filename = SnapshotToBinaryProto();
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
model_filename = SnapshotToHDF5();
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
SnapshotSolverState(model_filename);
}
template <typename Dtype>
void Solver<Dtype>::CheckSnapshotWritePermissions() { //检查是否允许保存快照
if (Caffe::root_solver() && param_.snapshot()) {
CHECK(param_.has_snapshot_prefix())
<< "In solver params, snapshot is specified but snapshot_prefix is not";
string probe_filename = SnapshotFilename(".tempfile");
std::ofstream probe_ofs(probe_filename.c_str());
if (probe_ofs.good()) {
probe_ofs.close();
std::remove(probe_filename.c_str());
} else {
LOG(FATAL) << "Cannot write to snapshot prefix '"
<< param_.snapshot_prefix() << "'. Make sure "
<< "that the directory exists and is writeable.";
}
}
}
template <typename Dtype>
string Solver<Dtype>::SnapshotFilename(const string extension) {/ /生成快照文件名
return param_.snapshot_prefix() + "_iter_" + caffe::format_int(iter_)
+ extension;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToBinaryProto() {
string model_filename = SnapshotFilename(".caffemodel");
LOG(INFO) << "Snapshotting to binary proto file " << model_filename;
NetParameter net_param;
net_->ToProto(&net_param, param_.snapshot_diff());
WriteProtoToBinaryFile(net_param, model_filename);
return model_filename;
}
template <typename Dtype>
string Solver<Dtype>::SnapshotToHDF5() {
string model_filename = SnapshotFilename(".caffemodel.h5");
LOG(INFO) << "Snapshotting to HDF5 file " << model_filename;
net_->ToHDF5(model_filename, param_.snapshot_diff());
return model_filename;
}
template <typename Dtype>
void Solver<Dtype>::Restore(const char* state_file) {
string state_filename(state_file);
if (state_filename.size() >= 3 &&
state_filename.compare(state_filename.size() - 3, 3, ".h5") == 0) {
RestoreSolverStateFromHDF5(state_filename);
} else {
RestoreSolverStateFromBinaryProto(state_filename);
}
}
template <typename Dtype>
void Solver<Dtype>::UpdateSmoothedLoss(Dtype loss, int start_iter,
int average_loss) {
if (losses_.size() < average_loss) {
losses_.push_back(loss);
int size = losses_.size();
smoothed_loss_ = (smoothed_loss_ * (size - 1) + loss) / size;
} else {
int idx = (iter_ - start_iter) % average_loss;
smoothed_loss_ += (loss - losses_[idx]) / average_loss;
losses_[idx] = loss;
}
}
INSTANTIATE_CLASS(Solver);
} // namespace caffe
src/caffe/solvers/sgd_solver.cpp
#include <string>
#include <vector>
#include "caffe/sgd_solvers.hpp"
#include "caffe/util/hdf5.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
namespace caffe {
// Return the current learning rate. The currently implemented learning rate
// policies are as follows:
// - fixed: always return base_lr.
// - step: return base_lr * gamma ^ (floor(iter / step))
// - exp: return base_lr * gamma ^ iter
// - inv: return base_lr * (1 + gamma * iter) ^ (- power)
// - multistep: similar to step but it allows non uniform steps defined by
// stepvalue
// - poly: the effective learning rate follows a polynomial decay, to be
// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power)
// - sigmoid: the effective learning rate follows a sigmod decay
// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))
//
// where base_lr, max_iter, gamma, step, stepvalue and power are defined
// in the solver parameter protocol buffer, and iter is the current iteration.
template <typename Dtype>
Dtype SGDSolver<Dtype>::GetLearningRate() {
Dtype rate;
const string& lr_policy = this->param_.lr_policy();
if (lr_policy == "fixed") {
rate = this->param_.base_lr();
} else if (lr_policy == "step") {
this->current_step_ = this->iter_ / this->param_.stepsize();
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "exp") {
rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_);
} else if (lr_policy == "inv") {
rate = this->param_.base_lr() *
pow(Dtype(1) + this->param_.gamma() * this->iter_,
- this->param_.power());
} else if (lr_policy == "multistep") {
if (this->current_step_ < this->param_.stepvalue_size() &&
this->iter_ >= this->param_.stepvalue(this->current_step_)) {
this->current_step_++;
LOG(INFO) << "MultiStep Status: Iteration " <<
this->iter_ << ", step = " << this->current_step_;
}
rate = this->param_.base_lr() *
pow(this->param_.gamma(), this->current_step_);
} else if (lr_policy == "poly") {
rate = this->param_.base_lr() * pow(Dtype(1.) -
(Dtype(this->iter_) / Dtype(this->param_.max_iter())),
this->param_.power());
} else if (lr_policy == "sigmoid") {
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
}
return rate;
}
template <typename Dtype>
void SGDSolver<Dtype>::PreSolve() {
// Initialize the history
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
history_.clear();
update_.clear();
temp_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const vector<int>& shape = net_params[i]->shape();
history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(shape)));
}
}
template <typename Dtype>
void SGDSolver<Dtype>::ClipGradients() {
const Dtype clip_gradients = this->param_.clip_gradients();
if (clip_gradients < 0) { return; }
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();
Dtype sumsq_diff = 0;
for (int i = 0; i < net_params.size(); ++i) {
sumsq_diff += net_params[i]->sumsq_diff();
}
const Dtype l2norm_diff = std::sqrt(sumsq_diff);
if (l2norm_diff > clip_gradients) {
Dtype scale_factor = clip_gradients / l2norm_diff;
LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
<< l2norm_diff << " > " << clip_gradients << ") "
<< "by scale factor " << scale_factor;
for (int i = 0; i < net_params.size(); ++i) {
net_params[i]->scale_diff(scale_factor);
}
}
}
//:ApplyUpdate()
template <typename Dtype>
void SGDSolver<Dtype>::ApplyUpdate() {
Dtype rate = GetLearningRate();//根据设置的学习率改变策略,计算当前迭代的学习率
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {//判断是否需要输出当前的学习率
LOG_IF(INFO, Caffe::root_solver()) << "Iteration " << this->iter_
<< ", lr = " << rate;
}
ClipGradients();//避免梯度爆炸,如果梯度的二范数超过某个数值则进行scale操作将梯度减小
//对所有可更新的网络参数进行操作
for (int param_id = 0; param_id < this->net_->learnable_params().size();
++param_id) {
Normalize(param_id);//将第param_id个参数除以iter_size,这一步的作用是保证实际的batch_size=iter_size*设置的batch_size
Regularize(param_id);//将正则化部分的梯度降到每个参数的梯度中
ComputeUpdateValue(param_id, rate);//计算sgd算法的梯度
}
this->net_->Update();//调用网络更新所有参数
}
//Normalize
template <typename Dtype>
void SGDSolver<Dtype>::Normalize(int param_id) {
if (this->param_.iter_size() == 1) { return; }/如果iter_size等于1,不用操作,直接返回
// Scale gradient to counterbalance accumulation.
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//通过net返回所有的可学习参数,是vector类型
const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();//要乘以的系数等于1/iter_size
switch (Caffe::mode()) {
case Caffe::CPU: {
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
break;
//caffe_scal函数在src/caffe/util/math_functions.cpp中。是blas的scale函数的一个封装。
//第一个参数是数据的个数,第二个参数是乘以的系数,第三个参数是数据的指针
}
case Caffe::GPU: {
#ifndef CPU_ONLY
caffe_gpu_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_gpu_diff());
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
.//Regularize
template <typename Dtype>
void SGDSolver<Dtype>::Regularize(int param_id) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//获取所有可以学习的参数
const vector<float>& net_params_weight_decay =
this->net_->params_weight_decay();//获取所有参数对应的权重衰减
Dtype weight_decay = this->param_.weight_decay();//模型整体的权重衰减数值
string regularization_type = this->param_.regularization_type();//获取正则化类型
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];//实际的衰减等于整体模型的衰减乘以具体每个参数的数值
switch (Caffe::mode()) {
case Caffe::CPU: {
if (local_decay) {
if (regularization_type == "L2") {
// L2的梯度是diff_=weight_decay*data_+diff_
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//caffe_axpy函数是计算y=a*x+y即diff_=weight_delay*data+diff_第一个参数是数据的个数,第二个是a,第三个是data指针,第四个是y指针.
} else if (regularization_type == "L1") {
//L1的梯度是diff_=diff_+sign(data)
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());//temp_=sign(data)
caffe_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//将temp_添加到diff_中,diff_=weight_decay*temp_+diff_
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
if (local_decay) {
if (regularization_type == "L2") {
// add weight decay
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else if (regularization_type == "L1") {
caffe_gpu_sign(net_params[param_id]->count(),
net_params[param_id]->gpu_data(),
temp_[param_id]->mutable_gpu_data());
caffe_gpu_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->gpu_data(),
net_params[param_id]->mutable_gpu_diff());
} else {
LOG(FATAL) << "Unknown regularization type: " << regularization_type;
}
}
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
#ifndef CPU_ONLY
template <typename Dtype>
void sgd_update_gpu(int N, Dtype* g, Dtype* h, Dtype momentum,
Dtype local_rate);
#endif
//ComputeUpdateValue
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
const vector<Blob<Dtype>*>& net_params = this->net_->learnable_params();//获取所有可更新的参数
const vector<float>& net_params_lr = this->net_->params_lr();//获取所有参数对应的学习率
Dtype momentum = this->param_.momentum();//获取动量数值
Dtype local_rate = rate * net_params_lr[param_id];//实际的学习率等于全局的学习率乘以每个参数各自的学习率
// Compute the update to history, then copy it to the parameter diff.
switch (Caffe::mode()) {
case Caffe::CPU: {
//history存储了上一次的梯度
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());//history_=lr*diff_+momentum*history
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());//把当前的梯度拷贝给参数blob的diff_
break;
}
case Caffe::GPU: {
#ifndef CPU_ONLY
sgd_update_gpu(net_params[param_id]->count(),
net_params[param_id]->mutable_gpu_diff(),
history_[param_id]->mutable_gpu_data(),
momentum, local_rate);
#else
NO_GPU;
#endif
break;
}
default:
LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode();
}
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverState(const string& model_filename) {
switch (this->param_.snapshot_format()) {
case caffe::SolverParameter_SnapshotFormat_BINARYPROTO:
SnapshotSolverStateToBinaryProto(model_filename);
break;
case caffe::SolverParameter_SnapshotFormat_HDF5:
SnapshotSolverStateToHDF5(model_filename);
break;
default:
LOG(FATAL) << "Unsupported snapshot format.";
}
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToBinaryProto(
const string& model_filename) {
SolverState state;
state.set_iter(this->iter_);
state.set_learned_net(model_filename);
state.set_current_step(this->current_step_);
state.clear_history();
for (int i = 0; i < history_.size(); ++i) {
// Add history
BlobProto* history_blob = state.add_history();
history_[i]->ToProto(history_blob);
}
string snapshot_filename = Solver<Dtype>::SnapshotFilename(".solverstate");
LOG(INFO)
<< "Snapshotting solver state to binary proto file " << snapshot_filename;
WriteProtoToBinaryFile(state, snapshot_filename.c_str());
}
template <typename Dtype>
void SGDSolver<Dtype>::SnapshotSolverStateToHDF5(
const string& model_filename) {
string snapshot_filename =
Solver<Dtype>::SnapshotFilename(".solverstate.h5");
LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename;
hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC,
H5P_DEFAULT, H5P_DEFAULT);
CHECK_GE(file_hid, 0)
<< "Couldn't open " << snapshot_filename << " to save solver state.";
hdf5_save_int(file_hid, "iter", this->iter_);
hdf5_save_string(file_hid, "learned_net", model_filename);
hdf5_save_int(file_hid, "current_step", this->current_step_);
hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT,
H5P_DEFAULT);
CHECK_GE(history_hid, 0)
<< "Error saving solver state to " << snapshot_filename << ".";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_save_nd_dataset<Dtype>(history_hid, oss.str(), *history_[i]);
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromBinaryProto(
const string& state_file) {
SolverState state;
ReadProtoFromBinaryFile(state_file, &state);
this->iter_ = state.iter();
if (state.has_learned_net()) {
NetParameter net_param;
ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param);
this->net_->CopyTrainedLayersFrom(net_param);
}
this->current_step_ = state.current_step();
CHECK_EQ(state.history_size(), history_.size())
<< "Incorrect length of history blobs.";
LOG(INFO) << "SGDSolver: restoring history";
for (int i = 0; i < history_.size(); ++i) {
history_[i]->FromProto(state.history(i));
}
}
template <typename Dtype>
void SGDSolver<Dtype>::RestoreSolverStateFromHDF5(const string& state_file) {
hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT);
CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file;
this->iter_ = hdf5_load_int(file_hid, "iter");
if (H5LTfind_dataset(file_hid, "learned_net")) {
string learned_net = hdf5_load_string(file_hid, "learned_net");
this->net_->CopyTrainedLayersFrom(learned_net);
}
this->current_step_ = hdf5_load_int(file_hid, "current_step");
hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT);
CHECK_GE(history_hid, 0) << "Error reading history from " << state_file;
int state_history_size = hdf5_get_num_links(history_hid);
CHECK_EQ(state_history_size, history_.size())
<< "Incorrect length of history blobs.";
for (int i = 0; i < history_.size(); ++i) {
ostringstream oss;
oss << i;
hdf5_load_nd_dataset<Dtype>(history_hid, oss.str().c_str(), 0,
kMaxBlobAxes, history_[i].get());
}
H5Gclose(history_hid);
H5Fclose(file_hid);
}
INSTANTIATE_CLASS(SGDSolver);
REGISTER_SOLVER_CLASS(SGD);//在代码最后,调用宏完成注册
} // namespace caffe