在web实现鱼类识别,pytorch,flask,opencv

基于pytorch的预训练模型,结合opencv处理图像,用flask搭建本地服务器,在web给用户提供鱼类识别的项目。

目录

pytorch代码

flask本地服务器代码

以下是我的html文件的代码

index.html 

2.html

1.html


pytorch代码

import os
import cv2
import torch
from random import uniform
from torch import optim, nn, device, cuda, save
import torchvision.transforms as T
import torchvision
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

# opencv图像处理
def pre_handing_train(img_path):
    # 读取图片,灰度化
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    # 除背景噪声
    img = cv2.medianBlur(img, 3)
    img = cv2.GaussianBlur(img, (3, 3), 0)
    # 锐化边缘
    img = cv2.equalizeHist(img)
    return img


def pre_handing_test(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    return img


# 读取图像处理
class PicDataset(Dataset):
    def __init__(self, path, transform=None):
        if path == "pic/":
            pre_handing = pre_handing_train
        else:
            pre_handing = pre_handing_test
        join = os.path.join
        listdir = os.listdir
        self.transform = transform
        self.images = []  # 存储图像数据
        self.labels = []  # 存储标签
        appendI = self.images.append
        appendL = self.labels.append
        # 遍历每个分类的文件夹
        for folder in listdir(path):
            folder_path = join(path, folder)
            # 遍历每张图片
            for filename in listdir(folder_path):
                # 每一张图片所在的文件路径
                img_path = join(folder_path, filename)
                # opencv处理
                img = pre_handing(img_path)
                # 存储图像信息
                appendI(img)
                appendL(int(folder))

    # 每一张图片的存储长度
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        # 转存储格式为tensor格式
        if self.transform is not None:
            image = self.transform(image)

        return image, label


train_dir = "fishP/train"
# 构建训练集的各种操作集合,创建训练集
transform_train = T.Compose([
    T.ToPILImage(),  # 将图片存储格式转为PIL类型
    T.Resize((32, 32)),  # 将图片按比例重塑成64*64像素大小
    T.Grayscale(num_output_channels=3),  # 灰度化图片,输出三个通道
    T.ToTensor(),  # 将图片存储格式转为tensor类型
    # 数据标准化,即均值为0,标准差为1。
    # 简单来说就是将数据按通道进行计算,将每一个通道的数据先计算出其方差与均值,然后再将其每一个通道内的每一个数据减去均值,再除以方差,得到归一化后的结果。
    # 在深度学习图像处理中,标准化处理之后,可以使数据更好的响应激活函数,提高数据的表现力,减少梯度爆炸和梯度消失的出现。
    T.Normalize((0.5,), (0.5,)),
    T.RandomHorizontalFlip(p=0.3),  # 随机水平翻转
    T.RandomRotation(10)  # 随机将图片旋转10°
])
trainset = PicDataset(train_dir, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=2, shuffle=True)


# 加载预训练模型vgg16并更改分类器的输出
model =torchvision.models.vgg16(weights="VGG16_Weights.DEFAULT")
for p in model.features.parameters():
    p.requires_grad=False
model.classifier[-1].out_features=23

# 优先选择显卡训练
device = device("cuda:0" if cuda.is_available() else "cpu")
model.to(device)

# 损失函数计算
criterion = nn.CrossEntropyLoss()
# 优化器
optimizer = optim.SGD(model.parameters(), lr=uniform(0.001, 0.01), momentum=0.9)

print('start train')
for epoch in range(10):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        # 对于分类任务,output是一个概率分布,每个元素表示相应的类别的置信度得分,可使用softmax将其转化为概率值
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # 取模运算,每200个小批量打印一次
        if i % 200 == 199:
            print('[%d, %5d] loss: %.3f' %
                (epoch + 1, i + 1, running_loss / 200))
            running_loss = 0.0
print('Finished Train')

# 保存模型的参数
torch.save(model.state_dict(), 'model.pth')

# 测试集路径
test_dir = "fishP/valid/"
# 构建测试集的各种操作集合,创建测试级
transform_test = T.Compose([
    T.ToPILImage(),  # 将图片存储格式转为PIL类型
    T.Resize((64, 64)),  # 将图片按比例重塑成64*64像素大小
    T.Grayscale(num_output_channels=3),  # 灰度化图片,输出三个通道
    T.Normalize((0.5,), (0.5,)),
    T.ToTensor(),  # 将图片存储格式转为tensor类型
])
testset = PicDataset(test_dir, transform=transform_test)
testloader = DataLoader(testset, batch_size=64, shuffle=True)

# 测试模型
with torch.no_grad():
    correct = 0
    total = 0
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' % (
            100 * correct / total))

flask本地服务器代码

import os
from flask import Flask, request, render_template
from werkzeug.utils import secure_filename
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import torchvision.models as models

