[Python]Add legend on cluster visualization use Matplotlib

Posted by John on 2019-09-20
Words 777 and Reading Time 3 Minutes
Viewed Times

這篇主要在介紹如何正確在Matplotlib上為每種cluster加上圖例(legend)。

前言

資料視覺化在不管是機器學習/深度學習都是很重要的一個環節,有了資料視覺化能夠很快速的讓其他人了解你的資料/模型。 但是在呈現dataset的時候,資料往往維度不是2維,無法透過x, y軸呈現出來。這時候通常會透過降維的方法來把高維度的資料壓縮到低維度呈現,常見的降維方法有PCA和t-SNE。 所以,你以為這篇要來介紹這兩個方法背後的原理嗎?

5efa825b04849f9665c723e29793ac46

詳細的資料網路上都有很多,這邊丟幾篇給有興趣的自己去看:

而在其中t-SNE又因為某些特性,使得他在低維度的時候能比PCA有更好的視覺化效果,使用方始也很簡單,sklearn上就有支援,所以這篇也沒有打算教怎麼使用。 [透過Matplotlib呈現cluster結果] 當把資料降到低維度(2維或3維)後,就可以透過Matplotlib來呈現cluster。下面介紹兩種呈現的方式:

1.

相信有在學深度學習相關知識的人對下面這張圖都不陌生:

0001 對,這張圖是莫煩CNN教程中提供的t-SNE visualization code,如果去看code就會知道他是寫了一個迴圈,針對每一個point去做plt.text(),並同時計算該點的color

1
2
3
4
5
6
7
8
9
10
def plot_with_labels(lowDWeights, labels):
plt.cla()
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
for x, y, s in zip(X, Y, labels):
c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
plt.xlim(X.min(), X.max());
plt.ylim(Y.min(), Y.max());
plt.title('Visualize last layer');
plt.show();
plt.pause(0.01)

2.

如果不想以文字呈現的話,也可以用散布圖(plt.scatter())來呈現,這時候用一行解決,把整個x, y array餵進去,而不用寫個迴圈去做:

X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
plt.scatter(X, Y, c=labels, marker=marker, cmap=cmap)

常用matplotlib的人就會知道,以往要在圖片上加上圖例,都會在每次的繪圖function指定lebel,最後呼叫plt.legend():

1
2
3
4
plt.plot(......, label = '圖例1')
plt.plot(......, label = '圖例2')
...
plt.legend()

不過對於cluster而言,我們想要的應該是每一個cluster一個color,但又不想用迴圈一個一個點慢慢加,這時候可以使用matplotlib version (3.1.0)後提供的new method:

Now, PathCollection provides a method legend_elements() to obtain the handles and labels for a scatter plot in an automated way. This makes creating a legend for a scatter plot as easy as

使用方法如下:

  1. 把原本的plt.scatter用個變數接好
  2. 透過呼叫該變數的function legend_elements()取得總共有哪些labels數(matplotlib會自動算好)
  3. legend_elements()會回傳一個list,把list拆開餵到plt.legend()裡面

上面步驟看不懂沒關係,咱們直接看code,兩行而已:

1
2
3
4
5
# plot
scatter = plt.scatter(x_arr, y_arr, c=labels, marker=marker, cmap=cmap)

# produce a legend with the unique colors from the scatter
plt.legend(*scatter.legend_elements())

000.PNG 就john,結束!


>