Python、機械学習

【Python】データ分析におけるyとpredict_yのプロット方法(実際の値と予測値の差を可視化しよう)

回帰問題において機械学習を行い、実際の目的変数yに対して、予測値predict_yが得られたとしましょう。

機械学習モデルの精度を評価するためにR2やRMSE、MAEといった指標を計算するかと思います。

それに加えて、yとpredict_yの差を目視で確認することが重要です。

ここでは、Pythonでyとpredict_yの関係をプロットする方法を説明します。

全てのプロットが収まるように散布図を作成する

まずは必要なライブラリをインポートします。

import numpy as np
from matplotlib import pyplot as plt

こちらに、目的変数yと、機械学習から予測した目的変数predicted_yを適当に準備しました。
このdfをprintで表示するとこのようになります。

y= np.array([90, 30, 12, 45, 60, 23, 11, 7, 76, 
51, 34, 42, 12, 94, 28, 99, 1, 55, 71, 33])

predict_y= np.array([93, 28, 10, 40, 63, 21, 12, 
5, 79, 48, 27, 54, 4, 99, 40, 85, 10, 35, 45, 38])

まずはplt.figure()でグラフエリアを確保し、yとpredict_yの散布図を描くコードを作ります。

次に、全プロットが収まるようにするため、プロットエリアの範囲を決めます。

# グラフエリアを設定し、散布図を描く。
plt.figure(figsize=(6, 6))
plt.scatter(y, predict_y)

# yの最大値、最小値を計算する。
y_max = np.max(y)
y_min = np.min(y)

# predict_yの最大値、最小値を計算する。
predict_y_max = np.max(predict_y)
predict_y_min = np.min(predict_y)

# 全てのプロットが収まるようにするには、yとpredict_y両方のうち
# 最も小さい値、最も大きい値を縦軸横軸の範囲にすればいい。
axis_max = max(y_max, predict_y_max)
axis_min = min(y_min, predict_y_min)

plt.xlim(axis_min, axis_max)
plt.ylim(axis_min, axis_max)

plt.xlabel('y')
plt.ylabel('predict_y')

plt.show()

ちょっと窮屈ですね・・・。

余白を持たせる

少し余白を持たせましょう。縦横ともに、プロットエリアの長さの5%の余白を両側に持たせたいと思います。

前述のPythonコードの「plt.xlim(axis_min, axis_max)」の手前に、次のような文を入れます。

# プロットエリアの長さは縦横ともにaxis_max-axis_min。
# これの5%の長さ分の余白を取る。
axis_max = axis_max + (axis_max-axis_min)*0.05
axis_min = axis_min + (axis_max-axis_min)*0.05

散布図に余裕ができ、見やすくなりました。

y = predict_yの直線を入れる

各プロットが理想状態からどのくらい乖離しているかを可視化するため、y = predict_yの直線を入れます。

この直線から遠いプロットはyとpredict_yの乖離が大きく、予測精度が悪いと言えます。逆も然りです。

前述のPythonコードの「plt.xlim(axis_min, axis_max)」の手前に、次のような文を入れましょう。

#y=predicted_yの直線を引く。
plt.plot([axis_min, axis_max], [axis_min, axis_max])

yとpredict_yの一致度合いが分かりやすくなりました。

あとはお好みで直線の色やスタイルを変えてもらえればと思います。

例えば、直線を黒にするなら’k’を書き足します。

# y=predicted_yの直線を引く。
plt.plot([axis_min, axis_max], [axis_min, axis_max], 'k')

点線にしたいなら、’–‘を書き足します。

# y=predicted_yの直線を引く
plt.plot([axis_min, axis_max], [axis_min, axis_max], '--')


ちなみに、’-‘は実線になります。plt.plot()のデフォルトが実線ですので、実線でいいなら、敢えて’-‘を書く必要はありません。

蛇足ですが、黒の点線にしたい時は’–k’とします。

データ分析の最後にぜひ使ってください。