首頁 > 軟體

教你如何讓spark sql寫mysql的時候支援update操作

2022-02-15 19:01:40

如何讓sparkSQL在對接mysql的時候,除了支援:Append、Overwrite、ErrorIfExists、Ignore;還要在支援update操作

1、首先了解背景

spark提供了一個列舉類,用來支撐對接資料來源的操作模式

通過原始碼檢視,很明顯,spark是不支援update操作的

2、如何讓sparkSQL支援update

關鍵的知識點就是:

我們正常在sparkSQL寫資料到mysql的時候:

大概的api是:

dataframe.write
          .format("sql.execution.customDatasource.jdbc")
          .option("jdbc.driver", "com.mysql.jdbc.Driver")
          .option("jdbc.url", "jdbc:mysql://localhost:3306/test?user=root&password=&useUnicode=true&characterEncoding=gbk&autoReconnect=true&failOverReadOnly=false")
          .option("jdbc.db", "test")
          .save()

那麼在底層中,spark會通過JDBC方言JdbcDialect , 將我們要插入的資料翻譯成:

insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)

那麼通過方言解析出的sql語句就通過PrepareStatement的executeBatch(),將sql語句提交給mysql,然後資料插入;

那麼上面的sql語句很明顯,完全就是插入程式碼,並沒有我們期望的 update操作,類似:

UPDATE table_name SET field1=new-value1, field2=new-value2

但是mysql獨家支援這樣的sql語句:

