import numpy as np
import math
class Conv2D():
def __init__(self, inputShape, outputChannel, kernelSize, stride=1, method=""):
self.height = inputShape[2]
self.width = inputShape[3]
self.inputChannel = inputShape[1]
self.outputChannel = outputChannel
self.batchSize = inputShape[0]
self.stride = stride
self.kernelSize = kernelSize
self.method = method
# initial the parameters of the kernel, do not initial them as zero
self.weights = np.random.random([kernelSize, kernelSize, self.inputChannel,self.outputChannel])
self.bias = np.random.random(self.outputChannel)
if method == "SAME":
self.output = np.zeros(
(self.batchSize,
self.outputChannel, math.floor(self.height / self.stride), math.floor(self.width / self.stride)))
if method == "VALID":
self.output = np.zeros([self.batchSize,
self.outputChannel, np.floor((self.height - kernelSize + 1) / self.stride),
math.floor((self.width - kernelSize + 1) / self.stride)])
def forward(self, x):
weights = self.weights.reshape([-1,self.outputChannel]) # shape: [(h*w),#]
# Filling operation
# Note that: x is 4-dimensional.
convOut = np.zeros(self.output.shape)
for i in range(self.batchSize):
img_i = x[i]
print("img_i:",img_i.shape)
# img_i = x[i][np.newaxis, :, :, :]
colImage_i = self.im2col(img_i, self.kernelSize, self.stride)
print("colImage_i:",colImage_i.shape)
convOut[i] = np.reshape(np.dot(colImage_i, weights) + self.bias, self.output[0].shape)
return convOut
# im2col function
def im2col(self, image, kernelSize, stride):
imageCol = []
for i in range(0, image.shape[1] - kernelSize + 1, stride):
for j in range(0, image.shape[2] - kernelSize + 1, stride):
col = image[:,i:i + kernelSize, j:j + kernelSize].reshape([-1])
imageCol.append(col)
imageCol = np.array(imageCol) # shape: [(h*w),(c*h*w)] kernel's height, width and channels
return imageCol
# Test part
inputData = np.random.random((4,3, 5, 5))
kernel = list([3, 3, 32])
conv2d = Conv2D(inputShape=inputData.shape, outputChannel=kernel[2], kernelSize=kernel[0], stride=1, method='SAME')
outputData = conv2d.forward(inputData)
print("outputShape: ", outputData.shape)
python卷积操作
最新推荐文章于 2024-05-08 16:26:37 发布