图像生成的Rectified Flow方法之一:原理

一、资源

下面这个链接,已经介绍的很详细了。

ICLR 2023 | 扩散生成模型新方法:极度简化,一步生成 - 知乎

二、原理的代码demo

import torch
import numpy as np
import torch.nn as nn
from torch.distributions import Normal, Categorical
from torch.distributions.multivariate_normal import MultivariateNormal
from sklearn.datasets import make_circles, make_checkerboard
from torch.distributions.mixture_same_family import MixtureSameFamily
from torch.distributions.log_normal import LogNormal
from torch.distributions.independent import Independent
from torch.distributions.pareto import Pareto
from torch.distributions.studentT import StudentT
import matplotlib.pyplot as plt
import torch.nn.functional as F
from tqdm.notebook import tqdm

# def get_batch(num_samples):
#   points, _ = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
#   x = torch.tensor(points).type(torch.float32)
#   logp_diff_t1 = torch.zeros(num_samples, 1).type(torch.float32)
#
#   return(x, logp_diff_t1)


class MLP(nn.Module):
    def __init__(self, input_dim=2, hidden_num=100):
        super().__init__()
        self.fc1 = nn.Linear(input_dim + 1, hidden_num, bias=True)
        self.fc2 = nn.Linear(hidden_num, hidden_num, bias=True)
        self.fc3 = nn.Linear(hidden_num, input_dim, bias=True)
        self.act = lambda x: torch.tanh(x)

    def forward(self, x_input, t):
        inputs = torch.cat([x_input, t], dim=1)
        x = self.fc1(inputs)
        x = self.act(x)
        x = self.fc2(x)
        x = self.act(x)
        x = self.fc3(x)

        return x


class RectifiedFlow():
    def __init__(self, model=None, num_steps=1000):
        self.model = model
        self.N = num_steps

    def get_train_tuple(self, z0=None, z1=None):
        t = torch.rand((z1.shape[0], 1))
        z_t = t * z1 + (1. - t) * z0
        target = z1 - z0

        return z_t, t, target

    @torch.no_grad()
    def sample_ode(self, z0=None, N=None):
        ### NOTE: Use Euler method to sample from the learned flow
        if N is None:
            N = self.N
        dt = 1. / N
        traj = []  # to store the trajectory
        z = z0.detach().clone()
        batchsize = z.shape[0]

        traj.append(z.detach().clone())
        for i in range(N):
            t = torch.ones((batchsize, 1)) * i / N
            pred = self.model(z, t)
            z = z.detach().clone() + pred * dt  # z = z + vdt,flow的点

            traj.append(z.detach().clone())  # 保存flow点

        return traj


def train_rectified_flow(rectified_flow, optimizer, pairs, batchsize, inner_iters):
    loss_curve = []
    # for i in range(inner_iters+1):
    for i in tqdm(range(inner_iters + 1)):
        optimizer.zero_grad()
        indices = torch.randperm(len(pairs))[:batchsize]
        # print(f'indices:{indices}.')
        batch = pairs[indices]
        # print(f'batch:{batch}.')
        z0 = batch[:, 0].detach().clone()
        print(z0.shape)
        z1 = batch[:, 1].detach().clone()
        print(z1.shape)
        z_t, t, target = rectified_flow.get_train_tuple(z0=z0, z1=z1)

        pred = rectified_flow.model(z_t, t)
        print("p", pred.shape)
        print("t", target.shape)
        loss = (target - pred).view(pred.shape[0], -1).abs().pow(2).sum(dim=1)
        loss = loss.mean()
        loss.backward()

        optimizer.step()
        loss_curve.append(np.log(loss.item()))  ## to store the loss curve

    return rectified_flow, loss_curve


@torch.no_grad()
def draw_plot(rectified_flow, z0, z1, N=None):
    traj = rectified_flow.sample_ode(z0=z0, N=N)

    plt.figure(figsize=(4, 4))
    plt.xlim(-M, M)
    plt.ylim(-M, M)

    # plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label=r'$\pi_1$', alpha=0.15)
    plt.scatter(traj[0][:, 0].cpu().numpy(), traj[0][:, 1].cpu().numpy(), label=r'$\pi_0$', alpha=0.15)
    plt.scatter(traj[-1][:, 0].cpu().numpy(), traj[-1][:, 1].cpu().numpy(), label='Generated', alpha=0.15)
    plt.legend()
    plt.title('Distribution')
    plt.tight_layout()

    traj_particles = torch.stack(traj)
    plt.figure(figsize=(4, 4))
    plt.xlim(-M, M)
    plt.ylim(-M, M)
    plt.axis('equal')
    for i in range(100):
        plt.plot(traj_particles[:, i, 0], traj_particles[:, i, 1])
    plt.title('Transport Trajectory')
    plt.tight_layout()


