Spark MLlib 的官方例子里面提供的數據大部分是 libsvm 格式的。這其實是一種非常蛋疼的文件格式,和常見的二維表格形式相去甚遠,下圖是里面的一個例子:

完整代碼
libsvm 文件的基本格式如下:
<label> <index1>:<value1> <index2>:<value2>…
label 為類別標識,index 為特征序號,value 為特征取值。如上圖中第一行中 0
為標簽,128:51
表示第 128 個特征取值為 51 。
Spark 固然提供了讀取 libsvm 文件的API,然而如果想把這些數據放到別的庫 (比如scikit-learn) 中使用,就不得不面臨一個格式轉換的問題了。由於 CSV 文件是廣大人民群眾喜聞樂見的文件格式,因此分別用 Python 和Java 寫一個程序來進行轉換。我在網上查閱了一下,基本上全是 csv 轉 libsvm,很少有 libsvm 轉 csv 的,唯一的一個是 phraug
庫中的libsvm2csv.py
。但這個實現有兩個缺點: 一個是需要事先指定維度; 另一個是像上圖中的特征序號是 128 - 658
,這樣轉換完之后 0 - 127
維的特征全為 0,就顯得多余了,而比較好的做法是將全為 0 的特征列一並去除。下面是基於 Python 的實現:
import sys
import csv
import numpy as np
def empty_table(input_file): # 建立空表格, 維數為原數據集中最大特征維數
max_feature = 0
count = 0
with open(input_file, 'r', newline='') as f:
reader = csv.reader(f, delimiter=" ")
for line in reader:
count += 1
for i in line:
num = int(i.split(":")[0])
if num > max_feature:
max_feature = num
return np.zeros((count, max_feature + 1))
def write(input_file, output_file, table):
with open(input_file, 'r', newline='') as f:
reader = csv.reader(f, delimiter=" ")
for c, line in enumerate(reader):
label = line.pop(0)
table[c, 0] = label
if line[-1].strip() == '':
line.pop(-1)
line = map(lambda x : tuple(x.split(":")), line)
for i, v in line:
i = int(i)
table[c, i] = v
delete_col = []
for col in range(table.shape[1]):
if not any(table[:, col]):
delete_col.append(col)
table = np.delete(table, delete_col, axis=1) # 刪除全 0 列
with open(output_file, 'w') as f:
writer = csv.writer(f)
for line in table:
writer.writerow(line)
if __name__ == "__main__":
input_file = sys.argv[1]
output_file = sys.argv[2]
table = empty_table(input_file)
write(input_file, output_file, table)
以下基於 Java 來實現,不得不說 Java 由於沒有 Numpy 這類庫的存在,寫起來要繁瑣得多。
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
public class LibsvmToCsv {
public static void main(String[] args) throws IOException {
String src = args[0];
String dest = args[1];
double[][] table = EmptyTable(src);
double[][] newcsv = NewCsv(table, src);
write(newcsv, dest);
}
// 建立空表格, 維數為原數據集中最大特征維數
public static double[][] EmptyTable(String src) throws IOException {
int maxFeatures = 0, count = 0;
File f = new File(src);
BufferedReader br = new BufferedReader(new FileReader(f));
String temp = null;
while ((temp = br.readLine()) != null){
count++;
for (String pair : temp.split(" ")){
int num = Integer.parseInt(pair.split(":")[0]);
if (num > maxFeatures){
maxFeatures = num;
}
}
}
double[][] emptyTable = new double[count][maxFeatures + 1];
return emptyTable;
}
public static double[][] NewCsv(double[][] newTable, String src) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(src)));
String temp = null;
int count = 0;
while ((temp = br.readLine()) != null){
String[] array = temp.split(" ");
double label = Integer.parseInt(array[0]);
for (String pair : Arrays.copyOfRange(array, 1, array.length)){
String[] pairs = pair.split(":");
int index = Integer.parseInt(pairs[0]);
double value = Double.parseDouble(pairs[1]);
newTable[count][index] = value;
}
newTable[count][0] = label;
count++;
}
List<Integer> deleteCol = new ArrayList<>(); // 要刪除的全 0 列
int deleteColNum = 0;
coll:
for (int col = 0; col < newTable[0].length; col++){
int zeroCount = 0;
for (int row = 0; row < newTable.length; row++){
if (newTable[row][col] != 0.0){
continue coll; // 若有一個值不為 0, 繼續判斷下一列
} else {
zeroCount++;
}
}
if (zeroCount == newTable.length){
deleteCol.add(col);
deleteColNum++;
}
}
int newColNum = newTable[0].length - deleteColNum;
double[][] newCsv = new double[count][newColNum]; // 新的不帶全 0 列的空表格
int newCol = 0;
colll:
for (int col = 0; col < newTable[0].length; col++){
for (int dCol : deleteCol){
if (col == dCol){
continue colll;
}
}
for (int row = 0; row < newTable.length; row++){
newCsv[row][newCol] = newTable[row][col];
}
newCol++;
}
return newCsv;
}
public static void write(double[][] table, String path) throws FileNotFoundException {
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(path)));
try{
for (double[] row : table){
int countComma = 0;
for (double c : row){
countComma ++;
bw.write(String.valueOf(c));
if (countComma <= row.length - 1){
bw.append(',');
}
}
bw.flush();
bw.newLine();
}
} catch (IOException e){
e.printStackTrace();
} finally {
try{
if (bw != null){
bw.close();
}
} catch (IOException e){
e.printStackTrace();
}
}
}
}
/