線形モデルの SHAP 値を計算したいと考えています。回帰には、サンプルの重みを使用する必要があります。
問題は、SHAP 値を計算する際に、サンプルの重みが実際に適切に適用されたかどうかを評価できないことです。
ここに例があります。
# Import libraries
import shap
import pandas
import numpy
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
# Setting up the data and the model
df.head()
weights Funnel Q23_1 Q23_2 Q23_3 Q23_4 Q23_5 Q23_6 Q23_7
Q23_8 Q23_9 Q23_10 Q23_11 Q23_12 Q23_13 Q23_14 Q23_15
847 0.75149 5.0 2.0 2.0 1.0 3.0 3.0 3.0 2.0 5.0 3.0 1.0 2.0 2.0 2.0 3.0 1.0
995 2.18378 2.0 1.0 1.0 1.0 3.0 3.0 3.0 1.0 4.0 2.0 2.0 1.0 2.0 2.0 2.0 2.0
14403 1.10852 2.0 1.0 1.0 1.0 2.0 2.0 4.0 1.0 5.0 1.0 2.0 2.0 1.0 3.0 3.0
1.0
13311 0.85934 4.0 2.0 2.0 3.0 3.0 2.0 3.0 3.0 4.0 4.0 3.0 2.0 2.0 3.0 3.0
2.0
17019 0.95337 2.0 1.0 1.0 2.0 3.0 2.0 2.0 3.0 2.0 2.0 2.0 3.0 1.0 1.0 1.0
2.0
Y = df_t.drop(['Funnel', 'weights'], axis=1)
X = df_t[['Funnel']]
lm = LinearRegression()
まず、重みなしで回帰を計算します。
fit = lm.fit(X,Y)
pred = fit.predict(X)
print("R2 - No Weights:", r2_score(Y,pred))
次に、重みを使用して回帰を計算します。
fit = lm.fit(X,Y, sample_weight=df['weights'])
pred = fit.predict(X)
print("R2 - Wit weights:", r2_score(Y, pred2, sample_weight=df['weights']))
これまでにわかったことから (結果を評価するために R や SPSS などの他のソフトウェア パッケージを使用してさまざまな組み合わせをテストしました)、正しい結果を得るにはfit()
関数と関数に重みを適用する必要がありますr2_score()
(上記の例を参照)。 . たとえば、重みをfit()
関数にのみ適用し、関数には適用しないr2_score()
場合、レポートされる R2 値は間違っています (つまり、モデルが間違っています)。関数にも重みを適用するとpredict()
、R2 値も間違っています (つまり、モデルが間違っています)。
fit = lm.fit(X,Y, sample_weight=df['weights'])
pred = fit.predict(X, sample_weight=df['weights'])
print("R2 - With something in between:", r2_score(Y, pred, sample_weight=df['weights']))
ただし、SHAP 値は Python でしか計算できないため、結果を評価する方法がありません。問題は、SHAP 値を正しく計算するためにサンプルの重みをどのように適用すればよいかということです。
フィット関数のみ (?):
fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent')
shap_values = explainer.shap_values(X)
または、explainer()
関数内 (?):
fit = lm.fit(X,Y, sample_weight=df['weights'])
explainer = shap.LinearExplainer(fit, X, feature_dependence = 'independent',
sample_weight=df['weights'])
shap_values = explainer.shap_values(X)
他の可能性もあるかもしれませんが、どれが正しいかわかりません。
これは小さなデータサンプルです。
print(df.to_dict())
{'weights': {847: 0.75149, 995: 2.18378, 14403: 1.10852, 13311: 0.85934, 17019: 0.95337, 23707: 0.8899, 29562: 0.96819, 30627: 1.16261, 15187: 1.15915, 24179: 1.09833}, 'Funnel': {847: 5.0, 995: 2.0, 14403: 2.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 2.0, 30627: 4.0, 15187: 4.0, 24179: 5.0}, 'Q23_1': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_2': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_3': {847: 1.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 2.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_4': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_5': {847: 3.0, 995: 3.0, 14403: 2.0, 13311: 2.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_6': {847: 3.0, 995: 3.0, 14403: 4.0, 13311: 3.0, 17019: 2.0, 23707: 4.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 2.0}, 'Q23_7': {847: 2.0, 995: 1.0, 14403: 1.0, 13311: 3.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_8': {847: 5.0, 995: 4.0, 14403: 5.0, 13311: 4.0, 17019: 2.0, 23707: 4.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_9': {847: 3.0, 995: 2.0, 14403: 1.0, 13311: 4.0, 17019: 2.0, 23707: 2.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_10': {847: 1.0, 995: 2.0, 14403: 2.0, 13311: 3.0, 17019: 2.0, 23707: 2.0, 29562: 3.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_11': {847: 2.0, 995: 1.0, 14403: 2.0, 13311: 2.0, 17019: 3.0, 23707: 3.0, 29562: 2.0, 30627: 1.0, 15187: 2.0, 24179: 1.0}, 'Q23_12': {847: 2.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 1.0, 23707: 2.0, 29562: 2.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_13': {847: 2.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 2.0, 29562: 4.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_14': {847: 3.0, 995: 2.0, 14403: 3.0, 13311: 3.0, 17019: 1.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}, 'Q23_15': {847: 1.0, 995: 2.0, 14403: 1.0, 13311: 2.0, 17019: 2.0, 23707: 3.0, 29562: 1.0, 30627: 1.0, 15187: 5.0, 24179: 1.0}}