使用 TensorFlow 學習 kNN 分類演算法

Photo by Toa Heftiba on Unsplash

最近鄰居法 ( k Nearest Neighbor, kNN) 是一個可以分類回歸分析資料的一種非參數 (nonparametric) 演算法,為什麼稱作非參數演算法?

參數演算法 V.S. 非參數演算法

兩者的差別在於是否針對問題假設 (hypothesis),舉一個例子,我們想要透過機器學習的演算法進行房價預測,所以我們假設「房價與屋齡、坪數、地點之間的關係為 y = wx + b」,透過機器學習的演算法學習參數 wb,進而預測房價,這稱作參數演算法

反之,非參數演算法即是不經由假設便可以學習任何的函式,同上,我們想要預測房價,以 kNN 這個演算法來說,透過求出最鄰近 k 個點的房價平均值,達成預測房價的目的。kNN 不需要任何的假設,直接藉由演算法求得房價,這種方法稱作非參數演算法

非參數演算法在過去的經驗很少時相當地有效,例如在美國賣場中有時會把尿布與啤酒放在一起,因為老爸在買尿布時想到晚上要看球賽,然後就會順手拿了一手的啤酒。在過去誰又會想到尿布加啤酒的組合呢?在沒有經驗的情況下,非參數演算法便能夠發揮它的作用,透過演算法找出這種神奇的組合。

kNN 可以說是機器學習中最簡單的演算法之一,因為它不需要複雜的計算就能夠實現。接下來,我們用 TensorFlow 與入門經常使用的手寫數字辨識範例,帶大家實現 kNN 的演算法。

kNN 分類

手寫數字辨識屬於一種分類 (Classification) 的問題,輸入一張 0 ~ 9 的影像,並預測該影像為 0 ~ 9 的哪一個數字。

首先,我們要先知道 kNN 分類運作的模式,kNN 的目標在於找出最鄰近的 k 個點,並透過「多數決」的方式決定該點屬於哪一類。以底下這張圖來說,假設我們要找出「?」屬於藍色或是紅色,於是先尋找鄰近的 5 個點,接著我們發現藍色與紅色點的數量比例為 3 : 2,所以經由「多數決」的結果可以判定「?」屬於藍色。

kNN 分類演算法實現數字手寫辨識

以手寫數字辨識來說,同樣地,我們想要知道一張影像屬於 0 ~ 9 哪一個數字,先求出該影像與其他影像的歐幾里得距離,得知最鄰近的 k 張影像後,再透過多數決判定該影像的分類。

計算距離有很多種,常見的像是這個範例使用的歐幾里得距離,還有曼哈頓距離餘弦相似度漢明距離,如果有興趣各種計算距離的方法可以參考這篇文章

引入所需相依套件、資料集

首先,引入我們需要的幾個套件,在這個範例中,使用的 TensorFlow 版本為 1.9,因為有使用到 tf.contrib 的函式,該函式庫的函式未來有可能會修改,所以請注意電腦中的 TensorFlow 版本是否已經不包含在這個範例中使用的函式。

我們使用的 Dataset 是 TensorFlow 提供的 MNIST,如果要使用其他來源的 MNIST 也可以,例如 kaggle 提供的 MNIST

定義訓練資料集與測試資料集

TensorFlow 提供的 MNIST 分成訓練資料 (testing data) 與測試資料 (training data),我們將測試資料做為欲分類的影像,目標即是在訓練資料中找出相近的影像。

原始的訓練資料與測試資料分別有 55000 與 10000 筆,如果將訓練資料 55000 筆資料全部作為 kNN 分類的資料將會導致運算時間過長,因為 kNN 是一個時間複雜度為 O(ndk) 的演算法,n 為欲分類的影像數量,d 為需要計算距離的影像數量,k 為一張影像中的像素個數。

在這個範例中,如果 n=10000、d=55000、k=764,則 big-O 相當地可觀,所以我們將 d 縮小成 10000,藉此減少 kNN 的計算量。

定義運算過程(建立計算圖)

首先,我們定義 X 為輸入訓練資料的節點,y 為輸入預測的影像的節點。因為一張手寫辨識的影像大小為 28 ✕ 28 = 784,所以節點的大小為 784。

placeholder 是 TensorFlow 中一個很有趣的用法,可以預先在神經網路中 佔位,而不用事先定義好輸入的資料,這也是 TensorFlow 中計算圖(Computing Graph)的概念,之後再使用 Session 將資料傳入計算圖中。

placeholder 中定義節點的形狀 shape=(None, num_inputs)None 的意思是二維矩陣的 row 事先尚未定義,也就是「不定義輸入資料的筆數」,隨著透過 Session 將資料傳入計算圖時,動態符合輸入資料的筆數。

在手寫數字辨識這個範例中,將影像的像素作為特徵,使用歐幾里得距離計算兩張影像之間的距離。

在計算完距離後,我們想要知道最鄰近的 7 張影像,所以使用 tf.contrib 中的函式排序張量,將結果進行切片 (slice),取得距離最接近的 7 個距離。

執行 kNN 演算法(運行計算圖)

訓練時,透過 Session 傳遞資料至計算圖中,別忘了記得要初始化所有的變數。top_k 能夠取得 7 個最接近的距離,所以我們可以利用 top_k 反查詢 dist 的索引值,而 dist 的索引值可以對應至 test_label 該影像的原始分類。

接著,我們要對 7 張最鄰近的影像進行投票,使用的是 Counter 這個資料結構, Counter 可以幫助我們計算在串列中數值出現的次數,並取得出現次數最多的影像類別,則該類別為 kNN 所預測的影像類別。我們將預測分類與真實分類都存入 prediction,為了在 kNN 執行完後計算辨識的準確率。

辨識結果

我們將 prediction 放入 DataFrame 中,方便我們計算準確率。最後可以得到準確率大約為 0.9391,是一個不錯的準確率,但在手寫數字辨識中還不算是最好的。如果你有興趣在手寫數字辨識這個問題中得到更好的結果,可以嘗試使用深度學習的模型,在這個網站中有許多方法在手寫數字辨識這個問題上得到 99 % 以上的準確率。

結論

最近鄰居法 ( kNN) 是一個機器學習中最簡單的演算法之一,在這篇文章中提供使用 TensorFlow 實現 kNN 的方法,經由一個簡單的範例,讓你了解 kNN 的運作。

在你執行完這個範例後,應該會發現執行時間非常地長,因為 kNN 的時間複雜度為 O(ndk),在特徵為高維度或是訓練資料過多時容易造成運行時間過長的問題。

如果你想嘗試更快更好的模型,而且對於運用類神經網路分類手寫數字有興趣,可以參考延伸閱讀。

閱讀資料

延伸閱讀

--

--

--

每天進步一點點,在終點遇見更好的自己。

Love podcasts or audiobooks? Learn on the go with our new app.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Airwaves

Airwaves

每天進步一點點,在終點遇見更好的自己。

More from Medium

Quoted Insurance Plan Prediction

Predictor Importance for Bagged Trees in Classification Learner App??

Handwritten digit recognition on MNIST dataset using python

#100DaysOfMLCode