業務場景:
現在項目中需要通過對spark對原始數據進行計算,然后將計算結果寫入到mysql中,但是在寫入的時候有個限制:
1、mysql中的目標表事先已經存在,並且當中存在主鍵,自增長的鍵id
2、在進行將dataFrame寫入表的時候,id字段不允許手動寫入,因為其實自增長的
要求:
1、寫入數據庫的時候,需要指定字段寫入,也就是說,只指定部分字段寫入
2、在寫入數據庫的時候,對於操作主鍵相同的記錄要實現更新操作,非插入操作
分析:
spark本身提供了對dataframe的寫入數據庫的操作,即:
/** * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source. * * @since 1.3.0 */ public enum SaveMode { /** * Append mode means that when saving a DataFrame to a data source, if data/table already exists, * contents of the DataFrame are expected to be appended to existing data. * * @since 1.3.0 */ Append, /** * Overwrite mode means that when saving a DataFrame to a data source, * if data/table already exists, existing data is expected to be overwritten by the contents of * the DataFrame. * * @since 1.3.0 */ Overwrite, /** * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists, * an exception is expected to be thrown. * * @since 1.3.0 */ ErrorIfExists, /** * Ignore mode means that when saving a DataFrame to a data source, if data already exists, * the save operation is expected to not save the contents of the DataFrame and to not * change the existing data. * * @since 1.3.0 */ Ignore }
但是,顯然這種方式寫入的時候,需要我們的dataFrame中的每個字段都需要對mysql目標表中相對應,在寫入的時候需要全部字段都寫入,這是種方式簡單,但是這不符合我們的業務需求,所以我們需要換一種思路,也就是說,如果我們能夠通過自定義insert語句的方式,也就是說通過jdbc的方式進行寫入數據,那就更好了。這樣也更符合我們的業務需求。
具體實現(開發環境:IDEA):
實現方式:通過c3p0連接池的方式進行數據的寫入,這樣我們就可以直接通過自己拼接sql,來實現我們需要插入數據庫的指定的字段值,當然這種方式實現起來也比較繁瑣。
第一步:
我們需要先導入響應的依賴包:
sbt項目導入方式:
打開build.sbt文件
在紅色框出進行添加即可
maven項目導入方式:
<dependency> <groupId>mysql</groupId> <artifactId>mysql-connector-java</artifactId> <version>6.0.6</version> </dependency> <dependency> <groupId>com.mchange</groupId> <artifactId>c3p0</artifactId> <version>0.9.5</version> </dependency>
我習慣與將關於數據庫操作的幾個庫類放到單獨的一個BDUtils包中
第一步:定義讀取配置文件的類
package cn.com.xxx.audit.DBUtils import java.util.Properties object PropertiyUtils { def getFileProperties(fileName: String, propertityKey: String): String = { val result = this.getClass.getClassLoader.getResourceAsStream(fileName) val prop = new Properties() prop.load(result) prop.getProperty(propertityKey) } }
第二步:定義一個配置文件(db.properties),將該文件放在resource目錄中,並且內容使用"="進行連接
db.propreties
mysql.jdbc.url=jdbc:mysql://localhost:3306/test?serverTimezone=UTC mysql.jdbc.host=127.0.0.1 mysql.jdbc.port=3306 mysql.jdbc.user=root mysql.jdbc.password=123456 mysql.pool.jdbc.minPoolSize=20 mysql.pool.jdbc.maxPoolSize=50 mysql.pool.jdbc.acquireIncrement=10 mysql.pool.jdbc.maxStatements=50 mysql.driver=com.mysql.jdbc.Driver
第三步:定義一個連接池的類,負責獲取配置文件,並創建數據庫連接池
package cn.com.xxx.audit.DBUtils import java.sql.Connection import com.mchange.v2.c3p0.ComboPooledDataSource class MySqlPool extends Serializable { private val cpds: ComboPooledDataSource = new ComboPooledDataSource(true) try { cpds.setJdbcUrl(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.url")) cpds.setDriverClass(PropertiyUtils.getFileProperties("db.properties", "mysql.driver")) cpds.setUser(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.user")) cpds.setPassword(PropertiyUtils.getFileProperties("db.properties", "mysql.jdbc.password")) cpds.setMinPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.minPoolSize").toInt) cpds.setMaxPoolSize(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxPoolSize").toInt) cpds.setAcquireIncrement(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.acquireIncrement").toInt) cpds.setMaxStatements(PropertiyUtils.getFileProperties("db.properties", "mysql.pool.jdbc.maxStatements").toInt) } catch { case e: Exception => e.printStackTrace() } def getConnection: Connection = { try { cpds.getConnection() } catch { case ex: Exception => ex.printStackTrace() null } } def close() = { try { cpds.close() } catch { case ex: Exception => ex.printStackTrace() } } }
第四步:創建連接池管理器對象,用來獲取數據庫連接
package cn.com.winner.audit.DBUtils object MySqlPoolManager { var mysqlManager: MySqlPool = _ def getMysqlManager: MySqlPool = { synchronized { if (mysqlManager == null) { mysqlManager = new MySqlPool } } mysqlManager } }
第五步:對數據庫的操作對象
package cn.com.winner.audit.DBUtils import java.sql.{Date, Timestamp} import java.util.Properties import org.apache.log4j.Logger import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, SQLContext} object OperatorMySql { val logger: Logger = Logger.getLogger(this.getClass.getSimpleName) /** * 將dataframe所有類型(除id外)轉換為string后,通過c3p0的連接池方式,向mysql寫入數據 * * @param tableName 表名 * @param resultDateFrame datafream */ def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame): Unit = { val colNumbsers = resultDateFrame.columns.length val sql = getInsertSql(tableName, colNumbsers) val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType) resultDateFrame.foreachPartition(partitionRecords => { val conn = MySqlPoolManager.getMysqlManager.getConnection val prepareStatement = conn.prepareStatement(sql) val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") try { conn.setAutoCommit(false) partitionRecords.foreach(record => { for (i <- 1 to colNumbsers) { val value = record.get(i - 1) val dateType = columnDataTypes(i - 1) if (value != null) { prepareStatement.setString(i, value.toString) dateType match { case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1)) case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1)) case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1)) case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1)) case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1)) case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1)) case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1)) case _ => throw new RuntimeException("nonsupport $ {dateType} !!!") } } else { metaData.absolute(i) prepareStatement.setNull(i, metaData.getInt("DATA_TYPE")) } } prepareStatement.addBatch() }) prepareStatement.executeBatch() conn.commit() } catch { case e: Exception => println(s"@@ saveDFtoDBUsePool ${e.getMessage}") } finally { prepareStatement.close() conn.close() } }) } /** * 拼接sql */ def getInsertSql(tableName: String, colNumbers: Int): String = { var sqlStr = "insert into " + tableName + "values(" for (i <- 1 to colNumbers) { sqlStr += "?" if (i != colNumbers) { sqlStr += "," } } sqlStr += ")" sqlStr } /** * 以元祖的額方式返回mysql屬性信息 * * @return */ def getMysqlInfo: (String, String, String) = { val jdbcURL = PropertiyUtils.getFileProperties("", "") val userName = PropertiyUtils.getFileProperties("", "") val password = PropertiyUtils.getFileProperties("", "") (jdbcURL, userName, password) } /** * 從mysql中獲取dataframe * * @param sqlContext sqlContext * @param mysqlTableName 表名 * @param queryCondition 查詢條件 * @return */ def getDFFromeMysql(sqlContext: SQLContext, mysqlTableName: String, queryCondition: String = ""): DataFrame = { val (jdbcURL, userName, password) = getMysqlInfo val prop = new Properties() prop.put("user", userName) prop.put("password", password) //scala中其實equals和==是相同的,並不跟java中一樣 if (null == queryCondition || "" == queryCondition) { sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop) } else { sqlContext.read.jdbc(jdbcURL, mysqlTableName, prop).where(queryCondition) } } /** * 刪除數據表 * * @param SQLContext * @param mysqlTableName * @return */ def dropMysqlTable(SQLContext: SQLContext, mysqlTableName: String): Boolean = { val conn = MySqlPoolManager.getMysqlManager.getConnection val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"drop table $mysqlTableName") } catch { case e: Exception => println(s"mysql drop MysqlTable error:${e.getMessage}") false } finally { preparedStatement.close() conn.close() } } /** * 從表中刪除數據 * * @param SQLContext * @param mysqlTableName 表名 * @param condition 條件,直接從where后面開始 * @return */ def deleteMysqlTableData(SQLContext: SQLContext, mysqlTableName: String, condition: String): Boolean = { val conn = MySqlPoolManager.getMysqlManager.getConnection val preparedStatement = conn.createStatement() try { preparedStatement.execute(s"delete from $mysqlTableName where $condition") } catch { case e: Exception => println(s"mysql delete MysqlTableNameData error:${e.getMessage}") false } finally { preparedStatement.close() conn.close() } } /** * 保存dataframe到mysql中,如果表不存在的話,會自動創建 * * @param tableName * @param resultDataFrame */ def saveDFtoDBCreateTableIfNotExists(tableName: String, resultDataFrame: DataFrame) = { //如果沒有表,根據dataframe建表 createTableIfNotExist(tableName, resultDataFrame) //驗證數據表字段和dataframe字段個數和名稱,順序是否一致 verifyFieldConsistency(tableName, resultDataFrame) //保存df saveDFtoDBUsePool(tableName, resultDataFrame) } /** * 如果表不存在則創建 * * @param tableName * @param df * @return */ def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = { val conn = MySqlPoolManager.getMysqlManager.getConnection val metaData = conn.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") //如果沒有該表,創建數據表 if (!colResultSet.next()) { //構建表字符串 val sb = new StringBuilder(s"create table `$tableName`") df.schema.fields.foreach(x => { if (x.name.equalsIgnoreCase("id")) { //如果字段名是id,則設置為主鍵,不為空,自增 sb.append(s"`${x.name}` int(255) not null auto_increment primary key,") } else { x.dataType match { case _: ByteType => sb.append(s"`${x.name}` int(100) default null,") case _: ShortType => sb.append(s"`${x.name}` int(100) default null,") case _: IntegerType => sb.append(s"`${x.name}` int(100) default null,") case _: LongType => sb.append(s"`${x.name}` bigint(100) default null,") case _: BooleanType => sb.append(s"`${x.name}` tinyint default null,") case _: FloatType => sb.append(s"`${x.name}` float(50) default null,") case _: DoubleType => sb.append(s"`${x.name}` double(50) default null,") case _: StringType => sb.append(s"`${x.name}` varchar(50) default null,") case _: TimestampType => sb.append(s"`${x.name}` timestamp default current_timestamp,") case _: DateType => sb.append(s"`${x.name}` date default null,") case _ => throw new RuntimeException(s"non support ${x.dataType}!!!") } } }) sb.append(") engine = InnDB default charset=utf8") val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString() println(sql_createTable) val statement = conn.createStatement() statement.execute(sql_createTable) } } /** * 拼接insertOrUpdate語句 * * @param tableName * @param cols * @param updateColumns * @return */ def getInsertOrUpdateSql(tableName: String, cols: Array[String], updateColumns: Array[String]): String = { val colNumbers = cols.length var sqlStr = "insert into " + tableName + "(" for (i <- 1 to colNumbers) { sqlStr += cols(i - 1) if (i != colNumbers) { sqlStr += "," } } sqlStr += ") values(" for (i <- 1 to colNumbers) { sqlStr += "?" if (i != colNumbers) { sqlStr += "," } } sqlStr += ") on duplicate key update " updateColumns.foreach(str => { sqlStr += s"$str=?," }) sqlStr.substring(0, sqlStr.length - 1) } /** * * @param tableName * @param resultDateFrame 要入庫的dataframe * @param updateColumns 要更新的字段 */ def insertOrUpdateDFtoDBUserPool(tableName: String, resultDateFrame: DataFrame, updateColumns: Array[String]): Boolean = { var status = true var count = 0 val colNumbsers = resultDateFrame.columns.length val sql = getInsertOrUpdateSql(tableName, resultDateFrame.columns, updateColumns) val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType) println(s"\n$sql") resultDateFrame.foreachPartition(partitionRecords => { val conn = MySqlPoolManager.getMysqlManager.getConnection val prepareStatement = conn.prepareStatement(sql) val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") try { conn.setAutoCommit(false) partitionRecords.foreach(record => { //設置需要插入的字段 for (i <- 1 to colNumbsers) { val value = record.get(i - 1) val dateType = columnDataTypes(i - 1) if (value != null) { prepareStatement.setString(i, value.toString) dateType match { case _: ByteType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: ShortType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: IntegerType => prepareStatement.setInt(i, record.getAs[Int](i - 1)) case _: LongType => prepareStatement.setLong(i, record.getAs[Long](i - 1)) case _: BooleanType => prepareStatement.setBoolean(i, record.getAs[Boolean](i - 1)) case _: FloatType => prepareStatement.setFloat(i, record.getAs[Float](i - 1)) case _: DoubleType => prepareStatement.setDouble(i, record.getAs[Double](i - 1)) case _: StringType => prepareStatement.setString(i, record.getAs[String](i - 1)) case _: TimestampType => prepareStatement.setTimestamp(i, record.getAs[Timestamp](i - 1)) case _: DateType => prepareStatement.setDate(i, record.getAs[Date](i - 1)) case _ => throw new RuntimeException("nonsupport $ {dateType} !!!") } } else { metaData.absolute(i) prepareStatement.setNull(i, metaData.getInt("Data_Type")) } } //設置需要 更新的字段值 for (i <- 1 to updateColumns.length) { val fieldIndex = record.fieldIndex(updateColumns(i - 1)) val value = record.get(i) val dataType = columnDataTypes(fieldIndex) println(s"\n更新字段值屬性索引: $fieldIndex,屬性值:$value,屬性類型:$dataType") if (value != null) { dataType match { case _: ByteType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex)) case _: ShortType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex)) case _: IntegerType => prepareStatement.setInt(colNumbsers + i, record.getAs[Int](fieldIndex)) case _: LongType => prepareStatement.setLong(colNumbsers + i, record.getAs[Long](fieldIndex)) case _: BooleanType => prepareStatement.setBoolean(colNumbsers + i, record.getAs[Boolean](fieldIndex)) case _: FloatType => prepareStatement.setFloat(colNumbsers + i, record.getAs[Float](fieldIndex)) case _: DoubleType => prepareStatement.setDouble(colNumbsers + i, record.getAs[Double](fieldIndex)) case _: StringType => prepareStatement.setString(colNumbsers + i, record.getAs[String](fieldIndex)) case _: TimestampType => prepareStatement.setTimestamp(colNumbsers + i, record.getAs[Timestamp](fieldIndex)) case _: DateType => prepareStatement.setDate(colNumbsers + i, record.getAs[Date](fieldIndex)) case _ => throw new RuntimeException(s"no support ${dataType} !!!") } } else { metaData.absolute(colNumbsers + i) prepareStatement.setNull(colNumbsers + i, metaData.getInt("data_Type")) } } prepareStatement.addBatch() count += 1 }) //批次大小為100 if (count % 100 == 0) { prepareStatement.executeBatch() } conn.commit() } catch { case e: Exception => println(s"@@ ${e.getMessage}") status = false } finally { prepareStatement.executeBatch() conn.commit() prepareStatement.close() conn.close() } }) status } /** * 驗證屬性是否存在 */ def verifyFieldConsistency(tableName: String, df: DataFrame) = { val conn = MySqlPoolManager.getMysqlManager.getConnection val metaData = conn.getMetaData val colResultSet = metaData.getColumns(null, "%", tableName, "%") colResultSet.last() val tableFieldNum = colResultSet.getRow val dfFieldNum = df.columns.length if (tableFieldNum != dfFieldNum) { throw new Exception("") } for (i <- 1 to tableFieldNum) { colResultSet.absolute(i) val tableFieldName = colResultSet.getString("column_name") val dfFieldName = df.columns.apply(i - 1) if (tableFieldName.equals(dfFieldName)) { throw new Exception("") } } colResultSet.beforeFirst() } }
第六步:調用對應的方法,對數據庫進行自定義增刪改查,而不是通過dataFrame自帶的api對數據庫操作,這樣更加的靈活。
package cn.com.xxx.audit import cn.com.winner.audit.DBUtils.{OperatorMySql, PropertiyUtils} import cn.com.winner.common.until.{DateOperator, DateUtil} import org.apache.spark.HashPartitioner import org.apache.spark.sql.DataFrame /** * 持久化數據 */ object SaveData { /** * DF數據寫入mysql結果表 * * @param tableName 保存的表名 * @param ResultDFs 需要保存的DF * @param updateCols 更新的字段 * @return */ def saveToMysql(tableName: String, ResultDFs: Array[DataFrame], updateCols: Array[String]) = { //將DataFrmae進行合並 val resultDF = LoadData.mergeDF(ResultDFs.toVector)
//這里直接調用OperatorMysql的insert方法,使用拼接sql的方式進行對數據庫進行插入操作 OperatorMySql.insertOrUpdateDFtoDBUserPool(tableName, resultDF, updateCols) } }
對於第五步中的sql拼接,我只是根據我的需求進行拼接,我們可以根據自己不同的需求對sql進行拼接,並且調用不同的方法對dataFrame進行操作。