C++11 使用CRTP和组合模式模拟神经网络

本例来自《Design Pattern in Modern C++》一书的源码。
我做了两处修改。

  1. CRTP的SomeNeurons基类的for循环创建连接处,如果直接使用Neuron类作为for循环的参数,需要前置声明,会产生imcompleted type 警告。我改为了auto&
  2. NeuronLayer类在原先的附书代码中直接继承自std::vector< Neuron>, 这个不大合适,因为std::vector类的析构函数不是虚函数,子类如果有堆对象需要释放,可能导致子类析构函数无法调用,内存泄漏。我改为了组合std::vector的形式,加入了两个版本的begin,end函数。一个是const版本,一个是非const版本。

程序目录结构如下,
程序结构图test/CMakeLists.txt

cmake_minimum_required(VERSION 2.6)

if(APPLE)
   message(STATUS "This is Apple, do nothing.")
   set(CMAKE_MACOSX_RPATH 1)
   set(CMAKE_PREFIX_PATH /Users/aabjfzhu/software/vcpkg/ports/cppwork/vcpkg_installed/x64-osx/share )
elseif(UNIX)
   message(STATUS "This is linux, set CMAKE_PREFIX_PATH.")
   set(CMAKE_PREFIX_PATH /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/share)
endif(APPLE)

project(neural_network)

set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")

add_definitions(-g)

find_package(ZLIB)

find_package(OpenCV REQUIRED )
find_package(Arrow CONFIG REQUIRED)

find_package(unofficial-brotli REQUIRED)
find_package(unofficial-utf8proc CONFIG REQUIRED)
find_package(Thrift CONFIG REQUIRED)

find_package(glog REQUIRED)

find_package(OpenSSL REQUIRED)

find_package(Boost REQUIRED COMPONENTS
   system
   filesystem
   serialization
   program_options
   thread
   )

find_package(DataFrame REQUIRED)

if(APPLE)
   MESSAGE(STATUS "This is APPLE, set INCLUDE_DIRS")
set(INCLUDE_DIRS ${Boost_INCLUDE_DIRS} /usr/local/include /usr/local/iODBC/include /opt/snowflake/snowflakeodbc/include/ ${CMAKE_CURRENT_SOURCE_DIR}/../include/ ${CMAKE_CURRENT_SOURCE_DIR}/../../../include)
elseif(UNIX)
   MESSAGE(STATUS "This is linux, set INCLUDE_DIRS")
   set(INCLUDE_DIRS ${Boost_INCLUDE_DIRS} /usr/local/include ${CMAKE_CURRENT_SOURCE_DIR}/../include/   ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/)
endif(APPLE)


if(APPLE)
   MESSAGE(STATUS "This is APPLE, set LINK_DIRS")
   set(LINK_DIRS /usr/local/lib /usr/local/iODBC/lib /opt/snowflake/snowflakeodbc/lib/universal)
elseif(UNIX)
   MESSAGE(STATUS "This is linux, set LINK_DIRS")
   set(LINK_DIRS ${Boost_INCLUDE_DIRS} /usr/local/lib /vcpkg/ports/cppwork/vcpkg_installed/x64-linux/lib)
endif(APPLE)

if(APPLE)
   MESSAGE(STATUS "This is APPLE, set ODBC_LIBS")
   set(ODBC_LIBS iodbc iodbcinst)
elseif(UNIX)
   MESSAGE(STATUS "This is linux, set LINK_DIRS")
   set(ODBC_LIBS odbc odbcinst ltdl)
endif(APPLE)

include_directories(${INCLUDE_DIRS})
LINK_DIRECTORIES(${LINK_DIRS})

