8

私は、人々がプレイしたチェス ゲームの結果を収集する Web サイトに取り組んでいます。プレーヤーのレーティングと、対戦相手とのレーティングの差を見て、勝ち (緑)、引き分け (青)、負け (赤) を表すドットでグラフをプロットします。

この情報を使用して、ロジスティック回帰アルゴリズムも実装して、勝利と勝利/引き分けのカットオフを分類しました。評価と差を 2 つの特徴として使用して、分類器を取得し、分類器が予測を変更する場所の境界をグラフに描画します。

勾配降下、コスト関数、およびシグモイド関数のコードは次のとおりです。

  def gradient_descent()
    oldJ = 0    
    newJ = J()
    alpha = 1.0     # Learning rate
    run = 0
    while (run < 100) do
      tmpTheta = Array.new
      for j in 0...numFeatures do
        sum = 0
        for i in 0...m do
          sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
        end
        tmpTheta[j] = Array.new
        tmpTheta[j][0] = theta[j, 0] - (alpha / m) * sum  # Alpha * partial derivative of J with respect to theta_j
      end
      self.theta = Matrix.rows(tmpTheta)
      oldJ = newJ
      newJ = J()
      run += 1
      if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end   # Do 20 more if the error is still going down a fair amount.
      if (oldJ < newJ)
        alpha /= 10
      end
    end
  end

  def J()
    sum = 0
    for i in 0...m
      sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i]))) 
          + ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
    end
    return (-1.0 / m) * sum
  end

  def h(x)
    if (x.class != 'Matrix')    # In case it comes in as a row vector or an array
      x = Matrix.rows([x])      # [x] because if it's a row vector we want [[a, b]] to get an array whose first row is x.
    end
    x = x.transpose   # x is supposed to be a column vector, and theta^ a row vector, so theta^*x is a number.
    return g((theta.transpose * x)[0, 0])  # theta^ * x gives [[z]], so get [0, 0] of that for the number z.
  end

  def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp == 1.0) then tmp = 0.99999 end    # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp == 0.0) then tmp = 0.00001 end
    return tmp
  end

私自身のチェス プロファイルを表すデータ セットでこれをテストすると、満足できる合理的な結果が得られます。

12842311: 正しい結果

しばらく、私は幸せでした。私が試したすべての例は、興味深いチャートを示しました。次に、250 以上のトーナメントに出場し、1,000 以上のゲームを経験したケビン・カオというプレイヤーを、非常に大規模なトレーニング セットで試してみました。結果は明らかに間違っていました:

12905349: 不正確な結果

まあ、それはダメでした。そこで、最初のアイデアとして、初期学習率を 1.0 から 100.0 に上げました。これにより、Kevin にとって適切な結果が得られました。

12905349: 正しい結果

残念ながら、自分自身と小さなデータセットで試してみたところ、予測の 1 つに対して 0 で平らな線が表示されるという奇妙な現象が発生しました。

12842311: 不正確な結果

シータを確認したところ、[[2.3707682771730836]、[21.22408286825226]、[-19081.906528679192]] でした。3 番目のトレーニング変数 (x_0 = 1 であるため、実際には 2 番目) は評価の差であるため、その差が正のわずかなビットである場合、ロジスティック回帰の式はかなり負になり、シグモイド関数は y = 0 を予測します。差はわずかに正であり、同様に、大きく跳ね上がり、y = 1 を予測します。

初期学習率を 100.0 から 1.0 に戻し、代わりによりゆっくりと減少させることにしました。そのため、コスト関数が増加したときに 10 分の 1 に減らす代わりに、2 分の 1 に減らしました。

残念ながら、これは私の結果をまったく変えませんでした。勾配降下のループ数を 100 から 1000 に増やしても、間違った結果を予測し続けました。

私はまだロジスティック回帰の初心者です (coursera の機械学習クラスを終えたばかりで、そこで学んだアルゴリズムを実装しようとするのはこれが初めてです) ので、自分の直感の範囲に到達しました。ここで何が間違っているのか、何が間違っているのか、どうすれば修正できるのかを誰かが教えてくれたら、とても感謝しています。

