1、可以前往Mnist数据集 简单介绍_mnist百度百科-CSDN博客进行简单的了解,并将.idx3-ubyte文件下载下来。将将要生成的训练图片放在train文件夹下面,测试图片放在test文件夹下面。
2、工具类。
public class FileUtils {
/**
* 删除指定文件夹下所有文件
*
* @param path 文件夹完整绝对路径
*/
public static void delAllFile(String path) {
File file = new File(path);
if (!file.exists() || !file.isDirectory()) {
return;
}
String[] fileStrList = file.list();
File fileItem;
for (int i = 0; i < fileStrList.length; i++) {
if (path.endsWith(File.separator)) {
fileItem = new File(path + fileStrList[i]);
} else {
fileItem = new File(path + File.separator + fileStrList[i]);
}
if (fileItem.isFile()) {
fileItem.delete();
}
if (fileItem.isDirectory()) {
delAllFile(path + "//" + fileStrList[i]);// 先删除文件夹里面的文件
delFolder(path + "//" + fileStrList[i]);// 再删除空文件夹
}
}
return;
}
/**
* 删除文件夹
*
* @param folderPath folderPath文件夹完整绝对路径
*/
public static void delFolder(String folderPath) {
try {
delAllFile(folderPath); // 删除完里面所有内容
File folder = new File(folderPath);
folder.delete(); // 删除空文件夹
} catch (Exception e) {
e.printStackTrace();
}
}
}
public class ByteUtils {
public static int getTenHex(byte[] bytes) {
int result = 0;
for (int i = 0; i < bytes.length; i++) {
int move = bytes.length - i - 1;
int value = (bytes[i] & 255) << (8 * move);
result += value;
}
return result;
}
}
3、读取图片的方法,返回一个byte[][]数组,第一个[]代表的是第几个图片,第二个[]代表的是该图片的内容。记得用BufferedInputStream来读取,因为文件比较大。如果直接使用FileInputStream来读取的话,程序会卡住的。
//获取图片
private static byte[][] getMnistImg(String filePath) {
InputStream inputStream = null;//输入流
byte[][] imgArray = null;
try {
inputStream = new BufferedInputStream(new FileInputStream(filePath));
//读取基本信息
byte[] readBytes = new byte[4];
inputStream.read(readBytes);
System.out.println("读取到的幻数:" + ByteUtils.getTenHex(readBytes));
inputStream.read(readBytes);
int imgCount = ByteUtils.getTenHex(readBytes);
inputStream.read(readBytes);
imgWidth = ByteUtils.getTenHex(readBytes);
inputStream.read(readBytes);
imgHeight = ByteUtils.getTenHex(readBytes);
int imgSize = imgWidth * imgHeight;
System.out.println(String.format("一共有%d张图片。\n" + "每张图片每行%d个像素点,每列%d个像素点。" +
"\n每张图片一共有%d个像素点", imgCount, imgWidth, imgHeight, imgSize));
//读取每张图片
imgArray = new byte[imgCount][imgSize];
for (int i = 0; i < imgCount; i++) {
byte[] imgBytes = new byte[imgSize];
for (int j = 0; j < imgSize; j++) {
imgBytes[j] = (byte) inputStream.read();
}
imgArray[i] = imgBytes;
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (inputStream != null) {
inputStream.close();
}
} catch (Exception e) {
e.printStackTrace();
}
}
return imgArray;
}
4、读取图片所代表数字的方法,返回int[],其顺序是和步骤3中的图片顺序是一一对应的。
//获取图片所代表的数字
private static byte[] getMnistLable(String filePath) {
InputStream inputStream = null;//输入流
byte[] labelArray = null;
try {
inputStream = new BufferedInputStream(new FileInputStream(filePath));
//读取基本信息
byte[] readBytes = new byte[4];
inputStream.read(readBytes);
System.out.println("读取到的幻数:" + ByteUtils.getTenHex(readBytes));
inputStream.read(readBytes);
int labelCount = ByteUtils.getTenHex(readBytes);
System.out.println(String.format("一共有%d个标签", labelCount));
//读取每个标签
labelArray = new byte[labelCount];
for (int i = 0; i < labelCount; i++) {
labelArray[i] = (byte) inputStream.read();
}
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if (inputStream != null) {
inputStream.close();
}
} catch (Exception e) {
e.printStackTrace();
}
}
return labelArray;
}
5、主程序和保存图片的方法。因为我们从Mnist数据集中获取到的图片是8位灰度图像,每个像素存放在一个byte空间(8位,0-255:0表示最暗色,255表示最亮色)。8位灰度图像可以看成是一系列1位“位平面”的叠加。所以初始化BufferedImage的时候要用TYPE_INT_GRAY,只有8bit大小的存储空间,图片的存储空间也比其他如TYPE_INT_RGB类型生成的小。这样子在使用bufferedImage.setRGB的时候,按28*28大小的形状,将灰度值一个个的放进去就可以了。
private static final String Pre_Path = "G:\\xiaojie-java-test\\mnist\\";
//下载的测试集(二进制文件)。
private static final String Train_Img_Path = Pre_Path + "download\\train-images.idx3-ubyte";//训练集图像
private static final String Train_Label_Path = Pre_Path + "download\\train-labels.idx1-ubyte";//训练集标签(标签指明图像代表的意思)
//下载的测试集(二进制文件)。测试集的前5000个示例来自原始的NIST训练集。最后的5000个来自原始的NIST测试集。前5000个比后5000个更干净,更容易识别。
private static final String Test_Images_Path = Pre_Path + "download\\t10k-images.idx3-ubyte";//测试集图像
private static final String Test_Lable_Path = Pre_Path + "download\\t10k-labels.idx1-ubyte";//测试集标签(标签指明图像代表的意思)
//图片文件的保存地址
private static final String Train_Save_Path = Pre_Path + "train\\";
private static final String Test_Save_Path = Pre_Path + "test\\";
private static final int Img_Page_Count = 1000;
private static int imgWidth;
private static int imgHeight;
public static void main(String[] args) {
byte[][] mnistImg = getMnistImg(Test_Images_Path);
byte[] label = getMnistLable(Test_Lable_Path);
saveImg(mnistImg, label, Test_Save_Path);
}
//保存图片
private static void saveImg(byte[][] mnistImg, byte[] label, String savePath) {
BufferedImage bufferedImage;//输出流
try {
for (int i = 0; i < mnistImg.length; i++) {
byte[] imgArray = mnistImg[i];
//生成BufferedImage对象
bufferedImage = new BufferedImage(imgWidth, imgHeight, BufferedImage.TYPE_BYTE_GRAY);
for (int j = 0; j < imgHeight; j++) {
for (int k = 0; k < imgWidth; k++) {
//System.out.print(String.format("%4d", imgArray[j * imgWidth + k]));//可以在控制台打出图片的样子
bufferedImage.setRGB(k, j, imgArray[j * imgWidth + k]);
}
// System.out.println();//可以在控制台打出图片的样子
}
//生成文件夹 由于文件太多了,让其分成60个文件夹,每个文件夹里面有1000张图片
int pageIndex = i / Img_Page_Count + 1;
String filePath = savePath + (pageIndex < 10 ? "0" + pageIndex : pageIndex);
if (i % Img_Page_Count == 0) {
FileUtils.delAllFile(filePath);
System.out.println(String.format("准备生成第%d个有%d个图像的文件夹,", pageIndex, Img_Page_Count));
}
File file = new File(filePath);
if (!file.exists()) {
file.mkdir();
}
//生成图片
String fileName = filePath + "//" + label[i] + "_" + System.currentTimeMillis() + ".png";
ImageIO.write(bufferedImage, "PNG", new File(fileName));
Thread.sleep(1);//要休眠一下,不然有时图像会生成失败
}
} catch (Exception e) {
e.printStackTrace();
}
}
6、结果。演示的是生成测试集的图片,一共有10000个图片,生成10个文件夹,每个文件夹中有1000个图片。如果要生成训练集的图片,将main()方法中的,源文件和目标文件的路径改成相应的就可以了。
Connected to the target VM, address: '127.0.0.1:53725', transport: 'socket'
读取到的幻数:2051
一共有10000张图片。
每张图片每行28个像素点,每列28个像素点。
每张图片一共有784个像素点
读取到的幻数:2049
一共有10000个标签
准备生成第1个有1000个图像的文件夹,
准备生成第2个有1000个图像的文件夹,
准备生成第3个有1000个图像的文件夹,
准备生成第4个有1000个图像的文件夹,
准备生成第5个有1000个图像的文件夹,
准备生成第6个有1000个图像的文件夹,
准备生成第7个有1000个图像的文件夹,
准备生成第8个有1000个图像的文件夹,
准备生成第9个有1000个图像的文件夹,
准备生成第10个有1000个图像的文件夹,
Disconnected from the target VM, address: '127.0.0.1:53725', transport: 'socket'
Process finished with exit code 0