paper | code | video |
---|---|---|
https://arxiv.org/abs/2008.10546 | https://github.com/Lingkai-Kong/SDE-Net | https://www.youtube.com/watch?v=RylZA4Ioc3M |
离散化:
x
t
+
1
=
x
t
+
f
(
x
t
,
t
)
\Large x_{t+1} = x_t+f(x_t,t)
xt+1=xt+f(xt,t)
连续化:
d
x
t
=
f
(
x
t
,
t
)
d
t
\Large d{x_t} = f(x_t,t)dt
dxt=f(xt,t)dt
d x t = f ( x t , t ) d t + g ( x t , t ) d w t \Large d{x_t} = f(x_t,t)dt+g(x_t,t)dw_t dxt=f(xt,t)dt+g(xt,t)dwt
x t + 1 = x t + f ( x k , t ) δ t + g ( x o ) δ t Z k \Large x_{t+1} = x_t + f(x_k,t)\delta t +g(x_o) \sqrt {\delta t} Z_k xt+1=xt+f(xk,t)δt+g(xo)δtZk
- 简单图示过程:
python sdenet_mnist.py
Evaluation:
python test_detection.py --pre_trained_net save_sdenet_mnist/final_model --network sdenet --dataset
代码
- 好消息是环境异常简单,直接打开工程,然后添加了一个venv的环境
- 然后安装个torch就能运行了
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
SDENet_mnist(
(downsampling_layers): Sequential(
(0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1))
(1): GroupNorm(32, 64, eps=1e-05, affine=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
(4): GroupNorm(32, 64, eps=1e-05, affine=True)
(5): ReLU(inplace=True)
(6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
(drift): Drift(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(relu): ReLU(inplace=True)
(conv1): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv2): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm3): GroupNorm(32, 64, eps=1e-05, affine=True)
)
(diffusion): Diffusion(
(norm1): GroupNorm(32, 64, eps=1e-05, affine=True)
(relu): ReLU(inplace=True)
(conv1): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(norm2): GroupNorm(32, 64, eps=1e-05, affine=True)
(conv2): ConcatConv2d(
(_layer): Conv2d(65, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(fc): Sequential( # Diffusion 和 Drift 相比就多了一个这个
(0): GroupNorm(32, 64, eps=1e-05, affine=True)
(1): ReLU(inplace=True)
(2): AdaptiveAvgPool2d(output_size=(1, 1))
(3): Flatten()
(4): Linear(in_features=64, out_features=1, bias=True) # 这个和下边的的就只有输出维度不同
(5): Sigmoid()
)
)
(fc_layers): Sequential(
(0): GroupNorm(32, 64, eps=1e-05, affine=True)
(1): ReLU(inplace=True)
(2): AdaptiveAvgPool2d(output_size=(1, 1))
(3): Flatten()
(4): Linear(in_features=64, out_features=10, bias=True)
)
)
torch.Size([36, 1, 28, 28]) =>(out = self.downsampling_layers(x)) torch.Size([36, 64, 6, 6])
diffusion_term = self.sigma*self.diffusion(t, out)即为 20 * torch.Size([36, 1])
diffusion_term = torch.unsqueeze(diffusion_term, 2)=> torch.Size([36, 1, 1])
diffusion_term = torch.unsqueeze(diffusion_term, 3)=> torch.Size([36, 1, 1, 1])
t为 0.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 1.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 2.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 3.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 4.0 输出的大小 torch.Size([36, 64, 6, 6])
t为 5.0 输出的大小 torch.Size([36, 64, 6, 6])
final_out torch.Size([36, 10])
最关键的一句out = out + self.drift(t, out)*self.deltat + diffusion_term*math.sqrt(self.deltat)*torch.randn_like(out).to(x)
其中 self.deltat = 1.0
- 注意代码中使用的GroupNorm和ConcatConv2d,其中ConcatConv2d为:
- 代码中的t是一种权重
class ConcatConv2d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
快速运行
无须下载数据的sdenet_mnist.py版本
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 11 16:34:10 2019
@author: lingkaikong
"""
from __future__ import print_function
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import random
import os
import argparse
import sed as models
#import data_loader
parser = argparse.ArgumentParser(description='PyTorch SDE-Net Training')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate of drift net')
parser.add_argument('--lr2', default=0.01, type=float, help='learning rate of diffusion net')
parser.add_argument('--training_out', action='store_false', default=True, help='training_with_out')
parser.add_argument('--epochs', type=int, default=40, help='number of epochs to train')
parser.add_argument('--eva_iter', default=5, type=int, help='number of passes when evaluation')
parser.add_argument('--dataset_inDomain', default='mnist', help='training dataset')
parser.add_argument('--batch_size', type=int, default=36, help='input batch size for training')
parser.add_argument('--imageSize', type=int, default=28, help='the height / width of the input image to network')
parser.add_argument('--test_batch_size', type=int, default=1000)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--seed', type=float, default=0)
parser.add_argument('--droprate', type=float, default=0.1, help='learning rate decay')
parser.add_argument('--decreasing_lr', default=[10, 20,30], nargs='+', help='decreasing strategy')
parser.add_argument('--decreasing_lr2', default=[15, 30], nargs='+', help='decreasing strategy')
args = parser.parse_args()
device = torch.device('cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu')
torch.manual_seed(args.seed)
random.seed(args.seed)
if device == 'cuda':
cudnn.benchmark = True
torch.cuda.manual_seed(args.seed)
# print('load in-domain data: ',args.dataset_inDomain)
# train_loader_inDomain, test_loader_inDomain = data_loader.getDataSet(args.dataset_inDomain, args.batch_size, args.test_batch_size, args.imageSize)
# Model
print('==> Building model..')
net = models.SDENet_mnist(layer_depth=6, num_classes=10, dim=64)
net = net.to(device)
real_label = 0
fake_label = 1
criterion = nn.CrossEntropyLoss()
criterion2 = nn.BCELoss()
optimizer_F = optim.SGD([ {'params': net.downsampling_layers.parameters()}, {'params': net.drift.parameters()},
{'params': net.fc_layers.parameters()}], lr=args.lr, momentum=0.9, weight_decay=5e-4)
optimizer_G = optim.SGD([ {'params': net.diffusion.parameters()}], lr=args.lr2, momentum=0.9, weight_decay=5e-4)
#use a smaller sigma during training for training stability
net.sigma = 20
# Training
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
train_loss_out = 0
train_loss_in = 0
##training with in-domain data
for batch_idx in range(800):#, (inputs, targets) in enumerate(train_loader_inDomain):
#inputs, targets = inputs.to(device), targets.to(device)
inputs = torch.randn([36, 1, 28, 28]).to(device)
targets = torch.tensor([8, 1, 2, 7, 1, 2, 3, 0, 1, 2, 4, 5, 9, 6, 3, 9, 0, 3, 5, 7, 6, 9, 8, 1,
2, 5, 0, 2, 6, 9, 7, 3, 3, 4, 0, 8]).to(device)
optimizer_F.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer_F.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
#training with out-of-domain data
label = torch.full((args.batch_size,1), real_label, device=device).float()# 后加的folat
optimizer_G.zero_grad()
predict_in = net(inputs, training_diffusion=True)
loss_in = criterion2(predict_in, label)
loss_in.backward()
label.fill_(fake_label)
inputs_out = 2*torch.randn(args.batch_size,1, args.imageSize, args.imageSize, device = device)+inputs
predict_out = net(inputs_out, training_diffusion=True)
loss_out = criterion2(predict_out, label)
loss_out.backward()
train_loss_out += loss_out.item()
train_loss_in += loss_in.item()
optimizer_G.step()
# print('Train epoch:{} \tLoss: {:.6f} | Loss_in: {:.6f}, Loss_out: {:.6f} | Acc: {:.6f} ({}/{})'
# .format(epoch, train_loss/(len(train_loader_inDomain)), train_loss_in/len(train_loader_inDomain), train_loss_out/len(train_loader_inDomain), 100.*correct/total, correct, total))
# def test(epoch):
# net.eval()
# correct = 0
# total = 0
# with torch.no_grad():
# for batch_idx, (inputs, targets) in enumerate(test_loader_inDomain):
# inputs, targets = inputs.to(device), targets.to(device)
# outputs = 0
# for j in range(args.eva_iter):
# current_batch = net(inputs)
# outputs = outputs + F.softmax(current_batch, dim = 1)
#
# outputs = outputs/args.eva_iter
# _, predicted = outputs.max(1)
# total += targets.size(0)
# correct += predicted.eq(targets).sum().item()
#
# print('Test epoch: {} | Acc: {:.6f} ({}/{})'
# .format(epoch, 100.*correct/total, correct, total))
for epoch in range(0, args.epochs):
train(epoch)
# test(epoch)
# if epoch in args.decreasing_lr:
# for param_group in optimizer_F.param_groups:
# param_group['lr'] *= args.droprate
# if epoch in args.decreasing_lr2:
# for param_group in optimizer_G.param_groups:
# param_group['lr'] *= args.droprate
# if not os.path.isdir('./save_sdenet_mnist'):
# os.makedirs('./save_sdenet_mnist')
# torch.save(net.state_dict(),'./save_sdenet_mnist/final_model')
最主要的部分(独立运行这个也行)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Mar 11 16:42:11 2019
@author: lingkaikong
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.nn.init as init
import math
__all__ = ['SDENet_mnist']
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=1e-3)
if m.bias is not None:
init.constant_(m.bias, 0)
# torch.manual_seed(0)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def norm(dim):
return nn.GroupNorm(min(32, dim), dim)
class ConcatConv2d(nn.Module):
def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False):
super(ConcatConv2d, self).__init__()
module = nn.ConvTranspose2d if transpose else nn.Conv2d
self._layer = module(
dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=bias
)
def forward(self, t, x):
tt = torch.ones_like(x[:, :1, :, :]) * t
ttx = torch.cat([tt, x], 1)
return self._layer(ttx)
class Drift(nn.Module):
def __init__(self, dim):
super(Drift, self).__init__()
self.norm1 = norm(dim)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm2 = norm(dim)
self.conv2 = ConcatConv2d(dim, dim, 3, 1, 1)
self.norm3 = norm(dim)
def forward(self, t, x):
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.norm3(out)
return out
class Diffusion(nn.Module):
def __init__(self, dim_in, dim_out):
super(Diffusion, self).__init__()
self.norm1 = norm(dim_in)
self.relu = nn.ReLU(inplace=True)
self.conv1 = ConcatConv2d(dim_in, dim_out, 3, 1, 1)
self.norm2 = norm(dim_in)
self.conv2 = ConcatConv2d(dim_in, dim_out, 3, 1, 1)
self.fc = nn.Sequential(norm(dim_out), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(),
nn.Linear(dim_out, 1), nn.Sigmoid())
def forward(self, t, x):
out = self.norm1(x)
out = self.relu(out)
out = self.conv1(t, out)
out = self.norm2(out)
out = self.relu(out)
out = self.conv2(t, out)
out = self.fc(out)
return out
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
class SDENet_mnist(nn.Module):
def __init__(self, layer_depth, num_classes=10, dim=64):
super(SDENet_mnist, self).__init__()
self.layer_depth = layer_depth
self.downsampling_layers = nn.Sequential(
nn.Conv2d(1, dim, 3, 1),
norm(dim),
nn.ReLU(inplace=True),
nn.Conv2d(dim, dim, 4, 2, 1),
norm(dim),
nn.ReLU(inplace=True),
nn.Conv2d(dim, dim, 4, 2, 1),
)
self.drift = Drift(dim)
self.diffusion = Diffusion(dim, dim)
self.fc_layers = nn.Sequential(norm(dim), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)), Flatten(),
nn.Linear(dim, 10))
self.deltat = 6. / self.layer_depth
self.apply(init_params)
self.sigma = 500
def forward(self, x, training_diffusion=False):
out = self.downsampling_layers(x)
print(x.shape, "=>(out = self.downsampling_layers(x))", out.shape)
if not training_diffusion:
t = 0
diffusion_term = self.sigma * self.diffusion(t, out)
print("diffusion_term = self.sigma*self.diffusion(t, out)即为", self.sigma, "*", self.diffusion(t, out).shape)
diffusion_term = torch.unsqueeze(diffusion_term, 2)
print("diffusion_term = torch.unsqueeze(diffusion_term, 2)=>", diffusion_term.shape)
diffusion_term = torch.unsqueeze(diffusion_term, 3)
print("diffusion_term = torch.unsqueeze(diffusion_term, 3)=>", diffusion_term.shape)
for i in range(self.layer_depth):
t = 6 * (float(i)) / self.layer_depth
print(
"最关键的一句out = out + self.drift(t, out)*self.deltat + diffusion_term*math.sqrt(self.deltat)*torch.randn_like(out).to(x)")
print("self.deltat", self.deltat)
out = out + self.drift(t, out) * self.deltat + diffusion_term * math.sqrt(
self.deltat) * torch.randn_like(out).to(x) # .to(x) 表示变成x的类型和device
print("t为", t, "输出的大小", out.shape)
final_out = self.fc_layers(out)
print("final_out", final_out.shape)
else:
t = 0
final_out = self.diffusion(t, out.detach())
return final_out
# def test():
# model = SDENet_mnist(layer_depth=10, num_classes=10, dim=64)
# return model
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
if __name__ == '__main__':
model = SDENet_mnist(layer_depth=10, num_classes=10, dim=64)#test() # 有些可能会要求pytest模块才能运行所以注释一下
num_params = count_parameters(model)
print(num_params)
CG
可以估计不确定性的神经网络——SDE-Net模型浅析
概率自回归预测——DeepAR模型浅析
https://github.com/Junghwan-brian/SDE-Net/blob/master/model/SDENet.py
https://github.com/Lingkai-Kong/SDE-Net/blob/master/MNIST/resnet_mnist.py