【365计划-4】pytorch猴痘病识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍦 参考文章地址: 365天深度学习训练营-第P4周:猴痘病识别
🍖 作者:K同学啊

###本项目来自K同学在线指导###
数据集下载:https://pan.baidu.com/s/11r_uOUV0ToMNXQtxahb0yg?pwd=7qtp

import torch
import os,PIL,random,pathlib,warnings
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision,time
import torch.nn as nn
from torchsummary import summary
from torchvision.models import vgg16

num_class=2
def localDataset(data_dir):
    data_dir=pathlib.Path(data_dir)
    data_paths=list(data_dir.glob('*'))
    classNames=[str(path).split('\\')[-1] for path in data_paths]
    print("className:",classNames,'\n')
    train_transforms=torchvision.transforms.Compose([
        torchvision.transforms.Resize([224,224]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=[0.485,0.456,0.406],
            std=[0.229,0.224,0.225])
    ])
    total_data=torchvision.datasets.ImageFolder(data_dir,transform=train_transforms)
    print(total_data,'\n')

    train_size=int(0.8*len(total_data))
    test_size=len(total_data)-train_size
    print("Train_size:",train_size,"Test_size",test_size,'\n')
    train_data,test_data=torch.utils.data.random_split(total_data,[train_size,test_size])
    return classNames,train_data,test_data

def displayData(imgs,root,show_flag):
    plt.figure(figsize=(20,5))
    for i,imgs in enumerate(imgs[:20]):
        npimg=imgs.numpy().transpose(1,2,0)
        plt.subplot(20,5,i+1)
        plt.imshow(npimg,cmap=plt.cm.binary)
        plt.axis('off')
    plt.savefig(os.path.join(root,'DatasetDispaly.png'))
    if show_flag:
        plt.show()
    else:
        plt.close('all')

def loadData(train_ds,test_ds,batch_size,root='output',show_flag=False):
    train_dl=torch.utils.data.DataLoader(train_ds,batch_size=batch_size,shuffle=True)
    test_dl=torch.utils.data.DataLoader(test_ds,batch_size=batch_size)
    for x,y in train_dl:
        print("shape of x [N,C,H,W]:",x.shape)
        print("shape of y:",y.shape)
        break
    imgs,labels=next(iter(train_dl))
    print("Image shape:",imgs.shape,'\n')
    if not os.path.exists(root) or not os.path.isdir(root):
        os.mkdir(output)
    displayData(imgs,root,show_flag=show_flag)
    return train_dl,test_dl

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(3,12,kernel_size=5,padding=0),
            nn.BatchNorm2d(12),
            nn.ReLU()
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(12,12,kernel_size=5,padding=0),
            nn.BatchNorm2d(12),
            nn.ReLU()
        )
        self.pool3=nn.Sequential(
            nn.MaxPool2d(2),
            nn.Dropout(p=0.2)
        )
        self.conv4=nn.Sequential(
            nn.Conv2d(12,24,kernel_size=5,padding=0),
            nn.BatchNorm2d(24),
            nn.ReLU()
        )
        self.conv5=nn.Sequential(
            nn.Conv2d(24,24,kernel_size=5,padding=0),
            nn.BatchNorm2d(24),
            nn.ReLU()
        )
        self.pool6=nn.Sequential(
            nn.MaxPool2d(2),
            nn.Dropout(p=0.2)
        )

        self.conv7=nn.Sequential(
            nn.Conv2d(24,48,kernel_size=5,padding=0),
            nn.BatchNorm2d(48),
            nn.ReLU()
        )
        self.conv8=nn.Sequential(
            nn.Conv2d(48,48,kernel_size=5,padding=0),
            nn.BatchNorm2d(48),
            nn.ReLU()
        )
        self.pool9=nn.Sequential(
            nn.MaxPool2d(2),
            nn.Dropout(p=0.2)
        )
        self.fc=nn.Sequential(
            nn.Linear(48*21*21,num_class)
        )

    def forward(self,x):
        batch_size=x.size(0)
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.pool3(x)
        x=self.conv4(x)
        x=self.conv5(x)
        x=self.pool6(x)
        x=self.conv7(x)
        x=self.conv8(x)
        x=self.pool9(x)
        x=x.view(batch_size,-1)
        x=self.fc(x)
        return x