編集: 約 300 のデータ ポイントを持つ別のデータ セットでも試してみたところ、平らな緑色の線と通常の青色の線が得られました。アルゴリズムは基本的に両方で同じですが、マルチクラス分類を行っているため、y の結果が若干異なります。

編集: 人々がそれを求めたので、平坦化された線の勾配降下の各反復の J、アルファ、およびシータは次のとおりです。

J: 1.7679949412730092  Alpha: 1.0  Theta: Matrix[[-0.004477611940298508], [0.2835820895522388], [-123.63880597014925]]
J: 0.6873432218114784  Alpha: 0.1  Theta: Matrix[[-0.008057848266678727], [-8.033992854843122], [-118.62571350649955]]
J: 2.7493579020963597  Alpha: 0.1  Theta: Matrix[[0.0035837099422764904], [10.036108977992713], [-114.29679460799208]]
J: 2.5431564907845736  Alpha: 0.01  Theta: Matrix[[0.002061352330336195], [7.255061503962862], [-113.88091708799209]]
J: 2.268221136398013  Alpha: 0.01  Theta: Matrix[[0.0008076454646645536], [4.923257856798684], [-113.43169704202194]]
J: 2.02765281325063  Alpha: 0.01  Theta: Matrix[[-0.00014755931145485107], [3.0843409102315205], [-112.95644762679805]]
J: 1.821451342237053  Alpha: 0.01  Theta: Matrix[[-0.0008639634905593289], [1.6548476959031622], [-112.46627318829059]]
J: 1.8214513720879484  Alpha: 0.01  Theta: Matrix[[-0.0013117163263802246], [0.6758826956046561], [-111.9660989569473]]
J: 1.8214513720879484  Alpha: 0.001  Theta: Matrix[[-0.0013535066248876874], [0.5834935043210742], [-111.91600392423089]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0013952969233951501], [0.49110431303749225], [-111.86590889151448]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014341021771264934], [0.40365238581361185], [-111.81578997843985]]
J: 1.7870844304014568  Alpha: 0.001  Theta: Matrix[[-0.0014729074308578367], [0.31620045858973145], [-111.76567106536523]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.0015115010626209136], [0.22904945780472585], [-111.71555130580272]]
J: 1.752717488714965  Alpha: 0.001  Theta: Matrix[[-0.001544336226800018], [0.15110191314800955], [-111.66540851236988]]
J: 1.770809597429665  Alpha: 0.001  Theta: Matrix[[-0.0015771713909791226], [0.07315436849129325], [-111.61526571893704]]
J: 1.7297985323807161  Alpha: 0.0001  Theta: Matrix[[-0.00158045491336022], [0.06535960382896211], [-111.61025143962061]]
J: 1.718350722631126  Alpha: 0.0001  Theta: Matrix[[-0.0015837319880072584], [0.05757622586497872], [-111.60523715385645]]
J: 1.7183505768797593  Alpha: 0.0001  Theta: Matrix[[-0.0015867170175074515], [0.05030859963032436], [-111.60022257604714]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015897020324328638], [0.04304099913473299], [-111.59520799822326]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0015926870473582369], [0.03577339863921061], [-111.59019342039937]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.00159567206228361], [0.028505798143688237], [-111.58517884257549]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001598657077208983], [0.02123819764816586], [-111.5801642647516]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001601642092134356], [0.013970597152643486], [-111.57514968692772]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001604627107059729], [0.006702996657121109], [-111.57013510910383]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016076121219851022], [-0.0005646038384012671], [-111.56512053127994]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016105971369104752], [-0.007832204333923645], [-111.56010595345606]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016135821518358483], [-0.01509980482944602], [-111.55509137563217]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016165671667612213], [-0.022367405324968396], [-111.55007679780829]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016195521816865944], [-0.02963500582049077], [-111.5450622199844]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016225371966119674], [-0.03690260631601315], [-111.54004764216052]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016255222115373405], [-0.04417020681153553], [-111.53503306433663]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016285072264627136], [-0.05143780730705791], [-111.53001848651274]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016314922443731613], [-0.05870541239661013], [-111.52500390868587]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016344772622834192], [-0.06597301748587016], [-111.519989330859]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016374622664495802], [-0.07324060142296517], [-111.51497475304588]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.001640217664533409], [-0.08015482159935092], [-111.50996040483884]]
J: 1.7183505768793688  Alpha: 0.0001  Theta: Matrix[[-0.0016455906875599943], [-0.0937712290880118], [-111.49993184619791]]
J: 1.994702022407994  Alpha: 0.0001  Theta: Matrix[[-0.0016482771980077554], [-0.10057943119248941], [-111.49491756687851]]
J: 1.9789198631246232  Alpha: 1.0e-05  Theta: Matrix[[-0.0016485458502465615], [-0.10126025363935508], [-111.49441613894419]]
J: 1.948354991984789  Alpha: 1.0e-05  Theta: Matrix[[-0.0016490831547241735], [-0.10262189853308641], [-111.49341328307554]]
J: 1.9331013621188657  Alpha: 1.0e-05  Theta: Matrix[[-0.0016493518069629796], [-0.10330272097995208], [-111.49291185514122]]
J: 1.9178620371528292  Alpha: 1.0e-05  Theta: Matrix[[-0.0016496204592017856], [-0.10398354342681772], [-111.49241042720689]]
J: 1.902623825636303  Alpha: 1.0e-05  Theta: Matrix[[-0.0016498891114405914], [-0.10466436587368326], [-111.49190899927257]]
J: 1.8873858680247269  Alpha: 1.0e-05  Theta: Matrix[[-0.0016501577636793972], [-0.10534518832054848], [-111.49140757133824]]
J: 1.8721478527437034  Alpha: 1.0e-05  Theta: Matrix[[-0.0016504264159182024], [-0.10602601076741257], [-111.49090614340392]]
J: 1.8569098083540256  Alpha: 1.0e-05  Theta: Matrix[[-0.0016506950681570054], [-0.10670683321427255], [-111.4904047154696]]
J: 1.8416717846532462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016509637203958004], [-0.10738765566111781], [-111.48990328753527]]
J: 1.8264337702403803  Alpha: 1.0e-05  Theta: Matrix[[-0.0016512323726345674], [-0.10806847810791036], [-111.48940185960095]]
J: 1.8111957469624462  Alpha: 1.0e-05  Theta: Matrix[[-0.0016515010251717409], [-0.1087493010703349], [-111.48890043166602]]
J: 1.7959577228777213  Alpha: 1.0e-05  Theta: Matrix[[-0.001651769677708553], [-0.10943012403208266], [-111.4883990037311]]
J: 1.7807196990939538  Alpha: 1.0e-05  Theta: Matrix[[-0.0016520383302440706], [-0.11011094699140556], [-111.48789757579618]]
J: 1.7654816767669712  Alpha: 1.0e-05  Theta: Matrix[[-0.0016523069827749494], [-0.11079176994204029], [-111.48739614786128]]
J: 1.7197677244765115  Alpha: 1.0e-05  Theta: Matrix[[-0.0016531129399852717], [-0.11283423807786983], [-111.4858918640573]]
J: 1.7045300185036796  Alpha: 1.0e-05  Theta: Matrix[[-0.0016533815914621833], [-0.11351505905442376], [-111.48539043612449]]
J: 1.689293134633683  Alpha: 1.0e-05  Theta: Matrix[[-0.0016536502402002386], [-0.11419587490110002], [-111.48488900819716]]
J: 1.674059195452273  Alpha: 1.0e-05  Theta: Matrix[[-0.001653918879126327], [-0.1148766723699622], [-111.48438758028945]]
J: 1.6588357959146847  Alpha: 1.0e-05  Theta: Matrix[[-0.0016541874829120791], [-0.11555740402097447], [-111.48388615245203]]
J: 1.6436500186219352  Alpha: 1.0e-05  Theta: Matrix[[-0.0016544559609891405], [-0.1162379002196091], [-111.48338472486603]]
J: 1.6285972611659707  Alpha: 1.0e-05  Theta: Matrix[[-0.001654723991174496], [-0.11691755751707966], [-111.4828832981758]]
J: 1.6139994752963014  Alpha: 1.0e-05  Theta: Matrix[[-0.0016549904481917704], [-0.11759426827073645], [-111.48238187463193]]
J: 1.600799606845299  Alpha: 1.0e-05  Theta: Matrix[[-0.0016552516449943116], [-0.11826112664220582], [-111.48188046160847]]
J: 1.5908244528084288  Alpha: 1.0e-05  Theta: Matrix[[-0.0016554977759847996], [-0.1188997667477244], [-111.48137907871664]]
J: 1.5851960976828814  Alpha: 1.0e-05  Theta: Matrix[[-0.0016557144987826046], [-0.11948332530842007], [-111.4808777546412]]
J: 1.5826817076400923  Alpha: 1.0e-05  Theta: Matrix[[-0.0016558999497352893], [-0.12000831170339445], [-111.48037649310945]]
J: 1.5816354848004566  Alpha: 1.0e-05  Theta: Matrix[[-0.0016560658987327093], [-0.12049677093659837], [-111.4798752705816]]
J: 1.581199878569286  Alpha: 1.0e-05  Theta: Matrix[[-0.0016562224426970157], [-0.12096761454376066], [-111.47937406686383]]
J: 1.5810169018926878  Alpha: 1.0e-05  Theta: Matrix[[-0.0016563748211790893], [-0.12143065620486218], [-111.47887287147701]]
J: 1.5809396242131868  Alpha: 1.0e-05  Theta: Matrix[[-0.0016565254040880424], [-0.1218903347622732], [-111.47837167968135]]
J: 1.5809069017613124  Alpha: 1.0e-05  Theta: Matrix[[-0.0016566752202995195], [-0.12234857730581448], [-111.47787048941908]]
J: 1.5808930296490606  Alpha: 1.0e-05  Theta: Matrix[[-0.001656824710233385], [-0.12280620875454971], [-111.47736929980935]]
J: 1.580887145848097  Alpha: 1.0e-05  Theta: Matrix[[-0.0016569740612930289], [-0.12326358014294572], [-111.47686811047738]]
J: 1.580884649719601  Alpha: 1.0e-05  Theta: Matrix[[-0.0016571233527736234], [-0.12372084005243131], [-111.47636692126457]]
J: 1.5808835906710963  Alpha: 1.0e-05  Theta: Matrix[[-0.0016572726175860411], [-0.12417805026085695], [-111.47586573210509]]
J: 1.5808831413239819  Alpha: 1.0e-05  Theta: Matrix[[-0.00165742186803091], [-0.12463523410670607], [-111.47536454297435]]
.........

