背景
用Python的keras深度学习框架完成的Unet网络代码转为MATLAB去实现,不能用MATLAB去调用Python里的包,所以说网络的底层都需要自己去实现,比如一维卷积层,一维池化层以及上采样层,还要加载训练好的网络参数等等,在实现过程中遇到的一些问题,拿出来写一写。
问题1 - reshape函数
这个是遇到最苟的问题,一直没想到是reshape函数出了问题。
MATLAB中的reshape函数是按照列的顺序进行读取转换的,也就是第一列读完,读第二列,如下面例子即依次读取 a 的每一列数据,存的时候也是按列进行存放,按顺序依次存放到目标矩阵的每一列得到 b ,也就是所谓的按列读取按列存放。
a = [1:6]
b = reshape(a, 2, 3)
c = reshape(a, 3, 2)
反观Python中的reshape函数,这里以numpy的reshape示例,与MATLAB相反,它是按行读取按行存放的,这是很容易出错的地方,使用的时候一定要注意!
import numpy as np
a = np.linspace(1,6,6)
b = np.reshape(a, [2, 3])
c = np.reshape(a, [3, 2])
print('a = \n', a)
print('b = \n', b)
print('c = \n', c)
Tips:无论用哪种语言的reshape,在使用时都要注意数据维度问题奥,维度不匹配也会出现错误,如果不确定维度或者懒得去想可以用以下方式解决:
% matlab
b = reshape(a, 3, []) % 未知维度用中括号顶替,结果是一样的,会自动补齐维度
c = reshape(a, [], 2)
# Python
b = np.reshape(a, [3, -1]) # 未知维度参数设为-1,reshape结果也是一样的
print('b = \n', b)
c = np.reshape(a, [-1, 2])
print('c = \n', c)
如果想让俩种语言实现效果一致,比如用MATLAB实现的Python的reshape,可以通过先用reshape得到原目标矩阵的转置矩阵,然后再求转置:
a = [1:6]
b = reshape(a, 2, 3)
b = b' % 此时b为目标矩阵,维度为3*2,与上述Python结果一致
问题2 - 索引问题
这个算比较常见的,MATLAB的索引是从1开始的,而Python的索引是从0开始的:
% MATLAB
a = [1:6]
b = a(1)
# Python
import numpy as np
a = np.linspace(1,6,6)
b = a[1]
print('a = \n', a)
print('b = \n', b)
可以看出取出a[1]的数据,但MATLAB取的是数组里的第一个数,而python取的是第二个数,这便是索引起始位置不同引起的差异。
看一下两种代码获取相同数据的即视感:
seg_len = floor(test_flat(i)/sig_length);
for j = 1:seg_len
label = tmp_states((j-1)*sig_length+1:j*sig_length);
a_sig = tmp_features((j-1)*sig_length+1:j*sig_length, :);
end
seg_len = int(flat[0][i]/sig_length)
for j in range(seg_len):
label = tmp_states[j*sig_length:(j+1)*sig_length]
a_sig = tmp_features[j*sig_length:(j+1)*sig_length]
一定要注意截取区间的设置!这两部分代码输入相同数据,输出是一致的。
问题3 - 补零问题
毕竟要底层实现网络各部分结构,比如卷积部分,池化部分,选取SAME或full的方式进行操作的时候,就会涉及到补零的问题(按步长为1考虑)。
要想正确补零,在实现时要知道卷积核大小是奇数还是偶数:
- 当为奇数时,这时候若是一维卷积,上侧和下侧补零个数是相同的;若为二维卷积,左侧和右侧补零和数一致,都是(filter_size-1)/2。
- 当为偶数时,这时候若是一维卷积,下侧补零个数要比上侧多一;若为二维卷积,右侧补零个数要比左侧个数多一。上侧补零个数为filter_size/2-1,下侧为filter_size/2。
(不懂的小伙伴可以看一下三种padding方式,补零的时候一定要注意)
demo:
top = floor((filter_size-1) /2); % 上侧补零个数
bottom = filter_size - top ; % 下侧补零个数
问题4 - keras的concatenate和UpSampling1D函数
这两个函数都是keras里相对简单的函数,其中从concatenate函数即为简单的按列进行拼接,而UpSampling1D函数也是最简单简单重复采样函数。它们的实现的效果如下:
有疑问的小伙伴可以私聊我奥,冲啊一起变强!
留言:“半山腰太挤了,我想去山顶看看”