最近一個項目,需要操作近70億數據進行統計分析。如果存入MySQL,很難讀取如此大的數據,即使使用搜索引擎,也是非常慢。經過調研決定借助我們公司大數據平台結合Spark技術完成這么大數據量的統計分析。
為了后期方便開發人員開發,決定寫了幾個工具類,屏蔽對MySQL及Hive的操作代碼,只需要關心業務代碼的編寫。
工具類如下:
一. Spark操作MySQL
1. 根據sql語句獲取Spark DataFrame:
/** * 從MySql數據庫中獲取DateFrame * * @param spark SparkSession * @param sql 查詢SQL * @return DateFrame */ def getDFFromMysql(spark: SparkSession, sql: String): DataFrame = { println(s"url:${mySqlConfig.url} user:${mySqlConfig.user} sql: ${sql}") spark.read.format("jdbc").option("url", mySqlConfig.url) .option("user", mySqlConfig.user) .option("password", mySqlConfig.password) .option("driver", "com.mysql.jdbc.Driver") .option("query", sql) .load() }
2. 將Spark DataFrame 寫入MySQL數據庫表
/** * 將結果寫入Mysql * @param df DataFrame * @param mode SaveMode * @param tableName SaveMode */ def writeIntoMySql(df: DataFrame, mode: SaveMode, tableName: String): Unit ={ mode match { case SaveMode.Append => appendDataIntoMysql(df, tableName); case SaveMode.Overwrite => overwriteMysqlData(df, tableName); case _ => throw new Exception("目前只支持Append及Overwrite!") } }
/** * 將數據集插入Mysql表 * @param df DataFrame * @param mysqlTableName 表名:database_name.table_name * @return */ def appendDataIntoMysql(df: DataFrame, mysqlTableName: String) = { df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp) }
/** * 將數據集插入Mysql表 * @param df DataFrame * @param mysqlTableName 表名:database_name.table_name * @return */ def overwriteMysqlData(df: DataFrame, mysqlTableName: String) = { //先清除Mysql表中數據 truncateMysqlTable(mysqlTableName) //再往表中追加數據 df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp) }
/** * 刪除數據表 * @param mysqlTableName * @return */ def truncateMysqlTable(mysqlTableName: String): Boolean = { val conn = MySQLPoolManager.getMysqlManager.getConnection //從連接池中獲取一個連接 val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"truncate table $mysqlTableName") } catch { case e: Exception => println(s"mysql truncateMysqlTable error:${ExceptionUtil.getExceptionStack(e)}") false } finally { preparedStatement.close() conn.close() }
3. 根據條件刪除MySQL表數據
/** * 刪除表中的數據 * @param mysqlTableName * @param condition * @return */ def deleteMysqlTableData(mysqlTableName: String, condition: String): Boolean = { val conn = MySQLPoolManager.getMysqlManager.getConnection //從連接池中獲取一個連接 val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"delete from $mysqlTableName where $condition") } catch { case e: Exception => println(s"mysql deleteMysqlTable error:${ExceptionUtil.getExceptionStack(e)}") false } finally { preparedStatement.close() conn.close() } }
4. 保存DataFrame 到 MySQL中,如果表不存在的話,會自動創建
/** * 保存DataFrame 到 MySQL中,如果表不存在的話,會自動創建 * @param tableName * @param resultDateFrame */ def saveDFtoDBCreateTableIfNotExist(tableName: String, resultDateFrame: DataFrame) { //如果沒有表,根據DataFrame建表 createTableIfNotExist(tableName, resultDateFrame) //驗證數據表字段和dataFrame字段個數和名稱,順序是否一致 verifyFieldConsistency(tableName, resultDateFrame) //保存df saveDFtoDBUsePool(tableName, resultDateFrame) }
/** * 如果數據表不存在,根據DataFrame的字段創建數據表,數據表字段順序和dataFrame對應 * 若DateFrame出現名為id的字段,將其設為數據庫主鍵(int,自增,主鍵),其他字段會根據DataFrame的DataType類型來自動映射到MySQL中 * * @param tableName 表名 * @param df dataFrame * @return */ def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = { val con = MySQLPoolManager.getMysqlManager.getConnection val metaData = con.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") //如果沒有該表,創建數據表 if (!colResultSet.next()) { //構建建表字符串 val sb = new StringBuilder(s"CREATE TABLE `$tableName` (") df.schema.fields.foreach(x => if (x.name.equalsIgnoreCase("id")) { sb.append(s"`${x.name}` int(255) NOT NULL AUTO_INCREMENT PRIMARY KEY,") //如果是字段名為id,設置主鍵,整形,自增 } else { x.dataType match { case _: ByteType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: ShortType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: IntegerType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,") case _: LongType => sb.append(s"`${x.name}` bigint(100) DEFAULT NULL,") case _: BooleanType => sb.append(s"`${x.name}` tinyint DEFAULT NULL,") case _: FloatType => sb.append(s"`${x.name}` float(50) DEFAULT NULL,") case _: DoubleType => sb.append(s"`${x.name}` double(50) DEFAULT NULL,") case _: StringType => sb.append(s"`${x.name}` varchar(50) DEFAULT NULL,") case _: TimestampType => sb.append(s"`${x.name}` timestamp DEFAULT current_timestamp,") case _: DateType => sb.append(s"`${x.name}` date DEFAULT NULL,") case _ => throw new RuntimeException(s"nonsupport ${x.dataType} !!!") } } ) sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8") val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString() println(sql_createTable) val statement = con.createStatement() statement.execute(sql_createTable) } }
/** * 驗證數據表和dataFrame字段個數,名稱,順序是否一致 * * @param tableName 表名 * @param df dataFrame */ def verifyFieldConsistency(tableName: String, df: DataFrame): Unit = { val con = MySQLPoolManager.getMysqlManager.getConnection val metaData = con.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") colResultSet.last() val tableFiledNum = colResultSet.getRow val dfFiledNum = df.columns.length if (tableFiledNum != dfFiledNum) { throw new Exception(s"數據表和DataFrame字段個數不一致!!table--$tableFiledNum but dataFrame--$dfFiledNum") } for (i <- 1 to tableFiledNum) { colResultSet.absolute(i) val tableFileName = colResultSet.getString("COLUMN_NAME") val dfFiledName = df.columns.apply(i - 1) if (!tableFileName.equals(dfFiledName)) { throw new Exception(s"數據表和DataFrame字段名不一致!!table--'$tableFileName' but dataFrame--'$dfFiledName'") } } colResultSet.beforeFirst() }
/** * 將DataFrame所有類型(除id外)轉換為String后,通過c3p0的連接池方法,向mysql寫入數據 * * @param tableName 表名 * @param resultDateFrame DataFrame */ def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame) { val colNumbers = resultDateFrame.columns.length val sql = getInsertSql(tableName, colNumbers) val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType) resultDateFrame.foreachPartition(partitionRecords => { val conn = MySQLPoolManager.getMysqlManager.getConnection //從連接池中獲取一個連接 val preparedStatement = conn.prepareStatement(sql) val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") //通過連接獲取表名對應數據表的元數據 try { conn.setAutoCommit(false) partitionRecords.foreach(record => { //注意:setString方法從1開始,record.getString()方法從0開始 for (i <- 1 to colNumbers) { val value = record.get(i - 1) val dateType = columnDataTypes(i - 1) if (value != null) { //如何值不為空,將類型轉換為String preparedStatement.setString(i, value.toString) dateType match { case _: ByteType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: ShortType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: IntegerType => preparedStatement.setInt(i, record.getAs[Int](i - 1)) case _: LongType => preparedStatement.setLong(i, record.getAs[Long](i - 1)) case _: BooleanType => preparedStatement.setBoolean(i, record.getAs[Boolean](i - 1)) case _: FloatType => preparedStatement.setFloat(i, record.getAs[Float](i - 1)) case _: DoubleType => preparedStatement.setDouble(i, record.getAs[Double](i - 1)) case _: StringType => preparedStatement.setString(i, record.getAs[String](i - 1)) case _: TimestampType => preparedStatement.setTimestamp(i, record.getAs[Timestamp](i - 1)) case _: DateType => preparedStatement.setDate(i, record.getAs[Date](i - 1)) case _ => throw new RuntimeException(s"nonsupport ${dateType} !!!") } } else { //如果值為空,將值設為對應類型的空值 metaData.absolute(i) preparedStatement.setNull(i, metaData.getInt("DATA_TYPE")) } } preparedStatement.addBatch() }) preparedStatement.executeBatch() conn.commit() } catch { case e: Exception => println(s"@@ saveDFtoDBUsePool error: ${ExceptionUtil.getExceptionStack(e)}") // do some log } finally { preparedStatement.close() conn.close() } }) }
二、操作Spark
1. 切換Spark環境
定義環境Profile.scala
/** * @descrption * scf * @author wangxuexing * @date 2019/12/23 */ object Profile extends Enumeration{ type Profile = Value /** * 生產環境 */ val PROD = Value("prod") /** * 生產測試環境 */ val PROD_TEST = Value("prod_test") /** * 開發環境 */ val DEV = Value("dev") /** * 設置當前環境 */ val currentEvn = PROD }
定義SparkUtil.scala
import com.dmall.scf.Profile import com.dmall.scf.dto.{Env, MySqlConfig} import org.apache.spark.sql.{DataFrame, Encoder, SparkSession} import scala.collection.JavaConversions._ /** * @descrption Spark工具類 * scf * @author wangxuexing * @date 2019/12/23 */ object SparkUtils {
//開發環境
val DEV_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
val DEV_USER = "user"
val DEV_PASSWORD = "password"
//生產測試環境
val PROD_TEST_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"
val PROD_TEST_USER = "user"
val PROD_TEST_PASSWORD = "password"
//生產環境
val PROD_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"
val PROD_USER = "user"
val PROD_PASSWORD = "password"
def env = Profile.currentEvn /** * 獲取環境設置 * @return */ def getEnv: Env ={ env match { case Profile.DEV => Env(MySqlConfig(DEV_URL, DEV_USER, DEV_PASSWORD), SparkUtils.getDevSparkSession) case Profile.PROD => Env(MySqlConfig(PROD_URL,PROD_USER,PROD_PASSWORD), SparkUtils.getProdSparkSession) case Profile.PROD_TEST => Env(MySqlConfig(PROD_TEST_URL, PROD_TEST_USER, PROD_TEST_PASSWORD), SparkUtils.getProdSparkSession) case _ => throw new Exception("無法獲取環境") } } /** * 獲取生產SparkSession * @return */ def getProdSparkSession: SparkSession = { SparkSession .builder() .appName("scf") .enableHiveSupport()//激活hive支持 .getOrCreate() } /** * 獲取開發SparkSession * @return */ def getDevSparkSession: SparkSession = { SparkSession .builder() .master("local[*]") .appName("local-1576939514234") .config("spark.sql.warehouse.dir", "C:\\data\\spark-ware")//不指定,默認C:\data\projects\parquet2dbs\spark-warehouse .enableHiveSupport()//激活hive支持 .getOrCreate(); } /** * DataFrame 轉 case class * @param df DataFrame * @tparam T case class * @return */ def dataFrame2Bean[T: Encoder](df: DataFrame, clazz: Class[T]): List[T] = { val fieldNames = clazz.getDeclaredFields.map(f => f.getName).toList df.toDF(fieldNames: _*).as[T].collectAsList().toList } }
三、定義Spark操作流程
從MySQL或Hive讀取數據->邏輯處理->寫入MySQL
1. 定義處理流程
SparkAction.scala
import com.dmall.scf.utils.{MySQLUtils, SparkUtils} import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} /** * @descrption 定義Spark處理流程 * @author wangxuexing * @date 2019/12/23 */ trait SparkAction[T] { /** * 定義流程 */ def execute(args: Array[String], spark: SparkSession)={ //1. 前置處理 preAction //2. 處理 val df = action(spark, args) //3. 后置處理 postAction(df) } /** * 前置處理 * @return */ def preAction() = { //無前置處理 } /** * 處理 * @param spark * @return */ def action(spark: SparkSession, args: Array[String]) : DataFrame /** * 后置處理,比如保存結果到Mysql * @param df */ def postAction(df: DataFrame)={ //結果追加到scfc_supplier_run_field_value表 MySQLUtils.writeIntoMySql(df, saveTable._1, saveTable._2) } /** * 保存mode及表名 * @return */ def saveTable: (SaveMode, String) }
2. 實現流程
KanbanAction.scala
import com.dmall.scf.SparkAction import com.dmall.scf.dto.KanbanFieldValue import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import scala.collection.JavaConverters._ /** * @descrption * scf-spark * @author wangxuexing * @date 2020/1/10 */ trait KanbanAction extends SparkAction[KanbanFieldValue] { /** * 獲取datafram * @param resultList * @param spark * @return */ def getDataFrame(resultList: List[KanbanFieldValue], spark: SparkSession): DataFrame= { //根據模式字符串生成模式schema val fields = List(StructField("company_id", LongType, nullable = false), StructField("statistics_date", StringType, nullable = false), StructField("field_id", LongType, nullable = false), StructField("field_type", StringType, nullable = false), StructField("field_value", StringType, nullable = false), StructField("other_value", StringType, nullable = false)) val schema = StructType(fields) //將RDD的記錄轉換為行 val rowRDD = resultList.map(x=>Row(x.companyId, x.statisticsDate, x.fieldId, x.fieldType, x.fieldValue, x.otherValue)).asJava //RDD轉為DataFrame spark.createDataFrame(rowRDD, schema) } /** * 保存mode及表名 * * @return */ override def saveTable: (SaveMode, String) = (SaveMode.Append, "scfc_kanban_field_value") }
3. 實現具體業務邏輯
import com.dmall.scf.dto.{KanbanFieldValue, RegisteredMoney} import com.dmall.scf.utils.{DateUtils, MySQLUtils} import org.apache.spark.sql.{DataFrame, SparkSession} /** * @descrption * scf-spark 注冊資本分布 * @author wangxuexing * @date 2020/1/10 */ object RegMoneyDistributionAction extends KanbanAction{ val CLASS_NAME = this.getClass.getSimpleName().filter(!_.equals('$')) val RANGE_50W = BigDecimal(50) val RANGE_100W = BigDecimal(100) val RANGE_500W = BigDecimal(500) val RANGE_1000W = BigDecimal(1000) /** * 處理 * * @param spark * @return */ override def action(spark: SparkSession, args: Array[String]): DataFrame = { import spark.implicits._ if(args.length < 2){ throw new Exception("請指定是當前年(值為1)還是去年(值為2):1|2") } val lastDay = DateUtils.addSomeDays(-1) val (starDate, endDate, filedId) = args(1) match { case "1" => val startDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getFirstDateOfLastYear case false => DateUtils.getFirstDateOfCurrentYear } (startDate, DateUtils.formatNormalDateStr(lastDay), 44) case "2" => val startDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getLast2YearFirstStr(DateUtils.YYYY_MM_DD) case false => DateUtils.getLastYearFirstStr(DateUtils.YYYY_MM_DD) } val endDate = DateUtils.isFirstDayOfYear match { case true => DateUtils.getLast2YearLastStr(DateUtils.YYYY_MM_DD) case false => DateUtils.getLastYearLastStr(DateUtils.YYYY_MM_DD) } (startDate, endDate, 45) case _ => throw new Exception("請傳入正確的參數:是當前年(值為1)還是去年(值為2):1|2") } val sql = s"""SELECT id, IFNULL(registered_money, 0) registered_money FROM scfc_supplier_info WHERE `status` = 3 AND yn = 1""" val allDimension = MySQLUtils.getDFFromMysql(spark, sql) val beanList = allDimension.map(x => RegisteredMoney(x.getLong(0), x.getDecimal(1))) //val filterList = SparkUtils.dataFrame2Bean[RegisteredMoney](allDimension, classOf[RegisteredMoney]) val hiveSql = s""" SELECT DISTINCT(a.company_id) supplier_ids FROM wm_ods_cx_supplier_card_info a JOIN wm_ods_jrbl_loan_dkzhxx b ON a.card_code = b.gshkahao WHERE a.audit_status = '2' AND b.jiluztai = '0' AND to_date(b.gxinshij)>= '${starDate}' AND to_date(b.gxinshij)<= '${endDate}'""" println(hiveSql) val supplierIds = spark.sql(hiveSql).collect().map(_.getLong(0)) val filterList = beanList.filter(x => supplierIds.contains(x.supplierId)) val range1 = spark.sparkContext.collectionAccumulator[Int] val range2 = spark.sparkContext.collectionAccumulator[Int] val range3 = spark.sparkContext.collectionAccumulator[Int] val range4 = spark.sparkContext.collectionAccumulator[Int] val range5 = spark.sparkContext.collectionAccumulator[Int] filterList.foreach(x => { if(RANGE_50W.compare(x.registeredMoney) >= 0){ range1.add(1) } else if (RANGE_50W.compare(x.registeredMoney) < 0 && RANGE_100W.compare(x.registeredMoney) >= 0){ range1.add(1) } else if (RANGE_100W.compare(x.registeredMoney) < 0 && RANGE_500W.compare(x.registeredMoney) >= 0){ range2.add(1) } else if (RANGE_500W.compare(x.registeredMoney) < 0 && RANGE_1000W.compare(x.registeredMoney) >= 0){ range3.add(1) } else if (RANGE_1000W.compare(x.registeredMoney) < 0){ range4.add(1) } }) val resultList = List(("50萬元以下", range1.value.size()), ("50-100萬元", range2.value.size()), ("100-500萬元", range3.value.size()),("500-1000萬元", range4.value.size()), ("1000萬元以上", range5.value.size())).map(x => { KanbanFieldValue(1, lastDay, filedId, x._1, x._2.toString, "") }) getDataFrame(resultList, spark) } }
具體項目源碼請參考:
https://github.com/barrywang88/spark-tool
https://gitee.com/barrywang/spark-tool