私は、人々がプレイしたチェス ゲームの結果を収集する Web サイトに取り組んでいます。プレーヤーのレーティングと、対戦相手とのレーティングの差を見て、勝ち (緑)、引き分け (青)、負け (赤) を表すドットでグラフをプロットします。
この情報を使用して、ロジスティック回帰アルゴリズムも実装して、勝利と勝利/引き分けのカットオフを分類しました。評価と差を 2 つの特徴として使用して、分類器を取得し、分類器が予測を変更する場所の境界をグラフに描画します。
勾配降下、コスト関数、およびシグモイド関数のコードは次のとおりです。
def gradient_descent()
oldJ = 0
newJ = J()
alpha = 1.0 # Learning rate
run = 0
while (run < 100) do
tmpTheta = Array.new
for j in 0...numFeatures do
sum = 0
for i in 0...m do
sum += ((h(training_data[:x][i]) - training_data[:y][i][0]) * training_data[:x][i][j])
end
tmpTheta[j] = Array.new
tmpTheta[j][0] = theta[j, 0] - (alpha / m) * sum # Alpha * partial derivative of J with respect to theta_j
end
self.theta = Matrix.rows(tmpTheta)
oldJ = newJ
newJ = J()
run += 1
if (run == 100 && (oldJ - newJ > 0.001)) then run -= 20 end # Do 20 more if the error is still going down a fair amount.
if (oldJ < newJ)
alpha /= 10
end
end
end
def J()
sum = 0
for i in 0...m
sum += ((training_data[:y][i][0] * Math.log(h(training_data[:x][i])))
+ ((1 - training_data[:y][i][0]) * Math.log(1 - h(training_data[:x][i]))))
end
return (-1.0 / m) * sum
end
def h(x)
if (x.class != 'Matrix') # In case it comes in as a row vector or an array
x = Matrix.rows([x]) # [x] because if it's a row vector we want [[a, b]] to get an array whose first row is x.
end
x = x.transpose # x is supposed to be a column vector, and theta^ a row vector, so theta^*x is a number.
return g((theta.transpose * x)[0, 0]) # theta^ * x gives [[z]], so get [0, 0] of that for the number z.
end
def g(z)
tmp = 1.0 / (1.0 + Math.exp(-z)) # Sigmoid function
if (tmp == 1.0) then tmp = 0.99999 end # These two things are here because ln(0) DNE, so we don't want to do ln(1 - 1.0) or ln(0.0)
if (tmp == 0.0) then tmp = 0.00001 end
return tmp
end
私自身のチェス プロファイルを表すデータ セットでこれをテストすると、満足できる合理的な結果が得られます。
しばらく、私は幸せでした。私が試したすべての例は、興味深いチャートを示しました。次に、250 以上のトーナメントに出場し、1,000 以上のゲームを経験したケビン・カオというプレイヤーを、非常に大規模なトレーニング セットで試してみました。結果は明らかに間違っていました:
まあ、それはダメでした。そこで、最初のアイデアとして、初期学習率を 1.0 から 100.0 に上げました。これにより、Kevin にとって適切な結果が得られました。
残念ながら、自分自身と小さなデータセットで試してみたところ、予測の 1 つに対して 0 で平らな線が表示されるという奇妙な現象が発生しました。
シータを確認したところ、[[2.3707682771730836]、[21.22408286825226]、[-19081.906528679192]] でした。3 番目のトレーニング変数 (x_0 = 1 であるため、実際には 2 番目) は評価の差であるため、その差が正のわずかなビットである場合、ロジスティック回帰の式はかなり負になり、シグモイド関数は y = 0 を予測します。差はわずかに正であり、同様に、大きく跳ね上がり、y = 1 を予測します。
初期学習率を 100.0 から 1.0 に戻し、代わりによりゆっくりと減少させることにしました。そのため、コスト関数が増加したときに 10 分の 1 に減らす代わりに、2 分の 1 に減らしました。
残念ながら、これは私の結果をまったく変えませんでした。勾配降下のループ数を 100 から 1000 に増やしても、間違った結果を予測し続けました。
私はまだロジスティック回帰の初心者です (coursera の機械学習クラスを終えたばかりで、そこで学んだアルゴリズムを実装しようとするのはこれが初めてです) ので、自分の直感の範囲に到達しました。ここで何が間違っているのか、何が間違っているのか、どうすれば修正できるのかを誰かが教えてくれたら、とても感謝しています。
編集: 約 300 のデータ ポイントを持つ別のデータ セットでも試してみたところ、平らな緑色の線と通常の青色の線が得られました。アルゴリズムは基本的に両方で同じですが、マルチクラス分類を行っているため、y の結果が若干異なります。
編集: 人々がそれを求めたので、平坦化された線の勾配降下の各反復の J、アルファ、およびシータは次のとおりです。
J: 1.7679949412730092 Alpha: 1.0 Theta: Matrix[[-0.004477611940298508], [0.2835820895522388], [-123.63880597014925]]
J: 0.6873432218114784 Alpha: 0.1 Theta: Matrix[[-0.008057848266678727], [-8.033992854843122], [-118.62571350649955]]
J: 2.7493579020963597 Alpha: 0.1 Theta: Matrix[[0.0035837099422764904], [10.036108977992713], [-114.29679460799208]]
J: 2.5431564907845736 Alpha: 0.01 Theta: Matrix[[0.002061352330336195], [7.255061503962862], [-113.88091708799209]]
J: 2.268221136398013 Alpha: 0.01 Theta: Matrix[[0.0008076454646645536], [4.923257856798684], [-113.43169704202194]]
J: 2.02765281325063 Alpha: 0.01 Theta: Matrix[[-0.00014755931145485107], [3.0843409102315205], [-112.95644762679805]]
J: 1.821451342237053 Alpha: 0.01 Theta: Matrix[[-0.0008639634905593289], [1.6548476959031622], [-112.46627318829059]]
J: 1.8214513720879484 Alpha: 0.01 Theta: Matrix[[-0.0013117163263802246], [0.6758826956046561], [-111.9660989569473]]
J: 1.8214513720879484 Alpha: 0.001 Theta: Matrix[[-0.0013535066248876874], [0.5834935043210742], [-111.91600392423089]]
J: 1.7870844304014568 Alpha: 0.001 Theta: Matrix[[-0.0013952969233951501], [0.49110431303749225], [-111.86590889151448]]
J: 1.7870844304014568 Alpha: 0.001 Theta: Matrix[[-0.0014341021771264934], [0.40365238581361185], [-111.81578997843985]]
J: 1.7870844304014568 Alpha: 0.001 Theta: Matrix[[-0.0014729074308578367], [0.31620045858973145], [-111.76567106536523]]
J: 1.752717488714965 Alpha: 0.001 Theta: Matrix[[-0.0015115010626209136], [0.22904945780472585], [-111.71555130580272]]
J: 1.752717488714965 Alpha: 0.001 Theta: Matrix[[-0.001544336226800018], [0.15110191314800955], [-111.66540851236988]]
J: 1.770809597429665 Alpha: 0.001 Theta: Matrix[[-0.0015771713909791226], [0.07315436849129325], [-111.61526571893704]]
J: 1.7297985323807161 Alpha: 0.0001 Theta: Matrix[[-0.00158045491336022], [0.06535960382896211], [-111.61025143962061]]
J: 1.718350722631126 Alpha: 0.0001 Theta: Matrix[[-0.0015837319880072584], [0.05757622586497872], [-111.60523715385645]]
J: 1.7183505768797593 Alpha: 0.0001 Theta: Matrix[[-0.0015867170175074515], [0.05030859963032436], [-111.60022257604714]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0015897020324328638], [0.04304099913473299], [-111.59520799822326]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0015926870473582369], [0.03577339863921061], [-111.59019342039937]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.00159567206228361], [0.028505798143688237], [-111.58517884257549]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.001598657077208983], [0.02123819764816586], [-111.5801642647516]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.001601642092134356], [0.013970597152643486], [-111.57514968692772]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.001604627107059729], [0.006702996657121109], [-111.57013510910383]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016076121219851022], [-0.0005646038384012671], [-111.56512053127994]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016105971369104752], [-0.007832204333923645], [-111.56010595345606]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016135821518358483], [-0.01509980482944602], [-111.55509137563217]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016165671667612213], [-0.022367405324968396], [-111.55007679780829]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016195521816865944], [-0.02963500582049077], [-111.5450622199844]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016225371966119674], [-0.03690260631601315], [-111.54004764216052]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016255222115373405], [-0.04417020681153553], [-111.53503306433663]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016285072264627136], [-0.05143780730705791], [-111.53001848651274]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016314922443731613], [-0.05870541239661013], [-111.52500390868587]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016344772622834192], [-0.06597301748587016], [-111.519989330859]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016374622664495802], [-0.07324060142296517], [-111.51497475304588]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.001640217664533409], [-0.08015482159935092], [-111.50996040483884]]
J: 1.7183505768793688 Alpha: 0.0001 Theta: Matrix[[-0.0016455906875599943], [-0.0937712290880118], [-111.49993184619791]]
J: 1.994702022407994 Alpha: 0.0001 Theta: Matrix[[-0.0016482771980077554], [-0.10057943119248941], [-111.49491756687851]]
J: 1.9789198631246232 Alpha: 1.0e-05 Theta: Matrix[[-0.0016485458502465615], [-0.10126025363935508], [-111.49441613894419]]
J: 1.948354991984789 Alpha: 1.0e-05 Theta: Matrix[[-0.0016490831547241735], [-0.10262189853308641], [-111.49341328307554]]
J: 1.9331013621188657 Alpha: 1.0e-05 Theta: Matrix[[-0.0016493518069629796], [-0.10330272097995208], [-111.49291185514122]]
J: 1.9178620371528292 Alpha: 1.0e-05 Theta: Matrix[[-0.0016496204592017856], [-0.10398354342681772], [-111.49241042720689]]
J: 1.902623825636303 Alpha: 1.0e-05 Theta: Matrix[[-0.0016498891114405914], [-0.10466436587368326], [-111.49190899927257]]
J: 1.8873858680247269 Alpha: 1.0e-05 Theta: Matrix[[-0.0016501577636793972], [-0.10534518832054848], [-111.49140757133824]]
J: 1.8721478527437034 Alpha: 1.0e-05 Theta: Matrix[[-0.0016504264159182024], [-0.10602601076741257], [-111.49090614340392]]
J: 1.8569098083540256 Alpha: 1.0e-05 Theta: Matrix[[-0.0016506950681570054], [-0.10670683321427255], [-111.4904047154696]]
J: 1.8416717846532462 Alpha: 1.0e-05 Theta: Matrix[[-0.0016509637203958004], [-0.10738765566111781], [-111.48990328753527]]
J: 1.8264337702403803 Alpha: 1.0e-05 Theta: Matrix[[-0.0016512323726345674], [-0.10806847810791036], [-111.48940185960095]]
J: 1.8111957469624462 Alpha: 1.0e-05 Theta: Matrix[[-0.0016515010251717409], [-0.1087493010703349], [-111.48890043166602]]
J: 1.7959577228777213 Alpha: 1.0e-05 Theta: Matrix[[-0.001651769677708553], [-0.10943012403208266], [-111.4883990037311]]
J: 1.7807196990939538 Alpha: 1.0e-05 Theta: Matrix[[-0.0016520383302440706], [-0.11011094699140556], [-111.48789757579618]]
J: 1.7654816767669712 Alpha: 1.0e-05 Theta: Matrix[[-0.0016523069827749494], [-0.11079176994204029], [-111.48739614786128]]
J: 1.7197677244765115 Alpha: 1.0e-05 Theta: Matrix[[-0.0016531129399852717], [-0.11283423807786983], [-111.4858918640573]]
J: 1.7045300185036796 Alpha: 1.0e-05 Theta: Matrix[[-0.0016533815914621833], [-0.11351505905442376], [-111.48539043612449]]
J: 1.689293134633683 Alpha: 1.0e-05 Theta: Matrix[[-0.0016536502402002386], [-0.11419587490110002], [-111.48488900819716]]
J: 1.674059195452273 Alpha: 1.0e-05 Theta: Matrix[[-0.001653918879126327], [-0.1148766723699622], [-111.48438758028945]]
J: 1.6588357959146847 Alpha: 1.0e-05 Theta: Matrix[[-0.0016541874829120791], [-0.11555740402097447], [-111.48388615245203]]
J: 1.6436500186219352 Alpha: 1.0e-05 Theta: Matrix[[-0.0016544559609891405], [-0.1162379002196091], [-111.48338472486603]]
J: 1.6285972611659707 Alpha: 1.0e-05 Theta: Matrix[[-0.001654723991174496], [-0.11691755751707966], [-111.4828832981758]]
J: 1.6139994752963014 Alpha: 1.0e-05 Theta: Matrix[[-0.0016549904481917704], [-0.11759426827073645], [-111.48238187463193]]
J: 1.600799606845299 Alpha: 1.0e-05 Theta: Matrix[[-0.0016552516449943116], [-0.11826112664220582], [-111.48188046160847]]
J: 1.5908244528084288 Alpha: 1.0e-05 Theta: Matrix[[-0.0016554977759847996], [-0.1188997667477244], [-111.48137907871664]]
J: 1.5851960976828814 Alpha: 1.0e-05 Theta: Matrix[[-0.0016557144987826046], [-0.11948332530842007], [-111.4808777546412]]
J: 1.5826817076400923 Alpha: 1.0e-05 Theta: Matrix[[-0.0016558999497352893], [-0.12000831170339445], [-111.48037649310945]]
J: 1.5816354848004566 Alpha: 1.0e-05 Theta: Matrix[[-0.0016560658987327093], [-0.12049677093659837], [-111.4798752705816]]
J: 1.581199878569286 Alpha: 1.0e-05 Theta: Matrix[[-0.0016562224426970157], [-0.12096761454376066], [-111.47937406686383]]
J: 1.5810169018926878 Alpha: 1.0e-05 Theta: Matrix[[-0.0016563748211790893], [-0.12143065620486218], [-111.47887287147701]]
J: 1.5809396242131868 Alpha: 1.0e-05 Theta: Matrix[[-0.0016565254040880424], [-0.1218903347622732], [-111.47837167968135]]
J: 1.5809069017613124 Alpha: 1.0e-05 Theta: Matrix[[-0.0016566752202995195], [-0.12234857730581448], [-111.47787048941908]]
J: 1.5808930296490606 Alpha: 1.0e-05 Theta: Matrix[[-0.001656824710233385], [-0.12280620875454971], [-111.47736929980935]]
J: 1.580887145848097 Alpha: 1.0e-05 Theta: Matrix[[-0.0016569740612930289], [-0.12326358014294572], [-111.47686811047738]]
J: 1.580884649719601 Alpha: 1.0e-05 Theta: Matrix[[-0.0016571233527736234], [-0.12372084005243131], [-111.47636692126457]]
J: 1.5808835906710963 Alpha: 1.0e-05 Theta: Matrix[[-0.0016572726175860411], [-0.12417805026085695], [-111.47586573210509]]
J: 1.5808831413239819 Alpha: 1.0e-05 Theta: Matrix[[-0.00165742186803091], [-0.12463523410670607], [-111.47536454297435]]
.........
適切な予測を作成するもの:
J: 4.330234652497978 Alpha: 1.0 Theta: Matrix[[0.12388059701492538], [211.9910447761194], [-111.13731343283582]]
J: 4.330234652497978 Alpha: 0.1 Theta: Matrix[[0.08626965671641812], [152.3222144059701], [-118.07202388059702]]
J: 4.2958677406623815 Alpha: 0.1 Theta: Matrix[[0.048658716417910856], [92.65338403582082], [-125.0067343283582]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.011644779104478219], [33.61767533134318], [-131.44443979104477]]
J: 0.4467735852246924 Alpha: 0.1 Theta: Matrix[[-0.014623104477611202], [-11.126378913433022], [-132.24166105074627]]
J: 3.333594209265678 Alpha: 0.1 Theta: Matrix[[0.01194378805970217], [31.177094038805805], [-126.89243925671643]]
J: 3.0930257965656063 Alpha: 0.01 Theta: Matrix[[0.009436400895523079], [26.892626149850567], [-126.92472924]]
J: 2.7493567080605392 Alpha: 0.01 Theta: Matrix[[0.007257365074627634], [23.13644550388053], [-126.8386038647761]]
J: 2.508788325211366 Alpha: 0.01 Theta: Matrix[[0.005466380895523164], [19.99261048238799], [-126.62851089164178]]
J: 2.405687589704577 Alpha: 0.01 Theta: Matrix[[0.004152999104478391], [17.61296913194023], [-126.28907722179103]]
J: 2.268219942362192 Alpha: 0.01 Theta: Matrix[[0.002959017910448543], [15.415473392238736], [-125.92224111492536]]
J: 2.1307522353180164 Alpha: 0.01 Theta: Matrix[[0.002093389253732125], [13.751072827761122], [-125.48597339134326]]
J: 2.027651529662123 Alpha: 0.01 Theta: Matrix[[0.0014367116417918252], [12.436814710149182], [-125.00961691402983]]
J: 1.9589177059909308 Alpha: 0.01 Theta: Matrix[[0.0009889847761201823], [11.44908667850739], [-124.49911195194028]]
J: 1.8558169406332465 Alpha: 0.01 Theta: Matrix[[0.0006606582089560022], [10.652638055522315], [-123.97004023522386]]
J: 1.8214500586485458 Alpha: 0.01 Theta: Matrix[[0.0004218823880604789], [9.988664770447688], [-123.42914782925371]]
J: 1.8214500884994413 Alpha: 0.01 Theta: Matrix[[0.0002428068653197179], [9.416182220312082], [-122.88082274064425]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00023086931308091184], [9.369775500013574], [-122.82513353589798]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00021893176084210577], [9.323368779715066], [-122.7694443311517]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0002069942086032997], [9.276962059416558], [-122.71375512640543]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00019505665636449364], [9.23055533911805], [-122.65806592165916]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00018311910412568757], [9.184148618819542], [-122.60237671691289]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.0001711815518868815], [9.137741898521034], [-122.54668751216661]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00015924399964807544], [9.091335178222526], [-122.49099830742034]]
J: 1.8214500884994413 Alpha: 0.001 Theta: Matrix[[0.00014730641755852312], [9.04492840598372], [-122.43530910670393]]
J: 1.8677695240029366 Alpha: 0.001 Theta: Matrix[[0.0001353688354689708], [8.998521633744915], [-122.37961990598751]]
J: 1.8462563443835032 Alpha: 0.0001 Theta: Matrix[[0.0001341750742749415], [8.993880951437452], [-122.374050986289]]
J: 1.8247430163841476 Alpha: 0.0001 Theta: Matrix[[0.00013298131308164604], [8.98924026913124], [-122.3684820665904]]
J: 1.803243007740144 Alpha: 0.0001 Theta: Matrix[[0.0001317875528781551], [8.984599588510665], [-122.36291314676808]]
J: 1.7875423426167685 Alpha: 0.0001 Theta: Matrix[[0.00013059512176735966], [8.979961171334951], [-122.35734406080917]]
J: 1.7870839229503594 Alpha: 0.0001 Theta: Matrix[[0.0001296573060241053], [8.97575636413016], [-122.35174314792931]]
J: 1.7870831481868632 Alpha: 0.0001 Theta: Matrix[[0.00012876197468911015], [8.971623907872633], [-122.34613692449842]]
J: 1.7870831468153818 Alpha: 0.0001 Theta: Matrix[[0.00012786672082037553], [8.967491583540149], [-122.34053069138426]]
J: 1.7870831468129538 Alpha: 0.0001 Theta: Matrix[[0.000126971467088789], [8.963359259441226], [-122.33492445825294]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001260762133574453], [8.959226935342718], [-122.3293182251216]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012518095962610202], [8.95509461124421], [-122.32371199199025]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012428570589475874], [8.950962287145702], [-122.3181057588589]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012339045216341546], [8.946829963047193], [-122.31249952572756]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012249519843207218], [8.942697638948685], [-122.30689329259621]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012159994470072888], [8.938565314850177], [-122.30128705946487]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00012070469096938559], [8.934432990751668], [-122.29568082633352]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.0001198094372380423], [8.93030066665316], [-122.29007459320218]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.000118914183506699], [8.926168342554652], [-122.28446836007083]]
J: 1.7870831468129498 Alpha: 0.0001 Theta: Matrix[[0.00011801892977535571], [8.922036018456144], [-122.27886212693949]]
......
編集: シータがすべて 0 であるため、仮説の最初の反復は常に 0.5 を予測していることに気付きました。しかし、その後は常に1または0を予測します(コードに存在しない対数を避けるために0.00001または0.99999)。それは私には正しくないように思えます-あまりにも自信があります-そしておそらくこれが機能しない理由の鍵です.