app = Flask(__name__)
# 用户上传的文件的保存路径
UPLOAD_FOLDER = 'static'
ALLOWED_EXTENSIONS = {'jpg', 'jpeg', 'png', 'gif'}


def allowed_file(filename):
    return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS


@app.route('/')
def index():
    return render_template('index.html')

@app.route('/1.html', methods=["GET"])
def loop1():
    return render_template('1.html')

@app.route('/2.html', methods=["GET"])
def loop2():
    return render_template('2.html')


@app.route('/3.html', methods=["GET"])
def loop3():
    return render_template('3.html')


@app.route('/4.html', methods=["GET"])
def loop4():
    return render_template('4.html')


@app.route('/5.html', methods=["GET"])
def loop5():
    return render_template('5.html')


@app.route('/upload', methods=["POST"])
def upload():
    file = request.files["image"]
    if file and allowed_file(file.filename):
        filename = secure_filename(file.filename)
        file.save(os.path.join(UPLOAD_FOLDER, filename))
        # 参数=路径+文件名
        result = classify_fish(os.path.join(UPLOAD_FOLDER, filename))
        return render_template('2.html', result=result, statement="当前文件:" + filename)
    else:
        return render_template('2.html', statement="未选择文件或文件有误")


def classify_fish(image_file):
    # 加载模型
    model = models.vgg16()
    model.load_state_dict(torch.load('model.pth'))
    model.eval()
    # 图像预处理和归一化
    preprocess = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # 将图像转换为 PyTorch 张量
    image = Image.open(image_file).convert('RGB')
    image = preprocess(image)
    image = image.unsqueeze(0)

    with torch.no_grad():
        output = model(image)
        # torch.max()表示在张量中取最大值,参数:操作对象,操作维度
        # 操作维度是1,表示类别维度上取最大值
        # torch.max(output.data, 1)返回的是一个元组,(最大值,最大值对应的索引)
        # predict是一个张一维量,表示分类物品的编号
        _, predict = torch.max(output.data, 1)

    return predict.item()


if __name__ == '__main__':
    app.secret_key = 'supersecretkey'
    app.config['SESSION_TYPE'] = 'filesystem'
    # 只允许用户上传图片文件
    app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
    app.run(debug=True)

以下是我的文件目录,fishR文件夹是我的项目文件目录

fishP文件夹存放训练集和测试集,static文件夹存放一些用户上传的文件和js文件,templates文件夹存放html文件,test文件夹是存放一些无关紧要的测试文件,不需要建立,app.py存放基于flask建立的本地服务器,p3.py存放着pytorch训练代码。文件夹(static,templates)的命名一定要和我的一样,不然会出问题。

 我的templates文件夹下的目录

fishP下的文件目录

train (训练集)文件夹下放着每一种鱼的文件夹

valid (测试集或者叫验证集)文件夹下放着每一种鱼的文件夹

这里用数字表示鱼的种类

以下是我的html文件的代码

有部分html文件我没贴出来,是因为那些文件暂时是空的,作为我以后的功能拓展。可以根据个人需求建立相应的html文件

index.html 

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>鱼型识别系统</title>
    <style>
        h1 {text-align: center;padding: 30px;color: #ff4800}
        body {
            background-image: url("https://img2.baidu.com/it/u=1029016406,1880474304&fm=253&fmt=auto&app=138&f=JPEG?w=800&h=500");
            background-size: cover;
            background-attachment: fixed;
            opacity: 0.9;
            background-color: cornflowerblue;
        }
        input[type="submit"] {
            background-color: #ff4800;
            border: none;
            padding: 10px 20px;
            text-decoration: none;
            margin: 20px 0;
            border-radius: 5px;
            color: white;
        }
        #d0 {
            text-align: center;
            display: flex;
            justify-content: center;
            align-items: center;
        }

        #d1 {
            width: 30%;
            height: 75%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;

        }
        #d1:hover {
            background-color:rgba(255,255,255,0.6);
        }
        #d2 {
            width: 30%;
            height: 75%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
        }
        #d2:hover {
            background-color:rgba(255,255,255,0.6);
        }
        #d2 > p {
            padding: 70px;
            color: red;
            font-size: 15px;
        }
        #d3 {
            width: 30%;
            height: 75%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
        }
        #d3:hover {
                    background-color:rgba(255,255,255,0.6);
                }
        #d4 {
            text-align: center;
            display: flex;
            justify-content: center;
            align-items: center;
        }
        #d5 {
            width: 30%;
            height: 75%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
        }
        #d5:hover {
                    background-color:rgba(255,255,255,0.6);
                }
        #d6 {
            width: 30%;
            height: 75%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
        }
        #d6:hover {
                    background-color:rgba(255,255,255,0.6);
                }
    </style>
</head>
<body>