適切な予測を作成するもの:

J: 4.330234652497978  Alpha: 1.0  Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978  Alpha: 0.1  Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815  Alpha: 0.1  Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924  Alpha: 0.1  Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678  Alpha: 0.1  Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063  Alpha: 0.01  Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392  Alpha: 0.01  Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366  Alpha: 0.01  Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577  Alpha: 0.01  Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192  Alpha: 0.01  Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164  Alpha: 0.01  Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123  Alpha: 0.01  Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308  Alpha: 0.01  Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465  Alpha: 0.01  Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458  Alpha: 0.01  Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413  Alpha: 0.01  Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413  Alpha: 0.001  Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366  Alpha: 0.001  Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032  Alpha: 0.0001  Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476  Alpha: 0.0001  Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144  Alpha: 0.0001  Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685  Alpha: 0.0001  Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594  Alpha: 0.0001  Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632  Alpha: 0.0001  Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818  Alpha: 0.0001  Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538  Alpha: 0.0001  Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498  Alpha: 0.0001  Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......

編集: シータがすべて 0 であるため、仮説の最初の反復は常に 0.5 を予測していることに気付きました。しかし、その後は常に1または0を予測します(コードに存在しない対数を避けるために0.00001または0.99999)。それは私には正しくないように思えます-あまりにも自信があります-そしておそらくこれが機能しない理由の鍵です.

