ここから線形回帰コードを使用して、勾配降下法を使用した単純なロジスティック回帰を作成しました: Javaでの勾配降下線形回帰
ここで、ロジット変換を追加して仮説を変更することにより、ロジスティック回帰に適応させています。1 /(1 + e ^(-z))、ここでzは元のTheta ^ T*Xです。人口によるスケーリングではありません。サイズ。
結果をテストしようとすると、紛らわしい動作が発生します。独立変数(X)を乱数*期待される重みに設定し、従属(Y)をそれらの合計のロジットに設定します。したがって、y = logit(w0 * 1 + w1 * x1 + w2 * x2)。
この場合、正解に収束し、予想される重みを回復できます。しかし、明らかにYは0または1にする必要があります。しかし、切り上げまたは切り下げを行うと、収束しなくなります。
ここで私はトレーニングデータを生成しています:
@Test
public void testLogisticDescentMultiple() {
//...
//initialize Independent Xi
//Going to create test data y= 10 + .5(x1) + .33(x2)
for( int x=0;x<NUM_EXAMPLES;x++) {
independent.set(x, 0, 1); //x0 We always set this to 1 for the intercept
independent.set(x, 1, random.nextGaussian()); //x1
independent.set(x, 2, random.nextGaussian() ); //x2
}
//initialize dependent Yi
for( int x=0;x<NUM_EXAMPLES;x++) {
double val = w0 + (w1*independent.get(x,1)) + (w2*independent.get(x,2));
double logitVal = logit( val );
//Converges without this code block
if( logitVal < 0.5 ) {
logitVal = 0;
}else {
logitVal = 1;
}
//
dependent.set(x, logitVal );
}
//...
}
public static double logit( double val ) {
return( 1.0 / (1.0 + Math.exp(-val)));
}
//updated Logistic Regression
public DoubleMatrix1D logisticDescent(double alpha,
DoubleMatrix1D thetas,
DoubleMatrix2D independent,
DoubleMatrix1D dependent ) {
Algebra algebra = new Algebra();
//hypothesis is 1/( 1+ e ^ -(theta(Transposed) * X))
//start with theata(Transposed)*X
DoubleMatrix1D hypothesies = algebra.mult( independent, thetas );
//h = 1/(1+ e^-h)
hypothesies.assign(new DoubleFunction() {
@Override
public double apply (double val) {
return( logit( val ) );
}
});
//hypothesis - Y
//Now we have for each Xi, the difference between predicted by the hypothesis and the actual Yi
hypothesies.assign(dependent, Functions.minus);
//Transpose Examples(MxN) to NxM so we can matrix multiply by hypothesis Nx1
DoubleMatrix2D transposed = algebra.transpose(independent);
DoubleMatrix1D deltas = algebra.mult(transposed, hypothesies );
// thetas = thetas - (deltas*alpha) in one step
thetas.assign(deltas, Functions.minusMult(alpha));
return( thetas );
}