file( GLOB test_file_list ${CMAKE_CURRENT_SOURCE_DIR}/*.cpp) 

file( GLOB APP_SOURCES ${CMAKE_CURRENT_SOURCE_DIR}/../impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../include/*.h ${CMAKE_CURRENT_SOURCE_DIR}/../include/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/arr_/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/http/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/yaml/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/df/impl/*.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../../../include/death_handler/impl/*.cpp)

add_library(${PROJECT_NAME}_lib SHARED ${APP_SOURCES} ${test_file})
target_link_libraries(${PROJECT_NAME}_lib ${Boost_LIBRARIES} ZLIB::ZLIB glog::glog DataFrame::DataFrame ${OpenCV_LIBS})
target_link_libraries(${PROJECT_NAME}_lib OpenSSL::SSL OpenSSL::Crypto libgtest.a pystring libyaml-cpp.a libgmock.a ${ODBC_LIBS} libnanodbc.a pthread dl backtrace libzstd.a libbz2.a libsnappy.a re2::re2 parquet lz4 unofficial::brotli::brotlidec-static unofficial::brotli::brotlienc-static unofficial::brotli::brotlicommon-static utf8proc thrift::thrift  arrow arrow_dataset)

foreach( test_file ${test_file_list} )
   file(RELATIVE_PATH filename ${CMAKE_CURRENT_SOURCE_DIR} ${test_file})
   string(REPLACE ".cpp" "" file ${filename})
   add_executable(${file}  ${test_file})
   target_link_libraries(${file} ${PROJECT_NAME}_lib)
endforeach( test_file ${test_file_list})

test/neural_network_test.cpp

#include "death_handler/death_handler.h"
#include <glog/logging.h>
#include "neural.hpp"

#include <utility>
#include <gtest/gtest.h>



int main(int argc, char** argv) {
   FLAGS_log_dir = "./";
   FLAGS_alsologtostderr = true;
   // 日志级别 INFO, WARNING, ERROR, FATAL 的值分别为0、1、2、3
   FLAGS_minloglevel = 0;

   Debug::DeathHandler dh;

   google::InitGoogleLogging("./logs.log");
   testing::InitGoogleTest(&argc, argv);
   int ret = RUN_ALL_TESTS();
   return ret;
}

GTEST_TEST(NeuralNetworkTests, NeuralNetwork) {
   Neuron n1, n2;
   n1.connect_to(n2);
   std::cout << n1 << n2 << std::endl;

   NeuronLayer l1 {5};
   Neuron n3;
   l1.connect_to(n3);

   std::cout << "Neuron " << n3.id << std::endl << n3 << std::endl;

   std::cout << "Layer " << std::endl << l1 << std::endl;
   NeuronLayer l2 {2}, l3 {3};
   l2.connect_to(l3);
   std::cout << "Layer l2" << std::endl << l2;
   std::cout << "Layer l3" << std::endl << l3;
}

include/neural.hpp

#ifndef _FREDRIC_NEURAL_HPP_
#define _FREDRIC_NEURAL_HPP_
#include <vector>

template <typename Self>
struct SomeNeurons {
   // From类型和to类型必须支持range-based for loop
   template <typename T>
   void connect_to(T& other) {
       for(auto& from: *static_cast<Self*>(this)) {
           for(auto& to: other) {
               from.out.push_back(&to);
               to.in.push_back(&from);
           }
       }
   }
};

struct Neuron: SomeNeurons<Neuron> {
   std::vector<Neuron*> in, out;
   unsigned int id;

   Neuron() {
       static int g_id = 1;
       this->id = g_id++;
   }
   
   Neuron* begin() {
       return this;
   }

   Neuron* end() {
       return this + 1;
   }

   friend std::ostream& operator<<(std::ostream& os, Neuron const& obj) {
       for(Neuron* n: obj.in) {
           os << n->id << "\t--> \t[" << obj.id << "]\n";
       }

       for(Neuron* n: obj.out) {
           os << obj.id << "\t--> \t[" << n->id << "]\n";
       }
       return os;
   }
};

struct NeuronLayer: SomeNeurons<NeuronLayer> {
   std::vector<Neuron> neurons;

   NeuronLayer(int count) {
       while (count-- > 0)
           neurons.emplace_back(Neuron{});
   }

   friend std::ostream& operator<<(std::ostream& os, NeuronLayer const& obj) {
       for(auto& n: obj) os << n;
       return os;
   }

   std::vector<Neuron>::const_iterator begin() const { return neurons.begin(); }
   std::vector<Neuron>::const_iterator end() const { return neurons.end(); }

   std::vector<Neuron>::iterator begin() { return neurons.begin(); }
   std::vector<Neuron>::iterator end() { return neurons.end(); }
};
#endif

程序输出如下,
组合模式输出

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025Design Patterns in Modern C++: Reusable Approaches for Object-Oriented Software Design English | PDF| 2018 | 312 Pages | ISBN : 1484236025
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值