每个MNIST 手写字符图片包含28*28灰度像素。使用下面的例子,可以加载训练数据或测试数据集,可以选择查看某个图片对应的数据或图象
代码如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Thu Jan 4 16:24:07 2018
@author: wangd
"""
from tkinter import Tk, BOTH, RIGHT, RAISED,Text,X,SW,LEFT,INSERT,BOTTOM
from tkinter.ttk import Frame, Button, Style, Label, Entry
from tkinter import Listbox, StringVar, END
from tkinter import messagebox as msgbox
from MyMnistData import MyMnistData
import matplotlib.pyplot as plt
#Tk class is used to create root windows.
#frame is a container for other widgets.
class MnistView(Frame): #inherit from frame
def __init__(self):
super().__init__() #call the constructor of our inherited class
self.summary = StringVar() #for value of listbox and author
self.imageNum=0 #the total number of images
self.imgRows=0 #the number of rows in each image
self.imgCols=0 #the number of columns in each image
self.imgPoint=0 #the number of points in each image
self.imgs=None #matrix used to store images
self.labels=None #vector used to store lables
self.imageNumInput=StringVar('')
self.labelNumInput=StringVar('')
self.dataSetType=StringVar('')
self.currentImgNum=0
self.currentLb=0
self.imgArea=None
self.ctrFrame=None
self.infoFrame=None
self.infoTxt=None
self.initUI()
def initUI(self):
self.style = Style()
self.style.theme_use("default") #.theme_use("default")
# Tkinter support theming of widgets, imported from ttk module
self.master.title("Zhouwa MNIST Data View")
#self.master.config(bg='GREEN')
# set the title of the window, the master attribute gives access to root window
self.pack(fill=BOTH, expand=1)
#add another frame
self.addFrame()
self.addAbutton()
self.addControlInfo()
def addFrame(self):
#frame0 for copyright information
self.mainFrame = Frame(self, relief=RAISED, borderwidth=1)
self.mainFrame.pack(fill=X, expand=False)
lbl0=Label(self.mainFrame, text="This is a tool used to explore MNIST \
data, by Daoyi Wang, 2018-10-18",width=600)
lbl0.pack(side=LEFT, padx=5, pady=5)
#button area, for load data
ctrFrame = Frame(self,height=10)
ctrFrame.pack(fill=X)
lbl1 = Label(ctrFrame, text="Control Area", width=10)
lbl1.pack(side=LEFT, padx=5, pady=5)
self.ctrFrame=ctrFrame
#frame2 for general information display
infoFrame = Frame(self,height=10)
infoFrame.pack(fill=BOTH, expand=True)
lbl2 = Label(infoFrame, text="Summary", width=10)
lbl2.pack(side=LEFT, padx=5, pady=5)
self.infoFrame=infoFrame
infoTxt = Text(infoFrame,height=5)
infoTxt.pack(side=LEFT,fill=X, pady=5, padx=5, expand=True)
self.infoTxt=infoTxt
#data area
self.dataFrame = Frame(self, relief=RAISED, borderwidth=1,height=80)
self.dataFrame.pack(fill=BOTH, expand=True)
lbl3 = Label(self.dataFrame, text="Data Area",width=10)
lbl3.pack(side=LEFT, padx=5, pady=5)
viewImgBT = Button(self.dataFrame, text="ViewImg",command=self.onViewSingleImg)
viewImgBT.pack(side=BOTTOM, padx=5, pady=5,anchor=SW)
area = Text(self.dataFrame,width=50,height=70)
area.pack(side=LEFT, padx=5, pady=5)
self.imgArea=area
self.imgArea.insert(END,"The image will be displayed in this area!")
area2 = Text(self.dataFrame,width=50,height=70)
area2.pack(side=RIGHT, padx=5, pady=5)
self.imgArea2=area2
self.imgArea2.insert(END,"The simple image will be displayed in this area!")
def addListbox(self):
dataSetType = ['Test', 'Train']
lb = Listbox(self.ctrFrame,width=8,height=2)
for i in dataSetType:
lb.insert(END, i)
lb.bind("<<ListboxSelect>>", self.onSelect)
lb.pack(side=LEFT, padx=5, pady=5)
def onSelect(self, val):
sender = val.widget
#val is a VirtualEvent, and sender is a listbox
idx = sender.curselection()
#idx is the a tuple for the sequence number of the selected one
value = sender.get(idx)
self.dataSetType.set(value)
#used to store the name of author, which is assoc
def addAbutton(self):
self.addListbox()
loadButton = Button(self.ctrFrame, text="Load",command=self.onLoad)
loadButton.pack(side=LEFT, padx=10, pady=5)
viewImgBT = Button(self.ctrFrame, text="View",command=self.onViewImgBT)
viewImgBT.pack(side=LEFT, padx=5, pady=5)
nextImgBT = Button(self.ctrFrame, text="Next",command=self.onNextImgBT)
nextImgBT.pack(side=LEFT, padx=5, pady=5)
byLabelImgBT = Button(self.ctrFrame, text="By Label",command=self.onByLabelImgBT)
byLabelImgBT.pack(side=LEFT, padx=5, pady=5)
def addControlInfo(self):
lbl1 = Label(self.ctrFrame, text="ImgNum",width=8)
lbl1.pack(side=LEFT, padx=10, pady=2)
entry1 = Entry(self.ctrFrame, textvariable=self.imageNumInput,width=8)
entry1.pack(side=LEFT)
lbl2 = Label(self.ctrFrame, text="LabelNum",width=8)
lbl2.pack(side=LEFT, padx=2, pady=2)
entry2 = Entry(self.ctrFrame, textvariable=self.labelNumInput,width=4)
entry2.pack(side=LEFT)
def setimgPoints(self):
self.imgPoint=self.imgRows*self.imgCols
def updateSummary(self):
self.infoTxt.delete(1.0,END)
self.infoTxt.insert(INSERT,self.summary)
tempStr="\nthe number of points in each image is: "
tempStr=tempStr+str(self.imgPoint)
self.infoTxt.insert(INSERT,tempStr)
def initDataSet(self):
self.imageNum=0
self.imgRows=0
self.imgCols=0
self.imgs=None
self.labels=None
self.setimgPoints()
self.summary=''
self.updateSummary()
def onLoad(self):
myMnistData=MyMnistData()
myMnistData.setPath("d:\data\MNIST")
print(self.dataSetType.get())
self.initDataSet()
if self.dataSetType.get()=='Train':
myMnistData.loadTrainData()
tempstr=myMnistData.displaySumInfoOfTrainData()
self.imageNum=myMnistData.trainImgNum
self.imgRows=myMnistData.trainImgRows
self.imgCols=myMnistData.trainImgCols
self.imgs=myMnistData.trainImages
self.labels=myMnistData.trainLabels
self.setimgPoints()
self.summary=tempstr
self.updateSummary()
else:
myMnistData.loadTestData()
tempstr=myMnistData.displaySumInfoOfTestData()
self.imageNum=myMnistData.testImgNum
self.imgRows=myMnistData.testImgRows
self.imgCols=myMnistData.testImgCols
self.imgs=myMnistData.testImages
self.labels=myMnistData.testLabels
self.setimgPoints()
self.summary=tempstr
self.updateSummary()
def onViewImgBT(self):
try:
num=int(self.imageNumInput.get().lstrip())
except ValueError:
msgbox.showinfo('',"image number is not a integer!")
else:
self.setCurrentImgNum(num)
#string=str("the current image number is %s"%(self.getCurrentImgNum()))
#msgbox.showinfo('',string)
self.displayImage()
def onNextImgBT(self):
num=self.getCurrentImgNum()
num=num+1
self.setCurrentImgNum(num)
self.imageNumInput.set(str(num))
#string=str("the current image number is %s"%(self.getCurrentImgNum()))
#msgbox.showinfo('',string)
self.displayImage()
def onByLabelImgBT(self):
try:
num=int(self.labelNumInput.get().lstrip())
except ValueError:
msgbox.showinfo('',"label number is not a integer!")
else:
self.setCurrentLb(num)
#string=str("the current image number is %s"%(self.getCurrentLb()))
#msgbox.showinfo('',string)
self.displayImgByLabel()
def displayImage(self):
self.imgArea.delete(1.0,END)
imgIndex=self.getCurrentImgNum()
imgData=self.imgs[imgIndex].reshape(28, 28)
self.imgArea.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))
for i in range (0,28):
self.imgArea.insert(INSERT,'\n ')
for j in range (0,28):
colorStr=self.getColorStr(imgData[i,j])
self.imgArea.insert(INSERT,colorStr,'link')
self.displaySimpleImage()
def displaySimpleImage(self):
self.imgArea2.delete(1.0,END)
imgIndex=self.getCurrentImgNum()
imgData=self.imgs[imgIndex].reshape(28, 28)
self.imgArea2.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))
for i in range (0,28):
self.imgArea2.insert(INSERT,'\n ')
for j in range (0,28):
colorStr=self.getColorStr(imgData[i,j])
if colorStr=='00':
colorStr=' '
self.imgArea2.insert(INSERT,colorStr,'link')
def onViewSingleImg(self):
imgIndex=self.getCurrentImgNum()
imgData=self.imgs[imgIndex].reshape(28, 28)
plt.figure("the image for lable %s"%(self.getCurrentLb()))
ax=plt.subplot(111)
ax.imshow(imgData, cmap='Greys', interpolation='nearest')
def displayImgByLabel(self):
labelNum=self.getCurrentLb()
fig=plt.figure("All the image with label %d"%(labelNum))
ax = fig.subplots(
nrows=5,
ncols=5,
sharex=True,
sharey=True)
ax = ax.flatten()
for i in range (0,25):
img = self.imgs[self.labels == labelNum][i].reshape(28, 28)
ax[i].imshow(img, cmap='Greys', interpolation='nearest')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
#fig.set(title="All the image with label %d"%(labelNum))
plt.show()
def getColorStr(self,color):
tempColor=str(hex(color))[2:]
if len(tempColor)==1:
tempColor='0'+tempColor
return tempColor
def setCurrentImgNum(self,num):
self.currentImgNum=num
label=self.labels[num]
self.setCurrentLb(label)
def getCurrentImgNum(self):
return self.currentImgNum
def setCurrentLb(self,label):
self.currentLb=label
def getCurrentLb(self):
return self.currentLb
def centerWindow(root):
w = 1140
h = 850
sw = root.winfo_screenwidth()
sh = root.winfo_screenheight()
x = (sw - w)/2
y = (sh - h)/2
s='%dx%d+%d+%d' % (w, h, x, y);
root.title(s)
root.geometry(s)
def main():
root = Tk()
centerWindow(root)
app = MnistView()
#create the application frame
root.mainloop()
#the main loop begins to receive events and dispatches them to
#the application widgets.
if __name__ == '__main__':
main()
运行效果如下:(显示第7个测试样本)
点击 ViewImg之后的效果如下: