nodes.h 简介
这个类保存了一系列node的指针,提供前向和后向传播的操作。node是tiny-dnn的计算单元,目前有两种实现sequential 和 graph。
nodes 接受 lvalue 、 rvalue 、share_ptr 类型的node.
如果给定类型是rvalue 或者 share_ptr,nodes 会创建shared_ptr< node> 来保持给定的node alive。
sequential
void backward(const std::vector<tensor_t> &first) override {
std::vector<std::vector<const vec_t *>> reordered_grad;
reorder_for_layerwise_processing(first, reordered_grad);
assert(reordered_grad.size() == 1);
nodes_.back()->set_out_grads(&reordered_grad[0], 1);
for (auto l = nodes_.rbegin(); l != nodes_.rend(); l++) {
(*l)->backward();
}
}
std::vector<tensor_t> forward(const std::vector<tensor_t> &first) override {
std::vector<std::vector<const vec_t *>> reordered_data;
reorder_for_layerwise_processing(first, reordered_data);
assert(reordered_data.size() == 1);
nodes_.front()->set_in_data(&reordered_data[0], 1);
for (auto l : nodes_) {
l->forward();
}
std::vector<const tensor_t *> out;
nodes_.back()->output(out);
return normalize_out(out);
}