深度余弦度量学习(cosine-metric-learning)在VeRi数据集调试

深度余弦度量学习cosine-metric-learning在VeRi数据集调试

训练部分

在VeRi数据集上调试深度余弦度量学习时,出现了很多bug,其中有一个是由于输入图像的维度和placehold的维度不一致报错,导致我花费了很多时间和精力去找问题,后来通过更改了源代码,才得以调通。
在train_app.py上
原来代码:

filename_var = tf.placeholder(tf.string, (None, ))
image_var = tf.map_fn(
      lambda x:tf.image.decode_jpeg(
          tf.read_file(x), channels=num_channels),
      filename_var, back_prop=False, dtype=tf.float32)
image_var = tf.image.resize_images(image_var, image_shape[:2])

更改后的代码:

filename_var = tf.placeholder(tf.string, (None, ))
image_var = tf.map_fn(
    lambda x: tf.image.resize_images(tf.image.decode_jpeg(
        tf.read_file(x), channels=num_channels), image_shape[:2]),
    filename_var, back_prop=False, dtype=tf.float32)

**

数据读取部分

**
VeRi.py

# vim: expandtab:ts=4:sw=4
import os
import numpy as np
import cv2
import scipy.io as sio


# The maximum person ID in the dataset.
MAX_LABEL = 769   # VeRi max label 

IMAGE_SHAPE = 128, 64, 3  # 此处为了简单不改变模型结构,保持与Market1501一致

def _parse_filename(filename):
    filename_base, ext = os.path.splitext(filename)
    if '.' in filename_base:
        # Some images have double filename extensions.
        filename_base, ext = os.path.splitext(filename_base)
    if ext != ".jpg":
        return None
    person_id, cam_seq, frame_idx, detection_idx = filename_base.split('_')
    return int(person_id), int(cam_seq[1]), filename_base, ext


def read_train_split_to_str(dataset_dir):
    filenames, ids, camera_indices = [], [], []

    image_dir = os.path.join(dataset_dir, "image_train")
    for filename in sorted(os.listdir(image_dir)):
        meta_data = _parse_filename(filename)
        if meta_data is None:
            # This is not a valid filename (e.g., Thumbs.db).
            continue

        filenames.append(os.path.join(image_dir, filename))
        ids.append(meta_data[0])
        camera_indices.append(meta_data[1])

    return filenames, ids, camera_indices


def read_train_split_to_image(dataset_dir):
 
    filenames, ids, camera_indices = read_train_split_to_str(dataset_dir)

    images = np.zeros((len(filenames), 128, 64, 3), np.uint8)
    for i, filename in enumerate(filenames):
        # 将图片resize为(128, 64, 3) 保持与之前结构一致
        images[i] = cv2.resize(cv2.imread(filename, cv2.IMREAD_COLOR), (64, 128))

    ids = np.asarray(ids, np.int64)
    camera_indices = np.asarray(camera_indices, np.int64)
    return images, ids, camera_indices


def read_test_split_to_str(dataset_dir):
    # Read gallery.
    gallery_filenames, gallery_ids = [], []

    image_dir = os.path.join(dataset_dir, "bounding_box_test")
    for filename in sorted(os.listdir(image_dir)):
        meta_data = _parse_filename(filename)
        if meta_data is None:
            # This is not a valid filename (e.g., Thumbs.db).
            continue

        gallery_filenames.append(os.path.join(image_dir, filename))
        gallery_ids.append(meta_data[0])

    # Read queries.
    query_filenames, query_ids, query_junk_indices = [], [], []

    image_dir = os.path.join(dataset_dir, "query")
    for filename in sorted(os.listdir(image_dir)):
        meta_data = _parse_filename(filename)
        if meta_data is None:
            # This is not a valid filename (e.g., Thumbs.db).
            continue

        filename_base = meta_data[2]
        junk_matfile = filename_base + "_junk.mat"
        mat = sio.loadmat(os.path.join(dataset_dir, "gt_query", junk_matfile))
        if np.any(mat["junk_index"] < 1):
            indices = []
        else:
            # MATLAB to Python index.
            indices = list(mat["junk_index"].astype(np.int64).ravel() - 1)

        query_junk_indices.append(indices)
        query_filenames.append(os.path.join(image_dir, filename))
        query_ids.append(meta_data[0])

    # The following matrix maps from query (row) to gallery image (column) such
    # that element (i, j) evaluates to 0 if query i and gallery image j should
    # be excluded from computation of the evaluation metrics and 1 otherwise.
    good_mask = np.ones(
        (len(query_filenames), len(gallery_filenames)), np.float32)
    for i, junk_indices in enumerate(query_junk_indices):
        good_mask[i, junk_indices] = 0.

    return gallery_filenames, gallery_ids, query_filenames, query_ids, good_mask


def read_test_split_to_image(dataset_dir):
    gallery_filenames, gallery_ids, query_filenames, query_ids, good_mask = (
        read_test_split_to_str(dataset_dir))

    gallery_images = np.zeros((len(gallery_filenames), 128, 64, 3), np.uint8)
    for i, filename in enumerate(gallery_filenames):
    	# 将图片resize为(128, 64, 3) 保持与之前结构一致
        gallery_images[i] = cv2.resize(cv2.imread(filename, cv2.IMREAD_COLOR), (64, 128))

    query_images = np.zeros((len(query_filenames), 128, 64, 3), np.uint8)
    for i, filename in enumerate(query_filenames):
        query_images[i] = cv2.imread(filename, cv2.IMREAD_COLOR)

    gallery_ids = np.asarray(gallery_ids, np.int64)
    query_ids = np.asarray(query_ids, np.int64)
    return gallery_images, gallery_ids, query_images, query_ids, good_mask

train_VeRi_dataset.py

# vim: expandtab:ts=4:sw=4
import functools
import os
import numpy as np
import scipy.io as sio
import train_app
from datasets import market1501
from datasets import util
import nets.deep_sort.network_definition as net


class VeRi_dataset(object):

    def __init__(self, dataset_dir, num_validation_y=0.1, seed=1):
    # 切分训练集的10%为验证集
        self._dataset_dir = dataset_dir
        self._num_validation_y = num_validation_y
        self._seed = seed

    def read_train(self):
        filenames, ids, camera_indices = VeRi.read_train_split_to_str(
            self._dataset_dir)
        train_indices, _ = util.create_validation_split(
            np.asarray(ids, np.int64), self._num_validation_y, self._seed)

        filenames = [filenames[i] for i in train_indices]
        ids = [ids[i] for i in train_indices]
        camera_indices = [camera_indices[i] for i in train_indices]
        return filenames, ids, camera_indices

    def read_validation(self):
        filenames, ids, camera_indices = VeRi.read_train_split_to_str(
            self._dataset_dir)
        _, valid_indices = util.create_validation_split(
            np.asarray(ids, np.int64), self._num_validation_y, self._seed)

        filenames = [filenames[i] for i in valid_indices]
        ids = [ids[i] for i in valid_indices]
        camera_indices = [camera_indices[i] for i in valid_indices]
        return filenames, ids, camera_indices

    def read_test(self):
        return VeRi.read_test_split_to_str(self._dataset_dir)


def main():
    arg_parser = train_app.create_default_argument_parser("VeRi")
    arg_parser.add_argument(
        "--dataset_dir", help="Path to Market1501 dataset directory.",
        default="data/VeRi")
    args = arg_parser.parse_args()
    dataset = VeRi(args.dataset_dir, num_validation_y=0.1, seed=1234)

    if args.mode == "train":
        train_x, train_y, _ = dataset.read_train()
        print("Train set size: %d images, %d identities" % (
            len(train_x), len(np.unique(train_y))))

        network_factory = net.create_network_factory(
            is_training=True, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=args.loss_mode == "cosine-softmax")
        train_kwargs = train_app.to_train_kwargs(args)
        train_app.train_loop(
            net.preprocess, network_factory, train_x, train_y,
            num_images_per_id=4, image_shape=VeRi.IMAGE_SHAPE,
            **train_kwargs)
    elif args.mode == "eval":
        valid_x, valid_y, camera_indices = dataset.read_validation()
        print("Validation set size: %d images, %d identities" % (
            len(valid_x), len(np.unique(valid_y))))

        network_factory = net.create_network_factory(
            is_training=False, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=args.loss_mode == "cosine-softmax")
        eval_kwargs = train_app.to_eval_kwargs(args)
        train_app.eval_loop(
            net.preprocess, network_factory, valid_x, valid_y, camera_indices,
            image_shape=VeRi.IMAGE_SHAPE, **eval_kwargs)
    elif args.mode == "export":
        # Export one specific model.
        gallery_filenames, _, query_filenames, _, _ = dataset.read_test()

        network_factory = net.create_network_factory(
            is_training=False, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=False, reuse=None)
        gallery_features = train_app.encode(
            net.preprocess, network_factory, args.restore_path,
            gallery_filenames, image_shape=VeRi.IMAGE_SHAPE)
        sio.savemat(
            os.path.join(args.sdk_dir, "feat_test.mat"),
            {"features": gallery_features})

        network_factory = net.create_network_factory(
            is_training=False, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=False, reuse=True)
        query_features = train_app.encode(
            net.preprocess, network_factory, args.restore_path,
            query_filenames, image_shape=VeRi.IMAGE_SHAPE)
        sio.savemat(
            os.path.join(args.sdk_dir, "feat_query.mat"),
            {"features": query_features})
    elif args.mode == "finalize":
        network_factory = net.create_network_factory(
            is_training=False, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=False, reuse=None)
        train_app.finalize(
            functools.partial(net.preprocess, input_is_bgr=True),
            network_factory, args.restore_path,
            image_shape=VeRi.IMAGE_SHAPE,
            output_filename="./VeRi.ckpt")
    elif args.mode == "freeze":
        network_factory = net.create_network_factory(
            is_training=False, num_classes=VeRi.MAX_LABEL + 1,
            add_logits=False, reuse=None)
        train_app.freeze(
            functools.partial(net.preprocess, input_is_bgr=True),
            network_factory, args.restore_path,
            image_shape=VeRi.IMAGE_SHAPE,
            output_filename="./VeRi.pb")
    else:
        raise ValueError("Invalid mode argument.")


if __name__ == "__main__":
    main()

版权归属本作者所用,转载需引用

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值