Pytorch——CNN Flatten Operation Visualized - Tensor Batch Processing for Deep Learning

Flatten operation for a batch of image inputs to a CNN

Flattening an entire tensor

To flatten a tensor, we need to have at least two axes. This makes it so that we are starting with something that is not already flat.
But what if we want to only flatten specific axes within the tensor? This is typically required when working with CNNs

Flattening specific axes of a tensor

Tensor inputs to a convolutional neural network typically have 4 axes, one for batch size, one for color channels, and one each for height and width.

  • (Batch Size, Channels, Height, Width)

Suppose we have the following three tensors.

t1 = torch.tensor([
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1],
    [1,1,1,1]
])

t2 = torch.tensor([
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2],
    [2,2,2,2]
])

t3 = torch.tensor([
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3],
    [3,3,3,3]
])

Each of these has a shape of 4 x 4, so we have three rank-2 tensors. For our purposes here, we’ll consider these to be three 4 x 4 images that well use to create a batch that can be passed to a CNN.
Remember, batches are represented using a single tensor, so we’ll need to combine these three tensors into a single larger tensor that has three axes instead of 2.

> t = torch.stack((t1, t2, t3))
> t.shape

torch.Size([3, 4, 4])

The axis with a length of 3 represents the batch size while the axes of length 4 represent the height and width respectively. This is what the output for this this tensor representation of batch looks like.

> t
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]],

        [[2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2],
         [2, 2, 2, 2]],

        [[3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3],
         [3, 3, 3, 3]]])

At this point, we have a rank-3 tensor that contains a batch of three 4 x 4 images. All we need to do now to get this tensor into a form that a CNN expects is add an axis for the color channels. We basically have an implicit single color channel for each of these image tensors, so in practice, these would be grayscale images.
A CNN will expect to see an explicit color channel axis, so let’s add one by reshaping this tensor.

 t = t.reshape(3,1,4,4)
> t
tensor(
[
    [
        [
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1],
            [1, 1, 1, 1]
        ]
    ],
    [
        [
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2],
            [2, 2, 2, 2]
        ]
    ],
    [
        [
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3],
            [3, 3, 3, 3]
        ]
    ]
])

The first axis has 3 elements. Each element of the first axis represents an image. For each image, we have a single color channel on the channel axis. Each of these channels contain 4 arrays that contain 4 numbers or scalar components.
We have the first image.

> t[0]
tensor([[[1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]])

We have the first color channel in the first image.

> t[0][0]
tensor([[1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

We have the first first row of pixels in the first color channel of the first image.

> t[0][0][0]
tensor([1, 1, 1, 1])

We have the first pixel value in the first row of the first color channel of the first image.

> t[0][0][0][0]
tensor(1)

Flattening the tensor batch

Let’s flatten the whole thing first just to see what it will look like.

> t.reshape(1,-1)[0] # Method 1
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.reshape(-1) # Method 2
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.view(t.numel()) # Method 3
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

> t.flatten() # Method 4
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
    2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3])

We want our flattened batch work well inside our CNN, the way to do it is to flatten each image while still maintaining the batch axis. We want to flatten the, color channel axis with the height and width axes.These axes need to be flattened: (C,H,W)

Flattening specific axes of a tensor

从维度1开始,之后的全部被flatten,保留维度0不变

> t.flatten(start_dim=1).shape
torch.Size([3, 16])

> t.flatten(start_dim=1)
tensor(
[
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
    [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]
]
)

The start_dim parameter tells the flatten() method which axis it should start the flatten operation. The one here is an index, so it’s the second axis which is the color channel axis.

Flattening an RGB Image

If we flatten an RGB image, each color channel will be flattened first. Then, the flattened channels will be lined up side by side on a single axis of the tensor. Let’s look at an example in code.

# torch.ones()注意
>torch.ones(2,2)
tensor([[1., 1.],
        [1., 1.]])


> torch.ones(1,2,2)
tensor([[[1., 1.],
         [1., 1.]]])

> torch.ones(2, 2, 2)
tensor([[[1., 1.],
         [1., 1.]],
			
	 [1., 1.],
	 [1., 1.]]])

We’ll build an example RGB image tensor with a height of two and a width of two.

r = torch.ones(1,2,2)
g = torch.ones(1,2,2) + 1
b = torch.ones(1,2,2) + 2

img = torch.cat(
    (r,g,b)
    ,dim=0
)

> img.shape
torch.Size([3, 2, 2])

> img
tensor([
    [
        [1., 1.]
        ,[1., 1.]
    ]
    ,[
        [2., 2.]
        , [2., 2.]
    ],
    [
        [3., 3.]
        ,[3., 3.]
    ]
])

> img.flatten(start_dim=0)
tensor([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.])

> img.flatten(start_dim=1)
tensor([
    [1., 1., 1., 1.],
    [2., 2., 2., 2.],
    [3., 3., 3., 3.]
])
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

TonyHsuM

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值