NightPxy 个人技术博客

Sql-自定义数据源

Posted on By NightPxy

数据源核心类

package org.apache.spark.sql.sources 中 物理文件 sql\core\src\main\scala\org\apache\spark\sql\sources\interfaces.scala

数据源

BaseRelation

BaseRelation 代表一个抽象的数据源

abstract class BaseRelation {
  def sqlContext: SQLContext
  def schema: StructType

  def sizeInBytes: Long = sqlContext.conf.defaultSizeInBytes
   /**
   * 如果设为false,将会始终以InternalRow的形式返回(即列类型不转化)
   * 如果设为true,将会以Row的形式返回(即InternalRow=>Row)  
   *      java.lang.String => UTF8String 
   *      java.lang.Decimal => Decimal 
   */
  def needConversion: Boolean = true

  // filter过滤器
  def unhandledFilters(filters: Array[Filter]): Array[Filter] = filters
}

TableScan

//描述 BaseRelation 如何全表扫描  
trait TableScan {
  def buildScan(): RDD[Row]
}

PrunedScan

//描述 BaseRelation 如何进行列裁剪  
trait PrunedScan {
  def buildScan(requiredColumns: Array[String]): RDD[Row]
}

PrunedFilterScan

//描述 BaseRelation 如何进行 列裁剪+谓词下推  
trait PrunedFilteredScan {
  def buildScan(
    requiredColumns: Array[String], 
    filters: Array[Filter]): RDD[Row]
}

InsertableRelation

//描述 BaseRelation 如何进行写入数据的  
trait InsertableRelation {
  /**
  * overwrite 如果为真,则应该用新数据覆盖旧数据
  *           如果为假,则应该以追加形式写入数据
  *
  * insert 有以下三个假设  
  *   1.  data: DataFrame 的字段顺序与BaseRelation中完全一致  
  *   2.  假定Insert中的Schema保持不变  
  *   3.  不保证 data: DataFrame 中字段不可空,需要自行判断         
  */
  def insert(data: DataFrame, overwrite: Boolean): Unit
}

CatalystScan

//描述 如何使用Catalyst引擎进行查询  
trait CatalystScan {
  def buildScan(
    requiredColumns: Seq[Attribute], 
    filters: Seq[Expression]): RDD[Row]
}

数据源构造器

RelationProvider

数据源构造类,数据源以这个类的实现作为构造依托
如果数据源没有实现该类,将会以 “org.apache.spark.sql.json.DefaultSource” 再试一次
RelationProvider 表示以自推断Schema的形式创建数据源

trait RelationProvider {
  def createRelation(
  	sqlContext: SQLContext, 
  	parameters: Map[String, String]): BaseRelation
}

SchemaRelationProvider

与 RelationProvider 类似 不同在于

  • RelationProvider 是以内部推断的形式产生Schema
  • SchemaRelationProvider 是以用户传入的形式产生Schema
  • 具体哪种依赖数据源目标本身,但可以同时继承这两个来获得即可以内部推断又可以外部传入
trait SchemaRelationProvider {
  def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      schema: StructType): BaseRelation
}

CreatableRelationProvider

数据源Output系调用此方法.如果要为数据源提供输出功能,也必须实现此类

trait CreatableRelationProvider {
  def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      data: DataFrame): BaseRelation
}

StreamSourceProvider

trait StreamSourceProvider {
  def sourceSchema(
      sqlContext: SQLContext,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): (String, StructType)

  def createSource(
      sqlContext: SQLContext,
      metadataPath: String,
      schema: Option[StructType],
      providerName: String,
      parameters: Map[String, String]): Source
}

trait StreamSinkProvider {
  def createSink(
      sqlContext: SQLContext,
      parameters: Map[String, String],
      partitionColumns: Seq[String],
      outputMode: OutputMode): Sink
}

JDBC 源码解析

JDBC 数据源定义

org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation


private[sql] case class JDBCRelation(
    parts: Array[Partition], 
    jdbcOptions: JDBCOptions)(@transient val sparkSession: SparkSession)
  extends BaseRelation //数据源
  with PrunedFilteredScan //支持列裁剪+谓词下推
  with InsertableRelation //支持写入

BaseRelation 实现

  override def sqlContext: SQLContext = sparkSession.sqlContext

  //不类型转化,也就是JDBC中是InternalRow(实际后面是强转的,没有类型转换是因为JDBC查表数据没有诸如UTF8String之类的InternalRow包装字段)  
  override val needConversion: Boolean = false

  //JDBC数据源的Schema自推断
  override val schema: StructType = {
    val tableSchema = JDBCRDD.resolveTable(jdbcOptions)
    //jdbc实际还是允许用户传入Schema的,但是相当的偏  
    jdbcOptions.customSchema match {
      case Some(customSchema) => JdbcUtils.getCustomSchema(
        tableSchema, customSchema, sparkSession.sessionState.conf.resolver)
      case None => tableSchema
    }
  }

  // Check if JDBCRDD.compileFilter can accept input filters
  override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
    filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
  }

PrunedFilteredScan 实现

override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
     // 强转为 RDD[Row]
    JDBCRDD.scanTable(
      sparkSession.sparkContext,
      schema,
      requiredColumns,
      filters,
      parts,
      jdbcOptions).asInstanceOf[RDD[Row]]
  }

InsertableRelation 实现

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
    data.write
      .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append)
      .jdbc(jdbcOptions.url, jdbcOptions.table, jdbcOptions.asProperties)
  }

JDBC数据源构造器

org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider


//数据源构造器  

class JdbcRelationProvider 
  extends CreatableRelationProvider //JDBC数据源支持输出  
  with RelationProvider // JDBC数据源中Schema以内部推断形式产生  
  with DataSourceRegister // JDBC数据源缩略名注册  

RelationProvider 实现


// 因为走的RelationProvider,所以必须内部产生Schema(当然JDBC本身也支持推断出Schema)  
// JDBC 的 Schema 推断依赖JDBC请求带回的Schema信息(字段名,字段类型等等)  

override def createRelation(
      sqlContext: SQLContext,
      parameters: Map[String, String]): BaseRelation = {
    import JDBCOptions._

    val jdbcOptions = new JDBCOptions(parameters)
    val partitionColumn = jdbcOptions.partitionColumn
    val lowerBound = jdbcOptions.lowerBound
    val upperBound = jdbcOptions.upperBound
    val numPartitions = jdbcOptions.numPartitions

    val partitionInfo = if (partitionColumn.isEmpty) {
      assert(lowerBound.isEmpty && upperBound.isEmpty, "When 'partitionColumn' is not specified, " +
        s"'$JDBC_LOWER_BOUND' and '$JDBC_UPPER_BOUND' are expected to be empty")
      null
    } else {
      assert(lowerBound.nonEmpty && upperBound.nonEmpty && numPartitions.nonEmpty,
        s"When 'partitionColumn' is specified, '$JDBC_LOWER_BOUND', '$JDBC_UPPER_BOUND', and " +
          s"'$JDBC_NUM_PARTITIONS' are also required")
      JDBCPartitioningInfo(
        partitionColumn.get, lowerBound.get, upperBound.get, numPartitions.get)
    }
    val parts = JDBCRelation.columnPartition(partitionInfo)
    //返回JDBC数据源  
    JDBCRelation(parts, jdbcOptions)(sqlContext.sparkSession)
  }

CreatableRelationProvider 实现


// JDBC数据源 的 Output 实现  

override def createRelation(
      sqlContext: SQLContext,
      mode: SaveMode,
      parameters: Map[String, String],
      df: DataFrame): BaseRelation = {
    val options = new JDBCOptions(parameters)
    val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis

    val conn = JdbcUtils.createConnectionFactory(options)()
    try {
      val tableExists = JdbcUtils.tableExists(conn, options)
      if (tableExists) {
        mode match {
          case SaveMode.Overwrite =>
            if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
              // In this case, we should truncate table and then load.
              truncateTable(conn, options)
              val tableSchema = JdbcUtils.getSchemaOption(conn, options)
              saveTable(df, tableSchema, isCaseSensitive, options)
            } else {
              // Otherwise, do not truncate the table, instead drop and recreate it
              dropTable(conn, options.table)
              createTable(conn, df, options)
              saveTable(df, Some(df.schema), isCaseSensitive, options)
            }

          case SaveMode.Append =>
            val tableSchema = JdbcUtils.getSchemaOption(conn, options)
            saveTable(df, tableSchema, isCaseSensitive, options)

          case SaveMode.ErrorIfExists =>
            throw new AnalysisException(
              s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.")

          case SaveMode.Ignore =>
            // With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
            // to not save the contents of the DataFrame and to not change the existing data.
            // Therefore, it is okay to do nothing here and then just return the relation below.
        }
      } else {
        createTable(conn, df, options)
        saveTable(df, Some(df.schema), isCaseSensitive, options)
      }
    } finally {
      conn.close()
    }

    createRelation(sqlContext, parameters)
  }

