import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
from PIL import Image, ImageDraw
import cv2
import os
import csv
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
# 设置中文字体和负号显示
plt.rcParams["font.family"] = ["SimHei", "Microsoft YaHei"]
plt.rcParams["axes.unicode_minus"] = False
# 尝试导入XGBoost和LightGBM
XGB_INSTALLED = False
LGB_INSTALLED = False
try:
import xgboost as xgb
XGB_INSTALLED = True
except ImportError:
print("警告: 未安装XGBoost库,无法使用XGBoost模型")
try:
import lightgbm as lgb
LGB_INSTALLED = True
except ImportError:
print("警告: 未安装LightGBM库,无法使用LightGBM模型")
# 定义模型元数据常量(优化参数)
MODEL_METADATA = {
'svm': ('支持向量机(SVM)', SVC, StandardScaler, {'probability': True, 'random_state': 42}),
'dt': ('决策树(DT)', DecisionTreeClassifier, None, {'random_state': 42}),
'rf': ('随机森林(RF)', RandomForestClassifier, None, {'n_estimators': 100, 'random_state': 42}),
'mlp': ('多层感知机(MLP)', MLPClassifier, StandardScaler, {'hidden_layer_sizes': (100, 50), 'max_iter': 500, 'random_state': 42}),
'knn': ('K最近邻(KNN)', KNeighborsClassifier, StandardScaler, {'n_neighbors': 5, 'weights': 'distance'}),
'nb': ('高斯朴素贝叶斯(NB)', GaussianNB, None, {}),
}
# 添加可选模型
if XGB_INSTALLED:
MODEL_METADATA['xgb'] = ('XGBoost(XGB)', xgb.XGBClassifier, None, {'objective': 'multi:softmax', 'random_state': 42})
if LGB_INSTALLED:
MODEL_METADATA['lgb'] = ('LightGBM(LGB)', lgb.LGBMClassifier, None, {
'objective': 'multiclass',
'random_state': 42,
'num_class': 10,
'max_depth': 5,
'min_child_samples': 10,
'learning_rate': 0.1,
'force_col_wise': True
})
class ModelFactory:
@staticmethod
def get_split_data(digits_dataset):
"""数据集划分"""
X, y = digits_dataset.data, digits_dataset.target
return train_test_split(X, y, test_size=0.3, random_state=42)
@classmethod
def create_model(cls, model_type):
"""创建模型和数据标准化器"""
if model_type not in MODEL_METADATA:
raise ValueError(f"未知模型类型: {model_type}")
name, model_cls, scaler_cls, params = MODEL_METADATA[model_type]
if not model_cls:
raise ImportError(f"{name}模型依赖库未安装")
model = model_cls(**params)
scaler = scaler_cls() if scaler_cls else None
return model, scaler
@staticmethod
def train_model(model, X_train, y_train, scaler=None, model_type=None):
"""训练模型"""
if scaler:
X_train = scaler.fit_transform(X_train)
if model_type == 'lgb' and isinstance(X_train, np.ndarray):
X_train = pd.DataFrame(X_train)
model.fit(X_train, y_train)
return model
@staticmethod
def evaluate_model(model, X_test, y_test, scaler=None, model_type=None):
"""评估模型"""
if scaler:
X_test = scaler.transform(X_test)
if model_type == 'lgb' and isinstance(X_test, np.ndarray) and hasattr(model, 'feature_name_'):
X_test = pd.DataFrame(X_test, columns=model.feature_name_)
y_pred = model.predict(X_test)
return accuracy_score(y_test, y_pred)
@classmethod
def train_and_evaluate(cls, model_type, X_train, y_train, X_test, y_test):
"""训练并评估模型"""
try:
model, scaler = cls.create_model(model_type)
model = cls.train_model(model, X_train, y_train, scaler, model_type)
accuracy = cls.evaluate_model(model, X_test, y_test, scaler, model_type)
return model, scaler, accuracy
except Exception as e:
print(f"模型 {model_type} 训练/评估错误: {str(e)}")
raise
@classmethod
def evaluate_all_models(cls, digits_dataset):
"""评估所有可用模型"""
print("\n=== 模型评估 ===")
X_train, X_test, y_train, y_test = cls.get_split_data(digits_dataset)
results = []
for model_type in MODEL_METADATA:
name = MODEL_METADATA[model_type][0]
print(f"评估模型: {name} ({model_type})")
if not MODEL_METADATA[model_type][1]:
results.append({"模型名称": name, "准确率": "N/A"})
continue
try:
_, _, accuracy = cls.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
results.append({"模型名称": name, "准确率": f"{accuracy:.4f}"})
except Exception as e:
results.append({"模型名称": name, "准确率": f"错误: {str(e)}"})
# 按准确率排序
results.sort(
key=lambda x: float(x["准确率"])
if isinstance(x["准确率"], str) and x["准确率"].replace('.', '', 1).isdigit()
else -1,
reverse=True
)
print(pd.DataFrame(results))
return results
class HandwritingBoard:
CANVAS_SIZE = 300 # 固定画布尺寸
BRUSH_SIZE = 12 # 画笔大小
def __init__(self, root, model_factory, digits):
self.root = root
self.root.title("手写数字识别系统")
self.root.geometry("1000x700") # 增加窗口尺寸以容纳所有组件
self.model_factory = model_factory
self.digits = digits
self.model_cache = {}
self.current_model = None
self.scaler = None
self.current_model_type = None
self.has_drawn = False
self.custom_data = []
self.drawing = False
self.last_x = self.last_y = 0
# 自定义数据目录
self.data_dir = "custom_digits_data"
os.makedirs(self.data_dir, exist_ok=True)
# 初始化画布
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
self.draw_obj = ImageDraw.Draw(self.image)
self.create_widgets()
self.init_default_model()
def create_widgets(self):
"""使用grid布局管理器创建界面组件"""
# 创建主框架
main_frame = tk.Frame(self.root)
main_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)
# 使用grid布局管理器
# 第一行:模型选择区域
model_frame = tk.LabelFrame(main_frame, text="模型选择", font=("Arial", 10, "bold"))
model_frame.grid(row=0, column=0, columnspan=2, sticky="ew", padx=5, pady=5)
model_frame.grid_columnconfigure(1, weight=1) # 让模型标签可以扩展
tk.Label(model_frame, text="选择模型:", font=("Arial", 10)).grid(row=0, column=0, padx=5, pady=5, sticky="w")
self.available_models = []
for model_type, (name, _, _, _) in MODEL_METADATA.items():
if MODEL_METADATA[model_type][1]:
self.available_models.append((model_type, name))
self.model_var = tk.StringVar()
self.model_combobox = ttk.Combobox(
model_frame,
textvariable=self.model_var,
values=[name for _, name in self.available_models],
state="readonly",
width=25,
font=("Arial", 10)
)
self.model_combobox.current(0)
self.model_combobox.bind("<<ComboboxSelected>>", self.on_model_select)
self.model_combobox.grid(row=0, column=1, padx=5, pady=5, sticky="ew")
self.model_label = tk.Label(
model_frame,
text="",
font=("Arial", 10),
relief=tk.SUNKEN,
padx=5,
pady=2
)
self.model_label.grid(row=0, column=2, padx=5, pady=5, sticky="ew")
# 第二行:左侧绘图区域和右侧结果区域
# 左侧绘图区域
left_frame = tk.LabelFrame(main_frame, text="绘制区域", font=("Arial", 10, "bold"))
left_frame.grid(row=1, column=0, padx=5, pady=5, sticky="nsew")
self.canvas = tk.Canvas(left_frame, bg="white", width=self.CANVAS_SIZE, height=self.CANVAS_SIZE)
self.canvas.pack(padx=10, pady=10)
self.canvas.bind("<Button-1>", self.start_draw)
self.canvas.bind("<B1-Motion>", self.draw)
self.canvas.bind("<ButtonRelease-1>", self.stop_draw)
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
# 绘图控制按钮
btn_frame = tk.Frame(left_frame)
btn_frame.pack(fill=tk.X, pady=(0, 10))
tk.Button(btn_frame, text="识别", command=self.recognize, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="清除", command=self.clear_canvas, width=8).pack(side=tk.LEFT, padx=5)
tk.Button(btn_frame, text="样本", command=self.show_samples, width=8).pack(side=tk.LEFT, padx=5)
# 右侧结果区域
right_frame = tk.Frame(main_frame)
right_frame.grid(row=1, column=1, padx=5, pady=5, sticky="nsew")
# 识别结果
result_frame = tk.LabelFrame(right_frame, text="识别结果", font=("Arial", 10, "bold"))
result_frame.pack(fill=tk.X, padx=5, pady=5)
self.result_label = tk.Label(
result_frame,
text="请绘制数字",
font=("Arial", 24),
pady=10
)
self.result_label.pack()
self.prob_label = tk.Label(
result_frame,
text="",
font=("Arial", 12)
)
self.prob_label.pack()
# 置信度可视化
confidence_frame = tk.LabelFrame(right_frame, text="识别置信度", font=("Arial", 10, "bold"))
confidence_frame.pack(fill=tk.X, padx=5, pady=5)
self.confidence_canvas = tk.Canvas(
confidence_frame,
bg="white",
height=50
)
self.confidence_canvas.pack(fill=tk.X, padx=10, pady=10)
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
# 候选数字
candidates_frame = tk.LabelFrame(right_frame, text="可能的数字", font=("Arial", 10, "bold"))
candidates_frame.pack(fill=tk.X, padx=5, pady=5)
columns = ("数字", "概率")
self.candidates_tree = ttk.Treeview(
candidates_frame,
columns=columns,
show="headings",
height=4
)
for col in columns:
self.candidates_tree.heading(col, text=col)
self.candidates_tree.column(col, width=80, anchor=tk.CENTER)
scrollbar = ttk.Scrollbar(
candidates_frame,
orient=tk.VERTICAL,
command=self.candidates_tree.yview
)
self.candidates_tree.configure(yscroll=scrollbar.set)
self.candidates_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5)
# 第三行:模型性能对比和训练集管理
# 模型性能对比
performance_frame = tk.LabelFrame(main_frame, text="模型性能对比", font=("Arial", 10, "bold"))
performance_frame.grid(row=2, column=0, padx=5, pady=5, sticky="nsew")
columns = ("模型名称", "准确率")
self.performance_tree = ttk.Treeview(
performance_frame,
columns=columns,
show="headings",
height=8
)
for col in columns:
self.performance_tree.heading(col, text=col)
self.performance_tree.column(col, width=120, anchor=tk.CENTER)
scrollbar = ttk.Scrollbar(
performance_frame,
orient=tk.VERTICAL,
command=self.performance_tree.yview
)
self.performance_tree.configure(yscroll=scrollbar.set)
self.performance_tree.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
scrollbar.pack(side=tk.RIGHT, fill=tk.Y, padx=5, pady=5)
# 训练集管理
train_frame = tk.LabelFrame(main_frame, text="训练集管理", font=("Arial", 10, "bold"))
train_frame.grid(row=2, column=1, padx=5, pady=5, sticky="nsew")
# 使用grid布局训练集管理按钮
tk.Button(
train_frame,
text="保存为训练样本",
command=self.save_as_training_sample,
width=18,
height=2
).grid(row=0, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="保存全部训练集",
command=self.save_all_training_data,
width=18,
height=2
).grid(row=0, column=1, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="加载训练集",
command=self.load_training_data,
width=18,
height=2
).grid(row=1, column=0, padx=5, pady=5, sticky="ew")
tk.Button(
train_frame,
text="性能图表",
command=self.show_performance_chart,
width=18,
height=2
).grid(row=1, column=1, padx=5, pady=5, sticky="ew")
# 状态信息
self.status_var = tk.StringVar(value="就绪")
status_bar = tk.Label(
self.root,
textvariable=self.status_var,
bd=1,
relief=tk.SUNKEN,
anchor=tk.W,
font=("Arial", 10)
)
status_bar.pack(side=tk.BOTTOM, fill=tk.X)
# 配置权重
main_frame.grid_columnconfigure(0, weight=1)
main_frame.grid_columnconfigure(1, weight=1)
main_frame.grid_rowconfigure(1, weight=1)
main_frame.grid_rowconfigure(2, weight=1)
def start_draw(self, event):
"""开始绘制"""
self.drawing = True
self.last_x, self.last_y = event.x, event.y
def draw(self, event):
"""绘制"""
if not self.drawing:
return
x, y = event.x, event.y
# 在画布上绘制
self.canvas.create_line(
self.last_x, self.last_y, x, y,
fill="black",
width=self.BRUSH_SIZE,
capstyle=tk.ROUND,
smooth=True
)
# 在图像上绘制
self.draw_obj.line(
[self.last_x, self.last_y, x, y],
fill=0,
width=self.BRUSH_SIZE
)
self.last_x, self.last_y = x, y
def stop_draw(self, event):
"""停止绘制"""
self.drawing = False
self.has_drawn = True
self.status_var.set("已绘制数字,点击'识别'进行识别")
def clear_canvas(self):
"""清除画布"""
self.canvas.delete("all")
self.image = Image.new("L", (self.CANVAS_SIZE, self.CANVAS_SIZE), 255)
self.draw_obj = ImageDraw.Draw(self.image)
# 添加绘制提示
self.canvas.create_text(
self.CANVAS_SIZE / 2, self.CANVAS_SIZE / 2,
text="绘制数字", fill="gray", font=("Arial", 16)
)
self.result_label.config(text="请绘制数字")
self.prob_label.config(text="")
self.clear_confidence_display()
self.has_drawn = False
self.status_var.set("画布已清除")
def clear_confidence_display(self):
"""清除置信度显示"""
self.confidence_canvas.delete("all")
self.confidence_canvas.create_text(
150, 25,
text="识别后显示置信度",
fill="gray",
font=("Arial", 10)
)
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
def preprocess_image(self):
"""预处理手写数字图像"""
img_array = np.array(self.image)
# 高斯模糊降噪
img_array = cv2.GaussianBlur(img_array, (5, 5), 0)
# 二值化
_, img_array = cv2.threshold(img_array, 127, 255, cv2.THRESH_BINARY_INV)
# 轮廓检测
contours, _ = cv2.findContours(img_array, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
self.status_var.set("未检测到有效数字,请重新绘制")
return None
# 找到最大轮廓
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
# 提取数字区域
digit = img_array[y:y+h, x:x+w]
# 填充为正方形
size = max(w, h)
padded = np.ones((size, size), dtype=np.uint8) * 255
offset_x = (size - w) // 2
offset_y = (size - h) // 2
padded[offset_y:offset_y+h, offset_x:offset_x+w] = digit
# 缩放为8x8
resized = cv2.resize(padded, (8, 8), interpolation=cv2.INTER_AREA)
# 归一化
normalized = 16 - (resized / 255 * 16).astype(np.uint8)
return normalized.flatten()
def recognize(self):
"""识别手写数字"""
if not self.has_drawn:
self.status_var.set("请先绘制数字再识别")
return
if self.current_model is None:
self.status_var.set("模型未加载,请选择模型")
return
# 预处理图像
img_array = self.preprocess_image()
if img_array is None:
return
img_input = img_array.reshape(1, -1)
try:
# 标准化
if self.scaler:
img_input = self.scaler.transform(img_input)
# LightGBM特殊处理
if self.current_model_type == 'lgb' and hasattr(self.current_model, 'feature_name_'):
img_input = pd.DataFrame(img_input, columns=self.current_model.feature_name_)
# 预测
pred = self.current_model.predict(img_input)[0]
self.result_label.config(text=f"识别结果: {pred}")
# 概率预测
if hasattr(self.current_model, 'predict_proba'):
probs = self.current_model.predict_proba(img_input)[0]
confidence = probs[pred]
# 更新UI
self.prob_label.config(text=f"置信度: {confidence:.2%}")
self.update_confidence_display(confidence)
# 显示候选数字
top3 = sorted(enumerate(probs), key=lambda x: -x[1])[:3]
self.update_candidates_display(top3)
else:
self.prob_label.config(text="该模型不支持概率输出")
self.clear_confidence_display()
self.status_var.set(f"识别完成: 数字 {pred}")
except Exception as e:
self.status_var.set(f"识别错误: {str(e)}")
self.clear_confidence_display()
def update_confidence_display(self, confidence):
"""更新置信度可视化"""
self.confidence_canvas.delete("all")
# 画布尺寸
canvas_width = self.confidence_canvas.winfo_width() or 300
# 绘制背景
self.confidence_canvas.create_rectangle(
10, 10, canvas_width - 10, 40,
fill="#f0f0f0",
outline="#cccccc"
)
# 绘制置信度条
bar_width = int((canvas_width - 20) * confidence)
color = self.get_confidence_color(confidence)
self.confidence_canvas.create_rectangle(
10, 10, 10 + bar_width, 40,
fill=color,
outline=""
)
# 绘制文本
self.confidence_canvas.create_text(
canvas_width / 2, 25,
text=f"{confidence:.1%}",
font=("Arial", 10, "bold")
)
# 绘制刻度
for i in range(0, 11):
x_pos = 10 + i * (canvas_width - 20) / 10
self.confidence_canvas.create_line(x_pos, 40, x_pos, 45, width=1)
if i % 2 == 0:
self.confidence_canvas.create_text(x_pos, 55, text=f"{i*10}%", font=("Arial", 8))
def get_confidence_color(self, confidence):
"""根据置信度获取颜色"""
if confidence >= 0.9:
return "#4CAF50" # 绿色
elif confidence >= 0.7:
return "#FFC107" # 黄色
else:
return "#F44336" # 红色
def update_candidates_display(self, candidates):
"""更新候选数字显示"""
# 清空现有项
for item in self.candidates_tree.get_children():
self.candidates_tree.delete(item)
# 添加新项
for digit, prob in candidates:
self.candidates_tree.insert(
"", tk.END,
values=(digit, f"{prob:.2%}")
)
def show_samples(self):
"""显示样本图像"""
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i+1)
sample_idx = np.where(self.digits.target == i)[0][0]
plt.imshow(self.digits.images[sample_idx], cmap="gray")
plt.title(f"数字 {i}", fontsize=9)
plt.axis("off")
plt.tight_layout()
plt.show()
def on_model_select(self, event):
"""模型选择事件处理"""
selected_name = self.model_var.get()
model_type = next(
(k for k, v in self.available_models if v == selected_name),
None
)
if model_type:
self.change_model(model_type)
def change_model(self, model_type):
"""切换模型"""
model_name = MODEL_METADATA[model_type][0]
# 从缓存加载
if model_type in self.model_cache:
self.current_model, self.scaler, accuracy, self.current_model_type = self.model_cache[model_type]
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"已加载模型: {model_name}")
return
self.status_var.set(f"正在加载模型: {model_name}...")
self.root.update() # 更新UI显示状态
try:
X_train, X_test, y_train, y_test = self.model_factory.get_split_data(self.digits)
self.current_model, self.scaler, accuracy = self.model_factory.train_and_evaluate(
model_type, X_train, y_train, X_test, y_test
)
self.current_model_type = model_type
self.model_cache[model_type] = (self.current_model, self.scaler, accuracy, self.current_model_type)
self.model_label.config(text=f"{model_name} (准确率:{accuracy:.4f})")
self.status_var.set(f"模型加载完成: {model_name}, 准确率: {accuracy:.4f}")
self.clear_canvas()
# 更新性能表格
self.load_performance_data()
except Exception as e:
self.status_var.set(f"模型加载失败: {str(e)}")
self.model_label.config(text="模型加载失败")
def init_default_model(self):
"""初始化默认模型"""
self.model_var.set(self.available_models[0][1])
self.change_model(self.available_models[0][0])
def load_performance_data(self):
"""加载性能数据"""
results = self.model_factory.evaluate_all_models(self.digits)
# 清空表格
for item in self.performance_tree.get_children():
self.performance_tree.delete(item)
# 添加数据
for i, result in enumerate(results):
tag = "highlight" if i == 0 else ""
self.performance_tree.insert(
"", tk.END,
values=(result["模型名称"], result["准确率"]),
tags=(tag,)
)
self.performance_tree.tag_configure("highlight", background="#e6f7ff")
def show_performance_chart(self):
"""显示性能图表"""
results = self.model_factory.evaluate_all_models(self.digits)
# 提取有效结果
valid_results = []
for result in results:
try:
accuracy = float(result["准确率"])
valid_results.append((result["模型名称"], accuracy))
except ValueError:
continue
if not valid_results:
messagebox.showinfo("提示", "没有可用的性能数据")
return
# 排序
valid_results.sort(key=lambda x: x[1], reverse=True)
models, accuracies = zip(*valid_results)
# 创建图表
plt.figure(figsize=(10, 5))
bars = plt.barh(models, accuracies, color='#2196F3')
plt.xlabel('准确率', fontsize=10)
plt.ylabel('模型', fontsize=10)
plt.title('模型性能对比', fontsize=12)
plt.xlim(0, 1.05)
# 添加数值标签
for bar in bars:
width = bar.get_width()
plt.text(
width + 0.01,
bar.get_y() + bar.get_height()/2,
f'{width:.4f}',
ha='left',
va='center',
fontsize=8
)
plt.tight_layout()
plt.show()
def save_as_training_sample(self):
"""保存为训练样本"""
if not self.has_drawn:
self.status_var.set("请先绘制数字再保存")
return
img_array = self.preprocess_image()
if img_array is None:
return
# 弹出标签输入窗口
label_window = tk.Toplevel(self.root)
label_window.title("输入标签")
label_window.geometry("300x150")
label_window.transient(self.root)
label_window.grab_set()
tk.Label(
label_window,
text="请输入数字标签 (0-9):",
font=("Arial", 10)
).pack(pady=10)
entry = tk.Entry(label_window, font=("Arial", 12), width=5)
entry.pack(pady=5)
entry.focus_set()
def save_with_label():
try:
label = int(entry.get())
if label < 0 or label > 9:
raise ValueError("标签必须是0-9的数字")
self.custom_data.append((img_array.tolist(), label))
self.status_var.set(f"已保存数字 {label} (共 {len(self.custom_data)} 个样本)")
label_window.destroy()
except ValueError as e:
self.status_var.set(f"保存错误: {str(e)}")
tk.Button(
label_window,
text="保存",
command=save_with_label,
width=10
).pack(pady=5)
def save_all_training_data(self):
"""保存全部训练数据"""
if not self.custom_data:
self.status_var.set("没有训练数据可保存")
return
file_path = filedialog.asksaveasfilename(
defaultextension=".csv",
filetypes=[("CSV文件", "*.csv")],
initialfile="custom_digits.csv",
title="保存训练集"
)
if not file_path:
return
try:
with open(file_path, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow([f'pixel{i}' for i in range(64)] + ['label'])
for img_data, label in self.custom_data:
writer.writerow(img_data + [label])
self.status_var.set(f"已保存 {len(self.custom_data)} 个样本到 {os.path.basename(file_path)}")
except Exception as e:
self.status_var.set(f"保存失败: {str(e)}")
def load_training_data(self):
"""加载训练数据"""
file_path = filedialog.askopenfilename(
filetypes=[("CSV文件", "*.csv")],
title="加载训练集"
)
if not file_path:
return
try:
self.custom_data = []
with open(file_path, 'r', newline='', encoding='utf-8') as f:
reader = csv.reader(f)
next(reader) # 跳过标题
for row in reader:
if len(row) != 65:
continue
img_data = [float(pixel) for pixel in row[:64]]
label = int(row[64])
self.custom_data.append((img_data, label))
self.status_var.set(f"已加载 {len(self.custom_data)} 个样本")
except Exception as e:
self.status_var.set(f"加载失败: {str(e)}")
def run(self):
"""运行应用"""
self.root.mainloop()
if __name__ == "__main__":
digits = load_digits()
root = tk.Tk()
app = HandwritingBoard(root, ModelFactory, digits)
app.run()
基于此代码,在其中做好大量注释,同时要明确代码的分区功能,要显示明白,让刚学python的同学要能看懂。