Caffe-Python接口常用API参考

本文整理了pycaffe中常用的API,转载自:
http://wentaoma.com/2016/08/10/caffe-python-common-api-reference/

Packages导入

import caffe
from caffe import layers as L
from caffe import params as P

Layers定义 ,Data层定义

lmdb/leveldb Data层定义

L.Data( 
        source=lmdb,
        backend=P.Data.LMDB,
        batch_size=batch_size, ntop=2,
        transform_param=dict(
                              crop_size=227,
                              mean_value=[104, 117, 123],
                              mirror=True
                              )
        )

HDF5 Data层定义

L.HDF5Data(
            hdf5_data_param={
                            'source': './training_data_paths.txt',  
                            'batch_size': 64
                            },
            include={
                    'phase': caffe.TRAIN
                    }
            )

ImageData Data层定义(适用于txt文件一行记录一张图片的数据源)

L.ImageData(
                source=list_path,
                batch_size=batch_size,
                new_width=48,
                new_height=48,
                ntop=2,

                ransform_param=dict(crop_size=40,mirror=True)
                )

Convloution层定义

L.Convolution(  
                bottom, 
                kernel_size=ks, 
                stride=stride,
                num_output=nout, 
                pad=pad, 
                group=group,
                param=[dict(lr_mult=1,decay_mult=1),dict(lr_mult=2,decay_mult=0)])
                )

LRN层定义

L.LRN(
        bottom, 
        local_size=5, 
        alpha=1e-4, 
        beta=0.75
        )

Activation层定义

ReLU层定义

L.ReLU(
        bottom, 
        in_place=True
        )

Pooling层定义

L.Pooling(
            bottom,
            pool=P.Pooling.MAX, 
            kernel_size=ks, 
            stride=stride
            )

FullConnect层定义

L.Dropout(
            bottom, 
            in_place=True
            )

Dropout层定义

L.Dropout(
            bottom, 
            in_place=True
            )

Loss层定义

L.SoftmaxWithLoss(
                    bottom, 
                    label
                    )

Accuracy层定义

L.Accuracy(
            bottom,
            label
            )

转换为proto文本

caffe.to_proto(
                loss, 
                acc     #训练阶段可以删去Accuracy层
                )

Solver定义

from caffe.proto import caffe_pb2
s = caffe_pb2.SolverParameter()
path='/home/xxx/data/'
solver_file=path+'solver.prototxt'     #solver文件保存位置
s.train_net = path+'train.prototxt'     # 训练配置文件
s.test_net.append(path+'val.prototxt')  # 测试配置文件
s.test_interval = 782                   # 测试间隔
s.test_iter.append(313)                 # 测试迭代次数
s.max_iter = 78200                      # 最大迭代次数
s.base_lr = 0.001                       # 基础学习率
s.momentum = 0.9                        # momentum系数
s.weight_decay = 5e-4                   # 权值衰减系数
s.lr_policy = 'step'                    # 学习率衰减方法
s.stepsize=26067                        # 此值仅对step方法有效
s.gamma = 0.1                           # 学习率衰减指数
s.display = 782                         # 屏幕日志显示间隔
s.snapshot = 7820
s.snapshot_prefix = 'shapshot'
s.type = “SGD”                          # 优化算法
s.solver_mode = caffe_pb2.SolverParameter.GPU
with open(solver_file, 'w') as f:
    f.write(str(s))

Model训练

# 训练设置
# 使用GPU
caffe.set_device(gpu_id) # 若不设置,默认为0
caffe.set_mode_gpu()
# 使用CPU
caffe.set_mode_cpu()
# 加载Solver,有两种常用方法
# 1. 无论模型中Slover类型是什么统一设置为SGD
solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt') 
# 2. 根据solver的prototxt中solver_type读取,默认为SGD
solver = caffe.get_solver('/home/xxx/data/solver.prototxt')
# 训练模型
# 1.1 前向传播
solver.net.forward()  # train net
solver.test_nets[0].forward()  # test net (there can be more than one)
# 1.2 反向传播,计算梯度
solver.net.backward()
# 2. 进行一次前向传播一次反向传播并根据梯度更新参数
solver.step(1)
# 3. 根据solver文件中设置进行完整model训练
solver.solve()

如果想在训练过程中保存模型参数,调用

solver.net.save('mymodel.caffemodel')
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值