class Model_vgg16(nn.Module):
    def __init__(self):
        super(Model_vgg16, self).__init__()
        self.sequ1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),  # 64*224*224
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),  # 64*224*224
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),  # 64*112*112
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 128*112*112
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),  # 128*112*112
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),  # 128*56*56
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),  # 256*56*56
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),  # 256*56*56
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),  # 256*56*56
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),  # 256*28*28
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # 512*28*28
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),  # 512*28*28
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),  # 512*28*28
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),  # 512*14*14
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),  # 512*14*14
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),  # 512*14*14
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),  # 512*14*14
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)  # 512*7*7
        )
        self.pool2 = nn.AdaptiveAvgPool2d(output_size=(7, 7))  # 512*7*7
        self.sequ3 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=25088, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5, inplace=False),
            nn.Linear(in_features=4096, out_features=17,bias=True),
            nn.Linear(17,2)
        )

    def forward(self, x):
        x = self.sequ1(x)
        x = self.pool2(x)
        x=self.sequ3(x)
        return x


def autopad(k, p=None):  # kernel, padding
    # Pad to 'same'
    if p is None:
        p = k // 2 if isinstance(k, int) else [x // 2 for x in k]  # auto-pad
    return p


class Conv(nn.Module):
    # Standard convolution
    def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):  # ch_in, ch_out, kernel, stride, padding, groups
        super().__init__()
        self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
        self.bn = nn.BatchNorm2d(c2)
        self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())

    def forward(self, x):
        return self.act(self.bn(self.conv(x)))

    def forward_fuse(self, x):
        return self.act(self.conv(x))


class Bottleneck(nn.Module):
    # Standard bottleneck
    def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c_, c2, 3, 1, g=g)
        self.add = shortcut and c1 == c2

    def forward(self, x):
        #print("Bottleneck")
        return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))


class C3(nn.Module):
    # CSP Bottleneck with 3 convolutions
    def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        c_ = int(c2 * e)  # hidden channels
        self.cv1 = Conv(c1, c_, 1, 1)
        self.cv2 = Conv(c1, c_, 1, 1)
        self.cv3 = Conv(2 * c_, c2, 1)  # act=FReLU(c2)
        self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))

    def forward(self, x):
        return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))


class Model_K(nn.Module):
    def __init__(self):
        super(Model_K, self).__init__()

        # 卷积模块
        self.Conv = Conv(3, 32, 3, 2)

        # C3模块1
        self.C3_1 = C3(32, 64, 1, 2)

        # 全连接网络层,用于分类
        self.classifier = nn.Sequential(
            nn.Linear(in_features=802816, out_features=100),
            nn.ReLU(),
            nn.Linear(in_features=100, out_features=2)
        )

    def forward(self, x):
        x = self.Conv(x)
        x = self.C3_1(x)
        x = torch.flatten(x, start_dim=1)
        x = self.classifier(x)

        return x


def train(train_dl,model,loss_fn,opt):
    size=len(train_dl.dataset)
    num_batches=len(train_dl)
    train_acc,train_loss=0,0
    for x,y in train_dl:
        x,y=x.to(device),y.to(device)
        pre=model(x)
        loss=loss_fn(pre,y)
        opt.zero_grad()
        loss.backward()
        opt.step()

        train_acc +=(pre.argmax(1)==y).type(torch.float).sum().item()
        train_loss +=loss.item()
    train_acc/=size
    train_loss/=num_batches
    return train_acc,train_loss

