时间序列分类算法代码分析之模式袋(Bag Of Patterns, BOP)

01:package timeseriesweka.classifiers;
02:
03:import utilities.ClassifierTools;
04:import weka.classifiers.Classifier;
05:import weka.classifiers.lazy.kNN;
06:import weka.core.Capabilities;
07:import weka.core.FastVector;
08:import weka.core.Instance;
09:import weka.core.Instances;
10:import weka.core.SparseInstance;
11:import weka.core.TechnicalInformation;
12:import timeseriesweka.filters.BagOfPatternsFilter;
13:import timeseriesweka.filters.SAX;
14:
15:/**
16: * Converts instances into Bag Of Patterns form, then gives to a 1NN 
17: * 
18: * Params: wordLength, alphabetSize, windowLength
19: * 
20: * @author James
21: */
22:public class BagOfPatterns extends AbstractClassifierWithTrainingData{
23:
24:    public Instances matrix;
25:    public kNN knn;
26:    
27:    private BagOfPatternsFilter bop;//模式袋,用于符号化时间序列
28:    private int PAA_intervalsPerWindow;//PAA窗口
29:    private int SAX_alphabetSize;//字母表大小
30:    private int windowSize;//滑动窗口
31:    
32:    private FastVector alphabet;//字母表
33:    
34:    private final boolean useParamSearch; //是否选择最优参数    
35:    /**
36:     * No params given, do parameter search
37:     */
38:    public BagOfPatterns() {
39:        this.PAA_intervalsPerWindow = -1;
40:        this.SAX_alphabetSize = -1;
41:        this.windowSize = -1;
42:
43:        knn = new kNN(); //defaults to 1NN, Euclidean distance
44:
45:        useParamSearch=true;
46:    }
47:    
48:    /**
49:     * Params given, use those only
50:     */
51:    public BagOfPatterns(int PAA_intervalsPerWindow, int SAX_alphabetSize, int windowSize) {
52:        this.PAA_intervalsPerWindow = PAA_intervalsPerWindow;
53:        this.SAX_alphabetSize = SAX_alphabetSize;
54:        this.windowSize = windowSize;
55:        
56:        bop = new BagOfPatternsFilter(PAA_intervalsPerWindow, SAX_alphabetSize, windowSize);       
57:        knn = new kNN(); //default to 1NN, Euclidean distance
58:        alphabet = SAX.getAlphabet(SAX_alphabetSize);
59:        useParamSearch=false;
60:    }
61:    
62:    public int getPAA_intervalsPerWindow() {
63:        return PAA_intervalsPerWindow;
64:    }
65:
66:    public int getSAX_alphabetSize() {
67:        return SAX_alphabetSize;
68:    }
69:
70:    public int getWindowSize() {
71:        return windowSize;
72:    }
73:    
74:    //搜索最优参数
75:    public static int[] parameterSearch(Instances data)throws Exception {
76:
77:        int minWinSize = (int)((data.numAttributes()-1) * (15.0/100.0));
78:        int maxWinSize = (int)((data.numAttributes()-1) * (36.0/100.0));
79:        //      int winInc = 1; //默认窗口增量为1
80:        //自适应窗口增量,最多10个窗口
81:        int winInc = (int)((maxWinSize - minWinSize) / 10.0); 
82:        if (winInc < 1) winInc = 1;
83:        double bestAcc = 0.0;
84:        
85:        //default to min of each para range
86:        // train set consists of one instance from each class,making it
87:        //impossible to correctly classify using nearest neighbour
88:        int bestAlpha = 2, bestWord = 2, bestWindowSize = minWinSize;
89:        //遍历所有可能的参数组合
90:        for (int alphaSize = 2; alphaSize <= 8; alphaSize++) {
91:            for (int winSize = minWinSize; winSize <= maxWinSize; winSize+=winInc) {
92:                for (int wordSize = 2; wordSize <= winSize/2; wordSize*=2) { //lin BoP suggestion
93:                    BagOfPatterns bop = new BagOfPatterns(wordSize, alphaSize, winSize);
94:                    double acc = bop.crossValidate(data); //leave-one-out without rebuiding every fold
95:                    
96:                    if (acc > bestAcc) {
97:                        bestAcc = acc;
98:                        bestAlpha = alphaSize;
99:                        bestWord = wordSize;
100:                        bestWindowSize = winSize;
101:                    }
102:                }
103:            }
104:        }
105:        //返回最优参数组合
106:        return new int[] { bestWord, bestAlpha, bestWindowSize};
107:    }
108:    
109:    
110:    @Override
111:    public void buildClassifier(final Instances data) throws Exception {
112:        trainResults.buildTime=System.currentTimeMillis();
113:        if (data.classIndex() != data.numAttributes()-1)
114:          throw new Exception("LinBoP_BuildClassifier: Class attribute);
115:        if (useParamSearch) {//若选择搜索最优参数,则调用参数搜索方法
116:            int[] params = parameterSearch(data);
117:            
118:            this.PAA_intervalsPerWindow = params[0];
119:            this.SAX_alphabetSize = params[1];
120:            this.windowSize = params[2];
121:            
122:            bop = new BagOfPatternsFilter(PAA_intervalsPerWindow, SAX_alphabetSize, windowSize);
123:            alphabet = SAX.getAlphabet(SAX_alphabetSize);
124:        }
125:        
126:        //验证参数的合法性
127:        if (PAA_intervalsPerWindow<0)
128:         throw new Exception("LinBoP_BuildClassifier: Invalid PAA word);
129:        if (PAA_intervalsPerWindow>windowSize)
130:         throw new Exception("LinBoP_BuildClassifier: Invalid PAA word);
131:        if (SAX_alphabetSize<0 || SAX_alphabetSize>10)
132:            throw new Exception("LinBoP_BuildClassifier: Invalid SAX);
133:        if (windowSize<0 || windowSize>data.numAttributes()-1)
134:          throw new Exception("LinBoP_BuildClassifier:Invalid sliding );
135:        
136:        //real work
137:        matrix = bop.process(data); //调用方法将时间序列转化为模式袋矩阵
138:        knn.buildClassifier(matrix); //训练最近邻分类器
139:        trainResults.buildTime=System.currentTimeMillis()-trainResults.buildTime;
140:        
141:    }
142:
143:    @Override
144:    public double classifyInstance(Instance instance) throws Exception {
145:        //convert to BOP form
146:        double[] hist = bop.bagToArray(bop.buildBag(instance));
147:        
148:        //stuff into Instance
149:        Instances newInsts = new Instances(matrix, 1); //copy attribute data
150:        newInsts.add(new SparseInstance(1.0, hist));
151:        
152:        return knn.classifyInstance(newInsts.firstInstance());
153:    }
154:  }
上述代码中的BagOfPatternsFilter的作用是将时间序列进行符号化,也就是将要介绍的基于SAX的时间序列符号化,下面详细介绍实现代码。
01:public class BagOfPatternsFilter extends SimpleBatchFilter {
02:
03:    public TreeSet<String> dictionary;//用于存储时间序列符号化之后的字典
04:    
05:    private final int windowSize;
06:    private final int numIntervals;
07:    private final int alphabetSize;
08:    private boolean useRealAttributes = true;
09:    //是否使用维度削减,即相邻单词如果相同则不进行统计
10:    private boolean numerosityReduction = false;      
11:    private FastVector alphabet = null;//字母表
12:    
13:    private static final long serialVersionUID = 1L;
14:
15:    public BagOfPatternsFilter(int PAA_intervalsPerWindow, int SAX_alphabetSize, int windowSize) {
16:        this.numIntervals = PAA_intervalsPerWindow;
17:        this.alphabetSize = SAX_alphabetSize;
18:        this.windowSize = windowSize;
19:        
20:        alphabet = SAX.getAlphabet(SAX_alphabetSize);
21:    }
22:    
23:    public int getWindowSize() {
24:        return numIntervals;
25:    }
26:    
27:    public int getNumIntervals() {
28:        return numIntervals;
29:    }
30:
31:    public int getAlphabetSize() {
32:        return alphabetSize;
33:    }
34:    
35:    public void useRealValuedAttributes(boolean b){
36:        useRealAttributes = b;
37:    }
38:    
39:    public void performNumerosityReduction(boolean b){
40:        numerosityReduction = b;
41:    }
42:    //建立直方图表示
43:    private HashMap<String, Integer> buildHistogram(LinkedList<double[]> patterns) {
44:        
45:        HashMap<String, Integer> hist = new HashMap<>();
46:        //将子序列转化为单词
47:        for (double[] pattern : patterns) {   
48:            //convert to string                
49:            String word = "";
50:            //将每个数值进行离散化处理,转化为字符
51:            for (int j = 0; j < pattern.length; ++j)
52:                word += (String) alphabet.get((int)pattern[j]);
53:            //统计直方图中该单词出现的次数
54:            Integer val = hist.get(word);
55:            if (val == null)
56:                val = 0;
57:            //更新直方图
58:            hist.put(word, val+1);
59:        }
60:        
61:        return hist;
62:    }
63:    //将时间序列转化为直方图表示
64:    public HashMap<String, Integer> buildBag(Instance series) throws Exception {
65:       
66:        LinkedList<double[]> patterns = new LinkedList<>();
67:        
68:        double[] prevPattern = new double[windowSize];
69:        for (int i = 0; i < windowSize; ++i) 
70:            prevPattern[i] = -1;
71:        //利用滑动窗口提取子序列
72:        for (int windowStart = 0; windowStart+windowSize-1 < series.numAttributes()-1; ++windowStart) { 
73:            double[] pattern = slidingWindow(series, windowStart);
74:            
75:            try {
76:                NormalizeCase.standardNorm(pattern);
77:            } catch(Exception e) {
78:                //throws exception if zero variance
79:                //if zero variance, all values in window the same 
80:                for (int j = 0; j < pattern.length; ++j)
81:                    pattern[j] = 0;
82:            }
83:            //利用PAA技术对时间序列子序列进行转化
84:            pattern = SAX.convertSequence(pattern, alphabetSize, numIntervals);
85:            //如果使用维度削减则进行判断
86:            if (!(numerosityReduction && identicalPattern(pattern, prevPattern)))
87:                patterns.add(pattern);
88:        }
89:        
90:        return buildHistogram(patterns);
91:    }
92:    //滑动窗口提取子序列
93:    private double[] slidingWindow(Instance series, int windowStart) {
94:        double[] window = new double[windowSize];
95:
96:        for (int i = 0; i < windowSize; ++i)
97:            window[i] = series.value(i + windowStart);
98:        
99:        return window;
100:    }
101:    //判断两个数组数值是否相同,用于维度削减
102:    private boolean identicalPattern(double[] a, double[] b) {
103:        for (int i = 0; i < a.length; ++i)
104:            if (a[i] != b[i])
105:                return false;
106:        
107:        return true;
108:    }
109:  //格式化输出格式,即转化后的实例格式
110:    @Override
111:    protected Instances determineOutputFormat(Instances inputFormat)
112:            throws Exception {
113:        
114:       
115:        for (int i = 0; i < inputFormat.numAttributes(); i++) {
116:            if (inputFormat.classIndex() != i) {
117:                if (!inputFormat.attribute(i).isNumeric()) {
118:                    throw new Exception("Non numeric attribute not allowed for BoP conversion");
119:                }
120:            }
121:        }
122:
123:        FastVector attributes = new FastVector();
124:        for (String word : dictionary) //按照字典来格式化属性
125:            attributes.add(new Attribute(word));
126:        
127:        Instances result = new Instances("BagOfPatterns_" + inputFormat.relationName(), attributes, inputFormat.numInstances());
128:        
129:        if (inputFormat.classIndex() >= 0) {	//Classification set, set class 
130:            //类属性			
131:            Attribute target = inputFormat.attribute(inputFormat.classIndex());
132:
133:            FastVector vals = new FastVector(target.numValues());
134:            for (int i = 0; i < target.numValues(); i++) {
135:                vals.addElement(target.value(i));
136:            }
137:            
138:            result.insertAttributeAt(new Attribute(inputFormat.attribute(inputFormat.classIndex()).name(), vals), result.numAttributes());
139:            result.setClassIndex(result.numAttributes() - 1);
140:        }
141: 
142:        return result;
143:    }
144:
145:    @Override
146:    public String globalInfo() {
147:        return null;
148:    }
149:    //处理的主函数
150:    @Override
151:    public Instances process(final Instances input) 
152:            throws Exception {
153:        
154:        ArrayList< HashMap<String, Integer> > bags = new ArrayList<>(input.numInstances());
155:        dictionary = new TreeSet<>();
156:        
157:        for (int i = 0; i < input.numInstances(); i++) {
158:            bags.add(buildBag(input.get(i)));//将时间序列转化为直方图表示
159:            dictionary.addAll(bags.get(i).keySet());//添加单词进字典
160:        }
161:        
162:        Instances output = determineOutputFormat(input); 
163:        //遍历词袋
164:        Iterator<HashMap<String, Integer> > it = bags.iterator();
165:        int i = 0;
166:        while (it.hasNext()) {
167:            double[] bag = bagToArray(it.next());
168:            it.remove(); 
169:            output.add(new SparseInstance(1.0, bag));
170:            output.get(i).setClassValue(input.get(i).classValue());
171:            ++i;
172:        }
173:        
174:        return output;
175:    }
176:   //将直方图按照字典顺序转化为数组
177:    public double[] bagToArray(HashMap<String, Integer> bag) {
178:        double[] res = new double[dictionary.size()];
179:            
180:        int j = 0;
181:        for (String word : dictionary) {
182:            Integer val = bag.get(word);
183:            if (val != null)
184:                res[j] += val;
185:            ++j;
186:        }
187:
188:        return res;
189:    }
190:
191:    public String getRevision() {
192:        // TODO Auto-generated method stub
193:        return null;
194:    }
195:}

 

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值