1.卷积操作
输入的m、n代表图片长宽,a、b代表卷积核大小(n >=a, m >=b),这里没有用权重,在每个滑动窗口内求最大值。
import java.util.Scanner;
public class Main {
public int[][] convolve(int[][] nums, int n, int m, int a, int b){
int[][] maxVal = new int[n-a+1][m-b+1];
// row 和 col总共有多少个
for(int row = 0; row < (n - a + 1); row++){
for(int col = 0; col < (m - b + 1); col++){
int maxData = 0;
// 内部的数组是怎么样的
for(int r = row; r < row + a; r++){
for(int c = col; c < col + b; c++){
if (nums[r][c] > maxData) maxData = nums[r][c];
}
}
maxVal[row][col] = maxData;
}
}
return maxVal;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int m = in.nextInt();
int a = in.nextInt();
int b = in.nextInt();
int[][] nums = new int[n][m];
for(int i = 0; i < n; i++){
for(int j = 0; j < m; j++){
nums[i][j] = in.nextInt();
}
}
Main ma = new Main();
int[][] maxVal = ma.convolve(nums, n, m, a, b);
for(int i = 0; i < maxVal.length; i++){
for(int j = 0; j < maxVal[0].length; j++){
System.out.print(maxVal[i][j] + " ");
}
}
}
}
2.池化操作
import java.util.Scanner;
public class Main {
public int[][] max_pool(int[][] nums, int n, int m, int a, int b){
// 向上取整,数组最后不满足窗口大小的部分也要池化
int i = (int)(n-1)/a + 1, j = (int)(m-1)/b + 1;
int[][] maxVal = new int[i][j];
// 池化和卷积不一样,它是一个窗口一个取最大值
// 然后到一个新的窗口
for(int row = 0, row_ = 0; row < n; row += a){
for(int col = 0, col_ = 0; col < m; col += b){
int maxData = 0;
int row_min = Math.min(row+a, n);
for(int r = row; r < row_min; r++){
int col_min = Math.min(col+b, m);
for(int c = col; c < col_min; c++){
if (nums[r][c] > maxData) maxData = nums[r][c];
}
}
maxVal[row_][col_] = maxData;
col_++;
}
row_++;
}
return maxVal;
}
public static void main(String[] args) {
Scanner in = new Scanner(System.in);
int n = in.nextInt();
int m = in.nextInt();
int a = in.nextInt();
int b = in.nextInt();
int[][] nums = new int[n][m];
for(int i = 0; i < n; i++){
for(int j = 0; j < m; j++){
nums[i][j] = in.nextInt();
}
}
Main ma = new Main();
int[][] maxVal = ma.max_pool(nums, n, m, a, b);
for(int i = 0; i < maxVal.length; i++){
for(int j = 0; j < maxVal[0].length; j++){
System.out.print(maxVal[i][j] + " ");
}
}
}
}