程序簡介
項目以ml-100k電影評分數據集為輸入,實現了基於用戶的協同過濾算法,最后預測的MAE為0.84,因為經過優化,10萬條評分數據運行時間不超過2分鍾
協同過濾算法(CF)基於對用戶歷史行為數據的挖掘發現用戶的喜好偏向,並預測用戶可能喜好的產品進行推薦。也就是常見的“猜你喜歡”,和“購買了該商品的人也喜歡”等功能。
程序/數據集下載
代碼分析
導入模塊、路徑、預設參數
# -*- coding: utf-8 -*-
import numpy as np
from numba import jit
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import mean_absolute_error as MAE
import os
######################參數#################################
topK = 20#近鄰用戶數
testRate = 0.2#測試比例
seed = 2#隨機種子
testCount = 1000#最終參與測試行為數(不然運行太慢)
##########################################################
#用來正常顯示中文標簽
plt.rcParams['font.sans-serif']=['SimHei']
#用來正常顯示負號
plt.rcParams['axes.unicode_minus']=False
#路徑目錄
baseDir = ''#當前目錄
staticDir = os.path.join(baseDir,'Static')#靜態文件目錄
resultDir = os.path.join(baseDir,'Result')#結果文件目錄
載入數據,設置索引(這樣會快很多很多),查看數據前5行
#讀數據
users = pd.read_table(staticDir+'/電影評分數據/u.user',sep="|",names=['user_id','age','sex','occupation','zip_code'],encoding='latin-1',engine='python')
ratings = pd.read_table(staticDir+'/電影評分數據/u.data', sep='\t', names=['user_id', 'movie_id', 'rating', 'unix_timestamp'],encoding='latin-1',engine='python')
movies = pd.read_table(staticDir+'/電影評分數據/u.item',engine='python', sep='|',header=None,encoding='latin-1',names=['movie_id','title','release_date','video_release_date','IMDb_URL','unknown','Action','Adventure','Animation','Children','Comedy','Crime','Documentary','Drama','Fantasy','Film-Noir','Horror','Musical','Mystery','Romance','Sci-Fi','Thriller','War','Western'])
movies = movies.iloc[:,:5]
#設置索引
users = users.set_index(['user_id'],drop=False)
movies = movies.set_index(['movie_id'],drop=False)
ratings = ratings.set_index(['user_id','movie_id'],drop=False)
用戶數據
users.head()
user_id | age | sex | occupation | zip_code | |
---|---|---|---|---|---|
user_id | |||||
1 | 1 | 24 | M | technician | 85711 |
2 | 2 | 53 | F | other | 94043 |
3 | 3 | 23 | M | writer | 32067 |
4 | 4 | 24 | M | technician | 43537 |
5 | 5 | 33 | F | other | 15213 |
電影數據
movies.head()
movie_id | title | release_date | video_release_date | IMDb_URL | |
---|---|---|---|---|---|
movie_id | |||||
1 | 1 | Toy Story (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Toy%20Story%2... |
2 | 2 | GoldenEye (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?GoldenEye%20(... |
3 | 3 | Four Rooms (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Four%20Rooms%... |
4 | 4 | Get Shorty (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Get%20Shorty%... |
5 | 5 | Copycat (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Copycat%20(1995) |
評分數據(行為數據)
ratings.head()
user_id | movie_id | rating | unix_timestamp | ||
---|---|---|---|---|---|
user_id | movie_id | ||||
196 | 242 | 196 | 242 | 3 | 881250949 |
186 | 302 | 186 | 302 | 3 | 891717742 |
22 | 377 | 22 | 377 | 1 | 878887116 |
244 | 51 | 244 | 51 | 2 | 880606923 |
166 | 346 | 166 | 346 | 1 | 886397596 |
分割評分數據集為測試集和訓練集
#分割行為數據成測試集和訓練集
np.random.seed(seed)
testIndex = np.random.choice(range(ratings.shape[0]),size=int(ratings.shape[0]*testRate),replace=False)
testRatings = ratings.iloc[testIndex,:]
trainIndex = list(set(range(ratings.shape[0]))-set(testIndex))
trainRatings = ratings.iloc[trainIndex,:]
電影數據新增users列,統計出每部電影被哪些用戶觀看過,查看前5行
#統計所有電影在訓練集被觀看的用戶id
def calUsers(movieId):
#觀影記錄
views = trainRatings.loc[(slice(None),movieId),:]
users = views['user_id'].values.tolist()
return users
movies['users'] = movies['movie_id'].apply(calUsers)
movies.head()
movie_id | title | release_date | video_release_date | IMDb_URL | users | |
---|---|---|---|---|---|---|
movie_id | ||||||
1 | 1 | Toy Story (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Toy%20Story%2... | [308, 287, 148, 280, 66, 109, 181, 95, 189, 14... |
2 | 2 | GoldenEye (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?GoldenEye%20(... | [5, 268, 276, 87, 250, 201, 64, 13, 213, 373, ... |
3 | 3 | Four Rooms (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Four%20Rooms%... | [181, 81, 130, 49, 320, 145, 95, 99, 267, 417,... |
4 | 4 | Get Shorty (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Get%20Shorty%... | [99, 19, 207, 295, 201, 10, 308, 328, 109, 334... |
5 | 5 | Copycat (1995) | 01-Jan-1995 | NaN | http://us.imdb.com/M/title-exact?Copycat%20(1995) | [293, 43, 311, 109, 344, 145, 314, 308, 280, 4... |
用戶數據新增movies列,統計所有用戶在訓練集看過的電影id,查看前5行
#統計所有用戶在訓練集看過的電影id
def calMovies(userId):
#觀影記錄
views = trainRatings.loc[(userId,slice(None)),:]
movies = views['movie_id'].values.tolist()
return movies
users['movies'] = users['user_id'].apply(calMovies)
users.head()
user_id | age | sex | occupation | zip_code | movies | |
---|---|---|---|---|---|---|
user_id | ||||||
1 | 1 | 24 | M | technician | 85711 | [61, 189, 33, 160, 20, 202, 171, 265, 117, 47,... |
2 | 2 | 53 | F | other | 94043 | [292, 251, 314, 297, 312, 281, 13, 303, 308, 2... |
3 | 3 | 23 | M | writer | 32067 | [335, 245, 337, 343, 323, 331, 294, 332, 334, ... |
4 | 4 | 24 | M | technician | 43537 | [264, 303, 361, 357, 260, 356, 294, 288, 50, 2... |
5 | 5 | 33 | F | other | 15213 | [2, 439, 225, 110, 454, 424, 363, 98, 102, 211... |
計算相似度矩陣,查看前5行5列,元素即每個用戶之間相互的相似度,計算公式如下
#計算相似度矩陣
sims = pd.DataFrame(0,columns=users.index,index=users.index)
def calSim(userId1,userId2):
#用戶1的物品
user1Items = users.loc[userId1,'movies']
#物品2的用戶
user2Items = users.loc[userId2,'movies']
#兩個物品共同用戶
cross = list(set(user1Items) & set(user2Items))
#相似度
sim = len(cross)/((max(1e-1,len(user1Items))*max(1e-1,len(user2Items)))**0.5)
return sim
def fillSims(row):
userIds = pd.Series(row.index)
row[:] = userIds.apply(calSim,args=(row.name,))
return row
sims = sims.apply(fillSims,axis=1)
sims.iloc[:5,:5]
user_id | 1 | 2 | 3 | 4 | 5 |
---|---|---|---|---|---|
user_id | |||||
1 | 1.000000 | 0.116008 | 0.069067 | 0.064449 | 0.275500 |
2 | 0.116008 | 1.000000 | 0.142887 | 0.133333 | 0.036380 |
3 | 0.069067 | 0.142887 | 1.000000 | 0.238145 | 0.012377 |
4 | 0.064449 | 0.133333 | 0.238145 | 1.000000 | 0.040423 |
5 | 0.275500 | 0.036380 | 0.012377 | 0.040423 | 1.000000 |
用戶數據新增nears列,統計用戶的鄰近用戶,查看前5行
#計算用戶的鄰近用戶
def calNearUsers(userId):
#該物品的相似度向量,選取topK個物品
nearUserIds = sims.loc[:,userId].sort_values(ascending=False)[1:topK+1]
nearUserIds = nearUserIds.index.tolist()
return nearUserIds
users['near'] = users['user_id'].apply(calNearUsers)
users.head()
user_id | age | sex | occupation | zip_code | movies | near | |
---|---|---|---|---|---|---|---|
user_id | |||||||
1 | 1 | 24 | M | technician | 85711 | [61, 189, 33, 160, 20, 202, 171, 265, 117, 47,... | [457, 435, 916, 648, 933, 276, 864, 297, 805, ... |
2 | 2 | 53 | F | other | 94043 | [292, 251, 314, 297, 312, 281, 13, 303, 308, 2... | [701, 673, 926, 131, 306, 569, 937, 520, 486, ... |
3 | 3 | 23 | M | writer | 32067 | [335, 245, 337, 343, 323, 331, 294, 332, 334, ... | [752, 489, 784, 587, 863, 529, 783, 428, 126, ... |
4 | 4 | 24 | M | technician | 43537 | [264, 303, 361, 357, 260, 356, 294, 288, 50, 2... | [33, 816, 750, 408, 443, 783, 725, 596, 355, 6... |
5 | 5 | 33 | F | other | 15213 | [2, 439, 225, 110, 454, 424, 363, 98, 102, 211... | [222, 648, 407, 56, 495, 254, 497, 457, 727, 1... |
對測試集評分進行預測,查看前5行,最終測試MAE為0.84,公式如下
def predict(row):
'''預測評分'''
userId = row['user_id']
movieId = row['movie_id']
#topK個鄰近用戶
nearUserIds = users.loc[userId,'near']
#用戶在訓練集的物品
itemUserIds = movies.loc[movieId,'users']
#物品交集
cross = list(set(nearUserIds) & set(itemUserIds))
#預測評分
up = 0#分母
down = 0#分子
for nearUserId in cross:
sim = sims.loc[nearUserId,userId]
down += sim
#用戶對鄰近物品評分
score = trainRatings.loc[(nearUserId,movieId),'rating']
up += score * sim
if up == 0:
return None
else:
return up/down
#開始測試
testRatings['predict'] = testRatings.apply(predict,axis=1)
testRatings = testRatings.dropna()
mae = MAE(testRatings['rating'],testRatings['predict'])
print('測試集MAE為%.2f'%mae)
testRatings.head()
測試集MAE為0.84
user_id | movie_id | rating | unix_timestamp | predict | ||
---|---|---|---|---|---|---|
user_id | movie_id | |||||
157 | 273 | 157 | 273 | 5 | 886889876 | 3.740361 |
405 | 1065 | 405 | 1065 | 1 | 885546069 | 3.790151 |
244 | 550 | 244 | 550 | 1 | 880602264 | 3.164083 |
378 | 768 | 378 | 768 | 4 | 880333598 | 2.786181 |
919 | 111 | 919 | 111 | 4 | 875288681 | 3.691650 |