torch

require 'paths';
require 'nn';

---Load TrainSet
paths.filep("/home/xuhang/torch/myfiles/mydata/cifar10torchsmall.zip"); 

trainset = torch.load('/home/xuhang/torch/myfiles/mydata/cifar10-train.t7');
testset = torch.load('/home/xuhang/torch/myfiles/mydata/cifar10-test.t7');
classes = {'airplane', 'automobile', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck'};

---Add size() function and Tensor index operator 
setmetatable(trainset, 
    {__index = function(t, i) 
                    return {t.data[i], t.label[i]} 
                end}
);
trainset.data = trainset.data:double() 

function trainset:size() 
    return self.data:size(1) 
end

---Normalize data
mean = {}
stdv = {}
for i=1,3 do
    mean[i] = trainset.data[{ {}, {i}, {}, {}  }]:mean()
    print('Channel ' .. i .. ', Mean: ' .. mean[i])
    trainset.data[{ {}, {i}, {}, {}  }]:add(-mean[i])

    stdv[i] = trainset.data[{ {}, {i}, {}, {}  }]:std()
    print('Channel ' .. i .. ', Standard Deviation:' .. stdv[i])
    trainset.data[{ {}, {i}, {}, {}  }]:div(stdv[i])
end
net = nn.Sequential()

--change 1 channel to 3 channels
--net:add(nn.SpatialConvolution(1, 6, 5, 5))
net:add(nn.SpatialConvolution(3, 6, 5, 5)) 

net:add(nn.ReLU())                       
net:add(nn.SpatialMaxPooling(2,2,2,2))     
net:add(nn.SpatialConvolution(6, 16, 5, 5))
net:add(nn.ReLU())                       
net:add(nn.SpatialMaxPooling(2,2,2,2))
net:add(nn.View(16*5*5))                    
net:add(nn.Linear(16*5*5, 120))         
net:add(nn.ReLU())                       
net:add(nn.Linear(120, 84))
net:add(nn.ReLU())                       
net:add(nn.Linear(84, 10))                  
net:add(nn.LogSoftMax()) 

criterion = nn.ClassNLLCriterion();

trainer = nn.StochasticGradient(net, criterion)
trainer.learningRate = 0.001
trainer.maxIteration = 5

trainer:train(trainset)

//test

testset.data=testset.data:double();
for i=1,3 do
    testset.data[{ {},{i},{},{} }]:add(-mean[i])
    testset.data[{ {},{i},{},{} }]:div(stdv[i])
end

print(classes[testset.label[100]])
itorch.image(testset.data[100])

predicted=net:forward(testset.data[100])
print(predicted:exp())
--
gailv,label=torch.sort(predicted,true)
print (gailv[1])
print (label[1])
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值