4

ポイントを2D空間の 2つのパーティションに分類できる単純なモデルを構築しようとしています:

  1. いくつかのポイントとそれらが属するパーティションを指定して、モデルをトレーニングします。
  2. モデルを使用して、テスト ポイントが分類されるグループを予測 (分類)します。

残念ながら、期待どおりの答えが得られません。コードに何かが欠けているのでしょうか、それとも何か間違っていますか?

public class SimpleClassifier {

    public static class Point{
        public int x;
        public int y;

        public Point(int x,int y){
            this.x = x;
            this.y = y;
        }

        @Override
        public boolean equals(Object arg0) {
            Point p = (Point)  arg0;
            return( (this.x == p.x) &&(this.y== p.y));
        }

        @Override
        public String toString() {
            // TODO Auto-generated method stub
            return  this.x + " , " + this.y ; 
        }
    }

    public static void main(String[] args) {

        Map<Point,Integer> points = new HashMap<SimpleClassifier.Point, Integer>();

        points.put(new Point(0,0), 0);
        points.put(new Point(1,1), 0);
        points.put(new Point(1,0), 0);
        points.put(new Point(0,1), 0);
        points.put(new Point(2,2), 0);


        points.put(new Point(8,8), 1);
        points.put(new Point(8,9), 1);
        points.put(new Point(9,8), 1);
        points.put(new Point(9,9), 1);


        OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
        learningAlgo =  new OnlineLogisticRegression(2, 2, new L1());
        learningAlgo.learningRate(50);

        //learningAlgo.alpha(1).stepOffset(1000);

        System.out.println("training model  \n" );
        for(Point point : points.keySet()){
            Vector v = getVector(point);
            System.out.println(point  + " belongs to " + points.get(point));
            learningAlgo.train(points.get(point), v);
        }

        learningAlgo.close();


        //now classify real data
        Vector v = new RandomAccessSparseVector(2);
        v.set(0, 0.5);
        v.set(1, 0.5);

        Vector r = learningAlgo.classifyFull(v);
        System.out.println(r);

        System.out.println("ans = " );
        System.out.println("no of categories = " + learningAlgo.numCategories());
        System.out.println("no of features = " + learningAlgo.numFeatures());
        System.out.println("Probability of cluster 0 = " + r.get(0));
        System.out.println("Probability of cluster 1 = " + r.get(1));

    }

    public static Vector getVector(Point point){
        Vector v = new DenseVector(2);
        v.set(0, point.x);
        v.set(1, point.y);

        return v;
    }
}

出力:

ans = 
no of categories = 2
no of features = 2
Probability of cluster 0 = 3.9580985042775296E-4
Probability of cluster 1 = 0.9996041901495722

99% の確率で、出力はより多くの確率を示しますcluster 1なんで?

4

2 に答える 2

5

問題は、常に 1 であるバイアス (インターセプト) 項を含めなかったことです。バイアス項 (1) をポイント クラスに追加する必要があります。

これは、機械学習の経験者の多くが犯す非常に基本的な間違いです。理論の学習に時間を割くのは良い考えかもしれません。Andrew Ng の講義は、学ぶのに最適な場所の 1 つです。

コードで期待どおりの出力が得られるようにするには、次の点を変更する必要があります。

  1. バイアス項を追加。
  2. 学習パラメータが高すぎました。10に変更しました

これで、クラス 0 の P(0)=0.9999 が得られます。

正しい結果が得られる完全な動作例を次に示します。

import java.util.HashMap;
import java.util.Map;

import org.apache.mahout.classifier.sgd.L1;
import org.apache.mahout.classifier.sgd.OnlineLogisticRegression;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;


class Point{
    public int x;
    public int y;

    public Point(int x,int y){
        this.x = x;
        this.y = y;
    }

    @Override
    public boolean equals(Object arg0) {
        Point p = (Point)  arg0;
        return( (this.x == p.x) &&(this.y== p.y));
    }

    @Override
    public String toString() {
        return  this.x + " , " + this.y ; 
    }
}

public class SimpleClassifier {



    public static void main(String[] args) {

            Map<Point,Integer> points = new HashMap<Point, Integer>();

            points.put(new Point(0,0), 0);
            points.put(new Point(1,1), 0);
            points.put(new Point(1,0), 0);
            points.put(new Point(0,1), 0);
            points.put(new Point(2,2), 0);

            points.put(new Point(8,8), 1);
            points.put(new Point(8,9), 1);
            points.put(new Point(9,8), 1);
            points.put(new Point(9,9), 1);


            OnlineLogisticRegression learningAlgo = new OnlineLogisticRegression();
            learningAlgo =  new OnlineLogisticRegression(2, 3, new L1());
            learningAlgo.lambda(0.1);
            learningAlgo.learningRate(10);

            System.out.println("training model  \n" );

            for(Point point : points.keySet()){

                Vector v = getVector(point);
                System.out.println(point  + " belongs to " + points.get(point));
                learningAlgo.train(points.get(point), v);
            }

            learningAlgo.close();

            Vector v = new RandomAccessSparseVector(3);
            v.set(0, 0.5);
            v.set(1, 0.5);
            v.set(2, 1);

            Vector r = learningAlgo.classifyFull(v);
            System.out.println(r);

            System.out.println("ans = " );
            System.out.println("no of categories = " + learningAlgo.numCategories());
            System.out.println("no of features = " + learningAlgo.numFeatures());
            System.out.println("Probability of cluster 0 = " + r.get(0));
            System.out.println("Probability of cluster 1 = " + r.get(1));

    }

    public static Vector getVector(Point point){
        Vector v = new DenseVector(3);
        v.set(0, point.x);
        v.set(1, point.y);
        v.set(2, 1);
        return v;
    }
}

出力:

2 , 2 belongs to 0
1 , 0 belongs to 0
9 , 8 belongs to 1
8 , 8 belongs to 1
0 , 1 belongs to 0
0 , 0 belongs to 0
1 , 1 belongs to 0
9 , 9 belongs to 1
8 , 9 belongs to 1
{0:2.470723149516907E-6,1:0.9999975292768505}
ans = 
no of categories = 2
no of features = 3
Probability of cluster 0 = 2.470723149516907E-6
Probability of cluster 1 = 0.9999975292768505

クラス Point を SimpleClassifier クラスの外側に定義したことに注意してください。ただし、これはコードを読みやすくするためだけのものであり、必須ではありません。

学習率を変更するとどうなるか見てみましょう。学習率の選択方法を理解するには、交差検証に関する注意事項をお読みください。

Learning Rate => Probability of cluster 0
0.001 => 0.4991116089
0.01 => 0.492481585
0.1 => 0.469961472
1 => 0.5327745322
10 => 0.9745740393
100 => 0
1000 => 0

学習率の選択:

  1. 確率的勾配降下法を実行するのが一般的です。固定の学習率 α から開始して、アルゴリズムの実行時に学習率 α をゼロまでゆっくりと減少させることで、パラメーターがグローバルに収束することを確認することもできます。ではなく、単に最小値の周りで振動します。
  2. この場合のように、定数 α を使用すると、初期選択を行い、勾配降下を実行してコスト関数を観察し、それに応じて学習率を調整できます。ここで説明されています
于 2014-09-16T15:24:10.890 に答える