如何讓spark sql寫mysql的時候支持update操作


 

 

 

如何讓sparkSQL在對接mysql的時候,除了支持:Append、Overwrite、ErrorIfExists、Ignore;還要在支持update操作

1、首先了解背景

spark提供了一個枚舉類,用來支撐對接數據源的操作模式

 

 

 通過源碼查看,很明顯,spark是不支持update操作的

2、如何讓sparkSQL支持update

關鍵的知識點就是:

我們正常在sparkSQL寫數據到mysql的時候:

大概的api是:

dataframe.write
        .format("sql.execution.customDatasource.jdbc")
        .option("jdbc.driver", "com.mysql.jdbc.Driver")
        .option("jdbc.url", "jdbc:mysql://localhost:3306/test?user=root&password=&useUnicode=true&characterEncoding=gbk&autoReconnect=true&failOverReadOnly=false")
        .option("jdbc.db", "test")
        .save()

那么在底層中,spark會通過JDBC方言JdbcDialect , 將我們要插入的數據翻譯成:

insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)

那么通過方言解析出的sql語句就通過PrepareStatement的executeBatch(),將sql語句提交給mysql,然后數據插入;

那么上面的sql語句很明顯,完全就是插入代碼,並沒有我們期望的 update操作,類似:

UPDATE table_name SET field1=new-value1, field2=new-value2

但是mysql獨家支持這樣的sql語句:

