caffe---之python接口写网络

转载:https://blog.csdn.net/tiewadhd/article/details/54343349

caffe_root='/home/cheam/caffe-master/'
import subprocess
import sys
import os
import numpy as np
import h5py
sys.path.insert(0,caffe_root+'python')
from pylab import *
import caffe
from caffe.proto import caffe_pb2
from caffe import layers as L,params as P

#difine network structure
def dsn_net(lmdb,batch_size,train=True):
    n=caffe.NetSpec()
    n.data,n.label=L.HDF5Data(batch_size=batch_size,source=lmdb,
                         shuffle=True,ntop=2)
    n.W1=L.InnerProduct(n.data,num_output=50,weight_filler=dict(type='xavier'))
    n.relu1=L.PReLU(n.W1,in_place=True)
    n.U1=L.InnerProduct(n.relu1,num_output=50,weight_filler=dict(type='xavier'))
    n.relu2 = L.PReLU(n.U1, in_place=True)
    n.Wrand=L.InnerProduct(n.relu2,num_output=50,weight_filler=dict(type='xavier'))
    n.relu3=L.PReLU(n.Wrand, in_place=True)
    n.W2 = L.InnerProduct(n.data, num_output=50, weight_filler=dict(type='xavier'))
    n.relu4 = L.PReLU(n.W2, in_place=True)
    n.out1=L.Concat(n.Wrand,n.W2,concat_param=dict(axis=1))
    n.U2 = L.InnerProduct(n.out1, num_output=50, weight_filler=dict(type='xavier'))
    n.relu5 = L.PReLU(n.U2, in_place=True)
    n.Wrand1 = L.InnerProduct(n.relu5, num_output=50, weight_filler=dict(type='xavier'))
    n.relu6 = L.PReLU(n.Wrand1, in_place=True)
    n.Wrand2 = L.InnerProduct(n.relu6, num_output=50, weight_filler=dict(type='xavier'))
    n.relu7 = L.PReLU(n.Wrand2, in_place=True)
    n.W3 = L.InnerProduct(n.data, num_output=50, weight_filler=dict(type='xavier'))
    n.relu8 = L.PReLU(n.W3, in_place=True)
    n.out2 = L.Concat(n.Wrand1, n.Wrand2, n.W3,concat_param=dict(axis=1))
    n.U3 = L.InnerProduct(n.out2, num_output=50, weight_filler=dict(type='xavier'))
    n.relu9 = L.PReLU(n.U3, in_place=True)
    n.drop3=L.Dropout(n.relu9,dropout_ratio=0.1,in_place=True)
    n.Y=L.InnerProduct(n.drop3, num_output=50, weight_filler=dict(type='xavier'))
    n.relu10 = L.PReLU(n.Y, in_place=True)
    n.loss=L.SoftmaxWithLoss(n.Y,n.label)
    if train is False:
        n.acc = L.Accuracy(n.Y, n.label)
    return n.to_proto()

#difine network training parameters
def train_dsn_solver():
    s=caffe_pb2.SolverParameter()
    #s.random_seed=0xCAFFE
    s.train_net=file_name+'dsn_train_file'
    s.test_net.append(file_name+'dsn_test_file')
    s.test_interval=10000
    s.test_iter.append(1)
    s.max_iter=3000000
    s.type="AdaGrad"
    s.base_lr=0.01
    s.weight_decay=0.0005
    s.lr_policy='multistep'
    s.gamma=0.5
    s.display=10000
    s.snapshot=10000
    s.snapshot_prefix=file_name+'indian_pines/indian_pines'
    s.stepvalue.append(2000000)
    s.regularization_type="L2"
    s.snapshot_format=0
    s.debug_info=0
    s.solver_mode = caffe_pb2.SolverParameter.GPU
    return s

def run_command():
    job_file='dsn_job_0.sh'
    gpus=0
    #job_dir=os.getcwd()
    job_dir=file_name
    solver_file=file_name+'/'+'solver_file'
    model_name='dsn_net'
    log_file = "{}/{}_{}_{}.log".format(job_dir, model_name, 'dsn', 2)
    with open(job_file,'w') as f:
        f.write(caffe_root + 'build/tools/caffe train \\\n')
        f.write('--solver="{}" \\\n'.format(solver_file))
        f.write('--gpu {} 2>&1 | tee {}\n'.format(gpus, log_file))
    subprocess.check_call('bash {}'.format(job_file), shell=True)
if __name__=='__main__':
    global file_name
    file_name=os.getcwd()+'/'
    with open(file_name+'dsn_train_file', 'w')as f:
        f.write(str(dsn_net(file_name+'indian_pines_train.txt', 32)))
    with open(file_name+'dsn_test_file', 'w')as f:
        f.write(str(dsn_net(file_name+'indian_pines_test.txt', 6904,False)))
    with open(file_name+'solver_file','w')as f:
        f.write(str(train_dsn_solver()))
    run_command()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值