1、概念
VectorSlicer是一種轉換器,它接受特征向量並輸出帶有原始特征子數組的新特征向量。這對於從向量列中提取特征很有用。
VectorSlicer接受具有指定索引的向量列,然后輸出一個新的向量列,其值通過這些索引選擇。索引有兩種類型,
整數索引,代表向量setIndices()的索引。
字符串索引,表示向量中的要素名稱setNames()。這需要向量列具有AttributeGroup,因為實現在Attribute的名稱字段上匹配。
以整數和字符串指定均可接受。此外,可以同時使用整數索引和字符串名稱。必須至少選擇一項功能。不允許同時選擇相同的特征項,因此所選索引和名稱之間不能有重疊。請注意,如果選擇了特征名稱,則在遇到空輸入屬性時將引發異常。
輸出矢量將首先對具有選定索引的特征(以給定的順序)進行排序,然后是選定名稱(以給定的順序)進行排序。
2、code
package com.home.spark.ml import java.util import org.apache.spark.SparkConf import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.feature.VectorSlicer import org.apache.spark.ml.linalg.Vectors import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{Row, SparkSession} /** * @Description: 向量切片 * 此類采用特征向量,並輸出帶有原始特征子數組的新特征向量。 * * 可以使用索引(setIndices())或名稱(setNames())來指定功能的子集。 必須至少選擇一項功能。 不允許使用重復功能,因此所選索引和名稱之間不能有重疊。 * * 輸出矢量將首先對具有選定索引的要素(以給定的順序)進行排序,然后是選定名稱(以給定的順序)進行排序。 **/ object Ex_VectorSlicer { def main(args: Array[String]): Unit = { val conf: SparkConf = new SparkConf(true).setMaster("local[2]").setAppName("spark ml") val spark = SparkSession.builder().config(conf).getOrCreate() val data = util.Arrays.asList( Row(Vectors.sparse(3, Seq((0, -2.0), (1, 2.3)))), Row(Vectors.dense(-2.0, 2.3, 0.0)) ) val defaultAttr = NumericAttribute.defaultAttr val attrs = Array("f1", "f2", "f3").map(defaultAttr.withName) val attrGroup = new AttributeGroup("userFeatures", attrs.asInstanceOf[Array[Attribute]]) val dataset = spark.createDataFrame(data, StructType(Array(attrGroup.toStructField()))) val slicer = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") slicer.setIndices(Array(1)).setNames(Array("f3")) // or slicer.setIndices(Array(1, 2)), or slicer.setNames(Array("f2", "f3")) val output = slicer.transform(dataset) output.show(false) val dataset2 = spark.createDataFrame( Seq( (0, 18, 1.0, Vectors.dense(0.0, 10.0, 0.5,99.0,100), 1.0), (0, 18, 1.0, Vectors.dense(0.0, 11.0,15.0,22.0,101), 0.0)) ).toDF("id", "hour", "mobile", "userFeatures", "clicked") val slicer2 = new VectorSlicer().setInputCol("userFeatures").setOutputCol("features") slicer2.setIndices(Array(1,3,4)) val result = slicer2.transform(dataset2) result.show(false) spark.stop() } }
+--------------------+-------------+
|userFeatures |features |
+--------------------+-------------+
|(3,[0,1],[-2.0,2.3])|(2,[0],[2.3])|
|[-2.0,2.3,0.0] |[2.3,0.0] |
+--------------------+-------------+
+---+----+------+--------------------------+-------+-----------------+
|id |hour|mobile|userFeatures |clicked|features |
+---+----+------+--------------------------+-------+-----------------+
|0 |18 |1.0 |[0.0,10.0,0.5,99.0,100.0] |1.0 |[10.0,99.0,100.0]|
|0 |18 |1.0 |[0.0,11.0,15.0,22.0,101.0]|0.0 |[11.0,22.0,101.0]|
+---+----+------+--------------------------+-------+-----------------+
