Pytorch实现鲜花分类(102 Category Flower Dataset)

本文主要讲解该算法的实现过程

实验环境

python3.6+pytorch1.2+cuda10.1

数据集

102 Category Flower Dataset数据集由102类产自英国的花卉组成,每类由40-258张图片组成

下边使用的数据集看好多人私信要,我就上传到CSDN了:鲜花分类集(已划分)-深度学习文档类资源-CSDN下载

下边是代码实现过程及讲解

数据加载

#选择设备
device = torch.device("cuda:0")
#对三种数据集进行不同预处理,对训练数据进行加强
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])
}

#数据目录
data_dir = "./data"

#获取三个数据集
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                 data_transforms[x]) for x in ['train', 'valid','test']}
traindataset = image_datasets['train']
validdataset = image_datasets['valid']
testdataset = image_datasets['test']

batch_size = 8
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
             shuffle=True, num_workers=4) for x in ['train', 'valid','test']}
print(dataloaders)
traindataloader = dataloaders['train']
validdataloader = dataloaders['valid']
testdataloader = dataloaders['test']

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid','test']}

定义网络结构

 使用ResNet152

#使用resnet152的网络结构,最后一层全连接重写输出102
class Net(nn.Module):
    def __init__(self,model):
        super(Net,self).__init__()
        self.resnet = nn.Sequential(*list(model.children())[:-1])
        #可以选择冻结卷积层
        # for p in self.parameters():
        #     p.requires_grad = False
        self.fc = nn.Linear(in_features=2048,out_features=102)


    def forward(self,x):
        x = self.resnet(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x

resnet152 = models.resnet152(pretrained=True)
net = Net(resnet152)

 使用VGG19

class Net(nn.Module):
    def __init__(self,model):
        super(Net,self).__init__()
        self.features = model.features
        # for p in self.parameters():
        #     p.requires_grad = False
        self.classifier = nn.Sequential(
            nn.Linear(25088, 4096,bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5,inplace=False),
            nn.Linear(4096, 4096,bias=True),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5,inplace=False),
            nn.Linear(4096, 102,bias=True)
        )

    def forward(self,x):
        x = self.features(x)
        x = x.view(x.shape[0], -1)
        x = self.classifier(x)
        return x


vgg = models.vgg19(pretrained=True)
net = Net(vgg)

参数设定

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),lr=0.0001,momentum=0.9)

测试集检验

def valid_model(model, criterion):
    best_acc = 0.0
    print('-' * 10)

    running_loss = 0.0
    running_corrects = 0
    model = model.to(device)
    for inputs, labels in validdataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.eval()
        with torch.no_grad():
            outputs = model(inputs)
        loss = criterion(outputs, labels)

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels)
    epoch_loss = running_loss / dataset_sizes['valid']
    print(running_corrects.double())
    epoch_acc = running_corrects.double() / dataset_sizes['valid']
    print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            'valid', epoch_loss, epoch_acc))
    print('-' * 10)
    print()

验证集检验

def test_model(model, criterion):
    best_acc = 0.0
    print('-' * 10)

    running_loss = 0.0
    running_corrects = 0
    model = model.to(device)
    for inputs, labels in testdataloader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        model.eval()
        with torch.no_grad():
            outputs = model(inputs)
        loss = criterion(outputs, labels)

        _, preds = torch.max(outputs, 1)
        running_loss += loss.item()
        running_corrects += torch.sum(preds == labels)
    epoch_loss = running_loss / dataset_sizes['test']
    print(running_corrects.double())
    epoch_acc = running_corrects.double() / dataset_sizes['test']
    print('{} Loss: {:.4f} Acc: {:.4f}'.format(
            'test', epoch_loss, epoch_acc))
    print('-' * 10)
    print()

训练模型

def train_model(model, criterion, optimizer, num_epochs=5):
    since = time.time()
    best_acc = 0.0
    for epoch in range(num_epochs):
        if (epoch+1)%5==0:
            test_model(model, criterion)
        print('-' * 10)
        print('Epoch {}/{}'.format(epoch+1, num_epochs))

        running_loss = 0.0
        running_corrects = 0
        model = model.to(device)
        for inputs, labels in traindataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            model.train()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item()
            running_corrects += torch.sum(preds == labels)
        epoch_loss = running_loss / dataset_sizes['train']
        print(dataset_sizes['train'])
        print(running_corrects.double())
        epoch_acc = running_corrects.double() / dataset_sizes['train']
        best_acc = max(best_acc,epoch_acc)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                'train', epoch_loss, epoch_acc))

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    return model

开始训练

epochs = 48
model = train_model(net, criterion, optimizer, epochs)

valid_model(model,criterion)

torch.save(model, 'model.pkl')

输出结果

在最后,训练集的精度达到了97.01%,测试集上达到了97.45%

 

好的,下面是一个简单的鲜花店数据库的 SQL 代码,包含了一些基本的数据表和查询语句。请注意,这只是一个示例,具体的数据库设计和代码实现需要根据实际情况进行调整和修改。 ``` -- 创建鲜花店数据库 CREATE DATABASE flower_shop; -- 创建鲜花类别表 CREATE TABLE flower_category ( id INT NOT NULL AUTO_INCREMENT, name VARCHAR(50) NOT NULL, PRIMARY KEY (id) ); -- 创建鲜花表 CREATE TABLE flower ( id INT NOT NULL AUTO_INCREMENT, name VARCHAR(50) NOT NULL, price DECIMAL(10, 2) NOT NULL, category_id INT NOT NULL, PRIMARY KEY (id), FOREIGN KEY (category_id) REFERENCES flower_category(id) ); -- 创建客户表 CREATE TABLE customer ( id INT NOT NULL AUTO_INCREMENT, name VARCHAR(50) NOT NULL, email VARCHAR(50) NOT NULL, PRIMARY KEY (id) ); -- 创建订单表 CREATE TABLE order ( id INT NOT NULL AUTO_INCREMENT, customer_id INT NOT NULL, order_date DATE NOT NULL, total_amount DECIMAL(10, 2) NOT NULL, PRIMARY KEY (id), FOREIGN KEY (customer_id) REFERENCES customer(id) ); -- 创建订单详情表 CREATE TABLE order_detail ( id INT NOT NULL AUTO_INCREMENT, order_id INT NOT NULL, flower_id INT NOT NULL, quantity INT NOT NULL, unit_price DECIMAL(10, 2) NOT NULL, PRIMARY KEY (id), FOREIGN KEY (order_id) REFERENCES order(id), FOREIGN KEY (flower_id) REFERENCES flower(id) ); -- 查询所有鲜花 SELECT * FROM flower; -- 查询指定类别的鲜花 SELECT * FROM flower WHERE category_id = 1; -- 查询客户的订单 SELECT * FROM order WHERE customer_id = 1; -- 查询订单详情 SELECT * FROM order_detail WHERE order_id = 1; ``` 请注意,这只是一个示例代码,具体的应用场景和需求可能需要进行更多的定制和优化。如果您有任何问题或需要进一步的帮助,请随时告诉我。
评论 34
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

负壹

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值