01:package timeseriesweka.classifiers;
02:
03:import utilities.ClassifierTools;
04:import weka.classifiers.Classifier;
05:import weka.classifiers.lazy.kNN;
06:import weka.core.Capabilities;
07:import weka.core.FastVector;
08:import weka.core.Instance;
09:import weka.core.Instances;
10:import weka.core.SparseInstance;
11:import weka.core.TechnicalInformation;
12:import timeseriesweka.filters.BagOfPatternsFilter;
13:import timeseriesweka.filters.SAX;
14:
15:/**
16: * Converts instances into Bag Of Patterns form, then gives to a 1NN
17: *
18: * Params: wordLength, alphabetSize, windowLength
19: *
20: * @author James
21: */
22:public class BagOfPatterns extends AbstractClassifierWithTrainingData{
23:
24: public Instances matrix;
25: public kNN knn;
26:
27: private BagOfPatternsFilter bop;//模式袋,用于符号化时间序列
28: private int PAA_intervalsPerWindow;//PAA窗口
29: private int SAX_alphabetSize;//字母表大小
30: private int windowSize;//滑动窗口
31:
32: private FastVector alphabet;//字母表
33:
34: private final boolean useParamSearch; //是否选择最优参数
35: /**
36: * No params given, do parameter search
37: */
38: public BagOfPatterns() {
39: this.PAA_intervalsPerWindow = -1;
40: this.SAX_alphabetSize = -1;
41: this.windowSize = -1;
42:
43: knn = new kNN(); //defaults to 1NN, Euclidean distance
44:
45: useParamSearch=true;
46: }
47:
48: /**
49: * Params given, use those only
50: */
51: public BagOfPatterns(int PAA_intervalsPerWindow, int SAX_alphabetSize, int windowSize) {
52: this.PAA_intervalsPerWindow = PAA_intervalsPerWindow;
53: this.SAX_alphabetSize = SAX_alphabetSize;
54: this.windowSize = windowSize;
55:
56: bop = new BagOfPatternsFilter(PAA_intervalsPerWindow, SAX_alphabetSize, windowSize);
57: knn = new kNN(); //default to 1NN, Euclidean distance
58: alphabet = SAX.getAlphabet(SAX_alphabetSize);
59: useParamSearch=false;
60: }
61:
62: public int getPAA_intervalsPerWindow() {
63: return PAA_intervalsPerWindow;
64: }
65:
66: public int getSAX_alphabetSize() {
67: return SAX_alphabetSize;
68: }
69:
70: public int getWindowSize() {
71: return windowSize;
72: }
73:
74: //搜索最优参数
75: public static int[] parameterSearch(Instances data)throws Exception {
76:
77: int minWinSize = (int)((data.numAttributes()-1) * (15.0/100.0));
78: int maxWinSize = (int)((data.numAttributes()-1) * (36.0/100.0));
79: // int winInc = 1; //默认窗口增量为1
80: //自适应窗口增量,最多10个窗口
81: int winInc = (int)((maxWinSize - minWinSize) / 10.0);
82: if (winInc < 1) winInc = 1;
83: double bestAcc = 0.0;
84:
85: //default to min of each para range
86: // train set consists of one instance from each class,making it
87: //impossible to correctly classify using nearest neighbour
88: int bestAlpha = 2, bestWord = 2, bestWindowSize = minWinSize;
89: //遍历所有可能的参数组合
90: for (int alphaSize = 2; alphaSize <= 8; alphaSize++) {
91: for (int winSize = minWinSize; winSize <= maxWinSize; winSize+=winInc) {
92: for (int wordSize = 2; wordSize <= winSize/2; wordSize*=2) { //lin BoP suggestion
93: BagOfPatterns bop = new BagOfPatterns(wordSize, alphaSize, winSize);
94: double acc = bop.crossValidate(data); //leave-one-out without rebuiding every fold
95:
96: if (acc > bestAcc) {
97: bestAcc = acc;
98: bestAlpha = alphaSize;
99: bestWord = wordSize;
100: bestWindowSize = winSize;
101: }
102: }
103: }
104: }
105: //返回最优参数组合
106: return new int[] { bestWord, bestAlpha, bestWindowSize};
107: }
108:
109:
110: @Override
111: public void buildClassifier(final Instances data) throws Exception {
112: trainResults.buildTime=System.currentTimeMillis();
113: if (data.classIndex() != data.numAttributes()-1)
114: throw new Exception("LinBoP_BuildClassifier: Class attribute);
115: if (useParamSearch) {//若选择搜索最优参数,则调用参数搜索方法
116: int[] params = parameterSearch(data);
117:
118: this.PAA_intervalsPerWindow = params[0];
119: this.SAX_alphabetSize = params[1];
120: this.windowSize = params[2];
121:
122: bop = new BagOfPatternsFilter(PAA_intervalsPerWindow, SAX_alphabetSize, windowSize);
123: alphabet = SAX.getAlphabet(SAX_alphabetSize);
124: }
125:
126: //验证参数的合法性
127: if (PAA_intervalsPerWindow<0)
128: throw new Exception("LinBoP_BuildClassifier: Invalid PAA word);
129: if (PAA_intervalsPerWindow>windowSize)
130: throw new Exception("LinBoP_BuildClassifier: Invalid PAA word);
131: if (SAX_alphabetSize<0 || SAX_alphabetSize>10)
132: throw new Exception("LinBoP_BuildClassifier: Invalid SAX);
133: if (windowSize<0 || windowSize>data.numAttributes()-1)
134: throw new Exception("LinBoP_BuildClassifier:Invalid sliding );
135:
136: //real work
137: matrix = bop.process(data); //调用方法将时间序列转化为模式袋矩阵
138: knn.buildClassifier(matrix); //训练最近邻分类器
139: trainResults.buildTime=System.currentTimeMillis()-trainResults.buildTime;
140:
141: }
142:
143: @Override
144: public double classifyInstance(Instance instance) throws Exception {
145: //convert to BOP form
146: double[] hist = bop.bagToArray(bop.buildBag(instance));
147:
148: //stuff into Instance
149: Instances newInsts = new Instances(matrix, 1); //copy attribute data
150: newInsts.add(new SparseInstance(1.0, hist));
151:
152: return knn.classifyInstance(newInsts.firstInstance());
153: }
154: }
上述代码中的BagOfPatternsFilter的作用是将时间序列进行符号化,也就是将要介绍的基于SAX的时间序列符号化,下面详细介绍实现代码。
01:public class BagOfPatternsFilter extends SimpleBatchFilter {
02:
03: public TreeSet<String> dictionary;//用于存储时间序列符号化之后的字典
04:
05: private final int windowSize;
06: private final int numIntervals;
07: private final int alphabetSize;
08: private boolean useRealAttributes = true;
09: //是否使用维度削减,即相邻单词如果相同则不进行统计
10: private boolean numerosityReduction = false;
11: private FastVector alphabet = null;//字母表
12:
13: private static final long serialVersionUID = 1L;
14:
15: public BagOfPatternsFilter(int PAA_intervalsPerWindow, int SAX_alphabetSize, int windowSize) {
16: this.numIntervals = PAA_intervalsPerWindow;
17: this.alphabetSize = SAX_alphabetSize;
18: this.windowSize = windowSize;
19:
20: alphabet = SAX.getAlphabet(SAX_alphabetSize);
21: }
22:
23: public int getWindowSize() {
24: return numIntervals;
25: }
26:
27: public int getNumIntervals() {
28: return numIntervals;
29: }
30:
31: public int getAlphabetSize() {
32: return alphabetSize;
33: }
34:
35: public void useRealValuedAttributes(boolean b){
36: useRealAttributes = b;
37: }
38:
39: public void performNumerosityReduction(boolean b){
40: numerosityReduction = b;
41: }
42: //建立直方图表示
43: private HashMap<String, Integer> buildHistogram(LinkedList<double[]> patterns) {
44:
45: HashMap<String, Integer> hist = new HashMap<>();
46: //将子序列转化为单词
47: for (double[] pattern : patterns) {
48: //convert to string
49: String word = "";
50: //将每个数值进行离散化处理,转化为字符
51: for (int j = 0; j < pattern.length; ++j)
52: word += (String) alphabet.get((int)pattern[j]);
53: //统计直方图中该单词出现的次数
54: Integer val = hist.get(word);
55: if (val == null)
56: val = 0;
57: //更新直方图
58: hist.put(word, val+1);
59: }
60:
61: return hist;
62: }
63: //将时间序列转化为直方图表示
64: public HashMap<String, Integer> buildBag(Instance series) throws Exception {
65:
66: LinkedList<double[]> patterns = new LinkedList<>();
67:
68: double[] prevPattern = new double[windowSize];
69: for (int i = 0; i < windowSize; ++i)
70: prevPattern[i] = -1;
71: //利用滑动窗口提取子序列
72: for (int windowStart = 0; windowStart+windowSize-1 < series.numAttributes()-1; ++windowStart) {
73: double[] pattern = slidingWindow(series, windowStart);
74:
75: try {
76: NormalizeCase.standardNorm(pattern);
77: } catch(Exception e) {
78: //throws exception if zero variance
79: //if zero variance, all values in window the same
80: for (int j = 0; j < pattern.length; ++j)
81: pattern[j] = 0;
82: }
83: //利用PAA技术对时间序列子序列进行转化
84: pattern = SAX.convertSequence(pattern, alphabetSize, numIntervals);
85: //如果使用维度削减则进行判断
86: if (!(numerosityReduction && identicalPattern(pattern, prevPattern)))
87: patterns.add(pattern);
88: }
89:
90: return buildHistogram(patterns);
91: }
92: //滑动窗口提取子序列
93: private double[] slidingWindow(Instance series, int windowStart) {
94: double[] window = new double[windowSize];
95:
96: for (int i = 0; i < windowSize; ++i)
97: window[i] = series.value(i + windowStart);
98:
99: return window;
100: }
101: //判断两个数组数值是否相同,用于维度削减
102: private boolean identicalPattern(double[] a, double[] b) {
103: for (int i = 0; i < a.length; ++i)
104: if (a[i] != b[i])
105: return false;
106:
107: return true;
108: }
109: //格式化输出格式,即转化后的实例格式
110: @Override
111: protected Instances determineOutputFormat(Instances inputFormat)
112: throws Exception {
113:
114:
115: for (int i = 0; i < inputFormat.numAttributes(); i++) {
116: if (inputFormat.classIndex() != i) {
117: if (!inputFormat.attribute(i).isNumeric()) {
118: throw new Exception("Non numeric attribute not allowed for BoP conversion");
119: }
120: }
121: }
122:
123: FastVector attributes = new FastVector();
124: for (String word : dictionary) //按照字典来格式化属性
125: attributes.add(new Attribute(word));
126:
127: Instances result = new Instances("BagOfPatterns_" + inputFormat.relationName(), attributes, inputFormat.numInstances());
128:
129: if (inputFormat.classIndex() >= 0) { //Classification set, set class
130: //类属性
131: Attribute target = inputFormat.attribute(inputFormat.classIndex());
132:
133: FastVector vals = new FastVector(target.numValues());
134: for (int i = 0; i < target.numValues(); i++) {
135: vals.addElement(target.value(i));
136: }
137:
138: result.insertAttributeAt(new Attribute(inputFormat.attribute(inputFormat.classIndex()).name(), vals), result.numAttributes());
139: result.setClassIndex(result.numAttributes() - 1);
140: }
141:
142: return result;
143: }
144:
145: @Override
146: public String globalInfo() {
147: return null;
148: }
149: //处理的主函数
150: @Override
151: public Instances process(final Instances input)
152: throws Exception {
153:
154: ArrayList< HashMap<String, Integer> > bags = new ArrayList<>(input.numInstances());
155: dictionary = new TreeSet<>();
156:
157: for (int i = 0; i < input.numInstances(); i++) {
158: bags.add(buildBag(input.get(i)));//将时间序列转化为直方图表示
159: dictionary.addAll(bags.get(i).keySet());//添加单词进字典
160: }
161:
162: Instances output = determineOutputFormat(input);
163: //遍历词袋
164: Iterator<HashMap<String, Integer> > it = bags.iterator();
165: int i = 0;
166: while (it.hasNext()) {
167: double[] bag = bagToArray(it.next());
168: it.remove();
169: output.add(new SparseInstance(1.0, bag));
170: output.get(i).setClassValue(input.get(i).classValue());
171: ++i;
172: }
173:
174: return output;
175: }
176: //将直方图按照字典顺序转化为数组
177: public double[] bagToArray(HashMap<String, Integer> bag) {
178: double[] res = new double[dictionary.size()];
179:
180: int j = 0;
181: for (String word : dictionary) {
182: Integer val = bag.get(word);
183: if (val != null)
184: res[j] += val;
185: ++j;
186: }
187:
188: return res;
189: }
190:
191: public String getRevision() {
192: // TODO Auto-generated method stub
193: return null;
194: }
195:}