不废话,直接写代码
#define _CRT_SECURE_NO_WARNINGS
#pragma once
#include<stdio.h>
#include<stdlib.h>
#include<graphics.h>
#include<math.h>
#include<memory.h>
#include<memory>
#include<string.h>
#include<string>
#include<process.h>
#include<cstdio>
#include<cstdlib>
#include<graphics.h>
#include<time.h>
#include<Windows.h>
#include<mmsystem.h>
#include<wchar.h>
#include<locale>
#include<locale.h>
#include<iostream>
#pragma comment(lib,"winmm.lib")
using namespace std;
//define global number to start or to stop the program_thread without any problem
int global_language = 0; //global_language=0::Chinese global_languaage=1::English
int global_thread_stop_1 = 0; //global_thread_stop_1=0::continue global_thread_stop_1=1::exit the thread mouse_message_init and exit break the while 1 global_stop_1=2::exit the program.
int global_thread_stop_2 = 0; //global_thread_stop_2=0::continue global_thread_stop_2=1::goto next line. global_thread_stop=2::exit the program.
int global_judge_thread_mouse_message_init = 0; //global_judge_thread_mouse_message_init=1::thread mouse_message_init is alive global_judge_thread_mouse_init=0::thread mouse_message_init is dead
int global_check_files_0 = 0; //global_check_files_0=0::it has not begun to check files. global_check_files_0=-1::unknown eror this can cause this program exit.global_check_files_0=1,check model.dll finished global_check_files_0=2 check train.dll finished.global_check_files_0=3,check connect.dll finished.global_check_files_0=4 check Test_model.dll finished. global_check_files_0=5 check utils.py finished. global_check_files_0=6 check readme.txt finished. global_check_files_0=7 check Data_Pre.py finished.
int global_check_python_environment = 0; //global_check_python_environment=-1::no python environment.global_check_python_environment=0::it has not begun to check python environment global_check_python_environment=1::check torch finished. global_check_python_environment=2::check matplotlib finished. global_check_python_environment=3 check visdom finished.global_check_python_environment=4::check torchvision finished.
int global_files[7] = { 0 }; //in order to store which file is not exist. if global_files[]=1::means this file is not exist. if global_files[]=0::means this file is exist. global_files[0]->model.dll global_files[1]->train.dll global_files[2]->connect.dll global_files[3]->Test_model.dll global_files[4]->utils.py global_files[5]->readme.txt global_files[6]->Data_Pre.py
int global_judge_thread_mouse_message_after_1 = 0; //global_judge_thread_mouse_message_after_1=1::thread mouse_message_after_1 is alive global_judge_thread_mouse_after_1=0::thread mouse_message_after_1 is dead.
/********************************************\
* this program is powered by lry
* function : this program can provide a GUI for the guest
* mouse : this program can be use by mouse
* thread : this program can use in a mutli-threads system
* what's the function means:
* 1 :check_python_environment :this function can check the python environment is completely or not
* 2 :self_classification :this function can classify the photograph all by this program itself
* 3 :read_file :this function can read the files from connect.dll and show on the sereen.
* 4 :create_operatble_file :this function can create operatable files for this program to use
* 5 :show :this function can show the classification by this program or by the users
* 6 :check_file :this function can check files is or not fixed successfully.
* 7 :fix_file_1 :this function can fix the important file this program must use == model.dll
* 8 :fix_file_2 :this function can fix the important file this program must use == train.dll >>in order to train the model.
* 9 :fix_file_3 :this function can fix the important file this program must use == connect.dll >>in order to connect with other program.
* 10 :fix_file_4 :this function can fix the important file this program must use == Test_model.dll >>in order to test the model and make sure it can guess which kind of garbage it is.
* 11 :fix_file_5 :this function can fix the important file this program must use == utils.py >>in order to make the program have a better accuracy
* 12 :fix_file_6 :this function can fix the important file this program must use == readme.txt >>in order to make sure that the users can have the right choice to use this program.
* 13 :fix_file_7 :this function can fix the important file this program must use == Data_Pre.py >>in order to make this program have the prearation to analysis this photograph.
* 14 :create_update_data :this function can show the updata log to the users.
* 15 :mouse_message_init :this function is a new thread that can run with GUI interface.
* 16 :mouse_message_after :this function is a new thread that can give message to this program .in the use of the running time.
* 17 :GUI_interface_1 :this function is a new thread that can give the users a good usage operation.
* 18 :GUI_interface_2 :this function is a new thread that can give the users a operatalbe interface.
* 19 :GUI_interface_3 :this function is a new thread that can give the users a operatable interface.
* 20 :language_setting :set the language you use. //can't use in this version.
* 21 :get_the_time_now :get the time now.
* 22 :show_the_settings :show the settings of program
* start create time :2021.3.12
* 2021.3.12 version 1.0.1 can only run in dos mode.using python script.
* 2021.7.27 version 2.0.1 this version can fix the wrong files auto.but without any GUI interface.
* 2021.8.6 version 3.0.1 this version have a GUIinterface,but still have some bugs.
\***********************************************/
//function begin.
void check_python_environment(void *);
void self_classification();
void read_file(int a);
void create_operatble_file(int a);
void show(int a);
void init();
void check_file(void*);
void fix_file_1();
void fix_file_2();
void fix_file_3();
void fix_file_4();
void fix_file_5();
void fix_file_6();
void fix_file_7();
void create_update_data();
void mouse_message_init(void*);
void mouse_message_after_1(void*);
void GUI_interface_1();
void GUI_interface_2();
void GUI_interface_3();
void language_setting();
void get_the_time_now(void*);
void show_the_settings(void*);
void clear(void*);
//store place
wchar_t code_license_and_sourse_code[] = L"open sourse license\n==========================\nthis program is powered by lry.\nthis program is all provided by lry\n all rights reserved 2021\n=======================\nsourse code:train.dll\nimport torch\nimport visdom\nfrom torch import optim,nn\nfrom utils import Flatten\nfrom Data_Pre import Data\nfrom torch.utils.data import DataLoader\nfrom torchvision.models import resnet18\nbatchsz=32\nlr=1e-4\nepochs=20\ndevice=torch.device('cuda' if torch.cuda.is_available() else 'cpu'\ntorch.manual_seed(1234)\ntrain_db=Data('train_data',224,mode='train')\nval_db=Data('train_data',224,mode='val')\ntest_db=Data('train_data',224,mode='test')\ntrain_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)\nval_loader=DataLoader(val_db,batch_size=batchsz,num_worker=4)\ntest_loader=DataLoader(test_db,batchsz,num_worker=4)\nviz=visdom.Visdom()\ndef evalute():\n\tmodel.eval()\ncorrect=0\n=========================================\nif you want to see more information(sourse code),please goto this website\n:https://blog.csdn.net/liourenyu/article/details/119490901\n)";
//main_thread begin
int main()
{
setlocale(LC_ALL, "");
create_update_data();
init();
GUI_interface_1();
cleardevice();
GUI_interface_2();
system("pause");
}
void init()
{
initgraph(640, 480);
setbkcolor(WHITE);
setlinecolor(GREEN);
settextcolor(BROWN);
setfillcolor(GREEN);
cleardevice();
_beginthread(get_the_time_now, 0, NULL);
_beginthread(mouse_message_init, 0, NULL);
wchar_t ch_0_0[] = L"语言:中文,单击此处已更改 ";
wchar_t ch_1_0[] = L"启动主程序 ";
wchar_t ch_2_0[] = L"显示本程序设置以及版权信息 ";
wchar_t ch_3_0[] = L"退出本程序 ";
wchar_t ch_0_1[] = L"language : English,click here to change language";
wchar_t ch_1_1[] = L"start running the main program";
wchar_t ch_2_1[] = L"show the settings of this program and the copyright of the author";
wchar_t ch_3_1[] = L"exit this program";
while (1)
{
if (global_language == 0)
{
outtextxy(100, 100, ch_0_0);
outtextxy(100, 200, ch_1_0);
outtextxy(100, 300, ch_2_0);
outtextxy(100, 400, ch_3_0);
}
if (global_language == 1)
{
outtextxy(100, 100, ch_0_1);
outtextxy(100, 200, ch_1_1);
outtextxy(100, 300, ch_2_1);
outtextxy(100, 400, ch_3_1);
}
if (global_thread_stop_1==1)
{
break;
}
if (global_judge_thread_mouse_message_init == 0)
{
_beginthread(mouse_message_init, 0, NULL);
}
if (global_judge_thread_mouse_message_init == 1)
{
}
}
}
void mouse_message_init(void*)
{
global_judge_thread_mouse_message_init = 1;
MOUSEMSG m;
FlushMouseMsgBuffer();
int a = 0;
int temp = 0;
while (1)
{
m = GetMouseMsg();
while (m.mkLButton)
{
Sleep(100); //防止该线程抖动
if (m.x >= 100 && m.x <= 200 && m.y >= 70 && m.y <= 130)
{
if (global_language == 0)
{
temp = 1;
}
if (global_language == 1)
{
temp = 0;
}
cleardevice();
global_language = temp;
global_judge_thread_mouse_message_init = 0;
_endthread();
}
if (m.x >= 100 && m.x <= 200 && m.y >= 270 && m.y <= 330)
{
_beginthread(show_the_settings, 0, NULL);
global_judge_thread_mouse_message_init = 0;
_endthread();
}
if (m.x >= 100 && m.x <= 200 && m.y >= 370 && m.y <= 430)
{
global_judge_thread_mouse_message_init = 0;
exit(0);
}
if (m.x >= 100 && m.x <= 200 && m.y >= 170 && m.y <= 330)
{
global_thread_stop_1 = 1;
_endthread();
}
}
}
}
void show_the_settings(void*)
{
wchar_t ch_0[] = L"本程序由刘仁宇编写\n本程序可以通过内置的模型,进行自动化的垃圾分类操作。\n本程序也可以通过手工分类进行分类操作";
wchar_t ch_1[] = L"this program is powered by lry\n this program can use the model init to classify the garbage auto.\nthis program can also classify by the user";
wchar_t ch_0_00[] = L"程序设置以及版权信息";
wchar_t ch_1_00[] = L"the settings and the copyright of this program";
if (global_language == 0)
{
MessageBox(NULL, ch_0, ch_0_00, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_1, ch_1_00, MB_OK);
}
}
void GUI_interface_1()
{
cleardevice();
wchar_t ch_init_0[] = L"请稍后 ";
wchar_t ch_init_1[] = L"please wait a while ";
wchar_t ch_0_0[] = L"检查本程序所必须的文件,请稍后 ";
wchar_t ch_0_1[] = L"checking the files this program must have,please wait ";
wchar_t ch_1_0[] = L"正在检查python环境,请稍后... ";
wchar_t ch_1_1[] = L"checking python environment now ,please wait a while ";
wchar_t ch_2_0[] = L"正在检查python 环境配置:torch ";
wchar_t ch_2_1[] = L"checking pyhton environment settings :torch ";
wchar_t ch_3_0[] = L"正在检查python 环境配置:matplotlib ";
wchar_t ch_3_1[] = L"checking python environment settings : matplotlib ";
wchar_t ch_4_0[] = L"正在检查python 环境配置:visdom ";
wchar_t ch_4_1[] = L"checking python environment settings :visdom ";
wchar_t ch_5_0[] = L"正在检查python 环境配置:torchvision ";
wchar_t ch_5_1[] = L"checking python environment settings :torchvision ";
wchar_t ch_6_0[] = L"正在检查本程序所必需的文件,请稍后 ";
wchar_t ch_6_1[] = L"checking the files this program must have to start this program ,please wait a while";
wchar_t ch_7_0[] = L"正在检查本程序所必需的文件:model.dll ";
wchar_t ch_7_1[] = L"checking the essential files this program must have used : model.dll ";
wchar_t ch_8_0[] = L"正在检查本程序所必需的文件:train.dll ";
wchar_t ch_8_1[] = L"checking the essential files this program must have used : train.dll ";
wchar_t ch_9_0[] = L"正在检查本程序所必需的文件:connect.dll ";
wchar_t ch_9_1[] = L"checking the essential files this program must have used :connect.dll ";
wchar_t ch_10_0[] = L"正在检查本程序所必需的文件:Test_model.dll ";
wchar_t ch_10_1[] = L"checking the essential files this program must have used : Test_model.dll ";
wchar_t ch_11_0[] = L"正在检查本程序所必需的文件:utils.py ";
wchar_t ch_11_1[] = L"checking the essential files this program must have used : utils.py ";
wchar_t ch_12_0[] = L"正在检查本程序所必需的文件:readme.txt ";
wchar_t ch_12_1[] = L"checking the essential files this program must have used : readme.txt ";
wchar_t ch_13_0[] = L"正在检查本程序所必需的文件:Data_Pre.py ";
wchar_t ch_13_1[] = L"checking the essential files this program must have used : Data_Pre.py ";
wchar_t ch_finish_cn_0[] = L"检查本程序必须的文件已经完成,请等待后续操作";
wchar_t ch_finish_en_0[] = L"checking the essential files this program must have is finished .please wait for other operation finish";
wchar_t ch_finish_cn_1[] = L"python环境检查已经完成,即将进入下一个界面";
wchar_t ch_finish_en_1[] = L"we have successfully checked the environment of python.we will goto next interface ,please wait";
if (global_language == 0)
{
outtextxy(200, 200, ch_init_0);
outtextxy(200, 250, ch_0_0);
}
if (global_language == 1)
{
outtextxy(200, 200, ch_init_1);
outtextxy(200, 250, ch_0_1);
}
Sleep(1000);
cleardevice();
_beginthread(check_file, 0, NULL);
_beginthread(check_python_environment, 0, NULL);
while (1)
{
line(40, 100, 440, 100);
line(40, 100, 40, 150);
line(40, 150, 440, 150);
line(440, 100, 440, 150);
line(40, 200, 440, 200);
line(40, 200, 40, 250);
line(40, 250, 440, 250);
line(440, 250, 440, 200);
if (global_language == 0)
{
if (global_check_files_0 == 0)
{
outtextxy(40, 70, ch_6_0);
fillrectangle(40, 100, 90, 150);
}
if (global_check_files_0 == 1)
{
outtextxy(40, 70, ch_7_0);
fillrectangle(40, 100, 140, 150);
}
if (global_check_files_0 == 2)
{
outtextxy(40, 70, ch_8_0);
fillrectangle(40, 100, 190, 150);
}
if (global_check_files_0 == 3)
{
outtextxy(40, 70, ch_9_0);
fillrectangle(40, 100, 240, 150);
}
if (global_check_files_0 == 4)
{
outtextxy(40, 70, ch_10_0);
fillrectangle(40, 100, 290, 150);
}
if (global_check_files_0 == 5)
{
outtextxy(40, 70, ch_11_0);
fillrectangle(40, 100, 340, 150);
}
if (global_check_files_0 == 6)
{
outtextxy(40, 70, ch_12_0);
fillrectangle(40, 100, 390, 150);
}
if (global_check_files_0 == 7)
{
outtextxy(40, 70, ch_13_0);
fillrectangle(40, 100, 440, 150);
global_check_files_0++;
outtextxy(40, 70, ch_finish_cn_0);
}
//checking python files.
if (global_check_python_environment == 0)
{
outtextxy(40, 170, ch_1_0);
fillrectangle(40, 200, 120, 250);
}
if (global_check_python_environment == 1)
{
outtextxy(40, 170, ch_2_0);
fillrectangle(40, 200, 200, 250);
}
if (global_check_python_environment == 2)
{
outtextxy(40, 170, ch_3_0);
fillrectangle(40, 200, 280, 250);
}
if (global_check_python_environment == 3)
{
outtextxy(40, 170, ch_4_0);
fillrectangle(40, 200, 360, 250);
}
if (global_check_python_environment == 4)
{
outtextxy(40, 170, ch_5_0);
fillrectangle(40, 200, 440, 250);
global_check_python_environment++;
outtextxy(40, 170, ch_finish_cn_1);
}
if (global_check_python_environment == -1)
{
wchar_t ch[] = L"请正确安装python,本程序无法识别python\n错误代码0x01";
wchar_t ch1[] = L"错误信息";
MessageBox(NULL, ch, ch1, MB_OK);
exit(0);
}
}
if (global_language == 1)
{
if (global_check_files_0 == 0)
{
outtextxy(40, 70, ch_6_1);
fillrectangle(40, 100, 90, 150);
}
if (global_check_files_0 == 1)
{
outtextxy(40, 70, ch_7_1);
fillrectangle(40, 100, 140, 150);
}
if (global_check_files_0 == 2)
{
outtextxy(40, 70, ch_8_1);
fillrectangle(40, 100, 190, 150);
}
if (global_check_files_0 == 3)
{
outtextxy(40, 70, ch_9_1);
fillrectangle(40, 100, 240, 150);
}
if (global_check_files_0 == 4)
{
outtextxy(40, 70, ch_10_1);
fillrectangle(40, 100, 290, 150);
}
if (global_check_files_0 == 5)
{
outtextxy(40, 70, ch_11_1);
fillrectangle(40, 100, 340, 150);
}
if (global_check_files_0 == 6)
{
outtextxy(40, 70, ch_12_1);
fillrectangle(40, 100, 390, 150);
}
if (global_check_files_0 == 7)
{
outtextxy(40, 70, ch_13_1);
fillrectangle(40, 100, 440, 150);
global_check_files_0++;
outtextxy(40, 70, ch_finish_en_0);
}
//checking python files.
if (global_check_python_environment == 0)
{
outtextxy(40, 170, ch_1_1);
fillrectangle(40, 200, 120, 250);
}
if (global_check_python_environment == 1)
{
outtextxy(40, 170, ch_2_1);
fillrectangle(40, 200, 200, 250);
}
if (global_check_python_environment == 2)
{
outtextxy(40, 170, ch_3_1);
fillrectangle(40, 200, 280, 250);
}
if (global_check_python_environment == 3)
{
outtextxy(40, 170, ch_4_1);
fillrectangle(40, 200, 360, 250);
}
if (global_check_python_environment == 4)
{
outtextxy(40, 170, ch_5_1);
fillrectangle(40, 200, 440, 250);
global_check_python_environment++;
outtextxy(40, 170, ch_finish_en_1);
}
if (global_check_python_environment == -1)
{
wchar_t ch[] = L"请正确安装python,本程序无法识别python\n错误代码0x01";
wchar_t ch1[] = L"错误信息";
MessageBox(NULL, ch, ch1, MB_OK);
exit(0);
}
}
if (global_check_files_0 == 8 && global_check_python_environment == 5)
{
Sleep(2000);
break;
}
}
}
void check_file(void*)
{
FILE* f1 = fopen("model.dll", "rb");
FILE* f2 = fopen("train.dll", "rb");
FILE* f3 = fopen("connect.dll", "rb");
FILE* f4 = fopen("Test_model.dll", "rb");
FILE* f5 = fopen("utils.py", "rb");
FILE* f6 = fopen("readme.txt", "rb");
FILE* f7 = fopen("Data_Pre.py", "rb");
global_check_files_0 = 0;
if (f1 == NULL)
{
fix_file_1();
}
if (f1 != NULL)
{
fclose(f1);
}
global_check_files_0 = 1;
if (f2 == NULL)
{
fix_file_2();
}
if (f2 != NULL)
{
fclose(f2);
}
global_check_files_0 = 2;
if (f3 == NULL)
{
fix_file_3();
}
if (f3 != NULL)
{
fclose(f3);
}
global_check_files_0 = 3;
if (f4 == NULL)
{
fix_file_4();
}
if (f4 != NULL)
{
fclose(f4);
}
global_check_files_0 = 4;
if (f5 == NULL)
{
fix_file_5();
}
if (f5 != NULL)
{
fclose(f5);
}
global_check_files_0 = 5;
if (f6 == NULL)
{
fix_file_6();
}
if (f6 != NULL)
{
fclose(f6);
}
global_check_files_0 = 6;
if (f7 == NULL)
{
fix_file_7();
}
if (f7 != NULL)
{
fclose(f7);
}
global_check_files_0 = 7;
_endthread();
}
void check_python_environment(void*)
{
FILE * ffff1=freopen("error.log", "w", stderr);
system("pip");
FILE* f1 = fopen("error.log", "rb");
if (f1 != NULL)
{
char ch[100] = { '\0' };
fgets(ch, 10, f1);
if (ch[0] == 39)
{
global_check_python_environment = -1;
}
fclose(f1);
}
system("pip install torch");
global_check_python_environment = 1;
system("pip install matplotlib");
global_check_python_environment = 2;
system("pip install visdom");
global_check_python_environment = 3;
system("pip install torchvision");
global_check_python_environment = 4;
_endthread();
}
void fix_file_1()
{
FILE* fp = fopen("error.txt", "a+");
fprintf(fp, "can not create the file : model.dll\n");
fprintf(fp, "fix file failed,please go into the choice and choose the choice 3 to remake the file.\n");
wchar_t ch_cn_0[] = L"修复模型model.dll失败,请进入主界面后按照提示训练模型";
wchar_t ch_cn_1[] = L"警告";
wchar_t ch_en_0[] = L"fix model::model.dll failed ,please goto the main process train a new model with the tips";
wchar_t ch_en_1[] = L"warning";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
fclose(fp);
}
void fix_file_2()
{
//fix train.dll
FILE* f1 = fopen("train.dll", "a+");
fprintf(f1, "import torch\n");
fprintf(f1, "import visdom\n");
fprintf(f1, "from torch import optim, nn\n");
fprintf(f1, "from utils import Flatten\n");
fprintf(f1, "from Data_Pre import Data\n");
fprintf(f1, "from torch.utils.data import DataLoader\n");
fprintf(f1, "from torchvision.models import resnet18\n");
fprintf(f1, "batchsz=32\n");
fprintf(f1, "lr = 1e-4\n");
fprintf(f1, "epochs =20\n");
fprintf(f1, "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n");
fprintf(f1, "torch.manual_seed(1234)\n");
fprintf(f1, "train_db=Data('train_data',224,mode='train')\n");
fprintf(f1, "val_db=Data('train_data',224,mode='val')\n");
fprintf(f1, "test_db=Data('train_data',224,mode='test')\n");
fprintf(f1, "train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)\n");
fprintf(f1, "val_loader=DataLoader(val_db,batch_size=batchsz,num_workers=4)\n");
fprintf(f1, "test_loader=DataLoader(test_db,batch_size=batchsz,num_workers=4)\n");
fprintf(f1, "viz=visdom.Visdom()\n");
fprintf(f1, "def evalute(model,loader):\n");
fprintf(f1, "\tmodel.eval()\n");
fprintf(f1, "\tcorrect=0\n");
fprintf(f1, "\ttotal=len(loader.dataset)\n");
fprintf(f1, "\tfor x,y in loader:\n");
fprintf(f1, "\t\tx,y =x.to(device),y.to(device)\n");
fprintf(f1, "\t\twith torch.no_grad():\n");
fprintf(f1, "\t\t\tlogits=model(x)\n");
fprintf(f1, "\t\t\tpred=logits.argmax(dim=1)\n");
fprintf(f1, "\t\tcorrect+=torch.eq(pred,y).sum().float().item()\n");
fprintf(f1, "\treturn correct / total\n");
fprintf(f1, "def main():\n");
fprintf(f1, "\ttrained_model=resnet18(pretrained=True)\n");
fprintf(f1, "\tmodel = nn.Sequential(*list(trained_model.children())[:-1],Flatten(),nn.Linear(512,6)).to(device)\n");
fprintf(f1, "\toptimizer=optim.Adam(model.parameters(),lr=lr)\n");
fprintf(f1, "\tcriteon=nn.CrossEntropyLoss()\n");
fprintf(f1, "\tbest_acc,best_epoch=0,0\n");
fprintf(f1, "\tglobal_step=0\n");
fprintf(f1, "\tviz.line([[0.0,0.0]],[0.],win='test',opts=dict(title='Loss on Training Data and Accuracy on Training Data',xlabel='Epochs',ylabel='Loss and Accuracy',legend=['loss','val_acc']))\n");
fprintf(f1, "\tfor epoch in range(epochs):\n");
fprintf(f1, "\t\tfor step,(x,y) in enumerate(train_loader):\n");
fprintf(f1, "\t\t\tx,y = x.to(device),y.to(device)\n");
fprintf(f1, "\t\t\tmodel.train()\n");
fprintf(f1, "\t\t\tlogits=model(x)\n");
fprintf(f1, "\t\t\tloss=criteon(logits,y)\n");
fprintf(f1, "\t\t\toptimizer.zero_grad()\n");
fprintf(f1, "\t\t\tloss.backward()\n");
fprintf(f1, "\t\t\toptimizer.step()\n");
fprintf(f1, "\t\t\tviz.line([[loss.item(),evalute(model,val_loader)]],[global_step],win='test',update='append')\n");
fprintf(f1, "\t\t\tglobal_step+=1\n");
fprintf(f1, "\t\tif epoch%1==0:\n");
fprintf(f1, "\t\t\tprint('the '+str(epoch+1)+' epoch'+' training......')\n"); //fprintf(f1, "\t\t\tprint('第 '+str(epoch+1)+' 批'+' training……')\n");
fprintf(f1, "\t\t\tval_acc=evaluate(model,val_loader)\n");
fprintf(f1, "\t\t\tif val_acc>best_acc:\n");
fprintf(f1, "\t\t\t\tbest_epoch=epoch\n");
fprintf(f1, "\t\t\t\tbest_acc=val_acc\n");
fprintf(f1, "\t\t\t\ttorch.save(model.state_dict(),'best_trans.mdl')\n");
fprintf(f1, "\tprint('best accuracy:',best_acc,'best epoch:',(best_epoch+1))\n"); //fprintf(f1, "\tprint('最好的准确率:',best_acc,'最好的批次:',(best_epoch+1))\n");
fprintf(f1, "\ttorch.save(model,'model.dll')\n");
fprintf(f1, "\tprint('loading model......')\n"); //fprintf(f1, "\tprint('正在加载模型......')\n");
fprintf(f1, "\ttest_acc=evalute(model,test_loader)\n");
fprintf(f1, "\tprint('test accuracy:',test_acc)\n"); //fprintf(f1, "\tprint('测试准确率:',test_acc)\n");
fprintf(f1, "\tprint('successfully save the best model ')\n"); //fprintf(f1, "\tprint('保存最好效果模型成功!')\n");
fprintf(f1, "if __name__=='__main__':\n");
fprintf(f1, "\tmain()\n");
fclose(f1);
printf("\ntrain.dll has been fixed successfully.\n");
}
void fix_file_3()
{
FILE* fp = fopen("connect.dll", "a+");
fprintf(fp, "");
printf("\nconnect.dll has been fixed successfully.\n");
fclose(fp);
}
void fix_file_4()
{
FILE* f1 = fopen("Test_model.dll", "a+");
fprintf(f1, "import sys\n");
fprintf(f1, "import torch\n");
fprintf(f1, "from PIL import Image\n");
fprintf(f1, "from torchvision import transforms\n");
fprintf(f1, "import visdom\n");
fprintf(f1, "from torch import optim , nn\n");
fprintf(f1, "import os\n");
fprintf(f1, "classes=('harmful','kitch','others','recyc')\n");
fprintf(f1, "if torch.cuda.is_available():\n");
fprintf(f1, "\tdevice = torch.device('cuda')\n");
fprintf(f1, "\ttransform = transforms.Compose([\n");
fprintf(f1, "\t\ttransforms.Resize(256),\n");
fprintf(f1, "\t\ttransforms.CenterCrops(224),\n");
fprintf(f1, "\t\ttransforms.ToTensor(),\n");
fprintf(f1, "\t\ttransforms.Normalize(mean=[0.485,0.456,0.406],\n");
fprintf(f1, "\t\t\t\tstd=[0.229,0.224,0.225])\n");
fprintf(f1, "\t\t\t])\n");
fprintf(f1, "else:\n");
fprintf(f1, "\tdevice = torch.device('cpu')\n");
fprintf(f1, "\ttransform=transforms.Compose([\n");
fprintf(f1, "\t\ttransforms.Resize(256),\n");
fprintf(f1, "\t\ttransforms.CenterCrop(224),\n");
fprintf(f1, "\t\ttransforms.ToTensor(),\n");
fprintf(f1, "\t\ttransforms.Normalize(mean=[0.485,0.456,0.406],\n");
fprintf(f1, "\t\t\t\tstd=[0.229,0.224,0.225])\n");
fprintf(f1, "\t\t\t])\n");
fprintf(f1, "def predict(img_path):\n");
fprintf(f1, "\tif torch.cuda.is_available():\n");
fprintf(f1, "\t\tnet=torch.load('model.dll',map_location='cuda')\n");
fprintf(f1, "\t\tnet=net.to(device)\n");
fprintf(f1, "\t\ttorch.no_grad()\n");
fprintf(f1, "\t\timg=Image.open(img_path)\n");
fprintf(f1, "\t\timg=transform(img).unsqueeze(0)\n");
fprintf(f1, "\t\timg_=img.to(device)\n");
fprintf(f1, "\t\toutputs=net(img_)\n");
fprintf(f1, "\t\t_,predicted=torch.max(outputs,1)\n");
fprintf(f1, "\telse:\n");
fprintf(f1, "\t\tnet=torch.load('model.dll',map_location='cpu')\n");
fprintf(f1, "\t\tnet=net.to(device)\n");
fprintf(f1, "\t\ttorch.no_grad()\n");
fprintf(f1, "\t\timg=Image.open(img_path)\n");
fprintf(f1, "\t\timg=transform(img).unsqueeze(0)\n");
fprintf(f1, "\t\timg_=img.to(device)\n");
fprintf(f1, "\t\toutputs=net(img_)\n");
fprintf(f1, "\t\t_,predicted=torch.max(outputs,1)\n");
fprintf(f1, "\tprint(classes[predicted[0]])\n");
fprintf(f1, "\tpath='connect.dll'\n");
fprintf(f1, "\tif os.path.exists(path):\n");
fprintf(f1, "\t\tos.remove(path)\n");
fprintf(f1, "\telse:\n");
fprintf(f1, "\t\tprint('successfully create the file:connect.dll')\n");
fprintf(f1, "\tif classes[predicted[0]]=='harmful':\n");
fprintf(f1, "\t\t#print('1')\n");
fprintf(f1, "\t\tcreate_file(1)\n");
fprintf(f1, "\tif classes[predicted[0]]=='kitch':\n");
fprintf(f1, "\t\t#print('2')\n");
fprintf(f1, "\t\tcreate_file(2)\n");
fprintf(f1, "\tif classes[predicted[0]]=='others':\n");
fprintf(f1, "\t\t#print('3')\n");
fprintf(f1, "\t\tcreate_file(3)\n");
fprintf(f1, "\tif classes[predicted[0]]=='recyc':\n");
fprintf(f1, "\t\t#print('4')\n");
fprintf(f1, "\t\tcreate_file(4)\n");
fprintf(f1, "def create_file(a):\n");
fprintf(f1, "\tif a==1:\n");
fprintf(f1, "\t\ttry:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','r+')\n");
fprintf(f1, "\t\texcept FileNotFoundError:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','a+')\n");
fprintf(f1, "\tif a==2:\n");
fprintf(f1, "\t\ttry:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','r+')\n");
fprintf(f1, "\t\texcept FileNotFoundError:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','a+')\n");
fprintf(f1, "\tif a==3:\n");
fprintf(f1, "\t\ttry:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','r+')\n");
fprintf(f1, "\t\texcept FileNotFoundError:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','a+')\n");
fprintf(f1, "\tif a==4:\n");
fprintf(f1, "\t\ttry:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','r+')\n");
fprintf(f1, "\t\texcept FileNotFoundError:\n");
fprintf(f1, "\t\t\tfile=open('connect.dll','a+')\n");
fprintf(f1, "\twrite_file(a)\n");
fprintf(f1, "def write_file(a):\n");
fprintf(f1, "\tif a==1:\n");
fprintf(f1, "\t\twith open('connect.dll','a+',encoding='utf-8') as f:\n");
fprintf(f1, "\t\t\ttext='harmful'\n");
fprintf(f1, "\t\t\tf.write(text)\n");
fprintf(f1, "\tif a==2:\n");
fprintf(f1, "\t\twith open('connect.dll','a+',encoding='utf-8') as f:\n");
fprintf(f1, "\t\t\ttext='kitch'\n");
fprintf(f1, "\t\t\tf.write(text)\n");
fprintf(f1, "\tif a==3:\n");
fprintf(f1, "\t\twith open('connect.dll','a+',encoding='utf-8') as f:\n");
fprintf(f1, "\t\t\ttext='others'\n");
fprintf(f1, "\t\t\tf.write(text)\n");
fprintf(f1, "\tif a==1:\n");
fprintf(f1, "\t\twith open('connect.dll','a+',encoding='utf-8') as f:\n");
fprintf(f1, "\t\t\ttext='recyc'\n");
fprintf(f1, "\t\t\tf.write(text)\n");
fprintf(f1, "\nif __name__=='__main__':\n");
fprintf(f1, "\tpredict('./test/1.jpg')\n");
printf("\nTest_model.dll has been fixed successfully.\n");
fclose(f1);
}
void fix_file_5()
{
//fix utils.py
FILE* f1 = fopen("utils.py", "a+");
fprintf(f1, "import torch\n");
fprintf(f1, "from torch import nn\n");
fprintf(f1, "from matplotlib import pyplot as plt\n");
fprintf(f1, "class Flatten(nn.Module):\n");
fprintf(f1, "\tdef __init__(self):\n");
fprintf(f1, "\t\tsuper(Flatten,self).__init__()\n");
fprintf(f1, "\tdef forward(self,x):\n");
fprintf(f1, "\t\tshape=torch.prod(torch.tensor(x.shape[1:])).item()\n");
fprintf(f1, "\t\treturn x.view(-1,shape)\n");
fprintf(f1, "def plot_image(img,label,name):\n");
fprintf(f1, "\tfig=plt.figure()\n");
fprintf(f1, "\tfor i in range(6):\n");
fprintf(f1, "\t\tplt.subplot(2,3,i+1)\n");
fprintf(f1, "\t\tplt.tight_layout()\n");
fprintf(f1, "\t\tplt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')\n");
fprintf(f1, "\t\tplt.title('{}:{}'.format(name,label[i].item()))\n");
fprintf(f1, "\t\tplt.xticks([])\n");
fprintf(f1, "\t\tplt.yticks([])\n");
fprintf(f1, "\tplt.show()\n");
printf("\nutils.py has been fixed successfully.\n");
fclose(f1);
}
void fix_file_6()
{
FILE* fp = fopen("readme.txt", "a+");
fprintf(fp, "=================================================\n");
fprintf(fp, "this program is powered by lry\n");
fprintf(fp, "all rights reserved 2020~2021\n");
fprintf(fp, "this file is released by the program garbage_classifation_main_progress.exe\n");
fprintf(fp, "=================================================\n");
fprintf(fp, "this program should have these files below:\n");
fprintf(fp, "1 :grabage_classifation_main_progress.exe\n");
fprintf(fp, "2 :model.dll\n");
fprintf(fp, "3 :Test_model.dll\n");
fprintf(fp, "4 :train.dll\n");
fprintf(fp, "5 :utils.py\n");
fprintf(fp, "6 :Data_Pre.py\n");
fprintf(fp, "7 :readme.txt\n");
fprintf(fp, "=================================================================\n");
fprintf(fp, "if you find out the information are not match with this file,please connect with the program builder lry\n");
fprintf(fp, "author :lry\n");
fprintf(fp, "email address :1224137702@qq.com\n");
fprintf(fp, "=================================================================\n");
fprintf(fp, "if you want to use this program , please install python(>=3.8.5)\n");
fprintf(fp, "=================================================================\n");
fprintf(fp, "now ,the information below is very important.\n");
fprintf(fp, "this program can have the accuracy 98.2%c\n", '%');
fprintf(fp, "val_acc 96.4%c\n", '%');
fprintf(fp, "=================================================================\n");
fprintf(fp, "the way you use this program is that put the image(1.jpg)to test.model\n");
fprintf(fp, "you can get the result in three seconds.\n");
fprintf(fp, "and all the files have a signcode\n");
fprintf(fp, "this versioon is better than the version before this program\n");
fprintf(fp, "thank you for your usage\n");
fclose(fp);
printf("\nreadme.txt has been fixed successfully.\n");
}
void fix_file_7()
{
FILE* f1 = fopen("Data_Pre.py", "a+");
fprintf(f1, "import torch\n");
fprintf(f1, "import os,glob\n");
fprintf(f1, "import random,csv\n");
fprintf(f1, "from PIL import Image\n");
fprintf(f1, "from torchvision import transforms\n");
fprintf(f1, "from torch.utils.data import Dataset,DataLoader\n");
fprintf(f1, "class Data(Dataset):\n");
fprintf(f1, "\tdef __init__(self,root,resize,mode):\n");
fprintf(f1, "\t\tsuper(Data,self).__init__()\n");
fprintf(f1, "\t\tself.root=root\n");
fprintf(f1, "\t\tself.resize=resize\n");
fprintf(f1, "\t\tself.name2label={}\n");
fprintf(f1, "\t\tfor name in sorted(os.listdir(os.path.join(root))):\n");
fprintf(f1, "\t\t\tif not os.path.isdir(os.path.join(root,name)):\n");
fprintf(f1, "\t\t\t\tcontinue\n");
fprintf(f1, "\t\t\tself.name2label[name]=len(self.name2label.keys())\n");
fprintf(f1, "\t\tself.images,self.labels=self.load_csv('images.csv')\n");
fprintf(f1, "\t\tif mode=='train':\n");
fprintf(f1, "\t\t\tself.images=self.images[:int(0.6*len(self.images))]\n");
fprintf(f1, "\t\t\tself.labels=self.labels[:int(0.6*len(self.labels))]\n");
fprintf(f1, "\t\telif mode=='val':\n");
fprintf(f1, "\t\t\tself.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]\n");
fprintf(f1, "\t\t\tself.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]\n");
fprintf(f1, "\t\telse:\n");
fprintf(f1, "\t\t\tself.images=self.images[int(0.8*len(self.images)):]\n");
fprintf(f1, "\t\t\tself.labels=self.labels[int(0.8*len(self.images)):]\n");
fprintf(f1, "\tdef load_csv(self,filename):\n");
fprintf(f1, "\t\tif not os.path.exists(os.path.join(self.root,filename)):\n");
fprintf(f1, "\t\t\timages=[]\n");
fprintf(f1, "\t\t\tfor name in self.name2label.keys():\n");
fprintf(f1, "\t\t\t\timages+=glob.glob(os.path.join(self.root,name,'*.png'))\n");
fprintf(f1, "\t\t\t\timages+=glob.glob(os.path.join(self.root,name,'*.jpg'))\n");
fprintf(f1, "\t\t\t\timages+=glob.glob(os.path.join(self.root,name,'*.jpeg'))\n");
fprintf(f1, "\t\t\tprint(len(images))\n");
fprintf(f1, "\t\t\trandom.shuffle(images)\n");
fprintf(f1, "\t\t\twith open(os.path.join(self.root,filename),mode='w',nemline='') as f:\n");
fprintf(f1, "\t\t\t\twriter=csv.writer(f)\n");
fprintf(f1, "\t\t\t\tfor img in images:\n");
fprintf(f1, "\t\t\t\t\tname=img.split(os.sep)[-2]\n");
fprintf(f1, "\t\t\t\t\tlabel=self.name2label[name]\n");
fprintf(f1, "\t\t\t\t\twriter.writerow([img,label])\n");
fprintf(f1, "\t\t\t\tprint('write into csv into :',filename)\n");
fprintf(f1, "\t\timages,labels=[],[]\n");
fprintf(f1, "\t\twith open(os.path.join(self.root,filename)) as f:\n");
fprintf(f1, "\t\t\treader=csv.reader(f)\n");
fprintf(f1, "\t\t\tfor row in reader:\n");
fprintf(f1, "\t\t\t\timg,label=row\n");
fprintf(f1, "\t\t\t\tlabel=int(label)\n");
fprintf(f1, "\t\t\t\timages.append(img)\n");
fprintf(f1, "\t\t\t\tlabels.append(label)\n");
fprintf(f1, "\t\tassert len(images)==len(labels)\n");
fprintf(f1, "\t\treturn images,labels\n");
fprintf(f1, "\tdef __len__(self):\n");
fprintf(f1, "\t\treturn len(self.images)\n");
fprintf(f1, "\tdef denormalize(self,x_hat):\n");
fprintf(f1, "\t\tmean=[0.485,0.456,0.406]\n");
fprintf(f1, "\t\tstd=[0.229,0.224,0.225]\n");
fprintf(f1, "\t\tmean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)\n");
fprintf(f1, "\t\tstd=torch.tensor(std).unsqueeze(1).unsqueeze(1)\n");
fprintf(f1, "\t\tx=x_hat*std+mean\n");
fprintf(f1, "\t\treturn x\n");
fprintf(f1, "\tdef __getitem__(self,idx):\n");
fprintf(f1, "\t\timg,label=self.images[idx],self.labels[idx]\n");
fprintf(f1, "\t\ttf=transforms.Compose([lambda x:Image.open(x).convert('RGB'),transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self.resize),transforms.ToTensor(),transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])\n");
fprintf(f1, "\t\timg=tf(img)\n");
fprintf(f1, "\t\tlabel=torch.tensor(label)\n");
fprintf(f1, "\t\treturn img,label\n");
fprintf(f1, "def main():\n");
fprintf(f1, "\tdb=Data('train_data',64,'train')\n");
fprintf(f1, "\tDataLoader(db,batch_size=32,shuffle=True,num_workers=8)\n");
fprintf(f1, "if __name__=='__main__':\n");
fprintf(f1, "\tmain()\n");
fclose(f1);
printf("\nData_Pre.py has been fixed successfully\n");
}
void create_update_data()
{
FILE* f1 = fopen("updatedata.txt", "rb");
if (f1 != NULL)
{
fclose(f1);
FILE* f2 = fopen("updatedata.txt", "w+");
fprintf(f2, "version 1.0.1\n");
fprintf(f2, "this verson can make simple classify\n");
fprintf(f2, "version 2.0.1\n");
fprintf(f2, "this version can run in a mode(lost something it have,still can run)\n");
fprintf(f2, "add a signcode to this program\n");
fprintf(f2, "version 3.0.1\n");
fprintf(f2, "this program add some gui interface,easier for you to use");
fclose(f2);
}
if (f1 == NULL)
{
FILE* f2 = fopen("updatedata.txt", "a+");
fprintf(f2, "version 1.0.1\n");
fprintf(f2, "this verson can make simple classify\n");
fprintf(f2, "version 2.0.1\n");
fprintf(f2, "this version can run in a mode(lost something it have,still can run)\n");
fprintf(f2, "add a signcode to this program\n");
fprintf(f2, "version 3.0.1\n");
fprintf(f2, "this program add some gui interface,easier for you to use");
fclose(f2);
}
}
void get_the_time_now(void*)
{
int year, month, date, hour, min, sec;
wchar_t ch_year[20] = { ' ' };
wchar_t ch_month[20] = { ' ' };
wchar_t ch_date[20] = { ' ' };
wchar_t ch_hour[20] = { ' ' };
wchar_t ch_min[20] = { ' ' };
wchar_t ch_sec[20] = { ' ' };
wchar_t ch_[] = L"/";
wchar_t ch__[] = L":";
wchar_t ch_timegettime[64] = { ' ' };
wchar_t ch_mechine_time_cn[] = L"系统运行时间(ms)";
wchar_t ch_mechine_time_en[] = L"time begin from start the computer system(ms)";
_beginthread(clear, 0, NULL);
while (1)
{
DWORD t = timeGetTime();
time_t timep;
struct tm* p;
time(&timep);
p = gmtime(&timep);
year = 1900 + p->tm_year;
month =1 + p->tm_mon;
date = p->tm_mday;
hour = 8 + p->tm_hour;
min = p->tm_min;
sec = p->tm_sec;
int len0 = swprintf(ch_year, 8, L"%d", year);
int len1 = swprintf(ch_month, 8, L"%d", month);
int len2 = swprintf(ch_date, 8, L"%d", date);
int len3 = swprintf(ch_hour, 8, L"%d", hour);
int len4 = swprintf(ch_min, 8, L"%d", min);
int len5 = swprintf(ch_sec, 8, L"%d", sec);
int len6 = swprintf(ch_timegettime, 63, L"%d", t);
outtextxy(50, 450, ch_year);
outtextxy(95, 450, ch_month);
outtextxy(115, 450, ch_date);
outtextxy(135, 450, ch_hour);
outtextxy(155, 450, ch_min);
outtextxy(175, 450, ch_sec);
outtextxy(85, 450, ch_);
outtextxy(110, 450, ch_);
outtextxy(150, 450, ch__);
outtextxy(170, 450, ch__);
if (global_language == 0)
{
outtextxy(200, 450, ch_mechine_time_cn);
outtextxy(325, 450, ch_timegettime);
}
if (global_language == 1)
{
outtextxy(200, 450, ch_mechine_time_en);
outtextxy(505, 450, ch_timegettime);
}
}
}
void GUI_interface_2()
{
wchar_t ch_1_cn[] = L"退出本程序 ";
wchar_t ch_1_en[] = L"exit this program ";
wchar_t ch_2_cn[] = L"使用本程序自动分类 ";
wchar_t ch_2_en[] = L"use this program and classify auto. ";
wchar_t ch_3_cn[] = L"重新训练模型 并保存为model.dll ";
wchar_t ch_3_en[] = L"train the model again and save as model.dll ";
wchar_t ch_4_cn[] = L"开放源代码许可及部分源代码 ";
wchar_t ch_4_en[] = L"Open Sourse Lincese and part of sourse code ";
wchar_t ch_5_cn[] = L"语言设置:中文(单击此处以更改) ";
wchar_t ch_5_en[] = L"language setting:English(click here to change) ";
wchar_t ch_6_cn[] = L"自己分类 ";
wchar_t ch_6_en[] = L"classify by yourself ";
_beginthread(mouse_message_after_1, 0, NULL);
while (1)
{
if (global_language == 0)
{
outtextxy(100, 100, ch_1_cn);
outtextxy(100, 150, ch_2_cn);
outtextxy(100, 200, ch_3_cn);
outtextxy(100, 250, ch_4_cn);
outtextxy(100, 300, ch_5_cn);
outtextxy(100, 350, ch_6_cn);
}
if (global_language == 1)
{
outtextxy(100, 100, ch_1_en);
outtextxy(100, 150, ch_2_en);
outtextxy(100, 200, ch_3_en);
outtextxy(100, 250, ch_4_en);
outtextxy(100, 300, ch_5_en);
outtextxy(100, 350, ch_6_en);
}
if (global_judge_thread_mouse_message_after_1 == 0)
{
_beginthread(mouse_message_after_1, 0, NULL);
}
}
}
void mouse_message_after_1(void*)
{
global_judge_thread_mouse_message_after_1 = 1;
MOUSEMSG m;
FlushMouseMsgBuffer();
int a = 0;
int temp = 0;
while (1)
{
m = GetMouseMsg();
while (m.mkLButton)
{
Sleep(100); //防止该线程抖动
if (m.x >= 100 && m.x <= 300 && m.y >= 75 && m.y <= 115)
{
global_judge_thread_mouse_message_after_1 = 0;
exit(0);
}
if (m.x >= 100 && m.x <= 300 && m.y >= 125 && m.y <= 165)
{
FILE* f1 = fopen("model.dll", "rb");
if (f1 == NULL)
{
wchar_t ch_cn_0[] = L"本程序缺失model.dll组件,请到训练模型模块中重新生成后方可使用";
wchar_t ch_en_0[] = L"because of loss model.dll,this program cannot run this part.you can use it after remake in train part.";
wchar_t ch_cn_1[] = L"模块缺失警告";
wchar_t ch_en_1[] = L"the warning of loss of part of this program";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
global_judge_thread_mouse_message_after_1 = 0;
_endthread();
}
else if (f1 != NULL)
{
wchar_t ch_cn_1_0[] = L"请确保所检测图片位于./test/1.jpg";
wchar_t ch_en_1_0[] = L"please make sure that the photograph you want to predict is in\n./test/1.jpg";
wchar_t ch_cn_1_1[] = L"提示信息";
wchar_t ch_en_1_1[] = L"the information this program provided for users";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_1_0, ch_cn_1_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_1_0, ch_en_1_1, MB_OK);
}
global_judge_thread_mouse_message_after_1 = 0;
system("python model.dll");
read_file(1);
_endthread();
}
}
if (m.x >= 100 && m.x <= 300 && m.y >= 175 && m.y <= 215)
{
wchar_t ch_cn_0[] = L"程序中本模块不向普通人提供,仅开发人员可用\n开发人员请转到本程序目录下利用python生成模型";
wchar_t ch_en_0[] = L"this program is not provided for ordinary people.only available for developers.\nif you are developer please goto this dir to release model.";
wchar_t ch_cn_1[] = L"提示";
wchar_t ch_en_1[] = L"tips";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
global_judge_thread_mouse_message_after_1 = 0;
_endthread();
}
if (m.x >= 100 && m.x <= 300 && m.y >= 225 && m.y <= 265)
{
wchar_t ch[] = L"code and open sourse license";
MessageBox(NULL, code_license_and_sourse_code, ch, MB_OK);
global_judge_thread_mouse_message_after_1 = 0;
_endthread();
}
if (m.x >= 100 && m.x <= 300 && m.y >= 275 && m.y <= 315)
{
if (global_language == 0)
{
temp = 1;
}
if (global_language == 1)
{
temp = 0;
}
cleardevice();
global_language = temp;
global_judge_thread_mouse_message_after_1 = 0;
_endthread();
}
if (m.x >= 100 && m.x <= 300 && m.y >= 325 && m.y <= 365)
{
wchar_t ch_cn_0[] = L"请输入所放置的垃圾种类:1.有害垃圾.2.厨余垃圾.3.可回收垃圾.4.其他垃圾";
wchar_t ch_en_0[] = L"please input which kind of garbage you put in .1.harmful 2.kitchen.3.recyclable.4.others";
wchar_t ch_cn_1[] = L"输入选项前的数字即可,否则会使本程序报错";
wchar_t ch_en_1[] = L"you can only input the number before all the choices,else it will go to some errors.";
wchar_t ch_cn_2[] = L"请输入数字";
wchar_t ch_en_2[] = L"please input the number";
wchar_t ch[10] = {'\0'};
if (global_language == 0)
{
InputBox(ch, 10, ch_cn_0, ch_cn_2, ch_cn_1, NULL, NULL, false);
}
if (global_language == 1)
{
InputBox(ch, 10, ch_en_0, ch_en_2, ch_en_1, NULL, NULL, false);
}
int b = _wtoi(ch);
FILE* f1 = fopen("connect.dll", "w+");
if (b == 1)
{
fprintf(f1, "harmful");
}
if (b == 2)
{
fprintf(f1,"kitchen");
}
if (b == 3)
{
fprintf(f1, "recycleable");
}
if (b == 4)
{
fprintf(f1, "others");
}
read_file(1);
global_judge_thread_mouse_message_after_1 = 0;
_endthread();
}
}
}
}
void read_file(int a)
{
int b = 0;
b = 0;
FILE* f1 = fopen("connect.dll", "rb");
char ch[10] = { '\0' };
fgets(ch, 10, f1);
if (ch[0] == 'h')
{
//printf("get the result :harmful\n");
wchar_t ch_cn_0[] = L"识别结果:有害垃圾";
wchar_t ch_en_0[] = L"get the result from the program:harmful garbage";
wchar_t ch_cn_1[] = L"识别结果";
wchar_t ch_en_1[] = L"the result this program predicted";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
b = 1;
}
if (ch[0] == 'k')
{
//printf("get the result :kitchen\n");
wchar_t ch_cn_0[] = L"识别结果:厨余垃圾";
wchar_t ch_en_0[] = L"get the result from the program:kitchen garbage";
wchar_t ch_cn_1[] = L"识别结果";
wchar_t ch_en_1[] = L"the result this program predicted";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
b = 2;
}
if (ch[0] == 'r')
{
//printf("get the result :recyclable\n");
wchar_t ch_cn_0[] = L"识别结果:可回收垃圾";
wchar_t ch_en_0[] = L"get the result from the program:recyclable garbage";
wchar_t ch_cn_1[] = L"识别结果";
wchar_t ch_en_1[] = L"the result this program predicted";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
b = 3;
}
if (ch[0] == 'o')
{
//printf("get the result : others\n");
wchar_t ch_cn_0[] = L"识别结果:其他垃圾";
wchar_t ch_en_0[] = L"get the result from the program:other garbage";
wchar_t ch_cn_1[] = L"识别结果";
wchar_t ch_en_1[] = L"the result this program predicted";
if (global_language == 0)
{
MessageBox(NULL, ch_cn_0, ch_cn_1, MB_OK);
}
if (global_language == 1)
{
MessageBox(NULL, ch_en_0, ch_en_1, MB_OK);
}
b = 4;
}
//printf("%d", a);
fclose(f1);
if (b != 1 && b != 2 && b != 3 && b != 4)
{
//printf("[ERROR : read connect.dll failed] can't read at 0x0000000000000000\n");
wchar_t ch_cn_1_1[] = L"无法读取connect.dll,0x0000000000000000内存操作异常";
wchar_t ch_en_1_1[] = L"[ERROR : read connect.dll failed ] can't read at 0x0000000000000000";
wchar_t ch_cn_1_2[] = L"警告";
wchar_t ch_en_1_2[] = L"warning";
}
//create_operatble_file(a);
//show(a);
}
void clear(void*)
{
while (1)
{
Sleep(1000);
clearrectangle(50, 430, 195, 465);
}
_endthread();
}
这个程序执行完成后,会生成以下几个python脚本
1.train.dll(实际上是python文件)
import torch
import visdom
from torch import optim, nn
from utils import Flatten
from Data_Pre import Data
from torch.utils.data import DataLoader
from torchvision.models import resnet18
batchsz=32
lr = 1e-4
epochs =20
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(1234)
train_db=Data('train_data',224,mode='train')
val_db=Data('train_data',224,mode='val')
test_db=Data('train_data',224,mode='test')
train_loader=DataLoader(train_db,batch_size=batchsz,shuffle=True,num_workers=4)
val_loader=DataLoader(val_db,batch_size=batchsz,num_worker=4)
test_loader=DataLoader(test_db,batch_size=batchsz,num_worker=4)
viz=visdom.Visdom()
def evalute():
model.eval()
correct=0
total=len(loader.dataset)
for x,y in loader:
x,y =x.to(device),y.to(device)
with torch.no_grad():
logits=model(x)
pred=logits.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item()
return correct / total
def main():
trained_model=resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],Flatten(),nn.Linear(512,6)).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
criteon=nn.CrossEntropyloss()
best_acc,best_epoch=0,0
viz.line([[0.0,0.0]],[0.],win='test',opts=dict(title='Loss on Training Data and Accuracy on Training Data',xlabel='Epochs',ylabel='Loss and Accuracy',legend=['loss','val_acc']))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y = x.to(device),y.to(device)
model.train()
logits=model(x)
loss=criteon(logits,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([[loss.item(),evaluate(model,val_loader)]],[global_step],win='test',upgrade='append')
grobal_step+=1
if epoch==0:
print('the '+str(epoch+1)+' epoch'+' training......')
val_acc=evaluate(model,val_loader)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),'best_trans.mdl')
print('best accuracy:',best_acc,'best epoch:',(best_epoch+1))
torch.save(model,'model.dll')
print('loading model......')
test_acc=evalute(model,test_loader)
print('test accuracy:',test_acc)
print('successfully save the best model ')
if __name__=='__main__':
main()
2.Test_model.dll(也是python文件)
import sys
import torch
from PIL import Image
from torchvision import transforms
import visdom
from torch import optim , nn
import os
classes=('harmful','kitch','others','recyc')
if torch.cuda.is_available():
device = torch.device('cuda')
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrops(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
else:
device = torch.device('cpu')
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
def predict(img_path):
if torch.cuda.is_available():
net=torch.load('model.dll',map_location='cuda')
net=net.to(device)
torch.no_grad()
img=Image.open(img_path)
img=transform(img).unsqueeze(0)
img_=img.to(device)
outputs=net(img_)
_,predicted=torch.max(outputs,1)
else:
net=torch.load('model.dll',map_location='cpu')
net=net.to(device)
torch.no_grad()
img=Image.open(img_path)
img=transform(img).unsqueeze(0)
img_=img.to(device)
outputs=net(img_)
_,predicted=torch.max(outputs,1)
print(classes[predicted[0]])
path='connect.dll'
if os.path.exists(path):
os.remove(path)
else:
print('successfully create the file:connect.dll')
if classes[predicted[0]]=='harmful':
#print('1')
create_file(1)
if classes[predicted[0]]=='kitch':
#print('2')
create_file(2)
if classes[predicted[0]]=='others':
#print('3')
create_file(3)
if classes[predicted[0]]=='recyc':
#print('4')
create_file(4)
def create_file(a):
if a==1:
try:
file=open('connect.dll','r+')
except FileNotFoundError:
file=open('connect.dll','a+')
if a==2:
try:
file=open('connect.dll','r+')
except FileNotFoundError:
file=open('connect.dll','a+')
if a==3:
try:
file=open('connect.dll','r+')
except FileNotFoundError:
file=open('connect.dll','a+')
if a==4:
try:
file=open('connect.dll','r+')
except FileNotFoundError:
file=open('connect.dll','a+')
write_file(a)
def write_file(a):
if a==1:
with open('connect.dll','a+',encoding='utf-8') as f:
text='harmful'
f.write(text)
if a==2:
with open('connect.dll','a+',encoding='utf-8') as f:
text='kitch'
f.write(text)
if a==3:
with open('connect.dll','a+',encoding='utf-8') as f:
text='others'
f.write(text)
if a==1:
with open('connect.dll','a+',encoding='utf-8') as f:
text='recyc'
f.write(text)
if __name__=='__main__':
predict('./test/1.jpg')
3.utils.py
import torch
from torch import nn
from matplotlib import pyplot as plt
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,x):
shape=torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1,shape)
def plot_image(img,label,name):
fig=plt.figure()
for i in range(6):
plt.subplot(2,3,i+1)
plt.tight_layout()
plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
plt.title('{}:{}'.format(name,label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
4.Data_Pre.py
import torch
import os,glob
import random,csv
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
class Data(Dataset):
def __init__(self,root,resize,mode):
super(Data,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
self.images,self.labels=self.load_csv('images.csv')
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len(self.labels))]
elif mode=='val':
self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else:
self.images=self.images[int(0.8*len(self.images)):]
self.labels=self.labels[int(0.8*len(self.images)):]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)):
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png'))
images+=glob.glob(os.path.join(self.root,name,'*.jpg'))
images+=glob.glob(os.path.join(self.root,name,'*.jpeg'))
print(len(images))
random.shuffle(images)
with open os.path.join(self.root,filename),mode='w',nemline='' as f:
writer=csv.writer(f)
for img in images:
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print('write into csv into :',filename)
images,labels=[],[]
with open(os.path.join(self.root,filename)) as f:
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
assert len(images)==len(labels)
return images,labels
def __len__(self):
return len(self.images)
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
x=x_hat*std+mean
return x
def __getitem__(self,idx):
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([lambda x:Image.open(x).convert('RGB'),transforms.Resize((int(self.resize*1.25),int(self,resize*1.25))),transforms.RandomRotation(15),transforms.CenterCrop(self,resize),transforms.Totensor(),transforms,Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])])
img=tf(img)
label=torch.tensor(label)
return img,label
def main():
db=Data('train_data',64,'train')
DataLoader(db,batch_size=32,shuffle=True,num_workers=8)
if __name__=='__main__':
main()