@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
dt = -1. / self.rsde.N
z = torch.randn_like(x)
drift, diffusion = self.rsde.sde(x, t)
x_mean = x + drift * dt
x = x_mean + diffusion[:, None, None, None] * np.sqrt(-dt) * z
return x, x_mean
@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
def update_fn(self, x, t):
f, G = self.rsde.discretize(x, t)
z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
@register_predictor(name='ancestral_sampling')
class AncestralSamplingPredictor(Predictor):
"""The ancestral sampling predictor. Currently only supports VE/VP SDEs."""
def __init__(self, sde, score_fn, probability_flow=False):
super().__init__(sde, score_fn, probability_flow)
if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE):
raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
assert not probability_flow, "Probability flow not supported by ancestral sampling"
def vesde_update_fn(self, x, t):
#对应论文中公式(47)
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).long()
sigma = sde.discrete_sigmas[timestep]
adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])
score = self.score_fn(x, t)
x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[:, None, None, None]
std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
noise = torch.randn_like(x)
x = x_mean + std[:, None, None, None] * noise
return x, x_mean
def vpsde_update_fn(self, x, t):
sde = self.sde
timestep = (t * (sde.N - 1) / sde.T).long()
beta = sde.discrete_betas.to(t.device)[timestep]
score = self.score_fn(x, t)
x_mean = (x + beta[:, None, None, None] * score) / torch.sqrt(1. - beta)[:, None, None, None]
noise = torch.randn_like(x)
x = x_mean + torch.sqrt(beta)[:,