<h1>辅助鱼类研究系统</h1>
<div id="d0">
    <div id="d1" onclick="window.location.href='1.html'">
            <h1>3维追踪</h1>
    </div>

    <div id="d2"
         onclick="window.location.href='2.html'">
            <h1>快速识别</h1>
    </div>

    <div id="d3"
         onclick="window.location.href='3.html'">
            <h1>使用指南</h1>
    </div>
</div>

<br>

<div id="d4">
    <div id="d5"
         onclick="window.location.href='4.html'">
            <h1>鱼类种质库</h1>
    </div>
    <div id="d6"
         onclick="window.location.href='5.html'">
            <h1>信息记录</h1>
    </div>
</div>
</body>
</html>

2.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>Title</title>
    <style>
        body {
            /*渐变色*/
            background: radial-gradient(circle at center,cornflowerblue,cadetblue);
        }
        h1 {text-align: center}
        form{
            margin: auto;
            width: 50%;
            padding: 10px;
            text-align: center;
        }
        .result p{
                font-weight: bold;
            }
        button {
            background-color: #ff4800;
            border: none;
            padding: 10px 20px;
            text-decoration: none;
            margin: 20px 0;
            border-radius: 5px;
            color: white;
        }
        button:hover {
            box-shadow: 0 0 25px orangered;
            box-reflect: below 1px linear-gradient(transparent,rgba(0,0,0,0.3));
        }
        input[type="button"] {
            background-color: #ff4800;
            border: none;
            padding: 10px 20px;
            text-decoration: none;
            margin: 20px 0;
            border-radius: 5px;
            color: white;
        }
        input[type="button"]:hover {
            box-shadow: 0 0 25px orangered;
            box-reflect: below 1px linear-gradient(transparent,rgba(0,0,0,0.3));
        }
        main {
            width: 50%;
            height: 50%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
            padding: 1%;
        }
        div {
            display: flex;
        }
        #s1 {
            width: 33%;
            height: 30%;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
            margin: auto;
            padding: 10%;
        }
        #uploadFile {
            display: none;
        }
    </style>

    <script>
        function clickFile(){
            const input=document.querySelector('#uploadFile')
            input.click()
            // let f = document.getElementById('uploadFile').files;
            // let [fileName, fileSize, fileType] = [f[0].name,f[0].size, f[0].type];
            //                           /*分别获取   文件名   文件大小   文件类型*/
            // document.getElementById("stmend").innerHTML="当前选择的文件是"+fileName
        }
    </script>

</head>
<body>
    <header>
        <h1>鱼种属快速识别</h1>
    </header>

<div>
    <main>
        <form enctype="multipart/form-data" method="POST" action="/upload">
            <input type="hidden" name="MAX_FILE_SIZE" value="100000">
            <input id="uploadFile" type="file" name="image">
            <section id="s2">
                {% if statement %}
                <div class="statement">
                    <p>{{ statement }}</p>
                </div>
                {% endif %}
            </section>
            <input type="button" value="上传图片" class="btn" onclick="clickFile()">
            <button type="submit">查询</button>
        </form>
    </main>
</div>
<br>
<div>
    <section id="s1">
        {% if result %}
        <div class="result">
            <p>这张图片上的鱼是 {{ result }}</p>
        </div>
        {% endif %}
    </section>
</div>



</body>
</html>

1.html

<!DOCTYPE html>
<html lang="en">
<head>
    <meta name="viewport" content="initial-scale=1.0, user-scalable=no" />
    <meta http-equiv="Content-Type" content="text/html" charset="UTF-8">
    <title>3d建模</title>
    <style>
        html {height: 100%}
        body {height: 100%;margin: 0;padding: 0
        }
        #container {
            height: 100%;
        }
        #d1 {
            height: 100%;
            width: 100%;
            display: grid;
            grid-template-columns: 3fr 1fr;
        }
        main {
            background-color: rgba(195,150,255,0.5);
            border: 1px solid rgba(195,191,255,0.5);
            text-align: center;
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
        }
        button {
            background-color: orange;
        }
        #d2:hover {
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
        }
        p:hover {
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(255,255,255,0.3);
            text-align: center;
        }
        p {
            border: 2px solid rgba(195,150,255,1);
            border-radius: 20px;
        }
        #d2 {
            background-color: rgba(195,150,255,0.5);
            font-size: 20px;
            border: 2px solid rgba(195,150,255,1);
            border-radius: 20px;
        }
        #d3 {
            background-color: rgba(195,150,255,0.5);
            font-size: 20px;
            border: 2px solid rgba(195,150,255,1);
            border-radius: 20px;
        }
        #d3:hover {
            background-blend-mode: normal,soft-light;
            -webkit-backdrop-filter: blur(50px);
            border-radius: 20px;
            background-color:rgba(195,150,255,0.9);
            text-align: center;
        }


    </style>
    <script type="text/javascript" src="https://api.map.baidu.com/api?v=3.0&ak=你的百度地图密钥"></script>