JDBC RDD

//这里是JDBC-RDD的核心
//JDBC-RDD的分区实现本质是条件表达式SQL,通过执行不同的条件表达式SQL拉取不同的数据来完成分区 
//这个条件表达式是在Relation中组装生成分区集,然后交由JDBCRDD
//JDBCRDD通过这个分区集(不同的条件表达式),再行拉取数据完成组装  
case class JDBCPartition(whereClause: String, idx: Int) extends Partition {
  override def index: Int = idx
}

private[jdbc] class JDBCRDD(
    sc: SparkContext,
    getConnection: () => Connection,
    schema: StructType,
    columns: Array[String],
    filters: Array[Filter],
    partitions: Array[Partition],
    url: String,
    options: JDBCOptions)
  extends RDD[InternalRow](sc, Nil) {

  /**
   * Retrieve the list of partitions corresponding to this RDD.
   */
  override def getPartitions: Array[Partition] = partitions

  /**
   * `columns`, but as a String suitable for injection into a SQL query.
   */
  private val columnList: String = {
    val sb = new StringBuilder()
    columns.foreach(x => sb.append(",").append(x))
    if (sb.isEmpty) "1" else sb.substring(1)
  }

  /**
   * `filters`, but as a WHERE clause suitable for injection into a SQL query.
   */
  private val filterWhereClause: String =
    filters
      .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
      .map(p => s"($p)").mkString(" AND ")

  /**
   * A WHERE clause representing both `filters`, if any, and the current partition.
   */
  private def getWhereClause(part: JDBCPartition): String = {
    if (part.whereClause != null && filterWhereClause.length > 0) {
      "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})"
    } else if (part.whereClause != null) {
      "WHERE " + part.whereClause
    } else if (filterWhereClause.length > 0) {
      "WHERE " + filterWhereClause
    } else {
      ""
    }
  }

  /**
   * Runs the SQL query against the JDBC driver.
   *
   */
  override def compute(thePart: Partition, context: TaskContext): Iterator[InternalRow] = {
    var closed = false
    var rs: ResultSet = null
    var stmt: PreparedStatement = null
    var conn: Connection = null

    def close() {
      if (closed) return
      try {
        if (null != rs) {
          rs.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing resultset", e)
      }
      try {
        if (null != stmt) {
          stmt.close()
        }
      } catch {
        case e: Exception => logWarning("Exception closing statement", e)
      }
      try {
        if (null != conn) {
          if (!conn.isClosed && !conn.getAutoCommit) {
            try {
              conn.commit()
            } catch {
              case NonFatal(e) => logWarning("Exception committing transaction", e)
            }
          }
          conn.close()
        }
        logInfo("closed connection")
      } catch {
        case e: Exception => logWarning("Exception closing connection", e)
      }
      closed = true
    }

    context.addTaskCompletionListener{ context => close() }

    val inputMetrics = context.taskMetrics().inputMetrics
    val part = thePart.asInstanceOf[JDBCPartition]
    conn = getConnection()
    val dialect = JdbcDialects.get(url)
    import scala.collection.JavaConverters._
    dialect.beforeFetch(conn, options.asProperties.asScala.toMap)

    // This executes a generic SQL statement (or PL/SQL block) before reading
    // the table/query via JDBC. Use this feature to initialize the database
    // session environment, e.g. for optimizations and/or troubleshooting.
    options.sessionInitStatement match {
      case Some(sql) =>
        val statement = conn.prepareStatement(sql)
        logInfo(s"Executing sessionInitStatement: $sql")
        try {
          statement.execute()
        } finally {
          statement.close()
        }
      case None =>
    }

    // H2's JDBC driver does not support the setSchema() method.  We pass a
    // fully-qualified table name in the SELECT statement.  I don't know how to
    // talk about a table in a completely portable way.

    val myWhereClause = getWhereClause(part)

    val sqlText = s"SELECT $columnList FROM ${options.table} $myWhereClause"
    stmt = conn.prepareStatement(sqlText,
        ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
    stmt.setFetchSize(options.fetchSize)
    rs = stmt.executeQuery()
    val rowsIterator = JdbcUtils.resultSetToSparkInternalRows(rs, schema, inputMetrics)

    CompletionIterator[InternalRow, Iterator[InternalRow]](
      new InterruptibleIterator(context, rowsIterator), close())
  }