从零开始简单实现扩散模型
import torch
import torchvision
from torch import nn
from torch. nn import functional as F
from torch. utils. data import DataLoader
import matplotlib. pyplot as plt
device = torch. device( 'cuda' if torch. cuda. is_available( ) else 'cpu' )
device
dataset = torchvision. datasets. MNIST( root= './data' , download= True , train= True , transform= torchvision. transforms. ToTensor( ) )
train_dataloader = DataLoader( dataset, batch_size= 16 , shuffle= True )
x, y = next ( iter ( train_dataloader) )
print ( x. shape, len ( train_dataloader) )
print ( y)
plt. imshow( torchvision. utils. make_grid( x) [ 0 ] , cmap= 'Greys' )
def corrupt ( ori, amount) :
noise = torch. randn_like( ori)
amount = amount. view( - 1 , 1 , 1 , 1 )
return ( 1 - amount) * ori + amount * noise
fig, axs = plt. subplots( 2 , 1 , figsize= ( 8 , 3 ) )
axs[ 0 ] . set_title( 'input data' )
axs[ 0 ] . imshow( torchvision. utils. make_grid( x) [ 0 ] , cmap= 'Greys' )
amount = torch. linspace( 0 , 1 , x. shape[ 0 ] )
noised_x = corrupt( x, amount)
axs[ 1 ] . set_title( 'Corrupted data (-- amount increases -->)' )
axs[ 1 ] . imshow( torchvision. utils. make_grid( noised_x) [ 0 ] , cmap= 'Greys' )
<matplotlib.image.AxesImage at 0x7f543c107f40>
class BasicUNet ( nn. Module) :
def __init__ ( self, in_channel= 1 , out_channel= 1 ) :
super ( ) . __init__( )
self. down_layers= torch. nn. ModuleList( [
nn. Conv2d( in_channel, 32 , kernel_size= 5 , padding= 2 ) ,
nn. Conv2d( 32 , 64 , kernel_size= 5 , padding= 2 ) ,
nn. Conv2d( 64 , 64 , kernel_size= 5 , padding= 2 )
] )
self. up_layers= torch. nn. ModuleList( [
nn. Conv2d( 64 , 64 , kernel_size= 5 , padding= 2 ) ,
nn. Conv2d( 64 , 32 , kernel_size= 5 , padding= 2 ) ,
nn. Conv2d( 32 , out_channel, kernel_size= 5 , padding= 2 )
] )
self. act = nn. ReLU( )
self. downscale = nn. MaxPool2d( 2 )
self. upscale = nn. Upsample( scale_factor= 2 )
def forward ( self, x) :
h = [ ]
for i, l in enumerate ( self. down_layers) :
x = self. act( l( x) )
if i < 2 :
h. append( x)
x = self. downscale( x)
for i, l in enumerate ( self. up_layers) :
if i> 0 :
x= self. upscale( x)
x += h. pop( )
x = self. act( l( x) )
return x
net = BasicUNet( ) . to( device)
x= torch. rand( 8 , 1 , 28 , 28 ) . to( device)
net( x) . shape
torch.Size([8, 1, 28, 28])
n_epochs= 3
loss_fn = nn. MSELoss( )
optimizer= torch. optim. Adam( net. parameters( ) , lr= 1e-3 )
losses= [ ]
for epoch in range ( n_epochs) :
for x, y in train_dataloader:
x = x. to( device)
noise_amount = torch. rand( x. shape[ 0 ] ) . to( device)
noisy_x = corrupt( x, noise_amount)
pred = net( noisy_x)
loss = loss_fn( pred, x)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
losses. append( loss. item( ) )
avg_loss= sum ( losses[ - len ( train_dataloader) : ] ) / len ( train_dataloader)
print ( f'finished epoch { epoch} , average loss is { avg_loss: .5f } ' )
plt. plot( losses)
plt. ylim( 0 , 0.1 )
finished epoch 0, average loss is 0.03528
finished epoch 1, average loss is 0.03257
finished epoch 2, average loss is 0.03201
(0.0, 0.1)
x, y = next ( iter ( train_dataloader) )
x = x[ : 8 ]
amount= torch. linspace( 0 , 1 , x. shape[ 0 ] )
noise = torch. randn_like( x)
amount = amount. view( - 1 , 1 , 1 , 1 )
noisy_x = ( 1 - amount) * x + amount* noise
with torch. no_grad( ) :
preds= net( noisy_x. to( device) ) . detach( ) . cpu( )
fig, axs = plt. subplots( 3 , 1 , figsize= ( 8 , 5 ) )
axs[ 0 ] . set_title( 'input data' )
axs[ 0 ] . imshow( torchvision. utils. make_grid( x) [ 0 ] . clip( 0 , 1 ) , cmap= 'Greys' )
axs[ 1 ] . set_title( 'add noise data' )
axs[ 1 ] . imshow( torchvision. utils. make_grid( noisy_x) [ 0 ] . clip( 0 , 1 ) , cmap= 'Greys' )
axs[ 2 ] . set_title( 'network data' )
axs[ 2 ] . imshow( torchvision. utils. make_grid( preds) [ 0 ] , cmap= 'Greys' )
torchvision. utils. make_grid( preds) . shape
torch.Size([3, 32, 242])