/**
* 感知机 解决二分类问题 1,-1
* @author ysh 1208706282
*
*/
public class Perceptron {
double weight[];
List<Sample> samples;
static class Sample{
Double label;
List<Double> feature;
}
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i]));
}
sample.label = Double.valueOf(splits[splits.length-1]);
if(sample.label == 0){ //标签为0改为-1
sample.label = -1.0;
}
samples.add(sample);
}
reader.close();
}
public double classify(Sample sample,double weight[]){
double ret = 0;
for(int i=0;i<sample.feature.size();i++){
ret += sample.feature.get(i)*weight[i];
}
ret += weight[weight.length-1];//偏置
return ret;
}
public void updateWeight(Sample sample,double weight[],double eta){
for(int i=0;i<sample.feature.size();i++){
weight[i] += eta*sample.label*sample.feature.get(i);
}
weight[weight.length-1] += eta*sample.label;
}
public void train(int iters,double eta){
int len = samples.get(0).feature.size();
weight = new double[len+1];
for(int i=0;i<weight.length;i++){
weight[i] = 0;
}
for(int iter=0;iter<iters;iter++){
int count = 0;
for(Sample sample:samples){
if(sample.label*classify(sample,weight) <= 0){
updateWeight(sample,weight,eta);
count++;
}
}
if(count == 0){
System.out.println("already complete");
break;
}
System.out.println("iter "+iter+" count "+count);
}
}
public void test(){
int count = 0;
for(Sample sample:samples){
double value = classify(sample,weight);
System.out.println(value+","+sample.label);
if(sample.label>0){
if(value>=0){
count++;
}
}else{
if(value<0){
count++;
}
}
}
System.out.println("right rate: "+count*1.0/samples.size());
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Perceptron per = new Perceptron();
per.loadData("F:/contest/iris.csv",",");
per.train(100,0.1);
per.test();
}
}
* 感知机 解决二分类问题 1,-1
* @author ysh 1208706282
*
*/
public class Perceptron {
double weight[];
List<Sample> samples;
static class Sample{
Double label;
List<Double> feature;
}
public void loadData(String path,String regex) throws Exception{
samples = new ArrayList<Sample>();
BufferedReader reader = new BufferedReader(new FileReader(path));
String line = null;
String splits[] = null;
Sample sample = null;
while(null != (line=reader.readLine())){
splits = line.split(regex);
sample = new Sample();
sample.feature = new ArrayList<Double>(splits.length-1);
for(int i=0;i<splits.length-1;i++){
sample.feature.add(new Double(splits[i]));
}
sample.label = Double.valueOf(splits[splits.length-1]);
if(sample.label == 0){ //标签为0改为-1
sample.label = -1.0;
}
samples.add(sample);
}
reader.close();
}
public double classify(Sample sample,double weight[]){
double ret = 0;
for(int i=0;i<sample.feature.size();i++){
ret += sample.feature.get(i)*weight[i];
}
ret += weight[weight.length-1];//偏置
return ret;
}
public void updateWeight(Sample sample,double weight[],double eta){
for(int i=0;i<sample.feature.size();i++){
weight[i] += eta*sample.label*sample.feature.get(i);
}
weight[weight.length-1] += eta*sample.label;
}
public void train(int iters,double eta){
int len = samples.get(0).feature.size();
weight = new double[len+1];
for(int i=0;i<weight.length;i++){
weight[i] = 0;
}
for(int iter=0;iter<iters;iter++){
int count = 0;
for(Sample sample:samples){
if(sample.label*classify(sample,weight) <= 0){
updateWeight(sample,weight,eta);
count++;
}
}
if(count == 0){
System.out.println("already complete");
break;
}
System.out.println("iter "+iter+" count "+count);
}
}
public void test(){
int count = 0;
for(Sample sample:samples){
double value = classify(sample,weight);
System.out.println(value+","+sample.label);
if(sample.label>0){
if(value>=0){
count++;
}
}else{
if(value<0){
count++;
}
}
}
System.out.println("right rate: "+count*1.0/samples.size());
}
/**
* @param args
* @throws Exception
*/
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
Perceptron per = new Perceptron();
per.loadData("F:/contest/iris.csv",",");
per.train(100,0.1);
per.test();
}
}