20230918-pytorch训练模型,c++调用并使用模型,传入数据格式统一的问题

6 篇文章 0 订阅
4 篇文章 0 订阅

20230918——pytorch训练模型,c++调用模型,以及参数的问题

一、技术链方式

  • pytorch——>onnx——>c++

二、实验

2.1 实验代码

该部分的实验代码使用的是,参考博文

【ONNX】使用 C++ 调用 ONNX 格式的 PyTorch 深度学习模型进行预测(Windows, C++, PyTorch, ONNX, Visual Studio, OpenCV)

【ONNX】导出,载入PyTorch的ONNX模型并进行预测新手教程Windows+Python+Pycharm+PyTorch+ONNX)

2.1.1 pytorch

训练并保存模型

import torch
import torchvision

dummy_input = torch.randn(1, 3, 224, 224, device="cuda")
model = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.IMAGENET1K_V1).cuda()

input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16) ]
output_names = [ "output1" ]

torch.onnx.export(model, dummy_input, "alexnet.onnx", verbose=True, input_names=input_names,
                  output_names=output_names)

2.2.2 C++

加载并使用模型

#include <iostream>
#include <string>
#include <onnxruntime_cxx_api.h>
#include<opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>

using namespace std;

int main()
{
    string labels[] = { "tench", "goldfish", "great white shark", "tiger shark", "hammerhead", "electric ray", "stingray", "cock", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "robin", "bulbul", "jay", "magpie", "chickadee", "water ouzel", "kite", "bald eagle", "vulture", "great grey owl", "European fire salamander", "common newt", "eft", "spotted salamander", "axolotl", "bullfrog", "tree frog", "tailed frog", "loggerhead", "leatherback turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", "common iguana", "American chameleon", "whiptail", "agama", "frilled lizard", "alligator lizard", "Gila monster", "green lizard", "African chameleon", "Komodo dragon", "African crocodile", "American alligator", "triceratops", "thunder snake", "ringneck snake", "hognose snake", "green snake", "king snake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "rock python", "Indian cobra", "green mamba", "sea snake", "horned viper", "diamondback", "sidewinder", "trilobite", "harvestman", "scorpion", "black and gold garden spider", "barn spider", "garden spider", "black widow", "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie chicken", "peacock", "quail", "partridge", "African grey", "macaw", "sulphur-crested cockatoo", "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "drake", "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", "American egret", "bittern", "crane bird", "limpkin", "European gallinule", "American coot", "bustard", "ruddy turnstone", "red-backed sandpiper", "redshank", "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", "Chihuahua", "Japanese spaniel", "Maltese dog", "Pekinese", "Shih-Tzu", "Blenheim spaniel", "papillon", "toy terrier", "Rhodesian ridgeback", "Afghan hound", "basset", "beagle", "bloodhound", "bluetick", "black-and-tan coonhound", "Walker hound", "English foxhound", "redbone", "borzoi", "Irish wolfhound", "Italian greyhound", "whippet", "Ibizan hound", "Norwegian elkhound", "otterhound", "Saluki", "Scottish deerhound", "Weimaraner", "Staffordshire bullterrier", "American Staffordshire terrier", "Bedlington terrier", "Border terrier", "Kerry blue terrier", "Irish terrier", "Norfolk terrier", "Norwich terrier", "Yorkshire terrier", "wire-haired fox terrier", "Lakeland terrier", "Sealyham terrier", "Airedale", "cairn", "Australian terrier", "Dandie Dinmont", "Boston bull", "miniature schnauzer", "giant schnauzer", "standard schnauzer", "Scotch terrier", "Tibetan terrier", "silky terrier", "soft-coated wheaten terrier", "West Highland white terrier", "Lhasa", "flat-coated retriever", "curly-coated retriever", "golden retriever", "Labrador retriever", "Chesapeake Bay retriever", "German short-haired pointer", "vizsla", "English setter", "Irish setter", "Gordon setter", "Brittany spaniel", "clumber", "English springer", "Welsh springer spaniel", "cocker spaniel", "Sussex spaniel", "Irish water spaniel", "kuvasz", "schipperke", "groenendael", "malinois", "briard", "kelpie", "komondor", "Old English sheepdog", "Shetland sheepdog", "collie", "Border collie", "Bouvier des Flandres", "Rottweiler", "German shepherd", "Doberman", "miniature pinscher", "Greater Swiss Mountain dog", "Bernese mountain dog", "Appenzeller", "EntleBucher", "boxer", "bull mastiff", "Tibetan mastiff", "French bulldog", "Great Dane", "Saint Bernard", "Eskimo dog", "malamute", "Siberian husky", "dalmatian", "affenpinscher", "basenji", "pug", "Leonberg", "Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "chow", "keeshond", "Brabancon griffon", "Pembroke", "Cardigan", "toy poodle", "miniature poodle", "standard poodle", "Mexican hairless", "timber wolf", "white wolf", "red wolf", "coyote", "dingo", "dhole", "African hunting dog", "hyena", "red fox", "kit fox", "Arctic fox", "grey fox", "tabby", "tiger cat", "Persian cat", "Siamese cat", "Egyptian cat", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", "brown bear", "American black bear", "ice bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", "ladybug", "ground beetle", "long-horned beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", "cricket", "walking stick", "cockroach", "mantis", "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "admiral", "ringlet", "monarch", "cabbage butterfly", "sulphur butterfly", "lycaenid", "starfish", "sea urchin", "sea cucumber", "wood rabbit", "hare", "Angora", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "sorrel", "zebra", "hog", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", "ram", "bighorn", "ibex", "hartebeest", "impala", "gazelle", "Arabian camel", "llama", "weasel", "mink", "polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas", "baboon", "macaque", "langur", "colobus", "proboscis monkey", "marmoset", "capuchin", "howler monkey", "titi", "spider monkey", "squirrel monkey", "Madagascar cat", "indri", "Indian elephant", "African elephant", "lesser panda", "giant panda", "barracouta", "eel", "coho", "rock beauty", "anemone fish", "sturgeon", "gar", "lionfish", "puffer", "abacus", "abaya", "academic gown", "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", "amphibian", "analog clock", "apiary", "apron", "ashcan", "assault rifle", "backpack", "bakery", "balance beam", "balloon", "ballpoint", "Band Aid", "banjo", "bannister", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", "barrow", "baseball", "basketball", "bassinet", "bassoon", "bathing cap", "bath towel", "bathtub", "beach wagon", "beacon", "beaker", "bearskin", "beer bottle", "beer glass", "bell cote", "bib", "bicycle-built-for-two", "bikini", "binder", "binoculars", "birdhouse", "boathouse", "bobsled", "bolo tie", "bonnet", "bookcase", "bookshop", "bottlecap", "bow", "bow tie", "brass", "brassiere", "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", "bullet train", "butcher shop", "cab", "caldron", "candle", "cannon", "canoe", "can opener", "cardigan", "car mirror", "carousel", "carpenter's kit", "carton", "car wheel", "cash machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", "cellular telephone", "chain", "chainlink fence", "chain mail", "chain saw", "chest", "chiffonier", "chime", "china cabinet", "Christmas stocking", "church", "cinema", "cleaver", "cliff dwelling", "cloak", "clog", "cocktail shaker", "coffee mug", "coffeepot", "coil", "combination lock", "computer keyboard", "confectionery", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", "cowboy hat", "cradle", "crane", "crash helmet", "crate", "crib", "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "dial telephone", "diaper", "digital clock", "digital watch", "dining table", "dishrag", "dishwasher", "disk brake", "dock", "dogsled", "dome", "doormat", "drilling platform", "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", "envelope", "espresso maker", "face powder", "feather boa", "file", "fireboat", "fire engine", "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four - poster", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", "gasmask", "gas pump", "goblet", "go - kart", "golf ball", "golfcart", "gondola", "gong", "gown", "grand piano", "greenhouse", "grille", "grocery store", "guillotine", "hair slide", "hair spray", "half track", "hammer", "hamper", "hand blower", "hand - held computer", "handkerchief", "hard disc", "harmonica", "harp", "harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", "hoopskirt", "horizontal bar", "horse cart", "hourglass", "iPod", "iron", "jack - o'-lantern", "jean", "jeep", "jersey", "jigsaw puzzle", "jinrikisha", "joystick", "kimono", "knee pad", "knot", "lab coat", "ladle", "lampshade", "laptop", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", "lighter", "limousine", "liner", "lipstick", "Loafer", "lotion", "loudspeaker", "loupe", "lumbermill", "magnetic compass", "mailbag", "mailbox", "maillot", "maillot tank suit", "manhole cover", "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine chest", "megalith", "microphone", "microwave", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", "mitten", "mixing bowl", "mobile home", "Model T", "modem", "monastery", "monitor", "moped", "mortar", "mortarboard", "mosque", "mosquito net", "motor scooter", "mountain bike", "mountain tent", "mouse", "mousetrap", "moving van", "muzzle", "nail", "neck brace", "necklace", "nipple", "notebook", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "organ", "oscilloscope", "overskirt", "oxcart", "oxygen mask", "packet", "paddle", "paddlewheel", "padlock", "paintbrush", "pajama", "palace", "panpipe", "paper towel", "parachute", "parallel bars", "park bench", "parking meter", "passenger car", "patio", "pay-phone", "pedestal", "pencil box", "pencil sharpener", "perfume", "Petri dish", "photocopier", "pick", "pickelhaube", "picket fence", "pickup", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate", "pitcher", "plane", "planetarium", "plastic bag", "plate rack", "plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", "pool table", "pop bottle", "pot", "potter's wheel", "power drill", "prayer rug", "printer", "prison", "projectile", "projector", "puck", "punching bag", "purse", "quill", "quilt", "racer", "racket", "radiator", "radio", "radio telescope", "rain barrel", "recreational vehicle", "reel", "reflex camera", "refrigerator", "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "rubber eraser", "rugby ball", "rule", "running shoe", "safe", "safety pin", "saltshaker", "sandal", "sarong", "sax", "scabbard", "scale", "school bus", "schooner", "scoreboard", "screen", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe shop", "shoji", "shopping basket", "shopping cart", "shovel", "shower cap", "shower curtain", "ski", "ski mask", "sleeping bag", "slide rule", "sliding door", "slot", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar dish", "sombrero", "soup bowl", "space bar", "space heater", "space shuttle", "spatula", "speedboat", "spider web", "spindle", "sports car", "spotlight", "stage", "steam locomotive", "steel arch bridge", "steel drum", "stethoscope", "stole", "stone wall", "stopwatch", "stove", "strainer", "streetcar", "stretcher", "studio couch", "stupa", "submarine", "suit", "sundial", "sunglass", "sunglasses", "sunscreen", "suspension bridge", "swab", "sweatshirt", "swimming trunks", "swing", "switch", "syringe", "table lamp", "tank", "tape player", "teapot", "teddy", "television", "tennis ball", "thatch", "theater curtain", "thimble", "thresher", "throne", "tile roof", "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toyshop", "tractor", "trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "tub", "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright", "vacuum", "vase", "vault", "velvet", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", "wallet", "wardrobe", "warplane", "washbasin", "washer", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", "wig", "window screen", "window shade", "Windsor tie", "wine bottle", "wing", "wok", "wooden spoon", "wool", "worm fence", "wreck", "yawl", "yurt", "web site", "comic book", "crossword puzzle", "street sign", "traffic light", "book jacket", "menu", "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "ice lolly", "French loaf", "bagel", "pretzel", "cheeseburger", "hotdog", "mashed potato", "head cabbage", "broccoli", "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", "jackfruit", "custard apple", "pomegranate", "hay", "carbonara", "chocolate sauce", "dough", "meat loaf", "pizza", "potpie", "burrito", "red wine", "espresso", "cup", "eggnog", "alp", "bubble", "cliff", "coral reef", "geyser", "lakeside", "promontory", "sandbar", "seashore", "valley", "volcano", "ballplayer", "groom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", "hip", "buckeye", "coral fungus", "agaric", "gyromitra", "stinkhorn", "earthstar", "hen-of-the-woods", "bolete", "ear", "toilet tissue"
    };
    cv::dnn::Net net = cv::dnn::readNetFromONNX("./alexnet.onnx");  // 加载模型
    cv::Mat image = cv::imread("./cat_224x224.jpg", 1);  // 读取图片
    cv::Mat blob = cv::dnn::blobFromImage(image, (double)(1.0/225.0), cv::Size(), cv::Scalar(225,225,225), true, false, CV_32F );  // 由图片加载数据 还可以进行缩放、归一化等预处理操作
    net.setInput(blob);  // 设置模型输入
    cv::Mat predict = net.forward(); // 推理结果

    double minValue, maxValue;
    cv::Point minIdx, maxIdx;
    cv::minMaxLoc(predict, &minValue, &maxValue, &minIdx, &maxIdx);

    string res = labels[maxIdx.x];

    return 0;
}

该技术链条:pytorch——>onnx——>c++是正确的,但是上面,使用的是torchvision提供的模型,直接放入参数,数据训练而成,其内部实现是屏蔽了的。

如何训练自己的模型,如何在c++调用时传入数据呢?

2.2 存在的问题

2.2.1 最主要存在的问题

python端:

  • 输入的数据格式

c++端:

  • 调用onnx模型时,传入的参数的数据格式

图像的训练,一般采用卷积神经网络,其传入数据的格式为NCWH的格式,其中N表示训练的批量大小,C表示数据(图像)的通道数(channels),W表示图像的宽度(width),H表示图像的高度(height)。

一般

  • NWH可根据自己需求调整。

  • C一般为3通道,分别为R,G,B。如果是1通道,一般转为灰度处理。

2.2.2 两种方案

因此python端和c++端的输入数据格式应统一。最好不要在python端做训练时,将数据提前做好预处理,并转为一维数据,因为c++端将四维转为一维比较麻烦。

2.2.2.1 使用一维数据做训练

比如我最开始做HOG+SVM的行人检测时.

2.2.2.1.1 加载数据集

加载自己的数据集,并在加载数据集时,对数据做预处理。

import os
import cv2
import numpy as np
import torch
from skimage import io
from torch.utils.data import DataLoader, Dataset

# posPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\96X160H96\Train\pos'
# negPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\negphoto'
# modelPath = r'.\model'
# annPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\annotations'
# resPath = r'.\res'
# TestNegPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\Test\neg'

def gamma(img):
    gamma_val = np.power(img / 255.0, 2.2)
    return gamma_val

def hog_descriptor(image):
    if (image.max() - image.min()) != 0:
        image = (image - image.min()) / (image.max() - image.min())
        image *= 255
        image = image.astype(np.uint8)
    hog = cv2.HOGDescriptor((64, 128), (16, 16), (8, 8), (8, 8), 9)
    hog_feature = hog.compute(image)
    # print("hog_feature.type() == ", type(hog_feature))
    # print("hog_feature.size() == ", len(hog_feature))
    # print("hog_feature == ", hog_feature)

    return hog_feature


class load_my_data(Dataset):
    def __init__(self, path):
        # print("path == ",path )
        pos_list = os.listdir(path)
        hog_list = []
        label_list = []
        tup_list = []
        for i in range(len(pos_list)):
            pos_img = io.imread(os.path.join(path, pos_list[i]))
            # print("type(pos_img) == ",type(pos_img))
            pos_img = cv2.cvtColor(pos_img, cv2.COLOR_RGBA2GRAY)
            # 所用图像已经经过标准化
            pos_img = cv2.resize(pos_img, (64, 128), interpolation=cv2.INTER_NEAREST)
            pos_img = gamma(pos_img)
            pos_hog = hog_descriptor(pos_img)
            hog_list.append(pos_hog)
            tup_temp = (pos_img,1)
            tup_list.append(tup_temp)
            label_list.append(1)
        self.hog_list = hog_list
        self.label_list = label_list
        self.tup_list = tup_list
        self.length = len(hog_list) + len(label_list)+len(tup_list)

    def __len__(self):
        return self.length  # 直接返回长度

    def __getitem__(self, index):
        pass

    def get_features(self):
        temp = np.array(self.hog_list)
        temp = temp.flatten()
        temp = torch.as_tensor(temp)
        temp = temp.reshape(-1,3780)
        return temp

    def get_labels(self):
        temp = np.array(self.label_list)
        temp = temp.flatten()
        temp = torch.as_tensor(temp)
        temp = temp.reshape(-1)
        return temp

    def get_features_nums(self):
        if len(self.hog_list) == len(self.label_list):
            return len(self.tup_list)
        else :
            return -1


# mydata = load_my_data(posPath)
# print("type(mydata.tup_list) == ",type(mydata.tup_list))
# print("type(mydata.tup_list[0][0]) == ",type(mydata.tup_list[0][0]))
# features = mydata.get_features()
# print("type(features) == ",type(features))
# print("mydata.get_features_nums() == ",mydata.get_features_nums())



在这一步中,我从本地读取图片,通过hog处理,提取出特征,然后将每张图片转为3780长度的一维数据。

2.2.2.1.2 pytorch训练

然后使用pytorch训练,由于只做验证,使用的时网络模型仅有一层。

import numpy as np

from loadData import *
import torch
import torch.nn as nn
import torch.optim as optim

class SVM(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SVM, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)

    def forward(self, x):
        x = self.linear(x)
        return x


posPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\96X160H96\Train\pos'
negPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\negphoto'
modelPath = r'.\model'
annPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\annotations'
resPath = r'.\res'
TestNegPath = r'.\DataSet\INRIA Predestrian Dataset\image_data\Test\neg'


#定义训练数据
mydata = load_my_data(posPath)
x_train = mydata.get_features()
y_train = mydata.get_labels()

# 定义SVM模型
svm = SVM(input_size=3780, num_classes=1)
criterion = nn.HingeEmbeddingLoss()
optimizer = optim.SGD(svm.parameters(), lr=0.01)