INSERT INTO student (columns_1,columns_2)VALUES ('第一個字段值','第二個字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';

大概的意思就是,如果數據不存在則插入,如果數據存在,則 執行update操作;

因此,我們的切入點就是,讓sparkSQL內部對接JdbcDialect的時候,能夠生成這種sql:

INSERT INTO 表名稱 (columns_1,columns_2)VALUES ('第一個字段值','第二個字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';

 

3、改造源碼前,需要了解整體的代碼設計和執行流程

首先是:

dataframe.write

調用write方法就是為了返回一個類:DataFrameWriter

主要是因為DataFrameWriter是sparksql對接外部數據源寫入的入口攜帶類,下面這些內容是給DataFrameWriter注冊的攜帶信息

 

 

 

然后在出發save()操作后,就開始將數據寫入;

接下來看save()源碼:

 

 

 

在上面的源碼里面主要是注冊DataSource實例,然后使用DataSource的write方法進行數據寫入

實例化DataSource的時候:

def save(): Unit = {
   assertNotBucketed("save")
   val dataSource = DataSource(
     df.sparkSession,
     className = source,//自定義數據源的包路徑
     partitionColumns = partitioningColumns.getOrElse(Nil),//分區字段
     bucketSpec = getBucketSpec,//分桶(用於hive)
     options = extraOptions.toMap)//傳入的注冊信息
//mode:插入數據方式SaveMode , df:要插入的數據
   dataSource.write(mode, df)
}

然后就是dataSource.write(mode, df)的細節,整段的邏輯就是:

根據providingClass.newInstance()去做模式匹配,然后匹配到哪里,就執行哪里的代碼;

 

 

 然后看下providingClass是什么:

 

 

 

 

 

 拿到包路徑.DefaultSource之后,程序進入:

 

 

 那么如果是數據庫作為寫入目標的話,就會走:dataSource.createRelation,直接跟進源碼:

 

 

 

很明顯是個特質,因此哪里實現了特質,程序就會走到哪里了;

實現這個特質的地方就是:包路徑.DefaultSource , 然后就在這里面去實現數據的插入和update的支持操作;

4、改造源碼

根據代碼的流程,最終sparkSQL 將數據寫入mysql的操作,會進入:包路徑.DefaultSource這個類里面;

也就是說,在這個類里面既要支持spark的正常插入操作(SaveMode),還要在支持update;

如果讓sparksql支持update操作,最關鍵的就是做一個判斷,比如:

if(isUpdate){
sql語句:INSERT INTO student (columns_1,columns_2)VALUES ('第一個字段值','第二個字段值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';
}else{
   insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)
}

但是,在spark生產sql語句的源碼中,是這樣寫的:

 

 

 

沒有任何的判斷邏輯,就是最后生成一個:

INSERT INTO TABLE (字段1 , 字段2....) VALUES (? , ? ...)

所以首要的任務就是 ,怎么能讓當前代碼支持:ON DUPLICATE KEY UPDATE

可以做個大膽的設計,就是在insertStatement這個方法中做個如下的判斷

def insertStatement(conn: Connection, savemode:CustomSaveMode , table: String, rddSchema: StructType, dialect: JdbcDialect)
    : PreparedStatement = {
   val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
   val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
   if(savemode == CustomSaveMode.update){
  //TODO 如果是update,就組裝成ON DUPLICATE KEY UPDATE的模式處理
       s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
  }esle{
       val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
  conn.prepareStatement(sql)
  }
   
}

這樣,在用戶傳遞進來的savemode模式,我們進行校驗,如果是update操作,就返回對應的sql語句!

所以按照上面的邏輯,我們代碼這樣寫:

 

 

 

這樣我們就拿到了對應的sql語句;

但是只有這個sql語句還是不行的,因為在spark中會執行jdbc的prepareStatement操作,這里面會涉及到游標。

即jdbc在遍歷這個sql的時候,源碼會這樣做:

 

 

 看下makeSetter:

 

 

 

所謂有坑就是:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?)

那么當前在源碼中返回的數組長度應該是3:

val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
      .map(makeSetter(conn, dialect, _)).toArray

但是如果我們此時支持了update操作,既:

insert into table (字段1 , 字段2, 字段3) values (? , ? , ?) ON DUPLICATE KEY UPDATE 字段1 = ?,字段2 = ?,字段3=?;

那么很明顯,上面的sql語句提供了6個? , 但在規定字段長度的時候只有3

 

 

 

這樣的話,后面的update操作就無法執行,程序報錯!

所以我們需要有一個 識別機制,既:

if(isupdate){
    val numFields = rddSchema.fields.length * 2
}else{
    val numFields = rddSchema.fields.length
}

 

 

row[1,2,3] setter(0,1) //index of setter , index of row setter(1,2) setter(2,3) setter(3,1) setter(4,2) setter(5,3)

所以在prepareStatment中的占位符應該是row的兩倍,而且應該是類似這樣的一個邏輯

因此,代碼改造前樣子:

 

 

 

 

改造后的樣子:

try {
     if (supportsTransactions) {
       conn.setAutoCommit(false) // Everything in the same db transaction.
       conn.setTransactionIsolation(finalIsolationLevel)
    }
//     val stmt = insertStatement(conn, table, rddSchema, dialect)
     //此處采用最新自己的sql語句,封裝成prepareStatement
     val stmt = conn.prepareStatement(sqlStmt)
     println(sqlStmt)
     /**
       * 在mysql中有這樣的操作:
       * INSERT INTO user_admin_t (_id,password) VALUES ('1','第一次插入的密碼')
       * INSERT INTO user_admin_t (_id,password)VALUES ('1','第一次插入的密碼') ON DUPLICATE KEY UPDATE _id = 'UpId',password = 'upPassword';
       * 如果是下面的ON DUPLICATE KEY操作,那么在prepareStatement中的游標會擴增一倍
       * 並且如果沒有update操作,那么他的游標是從0開始計數的
       * 如果是update操作,要算上之前的insert操作
       * */
       //makeSetter也要適配update操作,即游標問題


     val isUpdate = saveMode == CustomSaveMode.Update

     val setters: Array[JDBCValueSetter] = isUpdate match {
       case true =>
         val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
          .map(makeSetter(conn, dialect, _)).toArray
         Array.fill(2)(setters).flatten
       case _ =>
         rddSchema.fields.map(_.dataType)
          .map(makeSetter(conn, dialect, _)).toArray
    }


     val numFieldsLength = rddSchema.fields.length
     val numFields = isUpdate match{
       case true => numFieldsLength *2
       case _ => numFieldsLength
    }
     val cursorBegin = numFields / 2
     try {
       var rowCount = 0
       while (iterator.hasNext) {
         val row = iterator.next()
         var i = 0
         while (i < numFields) {
           if(isUpdate){
             //需要判斷當前游標是否走到了ON DUPLICATE KEY UPDATE
             i < cursorBegin match{
                 //說明還沒走到update階段
               case true =>
                 //row.isNullAt 判空,則設置空值
                 if (row.isNullAt(i)) {
                   stmt.setNull(i + 1, nullTypes(i))
                } else {
                   setters(i).apply(stmt, row, i, 0)
                }
                 //說明走到了update階段
               case false =>
                 if (row.isNullAt(i - cursorBegin)) {
                   //pos - offset
                   stmt.setNull(i + 1, nullTypes(i - cursorBegin))
                } else {
                   setters(i).apply(stmt, row, i, cursorBegin)
                }
            }
          }else{
             if (row.isNullAt(i)) {
               stmt.setNull(i + 1, nullTypes(i))
            } else {
               setters(i).apply(stmt, row, i ,0)
            }
          }
           //滾動游標
           i = i + 1
        }
         stmt.addBatch()
         rowCount += 1
         if (rowCount % batchSize == 0) {
           stmt.executeBatch()
           rowCount = 0
        }
      }
       if (rowCount > 0) {
         stmt.executeBatch()
      }
    } finally {
       stmt.close()
    }
     if (supportsTransactions) {
       conn.commit()
    }
     committed = true
     Iterator.empty
  } catch {
     case e: SQLException =>
       val cause = e.getNextException
       if (cause != null && e.getCause != cause) {
         if (e.getCause == null) {
           e.initCause(cause)
        } else {
           e.addSuppressed(cause)
        }
      }
       throw e
  } finally {
     if (!committed) {
       // The stage must fail. We got here through an exception path, so
       // let the exception through unless rollback() or close() want to
       // tell the user about another problem.
       if (supportsTransactions) {
         conn.rollback()
      }
       conn.close()
    } else {
       // The stage must succeed. We cannot propagate any exception close() might throw.
       try {
         conn.close()
      } catch {
         case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
      }
    }

 

// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
 // `PreparedStatement`. The last argument `Int` means the index for the value to be set
 // in the SQL statement and also used for the value in `Row`.
 //PreparedStatement, Row, position , cursor
 private type JDBCValueSetter = (PreparedStatement, Row, Int , Int) => Unit

 private def makeSetter(
     conn: Connection,
     dialect: JdbcDialect,
     dataType: DataType): JDBCValueSetter = dataType match {
   case IntegerType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setInt(pos + 1, row.getInt(pos - cursor))

   case LongType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setLong(pos + 1, row.getLong(pos - cursor))

   case DoubleType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setDouble(pos + 1, row.getDouble(pos - cursor))

   case FloatType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setFloat(pos + 1, row.getFloat(pos - cursor))

   case ShortType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setInt(pos + 1, row.getShort(pos - cursor))

   case ByteType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setInt(pos + 1, row.getByte(pos - cursor))

   case BooleanType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setBoolean(pos + 1, row.getBoolean(pos - cursor))

   case StringType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
//       println(row.getString(pos))
       stmt.setString(pos + 1, row.getString(pos - cursor))

   case BinaryType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - cursor))

   case TimestampType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - cursor))

   case DateType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - cursor))

   case t: DecimalType =>
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       stmt.setBigDecimal(pos + 1, row.getDecimal(pos - cursor))

   case ArrayType(et, _) =>
     // remove type length parameters from end of type name
     val typeName = getJdbcType(et, dialect).databaseTypeDefinition
      .toLowerCase.split("\\(")(0)
    (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
       val array = conn.createArrayOf(
         typeName,
         row.getSeq[AnyRef](pos - cursor).toArray)
       stmt.setArray(pos + 1, array)

   case _ =>
    (_: PreparedStatement, _: Row, pos: Int,cursor:Int) =>
       throw new IllegalArgumentException(
         s"Can't translate non-null value for field $pos")
}

 

完整代碼:

https://github.com/niutaofan/bazinga

 


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM