幾張GIF理解K-均值聚類原理
k均值聚類數學推導與python實現
前文說了k均值聚類,他是基于中心的聚類方法,通過迭代將樣本分到k個類中,使每個樣本與其所屬類的中心或均值最近。
今天我們看一下無監督學習之聚類方法的另一種算法,層次聚類:
層次聚類前提假設類別直接存在層次關系,通過計算不同類別數據點間的相似度來創建一棵有層次的嵌套聚類樹。在聚類樹中,不同類別的原始數據點是樹的最低層,樹的頂層是一個聚類的根節點。創建聚類樹有聚合聚類(自下而上合并)和分裂聚類(自上而下分裂)兩種方法,分裂聚類一般很少使用,不做介紹。
聚合聚類
聚合聚類具體過程
對于給定的樣本集合,開始將每個樣本分到一個類,然后再按照一定的規則(比如類間距最?。瑢M足規則的類進行合并,反復進行,直到滿足停止條件。聚合聚類三要素有:
①距離或相似度(閔可夫斯基距離,相關系數、夾角余弦)
②合并規則(最長/短距離,中心距離,平均距離)
③停止條件(類個數或類直徑達到或超過閾值)
聚合聚類算法
輸入:n個樣本組成的樣本集合及樣本間距離
輸出:樣本集合的層次化聚類
(1)計算n個樣本兩兩之間歐氏距離{dij}
(2)構造n個類,每個類只包含一個樣本
(3)合并類間距最小的兩個類,構造一個新類
(4)計算新類與其他各類的距離,若類的個數為1,終止計算,否則回到(3)
動畫表示:
python實現及案例
import queue
import math
import copy
import numpy as np
import matplotlib.pyplot as plt
class clusterNode:
def __init__(self, value, id=[],left=None, right=None, distance=-1, count=-1, check = 0):
'''
value: 該節點的數值,合并節點時等于原來節點值的平均值
id:節點的id,包含該節點下的所有單個元素
left和right:合并得到該節點的兩個子節點
distance:兩個子節點的距離
count:該節點所包含的單個元素個數
check:標識符,用于遍歷時記錄該節點是否被遍歷過
'''
self.value = value
self.id = id
self.left = left
self.right = right
self.distance = distance
self.count = count
self.check = check
def show(self):
#顯示節點相關屬性
print(self.value,' ',self.left.id if self.left!=None else None,' ',/
self.right.id if self.right!=None else None,' ',self.distance,' ',self.count)
class hcluster:
def distance(self,x,y):
#計算兩個節點的距離,可以換成別的距離
return math.sqrt(pow((x.value-y.value),2))
def minDist(self,dataset):
#計算所有節點中距離最小的節點對
mindist = 1000
for i in range(len(dataset)-1):
if dataset[i].check == 1:
#略過合并過的節點
continue
for j in range(i+1,len(dataset)):
if dataset[j].check == 1:
continue
dist = self.distance(dataset[i],dataset[j])
if dist < mindist:
mindist = dist
x, y = i, j
return mindist, x, y
#返回最小距離、距離最小的兩個節點的索引
def fit(self,data):
dataset = [clusterNode(value=item,id=[(chr(ord('a')+i))],count=1) for i,item in enumerate(data)]
#將輸入的數據元素轉化成節點,并存入節點的列表
length = len(dataset)
Backup = copy.deepcopy(dataset)
#備份數據
while(True):
mindist, x, y = self.minDist(dataset)
dataset[x].check = 1
dataset[y].check = 1
tmpid = copy.deepcopy(dataset[x].id)
tmpid.extend(dataset[y].id)
dataset.append(clusterNode(value=(dataset[x].value+dataset[y].value)/2,id=tmpid,/
left=dataset[x],right=dataset[y],distance=mindist,count=dataset[x].count+dataset[y].count))
#生成新節點
if len(tmpid) == length:
#當新生成的節點已經包含所有元素時,退出循環,完成聚類
break
for item in dataset:
item.show()
return dataset
def show(self,dataset,num):
plt.figure(1)
showqueue = queue.Queue()
#存放節點信息的隊列
showqueue.put(dataset[len(dataset) - 1])
#存入根節點
showqueue.put(num)
#存入根節點的中心橫坐標
while not showqueue.empty():
index = showqueue.get()
#當前繪制的節點
i = showqueue.get()
#當前繪制節點中心的橫坐標
left = i - (index.count)/2
right = i + (index.count)/2
if index.left != None:
x = [left,right]
y = [index.distance,index.distance]
plt.plot(x,y)
x = [left,left]
y = [index.distance,index.left.distance]
plt.plot(x,y)
showqueue.put(index.left)
showqueue.put(left)
if index.right != None:
x = [right,right]
y = [index.distance,index.right.distance]
plt.plot(x,y)
showqueue.put(index.right)
showqueue.put(right)
plt.show()
def setData(num):
#生成num個隨機數據
Data = list(np.random.randint(1,100,size=num))
return Data
if name == '__main__':
num = 20
dataset = setData(num)
h = hcluster()
resultset = h.fit(dataset)
h.show(resultset,num)
本文由博客一文多發平臺 OpenWrite 發布!
審核編輯 黃昊宇
-
機器學習
+關注
關注
66文章
8428瀏覽量
132806 -
深度學習
+關注
關注
73文章
5508瀏覽量
121306
發布評論請先 登錄
相關推薦
評論