#训练模型
num_epochs = 1
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = svm(x_train)
    loss = criterion(outputs.squeeze(), y_train)
    loss.backward()
    optimizer.step()

#测试模型
x_test = mydata.get_features()[500:600]
outputs = svm(x_test)

predicted = outputs.detach().numpy().squeeze()
predicted = np.abs(predicted)
res = [predicted >0.7]
print(" x_test[0].size() == ", x_test.shape)
torch.onnx.export(svm, x_test, "./model/hogsvm.onnx")

保存模型后,在C++端进行调用。

2.2.2.1.3 c++调用模型

以下是c++的调用代码。

include <iostream>
#include <string>

#include <onnxruntime_cxx_api.h>

#include<opencv2/opencv.hpp>
#include <opencv2/dnn.hpp>


#include <torch/script.h>
#include <torch/torch.h>

using namespace std;

int main() {
    cv::dnn::Net net = cv::dnn::readNetFromONNX("./model/hogsvm.onnx");  // 加载模型
    std::array<float,3780> image;
    image.fill(0.99);
    //cv::Mat image = cv::imread("./images/2.jpg", 1);  // 读取图片
    //cv::imshow("image", image);
    //std::cout << "image.size == " << image.size() << std::endl;
    cv::Mat blob = cv::dnn::blobFromImage(image, (double)(1.0 / 225.0), cv::Size(), cv::Scalar(), true, false, CV_32F);  // 由图片加载数据 还可以进行缩放、归一化等预处理操作
    //std::cout << "blob == " << blob.size() << std::endl;
    net.setInput(blob);  // 设置模型输入
    cv::Mat predict = net.forward(); // 推理结果
    std::cout << "pridect == " << predict << std::endl;
    cv::waitKey(0);
}

