Python、機械学習

【Python】複数のグラフを並べて表示する方法(plt.figure()、plt.subplot()、fig, ax = plt.subplots() )

Pythonでデータ分析をするとき、複数のグラフを並べて表示することがよくあります。

Kaggleのkernelsを見ても、多くの人が複数のグラフを並べて考察をしています。

ここでは、グラフの並べ方を説明します。大きく、2通りの書き方があります。好きな方を使っていただければと思います。

方法1:plt.figure()とplt.subplot()のセット

まずは、グラフを配置するエリアを確保する:plt.figure()

まずは、グラフを配置するエリアを確保します。plt.figure()を使えばOKです。

plt.figure(figsize=(6, 4))
plt.show()

<Figure size 432×288 with 0 Axes>

figsize=()で、エリアの広さを指定します。デフォルトがfigsize=(6, 4)ですので、何も書かない場合は(6, 4)のエリアになります。

次に、エリアの中にグラフスペースを作る

次に、エリアの中にグラフスペースを作ります。plt.subplot()を使えばOKです。

plt.figure(figsize=(6, 4))

plt.subplot(1, 1, 1)
plt.plot([1,2], [2,1])

plt.show()

plt.subplot(1, 1, 1)で、エリア中にグラフスペースを1×1個置き、そのうちの1番目(左上)のグラフスペースを指定している状態になります。

その後、plt.plot()のようにグラフ描画のコードを描けばOKです。

plt.subplot(1, 1, 1)ですと、グラフスペースが1個だけです。1個だけでいいなら、次のような書き方でもOKです。

plt.plot([1,2], [2,1])
plt.show()

グラフを2つ以上並べる場合

グラフを2つ並べる場合は、このように書きます。

plt.figure(figsize=(6, 4))

plt.subplot(1,2,1)
plt.plot([1,2], [2,1])

plt.subplot(1,2,2)
plt.hist([1, 2, 3, 3, 4, 5, 3, 4, 6, 5, 6, 8, 5, 2,1], bins=5)
plt.show()

エリア中にグラフエリアを1×2個置き、左上から数えて1番目のところがplt.sublot(1, 2, 1)、左上から数えて2番目のところがplt.subplot(1, 2, 2)になります。

方法2:fig, ax = plt.subplots()

この書き方もよく見かけます。

まず最初にfig, ax = plt.subplots()と書き、行方向、列方向にグラフを何個ずつ配置するのか、サイズはどうするかをカッコ内に記入します。

その後、plt.subplots_adjust()で幅方向、高さ方向のグラフ間距離を定めます。これを書かないと窮屈な感じになるので、設定した方がいいです。

そして、ax[行方向、列方向]としてグラフエリアを指定しつつ、どんなプロットにするか、タイトルはどうするか等を紐づけていきます。

fig, ax = plt.subplots(2, 2, figsize=(6, 4))
plt.subplots_adjust(wspace=0.4, hspace=0.6) # デフォルトは共に 0.2

ax[0, 0].plot([1,2], [2,1])
ax[0, 0].set_title('lineplot')
ax[0, 0].set_xlabel('x')
ax[0, 0].set_ylabel('y')

ax[0, 1].hist([1, 2, 3, 3, 4, 5, 3, 4, 6, 5, 6, 8, 5, 2,1], 
bins=5)
ax[0, 1].set_title('histplot')
ax[0, 1].set_xlabel('x')
ax[0, 1].set_ylabel('y')

ax[1, 0].plot([1,2], [2,1], color='red')
ax[1, 0].set_title('lineplot')
ax[1, 0].set_xlabel('x')
ax[1, 0].set_ylabel('y')

ax[1, 1].hist([1, 2, 3, 3, 4, 5, 3, 4, 6, 5, 6, 8, 5, 2,1], 
bins=5, color='red')
ax[1, 1].set_title('histplot')
ax[1, 1].set_xlabel('x')
ax[1, 1].set_ylabel('y')

注意点!横一線、または縦一線の並べ方のとき!

fig, ax = plt.subplots(1, 2)やfig, ax = plt.subplots(2, 1)のように、どちらかが1の時は要注意です。

ax[]の角カッコ内の数字は1つだけになります!2つ書いてしまうとエラーとなります。

fig, ax = plt.subplots(1, 2, figsize=(6, 4))
plt.subplots_adjust(wspace=0.4, hspace=0.6)

ax[0].plot([1,2], [2,1])
ax[0].set_title('lineplot')
ax[0].set_xlabel('x')
ax[0].set_ylabel('y')

ax[1].hist([1, 2, 3, 3, 4, 5, 3, 4, 6, 5, 6, 8, 5, 2,1], bins=5)
ax[1].set_title('histplot')
ax[1].set_xlabel('x')
ax[1].set_ylabel('y')
fig, ax = plt.subplots(2, 1, figsize=(6, 4))
plt.subplots_adjust(wspace=0.4, hspace=0.6)

ax[0].plot([1,2], [2,1])
ax[0].set_title('lineplot')
ax[0].set_xlabel('x')
ax[0].set_ylabel('y')

ax[1].hist([1, 2, 3, 3, 4, 5, 3, 4, 6, 5, 6, 8, 5, 2,1], bins=5)
ax[1].set_title('histplot')
ax[1].set_xlabel('x')
ax[1].set_ylabel('y')