INSERT INTO student (columns_1,columns_2)VALUES ('第一個欄位值','第二個欄位值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';

大概的意思就是,如果資料不存在則插入,如果資料存在,則 執行update操作;

因此,我們的切入點就是,讓sparkSQL內部對接JdbcDialect的時候,能夠生成這種sql:

INSERT INTO 表名稱 (columns_1,columns_2)VALUES ('第一個欄位值','第二個欄位值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';

3、改造原始碼前,需要了解整體的程式碼設計和執行流程

首先是:

dataframe.write

呼叫write方法就是為了返回一個類:DataFrameWriter

主要是因為DataFrameWriter是sparksql對接外部資料來源寫入的入口攜帶類,下面這些內容是給DataFrameWriter註冊的攜帶資訊

然後在出發save()操作後,就開始將資料寫入;

接下來看save()原始碼:

在上面的原始碼裡面主要是註冊DataSource範例,然後使用DataSource的write方法進行資料寫入

範例化DataSource的時候:

def save(): Unit = {
    assertNotBucketed("save")
    val dataSource = DataSource(
      df.sparkSession,
      className = source,//自定義資料來源的包路徑
      partitionColumns = partitioningColumns.getOrElse(Nil),//分割區欄位
      bucketSpec = getBucketSpec,//分桶(用於hive)
      options = extraOptions.toMap)//傳入的註冊資訊
    //mode:插入資料方式SaveMode , df:要插入的資料
    dataSource.write(mode, df)
  }

然後就是dataSource.write(mode, df)的細節,整段的邏輯就是:

根據providingClass.newInstance()去做模式匹配,然後匹配到哪裡,就執行哪裡的程式碼;

然後看下providingClass是什麼:

拿到包路徑.DefaultSource之後,程式進入:

那麼如果是資料庫作為寫入目標的話,就會走:dataSource.createRelation,直接跟進原始碼:

很明顯是個特質,因此哪裡實現了特質,程式就會走到哪裡了;

實現這個特質的地方就是:包路徑.DefaultSource , 然後就在這裡面去實現資料的插入和update的支援操作;

4、改造原始碼

根據程式碼的流程,最終sparkSQL 將資料寫入mysql的操作,會進入:包路徑.DefaultSource這個類裡面;

也就是說,在這個類裡面既要支援spark的正常插入操作(SaveMode),還要在支援update;

如果讓sparksql支援update操作,最關鍵的就是做一個判斷,比如:

if(isUpdate){
    sql語句:INSERT INTO student (columns_1,columns_2)VALUES ('第一個欄位值','第二個欄位值') ON DUPLICATE KEY UPDATE columns_1 = '呵呵噠',columns_2 = '哈哈噠';
}else{
    insert into student (columns_1 , columns_2 , ...) values (? , ? , ....)
}

但是,在spark生產sql語句的原始碼中,是這樣寫的:

沒有任何的判斷邏輯,就是最後生成一個:

INSERT INTO TABLE (欄位1 , 欄位2....) VALUES (? , ? ...)

所以首要的任務就是 ,怎麼能讓當前程式碼支援:ON DUPLICATE KEY UPDATE

可以做個大膽的設計,就是在insertStatement這個方法中做個如下的判斷

def insertStatement(conn: Connection, savemode:CustomSaveMode , table: String, rddSchema: StructType, dialect: JdbcDialect)
      : PreparedStatement = {
    val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
    val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
    if(savemode == CustomSaveMode.update){
        //TODO 如果是update,就組裝成ON DUPLICATE KEY UPDATE的模式處理
        s"INSERT INTO $table ($columns) VALUES ($placeholders) ON DUPLICATE KEY UPDATE $duplicateSetting"
    }esle{
        val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
        conn.prepareStatement(sql)
    }
    
  }

這樣,在使用者傳遞進來的savemode模式,我們進行校驗,如果是update操作,就返回對應的sql語句!

所以按照上面的邏輯,我們程式碼這樣寫:

這樣我們就拿到了對應的sql語句;

但是隻有這個sql語句還是不行的,因為在spark中會執行jdbc的prepareStatement操作,這裡面會涉及到遊標。

即jdbc在遍歷這個sql的時候,原始碼會這樣做:

看下makeSetter:

所謂有坑就是:

insert into table (欄位1 , 欄位2, 欄位3) values (? , ? , ?)

那麼當前在原始碼中返回的陣列長度應該是3:

val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
        .map(makeSetter(conn, dialect, _)).toArray

但是如果我們此時支援了update操作,既:

insert into table (欄位1 , 欄位2, 欄位3) values (? , ? , ?) ON DUPLICATE KEY UPDATE 欄位1 = ?,欄位2 = ?,欄位3=?;

那麼很明顯,上面的sql語句提供了6個? , 但在規定欄位長度的時候只有3

這樣的話,後面的update操作就無法執行,程式報錯!

所以我們需要有一個 識別機制,既:

if(isupdate){
    val numFields = rddSchema.fields.length * 2
}else{
    val numFields = rddSchema.fields.length
}

row[1,2,3] setter(0,1) //index of setter , index of row setter(1,2) setter(2,3) setter(3,1) setter(4,2) setter(5,3)

所以在prepareStatment中的預留位置應該是row的兩倍,而且應該是類似這樣的一個邏輯

因此,程式碼改造前樣子:

改造後的樣子:

try {
      if (supportsTransactions) {
        conn.setAutoCommit(false) // Everything in the same db transaction.
        conn.setTransactionIsolation(finalIsolationLevel)
      }
//      val stmt = insertStatement(conn, table, rddSchema, dialect)
      //此處採用最新自己的sql語句,封裝成prepareStatement
      val stmt = conn.prepareStatement(sqlStmt)
      println(sqlStmt)
      /**
        * 在mysql中有這樣的操作:
        * INSERT INTO user_admin_t (_id,password) VALUES ('1','第一次插入的密碼')
        * INSERT INTO user_admin_t (_id,password)VALUES ('1','第一次插入的密碼') ON DUPLICATE KEY UPDATE _id = 'UpId',password = 'upPassword';
        * 如果是下面的ON DUPLICATE KEY操作,那麼在prepareStatement中的遊標會擴增一倍
        * 並且如果沒有update操作,那麼他的遊標是從0開始計數的
        * 如果是update操作,要算上之前的insert操作
        * */
        //makeSetter也要適配update操作,即遊標問題
​
      val isUpdate = saveMode == CustomSaveMode.Update
      val setters: Array[JDBCValueSetter] = isUpdate match {
        case true =>
          val setters: Array[JDBCValueSetter] = rddSchema.fields.map(_.dataType)
            .map(makeSetter(conn, dialect, _)).toArray
          Array.fill(2)(setters).flatten
        case _ =>
          rddSchema.fields.map(_.dataType)
      val numFieldsLength = rddSchema.fields.length
      val numFields = isUpdate match{
        case true => numFieldsLength *2
        case _ => numFieldsLength
      val cursorBegin = numFields / 2
      try {
        var rowCount = 0
        while (iterator.hasNext) {
          val row = iterator.next()
          var i = 0
          while (i < numFields) {
            if(isUpdate){
              //需要判斷當前遊標是否走到了ON DUPLICATE KEY UPDATE
              i < cursorBegin match{
                  //說明還沒走到update階段
                case true =>
                  //row.isNullAt 判空,則設定空值
                  if (row.isNullAt(i)) {
                    stmt.setNull(i + 1, nullTypes(i))
                  } else {
                    setters(i).apply(stmt, row, i, 0)
                  }
                  //說明走到了update階段
                case false =>
                  if (row.isNullAt(i - cursorBegin)) {
                    //pos - offset
                    stmt.setNull(i + 1, nullTypes(i - cursorBegin))
                    setters(i).apply(stmt, row, i, cursorBegin)
              }
            }else{
              if (row.isNullAt(i)) {
                stmt.setNull(i + 1, nullTypes(i))
              } else {
                setters(i).apply(stmt, row, i ,0)
            }
            //捲動遊標
            i = i + 1
          }
          stmt.addBatch()
          rowCount += 1
          if (rowCount % batchSize == 0) {
            stmt.executeBatch()
            rowCount = 0
        }
        if (rowCount > 0) {
          stmt.executeBatch()
      } finally {
        stmt.close()
        conn.commit()
      committed = true
      Iterator.empty
    } catch {
      case e: SQLException =>
        val cause = e.getNextException
        if (cause != null && e.getCause != cause) {
          if (e.getCause == null) {
            e.initCause(cause)
          } else {
            e.addSuppressed(cause)
        throw e
    } finally {
      if (!committed) {
        // The stage must fail.  We got here through an exception path, so
        // let the exception through unless rollback() or close() want to
        // tell the user about another problem.
        if (supportsTransactions) {
          conn.rollback()
        conn.close()
      } else {
        // The stage must succeed.  We cannot propagate any exception close() might throw.
        try {
          conn.close()
        } catch {
          case e: Exception => logWarning("Transaction succeeded, but closing failed", e)
// A `JDBCValueSetter` is responsible for setting a value from `Row` into a field for
  // `PreparedStatement`. The last argument `Int` means the index for the value to be set
  // in the SQL statement and also used for the value in `Row`.
  //PreparedStatement, Row, position , cursor
  private type JDBCValueSetter = (PreparedStatement, Row, Int , Int) => Unit
​
  private def makeSetter(
      conn: Connection,
      dialect: JdbcDialect,
      dataType: DataType): JDBCValueSetter = dataType match {
    case IntegerType =>
      (stmt: PreparedStatement, row: Row, pos: Int,cursor:Int) =>
        stmt.setInt(pos + 1, row.getInt(pos - cursor))
    case LongType =>
        stmt.setLong(pos + 1, row.getLong(pos - cursor))
    case DoubleType =>
        stmt.setDouble(pos + 1, row.getDouble(pos - cursor))
    case FloatType =>
        stmt.setFloat(pos + 1, row.getFloat(pos - cursor))
    case ShortType =>
        stmt.setInt(pos + 1, row.getShort(pos - cursor))
    case ByteType =>
        stmt.setInt(pos + 1, row.getByte(pos - cursor))
    case BooleanType =>
        stmt.setBoolean(pos + 1, row.getBoolean(pos - cursor))
    case StringType =>
//        println(row.getString(pos))
        stmt.setString(pos + 1, row.getString(pos - cursor))
    case BinaryType =>
        stmt.setBytes(pos + 1, row.getAs[Array[Byte]](pos - cursor))
    case TimestampType =>
        stmt.setTimestamp(pos + 1, row.getAs[java.sql.Timestamp](pos - cursor))
    case DateType =>
        stmt.setDate(pos + 1, row.getAs[java.sql.Date](pos - cursor))
    case t: DecimalType =>
        stmt.setBigDecimal(pos + 1, row.getDecimal(pos - cursor))
    case ArrayType(et, _) =>
      // remove type length parameters from end of type name
      val typeName = getJdbcType(et, dialect).databaseTypeDefinition
        .toLowerCase.split("\(")(0)
        val array = conn.createArrayOf(
          typeName,
          row.getSeq[AnyRef](pos - cursor).toArray)
        stmt.setArray(pos + 1, array)
    case _ =>
      (_: PreparedStatement, _: Row, pos: Int,cursor:Int) =>
        throw new IllegalArgumentException(
          s"Can't translate non-null value for field $pos")
  }

完整程式碼:

https://github.com/niutaofan/bazinga

到此這篇關於教你如何讓spark sql寫mysql的時候支援update操作的文章就介紹到這了,更多相關spark sql寫mysql支援update內容請搜尋it145.com以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援it145.com!


IT145.com E-mail:sddin#qq.com