實驗內容
提交內容
代碼+文檔+數據結果,打包成zip文件,文件名“學號_姓名_第三次作業”
代碼:最好是Python或R
文檔:下列形式之一(或多種結合):
- Jupyter Notebook(導出為html)
- Markdown
- 源代碼
數據結果:CSV文件
請確保提交的結果不需要我們重新運行,並且代碼和文檔的邏輯清晰易懂。
數據集說明
數據集為某個零售商某天的交易情況,包含商品id、商品類別和銷售量。每筆交易帶有顧客id,以及顧客的相關信息,例如已婚或未婚。
你的任務
在婚姻情況(列名“Married”)一列中,1代表已婚,0代表未婚,空白代表未知。對於未知婚姻情況的顧客,請發揮你的聰明才智填補空白。
請提交一份CSV格式的數據結果,只包含原先是未知的行,但將未知替換成你的預測結果。
在你的文檔中需要着重說明以下問題:
1、 對於預測結果你有多少把握?請量化評估結果,並說明你的評估過程。
較大把握,本實驗采用的AUC指標和准確率二分類結果進行評估
2、 你的預測結果還能怎樣提升?
本文采用了網格調優訓練最佳模型、k折交叉驗證找出最佳模型 和更換隨機森林分類器模型進行優化,除此之外我認為:
- 采用更多的模型進行訓練與預測
- 模型集成與融合
- 對數據進行更細致和符合現實的處理,做好特征工程,比如,一開始我將年齡階段
age
和在城市的時間YearsInCity
當作類型特征,直接進行OneHot處理了,但是會使其喪失數據意義,應將其編碼為數字,使其符合現實意義,這樣處理完之后auc和准確率均有提升。- 進行數據增強
- 更多模型調參
- 數據預處理更為細致,進行標准化處理
3、 你認為零售商還可以怎樣使用這類數據集?
挖掘人們經常需要什么,或者說在短時間間隔內,某個顧客會先后購買哪幾件商品,可以進行組合售賣和優惠。
代碼模型
環境安裝
實驗是在GoogleCloab訓練平台完成,第一步安裝pyspark和進入數據csv文件目錄
!pip install pyspark
Collecting pyspark
Downloading pyspark-3.2.0.tar.gz (281.3 MB)
[K |████████████████████████████████| 281.3 MB 35 kB/s
[?25hCollecting py4j==0.10.9.2
Downloading py4j-0.10.9.2-py2.py3-none-any.whl (198 kB)
[K |████████████████████████████████| 198 kB 53.7 MB/s
[?25hBuilding wheels for collected packages: pyspark
Building wheel for pyspark (setup.py) ... [?25l[?25hdone
Created wheel for pyspark: filename=pyspark-3.2.0-py2.py3-none-any.whl size=281805912 sha256=f287bbfe5f7b98931b38aa4cd87d34eb0c285b2ab74ced7d49488722fd0ac9ab
Stored in directory: /root/.cache/pip/wheels/0b/de/d2/9be5d59d7331c6c2a7c1b6d1a4f463ce107332b1ecd4e80718
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.2 pyspark-3.2.0
import os
path = '/content/drive/MyDrive/作業/數據智能技術/預測結婚狀態'
os.chdir(path)
!ls
1.csv 4.csv pyspark二分類-是否結婚.ipynb result RetailCustomerSales2.csv
3.csv 6.csv res result1
導入pyspark包
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql import SparkSession
from pyspark.ml.feature import StringIndexer,OneHotEncoder,VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
選擇用 sqlContext.read 導入數據文件RetailCustomerSales2.csv
,並創建df為DataFrame格式,並簡單的查看一下數據信息
spark = SparkSession\
.builder\
.appName("MaritalStatusClassification")\
.getOrCreate()
df = spark.read.format("csv").option("header", "true").option("delimiter", ",").load(r"RetailCustomerSales2.csv")
print(df.count())
df.printSchema()
517407
root
|-- CustomerID: string (nullable = true)
|-- ItemID: string (nullable = true)
|-- Sex: string (nullable = true)
|-- Age: string (nullable = true)
|-- Profession: string (nullable = true)
|-- CityType: string (nullable = true)
|-- YearsInCity: string (nullable = true)
|-- Married: string (nullable = true)
|-- ItemCategory1: string (nullable = true)
|-- ItemCategory2: string (nullable = true)
|-- ItemCategory3: string (nullable = true)
|-- Amount: string (nullable = true)
簡單的查看一下前五條數據,瀏覽一下數據
df.show(5)
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
|CustomerID| ItemID|Sex| Age|Profession|CityType|YearsInCity|Married|ItemCategory1|ItemCategory2|ItemCategory3|Amount|
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
| 1000001|P00069042| F|0-17| 10| A| 2| 0| 3| null| null| 8370|
| 1000001|P00248942| F|0-17| 10| A| 2| 0| 1| 6| 14| 15200|
| 1000001|P00087842| F|0-17| 10| A| 2| 0| 12| null| null| 1422|
| 1000001|P00085442| F|0-17| 10| A| 2| 0| 12| 14| null| 1057|
| 1000001|P00085942| F|0-17| 10| A| 2| 0| 2| 4| 8| 12842|
+----------+---------+---+----+----------+--------+-----------+-------+-------------+-------------+-------------+------+
only showing top 5 rows
特征工程
數據清洗
首先,需要空值或者異常值處理,
- 數據表中
ItemCategory1
,ItemCategory2
,ItemCategory3
列有數據值為null
,需要對其轉換,將null
轉為0,需要先觀察數據是否已經存在0,使用df.groupby('ItemCategory1').count().show()觀察。 YearsInCity
列有一個特殊值4+
,我們需要對其轉化為4
保留其數值意義
具體采用UDF函數編寫,並將string類型改為相應的實際類型。
首先,觀察列中數據有什么,防止轉換錯誤
df.groupby('ItemCategory1').count().show()
df.groupby('ItemCategory2').count().show()
df.groupby('ItemCategory3').count().show()
df.groupby('YearsInCity').count().show()
+-------------+------+
|ItemCategory1| count|
+-------------+------+
| 7| 3493|
| 15| 5992|
| 11| 23121|
| 3| 19207|
| 8|107589|
| 16| 9323|
| 5|143167|
| 18| 2960|
| 17| 546|
| 6| 19306|
| 9| 391|
| 1|133215|
| 10| 4857|
| 4| 11123|
| 12| 3763|
| 13| 5210|
| 14| 1455|
| 2| 22689|
+-------------+------+
+-------------+------+
|ItemCategory2| count|
+-------------+------+
| 7| 587|
| 15| 36055|
| 11| 13451|
| 3| 2761|
| 8| 60502|
| 16| 40923|
| 5| 24928|
| 18| 2624|
| 17| 12639|
| 6| 15592|
| 9| 5413|
| 10| 2870|
| 4| 24384|
| 12| 5241|
| 13| 9942|
| 14| 52222|
| 2| 46770|
| null|160503|
+-------------+------+
+-------------+------+
|ItemCategory3| count|
+-------------+------+
| 15| 26694|
| 11| 1701|
| 3| 584|
| 8| 11923|
| 16| 30920|
| 5| 15809|
| 18| 4381|
| 17| 15848|
| 6| 4640|
| 9| 11021|
| 10| 1639|
| 4| 1792|
| 12| 8816|
| 13| 5143|
| 14| 17451|
| null|359045|
+-------------+------+
+-----------+------+
|YearsInCity| count|
+-----------+------+
| 3| 89565|
| 0| 68774|
| 4+| 79392|
| 1|183627|
| 2| 96049|
+-----------+------+
然后,定義UDF轉換函數,對異常值處理
from pyspark.sql.functions import udf
def replace_col(x):
if x == "0-17":
return 1.0
elif x == "18-25":
return 2.0
elif x == "26-35":
return 3.0
elif x == "36-45":
return 4.0
elif x == "46-50":
return 5.0
elif x == "51-55":
return 6.0
elif x == "55+":
return 7.0
elif x == None:
return "0"
elif x == "4+":
return "4"
return x
replace_col = udf(replace_col)
對數據進行符合實際類型轉換和應用UDF函數處理
from pyspark.sql.functions import col
import pyspark.sql.types
# ['ItemID', 'Sex', 'Age', 'Profession', 'CityType','YearsInCity','Married'] + |ItemCategory1|ItemCategory2|ItemCategory3|Amount|
clean_df = df.select(['ItemID', 'Sex', 'CityType', 'Profession'] + [replace_col(col('age')).cast("double").alias('age')] +
[replace_col(col('YearsInCity')).cast("double").alias('YearsInCity')] +
[replace_col(col('Amount')).cast("double").alias('Amount')] +
[replace_col(col(column)).cast("string").alias(column) for column in df.columns[8:11]] +
[col('Married').cast("double").alias('Married')])
clean_df.printSchema()
clean_df.show()
root
|-- ItemID: string (nullable = true)
|-- Sex: string (nullable = true)
|-- CityType: string (nullable = true)
|-- Profession: string (nullable = true)
|-- age: double (nullable = true)
|-- YearsInCity: double (nullable = true)
|-- Amount: double (nullable = true)
|-- ItemCategory1: string (nullable = true)
|-- ItemCategory2: string (nullable = true)
|-- ItemCategory3: string (nullable = true)
|-- Married: double (nullable = true)
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
| ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|P00069042| F| A| 10|1.0| 2.0| 8370.0| 3| 0| 0| 0.0|
|P00248942| F| A| 10|1.0| 2.0|15200.0| 1| 6| 14| 0.0|
|P00087842| F| A| 10|1.0| 2.0| 1422.0| 12| 0| 0| 0.0|
|P00085442| F| A| 10|1.0| 2.0| 1057.0| 12| 14| 0| 0.0|
|P00085942| F| A| 10|1.0| 2.0|12842.0| 2| 4| 8| 0.0|
|P00102642| F| A| 10|1.0| 2.0| 2763.0| 4| 8| 9| 0.0|
|P00110842| F| A| 10|1.0| 2.0|11769.0| 1| 2| 5| 0.0|
|P00004842| F| A| 10|1.0| 2.0|13645.0| 3| 4| 12| 0.0|
|P00117942| F| A| 10|1.0| 2.0| 8839.0| 5| 15| 0| 0.0|
|P00258742| F| A| 10|1.0| 2.0| 6910.0| 5| 0| 0| 0.0|
|P00142242| F| A| 10|1.0| 2.0| 7882.0| 8| 0| 0| 0.0|
|P00000142| F| A| 10|1.0| 2.0|13650.0| 3| 4| 5| 0.0|
|P00297042| F| A| 10|1.0| 2.0| 7839.0| 8| 0| 0| 0.0|
|P00059442| F| A| 10|1.0| 2.0|16622.0| 6| 8| 16| 0.0|
| P0096542| F| A| 10|1.0| 2.0|13627.0| 3| 4| 12| 0.0|
|P00184942| F| A| 10|1.0| 2.0|19219.0| 1| 8| 17| 0.0|
|P00051842| F| A| 10|1.0| 2.0| 2849.0| 4| 8| 0| 0.0|
|P00214842| F| A| 10|1.0| 2.0|11011.0| 14| 0| 0| 0.0|
|P00165942| F| A| 10|1.0| 2.0|10003.0| 8| 0| 0| 0.0|
|P00111842| F| A| 10|1.0| 2.0| 8094.0| 8| 0| 0| 0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
only showing top 20 rows
獨熱編碼
因為'Amount', 'YearsInCity', 'Age'數值有實際意義,比如年齡越大,一般結婚的概率越大,所以需要保存數值的含義。不需要對其獨熱編碼。
對離散型特征(類型變量)處理采用獨熱向量編碼
流程為StringIndexer --> OneHotEncoder --> VectorAssembler
columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
def oneHotEncoder(col, df):
stringIndexer = StringIndexer(inputCol=col, outputCol=col+"Index")
model = stringIndexer.fit(df)
indexed = model.transform(df)
oneHotEncoder = OneHotEncoder(dropLast=False, inputCol=col+"Index", outputCol=col+"Vec")
encoder = oneHotEncoder.fit(indexed)
return encoder.transform(indexed)
for i in range(3, len(columns)):
clean_df = oneHotEncoder(columns[i], clean_df)
clean_df.show()
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
| ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|ItemIDIndex| ItemIDVec|SexIndex| SexVec|CityTypeIndex| CityTypeVec|ProfessionIndex| ProfessionVec|ItemCategory1Index|ItemCategory1Vec|ItemCategory2Index|ItemCategory2Vec|ItemCategory3Index|ItemCategory3Vec|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
|P00069042| F| A| 10|1.0| 2.0| 8370.0| 3| 0| 0| 0.0| 758.0| (3620,[758],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 6.0| (18,[6],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00248942| F| A| 10|1.0| 2.0|15200.0| 1| 6| 14| 0.0| 181.0| (3620,[181],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 1.0| (18,[1],[1.0])| 8.0| (18,[8],[1.0])| 3.0| (16,[3],[1.0])|
|P00087842| F| A| 10|1.0| 2.0| 1422.0| 12| 0| 0| 0.0| 1506.0|(3620,[1506],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 12.0| (18,[12],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00085442| F| A| 10|1.0| 2.0| 1057.0| 12| 14| 0| 0.0| 475.0| (3620,[475],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 12.0| (18,[12],[1.0])| 2.0| (18,[2],[1.0])| 0.0| (16,[0],[1.0])|
|P00085942| F| A| 10|1.0| 2.0|12842.0| 2| 4| 8| 0.0| 42.0| (3620,[42],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 4.0| (18,[4],[1.0])| 7.0| (18,[7],[1.0])| 6.0| (16,[6],[1.0])|
|P00102642| F| A| 10|1.0| 2.0| 2763.0| 4| 8| 9| 0.0| 17.0| (3620,[17],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 7.0| (18,[7],[1.0])| 1.0| (18,[1],[1.0])| 7.0| (16,[7],[1.0])|
|P00110842| F| A| 10|1.0| 2.0|11769.0| 1| 2| 5| 0.0| 15.0| (3620,[15],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 1.0| (18,[1],[1.0])| 3.0| (18,[3],[1.0])| 5.0| (16,[5],[1.0])|
|P00004842| F| A| 10|1.0| 2.0|13645.0| 3| 4| 12| 0.0| 809.0| (3620,[809],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 8.0| (16,[8],[1.0])|
|P00117942| F| A| 10|1.0| 2.0| 8839.0| 5| 15| 0| 0.0| 11.0| (3620,[11],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 0.0| (18,[0],[1.0])| 5.0| (18,[5],[1.0])| 0.0| (16,[0],[1.0])|
|P00258742| F| A| 10|1.0| 2.0| 6910.0| 5| 0| 0| 0.0| 40.0| (3620,[40],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00142242| F| A| 10|1.0| 2.0| 7882.0| 8| 0| 0| 0.0| 2284.0|(3620,[2284],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 2.0| (18,[2],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00000142| F| A| 10|1.0| 2.0|13650.0| 3| 4| 5| 0.0| 30.0| (3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|
|P00297042| F| A| 10|1.0| 2.0| 7839.0| 8| 0| 0| 0.0| 757.0| (3620,[757],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 2.0| (18,[2],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00059442| F| A| 10|1.0| 2.0|16622.0| 6| 8| 16| 0.0| 9.0| (3620,[9],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 5.0| (18,[5],[1.0])| 1.0| (18,[1],[1.0])| 1.0| (16,[1],[1.0])|
| P0096542| F| A| 10|1.0| 2.0|13627.0| 3| 4| 12| 0.0| 504.0| (3620,[504],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 8.0| (16,[8],[1.0])|
|P00184942| F| A| 10|1.0| 2.0|19219.0| 1| 8| 17| 0.0| 5.0| (3620,[5],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 1.0| (18,[1],[1.0])| 1.0| (18,[1],[1.0])| 4.0| (16,[4],[1.0])|
|P00051842| F| A| 10|1.0| 2.0| 2849.0| 4| 8| 0| 0.0| 935.0| (3620,[935],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 7.0| (18,[7],[1.0])| 1.0| (18,[1],[1.0])| 0.0| (16,[0],[1.0])|
|P00214842| F| A| 10|1.0| 2.0|11011.0| 14| 0| 0| 0.0| 954.0| (3620,[954],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 15.0| (18,[15],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00165942| F| A| 10|1.0| 2.0|10003.0| 8| 0| 0| 0.0| 1835.0|(3620,[1835],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 2.0| (18,[2],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
|P00111842| F| A| 10|1.0| 2.0| 8094.0| 8| 0| 0| 0.0| 232.0| (3620,[232],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 12.0|(21,[12],[1.0])| 2.0| (18,[2],[1.0])| 0.0| (18,[0],[1.0])| 0.0| (16,[0],[1.0])|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-------------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+
only showing top 20 rows
根據Married是否為空划分訓練集train和真實測試集real_test
再將train進一步划分為訓練集和測試集,方便后期調優
real_test = clean_df.filter("Married is null")
train = clean_df.filter("Married is not null")
train_df, test_df = train.randomSplit([0.7, 0.3])
train_df.cache()
test_df.cache()
DataFrame[ItemID: string, Sex: string, CityType: string, Profession: string, age: double, YearsInCity: double, Amount: double, ItemCategory1: string, ItemCategory2: string, ItemCategory3: string, Married: double, ItemIDIndex: double, ItemIDVec: vector, SexIndex: double, SexVec: vector, CityTypeIndex: double, CityTypeVec: vector, ProfessionIndex: double, ProfessionVec: vector, ItemCategory1Index: double, ItemCategory1Vec: vector, ItemCategory2Index: double, ItemCategory2Vec: vector, ItemCategory3Index: double, ItemCategory3Vec: vector]
將我們需要的特征列轉換成1列的行向量features,並統一命名。在建模時,只需使用該集合特征就可以。
assemblerInputs = []
columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
for i in range(3, len(columns)):
assemblerInputs.append(columns[i] + "Vec")
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
train_df=assembler.transform(train_df)
test_df=assembler.transform(test_df)
test_df.columns
['ItemID',
'Sex',
'CityType',
'Profession',
'age',
'YearsInCity',
'Amount',
'ItemCategory1',
'ItemCategory2',
'ItemCategory3',
'Married',
'ItemIDIndex',
'ItemIDVec',
'SexIndex',
'SexVec',
'CityTypeIndex',
'CityTypeVec',
'ProfessionIndex',
'ProfessionVec',
'ItemCategory1Index',
'ItemCategory1Vec',
'ItemCategory2Index',
'ItemCategory2Vec',
'ItemCategory3Index',
'ItemCategory3Vec',
'features']
采用決策樹進行訓練與預測
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features",impurity="gini",maxDepth=25, maxBins=14)
dt_model=dt.fit(train_df)
dt_model
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_80db86b6e90b, depth=25, numNodes=7791, numClasses=2, numFeatures=3698
將訓練好的模型應用到數據集
predictions_train_df = dt_model.transform(train_df)
predictions_test_df = dt_model.transform(test_df)
稍微展示一下預測類別和概率結果
predictions_test_df.select('rawPrediction','probability', 'prediction','Married').take(10)
[Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=0.0),
Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=1.0),
Row(rawPrediction=DenseVector([3605.0, 2429.0]), probability=DenseVector([0.5974, 0.4026]), prediction=0.0, Married=1.0),
Row(rawPrediction=DenseVector([1070.0, 882.0]), probability=DenseVector([0.5482, 0.4518]), prediction=0.0, Married=0.0),
Row(rawPrediction=DenseVector([59.0, 105.0]), probability=DenseVector([0.3598, 0.6402]), prediction=1.0, Married=0.0),
Row(rawPrediction=DenseVector([322.0, 250.0]), probability=DenseVector([0.5629, 0.4371]), prediction=0.0, Married=1.0),
Row(rawPrediction=DenseVector([73.0, 707.0]), probability=DenseVector([0.0936, 0.9064]), prediction=1.0, Married=1.0),
Row(rawPrediction=DenseVector([73.0, 707.0]), probability=DenseVector([0.0936, 0.9064]), prediction=1.0, Married=1.0)]
采用auc對模型進行評估
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
auc_evaluator = BinaryClassificationEvaluator(rawPredictionCol="rawPrediction",labelCol="Married",metricName="areaUnderROC") #使用auc進行評估
acc_evaluator = MulticlassClassificationEvaluator(labelCol="Married", predictionCol="prediction", metricName= "accuracy") #使用准確率進行評估
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.5171257423962644
在訓練集,模型 准確率 為: 0.6445869235551476
在測試集,模型 AUC 指標: 0.5165740574304908
在測試集,模型 准確率 為: 0.6181565099304023
Pipeline建模
上述設計的數據處理模塊可以規范化使用,只需要更換數據集即可
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder,VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
重新加載清洗過后的數據
clean_df = df.select(['ItemID', 'Sex', 'CityType', 'Profession'] + [replace_col(col('age')).cast("double").alias('age')] +
[replace_col(col('YearsInCity')).cast("double").alias('YearsInCity')] +
[replace_col(col('Amount')).cast("double").alias('Amount')] +
[replace_col(col(column)).cast("string").alias(column) for column in df.columns[8:11]] +
[col('Married').cast("double").alias('Married')])
clean_df.printSchema()
clean_df.show()
root
|-- ItemID: string (nullable = true)
|-- Sex: string (nullable = true)
|-- CityType: string (nullable = true)
|-- Profession: string (nullable = true)
|-- age: double (nullable = true)
|-- YearsInCity: double (nullable = true)
|-- Amount: double (nullable = true)
|-- ItemCategory1: string (nullable = true)
|-- ItemCategory2: string (nullable = true)
|-- ItemCategory3: string (nullable = true)
|-- Married: double (nullable = true)
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
| ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
|P00069042| F| A| 10|1.0| 2.0| 8370.0| 3| 0| 0| 0.0|
|P00248942| F| A| 10|1.0| 2.0|15200.0| 1| 6| 14| 0.0|
|P00087842| F| A| 10|1.0| 2.0| 1422.0| 12| 0| 0| 0.0|
|P00085442| F| A| 10|1.0| 2.0| 1057.0| 12| 14| 0| 0.0|
|P00085942| F| A| 10|1.0| 2.0|12842.0| 2| 4| 8| 0.0|
|P00102642| F| A| 10|1.0| 2.0| 2763.0| 4| 8| 9| 0.0|
|P00110842| F| A| 10|1.0| 2.0|11769.0| 1| 2| 5| 0.0|
|P00004842| F| A| 10|1.0| 2.0|13645.0| 3| 4| 12| 0.0|
|P00117942| F| A| 10|1.0| 2.0| 8839.0| 5| 15| 0| 0.0|
|P00258742| F| A| 10|1.0| 2.0| 6910.0| 5| 0| 0| 0.0|
|P00142242| F| A| 10|1.0| 2.0| 7882.0| 8| 0| 0| 0.0|
|P00000142| F| A| 10|1.0| 2.0|13650.0| 3| 4| 5| 0.0|
|P00297042| F| A| 10|1.0| 2.0| 7839.0| 8| 0| 0| 0.0|
|P00059442| F| A| 10|1.0| 2.0|16622.0| 6| 8| 16| 0.0|
| P0096542| F| A| 10|1.0| 2.0|13627.0| 3| 4| 12| 0.0|
|P00184942| F| A| 10|1.0| 2.0|19219.0| 1| 8| 17| 0.0|
|P00051842| F| A| 10|1.0| 2.0| 2849.0| 4| 8| 0| 0.0|
|P00214842| F| A| 10|1.0| 2.0|11011.0| 14| 0| 0| 0.0|
|P00165942| F| A| 10|1.0| 2.0|10003.0| 8| 0| 0| 0.0|
|P00111842| F| A| 10|1.0| 2.0| 8094.0| 8| 0| 0| 0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+
only showing top 20 rows
使用 Pipeline 進行規范建模,流程規范化
columns = ['Amount', 'YearsInCity', 'Age', 'ItemID', 'Sex', 'CityType', 'Profession', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3']
# indexers = [StringIndexer(inputCol=column, outputCol=column+"Index") for column in columns[3:]]
# encoders = [OneHotEncoder(dropLast=False, inputCol=column + "Index", outputCol=column+"Vec") for column in columns[3:]]
for i in range(3, len(columns)):
clean_df = oneHotEncoder(columns[i], clean_df)
assembler = VectorAssembler(inputCols=assemblerInputs, outputCol="features")
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features",impurity="gini",maxDepth=20, maxBins=15)
stages = []
# stages.extend(indexers)
# stages.extend(encoders)
stages.append(assembler)
stages.append(dt)
pipeline = Pipeline(stages=stages)
real_test = clean_df.filter("Married is null")
train = clean_df.filter("Married is not null")
test_df, train_df = train.randomSplit([0.3, 0.7])
train_df.cache()
test_df.cache()
DataFrame[ItemID: string, Sex: string, CityType: string, Profession: string, age: double, YearsInCity: double, Amount: double, ItemCategory1: string, ItemCategory2: string, ItemCategory3: string, Married: double, ItemIDIndex: double, ItemIDVec: vector, SexIndex: double, SexVec: vector, CityTypeIndex: double, CityTypeVec: vector, ProfessionIndex: double, ProfessionVec: vector, ItemCategory1Index: double, ItemCategory1Vec: vector, ItemCategory2Index: double, ItemCategory2Vec: vector, ItemCategory3Index: double, ItemCategory3Vec: vector]
使用 Pipeline 進行訓練
pipelineModel = pipeline.fit(train_df)
pipelineModel.stages[-1]
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_a0b9cd35629d, depth=20, numNodes=4343, numClasses=2, numFeatures=3698
# 用toDebugString[:1000]) 查看訓練好的模型的前1000字節的規則描述
print(pipelineModel.stages[-1].toDebugString[:1000])
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_a0b9cd35629d, depth=20, numNodes=4343, numClasses=2, numFeatures=3698
If (feature 3625 in {1.0})
If (feature 3620 in {0.0})
If (feature 3654 in {1.0})
If (feature 3622 in {0.0})
If (feature 1578 in {1.0})
Predict: 1.0
Else (feature 1578 not in {1.0})
If (feature 2155 in {1.0})
Predict: 1.0
Else (feature 2155 not in {1.0})
If (feature 2474 in {1.0})
Predict: 1.0
Else (feature 2474 not in {1.0})
If (feature 2654 in {1.0})
Predict: 1.0
Else (feature 2654 not in {1.0})
If (feature 2783 in {1.0})
Predict: 1.0
Else (feature 2783 not in {1.0})
If (feature 373 in {1.0})
If (feature 3623 in {0.0})
Predict: 0.0
Else (feature 3623 not in {0.0})
Predict: 1.0
Else (feature 373 not in {1.0})
If (feature 356 in {1.0})
使用 Pipeline 進行預測
predictions_train_df = pipelineModel.transform(train_df)
predictions_test_df = pipelineModel.transform(test_df)
predictions_test_df.show(10)
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
| ItemID|Sex|CityType|Profession|age|YearsInCity| Amount|ItemCategory1|ItemCategory2|ItemCategory3|Married|ItemIDIndex| ItemIDVec|SexIndex| SexVec|CityTypeIndex| CityTypeVec|ProfessionIndex| ProfessionVec|ItemCategory1Index|ItemCategory1Vec|ItemCategory2Index|ItemCategory2Vec|ItemCategory3Index|ItemCategory3Vec| features| rawPrediction| probability|prediction|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
|P00000142| F| A| 0|2.0| 1.0|13382.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 1.0| (21,[1],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...| 0.0|
|P00000142| F| A| 0|3.0| 0.0|13292.0| 3| 4| 5| 1.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 1.0| (21,[1],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...| 0.0|
|P00000142| F| A| 0|4.0| 0.0|10848.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 1.0| (21,[1],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...| 0.0|
|P00000142| F| A| 0|4.0| 1.0|13353.0| 3| 4| 5| 1.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 1.0| (21,[1],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...|[1746.0,1062.0]|[0.62179487179487...| 0.0|
|P00000142| F| A| 1|2.0| 2.0|13317.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 3.0| (21,[3],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...| [901.0,776.0]|[0.53726893261776...| 0.0|
|P00000142| F| A| 1|3.0| 1.0| 8347.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 3.0| (21,[3],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...| [901.0,776.0]|[0.53726893261776...| 0.0|
|P00000142| F| A| 14|3.0| 2.0|10704.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 8.0| (21,[8],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...| [6.0,26.0]| [0.1875,0.8125]| 1.0|
|P00000142| F| A| 2|2.0| 2.0|10783.0| 3| 4| 5| 1.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 7.0| (21,[7],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...|[1330.0,1274.0]|[0.51075268817204...| 0.0|
|P00000142| F| A| 20|3.0| 1.0| 5708.0| 3| 4| 5| 1.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 5.0| (21,[5],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...| [84.0,677.0]|[0.11038107752956...| 1.0|
|P00000142| F| A| 3|3.0| 1.0|13411.0| 3| 4| 5| 0.0| 30.0|(3620,[30],[1.0])| 1.0|(2,[1],[1.0])| 2.0|(3,[2],[1.0])| 11.0|(21,[11],[1.0])| 6.0| (18,[6],[1.0])| 7.0| (18,[7],[1.0])| 5.0| (16,[5],[1.0])|(3698,[30,3621,36...| [1516.0,144.0]|[0.91325301204819...| 0.0|
+---------+---+--------+----------+---+-----------+-------+-------------+-------------+-------------+-------+-----------+-----------------+--------+-------------+-------------+-------------+---------------+---------------+------------------+----------------+------------------+----------------+------------------+----------------+--------------------+---------------+--------------------+----------+
only showing top 10 rows
評估模型的准確率
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.5178070790844446
在訓練集,模型 准確率 為: 0.6361669440160775
在測試集,模型 AUC 指標: 0.5189265745824964
在測試集,模型 准確率 為: 0.6186551241455227
優化
網格調優
機器學習模型需要測試不同參數進行調優
- 采用網格搜索方式ParamGridBuilder對模型中的多個參數進行賦值:設置impurity兩個參數值、maxDepth三個參數值、maxBins三個參數值
- TrainValidationSplit 對各個參數組合得出的指標 AUC 進行排序,尋找最優參數指標
from pyspark.ml.tuning import ParamGridBuilder,TrainValidationSplit
dt = DecisionTreeClassifier(labelCol="Married", featuresCol="features")
paramGrid = ParamGridBuilder()\
.addGrid(dt.impurity, ["gini","entropy"])\
.addGrid(dt.maxDepth, [15, 20, 25])\
.addGrid(dt.maxBins, [20, 25, 30])\
.build()
tvs = TrainValidationSplit(estimator=dt,evaluator=auc_evaluator,estimatorParamMaps=paramGrid,trainRatio=0.8)
stages = stages[:-1]
stages.append(tvs)
tvs_pipeline = Pipeline(stages = stages)
tvs_pipelineModel =tvs_pipeline.fit(train_df)
bestModel=tvs_pipelineModel.stages[-1].bestModel
bestModel
DecisionTreeClassificationModel: uid=DecisionTreeClassifier_0fb7d8a2ab72, depth=15, numNodes=1581, numClasses=2, numFeatures=3698
predictions_train_df = tvs_pipelineModel.transform(train_df)
predictions_test_df = tvs_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.5229827538027995
在訓練集,模型 准確率 為: 0.6236324931533902
在測試集,模型 AUC 指標: 0.5238498834275354
在測試集,模型 准確率 為: 0.6152576002609298
crossValidation模型評估
進一步,可用 crossValidation 交叉驗證法,對數據進行 K-Fold 訓練及驗證,得到更穩定的模型。k-Fold交叉驗證可以得到可靠穩定的模型,減少過度擬合,一般常用10-Fold。k越大效果越好但是所需時間也越多。
from pyspark.ml.tuning import CrossValidator
cv = CrossValidator(estimator=dt, evaluator=auc_evaluator, estimatorParamMaps=paramGrid, numFolds=3)
stages = stages[:-1]
stages.append(cv)
cv_pipeline = Pipeline(stages = stages)
cv_pipelineModel = cv_pipeline.fit(train_df)
predictions_train_df = cv_pipelineModel.transform(train_df)
predictions_test_df = cv_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.53819033912944
在訓練集,模型 准確率 為: 0.617298752375283
在測試集,模型 AUC 指標: 0.543748238239042
在測試集,模型 准確率 為: 0.613242342342334
改變模型
比如,使用隨機森林 RandomForestClassifier 進行數據訓練
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(labelCol="Married", featuresCol="features", numTrees=40)
stages = stages[:-1]
stages.append(rf)
rf_pipeline = Pipeline(stages=stages)
rf_pipelineModel = rf_pipeline.fit(train_df)
predictions_train_df = rf_pipelineModel.transform(train_df)
predictions_test_df = rf_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.6516250949994256
在訓練集,模型 准確率 為: 0.7350481045323578
在測試集,模型 AUC 指標: 0.6321354323217168
在測試集,模型 准確率 為: 0.7080452396836528
使用隨機森林后,AUC提升明顯,結合TrainValidation找出最佳模型看:
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.classification import RandomForestClassifier
paramGrid = ParamGridBuilder()\
.addGrid(rf.impurity, [ "gini","entropy"])\
.addGrid(rf.maxDepth, [15,20,25])\
.addGrid(rf.maxBins, [10,15,20])\
.addGrid(rf.numTrees, [20,30,40])\
.build()
rftvs = TrainValidationSplit(estimator=rf, evaluator=auc_evaluator, estimatorParamMaps=paramGrid, trainRatio=0.8)
stages = stages[:-1]
stages.append(rftvs)
rftvs_pipeline = Pipeline(stages=stages)
rftvs_pipelineModel = rftvs_pipeline.fit(train_df)
predictions_train_df = rftvs_pipelineModel.transform(train_df)
predictions_test_df = rftvs_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.683578899079912
在訓練集,模型 准確率 為: 0.755178412678924
在測試集,模型 AUC 指標: 0.689013140345656
在測試集,模型 准確率 為: 0.742304412393941
結合使用crossValidation找出最佳模型的話
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
rfcv = CrossValidator(estimator=rf, evaluator=auc_evaluator,estimatorParamMaps=paramGrid, numFolds=3)
stages = stages[:-1]
stages.append(rfcv)
rfcv_pipeline = Pipeline(stages=stages)
rfcv_pipelineModel = rfcv_pipeline.fit(train_df)
rfcvpredictions = rfcv_pipelineModel.transform(test_df)
predictions_train_df = rfcv_pipelineModel.transform(train_df)
predictions_test_df = rfcv_pipelineModel.transform(test_df)
auc = auc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_train_df)
print('在訓練集,模型 准確率 為:',acc)
auc = auc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 AUC 指標:',auc)
acc = acc_evaluator.evaluate(predictions_test_df)
print('在測試集,模型 准確率 為:',acc)
在訓練集,模型 AUC 指標: 0.717670600078623
在訓練集,模型 准確率 為: 0.796372867834134
在測試集,模型 AUC 指標: 0.705657664714809
在測試集,模型 准確率 為: 0.783345671239597
結果
采用上述auc最佳結果模型取預測,並保存結果文件。
predictions = rfcv_pipelineModel.transform(real_test)
columns = ['ItemID', 'Age', 'Sex', 'Profession', 'CityType', 'YearsInCity', 'ItemCategory1', 'ItemCategory2', 'ItemCategory3', 'Amount']
result = predictions.select([columns] + ["prediction"])
result.repartition(1).write.csv("./result", encoding="utf-8", header=True)