import tensorflow as tf
import numpy as np
from flask import Flask, request, jsonify
import os
import tkinter as tk
from tkinter import filedialog
app = Flask(__name__)
# 定义模型架构
def build_model(vocab_size, max_length):
model = tf.keras.Sequential([
tf.keras.layers.Embedding(vocab_size, 64, input_length=max_length),
tf.keras.layers.LSTM(128),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
return model
# 预处理函数
def preprocess_equation(equation, max_length):
char_to_index = {chr(i): i for i in range(128)}
encoded_equation = [char_to_index.get(c, -1) for c in equation]
padded_equation = np.pad(encoded_equation, (0, max_length - len(encoded_equation)), 'constant')
return padded_equation
# 生成模拟训练数据
def generate_train_data(num_samples=100):
elements = ['H', 'O', 'N', 'C', 'Na', 'Cl']
operators = ['+', '=', '->']
data = []
labels = []
for _ in range(num_samples):
num_reactants = np.random.randint(1, 4)
reactants = []
for _ in range(num_reactants):
num_elements = np.random.randint(1, 3)
elements_list = np.random.choice(elements, num_elements)
reactants.append(''.join(elements_list) + np.random.choice(operators))
num_products = np.random.randint(1, 3)
products = []
for _ in range(num_products):
num_elements = np.random.randint(1, 3)
elements_list = np.random.choice(elements, num_elements)
products.append(''.join(elements_list) + np.random.choice(operators))
equation = '.'.join(reactants) + '=' + '.'.join(products)
data.append(equation)
# 随机生成标签(0 或 1)
label = np.random.randint(0, 2)
labels.append(label)
return data, np.array(labels)
# 模型训练函数
def train_model(data, labels, max_length, epochs=10, batch_size=32, device='cpu'):
if device == 'gpu':
physical_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ''
model = build_model(128, max_length)
encoded_data = [preprocess_equation(eq, max_length) for eq in data]
encoded_data = np.array(encoded_data)
labels = np.array(labels)
model.fit(encoded_data, labels, epochs=epochs, batch_size=batch_size, validation_split=0.1)
model.save('chemistry_model.h5')
return model
# 从文件中读取化学方程式
def load_data_from_file(file_path):
if not os.path.exists(file_path):
print(f"文件 {file_path} 不存在")
return []
with open(file_path, 'r') as file:
data = file.readlines()
return data
@app.route('/train', methods=['POST'])
def train():
try:
data = request.json['data']
labels = np.array(request.json['labels'])
except KeyError:
return jsonify({"message": "数据格式错误,缺少 '数据' 或 '标签' 字段"})
max_length = max(len(eq) for eq in data)
train_model(data, labels, max_length)
return jsonify({"message": "训练完成!"})
@app.route('/predict', methods=['POST'])
def predict():
equation = request.json.get('equation', '')
if not equation:
return jsonify({"message": "未提供 '化学方程式' 字段"})
try:
model = tf.keras.models.load_model('chemistry_model.h5')
max_length = max(len(eq) for eq in load_data_from_file('training_data.txt'))
processed_equation = preprocess_equation(equation, max_length)
padded_equation = np.expand_dims(processed_equation, axis=0)
prediction = model.predict(padded_equation)
except FileNotFoundError:
return jsonify({"message": "模型文件未找到"})
except Exception as e:
return jsonify({"message": f"预测时出错: {str(e)}"})
return jsonify({"prediction": prediction[0][0]})
def tkinter_predict():
def select_device():
selected_device = device_var.get()
if selected_device == 'gpu' and not tf.test.is_gpu_available():
device_var.set('cpu')
result_label.config(text="未检测到 GPU,将使用 CPU 进行训练")
global model
try:
model = tf.keras.models.load_model('chemistry_model.h5')
except OSError:
# 生成模拟训练数据
data, labels = generate_train_data()
# 将训练数据保存到文件,以便计算最大长度
with open('training_data.txt', 'w') as file:
for equation in data:
file.write(equation + '\n')
max_length = max(len(eq) for eq in data)
model = train_model(data, labels, max_length, device=selected_device)
def predict_from_input():
equation = input_entry.get()
if not equation:
result_label.config(text="请输入化学方程式")
return
try:
max_length = max(len(eq) for eq in load_data_from_file('training_data.txt'))
processed_equation = preprocess_equation(equation, max_length)
padded_equation = np.expand_dims(processed_equation, axis=0)
prediction = model.predict(padded_equation)
result_label.config(text=f"预测结果: {prediction[0][0]}")
except FileNotFoundError:
result_label.config(text="模型文件未找到")
except Exception as e:
result_label.config(text=f"预测出错: {str(e)}")
def predict_from_file():
file_path = filedialog.askopenfilename()
if not file_path:
return
try:
with open(file_path, 'r') as file:
equation = file.read().strip()
if not equation:
result_label.config(text="文件内容为空")
return
max_length = max(len(eq) for eq in load_data_from_file('training_data.txt'))
processed_equation = preprocess_equation(equation, max_length)
padded_equation = np.expand_dims(processed_equation, axis=0)
prediction = model.predict(padded_equation)
result_label.config(text=f"预测结果: {prediction[0][0]}")
except FileNotFoundError:
result_label.config(text="模型文件未找到")
except Exception as e:
result_label.config(text=f"处理文件出错: {str(e)}")
# 创建界面
root = tk.Tk()
input_label = tk.Label(root, text="输入化学方程式:")
input_label.pack()
input_entry = tk.Entry(root)
input_entry.pack()
input_button = tk.Button(root, text="预测", command=predict_from_input)
input_button.pack()
file_label = tk.Label(root, text="或选择包含化学方程式的 txt 文件:")
file_label.pack()
file_button = tk.Button(root, text="选择文件", command=predict_from_file)
file_button.pack()
device_label = tk.Label(root, text="选择运行设备:")
device_label.pack()
device_var = tk.StringVar()
device_var.set('cpu')
device_radiobutton1 = tk.Radiobutton(root, text="直接运行(默认 CPU)", variable=device_var, value='cpu')
device_radiobutton1.pack()
device_radiobutton2 = tk.Radiobutton(root, text="使用 CPU", variable=device_var, value='cpu')
device_radiobutton2.pack()
device_radiobutton3 = tk.Radiobutton(root, text="使用 GPU(如果可用)", variable=device_var, value='gpu')
device_radiobutton3.pack()
device_button = tk.Button(root, text="确定", command=select_device)
device_button.pack()
result_label = tk.Label(root, text="")
result_label.pack()
root.mainloop()
if __name__ == '__main__':
app.run(debug=True)
tkinter_predict()
这个程序首先启动 Flask 应用,然后在主程序中运行 GUI。当用户通过 GUI 输入化学方程式或选择文件后,程序会加载模型并进行预测。请注意,你需要先训练模型(通过 POST 请求到 /train 路由),然后才能进行预测。