MapReduce之使用马尔可夫模型的智能邮件营销(完)
接着上一篇博文MapReduce之使用马尔可夫模型的智能邮件营销(四),在这个阶段中,通过一个简单的Java程序生成马尔可夫模型的状态转移矩阵并进行预测
整个过程分为三步:
- 1、读取MapReduce生成的状态转移实例数
- 2、根据状态转移实例数得到状态转移矩阵
- 3、使用状态转移矩阵对顾客交易信息进行预测
读取MapReduce生成的状态转移实例数
在这里使用的是正则表达式匹配到目标文本,对目标文本进行读取并存入特定文档中,因为代码没有在集群上运行,所以直接操作的本地目录,与在集群上大同小异
文件如下:
public static void stateModelBuilder() throws IOException {
String curDir=System.getProperty("user.dir");
File dir=new File(curDir);
File[] files=dir.listFiles();
FileFilter fileFilter=new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.isDirectory()||pathname.isFile();
}
};
files=dir.listFiles(fileFilter);
if(files.length==0){
System.out.println("目录不存在或者不是一个目录");
}else{
for(int i=0;i<files.length;i++){
System.out.println(files[i].toString());
if(files[i].toString().contains("MarkovState")){
aimDir=files[i].toString();
}
}
}
dir=new File(aimDir);
aimFiles=dir.listFiles(fileFilter);
String pattern=".*\\d{5}$";
BufferedWriter out=new BufferedWriter(new FileWriter("MarkovState/MarkovState.txt"));
if(aimFiles.length==0){
System.out.println("目录不存在或它不是一个目录");
}else{
for(int i=0;i<aimFiles.length;i++){
File filename=aimFiles[i];
if(Pattern.matches(pattern,filename.toString())){
BufferedReader in=new BufferedReader(new FileReader(filename));
String str;
while((str=in.readLine())!=null){
out.write(str+"\n");
System.out.println(str);
}
}
}
}
out.close();
System.out.println("文件创建成功");
generateStateTransitionTable();
}
生成状态转移矩阵
通过map函数建立每个状态的映射关系来实现对矩阵的行列进行操作,读取文件中的信息并赋值到矩阵中,使用拉普拉斯方法对矩阵进行修正,接下来计算矩阵中的每个值在该行的比重,即可得到状态转移矩阵并序列化表示出来
private void initStates(){
states=new HashMap<String, Integer>();
states.put("SL",0);
states.put("SE", 1);
states.put("SG", 2);
states.put("ML", 3);
states.put("ME", 4);
states.put("MG", 5);
states.put("LL", 6);
states.put("LE", 7);
states.put("LG", 8);
}
public StateBuilder(int numberOfStates){
this.numberOfStates=numberOfStates;
table=new double[numberOfStates][numberOfStates];
initStates();
}
public StateBuilder(int numberOfStates,int scale){
this(numberOfStates);
this.scale=scale;
}
public void add(String fromState,String toState,int count){
int row=states.get(fromState);
int column=states.get(toState);
table[row][column]=count;
}
// 拉普拉斯修正,如果一行存在为0的系数,该行所有值+1
public void normalizeRows(){
for(int r=0; r<numberOfStates;r++){
boolean gotZeroCount=false;
for(int c=0;c<numberOfStates;c++){
if(table[r][c]==0){
gotZeroCount=true;
break;
}
}
if(gotZeroCount){
for(int c=0;c<numberOfStates;c++){
table[r][c]+=1;
}
}
}
//计算状态转移概率
for(int r=0;r<numberOfStates;r++){
double rowSum=getRowSum(r);
for(int c=0;c<numberOfStates;c++)
table[r][c]=table[r][c]/rowSum;
}
}
public double getRowSum(int rowNumber){
double sum=0.0;
for(int column=0;column<numberOfStates;column++){
sum+=table[rowNumber][column];
}
return sum;
}
//序列化输出
public String serializeRow(int rowNumber){
StringBuilder builder=new StringBuilder();
for(int column=0;column<numberOfStates;column++){
double element=table[rowNumber][column];
builder.append(String.format("%.4g",element));
if(column<(numberOfStates-1)){
builder.append(",");
}
}
return builder.toString();
}
public void persistTable(){
for(int row=0;row<numberOfStates;row++){
String serializedRow=serializeRow(row);
System.out.println(serializedRow);
}
}
public static void generateStateTransitionTable(){
try {
BufferedReader in=new BufferedReader(new FileReader("MarkovState/MarkovState.txt"));
List<TableItem> list=new ArrayList<TableItem>();
StateBuilder tableBuilder=new StateBuilder(9);
String str = null;
while((str=in.readLine())!=null){
System.out.println(str);
list.add(new TableItem(str));
}
for(TableItem item:list){
tableBuilder.add(item.getFromState(),item.getToState(),item.getCount());
}
tableBuilder.normalizeRows();
tableBuilder.persistTable();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch(Exception e){
e.printStackTrace();
}
}
预测
生成状态转移矩阵后,还需要做点小处理才能实现预测功能,首先对顾客交易数据的日期和货物数量按照日期大小进行排序,然后继续生成交易序列(生成交易序列的过程与MapReduce第二阶段相同,对交易序列与状态转移矩阵进行计算得到下一个预测日期
在这里直接使用MapReduce在第一阶段直接排序好的数据即可,数据如下:
public static void main(String[] args) throws IOException {
stateModelBuilder();
BufferedReader br=new BufferedReader(new FileReader("output/part-r-00000"));
String str;
int t=0;
while((str=br.readLine())!=null&&t<5){
t++;
String[] tokens=str.split(",");
if(tokens.length<5){
return ;
}
StringBuffer sequence=new StringBuffer();
customerID=tokens[0];
int i=4;
String dd,ad;
while(i<tokens.length){
amount=Integer.parseInt(tokens[i]);
priorAmount=Integer.parseInt(tokens[i-2]);
try {
date= DateUtil.getDateAsMilliSeconds(tokens[i-1]);
lastDate=DateUtil.getDateAsMilliSeconds(tokens[i-1]);
priorDate=DateUtil.getDateAsMilliSeconds(tokens[i-3]);
} catch (Exception e) {
e.printStackTrace();
}
daysDiff=(date-priorDate)/aDay;
amountDiff=amount-priorAmount;
if(daysDiff<30){
dd="S";
}else if(daysDiff<60){
dd="M";
}else {
dd="L";
}
if(priorAmount<0.9*amount){
ad="L";
}else if(priorAmount<1.1*amount){
ad="E";
}else{
ad="G";
}
sequence.append(dd).append(ad).append(",");
i+=2;
}
String line=sequence.deleteCharAt(sequence.length()-1).toString();
String[] items=line.split(",");
String last=items[items.length-1];
//实现预测过程,保存最后交易日期方便进行下一次预测
int row_index=states.get(last);
double max_col=Arrays.stream(table[row_index]).max().getAsDouble();
int col_index=indexOfMax(table[row_index],max_col);
String[] states=new String[]{"SL", "SE", "SG", "ML", "ME", "MG", "LL", "LE", "LG"};
String next_state=states[col_index];
long nextDate=0;
// 15,30,90为根据样例得到的值,在实际情况中可根据需要修改
if(next_state.charAt(0)=='S'){
nextDate=lastDate+15*aDay;
}else if(next_state.charAt(0)=='M'){
nextDate=lastDate+30*aDay;
}else{
nextDate=lastDate+90*aDay;
}
System.out.println("row_index : "+row_index+" max_col : "+max_col+" col_index : "+col_index);
System.out.println("last Date "+DateUtil.getDateAsString(lastDate)+" next Date"+DateUtil.getDateAsString(nextDate));
System.out.println("the customerID is "+customerID+", the next date when send e-mail is "+SIMPLE_DATE_FORMAT.format(nextDate));
}
}
private static int indexOfMax(double[] a,double num){
int i;
for(i=0;i<a.length;i++){
if(Math.abs(a[i]-num)<=0.00005){
break;
}
}
return i;
}
运行结果如下
完整代码如下
package com.deng.MarkovState;
import com.deng.util.DateUtil;
import org.junit.Test;
import java.io.*;
import java.text.SimpleDateFormat;
import java.util.*;
import java.util.regex.Pattern;
public class StateBuilder {
private static Map<String,Integer> states=null;
private static double[][] table=null;
private int numberOfStates;
private int scale=100;
static final String DATE_FORMAT="yyyy-MM-dd";
static final SimpleDateFormat SIMPLE_DATE_FORMAT=new SimpleDateFormat(DATE_FORMAT);
private static String aimDir;
private static File[] aimFiles;
private static String customerID;
private static int amount,priorAmount;
private static long date,priorDate,lastDate,daysDiff,amountDiff;
private static long aDay=24*60*60*1000;
private void initStates(){
states=new HashMap<String, Integer>();
states.put("SL",0);
states.put("SE", 1);
states.put("SG", 2);
states.put("ML", 3);
states.put("ME", 4);
states.put("MG", 5);
states.put("LL", 6);
states.put("LE", 7);
states.put("LG", 8);
}
public StateBuilder(int numberOfStates){
this.numberOfStates=numberOfStates;
table=new double[numberOfStates][numberOfStates];
initStates();
}
public StateBuilder(int numberOfStates,int scale){
this(numberOfStates);
this.scale=scale;
}
public void add(String fromState,String toState,int count){
int row=states.get(fromState);
int column=states.get(toState);
table[row][column]=count;
}
// 拉普拉斯修正,如果一行存在为0的系数,该行所有值+1
public void normalizeRows(){
for(int r=0; r<numberOfStates;r++){
boolean gotZeroCount=false;
for(int c=0;c<numberOfStates;c++){
if(table[r][c]==0){
gotZeroCount=true;
break;
}
}
if(gotZeroCount){
for(int c=0;c<numberOfStates;c++){
table[r][c]+=1;
}
}
}
//计算状态转移概率
for(int r=0;r<numberOfStates;r++){
double rowSum=getRowSum(r);
for(int c=0;c<numberOfStates;c++)
table[r][c]=table[r][c]/rowSum;
}
}
public double getRowSum(int rowNumber){
double sum=0.0;
for(int column=0;column<numberOfStates;column++){
sum+=table[rowNumber][column];
}
return sum;
}
//序列化输出
public String serializeRow(int rowNumber){
StringBuilder builder=new StringBuilder();
for(int column=0;column<numberOfStates;column++){
double element=table[rowNumber][column];
builder.append(String.format("%.4g",element));
if(column<(numberOfStates-1)){
builder.append(",");
}
}
return builder.toString();
}
public void persistTable(){
for(int row=0;row<numberOfStates;row++){
String serializedRow=serializeRow(row);
System.out.println(serializedRow);
}
}
public static void generateStateTransitionTable(){
try {
BufferedReader in=new BufferedReader(new FileReader("MarkovState/MarkovState.txt"));
List<TableItem> list=new ArrayList<TableItem>();
StateBuilder tableBuilder=new StateBuilder(9);
String str = null;
while((str=in.readLine())!=null){
System.out.println(str);
list.add(new TableItem(str));
}
for(TableItem item:list){
tableBuilder.add(item.getFromState(),item.getToState(),item.getCount());
}
tableBuilder.normalizeRows();
tableBuilder.persistTable();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch(Exception e){
e.printStackTrace();
}
}
@Test
public static void stateModelBuilder() throws IOException {
String curDir=System.getProperty("user.dir");
File dir=new File(curDir);
File[] files=dir.listFiles();
FileFilter fileFilter=new FileFilter() {
@Override
public boolean accept(File pathname) {
return pathname.isDirectory()||pathname.isFile();
}
};
files=dir.listFiles(fileFilter);
if(files.length==0){
System.out.println("目录不存在或者不是一个目录");
}else{
for(int i=0;i<files.length;i++){
System.out.println(files[i].toString());
if(files[i].toString().contains("MarkovState")){
aimDir=files[i].toString();
}
}
}
dir=new File(aimDir);
aimFiles=dir.listFiles(fileFilter);
String pattern=".*\\d{5}$";
BufferedWriter out=new BufferedWriter(new FileWriter("MarkovState/MarkovState.txt"));
if(aimFiles.length==0){
System.out.println("目录不存在或它不是一个目录");
}else{
for(int i=0;i<aimFiles.length;i++){
File filename=aimFiles[i];
if(Pattern.matches(pattern,filename.toString())){
BufferedReader in=new BufferedReader(new FileReader(filename));
String str;
while((str=in.readLine())!=null){
out.write(str+"\n");
System.out.println(str);
}
}
}
}
out.close();
System.out.println("文件创建成功");
generateStateTransitionTable();
}
@Test
public static void main(String[] args) throws IOException {
stateModelBuilder();
BufferedReader br=new BufferedReader(new FileReader("output/part-r-00000"));
String str;
int t=0;
while((str=br.readLine())!=null&&t<5){
t++;
String[] tokens=str.split(",");
if(tokens.length<5){
return ;
}
StringBuffer sequence=new StringBuffer();
customerID=tokens[0];
int i=4;
String dd,ad;
while(i<tokens.length){
amount=Integer.parseInt(tokens[i]);
priorAmount=Integer.parseInt(tokens[i-2]);
try {
date= DateUtil.getDateAsMilliSeconds(tokens[i-1]);
lastDate=DateUtil.getDateAsMilliSeconds(tokens[i-1]);
priorDate=DateUtil.getDateAsMilliSeconds(tokens[i-3]);
} catch (Exception e) {
e.printStackTrace();
}
daysDiff=(date-priorDate)/aDay;
amountDiff=amount-priorAmount;
if(daysDiff<30){
dd="S";
}else if(daysDiff<60){
dd="M";
}else {
dd="L";
}
if(priorAmount<0.9*amount){
ad="L";
}else if(priorAmount<1.1*amount){
ad="E";
}else{
ad="G";
}
sequence.append(dd).append(ad).append(",");
i+=2;
}
String line=sequence.deleteCharAt(sequence.length()-1).toString();
String[] items=line.split(",");
String last=items[items.length-1];
int row_index=states.get(last);
double max_col=Arrays.stream(table[row_index]).max().getAsDouble();
int col_index=indexOfMax(table[row_index],max_col);
String[] states=new String[]{"SL", "SE", "SG", "ML", "ME", "MG", "LL", "LE", "LG"};
String next_state=states[col_index];
long nextDate=0;
if(next_state.charAt(0)=='S'){
nextDate=lastDate+15*aDay;
}else if(next_state.charAt(0)=='M'){
nextDate=lastDate+30*aDay;
}else{
nextDate=lastDate+90*aDay;
}
System.out.println("row_index : "+row_index+" max_col : "+max_col+" col_index : "+col_index);
System.out.println("last Date "+DateUtil.getDateAsString(lastDate)+" next Date"+DateUtil.getDateAsString(nextDate));
System.out.println("the customerID is "+customerID+", the next date when send e-mail is "+SIMPLE_DATE_FORMAT.format(nextDate));
}
}
private static int indexOfMax(double[] a,double num){
int i;
for(i=0;i<a.length;i++){
if(Math.abs(a[i]-num)<=0.00005){
break;
}
}
return i;
}
}