处理 yolov5-seg 的实例分割,主要是计算mask信息,进行 output1和output2的矩阵相乘,然后对结果矩阵取sigmod以及二值化。可视化如下:
和目标检测一样,主要是模型输出有两个矩阵:
1 * 25200 * 117:这个就是25200个框的信息,每个框用117位的浮点数组。其中1是边框置信度、2~5是 xywh 也就是中心点坐标和宽高、5~85 就是80个类别的概率,85~117就是32位的 mask信息。
1 * 32 * 160 * 160:这就是在原始模型新加的proto小型CNN,和前面 32 mask信息做矩阵乘法就得到轮廓信息
和目标检测的区别在于,需要将数据1的后32位,乘数据2的32*160*160矩阵来得到目标的 mask 信息。参考 yolov5 源码下 general.py 中的 process_mask 函数,流程如下:
将 output1 的 32*160*160 也就是 float[][][] 转换成二维矩阵 32*25600 的 float[][],相当于把160*160展平,得到矩阵1,是 32行、25600列,
将 output2 的每个检测框的 117 取最后32位,转为一维矩阵 float[],得到矩阵2,是1行,32列。
然后将矩阵2和矩阵1乘法,1*32 乘 32*25600 ,得到矩阵3,是 1*25600。
将矩阵3再转换为二位矩阵 160*160,再进行插值扩展到原始图像大小得到矩阵4。
再对矩阵4每个元素进行sigmod以及二值化到0/2。根据目标检测框标注、就可以看到掩膜信息。
后续再对截取的掩膜信息在原图上面标注轮廓信息即可。下面是参考 yolov5 源码下 general.py 中的 process_mask 函数的 java 版本,实现掩膜可视化:
// 参考 general.py 中的 process_mask 函数
public static void generateMaskInfo(List<Detection> detections,float[][][][] proto,int width,int height){
// 32 * 160 * 160 这个是mask原型 c h w
float[][][] maskSrc = proto[0];
// 转为二维矩阵也就是 32 * 25600,也就是 32 行 25600 列,相当于把 160*160展平
float[][] flattenedData = floatArray2floatArray(maskSrc);
// 再转为矩阵
RealMatrix m1 = MatrixUtils.createRealMatrix(floatArray2doubleArray(flattenedData));
// 每个目标框
detections.stream().forEach(detection -> {
// 32 这个是mask 掩膜系数,也就是权重,转为矩阵
float[] maskWeight = detection.getMaskWeight();
// 作为一个行向量存储在m1中,也就是 1 行 32 列
RealMatrix m2 = MatrixUtils.createRowRealMatrix(floatArray2doubleArray(maskWeight));
// 矩阵乘法 1*32 乘 32*25600 得到 1*25600
RealMatrix m3 = m2.multiply(m1);
// 再将 1*25600 转回 160*160 也就是一个缩小的掩膜图
RealMatrix m4 = transfer_25600_To_160_160(m3);
// 对每个元素求sigmod限制到0~1,后续根据阈值进行二值化
RealMatrix m5 = getSigmod(m4);
// 将160*160上采样到图片原始尺寸
RealMatrix m6 = resizeRealMatrix(m5,height,width);
// 目标在原始图片上的xyxy
showMatrixWithBox(m6,detection.getSrcXYXY()[0],detection.getSrcXYXY()[1],detection.getSrcXYXY()[2],detection.getSrcXYXY()[3]);
});
}
public static float[][] floatArray2floatArray(float[][][] data){
float[][] flattenedData = new float[data.length][data[0].length * data[0][0].length];
for (int i = 0; i < data.length; i++) {
float[][] slice = data[i];
for (int j = 0; j < slice.length; j++) {
System.arraycopy(slice[j], 0, flattenedData[i], j * slice[j].length, slice[j].length);
}
}
return flattenedData;
}
public static double[] floatArray2doubleArray(float[] data){
double[] maskDouble = new double[data.length];
for (int j = 0; j < data.length; j++) {
maskDouble[j] = (double) data[j];
}
return maskDouble;
}
public static double[][] floatArray2doubleArray(float[][] data){
double[][] maskDouble = new double[data.length][data[0].length];
for (int i = 0; i < data.length; i++) {
for(int j=0; j<data[0].length;j++){
maskDouble[i][j] = data[i][j];
}
}
return maskDouble;
}
// 再将 1*25600 转回 160*160
public static RealMatrix transfer_25600_To_160_160(RealMatrix data){
RealMatrix res = new Array2DRowRealMatrix(160, 160);
for (int i = 0; i < 160; i++) {
for (int j = 0; j < 160; j++) {
int index = i * 160 + j;
double value = data.getEntry(0, index);
res.setEntry(i, j, value);
}
}
return res;
}
public static RealMatrix resizeRealMatrix(RealMatrix matrix, int newRows, int newCols) {
int rows = matrix.getRowDimension();
int cols = matrix.getColumnDimension();
RealMatrix resizedMatrix = MatrixUtils.createRealMatrix(newRows, newCols);
for (int i = 0; i < newRows; i++) {
for (int j = 0; j < newCols; j++) {
int origI = (int) Math.floor(i * rows / newRows);
int origJ = (int) Math.floor(j * cols / newCols);
double d = matrix.getEntry(origI, origJ);
// if(d>=maskThreshold){
// d = 1;
// }else{
// d = 0;
// }
resizedMatrix.setEntry(i, j, d);
}
}
return resizedMatrix;
}
// 弹窗显示一个 showMatrix 并画框
public static void showMatrixWithBox(RealMatrix matrix,float xmin,float ymin,float xmax,float ymax){
// 转换 RealMatrix to BufferedImage
int numRows = matrix.getRowDimension();
int numCols = matrix.getColumnDimension();
BufferedImage image = new BufferedImage(numCols, numRows, BufferedImage.TYPE_INT_RGB);
for (int i = 0; i < numRows; i++) {
for (int j = 0; j < numCols; j++) {
double value = matrix.getEntry(i, j);
int grayValue = (int) Math.round(value * 255.0);
grayValue = Math.min(grayValue, 255);
grayValue = Math.max(grayValue, 0);
int pixelValue = (grayValue << 16) | (grayValue << 8) | grayValue;
image.setRGB(j, i, pixelValue);
}
}
// image 上画框
Graphics2D graph = image.createGraphics();
graph.setStroke(new BasicStroke(3));// 线粗细
graph.setColor(Color.RED);
// 画矩形
graph.drawRect(
Float.valueOf(xmin).intValue(),
Float.valueOf(ymin).intValue(),
Float.valueOf(xmax-xmin).intValue(),
Float.valueOf(ymax-ymin).intValue());
// 提交画框
graph.dispose();
// 弹窗显示
JFrame frame = new JFrame("Image Dialog");
frame.setSize(image.getWidth(), image.getHeight());
JLabel label = new JLabel(new ImageIcon(image));
frame.getContentPane().add(label);
frame.setVisible(true);
frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE)
}