def test(test_dl,model,loss_fn):
    size=len(test_dl.dataset)
    num_batches=len(test_dl)
    test_acc,test_loss=0,0
    with torch.no_grad():
        for x, y in test_dl:
            x, y = x.to(device), y.to(device)
            pre = model(x)
            loss = loss_fn(pre, y)
            test_acc += (pre.argmax(1) == y).type(torch.float).sum().item()
            test_loss += loss.item()
        test_acc /= size
        test_loss /= num_batches
        return test_acc, test_loss

def displayResult(train_acc,test_acc,train_loss,test_loss,start_epoch,epochs,output):
    epochs_range=range(start_epoch,epochs)
    plt.figure(figsize=(20,5))
    plt.subplot(1,2,1)
    plt.plot(epochs_range,train_acc,label="Train_acc")
    plt.plot(epochs_range,test_acc,label="Tesy_acc")
    plt.legend(loc="lower right")
    plt.title("train and test Acc")
    plt.subplot(1,2,2)
    plt.plot(epochs_range,train_loss,label="Train_loss")
    plt.plot(epochs_range,test_loss,label="Test_loss")
    plt.legend(loc="upper right")
    plt.title("Train and Test Loss")
    if not os.path.exists(output) or not os.path.isdir(output):
        os.mkdir(output)
    plt.savefig(os.path.join(output,"Accuracyloss.png"))
    plt.show()

def predict(model, img_path,classeNames):
    img = Image.open(img_path)
    train_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸
        torchvision.transforms.ToTensor(),  # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间
        torchvision.transforms.Normalize(  # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
    ])
    img = train_transforms(img)
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    img = img.to(device).unsqueeze(0)
    output = model(img)
    # print(output.argmax(1))

    _, indices = torch.max(output, 1)
    percentage = torch.nn.functional.softmax(output, dim=1)[0] * 100
    perc = percentage[int(indices)].item()
    result = classeNames[indices]
    print('predicted:', result, perc)

def save_file(output,model,epoche='best'):
    saveFile=os.path.join(output,'epoch'+str(epoche)+'.pkl')
    torch.save(model.state_dict(),saveFile)

if __name__=="__main__":
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using Device is {}".format(device))
    data_dir="./data"
    output='./output_vgg'
    num_classes,train_ds,test_ds=localDataset(data_dir)
    batch_size=64
    train_dl,test_dl=loadData(train_ds,test_ds,batch_size=batch_size,root=output,show_flag=True)
    epoches=100
    start_epoch=0
    train_acc=[]
    train_loss=[]
    test_acc=[]
    test_loss=[]
    best_acc=0.0
    #model=Model().to(device)
    model=Model_K().to(device)
    summary(model,(3,224,224))
    loss_fn=nn.CrossEntropyLoss()
    learn_rate=1e-3
    opt=torch.optim.SGD(model.parameters(),lr=learn_rate)
    print("---Starting train---")
    for epoche in range(start_epoch,epoches):
        model.train()
        epoch_train_acc,epoch_train_loss=train(train_dl,model,loss_fn,opt)
        model.eval()
        epoch_test_acc,epoch_test_loss=test(test_dl,model,loss_fn)
        train_acc.append(epoch_train_acc)
        train_loss.append(epoch_train_loss)
        test_acc.append(epoch_test_acc)
        test_loss.append(epoch_test_loss)
        if(epoch_test_acc>best_acc):
            best_acc=epoch_test_acc
            save_file(output, model)
        template=('Epoch:{:2d}/{:2d},train_acc:{:.1f}%,train_loss:{:.3f},test_acc{:.1f}%,test_loss{:.3f},best_acc:{:.1f}%')
        print(time.strftime('[%Y-%m-%d %H:%M:%S)]'),template.format(epoche+1,epoches,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss,best_acc*100))
    print("Done")
    displayResult(train_acc,test_acc,train_loss,test_loss,start_epoch,epoches,output)
    imgs_path='./data/Monkeypox/M01_01_06.jpg'
    predict(model,imgs_path,classeNames=num_classes)
    # save_file(output,model,epoches)

网络参数信息:

    Layer (type)               Output Shape         Param #

