from glob import glob
import paddle
from matplotlib import pyplot as plt
import numpy as np
from tqdm import tqdm
import cv2
defnewton_method_arbitraryroot(a, n, initial_guess, tolerance=1e-8, max_iterations=1000):"""
Approximate the nth root of a number using Newton's method.
:param a: The number for which the nth root is to be calculated.
:param n: The degree of the root to be calculated.
:param initial_guess: The initial guess for the nth root.
:param tolerance: The acceptable error tolerance for the approximation.
:param max_iterations: Maximum number of iterations allowed.
:return: Approximation of the nth root of a.
"""
x_old = initial_guess
label =[]
t_list =[]
input_list =[]for t inrange(1, max_iterations):
x_new = x_old -(x_old ** n - a)/n
input_list.append(x_old)
label.append(x_old)
t_list.append(t)print(np.var(x_new - x_old))ifabs(np.var(x_new - x_old))< tolerance:return input_list, label, t_list
x_old = x_new
return input_list, label, t_list
image = cv2.resize(cv2.imread('1.JPG',0),(64,64))
plt.imshow(image)
plt.show()# 由于噪声可能有复数所以使用奇数次
output, label, t_list = newton_method_arbitraryroot((image /255)**3,3, np.random.random(image.size).reshape(image.shape),
max_iterations=20)# for one in label:# plt.ion()# plt.imshow((one*255).astype(np.uint8))# plt.pause(0.1)# plt.clf()# plt.ioff()# for one in tqdm(glob("/home/aistudio/data_set/*")[:80]):# 训练坐标扩散from u_net_one import UNet as SD
net = SD(1,1)
opt = paddle.optimizer.AdamW(parameters=net.parameters(), learning_rate=0.0001)
loss_func = paddle.nn.MSELoss()
bar = tqdm(range(100))for epoch in bar:
out = net(paddle.to_tensor(output).astype("float32").reshape([-1,1,64,64]))
loss = loss_func(out.reshape([1,-1]), paddle.to_tensor(label).astype("float32").reshape([1,-1]))
bar.set_description("loss__{}_".format(loss.item()))
opt.clear_grad()
loss.backward()
opt.step()
out = net(paddle.to_tensor(output).astype("float32").reshape([-1,1,64,64]))[0,0,:,:]
output=[]for i inrange(30):
output.append(out.reshape([64,64]).numpy())
out = net(paddle.to_tensor(out).astype("float32").reshape([-1,1,64,64]))for one in output:
one+=np.abs(one).max()
one/=np.abs(one).max()
plt.ion()
plt.imshow((one*255).astype(np.uint8))
plt.pause(0.1)
plt.clf()
plt.ioff()