人工智能(AI)之朴素贝叶斯(NB)的基本实现

训练集测试集下载地址
具体的公式我就不一一描述了,主要看下图大概就能理解,主要是基于条件概率来实现的,最底下也有一个关于具体介绍的链接:
这里写图片描述

#include <iostream>
#include <fstream>
#include <cstring>
#include <cstdlib>
#include <sstream>
#include <string.h>
#include <set>
#include <cmath>
#include <iterator>
#include <queue>
#include <map>

using namespace std;
#define ANGER 0
#define DISGUST 1
#define FEAR 2
#define JOY 3
#define SAD 4
#define SURPRISE 5 
const double lapace = 0.09;

char c[300];
priority_queue<double,vector<double>,greater<double> >q;
map<double,int>map1; //从小到大
map<double,int, greater<double> >map2; //从大到小double> >两者空格不可少 
const string Str1 = "train", Str2 = "test";
set<string> sets;
bool vector_old[2000][4000];
double vector2[2000][4000];
double proba[9][2000];
double newproba[9][2000];
double dis_save[2000];
double K;
int num1=0;

void readanger()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/anger_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[ANGER][i++] = d;
    }
    in.close();
}

void readdisgust()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/disgust_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[DISGUST][i++] = d;
    }
    in.close();
}

void readfear()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/fear_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[FEAR][i++] = d;
    }
    in.close();
}

void readjoy()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/joy_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[JOY][i++] = d;
    }
    in.close();
}

void readsad()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/sad_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[SAD][i++] = d;
    }
    in.close();
}

void readsurprise()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/gold_train/surprise_train.txt");
    int i = 0;
    while (in && i < 246){
        memset(c, 0, sizeof(c));
        in.getline(c, 300);
        string s;
        s.append(c, 300);
        stringstream ss(s);
        ss >> s; // 第一个单词不用
        double d;
        ss >> d;
        proba[SURPRISE][i++] = d;
    }
    in.close();
}

void get_proba()
{
    readanger();
    readdisgust();
    readfear();
    readsad();
    readjoy();
    readsurprise();
}

void get_word()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/anger.txt");
    string str;
    int i = 0;
    if(in&&out)
    {
        while(getline(in,str))
        {
            if(i==0)
            {
                i++;
                continue;
            }
            else 
            {
                int j = 0;
                stringstream ss;
                ss << str;
                while(!ss.eof())
                {
                    {
                        if(j==0)
                        {
                            j++;
                            ss >> str;
                            str = " ";
                            sets.insert(str);
                        }
                        //cout << str <<endl;
                        else
                        {
                            ss >> str;
                            sets.insert(str);
                        }
                    }
                }
            }
        }
    }else{
        cerr<<"open in or out file error"<<endl;
    }

    for(set<string>::iterator it = sets.begin();it != sets.end();it++)
    {
        if(*it != " ")
        {
            out << *it << endl;
            //cout << *it << endl;
        }

    }
    in.close();
    out.close();
}


void clear_stopwords()
{
    fstream in;
    in.open("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplist (1).txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Foxstoplistout.txt");
    string str;
    if(in)
    {
        while(getline(in,str))
        {
            stringstream ss;
            ss << str;
            while(!ss.eof())
            {
                ss >> str;
                out << str <<endl;
                for(set<string>::iterator it = sets.begin();it != sets.end();)
                {
                    if(*it == str)
                    {
                        sets.erase(it);
                        break;
                    }
                    else
                    {
                        it++;
                    }
                }
            }
        }
    }
    in.close();
    out.close();
}


void vector_out()
{
    ifstream in("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/Dataset_words.txt");
    ofstream out("G:/桌面文档/学/大三上学期/第二学期/人工智能/实验/Lab 2 实验材料/vector.txt");
    string str;
    int i = 0;
    int row_num = 0;
    while(in&&out)
    {
        while(getline(in,str))
        {

            if(i==0)
            {
                i++;
                continue;
            }
            else
            {
                int j = 0;
                stringstream ss;
                ss << str;
                while(!ss.eof())
                {
                    int lin_num = 0;
                    if(j==0)
                    {
                        j++;
                        ss >> str;
                    }
                    else
                    {
                        ss >> str;
                        for(set<string>::iterator it=sets.begin(); it != sets.end() ; it++)
                        {
                            if(*it == str)
                            {
                                vector_old[row_num][lin_num] = true;
                            }
                            lin_num++;
                        }
                    }
                }
            }
            row_num++;
        }
    }
    string wenben = "文本编号 ";
    out << wenben;
    for(set<string>::iterator it= sets.begin(); it != sets.end(); it++)
    {
        out << *it << " ";
    }
    in.close();
    out.close();
}


void compute_dis()
{
    for (int i = 0; i < 1246; i++){
    double sum = 0;
    for (int j = 0; j < sets.size(); j++){
        if (vector_old[i][j]) 
        {
            sum++;
        }
    }       
    for (int j = 0; j < sets.size(); j++){
        vector2[i][j] = vector_old[i][j]*1.0/sum;
        //out << vector2[i][j] << " ";
    }
        //out <<endl;
    }
    double newpro_sum[1009] = {0};
    for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
    {

            for(int i = 0 ; i < 1000 ; i++)
            {
                double dis_sum = 0;
                double pro_sum = 0;
                double dis;
                int pos;
                for(int j = 0 ; j < 246 ; j++)
                {
                    double same_words = 1;
                    for(int k = 0 ; k < sets.size(); k++)
                    {
                        if(vector2[i+246][k] > 0)
                        {
                            if(vector2[j][k] == 0)
                            {
                                same_words*=lapace;
                            }
                            else
                            {
                                same_words*=vector2[j][k];
                            }
                        }
                    }
                    if(proba[j] > 0)
                    {
                        pro_sum+=same_words*proba[mood_n][j];
                    }
                    else
                    {
                        pro_sum+=same_words*lapace;
                    }
                }
                newproba[mood_n][i] = pro_sum;
            }
    }

    for(int i = 0 ; i < 1000; i++){
        for(int mood_n = 0 ; mood_n < 6 ; mood_n++){
            newpro_sum[i]+=newproba[mood_n][i];
        }
    }

    for(int mood_n = 0 ; mood_n < 6 ; mood_n++)
    {
        for(int i = 0 ; i < 1000 ; i++)
        {
            //cout << newpro_sum <<endl;
            newproba[mood_n][i] = newproba[mood_n][i] / newpro_sum[i];
        }
    }

    cout << "happy" <<endl;
}

void print()
{
    for(int i = 0 ; i < 6 ; i++)
    {

        ofstream f;
        switch(i)
        {
            case ANGER:    f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/anger_predict.txt"); break;
            case DISGUST:  f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/disgust_predict.txt"); break;
            case FEAR:     f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/fear_predict.txt"); break;
            case JOY:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/joy_predict.txt"); break;
            case SAD:      f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/sad_predict.txt"); break;
            case SURPRISE: f.open("C:/Users/windowos 7/Desktop/AILab/predict_test/surprise_predict.txt"); break;
        }
        for(int j = 0 ; j < 1000 ; j++)
        {
            f << newproba[i][j] <<endl;
            //cout << newproba[i][j] <<endl;
        }
        f.close();
    }
}


int main()
{
    get_word();
    cout << 0 <<endl;
    clear_stopwords();
    cout << 1 <<endl;
    get_proba();
    cout << 2 <<endl;
    vector_out();
    cout << 3 <<endl;
    compute_dis();
    cout << 4 <<endl;
    print();
    cout << 5 <<endl;
    cout << sets.size() <<endl;

    return 0;
}

NB参考网址

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值