Tensor must be 4-D with last dim 1, 3, or 4,bug记录

在用torchvision.utils.make_grid处理完图片之后,维度会变为三维,如果需要使用tf.summary.image ,需要增加batch size 的那一维,需要用到的函数为 tf.expand_dims。


输入:Tensor of shape (B x C x H x W)
输出:多个图片拼接成的一个大图 三维 没有batch size 那一维
like this:
batch size 16

torchvision.utils.make_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
scale_each=False, pad_value=0)

作用:Make a grid of images.

    tensor (Tensor or list) – 4D mini-batch Tensor of shape (B x C x H x W) or a list of images all of the same size.
    nrow (int, optional) – Number of images displayed in each row of the grid. The Final grid size is (B / nrow, nrow). Default is 8.
    padding (int, optional) – amount of padding. Default is 2.
    normalize (bool, optional) – If True, shift the image to the range (0, 1), by subtracting the minimum and dividing by the maximum pixel value.
    range (tuple, optional) – tuple (min, max) where min and max are numbers, then these numbers are used to normalize the image. By default, min and max are computed from the tensor.
    scale_each (bool, optional) – If True, scale each image in the batch of images separately rather than the (min, max) over all images.
    pad_value (float, optional) – Value for the padded pixels.

2、顺便记录一下另外一个函数 torchvision.utils.save_image:
torchvision.utils.save_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)
Save a given Tensor into an image file.

tensor (Tensor or list) – Image to be saved. If given a mini-batch tensor, saves the tensor as a grid of images by calling make_grid.
   **kwargs – Other arguments are documented in make_grid.


tf.image_summary(tag, tensor, max_images=None, collections=None, name=None)

Outputs a Summary protocol buffer with images.

The summary has up to max_images summary values containing images. The images are built from tensor which must be 4-D with shape [batch_size, height, width, channels] and where channels can be:

    1: tensor is interpreted as Grayscale.
    3: tensor is interpreted as RGB.
    4: tensor is interpreted as RGBA.

4、 tf.expand_dims 增加维度:

# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]

input: A Tensor.
dim: A Tensor. Must be one of the following types: int32, int64. 0-D (scalar). Specifies the dimension index at which to expand the shape of input.
name: A name for the operation (optional).

A Tensor. Has the same type as input. Contains the same data as input, but its shape has an additional dimension of size 1 added.


# coding=utf-8

from __future__ import print_function
from six.moves import range

import torch.backends.cudnn as cudnn
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torchvision.utils as vutils
import torchvision.datasets as dset
import torchvision.transforms as transforms
import numpy as np
import os
import time
import torchvision
from PIL import Image, ImageFont, ImageDraw
from copy import deepcopy
import tensorflow as tf
from torch.utils.data import DataLoader, Dataset
# from miscc.config import cfg
# from miscc.utils import mkdir_p
# from PIL import Image
import matplotlib
import matplotlib.pyplot as plt

sess = tf.InteractiveSession()
# def test():
# torchvision输出的是PILImage,值的范围是[0, 1].
# 我们将其转化为tensor数据,并归一化为[-1, 1]。
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

# 训练集,将相对目录./data下的cifar-10-batches-py文件夹中的全部数据(50000张图片作为训练数据)加载到内存中,若download为True时,会自动从网上下载数据并解压
trainset = torchvision.datasets.CIFAR10(root='./test', train=True, download=False, transform=transform)

# 将训练集的50000张图片划分成12500份,每份4张图,用于mini-batch输入。shffule=True在表示不同批次的数据遍历时,打乱顺序。num_workers=2表示使用两个子进程来加载数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=3, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
print ("len(trainset)",len(trainset))
print ("len(trainloader)",len(trainloader))

for i, data in enumerate(trainloader, 0):
    # print(data[i][0])
    # img = transforms.ToPILImage()(data[i][0])
    # img.show()
    # break
以上代码来自简书 https://www.jianshu.com/p/8da9b24b2fb6

    real_img_set = vutils.make_grid(data[i][0]).numpy()
    # print("real_img_set1",real_img_set)
    real_img_set = np.transpose(real_img_set, (1, 2, 0))
    # print("real_img_set_transpose",real_img_set)
    real_img_set = real_img_set * 255
    # print("real_img_set255",real_img_set)
    real_img_set = real_img_set.astype(np.uint8)

    super_real_img_set = tf.expand_dims(real_img_set, 0)
    print ("super_real_img_old", super_real_img_set)
    print ("super_real_img_old shape", super_real_img_set.shape)
    print ("super_real_img_old [0]", super_real_img_set[0])
    sup_real_img = tf.summary.image('real_img', super_real_img_set)
    print("sup_real_img", sup_real_img)
    print("sup_real_img shape", sup_real_img.shape)
    sup_real_img_new = sess.run(sup_real_img)
    # summary_writer.add_summary(sup_real_img_new, count)

个人分类: tensorflow
上一篇InfoGAN 网络训练过程介绍
想对作者说点什么? 我来说一句