SHAP を plot する

前回の続き

SHAP の dependence_plotscatter_plot は便利ですが、 詳しく見たい場合に shap_values に含まれる値でしか色分けできないのと、色分けの方法をあまり細かく指定できないようです。



そのようなことがしたかったので、shap_values を array として扱って matplotlib で scatter plot しました。

name = "水温"
plt.scatter(
X[name], # 横軸
shap_values[:,X.columns.get_loc(name)], # 縦軸
s=4,
alpha=0.6,
c=month, # 「月」で色分けする
cmap="hsv",
)
plt.xlabel(name)
plt.ylabel(f"SHAP value for {name}")
plt.grid(True)
plt.colorbar(label='月')

shap_values に入っていない"月"という値との dependency を見てみました。秋の水温20℃と春の水温20℃ではSHAP値が違うようだ、ということがわかりました。 

コード:https://github.com/aoda2/shap_example/blob/main/shap_%E3%82%A2%E3%82%B8.ipynb