一.smote相關理論
(1).
SMOTE是一種對普通過采樣(oversampling)的一個改良。普通的過采樣會使得訓練集中有很多重復的樣本。
SMOTE的全稱是Synthetic Minority Over-Sampling Technique,譯為“人工少數類過采樣法”。
SMOTE沒有直接對少數類進行重采樣,而是設計了算法來人工合成一些新的少數類的樣本。
為了敘述方便,就假設陽性為少數類,陰性為多數類
合成新少數類的陽性樣本的算法如下:
- 選定一個陽性樣本ss
- 找到ss最近的kk個樣本,kk可以取5,10之類。這kk個樣本可能有陽性的也有陰性的。
- 從這kk個樣本中隨機挑選一個樣本,記為rr。
- 合成一個新的陽性樣本s′s′,s′=λs+(1−λ)rs′=λs+(1−λ)r,λλ是(0,1)(0,1)之間的隨機數。換句話說,新生成的點在rr與ss之間的連線上。
重復以上步驟,就可以生成很多陽性樣本。
=======畫了幾張圖,更新一下======
用圖的形式說明一下SMOTE的步驟:
1.先選定一個陽性樣本(假設陽性為少數類)
2.找出這個陽性樣本的k近鄰(假設k=5)。5個近鄰已經被圈出。
3.隨機從這k個近鄰中選出一個樣本(用綠色圈出來了)。
4.在陽性樣本和被選出的這個近鄰之間的連線上,隨機找一點。這個點就是人工合成的新的陽性樣本(綠色正號標出)。
以上來自http://sofasofa.io/forum_main_post.php?postid=1000817中的敘述
(2).
With this approach, the positive class is over-sampled by taking each minority class sample and introducing synthetic examples along the line segments joining any/all of the k minority class nearest neighbours. Depending upon the amount of over-sampling required, neighbours from the k nearest neighbours are randomly chosen. This process is illustrated in the following Figure, where xixi is the selected point, xi1xi1 to xi4xi4are some selected nearest neighbours and r1r1 to r4r4 the synthetic data points created by the randomized interpolation. The implementation of this work uses only one nearest neighbour with the euclidean distance, and balances both classes to 50% distribution.
Synthetic samples are generated in the following way: Take the difference between the feature vector (sample) under consideration and its nearest neighbour. Multiply this difference by a random number between 0 and 1, and add it to the feature vector under consideration. This causes the selection of a random point along the line segment between two specific features. This approach effectively forces the decision region of the minority class to become more general. An example is detailed in the next Figure.
In short, the main idea is to form new minority class examples by interpolating between several minority class examples that lie together. In contrast with the common replication techniques (for example random oversampling), in which the decision region usually become more specific, with SMOTE the overfitting problem is somehow avoided by causing the decision boundaries for the minority class to be larger and to spread further into the majority class space, since it provides related minority class samples to learn from. Specifically, selecting a small k-value could also avoid the risk of including some noise in the data.
以上來自https://sci2s.ugr.es/multi-imbalanced中的敘述
二.spark實現smote
核心代碼如下,完整代碼https://github.com/jiangnanboy/spark-smote/blob/master/spark%20smote.txt
1 /** 2 * (1) 對於少數類(X)中每一個樣本x,計算它到少數類樣本集(X)中所有樣本的距離,得到其k近鄰。 3 * (2) 根據樣本不平衡比例設置一個采樣比例以確定采樣倍率sampling_rate,對於每一個少數類樣本x, 4 * 從其k近鄰中隨機選擇sampling_rate個近鄰,假設選擇的近鄰為 x(1),x(2),...,x(sampling_rate)。 5 * (3) 對於每一個隨機選出的近鄰 x(i)(i=1,2,...,sampling_rate),分別與原樣本按照如下的公式構建新的樣本 6 * xnew=x+rand(0,1)?(x(i)?x) 7 * 8 * http://sofasofa.io/forum_main_post.php?postid=1000817 9 * http://sci2s.ugr.es/multi-imbalanced 10 * @param session 11 * @param labelFeatures 12 * @param knn 樣本相似近鄰 13 * @param samplingRate 近鄰采樣率 (knn * samplingRate),從knn中選擇幾個近鄰 14 * @parm rationToMax 采樣比率(與最多類樣本數的比率) 0.1表示與最多樣本的比率是 -> (1:10),即達到最多樣本的比率 15 * @return 16 */ 17 public static Dataset<Row> smote(SparkSession session, Dataset<Row> labelFeatures, int knn, double samplingRate, double rationToMax) { 18 19 Dataset<Row> labelCountDataset = labelFeatures.groupBy("label").agg(count("label").as("keyCount")); 20 List<Row> listRow = labelCountDataset.collectAsList(); 21 ConcurrentMap<String, Long> keyCountConMap = new ConcurrentHashMap<>(); //每個label對應的樣本數 22 for(Row row : listRow) 23 keyCountConMap.put(row.getString(0), row.getLong(1)); 24 Row maxSizeRow = labelCountDataset.select(max("keyCount").as("maxSize")).first(); 25 long maxSize = maxSizeRow.getAs("maxSize");//最大樣本數 26 27 JavaPairRDD<String, SparseVector> sparseVectorJPR = labelFeatures.toJavaRDD().mapToPair(row -> { 28 String label = row.getString(0); 29 SparseVector features = (SparseVector) row.get(1); 30 return new Tuple2<String, SparseVector>(label, features); 31 }); 32 33 JavaPairRDD<String, List<SparseVector>> combineByKeyPairRDD = sparseVectorJPR.combineByKey(sparseVector -> { 34 List<SparseVector> list = new ArrayList<>(); 35 list.add(sparseVector); 36 return list; 37 }, (list, sparseVector) -> {list.add(sparseVector);return list;}, 38 (list_A, list_B) -> {list_A.addAll(list_B);return list_A;}); 39 40 41 JavaSparkContext jsc = JavaSparkContext.fromSparkContext(session.sparkContext()); 42 final Broadcast<ConcurrentMap<String, Long>> keyCountBroadcast = jsc.broadcast(keyCountConMap); 43 final Broadcast<Long> maxSizeBroadcast = jsc.broadcast(maxSize); 44 final Broadcast<Integer> knnBroadcast = jsc.broadcast(knn); 45 final Broadcast<Double> samplingRateBroadcast = jsc.broadcast(samplingRate); 46 final Broadcast<Double> rationToMaxBroadcast = jsc.broadcast(rationToMax); 47 48 /** 49 * JavaPairRDD<String, List<SparseVector>> 50 * JavaPairRDD<String, String> 51 * JavaRDD<Row> 52 */ 53 JavaPairRDD<String, List<SparseVector>> pairRDD = combineByKeyPairRDD 54 .filter(slt -> { 55 return slt._2().size() > 1; 56 }) 57 .mapToPair(slt -> { 58 String label = slt._1(); 59 ConcurrentMap<String, Long> keySizeConMap = keyCountBroadcast.getValue(); 60 long oldSampleSize = keySizeConMap.get(label); 61 long max = maxSizeBroadcast.getValue(); 62 double ration = rationToMaxBroadcast.getValue(); 63 int Knn = knnBroadcast.getValue(); 64 double rate = samplingRateBroadcast.getValue(); 65 if (oldSampleSize < maxSize * rationToMax) { 66 int needSampleSize = (int) (max * ration - oldSampleSize); 67 List<SparseVector> list = generateSample(slt._2(), needSampleSize, Knn, rate); 68 return new Tuple2<String, List<SparseVector>>(label, list); 69 } else { 70 return slt; 71 } 72 }); 73 74 JavaRDD<Row> javaRowRDD = pairRDD.flatMapToPair(slt -> { 75 List<Tuple2<String, SparseVector>> floatPairList = new ArrayList<>(); 76 String label = slt._1(); 77 for(SparseVector sv : slt._2()) 78 floatPairList.add(new Tuple2<String, SparseVector>(label, sv)); 79 return floatPairList.iterator(); 80 }).map(svt->{ 81 return RowFactory.create(svt._1(), svt._2()); 82 }); 83 84 Dataset<Row> resultDataset = session.createDataset(javaRowRDD.rdd(), EncoderInit.getlabelFeaturesRowEncoder()); 85 return resultDataset; 86 }