使用 TensorFlow 學習 kNN 分類演算法

以數字手寫辨識實作為範例

Leo Chiu
7 min readAug 10, 2018
Photo by Toa Heftiba on Unsplash

最近鄰居法 ( k Nearest Neighbor, kNN) 是一個可以分類回歸分析資料的一種非參數 (nonparametric) 演算法,為什麼稱作非參數演算法?

參數演算法 V.S. 非參數演算法

兩者的差別在於是否針對問題假設 (hypothesis),舉一個例子,我們想要透過機器學習的演算法進行房價預測,所以我們假設「房價與屋齡、坪數、地點之間的關係為 y = wx + b」,透過機器學習的演算法學習參數 wb,進而預測房價,這稱作參數演算法

反之,非參數演算法即是不經由假設便可以學習任何的函式,同上,我們想要預測房價,以 kNN 這個演算法來說,透過求出最鄰近 k 個點的房價平均值,達成預測房價的目的。kNN 不需要任何的假設,直接藉由演算法求得房價,這種方法稱作非參數演算法

非參數演算法在過去的經驗很少時相當地有效,例如在美國賣場中有時會把尿布與啤酒放在一起,因為老爸在買尿布時想到晚上要看球賽,然後就會順手拿了一手的啤酒。在過去誰又會想到尿布加啤酒的組合呢?在沒有經驗的情況下,非參數演算法便能夠發揮它的作用,透過演算法找出這種神奇的組合。

kNN 可以說是機器學習中最簡單的演算法之一,因為它不需要複雜的計算就能夠實現。接下來,我們用 TensorFlow 與入門經常使用的手寫數字辨識範例,帶大家實現 kNN 的演算法。

kNN 分類

手寫數字辨識屬於一種分類 (Classification) 的問題,輸入一張 0 ~ 9 的影像,並預測該影像為 0 ~ 9 的哪一個數字。

首先,我們要先知道 kNN 分類運作的模式,kNN 的目標在於找出最鄰近的 k 個點,並透過「多數決」的方式決定該點屬於哪一類。以底下這張圖來說,假設我們要找出「?」屬於藍色或是紅色,於是先尋找鄰近的 5 個點,接著我們發現藍色與紅色點的數量比例為 3 : 2,所以經由「多數決」的結果可以判定「?」屬於藍色。

kNN 分類演算法實現數字手寫辨識

以手寫數字辨識來說,同樣地,我們想要知道一張影像屬於 0 ~ 9 哪一個數字,先求出該影像與其他影像的歐幾里得距離,得知最鄰近的 k 張影像後,再透過多數決判定該影像的分類。

計算距離有很多種,常見的像是這個範例使用的歐幾里得距離,還有曼哈頓距離餘弦相似度漢明距離,如果有興趣各種計算距離的方法可以參考這篇文章

引入所需相依套件、資料集

首先,引入我們需要的幾個套件,在這個範例中,使用的 TensorFlow 版本為 1.9,因為有使用到 tf.contrib 的函式,該函式庫的函式未來有可能會修改,所以請注意電腦中的 TensorFlow 版本是否已經不包含在這個範例中使用的函式。

我們使用的 Dataset 是 TensorFlow 提供的 MNIST,如果要使用其他來源的 MNIST 也可以,例如 kaggle 提供的 MNIST

定義訓練資料集與測試資料集

TensorFlow 提供的 MNIST 分成訓練資料 (testing data) 與測試資料 (training data),我們將測試資料做為欲分類的影像,目標即是在訓練資料中找出相近的影像。

原始的訓練資料與測試資料分別有 55000 與 10000 筆,如果將訓練資料 55000 筆資料全部作為 kNN 分類的資料將會導致運算時間過長,因為 kNN 是一個時間複雜度為 O(ndk) 的演算法,n 為欲分類的影像數量,d 為需要計算距離的影像數量,k 為一張影像中的像素個數。

在這個範例中,如果 n=10000、d=55000、k=764,則 big-O 相當地可觀,所以我們將 d 縮小成 10000,藉此減少 kNN 的計算量。

定義運算過程(建立計算圖)

首先,我們定義 X 為輸入訓練資料的節點,y 為輸入預測的影像的節點。因為一張手寫辨識的影像大小為 28 ✕ 28 = 784,所以節點的大小為 784。

placeholder 是 TensorFlow 中一個很有趣的用法,可以預先在神經網路中 佔位,而不用事先定義好輸入的資料,這也是 TensorFlow 中計算圖(Computing Graph)的概念,之後再使用 Session 將資料傳入計算圖中。

placeholder 中定義節點的形狀 shape=(None, num_inputs)None 的意思是二維矩陣的 row 事先尚未定義,也就是「不定義輸入資料的筆數」,隨著透過 Session 將資料傳入計算圖時,動態符合輸入資料的筆數。

在手寫數字辨識這個範例中,將影像的像素作為特徵,使用歐幾里得距離計算兩張影像之間的距離。

在計算完距離後,我們想要知道最鄰近的 7 張影像,所以使用 tf.contrib 中的函式排序張量,將結果進行切片 (slice),取得距離最接近的 7 個距離。

執行 kNN 演算法(運行計算圖)

訓練時,透過 Session 傳遞資料至計算圖中,別忘了記得要初始化所有的變數。top_k 能夠取得 7 個最接近的距離,所以我們可以利用 top_k 反查詢 dist 的索引值,而 dist 的索引值可以對應至 test_label 該影像的原始分類。

接著,我們要對 7 張最鄰近的影像進行投票,使用的是 Counter 這個資料結構, Counter 可以幫助我們計算在串列中數值出現的次數,並取得出現次數最多的影像類別,則該類別為 kNN 所預測的影像類別。我們將預測分類與真實分類都存入 prediction,為了在 kNN 執行完後計算辨識的準確率。

辨識結果

我們將 prediction 放入 DataFrame 中,方便我們計算準確率。最後可以得到準確率大約為 0.9391,是一個不錯的準確率,但在手寫數字辨識中還不算是最好的。如果你有興趣在手寫數字辨識這個問題中得到更好的結果,可以嘗試使用深度學習的模型,在這個網站中有許多方法在手寫數字辨識這個問題上得到 99 % 以上的準確率。

結論

最近鄰居法 ( kNN) 是一個機器學習中最簡單的演算法之一,在這篇文章中提供使用 TensorFlow 實現 kNN 的方法,經由一個簡單的範例,讓你了解 kNN 的運作。

在你執行完這個範例後,應該會發現執行時間非常地長,因為 kNN 的時間複雜度為 O(ndk),在特徵為高維度或是訓練資料過多時容易造成運行時間過長的問題。

如果你想嘗試更快更好的模型,而且對於運用類神經網路分類手寫數字有興趣,可以參考延伸閱讀。

閱讀資料

延伸閱讀

--

--

Leo Chiu

每天進步一點點,在終點遇見更好的自己。 Instragram 小帳:@leo.web.dev