Diffusion代码思路

还是有一些地方没搞懂

import numpy as np
import math
import torch.nn as nn
import torch
import torch.nn.functional as F
import time


class SinousoidalPosEmb(nn.Module): #位置编码网络
    def __init__(self,dim):
        super(SinousoidalPosEmb,self).__init__()
        self.dim = dim

    def forward(self,x):
        device = x.device
        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = x[:, None] * emb[None, :]
        emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
        return emb


class MLP(nn.Module):  #去噪网络
    def __init__(self, state_dim,action_dim,hidden_dim,device,time_dim=16):
        super(MLP, self).__init__()
        """
        state_dim: 
        action_dim: 
        hidden_dim: 
        device: device
        time_dim: 时间维度,每一步扩散的时间需要被编码
        """
        self.time_dim = time_dim
        self.device = device
        self.action_dim = action_dim
        self.state_dim = state_dim
        self.hidden_dim = hidden_dim

        self.time_mlp = nn.Sequential(  #位置编码
            SinousoidalPosEmb(self.time_dim),
            nn.Linear(self.time_dim, self.time_dim*2),
            nn.Mish(),
            nn.Linear(self.time_dim*2, self.time_dim),
        )

        self.input_dim = self.state_dim + self.action_dim + self.time_dim
        self.mid_layer = nn.Sequential(
            nn.Linear(self.input_dim, self.hidden_dim),
            nn.Mish(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Mish(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.Mish(),
        )
        self.final_layer = nn.Linear(self.hidden_dim, self.action_dim)

        self.init_weights()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x, time, state):
        time_emb = self.time_mlp(time)
        x = torch.cat((x, state,time_emb), dim=1)
        x = self.mid_layer(x)
        return self.final_layer(x)


class WeightedLoss(nn.Module): #加权损失 采用继承
    def __init__(self):
        super(WeightedLoss, self).__init__()

    def forward(self, pred, target, weighted=1.0):
        loss = self._loss(pred, target)
        WeightedLoss = (loss* weighted).mean()
        return WeightedLoss

class WeightedL1(WeightedLoss):
    def _loss(self, pred, target):
        return torch.abs(pred - target)

class WeightedL2(WeightedLoss):
    def _loss(self, pred, target):
        return F.mse_loss(pred, target,reduction='none')
Losses = {
    'L1': WeightedL1,
    'L2': WeightedL2,
}

def extract(a,t,x_shape):  #取对于t的累乘α数据
    """ Extract values from a 1-D alpha tensor and reshape them to desired shape """
    b,*_ = x_shape
    out = a.gather(-1,t)
    return out.reshape(b,*((1,) * (len(x_shape)-1)))


class Diffusion(nn.Module): #扩散网络
    def __init__(self, loss_type,beta_schedule ='linear',clip_denoised = True,predict_epsilon=True,**kwargs):
        super(Diffusion, self).__init__()
        self.state_dim = kwargs["obs_dim"]
        self.action_dim = kwargs['act_dim']
        self.hidden_dim = kwargs['hidden_dim']
        self.T = kwargs['T'] #反传几步
        self.loss_type = loss_type
        self.clip_denoised = clip_denoised
        self.predict_epsilon = predict_epsilon
        self.device = torch.device(kwargs['device'])
        self.model = MLP(self.state_dim, self.action_dim, self.hidden_dim, self.device).to(kwargs['device'])

        if beta_schedule == 'linear':  #公式里的β间距划分
            betes = torch.linspace(1e-4, 0.02, self.T,dtype=torch.float32)

        alphas = 1.0 - betes   # α=1-β
        alphas_cumprod = torch.cumprod(alphas,axis=0) #αt的累乘 [1,2,3] -> [1,2,6]
        alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]]) # αt-1的累乘 [1,2,6] -> [1,1,2]


        #我们可能希望模型中的某些参数参数不更新(从开始到结束均保持不变),但又希望参数保存下来,就会用到 register_buffer()
        self.register_buffer("betes", betes)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cumprod", alphas_cumprod)
        self.register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)

        #前向过程
        # 根号累乘到αt
        self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
        # 根号1-累乘到αt
        self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))

        #反向过程
        #后验过程的方差
        posterior_variance = betes * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        self.register_buffer("posterior_variance", posterior_variance)
        self.register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
        #后验过程的均值
        self.register_buffer("sqrt_recip_alphas_cumprod",torch.sqrt(1.0/alphas_cumprod))
        self.register_buffer("sqrt_recipm1_alphas_cumprod",torch.sqrt(1.0/alphas_cumprod - 1))

        self.register_buffer("posterior_mean_coef1", betes * torch.sqrt(alphas_cumprod_prev)/(1.0 - alphas_cumprod))
        self.register_buffer("posterior_mean_coef2",(1.0-alphas_cumprod_prev)*torch.sqrt(alphas)/(1.0 - alphas_cumprod))

        self.loss_fn = Losses[loss_type]() #选择损失函数


    def q_posterior(self,x_start,x,t): #求后验概率
        posterior_mean = (extract(self.posterior_mean_coef1,t,x.shape) * x_start
        + extract(self.posterior_mean_coef2,t,x.shape) * x
        )
        posterior_variance = extract(self.posterior_variance,t,x.shape)
        posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped,t,x.shape)
        return posterior_mean,posterior_variance,posterior_log_variance_clipped

    def predict_start_from_noise(self,x,t,pred_noise):
        return (extract(self.sqrt_recip_alphas_cumprod,t,x.shape) * x
                -extract(self.sqrt_recipm1_alphas_cumprod,t,x.shape)*pred_noise)


    def p_mean_variance(self,x,t,s):
        pred_noise = self.model(x,t,s)
        x_recon = self.predict_start_from_noise(x,t,pred_noise)
        x_recon.clamp_(-1,1) #下划线代表在内部操作,不用再赋值
        model_mean, posterior_variance,posterior_log_variance = self.q_posterior(x_recon,x,t)
        return model_mean, posterior_log_variance   #用log比较稳定


    def p_sample(self,x,t,s):
        b, *_,device = *x.shape,x.device
        model_mean,model_log_variance = self.p_mean_variance(x,t,s) #模型预测的均值和方差
        noise=torch.randn_like(x)

        nonzero_mask = (1 - (t == 0).float()).reshape(b,*((1,)*(len(x.shape)-1))) #T=0不加噪
        return model_mean + nonzero_mask * (0.5*model_log_variance).exp() * noise

    def p_sample_loop(self, state, shape, *args,**kwargs): #反向采样过程
        device = self.device
        batch_size = state.shape[0]
        x = torch.randn(shape, device=device,requires_grad=False)  #高斯噪声

        for i in reversed(range(0,self.T)):
            t = torch.full((batch_size,),i,device=device,dtype=torch.long)
            x = self.p_sample(x,t,state)

        return x


    def sample(self,state,*args,**kwargs):
        """
        state: [batch_size, state_dim]
        """
        batch_size = state.shape[0]
        shape = [batch_size,self.action_dim]  #初始化噪声的形状
        action = self.p_sample_loop(state,shape, *args,**kwargs) #ddpm不断扩散的过程
        return action.clamp_(-1,1) #防止action超过能和环境交互的边界

    # --------------------------------------------training------------------------------------------#

    def q_sample(self,x_start,t,noise):
        sample =(
            extract(self.sqrt_alphas_cumprod,t,x_start.shape) * x_start
            + extract(self.sqrt_one_minus_alphas_cumprod,t,x_start.shape) * noise
        )
        return sample
    def p_losses(self,x_start,state,t,weights=1.0):
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start,t,noise)
        x_recon = self.model(x_noisy,t,state)

        loss = self.loss_fn(x_recon,noise,weights)
        return loss


    def loss(self,x,state,weights=1.0):
        batch_size = len(x)
        t = torch.randint(0,self.T,(batch_size,),device=self.device).long()
        return self.p_losses(x,state,t,weights)

    def forward(self,state,*args,**kwargs): #采样过程
        return self.sample(state, *args,**kwargs)




if __name__ == '__main__':
    device = "cpu"
    x = torch.randn(256,2).to(device)
    state = torch.randn(256,11).to(device)
    model = Diffusion(loss_type="L2",obs_dim=11,act_dim=2,hidden_dim=256,T=100,device=device)
    action = model(state)

    loss =model.loss(x,state)
    print(f"action:{action};loss:{loss.item()}")




  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值