前言
案例代码https://github.com/2012Netsky/pytorch_cnn/blob/main/4_time_series_bikes.ipynb
#!/usr/bin/env python
# coding: utf-8
# In[1]:
get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn as nn
torch.set_printoptions(edgeitems=2, linewidth=75)
# In[3]:
input_t = torch.arange(-3, 3.1, 0.1)
print(input_t.shape)
print(input_t)
# In[3]:
activation_list = [
nn.Tanh(),
nn.Hardtanh(),
nn.Sigmoid(),
nn.Softplus(),
nn.ReLU(),
nn.LeakyReLU(negative_slope=0.1),
#nn.Tanhshrink(),
#nn.Softshrink(),
#nn.Hardshrink(),
]
fig = plt.figure(figsize=(14, 28), dpi=600)
for i, activation_func in enumerate(activation_list):
subplot = fig.add_subplot(len(activation_list), 3, i+1)
subplot.set_title(type(activation_func).__name__)
output_t = activation_func(input_t)
plt.grid()
plt.plot(input_t.numpy(), input_t.numpy(),'k', linewidth=1)
plt.plot([-3,3], [0,0], 'k', linewidth=1)
plt.plot([0,0], [-3,3], 'k', linewidth=1)
plt.plot(input_t.numpy(), output_t.numpy(), 'r', linewidth=3)
output_t
# In[4]:
a = lambda x: nn.Tanh()(-2 * x - 1.25)
b = lambda x: nn.Tanh()( 1 * x + 0.75)
c = lambda x: nn.Tanh()( 4 * x + 1.)
d = lambda x: nn.Tanh()(-3 * x + 1.5)
wb_list = [
('A: tanh(-2 * x - 1.25)', a(input_t)),
('B: tanh(1 * x + 0.75)', b(input_t)),
('A + B', a(input_t) + b(input_t)),
('C: tanh(4 * x + 1.0)', c(input_t)),
('D: tanh(-3 * x - 1.5)', d(input_t)),
('C + D', c(input_t) + d(input_t)),
('C(A + B)', c(a(input_t) + b(input_t))),
('D(A + B)', d(a(input_t) + b(input_t))),
('C(A+B) + D(A+B)',
c(a(input_t) + b(input_t)) + d(a(input_t) + b(input_t))),
]
fig = plt.figure(figsize=(14, 42), dpi=600)
for i, (title_str, output_t) in enumerate(wb_list):
subplot = fig.add_subplot(len(wb_list), 3, i+1)
subplot.set_title(title_str)
plt.grid()
plt.plot(input_t.numpy(), input_t.numpy(),'k', linewidth=1)
plt.plot([-3,3], [0,0], 'k', linewidth=1)
plt.plot([0,0], [-3,3], 'k', linewidth=1)
plt.plot(input_t.numpy(), output_t.numpy(), 'r', linewidth=3)
output_t
# In[5]:
fig = plt.figure(figsize=(14, 14), dpi=600)
subplot = fig.add_subplot(1, 1, 1)
#subplot.set_title(type(activation_func).__name__)
output_t = nn.Tanh()(input_t)
#plt.grid()
plt.plot(input_t.numpy(), input_t.numpy(),'k', linewidth=1)
plt.plot([-3,3], [0,0], 'k', linewidth=1)
plt.plot([0,0], [-3,3], 'k', linewidth=1)
plt.plot(input_t.numpy(), output_t.numpy(), 'r', linewidth=3)
output_t
# In[ ]: