训练自己的keyPointsDetectionMethodWithTorch

具体自己可以百度

代码 keyPointsDetectionMethodWithTorch: 物体关键点检测算法。 关键点的数量可以随意指定,但训练集需要与指定的关键点数量一致。 本算法是基于Resnet+dsntnn组合网络构成的关键点坐标回归算法--KPDEM_model. 仓库的文件源码很简单。一个模型文件,两个数据处理文件(训练集和测试集),一个训练文件和一个测试文件。以及readme文件。 (gitee.com)

或者看这个博主的介绍

关键点模型算法---服装(Keypoints Detection)_柏常青的博客-CSDN博客_关键点检测模型

我训练100批次  ,   标注了自己8个点10张图片。 时间仓促 效果

 作者效果

百度网盘链接

链接:https://pan.baidu.com/s/1e59ifLHPZmxGr0MkEAeGuA?pwd=3003 
提取码:3003 
--来自百度网盘超级会员V3的分享
 

先等等吧 改的东西也比较多,我梳理梳理。

打自己的标签   box 的name建议英文,没用过中文 避免保错    图片只能用彩色图,灰色图在训练的时候会报错

 

 

 

 

每个点要用不同的name

 然后就获得了json文件,下一步就是要把json  边成csv

这是我给自己写的脚本(不一定都能用)

import gc
import json
import csv
import time
from urllib import request
l=[]

# with open("1234.csv", "a+") as csvfile:
#     writer = csv.writer(csvfile)
#
#     # 先写入columns_name
#     writer.writerow(["image_id", "image_category", "1", "2", "3", "4", "5", "6", "7", "8"])

def baoc(m, null=None):
    with open(m,encoding="UTF-8") as file:
        reader = json.load(file)
        print(reader)
        print(reader["imagePath"])
        n=reader["imagePath"]
        name=[n+",yifu,"]
        for i in reader["shapes"]:
            if i["label"]=="yifu":
                print(1)
            else:
                print (i['points'][0][0])
                print(i['points'][0][1])
                a=i['points'][0][0]
                b=i['points'][0][1]
                a=int(a)
                b=int(b)
                c=[a,"_",b,"_2"]

                l.append(c)

    print(l)
    str(l)


    #f.writelines(内容)






    num=0
    with open(r'1.txt', mode='w', newline='', encoding='utf8') as cfa:

         wf = csv.writer(cfa)
         for i in l:
             print(i)
             if num>0:
                wf.writerow(",")
                wf.writerow(i)

             else:
                 wf.writerow(i)
                 num+=1

    path = r'1.txt'#文本存放的路径

    with open(path) as file:
        lines = file.readlines()#读取每一行

    a = ''#空字符(中间不加空格)
    d = ''#空字符(中间不加空格)
    for line in lines:
        a += line.strip()#strip()是去掉每行末尾的换行符\n 1

    c = a.split()#将a分割成每个字符串 2
    b = ''.join(c)#将c的每个字符不以任何符号直接连接 3
    print(a)
    print(b)
    print(c)
    print(d)
    transformed_string=a.replace(",","")
    print("Transformed String is:")
    print(transformed_string)
    y=transformed_string.replace('""', ",")
    print("********************")
    print(y)
    print(type(y))
    al=[y]
    print(al)


    # with open("1234.csv", "a+") as csvfile:
    #     writer = csv.writer(csvfile)
    #
    #     # 先写入columns_name
    #     writer.writerow(a)

    path  = "1234.csv"
    with open(path,'a+') as f:
        csv_write = csv.writer(f)
        print(f)
        csv_write.writerow(al)

        print(type(al))

    print("********************")
    y=""
    c.clear()
    l.clear()
    for i in range(len(y)):
        print("********************")
        print(i)

    for i in range(len(al)):
        del al[i]
        print("********************")
        print(al)
        print(y)

    with open(r'1.txt', 'a+', encoding='utf-8') as test:
        test.truncate(0)
        time.sleep(0.5)




if __name__ == '__main__':
    for i in range(1,2000):

        if (i<=9):
            i=str(i)
            m="C:\\Users\\pc\\Desktop\\脚本\\1\\000" +i+".json"
        elif i>=10 and i<=99:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\00" + i + ".json"
        elif i >= 100 and i <= 999:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\0" + i + ".json"
        elif i >= 1000 and i <= 9999:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\" + i + ".json"
        baoc(m)

图片格式  需要自己创建  1.TXT  和1234.csv 两个文件,我的类别是yifu  脚本里面要改

 

脚本跑出来是这个样子的(水平有限)

 知识点(如果脚本跑不出来,要么自己写,要么安照下面的格式手动粘贴)

创建训练csv名字       train.csv

 第一行 这样写   1不变    2是   关键点的name

 后面就是    图片路径+框的类别+具体的点

 这就制作完成了

知识点  :里面是没有冒号  空格的

为了放置有些人看不懂

我这里用上面标注人头的那个写一下格式

这一块可以不看(新写的脚本和我一样懒的复制的可以学习一下)

先用代码跑一个文件路径

import csv
with open(r'C:\Users\pc\Desktop\脚本\test.csv', 'w', newline= '') as w :
    for i in range(1, 2000):

        if (i <= 9):
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\000" + i + ".jpg"
        elif i >= 10 and i <= 99:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\00" + i + ".jpg"
        elif i >= 100 and i <= 999:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\0" + i + ".jpg"
        elif i >= 1000 and i <= 9999:
            i = str(i)
            m = "C:\\Users\\pc\\Desktop\\脚本\\1\\" + i + ".jpg"
        csv_write = csv.writer(w)
        m=m+",yifu"
        a=[m]
        csv_write.writerow(a)


import csv

with open("test.csv", 'rt') as f:

    data = f.read()

    new_data = data.replace('"', '')

    for row in csv.reader(new_data.splitlines(), delimiter=' ', skipinitialspace=True):
        csv_write = csv.writer(w)
        print ('|'.join(row))
        print("")

效果

 不要管效果

把上面这个代码的打印的东西粘贴出来 把第一行和末尾删除

 

import csv
删除空行的代码

import pandas as pd
data = pd.read_csv("test.csv")
res = data.dropna(how="all")
res.to_csv("test.csv", index=False)

 这个代码去除一下空格

去冒号
import csv
with open("1234.csv", 'rt') as f:

    data = f.read()

    new_data = data.replace('"', '')

    for row in csv.reader(new_data.splitlines(), delimiter=' ', skipinitialspace=True):

        print ('|'.join(row))

 还是把打印的内容粘贴出来  并删除开头和末尾

在去一下空格   效果

拼接代码
import csv
with open('test.csv', 'r') as f:
     reader = csv.reader(f)
     result = list(reader)
with open('shuju.csv', 'r') as w:
    reader1 = csv.reader(w)
    result1 = list(reader1)


with open(r'C:\Users\pc\Desktop\脚本\test1.csv', 'w+', newline='') as w1:

     csv_write = csv.writer(w1)
     for i in range(1999):
         m = result[i]+result1[i]
         print(m)
         csv_write.writerow(m)


拼接完成以后就是下面这样

你制作的数据集也是和下面差不多的

 记得在开头加上

image_id,image_category,1,2,3,4,5,6,7,8,9,10,11,12,13,14

 

 

 

 

 闲了在更新出现这个错  找到拼接好的数据    他说122  17 个值

 发现多了个0-1-2     直接删除

 可以了

新家脚本

import json

for i in range(2115,4229):

    i=str(i)
    print (i)
    a="C:/Users/pc/Desktop/csxw/w/"+i+ ".json"


    with open(a, 'rb') as f:  # 使用只读模型,并定义名称为f
        params = json.load(f)  # 加载json文件
        params["imagePath"] = str(i)+".jpg"  # code字段对应的值修改为404

    f.close()  # 关闭json读模式
    with open(a, 'w') as r:
        # 将dict写入名称为r的文件中
        json.dump(params, r)
        # 关闭json写模式
    r.close()

功能改变 这个值

 删除标签   记得先备份

# !/usr/bin/env python
# -*- encoding: utf-8 -*-

import os
import json

# 这里写你自己的存放照片和json文件的路径
json_dir = r'C:\Users\pc\Desktop\zjk\xin1/'
json_files = os.listdir(json_dir)

# 这里写你要删除的标签名
delete_name = "16"

for json_file in json_files:
    json_file_ext = os.path.splitext(json_file)

    if json_file_ext[1] == '.json':
        # 判断是否为json文件
        jsonfile = json_dir + '\\' + json_file

        with open(jsonfile, 'r', encoding='utf-8') as jf:
            info = json.load(jf)

            # for i, label in enumerate(info['shapes']):
            for i in range(len(info['shapes'])-1,0,-1):
                if info['shapes'][i]['label'] == delete_name:
                    # 找到位置进行删除
                    del info['shapes'][i]
            # 使用新字典替换修改后的字典
            json_dict = info

        # 将替换后的内容写入原文件
        with open(jsonfile, 'w') as new_jf:
            json.dump(json_dict, new_jf)

print('delete label over!')

测试代码:

        修改csv

 

 其他就是路径的问题了

 

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 10
    评论
评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WangSaLe

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值