最开始,我在这里会产生各种各样的报错,最主要的原因是,在c++环境中opencv提供的imread方式读取的图片是3通道数据的二维数据,也就是这里的输入数据的shape是【1,3,W,H】,每张图片的大小是其本身的尺寸大小。而我们python环境训练的时候,使用的数据是一维的,为【L】(L是我自己定义的,表示一维数组的长度)。所以在C++环境中输入的数据不匹配,肯定会报错。因此,需要转换通过imread获取的数据到一维。

上述代码中直接做了测试,没有完成Mat到array一维的转换。

2.2.2.1.4 构建数据测试

上述代码构建了一个一维的长度为3780的array。进行测试。测试结果如下

在这里插入图片描述

这里有一些注意的点,构建的image类型需要为float类型,double会报错"OpenCV(3.4.16) Error: Assertion failed (image.depth() == blob_.depth()) in cv::dnn::experimental_dnn_34_v23::blobFromImages, file C:\build\3_4_winpack-build-win64-vc15\opencv\modules\dnn\src\dnn.cpp, line 370"

如下
在这里插入图片描述

2.2.2.15加载图像数据测试

待补充

2.2.2.2 使用4维数据

也可以带入4维度,数据进行测试,

2.2.2.2.1 pytorch训练

参考的博文链接:PyTorch深度学习实践 第十讲 卷积神经网络(基础篇)

在其代码上,加入出模型的命令。完整代码如下

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F

# 利用卷积神经网络解决MNIST手写数字识别
# 1、准备数据集
# 处理数据
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
batch_size = 64
# 训练集
mnist_train = MNIST(root='../dataset/mnist', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=mnist_train, shuffle=True, batch_size=batch_size)
# 测试集
mnist_test = MNIST(root='../dataset/mnist', train=False, transform=transform, download=True)
test_loader = DataLoader(dataset=mnist_test, shuffle=True, batch_size=batch_size)


# 2.设计模型类
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = torch.nn.Conv2d(1, 10, kernel_size=5)
        self.pooling = torch.nn.MaxPool2d(2)
        self.conv2 = torch.nn.Conv2d(10, 20, kernel_size=5)
        self.fc = torch.nn.Linear(320, 10)  # 最后使用全连接层,分类的类别为10个

    def forward(self, x):
        x = self.pooling(F.relu(self.conv1(x)))  # 先卷积,再激活,再池化
        x = self.pooling(F.relu(self.conv2(x)))
        # 全连接层,将x[batch_size,20,4,4]->x[batch,20*4*4]  全连接层只能接受一维的数据
        x = x.view(-1, 320)  # 或者写成 x = x.view(batch_size,-1)
        x = self.fc(x)
        return x


model = Net()
# 3、构造损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

input = torch.randn(64,1,28,28)


# 4、训练和测试
# 定义训练方法,一个训练周期
def train(epoch):
    running_loss = 0.0
    for idx, (inputs, target) in enumerate(train_loader, 0):
        # 这里的代码与之前没有区别
        # 正向
        y_pred = model(inputs)
        loss = criterion(y_pred, target)
        # 反向
        optimizer.zero_grad()
        loss.backward()
        # 更新
        optimizer.step()

        running_loss += loss.item()
        if idx % 300 == 299:  # 每300次打印一次平均损失,因为idx是从0开始的,所以%299,而不是300
            print(f'epoch={epoch + 1},batch_idx={idx + 1},loss={running_loss / 300}')
            running_loss = 0.0

    torch.onnx.export(model, input, "./model/torch_con.onnx")