4

4 に答える 4

3

実装には、少し非標準的なことがいくつかあります。

  1. まず、ロジスティック回帰の目的は、通常、の最小化問題として与えられます。

    lr(x[n],y[n])=log(1+exp(-y[n]*dot(w[n],x[n])))y[n]または _ 1_-1

    あなたはの同等の最大化問題の定式化を使用しているようです

    lr(x[n],y[n])=-y[n]*log(1+exp(-dot(w[n],x[n])))+(1-y[n])*(-dot(w[n],x[n])-log(1+exp(-dot(w[n],x[n])))

    ここで、y[n]は0または1のいずれかです(この定式化のy [n] = 0は、最初の定式化のy [n] = 1と同等です)。

    したがって、データセットで、ラベルが1または-1ではなく0または1であることを確認する必要があります。

  2. 次に、LR目標は通常、m(データセットのサイズ)で除算されません。ロジスティック回帰を確率モデルと見なす場合、このスケーリング係数は正しくありません。

  3. 最後に、実装に数値的な問題がある可能性があります(g関数で修正しようとしました)。Leon Bottouのsgdコード(http://leon.bottou.org/projects/sgd)には、次のように損失関数と導関数のより安定した計算があります(Cコードでは-彼は私が言及した最初のLR定式化を使用します):

    /* logloss(a,y) = log(1+exp(-a*y)) */
    double loss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return exp(-z);
      }
      if (z < -18) {
        return -z;
      }
      return log(1 + exp(-z));
    }
    
    /*  -dloss(a,y)/da */
    double dloss(double a, double y)
    {
      double z = a * y;
      if (z > 18) {
        return y * exp(-z);
      }
      if (z < -18){
        return y;
      }
      return y / (1 + exp(z));
    }
    

