import matplotlib.pyplot as plt
import torch
from torch import nn
x = torch.linspace(-6, 6, 10)
sigmoid = nn.Sigmoid() # sigmoid激活函数
ysigmoid = sigmoid(x)
tanh = nn.Tanh() # tanh激活函数
ytanh = tanh(x)
relu = nn.ReLU() # Relu激活函数
yrelu = relu(x)
softplus = nn.Softplus()
ysoftplus = softplus(x)
plt.figure(figsize=(14,3)) #可视化激活函数
plt.subplot(1,4,1)
plt.plot(x.data.numpy(), ysigmoid.data.numpy(), "r-")
plt.title("Sigmoid")
plt.grid()
plt.subplot(1,4,2)
plt.plot(x.data.numpy(), ytanh.data.numpy(), "r-")
plt.title("Tanh")
plt.grid()
plt.subplot(1,4,3)
plt.plot(x.data.numpy(), yrelu.data.numpy(), "r-")
plt.title("Relu")
plt.grid()
plt.subplot(1,4,4)
plt.plot(x.data.numpy(), ysoftplus.data.numpy(), "r-")
plt.title("softplus")
plt.show()