# 定义测试方法,一个测试周期
def test():
    # 所有预测正确的样本数
    correct_num = 0
    # 所有样本的数量
    total = 0
    # 测试时,我们不需要计算梯度,因此可以加上这一句,不需要梯度追踪
    with torch.no_grad():
        for images, labels in test_loader:
            # 获得预测值
            outputs = model(images)
            # 获取dim=1的最大值的位置,该位置就代表所预测的标签值
            _, predicted = torch.max(outputs.data, dim=1)
            # 累加每批次的样本数,以获得一个测试周期所有的样本数
            total += labels.size(0)
            # 累加每批次的预测正确的样本数,以获得一个测试周期的所有预测正确的样本数
            correct_num += (predicted == labels).sum().item()
        print(f'Accuracy on test set:{100 * correct_num / total}%')  # 打印一个测试周期的正确率


if __name__ == '__main__':
    # 训练周期为10次,每次训练所有的训练集样本数,并测试
    for epoch in range(1):
        train(epoch)
        test()
2.2.2.2.2 c++调用模型

和2.2.2.1.3节一样,构建数据进行测试。只要读取图片,处理成相同格式,就可以运行。

int main() {
    cv::dnn::Net net = cv::dnn::readNetFromONNX("./model/torch_con.onnx");  // 加载模型
    //cv::Mat image = cv::imread("./images/2.jpg", 1);  // 读取图片
    //torch::Tensor a = torch::randn({ 64,1,28,28 });
    cv::Mat image(28, 28, CV_8UC3, cv::Scalar(255));

    cv::Mat gray;
    cvtColor(image, gray, cv::COLOR_RGB2GRAY);
    cv::imshow("image", gray);
    std::cout << "image.size == " << gray.size() << std::endl;
    cv::Mat blob = cv::dnn::blobFromImage(gray, (double)(1.0 / 225.0), cv::Size(), cv::Scalar(), true, false, CV_32F);  // 由图片加载数据 还可以进行缩放、归一化等预处理操作
    //std::cout << "blob == " << blob.size() << std::endl;
    net.setInput(blob);  // 设置模型输入
    cv::Mat predict = net.forward(); // 推理结果
    std::cout << "pridect == " << predict << std::endl;
    cv::waitKey(0);
}

上面使用cv::Mat提供的构造函数,构建了一个三通道的28*28大小的图片(矩阵)。利用cvtColor将其转为灰度图像(1通道)。然后给模型放入该数据,可以看到测试结果如下图所示。
在这里插入图片描述

三、思考

3.1 建议

因为对于onnxruntime库,给net.setInput()方法放数据的时候,数据必须时protobuf格式,它提供了 cv::dnn::blobFromImage(gray, (double)(1.0 / 225.0), cv::Size(), cv::Scalar(), true, false, CV_32F),进行转换。

为了简单,

  • 对于一维数据,c++需要想办法把数据转为array类型
  • 对于四维数据
    • 三通道(图像),c++需要可以直接利用opencv的imread方法读取图像。
    • 一通道(灰度图像),c++通过imread读取图像后可以通过cvtColor转为1通道图像。

3.2 参数说明

