Python圖像識別(聚類)


 1 # -*- coding: utf-8 -*-
 2 """
 3 Created on Fri Sep 21 15:37:26 2018
 4 
 5 @author: zhen
 6 """
 7 from PIL import Image
 8 import numpy as np
 9 from sklearn.cluster import KMeans
10 import matplotlib
11 import matplotlib.pyplot as plt
12 
13 def restore_image(cb, cluster, shape):
14     row, col, dummy = shape
15     image = np.empty((row, col, dummy))
16     for r in range(row):
17         for c in range(col):
18             image[r, c] = cb[cluster[r * col + c]]
19     return image
20 
21 def show_scatter(a):
22     N = 10
23     density, edges = np.histogramdd(a, bins=[N, N, N], range=[(0, 1), (0, 1), (0, 1)])
24     density /= density.max()
25     x = y = z = np.arange(N)
26     d = np.meshgrid(x, y, z)
27     
28     fig = plt.figure(1, facecolor='w')
29     ax = fig.add_subplot(111, projection='3d')
30 
31     cm = matplotlib.colors.ListedColormap(list('rgbm'))
32     ax.scatter(d[0], d[1], d[2], s=100 * density, cmap=cm, marker='o', depthshade=True)
33     ax.set_xlabel(u'')
34     ax.set_ylabel(u'')
35     ax.set_zlabel(u'')
36     plt.title(u'圖像顏色三維頻數分布', fontsize=20)
37     
38     plt.figure(2, facecolor='w')
39     den = density[density > 0]
40     den = np.sort(den)[::-1]
41     t = np.arange(len(den))
42     plt.plot(t, den, 'r-', t, den, 'go', lw=2)
43     plt.title(u'圖像顏色頻數分布', fontsize=18)
44     plt.grid(True)
45     
46     plt.show()
47       
48 if __name__ == '__main__':
49     matplotlib.rcParams['font.sans-serif'] = [u'SimHei']
50     matplotlib.rcParams['axes.unicode_minus'] = False
51     # 聚類數2,6,30
52     num_vq = 2
53     im = Image.open('C:/Users/zhen/.spyder-py3/images/Lena.png')
54     image = np.array(im).astype(np.float) / 255
55     image = image[:, :, :3]
56     image_v = image.reshape((-1, 3))
57     kmeans = KMeans(n_clusters=num_vq, init='k-means++')
58     show_scatter(image_v)
59     
60     N = image_v.shape[0]  # 圖像像素總數
61     # 選擇樣本,計算聚類中心
62     idx = np.random.randint(0, N, size=int(N * 0.7))
63     image_sample = image_v[idx]
64     kmeans.fit(image_sample)
65     result = kmeans.predict(image_v)  # 聚類結果
66     print('聚類結果:\n', result)
67     print('聚類中心:\n', kmeans.cluster_centers_)
68     
69     plt.figure(figsize=(15, 8), facecolor='w')
70     plt.subplot(211)
71     plt.axis('off')
72     plt.title(u'原始圖片', fontsize=18)
73     plt.imshow(image)
74     # plt.savefig('原始圖片.png')
75     
76     plt.subplot(212)
77     vq_image = restore_image(kmeans.cluster_centers_, result, image.shape)
78     plt.axis('off')
79     plt.title(u'聚類個數:%d' % num_vq, fontsize=20)
80     plt.imshow(vq_image)
81     # plt.savefig('矢量化圖片.png')
82     
83     plt.tight_layout(1.2)
84     plt.show()

結果:

      

  1.當k=2時:

  

       

  2.當k=6時:

    

        

  3.當k=30時:

    

       

總結:當聚類個數較少時,算法運算速度快但效果較差,當聚類個數較多時,運算速度慢效果好但容易過擬合,所以恰當的k值對於聚類來說影響極其明顯!!

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM