Day2:Alexnet训练自己的数据集--多分类

首先说明,本文采用多分类40类自制数据集(数据集形式在下附图)

一、Model

二、dataset

三、train

四、遇到的问题

一、Model

框架搭建根据Alexnet网络构建

class Alexnet(nn.Module):
    def __init__(self):
        super(Alexnet,self).__init__()
        self.backbone = nn.Sequential(
            nn.Conv2d(3,96,11,4),
            nn.ReLU(),
            nn.MaxPool2d(3,2),

            nn.Conv2d(96,256,5,1,2),
            nn.MaxPool2d(3,2),
            nn.ReLU(),

            nn.Conv2d(256,384,3,1,1),
            nn.ReLU(),
            nn.Conv2d(384,384,3,1,1),
            nn.ReLU(),
            nn.Conv2d(384, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(3,2),
        )

        self.fc = nn.Sequential(
            nn.Linear(128*6*6,4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096,2000),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(2000,40),
        )

    def forward(self,x):
        f = self.backbone(x)
        output = torch.flatten(f,1)
        output = self.fc(output)
        return output

注意:最后fc层只是变成40,并没有加softmax,因为训练将使用crossentorpyloss里面自带softmax

二、dataset

class Train_Mydataset(Dataset):
    def __init__(self,root):
        self.dataset = []
        self.root = root
        label_list = os.listdir(root)#0 1 名
        for img_list in label_list:
            new_img_list = os.path.join(root,img_list)#'./0'  './1'地址
            new_img_list = new_img_list.replace("\\", "/")
            new_img_name = os.listdir(new_img_list)#每个图的名字
            for name in new_img_name:
                img_dir = os.path.join(new_img_list,name)
                img_dir = img_dir.replace("\\",'/')
                self.dataset.append((img_dir))


    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        img = cv2.imread(data)/255
        img = cv2.resize(img,(256,256))
        img = np.transpose(img,(2,0,1))
        label = int(data.split(('/'))[-2])
        # print(label)
        return np.float32(img),np.float32(label)

注意:\的使用可能会被误认为转义字符,做好使用/

三、train

def train(net,train_root):
    empoches = 100;
    batch_size = 8;
    net = net.to(device)
    print("training on",device)

    #定义一个优化器
    optimizer = torch.optim.Adam(net.parameters(),lr=0.003)

    #学习率调整
    scheduler = lr_scheduler.StepLR(optimizer,step_size=10,gamma=0.5)

    #定义一个损失函数
    loss = nn.CrossEntropyLoss()

    for epoch in range(empoches):
        loss_sum,acc_sum  = 0.0,0.0
        dataset_train = Train_Mydataset(train_root)
        # train_n = dataset_train.__len__()
        dataloader_train = DataLoader(dataset_train,batch_size = batch_size,shuffle=True)
        train_n = len(dataloader_train)
        # dataset_val = Train_Mydataset(val_root)
        # dataloader_val = DataLoader(dataset_val,batch_size = batch_size,shuffle=True)
        # val_n = dataset_val.__len__()

        scheduler.step()
        for img,label in tqdm(dataloader_train):

            label = label.long().to(device)
            # label = onehot(label,40).cuda()

            p = net(img.to(device).cuda())
            loss_c = loss(p,label)
            p = nn.softmax(p)
            p = torch.argmax(p,dim=1)


            #反向传播
            optimizer.zero_grad()
            loss_c.backward()
            optimizer.step()
            loss_sum += loss_c.cpu().item()
            acc_sum += (p == label).sum().cpu().item()
        print('train--epoch %d,lr%8f,loss%8f,acc%.3f'%(epoch,scheduler.get_lr()[0],loss_sum/train_n,acc_sum/train_n))

注意;nn.CrossEntopyLoss输入的预测值不需要softmax,输入的标签也不需要转成onehot

 

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
可以使用Java中的SimpleDateFormat类将String类型的时间转换为Date对象,然后再对Date对象进行处理。 首先,我们先将时间段转换为【2020-11-20 00:00:00,2021-10-09 23:59:59】,即将结束时间改为当天的最后一秒: ```java String startTime = "2020-11-20 09:09:09"; String endTime = "2021-10-09 10:10:10"; SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); Date startDate = sdf.parse(startTime); Date endDate = sdf.parse(endTime); Calendar endCalendar = Calendar.getInstance(); endCalendar.setTime(endDate); endCalendar.set(Calendar.HOUR_OF_DAY, 23); endCalendar.set(Calendar.MINUTE, 59); endCalendar.set(Calendar.SECOND, 59); endDate = endCalendar.getTime(); String newEndTime = sdf.format(endDate); System.out.println(newEndTime); ``` 输出结果为:2021-10-09 23:59:59 接下来,我们可以使用一个循环,每次增加一天,输出该天的开始时间和结束时间: ```java Calendar calendar = Calendar.getInstance(); calendar.setTime(startDate); while (calendar.getTime().before(endDate)) { Date startOfDay = calendar.getTime(); calendar.add(Calendar.DAY_OF_MONTH, 1); calendar.set(Calendar.HOUR_OF_DAY, 0); calendar.set(Calendar.MINUTE, 0); calendar.set(Calendar.SECOND, 0); Date endOfDay = calendar.getTime(); System.out.println(sdf.format(startOfDay) + " - " + sdf.format(endOfDay)); } // 输出最后一天的开始时间和结束时间 Date startOfDay = calendar.getTime(); System.out.println(sdf.format(startOfDay) + " - " + newEndTime); ``` 输出结果为: ``` 2020-11-20 09:09:09 - 2020-11-21 00:00:00 2020-11-21 00:00:00 - 2020-11-22 00:00:00 2020-11-22 00:00:00 - 2020-11-23 00:00:00 ... 2021-10-07 00:00:00 - 2021-10-08 00:00:00 2021-10-08 00:00:00 - 2021-10-09 00:00:00 2021-10-09 00:00:00 - 2021-10-09 23:59:59 ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值