Pytorch搭建GoogLeNet网络(奥特曼分类)

1 爬取奥特曼

get_data.py

import requests
import urllib.parse as up
import json
import time
import os

major_url = 'https://image.baidu.com/search/index?'
headers = {
   'User-Agent' : 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.135 Safari/537.36'}
def pic_spider(kw, path, page = 10):
    path = os.path.join(path, kw)
    if not os.path.exists(path):
        os.mkdir(path)
    if kw != '':
        for num in range(page):
            data = {
   
                "tn": "resultjson_com",
                "logid": "11587207680030063767",
                "ipn": "rj",
                "ct": "201326592",
                "is": "",
                "fp": "result",
                "queryWord": kw,
                "cl": "2",
                "lm": "-1",
                "ie": "utf-8",
                "oe": "utf-8",
                "adpicid": "",
                "st": "-1",
                "z": "",
                "ic": "0",
                "hd": "",
                "latest": "",
                "copyright": "",
                "word": kw,
                "s": "",
                "se": "",
                "tab": "",
                "width": "",
                "height": "",
                "face": "0",
                "istype": "2",
                "qc": "",
                "nc": "1",
                "fr": "",
                "expermode": "",
                "force": "",
                "pn": num*30,
                "rn": "30",
                "gsm": oct(num*30),
                "1602481599433": ""
            }
            url = major_url + up.urlencode(data)
            i = 0
            pic_list = []
            while i < 5:
                try:
                    pic_list = requests.get(url=url, headers=headers).json().get('data')
                    break
                except:
                    print('网络不好,正在重试...')
                    i += 1
                    time.sleep(1.3)
            for pic in pic_list:
                url = pic.get('thumbURL', '') # 有的没有图片链接,就设置成空
                if url == '':
                    continue
                name = pic.get('fromPageTitleEnc')
                for char in ['?', '\\', '/', '*', '"', '|', ':', '<', '>']:
                    name = name.replace(char, '') # 将所有不能出现在文件名中的字符去除掉
                type = pic.get('type', 'jpg') # 找到图片的类型,若没有找到,默认为 jpg
                pic_path = (os.path.join(path, '%s.%s') % (name, type))
                print(name, '已完成下载')
                if not os.path.exists(pic_path):
                    with open(pic_path, 'wb') as f:
                        f.write(requests.get(url = url, headers = headers).content)
cwd = os.getcwd() # 当前路径        
file1 = 'flower_data/flower_photos'
file2 = '数据/下载数据'
save_path = os.path.join(cwd,file2)
#flower_class = [cla for cla in os.listdir(file1) if ".txt" not in cla]
#lists = ['猫','哈士奇','燕子','恐龙','鹦鹉','老鹰','柴犬','田园犬','咖啡猫','老虎','狮子','哥斯拉','奥特曼']
lists = ['佐菲','初代','赛文','杰克','艾斯','泰罗','奥特之父','奥特之母','爱迪','尤莉安','雷欧','阿斯特拉','奥特之王','葛雷','帕瓦特','奈克斯特','奈克瑟斯','哉阿斯','迪加','戴拿','盖亚(大地)','阿古茹(海洋)','高斯(慈爱)','杰斯提斯(正义)','雷杰多(高斯与杰斯提斯的合体)','诺亚(奈克斯特的最终形态)','撒加','奈欧斯','赛文21','麦克斯','杰诺','梦比优斯','希卡利','赛罗','赛文X']


for list in lists:
    if not os.path.exists(save_path):
        os.mkdir(save_path)
    pic_spider(list+'奥特曼',save_path, page = 10)

print("lists_len: ",len(lists))

2 划分数据集

训练集 train :80%
验证集 val :10%
测试集 predict :10%

spile_data.py

import os
from shutil import copy
import random
import cv2


def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


#file = 'flower_data/flower_photos'
file = '数据/下载数据'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
#mkfile('flower_data/train')
mkfile('数据/train')
for cla in flower_class:
    #mkfile('flower_data/train/'+cla)
    mkfile('数据/train/'+cla)

#mkfile('flower_data/val')
mkfile('数据/val')
for cla in flower_class:
    #mkfile('flower_data/val/'+cla)
    mkfile('数据/val/'+cla)

mkfile('数据/predict')
for cla in flower_class:
    #mkfile('flower_data/predict/'+cla)
    mkfile('数据/predict/'+cla)

split_rate = 0.1

for cla in flower_class:
    images = []
    cla_path = file + '/' + cla + '/'

    # 过滤jpg和png
    images1 = [cla1 for cla1 in os.listdir(cla_path) if ".jpg" in cla1]
    images2 = [cla1 for cla1 in os.listdir(cla_path) if ".png" in cla1]+images1

    # 去掉小于256的图
    for image in images2:
        img = cv2.imread<
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值