@torch.no_grad()
def draw_plot_initial(rectified_flow, z0, z1, N=None):
    # traj = rectified_flow.sample_ode(z0=z0, N=N)

    plt.figure(figsize=(4, 4))
    plt.xlim(-M, M)
    plt.ylim(-M, M)

    # plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label=r'$\pi_1$', alpha=0.15)
    plt.scatter(z0[:, 0].cpu().numpy(), z0[:, 1].cpu().numpy(), label=r'$\pi_0$', alpha=0.15)
    plt.scatter(z1[:, 0].cpu().numpy(), z1[:, 1].cpu().numpy(), label='Generated', alpha=0.15)
    plt.legend()
    plt.title('Distribution')
    plt.tight_layout()

    # traj_particles = torch.stack(traj)
    plt.figure(figsize=(4, 4))
    plt.xlim(-M, M)
    plt.ylim(-M, M)
    plt.axis('equal')
    for i in range(30):
        z_t = torch.stack([0.1 * t * z1[i, :] + (1. - 0.1 * t) * z0[i, :] for t in range(8)])
        plt.plot(z_t[:, 0], z_t[:, 1])
    plt.title('Transport Trajectory')
    plt.tight_layout()


D = 10.
M = 15
VAR = 0.3
DOT_SIZE = 4
COMP = 3  # 3个gaussain mix在一起.
sampleCount = 10000

initial_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
initial_comp = MultivariateNormal(torch.tensor([
    [D * np.sqrt(3) / 2., D / 2.],
    [-D * np.sqrt(3) / 2., D / 2.],
    [0.0, - D * np.sqrt(3) / 2.]]).float(),  # mu
                                  VAR * torch.stack([torch.eye(2) for i in range(COMP)]))  # var
initial_model = MixtureSameFamily(initial_mix, initial_comp)
samples_0 = initial_model.sample([sampleCount])

target_mix = Categorical(torch.tensor([1 / COMP for i in range(COMP)]))
target_comp = MultivariateNormal(torch.tensor([
    [D * np.sqrt(3) / 2., - D / 2.],
    [-D * np.sqrt(3) / 2., - D / 2.],
    [0.0, D * np.sqrt(3) / 2.]]).float(),
                                 VAR * torch.stack([torch.eye(2) for i in range(COMP)]))
target_model = MixtureSameFamily(target_mix, target_comp)
samples_1 = target_model.sample([sampleCount])
print('Shape of the samples:', samples_0.shape, samples_1.shape)

# samples_0 = torch.randn(10000,2) * 0.2

# samples_1,_ = get_batch(10000)

plt.figure(figsize=(4, 4))
plt.xlim(-M, M)
plt.ylim(-M, M)
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(samples_0[:, 0].cpu().numpy(), samples_0[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_0$')  # pi_0, blue
plt.scatter(samples_1[:, 0].cpu().numpy(), samples_1[:, 1].cpu().numpy(), alpha=0.1, label=r'$\pi_1$')  # pi_1,
plt.legend()

plt.tight_layout()
plt.show()
print('fuck')


x_0 = samples_0.detach().clone()[torch.randperm(len(samples_0))]
x_1 = samples_1.detach().clone()[torch.randperm(len(samples_1))]
x_pairs = torch.stack([x_0, x_1], dim=1)
print(x_pairs.shape)


iterations = 10
batchsize = 4096
input_dim = 2

rectified_flow_1 = RectifiedFlow(model=MLP(input_dim, hidden_num=100), num_steps=100)
optimizer = torch.optim.Adam(rectified_flow_1.model.parameters(), lr=5e-3)
print('111111111111111111')
rectified_flow_1, loss_curve = train_rectified_flow(rectified_flow_1, optimizer, x_pairs, batchsize, iterations)
print('2222222222222222222')
plt.plot(np.linspace(0, iterations, iterations + 1), loss_curve[:(iterations + 1)])
plt.title('Training Loss Curve')
plt.show()


 运行起来,可能提示错误:

ImportError: IProgress not found. Please update jupyter and ipywidgets.

肿么办:

pip install --upgrade jupyter

安装了一堆东西,就好了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值