================================================================
Conv2d-1 [-1, 32, 112, 112] 864
BatchNorm2d-2 [-1, 32, 112, 112] 64
SiLU-3 [-1, 32, 112, 112] 0
Conv-4 [-1, 32, 112, 112] 0
Conv2d-5 [-1, 32, 112, 112] 1,024
BatchNorm2d-6 [-1, 32, 112, 112] 64
SiLU-7 [-1, 32, 112, 112] 0
Conv-8 [-1, 32, 112, 112] 0
Conv2d-9 [-1, 32, 112, 112] 1,024
BatchNorm2d-10 [-1, 32, 112, 112] 64
SiLU-11 [-1, 32, 112, 112] 0
Conv-12 [-1, 32, 112, 112] 0
Conv2d-13 [-1, 32, 112, 112] 9,216
BatchNorm2d-14 [-1, 32, 112, 112] 64
SiLU-15 [-1, 32, 112, 112] 0
Conv-16 [-1, 32, 112, 112] 0
Bottleneck-17 [-1, 32, 112, 112] 0
Conv2d-18 [-1, 32, 112, 112] 1,024
BatchNorm2d-19 [-1, 32, 112, 112] 64
SiLU-20 [-1, 32, 112, 112] 0
Conv-21 [-1, 32, 112, 112] 0
Conv2d-22 [-1, 64, 112, 112] 4,096
BatchNorm2d-23 [-1, 64, 112, 112] 128
SiLU-24 [-1, 64, 112, 112] 0
Conv-25 [-1, 64, 112, 112] 0
C3-26 [-1, 64, 112, 112] 0
Linear-27 [-1, 100] 80,281,700
ReLU-28 [-1, 100] 0
Linear-29 [-1, 2] 202