</head>
<body>
<div id="d1">
    <div id="container"></div>
    <div>
        <main>
            <h1 style="color: #ff4800">西湖鱼群信息</h1>
            <p>时间:<script>
                var d=new Date()
                document.write(d)
            </script></p>
            <p>气温:</p>
            <p>空气湿度:</p>
            <p>水温:</p>
            <p>溶氧量:</p>
            <p>水pH:</p>
            <p>气压:</p>
            <p>绿藻含量</p>
            <p>鱼群数量:</p>
            <p>检测到:</p>
            <p>罗非鱼</p>
            <p>草鱼</p>
            <div id="d2" onclick="window.location.href='2.html'">历史记录</div><br>
            <div id="d3" onclick="window.location.href='index.html'">回到主页</div>
        </main>

    </div>
</div>



<script type="text/javascript">
    // 定义一个自定义控件
    function resZoomControl(){
        //设置默认停靠位置和偏移量
        this.defaultAnchor=BMAP_ANCHOR_TOP_LEFT
        this.defaultOffset=new BMap.Size(10,10)
    }
    // 通过JavaScript的prototype属性继承于BMap.Control
    resZoomControl.prototype=new BMap.Control()
    resZoomControl.prototype.initialize=function(map){
        // 创建一个DOM元素
        var div=document.createElement("div")
        // 添加文字说明
        div.appendChild(document.createTextNode("回中"))
        // 设置样式
        div.style.cursor="pointer"
        div.style.border="1px solid gray"
        div.style.backgroundColor="red"
        // 绑定事件,点击回到中心处
        div.onclick=function(e){
            map.panTo(new BMap.Point(113.370,23.170))
        }

    }
</script>
<script type="text/javascript">
    // 创建地图实例
    var map =new BMap.Map("container");
    // 创建点坐标,华南农业大学西湖
    var point=new BMap.Point(113.370,23.170);
    // 初始化地图设置中心点坐标和地图级别
    map.centerAndZoom(point,15);
    // 或者,你可以设置城市名为中心点
    // map.centerAndZoom("广州",15);

    //创建自定义控件实例
    var myZoomCtrl = new resZoomControl();
    // 添加到地图当中
    map.addControl(myZoomCtrl);
    // 开启鼠标滚轮缩放
    map.enableScrollWheelZoom(true);
    // 添加控件
    map.addControl(new BMap.NavigationControl());// 平移缩放功能
    map.addControl(new BMap.ScaleControl());// 比例尺
    map.addControl(new BMap.OverviewMapControl());// 可折叠的缩略地图
    map.addControl(new BMap.GeolocationControl());// 定位
    map.addControl(new BMap.MapTypeControl());// 设置地图类型
    map.setCurrentCity("广州"); // 仅当设置城市信息时,MapTypeControl的切换功能才能可用


</script>

</body>
</html>















html效果

 

图片数据自取

  • 1
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
基于引用[1]和引用的描述,YOLOv5是一种基于深度学习的目标检测算法,可以用于鱼类识别。YOLOv5是YOLO系列算法的第五代,相比于传统方法,它在检测精度和速度方面表现更好。 要实现YOLOv5鱼类识别,你需要进行以下步骤: 1. 准备数据集:根据引用中的描述,你需要手动标注深海鱼这个类别的图片,并将其划分为训练集和验证集。确保数据集中包含不同旋转和光照条件下的鱼类图片。 2. 调整图片大小:由于YOLOv5对输入图片大小有限制,你需要将所有图片调整为相同的大小。根据引用中的描述,可以将图片调整为640x640的大小,并保持原有的宽高比例。 3. 数据增强:为了增强模型的泛化能力和鲁棒性,你可以使用数据增强技术,如随机旋转、缩放、裁剪和颜色变换等。这些技术可以扩充数据集并减少过拟合风险。 4. 训练模型:使用YOLOv5算法对准备好的数据集进行训练。你可以参考引用中提供的开源代码https://github.com/ultralytics/yolov5来实现模型训练。 5. 模型评估:在训练完成后,你可以使用验证集对训练好的模型进行评估,计算模型的准确性和性能。 6. 鱼类识别:使用训练好的模型对新的鱼类图片进行识别。根据引用中的描述,你可以使用PyTorch和Pyside6库来实现界面系统,完成目标检测识别页面的开发。 请注意,YOLOv5是一种高精度的目标检测算法,但它可能不是唯一的选择。根据引用中的描述,YOLO系列算法的最新进展已有YOLOv6、YOLOv7、YOLOv8等算法。你可以关注这些最新算法的发展,并根据需求选择适合的算法。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值