一、资源
下面这个链接,已经介绍的很详细了。
ICLR 2023 | 扩散生成模型新方法:极度简化,一步生成 - 知乎
二、原理的代码demo
import torch
import numpy as np
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from sklearn.datasets import make_circles, make_checkerboard
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.log_normal import LogNormal
from torch.distributions.independent import Independent
from torch.distributions.pareto import Pareto
from torch.distributions.studentT import StudentT
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm.notebook import tqdm
# def get_batch(num_samples):
# points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
# x = torch.tensor(points).type(torch.float32)
# logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32)
#
# return(x, logp_diff_t1)
class MLP(nn.Module):
def __init__(self, input_dim=2, hidden_num=100):
super().__init__()
self.fc1 = nn.Linear(input_dim + 1, hidden_num, bias=True)
self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)
self.act = lambda x: torch.tanh(x)
def forward(self, x_input, t):
inputs = torch.cat([x_input, t], dim=1)
x = self.fc1(inputs)
x = self.act(x)
x = self.fc2(x)
x = self.act(x)
x = self.fc3(x)
return x
class RectifiedFlow():
def __init__(self, model=None, num_steps=1000):
self.model = model
self.N = num_steps
def get_train_tuple(self, z0=None, z1=None):
t = torch.rand((z1.shape[0], 1))
z_t = t * z1 + (1. - t) * z0
target = z1 - z0
return z_t, t, target
@torch.no_grad()
def sample_ode(self, z0=None, N=None):
### NOTE: Use Euler method to sample from the learned flow
if N is None:
N = self.N
dt = 1. / N
traj = [] # to store the trajectory
z = z0.detach().clone()
batchsize = z.shape[0]
traj.append(z.detach().clone())
for i in range(N):
t = torch.ones((batchsize, 1)) * i / N
pred = self.model(z, t)
z = z.detach().clone() + pred * dt # z = z + vdt,flow的点
traj.append(z.detach().clone()) # 保存flow点
return traj
def train_rectified_flow(rectified_flow, optimizer, pairs, batchsize, inner_iters):
loss_curve = []
# for i in range(inner_iters+1):
for i in tqdm(range(inner_iters + 1)):
optimizer.zero_grad()
indices = torch.randperm(len(pairs))[:batchsize]
# print(f'indices:{indices}.')
batch = pairs[indices]
# print(f'batch:{batch}.')
z0 = batch[:, 0].detach().clone()
print(z0.shape)
z1 = batch[:, 1].detach().clone()
print(z1.shape)
z_t, t, target = rectified_flow.get_train_tuple(z0=z0, z1=z1)
pred = rectified_flow.model(z_t, t)
print("p", pred.shape)
print("t", target.shape)
loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
loss = loss.mean()
loss.backward()
optimizer.step()
loss_curve.append(np.log(loss.item())) ## to store the loss curve
return rectified_flow, loss_curve
@torch.no_grad()
def draw_plot(rectified_flow, z0, z1, N=None):
traj = rectified_flow.sample_ode(z0=z0, N=N)
plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
# plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label=r'$\pi_1$', alpha=0.15)
plt.scatter(traj[0][:, 0].cpu().numpy(), traj[0][:, 1].cpu().numpy(), label=r'$\pi_0$', alpha=0.15)
plt.scatter(traj[-1][:, 0].cpu().numpy(), traj[-1][:, 1].cpu().numpy(), label='Generated', alpha=0.15)
plt.legend()
plt.title('Distribution')
plt.tight_layout()
traj_particles = torch.stack(traj)
plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
plt.axis('equal')
for i in range(100):
plt.plot(traj_particles[:, i, 0], traj_particles[:, i, 1])
plt.title('Transport Trajectory')
plt.tight_layout()
@torch.no_grad()
def draw_plot_initial(rectified_flow, z0, z1, N=None):
# traj = rectified_flow.sample_ode(z0=z0, N=N)
plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
# plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label=r'$\pi_1$', alpha=0.15)
plt.scatter(z0[:, 0].cpu().numpy(), z0[:, 1].cpu().numpy(), label=r'$\pi_0$', alpha=0.15)
plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label='Generated', alpha=0.15)
plt.legend()
plt.title('Distribution')
plt.tight_layout()
# traj_particles = torch.stack(traj)
plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
plt.axis('equal')
for i in range(30):
z_t = torch.stack([0.1 * t * z1[i, :] + (1. - 0.1 * t) * z0[i, :] for t in range(8)])
plt.plot(z_t[:, 0], z_t[:, 1])
plt.title('Transport Trajectory')
plt.tight_layout()
D = 10.
M = 15
VAR = 0.3
DOT_SIZE = 4
COMP = 3 # 3个gaussain mix在一起.
sampleCount = 10000
initial_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
initial_comp = MultivariateNormal(torch.tensor([
[D * np.sqrt(3) / 2., D / 2.],
[-D * np.sqrt(3) / 2., D / 2.],
[0.0, - D * np.sqrt(3) / 2.]]).float(), # mu
VAR * torch.stack([torch.eye(2) for i in range(COMP)])) # var
initial_model = MixtureSameFamily(initial_mix, initial_comp)
samples_0 = initial_model.sample([sampleCount])
target_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
target_comp = MultivariateNormal(torch.tensor([
[D * np.sqrt(3) / 2., - D / 2.],
[-D * np.sqrt(3) / 2., - D / 2.],
[0.0, D * np.sqrt(3) / 2.]]).float(),
VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
target_model = MixtureSameFamily(target_mix, target_comp)
samples_1 = target_model.sample([sampleCount])
print('Shape of the samples:', samples_0.shape, samples_1.shape)
# samples_0 = torch.randn(10000,2) * 0.2
# samples_1,_ = get_batch(10000)
plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(samples_0[:, 0].cpu().numpy(), samples_0[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_0$') # pi_0, blue
plt.scatter(samples_1[:, 0].cpu().numpy(), samples_1[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_1$') # pi_1,
plt.legend()
plt.tight_layout()
plt.show()
print('fuck')
x_0 = samples_0.detach().clone()[torch.randperm(len(samples_0))]
x_1 = samples_1.detach().clone()[torch.randperm(len(samples_1))]
x_pairs = torch.stack([x_0, x_1], dim=1)
print(x_pairs.shape)
iterations = 10
batchsize = 4096
input_dim = 2
rectified_flow_1 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
optimizer = torch.optim.Adam(rectified_flow_1.model.parameters(), lr=5e-3)
print('111111111111111111')
rectified_flow_1, loss_curve = train_rectified_flow(rectified_flow_1, optimizer, x_pairs, batchsize, iterations)
print('2222222222222222222')
plt.plot(np.linspace(0, iterations, iterations + 1), loss_curve[:(iterations + 1)])
plt.title('Training Loss Curve')
plt.show()
运行起来,可能提示错误:
ImportError: IProgress not found. Please update jupyter and ipywidgets.
肿么办:
pip install --upgrade jupyter
安装了一堆东西,就好了。