また、ストックl-bfgsルーチン(Rubyの実装に精通していない)の実行を検討する必要があります。これにより、目的と勾配の計算を正しく行うことに集中でき、学習率などについて心配する必要がなくなります。

于 2012-12-11T03:48:04.980 に答える
1

いくつかの考え:

  • J()との繰り返しの値を示していただけると助かると思いますalpha
  • 特徴として定数(バイアス)を含めますか?私の記憶が正しければ、これを行わないと、(直線) の直線はh() == 0.5​​ゼロを通過することになります。

  • 関数は、の対数尤度をJ()返しているように見えます (したがって、最小化する必要があります)。それでも、学習率が低下します。つまり、大きくなった場合、つまり悪化します。if (oldJ < newJ)J()

于 2012-12-07T20:06:48.127 に答える
0

浮動小数点数:

これを試して?Equalフロート間の比較は私にはあまり意味がありません。

def g(z)
    tmp = 1.0 / (1.0 + Math.exp(-z))   # Sigmoid function
    if (tmp >= 0.99999) then tmp = 0.99999 end    # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
    if (tmp <= 0.00001) then tmp = 0.00001 end
    return tmp
end

機能のスケーリング

2 つの機能を使用しているとおっしゃいましたが、それらはプレイヤー自身の評価と評価の差分であると思います。あれは正しいですか?

また、データの前処理ステップとして、いくつかの機能スケーリングを使用することも検討してください。

ここに画像の説明を入力. または、データ内の各特徴の値にゼロ平均と単位分散を持たせることにより、標準化方法を実行できます。

質問:

  • グラフの青い線と緑の線の違いは何ですか?
  • 非常に小さい学習率 (0.01 または 0.001 など) から始めようとしましたか?
  • 固定学習率だけを使用すると、どのような動作になるでしょうか? 0.001、0.01、0.1、0.5、1、10 などを試してみてください。結果をここに投稿してください。
于 2012-12-07T19:31:45.537 に答える