blob = cv2.dnn.blobFromImage(image, scalefactor=1.0, size, mean, swapRB=True,crop=False,ddepth = CV_32F )
  • **image:**这个就是我们将要输入神经网络进行处理或者分类的图片。

  • scalefactor:当我们将图片减去平均值之后,还可以对剩下的像素值进行一定的尺度缩放,它的默认值是1,如果希望减去平均像素之后的值,全部缩小一半,那么可以将scalefactor设为1/2。

  • **size:**这个参数是我们神经网络在训练的时候要求输入的图片尺寸。

  • mean:需要将图片整体减去的平均值,如果我们需要对RGB图片的三个通道分别减去不同的值,那么可以使用3组平均值,如果只使用一组,那么就默认对三个通道减去一样的值。减去平均值**(mean):为了消除同一场景下不同光照的图片,对我们最终的分类或者神经网络的影响,我们常常对图片的R、G、B**通道的像素求一个平均值,然后将每个像素值减去我们的平均值,这样就可以得到像素之间的相对值,就可以排除光照的影响。

  • swapRB:OpenCV中认为我们的图片通道顺序是BGR,但是我平均值假设的顺序是RGB,所以如果需要交换R和G,那么就要使swapRB=true

  • crop,如果crop裁剪为真,则调整输入图像的大小,使调整大小后的一侧等于相应的尺寸,另一侧等于或大于。然后,从中心进行裁剪。如果“裁剪”为“假”,则直接调整大小而不进行裁剪并保留纵横比。

  • ddepth, 输出blob的深度,选则CV_32F or CV_8U。

    cv2.dnn.blobFromImage函数返回的blob是我们输入图像进行随意从中心裁剪,减均值、缩放和通道交换的结果。cv2.dnn.blobFromImages和cv2.dnn.blobFromImage不同在于,前者接受多张图像,后者接受一张图像。多张图像使用cv2.dnn.blobFromImages有更少的函数调用开销,你将能够更快批处理图像或帧。

类或者神经网络的影响,我们常常对图片的R、G、B**通道的像素求一个平均值,然后将每个像素值减去我们的平均值,这样就可以得到像素之间的相对值,就可以排除光照的影响。

  • swapRB:OpenCV中认为我们的图片通道顺序是BGR,但是我平均值假设的顺序是RGB,所以如果需要交换R和G,那么就要使swapRB=true

  • crop,如果crop裁剪为真,则调整输入图像的大小,使调整大小后的一侧等于相应的尺寸,另一侧等于或大于。然后,从中心进行裁剪。如果“裁剪”为“假”,则直接调整大小而不进行裁剪并保留纵横比。

  • ddepth, 输出blob的深度,选则CV_32F or CV_8U。

    cv2.dnn.blobFromImage函数返回的blob是我们输入图像进行随意从中心裁剪,减均值、缩放和通道交换的结果。cv2.dnn.blobFromImages和cv2.dnn.blobFromImage不同在于,前者接受多张图像,后者接受一张图像。多张图像使用cv2.dnn.blobFromImages有更少的函数调用开销,你将能够更快批处理图像或帧。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
要配置deeplabv3-plus-pytorch训练环境,您需要进行以下步骤: 1. 安装Python:确保已经安装了Python,并建议使用Python 3.6或更高版本。 2. 创建虚拟环境(可选):为了隔离不同项目的依赖,建议在项目中使用虚拟环境。您可以使用`venv`模块或第三方工具(如`conda`)创建和管理虚拟环境。 3. 安装PyTorch和TorchVision:PyTorch是进行深度学习的基础库,而TorchVision提供了处理图像数据集的工具。您可以使用以下命令安装PyTorch和TorchVision: ``` pip install torch torchvision ``` 如果您需要特定的PyTorch版本,可以在安装命令中指定版本号。 4. 克隆deeplabv3-plus-pytorch仓库:将deeplabv3-plus-pytorch的代码库克隆到本地: ``` git clone https://github.com/VainF/DeepLabV3Plus-Pytorch.git ``` 5. 安装依赖项:进入克隆的代码库目录,并使用以下命令安装所需的Python依赖项: ``` pip install -r requirements.txt ``` 6. 下载预训练模型权重(可选):如果您想从预训练模型开始训练,您可以下载已经预训练好的权重。可以在代码库的README文件中找到下载链接,并将权重文件保存到适当的位置。 7. 准备数据集:根据您的任务和数据集,将图像和标签数据组织到相应的文件夹中。确保数据集的文件路径与代码库中的配置文件相对应。 8. 开始训练:运行相应的训练脚本,例如`train.py`,并根据需要配置训练参数。您可以通过命令行参数或修改配置文件来设置训练参数。 以上是一个基本的环境配置过程,具体的步骤可能会因为您的特定环境和需求而有所不同。请参考deeplabv3-plus-pytorch代码库中的文档和说明,以获取更详细的配置指导。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

豆得儿不是猫

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值