训练过程:
—Starting train—
[2023-04-01 10:23:57)] Epoch: 1/100,train_acc:66.4%,train_loss:0.611,test_acc72.7%,test_loss0.551,best_acc:72.7%
[2023-04-01 10:24:02)] Epoch: 2/100,train_acc:79.8%,train_loss:0.446,test_acc77.6%,test_loss0.485,best_acc:77.6%
[2023-04-01 10:24:08)] Epoch: 3/100,train_acc:87.2%,train_loss:0.330,test_acc81.8%,test_loss0.415,best_acc:81.8%
[2023-04-01 10:24:13)] Epoch: 4/100,train_acc:90.3%,train_loss:0.282,test_acc85.5%,test_loss0.352,best_acc:85.5%
[2023-04-01 10:24:19)] Epoch: 5/100,train_acc:92.8%,train_loss:0.228,test_acc86.9%,test_loss0.348,best_acc:86.9%
[2023-04-01 10:24:24)] Epoch: 6/100,train_acc:95.3%,train_loss:0.192,test_acc88.1%,test_loss0.335,best_acc:88.1%
[2023-04-01 10:24:29)] Epoch: 7/100,train_acc:95.9%,train_loss:0.163,test_acc88.1%,test_loss0.301,best_acc:88.1%
[2023-04-01 10:24:34)] Epoch: 8/100,train_acc:96.7%,train_loss:0.151,test_acc87.4%,test_loss0.326,best_acc:88.1%
[2023-04-01 10:24:39)] Epoch: 9/100,train_acc:97.1%,train_loss:0.135,test_acc88.1%,test_loss0.297,best_acc:88.1%
[2023-04-01 10:24:44)] Epoch:10/100,train_acc:98.0%,train_loss:0.114,test_acc88.1%,test_loss0.323,best_acc:88.1%
[2023-04-01 10:24:50)] Epoch:11/100,train_acc:98.0%,train_loss:0.110,test_acc89.5%,test_loss0.297,best_acc:89.5%
[2023-04-01 10:24:55)] Epoch:12/100,train_acc:98.4%,train_loss:0.096,test_acc87.9%,test_loss0.300,best_acc:89.5%
[2023-04-01 10:25:00)] Epoch:13/100,train_acc:99.1%,train_loss:0.082,test_acc88.6%,test_loss0.291,best_acc:89.5%
[2023-04-01 10:25:05)] Epoch:14/100,train_acc:98.7%,train_loss:0.083,test_acc88.3%,test_loss0.300,best_acc:89.5%
[2023-04-01 10:25:09)] Epoch:15/100,train_acc:98.9%,train_loss:0.076,test_acc88.6%,test_loss0.311,best_acc:89.5%
[2023-04-01 10:25:14)] Epoch:16/100,train_acc:99.2%,train_loss:0.070,test_acc88.3%,test_loss0.278,best_acc:89.5%
[2023-04-01 10:25:19)] Epoch:17/100,train_acc:99.4%,train_loss:0.063,test_acc88.3%,test_loss0.274,best_acc:89.5%
[2023-04-01 10:25:24)] Epoch:18/100,train_acc:99.5%,train_loss:0.058,test_acc89.3%,test_loss0.275,best_acc:89.5%
[2023-04-01 10:25:29)] Epoch:19/100,train_acc:99.5%,train_loss:0.054,test_acc88.8%,test_loss0.281,best_acc:89.5%
[2023-04-01 10:25:34)] Epoch:20/100,train_acc:99.6%,train_loss:0.050,test_acc89.0%,test_loss0.283,best_acc:89.5%
[2023-04-01 10:25:39)] Epoch:21/100,train_acc:99.7%,train_loss:0.046,test_acc88.3%,test_loss0.279,best_acc:89.5%
[2023-04-01 10:25:45)] Epoch:22/100,train_acc:99.6%,train_loss:0.045,test_acc89.0%,test_loss0.281,best_acc:89.5%
[2023-04-01 10:25:50)] Epoch:23/100,train_acc:99.6%,train_loss:0.044,test_acc89.5%,test_loss0.280,best_acc:89.5%
[2023-04-01 10:25:55)] Epoch:24/100,train_acc:99.9%,train_loss:0.039,test_acc89.0%,test_loss0.282,best_acc:89.5%
[2023-04-01 10:26:00)] Epoch:25/100,train_acc:99.7%,train_loss:0.039,test_acc88.8%,test_loss0.283,best_acc:89.5%
[2023-04-01 10:26:05)] Epoch:26/100,train_acc:99.8%,train_loss:0.036,test_acc88.3%,test_loss0.284,best_acc:89.5%
[2023-04-01 10:26:10)] Epoch:27/100,train_acc:99.9%,train_loss:0.034,test_acc89.0%,test_loss0.285,best_acc:89.5%
[2023-04-01 10:26:16)] Epoch:28/100,train_acc:99.8%,train_loss:0.032,test_acc89.7%,test_loss0.289,best_acc:89.7%
[2023-04-01 10:26:21)] Epoch:29/100,train_acc:99.9%,train_loss:0.031,test_acc89.5%,test_loss0.289,best_acc:89.7%
[2023-04-01 10:26:27)] Epoch:30/100,train_acc:99.9%,train_loss:0.031,test_acc90.0%,test_loss0.289,best_acc:90.0%
[2023-04-01 10:26:32)] Epoch:31/100,train_acc:99.9%,train_loss:0.028,test_acc90.0%,test_loss0.286,best_acc:90.0%
[2023-04-01 10:26:37)] Epoch:32/100,train_acc:99.9%,train_loss:0.027,test_acc89.3%,test_loss0.293,best_acc:90.0%
[2023-04-01 10:26:43)] Epoch:33/100,train_acc:100.0%,train_loss:0.025,test_acc89.7%,test_loss0.292,best_acc:90.0%
[2023-04-01 10:26:48)] Epoch:34/100,train_acc:100.0%,train_loss:0.024,test_acc90.0%,test_loss0.296,best_acc:90.0%
[2023-04-01 10:26:53)] Epoch:35/100,train_acc:100.0%,train_loss:0.023,test_acc88.8%,test_loss0.293,best_acc:90.0%
[2023-04-01 10:26:59)] Epoch:36/100,train_acc:100.0%,train_loss:0.023,test_acc88.6%,test_loss0.295,best_acc:90.0%
[2023-04-01 10:27:04)] Epoch:37/100,train_acc:100.0%,train_loss:0.021,test_acc89.3%,test_loss0.296,best_acc:90.0%
[2023-04-01 10:27:09)] Epoch:38/100,train_acc:99.9%,train_loss:0.022,test_acc89.0%,test_loss0.295,best_acc:90.0%
[2023-04-01 10:27:14)] Epoch:39/100,train_acc:100.0%,train_loss:0.021,test_acc88.6%,test_loss0.297,best_acc:90.0%
[2023-04-01 10:27:20)] Epoch:40/100,train_acc:100.0%,train_loss:0.019,test_acc88.6%,test_loss0.295,best_acc:90.0%
[2023-04-01 10:27:25)] Epoch:41/100,train_acc:100.0%,train_loss:0.019,test_acc89.0%,test_loss0.293,best_acc:90.0%
[2023-04-01 10:27:30)] Epoch:42/100,train_acc:100.0%,train_loss:0.019,test_acc88.8%,test_loss0.296,best_acc:90.0%
[2023-04-01 10:27:36)] Epoch:43/100,train_acc:100.0%,train_loss:0.018,test_acc89.0%,test_loss0.299,best_acc:90.0%
[2023-04-01 10:27:41)] Epoch:44/100,train_acc:100.0%,train_loss:0.018,test_acc88.6%,test_loss0.298,best_acc:90.0%
[2023-04-01 10:27:46)] Epoch:45/100,train_acc:100.0%,train_loss:0.015,test_acc88.8%,test_loss0.301,best_acc:90.0%
[2023-04-01 10:27:51)] Epoch:46/100,train_acc:100.0%,train_loss:0.016,test_acc88.8%,test_loss0.302,best_acc:90.0%
[2023-04-01 10:27:57)] Epoch:47/100,train_acc:100.0%,train_loss:0.015,test_acc89.0%,test_loss0.304,best_acc:90.0%
[2023-04-01 10:28:02)] Epoch:48/100,train_acc:100.0%,train_loss:0.015,test_acc88.8%,test_loss0.305,best_acc:90.0%
[2023-04-01 10:28:07)] Epoch:49/100,train_acc:100.0%,train_loss:0.014,test_acc89.5%,test_loss0.303,best_acc:90.0%
[2023-04-01 10:28:13)] Epoch:50/100,train_acc:100.0%,train_loss:0.014,test_acc88.6%,test_loss0.304,best_acc:90.0%
[2023-04-01 10:28:18)] Epoch:51/100,train_acc:100.0%,train_loss:0.013,test_acc88.3%,test_loss0.310,best_acc:90.0%
[2023-04-01 10:28:23)] Epoch:52/100,train_acc:100.0%,train_loss:0.014,test_acc89.0%,test_loss0.308,best_acc:90.0%
[2023-04-01 10:28:29)] Epoch:53/100,train_acc:100.0%,train_loss:0.014,test_acc88.3%,test_loss0.306,best_acc:90.0%
[2023-04-01 10:28:34)] Epoch:54/100,train_acc:100.0%,train_loss:0.013,test_acc88.3%,test_loss0.308,best_acc:90.0%
[2023-04-01 10:28:40)] Epoch:55/100,train_acc:100.0%,train_loss:0.013,test_acc88.3%,test_loss0.308,best_acc:90.0%
[2023-04-01 10:28:45)] Epoch:56/100,train_acc:100.0%,train_loss:0.012,test_acc88.3%,test_loss0.311,best_acc:90.0%
[2023-04-01 10:28:51)] Epoch:57/100,train_acc:100.0%,train_loss:0.012,test_acc89.3%,test_loss0.311,best_acc:90.0%
[2023-04-01 10:28:56)] Epoch:58/100,train_acc:100.0%,train_loss:0.012,test_acc89.3%,test_loss0.311,best_acc:90.0%
[2023-04-01 10:29:01)] Epoch:59/100,train_acc:100.0%,train_loss:0.012,test_acc89.3%,test_loss0.314,best_acc:90.0%
[2023-04-01 10:29:07)] Epoch:60/100,train_acc:100.0%,train_loss:0.010,test_acc88.8%,test_loss0.315,best_acc:90.0%
[2023-04-01 10:29:12)] Epoch:61/100,train_acc:100.0%,train_loss:0.011,test_acc88.6%,test_loss0.318,best_acc:90.0%
[2023-04-01 10:29:18)] Epoch:62/100,train_acc:100.0%,train_loss:0.011,test_acc89.5%,test_loss0.316,best_acc:90.0%
[2023-04-01 10:29:23)] Epoch:63/100,train_acc:100.0%,train_loss:0.011,test_acc88.8%,test_loss0.317,best_acc:90.0%
[2023-04-01 10:29:29)] Epoch:64/100,train_acc:100.0%,train_loss:0.010,test_acc88.3%,test_loss0.316,best_acc:90.0%
[2023-04-01 10:29:34)] Epoch:65/100,train_acc:100.0%,train_loss:0.010,test_acc89.0%,test_loss0.322,best_acc:90.0%
[2023-04-01 10:29:40)] Epoch:66/100,train_acc:100.0%,train_loss:0.010,test_acc88.6%,test_loss0.323,best_acc:90.0%
[2023-04-01 10:29:45)] Epoch:67/100,train_acc:99.9%,train_loss:0.010,test_acc88.8%,test_loss0.322,best_acc:90.0%
[2023-04-01 10:29:51)] Epoch:68/100,train_acc:100.0%,train_loss:0.009,test_acc89.3%,test_loss0.323,best_acc:90.0%
[2023-04-01 10:29:56)] Epoch:69/100,train_acc:100.0%,train_loss:0.009,test_acc89.0%,test_loss0.321,best_acc:90.0%
[2023-04-01 10:30:02)] Epoch:70/100,train_acc:100.0%,train_loss:0.009,test_acc89.3%,test_loss0.321,best_acc:90.0%
[2023-04-01 10:30:07)] Epoch:71/100,train_acc:100.0%,train_loss:0.009,test_acc88.3%,test_loss0.321,best_acc:90.0%
[2023-04-01 10:30:13)] Epoch:72/100,train_acc:100.0%,train_loss:0.008,test_acc88.6%,test_loss0.323,best_acc:90.0%
[2023-04-01 10:30:18)] Epoch:73/100,train_acc:100.0%,train_loss:0.008,test_acc88.8%,test_loss0.325,best_acc:90.0%
[2023-04-01 10:30:24)] Epoch:74/100,train_acc:100.0%,train_loss:0.009,test_acc88.6%,test_loss0.327,best_acc:90.0%
[2023-04-01 10:30:29)] Epoch:75/100,train_acc:100.0%,train_loss:0.008,test_acc88.6%,test_loss0.325,best_acc:90.0%
[2023-04-01 10:30:35)] Epoch:76/100,train_acc:100.0%,train_loss:0.008,test_acc88.3%,test_loss0.326,best_acc:90.0%
[2023-04-01 10:30:40)] Epoch:77/100,train_acc:100.0%,train_loss:0.008,test_acc88.6%,test_loss0.330,best_acc:90.0%
[2023-04-01 10:30:46)] Epoch:78/100,train_acc:100.0%,train_loss:0.007,test_acc88.8%,test_loss0.326,best_acc:90.0%
[2023-04-01 10:30:52)] Epoch:79/100,train_acc:100.0%,train_loss:0.007,test_acc88.3%,test_loss0.329,best_acc:90.0%
[2023-04-01 10:30:57)] Epoch:80/100,train_acc:100.0%,train_loss:0.008,test_acc88.6%,test_loss0.327,best_acc:90.0%
[2023-04-01 10:31:03)] Epoch:81/100,train_acc:100.0%,train_loss:0.007,test_acc88.3%,test_loss0.332,best_acc:90.0%
[2023-04-01 10:31:08)] Epoch:82/100,train_acc:99.9%,train_loss:0.008,test_acc88.3%,test_loss0.332,best_acc:90.0%
[2023-04-01 10:31:14)] Epoch:83/100,train_acc:100.0%,train_loss:0.007,test_acc88.8%,test_loss0.337,best_acc:90.0%
[2023-04-01 10:31:19)] Epoch:84/100,train_acc:100.0%,train_loss:0.007,test_acc88.8%,test_loss0.331,best_acc:90.0%
[2023-04-01 10:31:25)] Epoch:85/100,train_acc:100.0%,train_loss:0.007,test_acc88.3%,test_loss0.333,best_acc:90.0%
[2023-04-01 10:31:31)] Epoch:86/100,train_acc:100.0%,train_loss:0.007,test_acc89.3%,test_loss0.328,best_acc:90.0%
[2023-04-01 10:31:36)] Epoch:87/100,train_acc:100.0%,train_loss:0.006,test_acc88.8%,test_loss0.331,best_acc:90.0%
[2023-04-01 10:31:42)] Epoch:88/100,train_acc:100.0%,train_loss:0.006,test_acc89.0%,test_loss0.331,best_acc:90.0%
[2023-04-01 10:31:47)] Epoch:89/100,train_acc:100.0%,train_loss:0.006,test_acc88.8%,test_loss0.335,best_acc:90.0%
[2023-04-01 10:31:53)] Epoch:90/100,train_acc:100.0%,train_loss:0.006,test_acc88.1%,test_loss0.334,best_acc:90.0%
[2023-04-01 10:31:58)] Epoch:91/100,train_acc:100.0%,train_loss:0.006,test_acc88.6%,test_loss0.336,best_acc:90.0%
[2023-04-01 10:32:04)] Epoch:92/100,train_acc:100.0%,train_loss:0.006,test_acc87.9%,test_loss0.340,best_acc:90.0%
[2023-04-01 10:32:10)] Epoch:93/100,train_acc:100.0%,train_loss:0.006,test_acc88.6%,test_loss0.337,best_acc:90.0%
[2023-04-01 10:32:15)] Epoch:94/100,train_acc:100.0%,train_loss:0.006,test_acc88.3%,test_loss0.336,best_acc:90.0%
[2023-04-01 10:32:21)] Epoch:95/100,train_acc:100.0%,train_loss:0.006,test_acc88.8%,test_loss0.339,best_acc:90.0%
[2023-04-01 10:32:26)] Epoch:96/100,train_acc:100.0%,train_loss:0.006,test_acc89.3%,test_loss0.337,best_acc:90.0%
[2023-04-01 10:32:32)] Epoch:97/100,train_acc:100.0%,train_loss:0.005,test_acc88.1%,test_loss0.339,best_acc:90.0%
[2023-04-01 10:32:37)] Epoch:98/100,train_acc:100.0%,train_loss:0.005,test_acc88.1%,test_loss0.340,best_acc:90.0%
[2023-04-01 10:32:43)] Epoch:99/100,train_acc:100.0%,train_loss:0.005,test_acc88.8%,test_loss0.342,best_acc:90.0%
[2023-04-01 10:32:49)] Epoch:100/100,train_acc:100.0%,train_loss:0.006,test_acc88.8%,test_loss0.338,best_acc:90.0%

在这里插入图片描述
总结:
(1)存在一定过拟合,验证的ACC不高,应该考虑添加dropout,BN,L2正则化
(2)可使用函数对训练的模型参数进行预测,熟悉训练与预测的流程
(3)接口封装集成度比之前更高,主函数更易控制

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

湫椿.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值