scala實現kmeans算法


算法的概念不做過都解釋,google一下一大把。直接貼上代碼,有比較詳細的注釋了。

主程序:

 1 import scala.io.Source
 2 import scala.util.Random
 3 
 4 /**
 5  * @author vincent
 6  *
 7  */
 8 object LocalKMeans {
 9     def main(args: Array[String]) {
10         val fileName = "/home/vincent/kmeans_data.txt"
11         val knumbers = 3
12         val rand = new Random()
13 
14         //  讀取文本數據
15         val lines = Source.fromFile(fileName).getLines.toArray
16         val points = lines.map(line => {
17             val parts = line.split("\t").map(_.toDouble)
18             new Point(parts(0), parts(1))
19         }).toArray
20         
21         //  隨機初始化k個質心
22         val centroids = new Array[Point](knumbers)
23         for (i <- 0 until knumbers) {
24             centroids(i) = points(new Random().nextInt(points.length))
25         }
26         val startTime = System.currentTimeMillis()
27         println("initialize centroids:\n" + centroids.mkString("\n") + "\n")
28         println("test points: \n" + points.mkString("\n") + "\n")
29 
30         val resultCentroids = kmeans(points, centroids, 0.001)
31         
32         val endTime = System.currentTimeMillis()
33         val runTime = endTime - startTime
34         println("run Time: " + runTime + "\nFinal centroids: \n" + resultCentroids.mkString("\n"))
35     }
36     
37     //  算法的核心函數
38     def kmeans(points: Seq[Point], centroids: Seq[Point], epsilon: Double): Seq[Point] = {
39         //  最近質心為key值,將數據集分簇
40         val clusters = points.groupBy(closestCentroid(centroids, _))
41         println("clusters: \n" + clusters.mkString("\n") + "\n")
42         //  分別計算簇中數據集的平均數,得到每個簇的新質心
43         val newCentroids = centroids.map(oldCentroid => {
44             clusters.get(oldCentroid) match {
45                 case Some(pointsInCluster) => pointsInCluster.reduceLeft(_ + _) / pointsInCluster.length
46                 case None => oldCentroid
47             }
48         })
49         //  計算新質心相對與舊質心的偏移量
50         val movement = (centroids zip newCentroids).map({ case (a, b) => a distance b })
51         println("Centroids changed by\n" + movement.map(d => "%3f".format(d)).mkString("(", ", ", ")")
52             + "\nto\n" + newCentroids.mkString(", ") + "\n")
53         //  根據偏移值大小決定是否繼續迭代,epsilon為最小偏移值
54         if (movement.exists(_ > epsilon))
55             kmeans(points, newCentroids, epsilon)
56         else
57             return newCentroids
58     }
59 
60     //  計算最近質心
61     def closestCentroid(centroids: Seq[Point], point: Point) = {
62         centroids.reduceLeft((a, b) => if ((point distance a) < (point distance b)) a else b)
63     }
64 }

 

自定義Point類:

 1 /**
 2  * @author vincent
 3  *
 4  */
 5 object Point {
 6     def random() = {
 7         new Point(math.random * 50, math.random * 50)
 8     }
 9 }
10 
11 case class Point(val x: Double, val y: Double) {
12     def +(that: Point) = new Point(this.x + that.x, this.y + that.y)
13     def -(that: Point) = new Point(this.x - that.x, this.y - that.y)
14     def /(d: Double) = new Point(this.x / d, this.y / d)
15     def pointLength = math.sqrt(x * x + y * y)
16     def distance(that: Point) = (this - that).pointLength
17     override def toString = format("(%.3f, %.3f)", x, y)
18 }

測試數據集:

12.044996    36.412378
31.881257    33.677009
41.703139    46.170517
43.244406    6.991669
19.319000    27.926669
3.556824    40.935215
29.328655    33.303675
43.702858    22.305344
28.978940    28.905725
10.426760    40.311507

 

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM