MNIST 数据集可视化示例

每个MNIST 手写字符图片包含28*28灰度像素。使用下面的例子,可以加载训练数据或测试数据集,可以选择查看某个图片对应的数据或图象

代码如下:

​
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Created on Thu Jan  4 16:24:07 2018
@author: wangd
"""
from tkinter import Tk, BOTH, RIGHT, RAISED,Text,X,SW,LEFT,INSERT,BOTTOM
from tkinter.ttk import Frame, Button, Style, Label, Entry
from tkinter import Listbox, StringVar, END
from tkinter import messagebox as msgbox

from MyMnistData import MyMnistData
import matplotlib.pyplot as plt

#Tk class is used to create root windows.
#frame is a container for other widgets.

class MnistView(Frame):  #inherit from frame
  
    def __init__(self):
        super().__init__()  #call the constructor of our inherited class

        self.summary = StringVar()  #for value of listbox and author       
        self.imageNum=0 #the total number of images
        self.imgRows=0 #the number of rows in each image
        self.imgCols=0 #the number of columns in each image
        self.imgPoint=0 #the number of points in each image
        self.imgs=None #matrix used to store images
        self.labels=None #vector used to store lables  
               
        self.imageNumInput=StringVar('')
        self.labelNumInput=StringVar('')
        self.dataSetType=StringVar('')
        self.currentImgNum=0
        self.currentLb=0
        
        self.imgArea=None
        self.ctrFrame=None
        self.infoFrame=None
        self.infoTxt=None
        
        self.initUI() 

        
    def initUI(self):
        self.style = Style()
        self.style.theme_use("default") #.theme_use("default")
        # Tkinter support theming of widgets, imported from ttk module

        self.master.title("Zhouwa MNIST Data View")
        #self.master.config(bg='GREEN')
        # set the title of the window, the master attribute gives access to root window
        
        self.pack(fill=BOTH, expand=1)
        
        #add another frame
        self.addFrame()
        self.addAbutton()
        self.addControlInfo()
      


    def addFrame(self):
        #frame0 for copyright information
        self.mainFrame = Frame(self, relief=RAISED, borderwidth=1)
        self.mainFrame.pack(fill=X, expand=False)  
        lbl0=Label(self.mainFrame, text="This is a tool used to explore MNIST \
data, by Daoyi Wang, 2018-10-18",width=600)
        lbl0.pack(side=LEFT, padx=5, pady=5)   

       #button area, for load data            
        ctrFrame = Frame(self,height=10)
        ctrFrame.pack(fill=X)       
        lbl1 = Label(ctrFrame, text="Control Area", width=10)
        lbl1.pack(side=LEFT, padx=5, pady=5)    
        self.ctrFrame=ctrFrame
        
        #frame2 for general information display
        infoFrame = Frame(self,height=10)
        infoFrame.pack(fill=BOTH, expand=True)      
        lbl2 = Label(infoFrame, text="Summary", width=10)
        lbl2.pack(side=LEFT, padx=5, pady=5)   
        self.infoFrame=infoFrame
        
        infoTxt = Text(infoFrame,height=5)
        infoTxt.pack(side=LEFT,fill=X, pady=5, padx=5, expand=True) 
        self.infoTxt=infoTxt
        
        #data area
        self.dataFrame = Frame(self, relief=RAISED, borderwidth=1,height=80)
        self.dataFrame.pack(fill=BOTH, expand=True)  
        
        lbl3 = Label(self.dataFrame, text="Data Area",width=10)
        lbl3.pack(side=LEFT, padx=5, pady=5) 
        
        viewImgBT = Button(self.dataFrame, text="ViewImg",command=self.onViewSingleImg)
        viewImgBT.pack(side=BOTTOM, padx=5, pady=5,anchor=SW)

        area = Text(self.dataFrame,width=50,height=70)
        area.pack(side=LEFT, padx=5, pady=5)
        self.imgArea=area
        self.imgArea.insert(END,"The image will be displayed in this area!")

        
        area2 = Text(self.dataFrame,width=50,height=70)
        area2.pack(side=RIGHT, padx=5, pady=5)
        self.imgArea2=area2
        self.imgArea2.insert(END,"The simple image will be displayed in this area!")
        
    def addListbox(self):
        dataSetType = ['Test', 'Train']

        lb = Listbox(self.ctrFrame,width=8,height=2)
        
        for i in dataSetType:
            lb.insert(END, i)         
        lb.bind("<<ListboxSelect>>", self.onSelect)               
        lb.pack(side=LEFT, padx=5, pady=5)


    def onSelect(self, val):
      
        sender = val.widget
        #val is a VirtualEvent, and sender is a listbox
        idx = sender.curselection()
        
        #idx is the a tuple for the sequence number of the selected one
        value = sender.get(idx)  

        self.dataSetType.set(value) 
        #used to store the name of author, which is assoc


    def addAbutton(self):
        self.addListbox()
        loadButton = Button(self.ctrFrame, text="Load",command=self.onLoad)
        loadButton.pack(side=LEFT, padx=10, pady=5)
        
        viewImgBT = Button(self.ctrFrame, text="View",command=self.onViewImgBT)
        viewImgBT.pack(side=LEFT, padx=5, pady=5)

        nextImgBT = Button(self.ctrFrame, text="Next",command=self.onNextImgBT)
        nextImgBT.pack(side=LEFT, padx=5, pady=5)

        byLabelImgBT = Button(self.ctrFrame, text="By Label",command=self.onByLabelImgBT)
        byLabelImgBT.pack(side=LEFT, padx=5, pady=5)
                  
    
    def addControlInfo(self):
        
        lbl1 = Label(self.ctrFrame, text="ImgNum",width=8)
        lbl1.pack(side=LEFT, padx=10, pady=2) 
        entry1 = Entry(self.ctrFrame, textvariable=self.imageNumInput,width=8) 
        entry1.pack(side=LEFT)   

        lbl2 = Label(self.ctrFrame, text="LabelNum",width=8)
        lbl2.pack(side=LEFT, padx=2, pady=2) 
        entry2 = Entry(self.ctrFrame, textvariable=self.labelNumInput,width=4)  
        entry2.pack(side=LEFT)  
        
    def setimgPoints(self):
        self.imgPoint=self.imgRows*self.imgCols

    def updateSummary(self):
        self.infoTxt.delete(1.0,END)
        self.infoTxt.insert(INSERT,self.summary)
        tempStr="\nthe number of points in each image is: "
        tempStr=tempStr+str(self.imgPoint)
        self.infoTxt.insert(INSERT,tempStr)
           
    def initDataSet(self):
        self.imageNum=0
        self.imgRows=0
        self.imgCols=0
        self.imgs=None
        self.labels=None 
        self.setimgPoints()
        self.summary=''
        self.updateSummary()
        
    def onLoad(self):
        myMnistData=MyMnistData()
        myMnistData.setPath("d:\data\MNIST")
        print(self.dataSetType.get())
        self.initDataSet()

        if self.dataSetType.get()=='Train':
            myMnistData.loadTrainData()
            tempstr=myMnistData.displaySumInfoOfTrainData()
            self.imageNum=myMnistData.trainImgNum
            self.imgRows=myMnistData.trainImgRows
            self.imgCols=myMnistData.trainImgCols
            self.imgs=myMnistData.trainImages
            self.labels=myMnistData.trainLabels  
            self.setimgPoints()
            self.summary=tempstr
            self.updateSummary()
        else:
            myMnistData.loadTestData()
            tempstr=myMnistData.displaySumInfoOfTestData()
            self.imageNum=myMnistData.testImgNum
            self.imgRows=myMnistData.testImgRows
            self.imgCols=myMnistData.testImgCols
            self.imgs=myMnistData.testImages
            self.labels=myMnistData.testLabels  
            self.setimgPoints()
            self.summary=tempstr
            self.updateSummary()
            
    def onViewImgBT(self):
        try:
            num=int(self.imageNumInput.get().lstrip())
        except ValueError:
            msgbox.showinfo('',"image number is not a integer!")    
        else:
            self.setCurrentImgNum(num)
            #string=str("the current image number is %s"%(self.getCurrentImgNum()))
            #msgbox.showinfo('',string)
            self.displayImage()
 
        
    def onNextImgBT(self): 
        num=self.getCurrentImgNum()
        num=num+1
        self.setCurrentImgNum(num)
        self.imageNumInput.set(str(num))
        #string=str("the current image number is %s"%(self.getCurrentImgNum()))
        #msgbox.showinfo('',string)   
        self.displayImage()

        
    def onByLabelImgBT(self):
        try:
            num=int(self.labelNumInput.get().lstrip())
        except ValueError:
            msgbox.showinfo('',"label number is not a integer!")    
        else:
            self.setCurrentLb(num)
            #string=str("the current image number is %s"%(self.getCurrentLb()))
            #msgbox.showinfo('',string)  
            self.displayImgByLabel()
        
    def displayImage(self):
        self.imgArea.delete(1.0,END)        
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        
        self.imgArea.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))

        for i in range (0,28):
            self.imgArea.insert(INSERT,'\n  ')
            for j in range (0,28):
                colorStr=self.getColorStr(imgData[i,j])
                self.imgArea.insert(INSERT,colorStr,'link')     
        self.displaySimpleImage()
 
    def displaySimpleImage(self):
        self.imgArea2.delete(1.0,END)       
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        
        self.imgArea2.tag_config('link', foreground='red',font=('Courier', 7, 'bold'))

        for i in range (0,28):
            self.imgArea2.insert(INSERT,'\n  ')
            for j in range (0,28):
                colorStr=self.getColorStr(imgData[i,j])
                if colorStr=='00':
                    colorStr='  '
                self.imgArea2.insert(INSERT,colorStr,'link')
                
    def onViewSingleImg(self):
        imgIndex=self.getCurrentImgNum()       
        imgData=self.imgs[imgIndex].reshape(28, 28) 
        plt.figure("the image for lable %s"%(self.getCurrentLb()))
        ax=plt.subplot(111)
        ax.imshow(imgData, cmap='Greys', interpolation='nearest')
    
    def displayImgByLabel(self):
        labelNum=self.getCurrentLb()
        fig=plt.figure("All the image with label %d"%(labelNum))
        ax = fig.subplots(
                nrows=5,
                ncols=5,
                sharex=True,
                sharey=True)
        ax = ax.flatten()
        
        for i in range (0,25):
            img = self.imgs[self.labels == labelNum][i].reshape(28, 28)
            ax[i].imshow(img, cmap='Greys', interpolation='nearest')
        ax[0].set_xticks([])
        ax[0].set_yticks([])
        plt.tight_layout()
        #fig.set(title="All the image with label %d"%(labelNum))
        plt.show()

        
    def getColorStr(self,color):
        tempColor=str(hex(color))[2:]
        if len(tempColor)==1:
            tempColor='0'+tempColor
        return tempColor
    
    def setCurrentImgNum(self,num):
        self.currentImgNum=num
        label=self.labels[num]
        self.setCurrentLb(label)
        
    def getCurrentImgNum(self):
        return self.currentImgNum
        
    def setCurrentLb(self,label):
         self.currentLb=label
         
    def getCurrentLb(self):
        return self.currentLb        

def centerWindow(root):
    w = 1140
    h = 850
    sw = root.winfo_screenwidth()
    sh = root.winfo_screenheight()
    x = (sw - w)/2
    y = (sh - h)/2
    
    s='%dx%d+%d+%d' % (w, h, x, y);
    root.title(s)
    root.geometry(s)

def main():
  
    root = Tk()
    centerWindow(root)
    app = MnistView()
    #create the application frame
   
    root.mainloop()  
    #the main loop begins to receive events and dispatches them to 
    #the application widgets.


if __name__ == '__main__':
    main()
    

​

运行效果如下:(显示第7个测试样本)

点击 ViewImg之后的效果如下:

 

 

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

道道1972

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值