1

非常に単純な 1 次元勾配降下アルゴリズムを実装しようとしています。私が持っているコードはまったく機能しません。基本的に私のアルファ値に応じて、終了パラメーターは非常に巨大 (〜 70 桁など) になるか、基本的にゼロ (〜 0.000) になります。勾配降下はアルファでこれほど敏感であってはならないと思います([0.0,1.0]で小さなデータを生成していますが、勾配自体がデータのスケールを説明する必要があると思いますよね?)。

コードは次のとおりです。

#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <vector>

using namespace std;

double a, b;
double theta0 = 0.0, theta1 = 0.0;

double myrand() {
  return double(rand()) / RAND_MAX;
}

double f(double x) {
  double y = a * x + b;
  y *= 0.1 * (myrand() - 0.5);  // +/- 5% noise

  return y;
}

double h(double x) {
  return theta1 * x + theta0;
}

int main() {
  srand(time(NULL));
  a = myrand();
  b = myrand();

  printf("set parameters: a = %lf, b = %lf\n", a, b);

  int N = 100;

  vector<double> xs(N);
  vector<double> ys(N);
  for (int i = 0; i < N; ++i) {
    xs[i] = myrand();
    ys[i] = f(xs[i]);
  }

  double sensitivity = 0.008;
  double d0, d1;

  for (int n = 0; n < 100; ++n) {
    d0 = d1 = 0.0;
    for (int i = 0; i < N; ++i) {
      d0 += h(xs[i]) - ys[i];
      d1 += (h(xs[i]) - ys[i]) * xs[i];
    }

    theta0 -= sensitivity * d0;
    theta1 -= sensitivity * d1;

    printf("theta0: %lf, theta1: %lf\n", theta0, theta1);
  }

  return 0;
}
4

2 に答える 2

0

I had a quick look at your implementation and it looks fine to me.

The code I have does not work at all.

I wouldn't say that. It seems to behave correctly for small enough values of sensitivity, which is a value that you just have to "guess", and that is how the gradient descent is supposed to work.

I feel like a gradient descent should not be nearly this sensitive in alpha

If you struggle to visualize that, remember that you are using gradient descent to find the minimum of the cost function of linear regression, which is a quadratic function. If you plot the cost function you will see why the learning rate is so sensitive in these cases: intuitively, if the parabola is narrow, the algorithm will converge more quickly, which is good, but then the learning rate is more "sensitive" and the algorithm can easily diverge if you are not careful.

于 2013-05-30T10:39:01.190 に答える