网站建设功能规划个人免费开发网站
#coding:utf-8
__author__ = 'jmh081701'
#本文件主要学习一种经典的聚类方法:k-means
#我们把这个算法用于一个RGB图像的聚类,看能出来的什么的效果
#k-means的原理:
'''
输入:x[1],x[2],x[3],...,x[n],其中每个x[i]都是m维的向量,给定聚类的数目k
1.随机生成k个代表元:z[1],z[2],...,z[k];每个z[i]都是第i类的中心元
2.repeat:更新 xi所述的类别ci,使得:|x[i]-z[ci]|最小更新 z[j],z[j]等于所在类别G[j]的所有样本的平均值
until:z不再改变
'''
import numpy as np
import math
import random
from PIL import Imagecnt=0
def calculate_zi(Gi,X):
#给定Gi,里面包含着属于这个类别的元素,然后计算这些元素的中心点
#在本实例中,Gi里面包含的是下标global cntsumi=np.zeros(len(X[0]))for each in Gi:cnt+=1sumi+=X[each]sumi/=(len(Gi)+0.000000001)zi=sumireturn zidef find_ci(xi,Z):#寻找离xi最近的中心元素ci,使得Z[ci]与xi之间的向量差的內积最小global cntdis_= np.inflen_=len(Z)rst_index = Nonefor i in range(len_):cnt+=1tmp_dist=np.dot(xi-Z[i],np.transpose(xi-Z[i]))if tmp_dist<dis_:rst_index=idis_=tmp_distreturn rst_indexdef k_mean(X,k):G=[] #G[i]={1,2,3...}表示属于第i类的样本在X中的索引,洗标Z=[] #Z[i] 第i类的中心点N=len(X)c=[] #c[i]=1,2,...,k;表示第i个样本属于第c[i]类tmpr=set()while len(Z)<k:r=random.randint(0,len(X)-1)if r not in tmpr:tmpr.add(r)Z.append(X[r])G.append(set())for i in range(N):c.append(0)#随机生成K个中心元素while True:group_flag=np.zeros(k)for i in range(N):new_ci = find_ci(X[i],Z)if c[i] != new_ci:#找到了更好的,把xi从原来的c[i]调到new_ci去,于是有两个组需要更新:new_ci,c[i]if i in G[c[i]]:G[c[i]].remove(i)group_flag[c[i]]=1 #把i从原来所属的组中移出来G[new_ci].add(i)group_flag[new_ci]=1 #把i加入到新的所属组去c[i]=new_ci#上面已经更新好了各元素的所属if np.sum(group_flag)==0:#没有组被修改breakfor i in range(k):if group_flag[i]==0:#未修改,无须重新计算continueelse:Z[i]=calculate_zi(list(G[i]),X)return Z,c,kdef test_rgb_img():filename=r"1.jpg"im = Image.open(filename)img = im.load()im.close()height = im.size[0]width= im.size[1]print(im.size)X=[]for i in range(0,height):for j in range(0,width):X.append(np.array(img[i,j]))Z,c,k=k_mean(X,8)#print(Z)new_im = Image.new("RGB",(height,width))for i in range(0,height):for j in range(0,width):index = i * width + jpix = list(Z[c[index]])for k in range(len(pix)):pix[k]=int(pix[k])new_im.putpixel((i,j),tuple(pix))new_im.show()
if __name__ == '__main__':test_rgb_img()print(cnt)
原图:
k=8的聚类结果:
k=4的聚类结果:
k=2:聚类结果
github地址:https://github.com/jmhIcoding/ml.git