1. 即调用dataset.write.format("jdbc").save保存数据到数据库
dataset.write.mode("append").format("jdbc")
.option("driver", "oracle.jdbc.driver.OracleDriver"))
.option("url", "xxx")
.option("dbtable", "tablename")
.option("user", "xxx")
.option("password", "xxx")
.option("truncate", "true")
.option("isolationLevel", "NONE") //不开启事务
.option("batchsize", 5000) //设置批量插入
.save();
dataset.write会返回DataframeWriter(DataframeWriter就是用于写dataset到外部存储系统的)。mode指定了写的方式,比如append、overwrite等。format指定底层输出数据源,比如jdbc、parquet、json等。option为底层数据源添加输出选项,所以可以指定用户密码、表、驱动、批量大小等。save则将Dataframe的内容保存为指定的表。
2. 整体api调用流程
3. 下面对源码进行分析
->save()
在save()方法内,会调用lookupDataSource()方法对判断当前的Source类型来执行不同的写入方法。这里传入的是jdbc,返回的cls是
org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
所以接下来是走向saveToV1Source()
def save(): Unit = {
if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException("Hive data source can only be used with tables, you can not " +
"write files of Hive data source directly.")
}
assertNotBucketed("save")
val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
ds match {
case ws: WriteSupport =>
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = df.sparkSession.sessionState.conf)
val options = new DataSourceOptions((sessionOptions ++ extraOptions).asJava)
// Using a timestamp and a random UUID to distinguish different writing jobs. This is good
// enough as there won't be tons of writing jobs created at the same second.
val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US)
.format(new Date()) + "-" + UUID.randomUUID()
val writer = ws.createWriter(jobId, df.logicalPlan.schema, mode, options)
if (writer.isPresent) {
runCommand(df.sparkSession, "save") {
WriteToDataSourceV2(writer.get(), df.planWithBarrier)
}
}
// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch writes. In that case, we fall back to saving
// as though it's a V1 source.
case _ => saveToV1Source()
}
} else {
saveToV1Source()
}
}
->saveToV1Source()
saveToV1Source()调用的是planForWriting(mode, df.planWithBarrier)
private def saveToV1Source(): Unit = {
// Code path for data source v1.
runCommand(df.sparkSession, "save") {
DataSource(
sparkSession = df.sparkSession,
className = source,
partitionColumns = partitioningColumns.getOrElse(Nil),
options = extraOptions.toMap).planForWriting(mode, df.planWithBarrier)
}
}
->planForWriting()
这里providingClass的初始化是
lazy val providingClass: Class[_] =
DataSource.lookupDataSource(className, sparkSession.sessionState.conf)
而className的传入值是jdbc,所以最终providingClass的返回值是
org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider
而JdbcRelationProvider又是继承于CreatableRelationProvider,所以这里调用的就是
SaveIntoDataSourceCommand()
def planForWriting(mode: SaveMode, data: LogicalPlan): LogicalPlan = {
if (data.schema.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) {
throw new AnalysisException("Cannot save interval data type into external storage.")
}
providingClass.newInstance() match {
case dataSource: CreatableRelationProvider =>
SaveIntoDataSourceCommand(data, dataSource, caseInsensitiveOptions, mode)
case format: FileFormat =>
planForWritingFileFormat(format, mode, data)
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
}
->SaveIntoDataSourceCommand()
run方法里调用createRelation
case class SaveIntoDataSourceCommand(
query: LogicalPlan,
dataSource: CreatableRelationProvider,
options: Map[String, String],
mode: SaveMode) extends RunnableCommand {
override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
override def run(sparkSession: SparkSession): Seq[Row] = {
dataSource.createRelation(
sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query))
Seq.empty[Row]
}
override def simpleString: String = {
val redacted = SQLConf.get.redactOptions(options)
s"SaveIntoDataSourceCommand ${dataSource}, ${redacted}, ${mode}"
}
}
->createRelation()
这里会先判断jdbc里表是否存在,假如存在则根据模式或truncate表,或drop再新建等再调用saveTable写入数据;不存在则直接创建createTable新表,saveTable写入数据。
def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: Dataframe): baseRelation
}
这是个Trait内的方法,所以最终调用的是JdbcRelationProvider内的方法
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
df: Dataframe): baseRelation = {
val options = new JDBCOptions(parameters)
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
val conn = JdbcUtils.createConnectionFactory(options)()
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
mode match {
case SaveMode.Overwrite =>
if (options.isTruncate && isCascadingTruncateTable(options.url) == Some(false)) {
// In this case, we should truncate table and then load.
truncateTable(conn, options)
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options)
} else {
// Otherwise, do not truncate the table, instead drop and recreate it
dropTable(conn, options.table)
createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
case SaveMode.Append =>
val tableSchema = JdbcUtils.getSchemaOption(conn, options)
saveTable(df, tableSchema, isCaseSensitive, options)
case SaveMode.ErrorIfExists =>
throw new AnalysisException(
s"Table or view '${options.table}' already exists. SaveMode: ErrorIfExists.")
case SaveMode.Ignore =>
// With `SaveMode.Ignore` mode, if table already exists, the save operation is expected
// to not save the contents of the Dataframe and to not change the existing data.
// Therefore, it is okay to do nothing here and then just return the relation below.
}
} else {
createTable(conn, df, options)
saveTable(df, Some(df.schema), isCaseSensitive, options)
}
} finally {
conn.close()
}
createRelation(sqlContext, parameters)
}
->saveTable()
里面就是获取具体底层的JdbcDialects,然后获取插入语句,然后调用savePartition()写入
def saveTable(
df: Dataframe,
tableSchema: Option[StructType],
isCaseSensitive: Boolean,
options: JDBCOptions): Unit = {
val url = options.url
val table = options.table
val dialect = JdbcDialects.get(url)
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
val insertStmt = getInsertStatement(table, rddSchema, tableSchema, isCaseSensitive, dialect)
val repartitionedDF = options.numPartitions match {
case Some(n) if n <= 0 => throw new IllegalArgumentException(
s"Invalid value `$n` for parameter `${JDBCOptions.JDBC_NUM_PARTITIONS}` in table writing " +
"via JDBC. The minimum value is 1.")
case Some(n) if n < df.rdd.getNumPartitions => df.coalesce(n)
case _ => df
}
repartitionedDF.rdd.foreachPartition(iterator => savePartition(
getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel)
)
}
->savePartition()
里面就是调用jdbc的api批量写入。
def savePartition(
getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
insertStmt: String,
batchSize: Int,
dialect: JdbcDialect,
isolationLevel: Int): Iterator[Byte] = {
val conn = getConnection()
......
val stmt = conn.prepareStatement(insertStmt)
......
stmt.addBatch()
......
}



