```python
def make_features(x):
x=x.unsqueeze(1)
return torch.cat([x**i for i in range(1,4)],1)
def get_batch(batch_size=32):
random=torch.randn(batch_size)
x=make_features(random)
y=f(x)
return x,y
w=torch.FloatTensor([0.5,3,2.4]).unsqueeze(1)
b=torch.FloatTensor([0.9])
def f(x):
return x.mm(w)+b[0]
class Poly_model(nn.Module):
def __init__(self):
super(Poly_model, self).__init__()
self.poly=nn.Linear(3,1)
def forward(self,x):
out=self.poly(x)
return out
model=Poly_model()
criterion=nn.MSELoss()
optimizer=optim.SGD(model.parameters(),lr=1e-3)
epochs_num=0
while True:
batch_x,batch_y=get_batch()
output=model(batch_x)
loss=criterion(output,batch_y)
print_loss=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
epochs_num+=1
if print_loss<0.001:
print('训练完成')
break
model.eval()
predict=model(batch_x)
predict=predict.data.numpy()
plt.plot(batch_x,batch_y,'ro',label='Original curve')
plt.plot(batch_x,predict,label='Fitting curve')
plt.show()