Spark操作MySQL,Hive并写入MySQL数据库

时间:2023-03-10 02:39:52
Spark操作MySQL,Hive并写入MySQL数据库

最近一个项目,需要操作近70亿数据进行统计分析。如果存入MySQL,很难读取如此大的数据,即使使用搜索引擎,也是非常慢。经过调研决定借助我们公司大数据平台结合Spark技术完成这么大数据量的统计分析。

为了后期方便开发人员开发,决定写了几个工具类,屏蔽对MySQL及Hive的操作代码,只需要关心业务代码的编写。

工具类如下:

一. Spark操作MySQL

1. 根据sql语句获取Spark DataFrame:

  /**
* 从MySql数据库中获取DateFrame
*
* @param spark SparkSession
* @param sql 查询SQL
* @return DateFrame
*/
def getDFFromMysql(spark: SparkSession, sql: String): DataFrame = {
println(s"url:${mySqlConfig.url} user:${mySqlConfig.user} sql: ${sql}")
spark.read.format("jdbc").option("url", mySqlConfig.url)
.option("user", mySqlConfig.user)
.option("password", mySqlConfig.password)
.option("driver", "com.mysql.jdbc.Driver")
.option("query", sql) .load()
}

2. 将Spark DataFrame 写入MySQL数据库表

  /**
* 将结果写入Mysql
* @param df DataFrame
* @param mode SaveMode
* @param tableName SaveMode
*/
def writeIntoMySql(df: DataFrame, mode: SaveMode, tableName: String): Unit ={
mode match {
case SaveMode.Append => appendDataIntoMysql(df, tableName);
case SaveMode.Overwrite => overwriteMysqlData(df, tableName);
case _ => throw new Exception("目前只支持Append及Overwrite!")
}
}
  /**
* 将数据集插入Mysql表
* @param df DataFrame
* @param mysqlTableName 表名:database_name.table_name
* @return
*/
def appendDataIntoMysql(df: DataFrame, mysqlTableName: String) = {
df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
}
  /**
* 将数据集插入Mysql表
* @param df DataFrame
* @param mysqlTableName 表名:database_name.table_name
* @return
*/
def overwriteMysqlData(df: DataFrame, mysqlTableName: String) = {
//先清除Mysql表中数据
truncateMysqlTable(mysqlTableName)
//再往表中追加数据
df.write.mode(SaveMode.Append).jdbc(mySqlConfig.url, mysqlTableName, getMysqlProp)
}
  /**
* 删除数据表
* @param mysqlTableName
* @return
*/
def truncateMysqlTable(mysqlTableName: String): Boolean = {
val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
val preparedStatement = conn.createStatement()
try {
preparedStatement.execute(s"truncate table $mysqlTableName")
} catch {
case e: Exception =>
println(s"mysql truncateMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
false
} finally {
preparedStatement.close()
conn.close()
}

3. 根据条件删除MySQL表数据

  /**
* 删除表中的数据
* @param mysqlTableName
* @param condition
* @return
*/
def deleteMysqlTableData(mysqlTableName: String, condition: String): Boolean = {
val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
val preparedStatement = conn.createStatement()
try {
preparedStatement.execute(s"delete from $mysqlTableName where $condition")
} catch {
case e: Exception =>
println(s"mysql deleteMysqlTable error:${ExceptionUtil.getExceptionStack(e)}")
false
} finally {
preparedStatement.close()
conn.close()
}
}

4. 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建

/**
* 保存DataFrame 到 MySQL中,如果表不存在的话,会自动创建
* @param tableName
* @param resultDateFrame
*/
def saveDFtoDBCreateTableIfNotExist(tableName: String, resultDateFrame: DataFrame) {
//如果没有表,根据DataFrame建表
createTableIfNotExist(tableName, resultDateFrame)
//验证数据表字段和dataFrame字段个数和名称,顺序是否一致
verifyFieldConsistency(tableName, resultDateFrame)
//保存df
saveDFtoDBUsePool(tableName, resultDateFrame)
}
  /**
* 如果数据表不存在,根据DataFrame的字段创建数据表,数据表字段顺序和dataFrame对应
* 若DateFrame出现名为id的字段,将其设为数据库主键(int,自增,主键),其他字段会根据DataFrame的DataType类型来自动映射到MySQL中
*
* @param tableName 表名
* @param df dataFrame
* @return
*/
def createTableIfNotExist(tableName: String, df: DataFrame): AnyVal = {
val con = MySQLPoolManager.getMysqlManager.getConnection
val metaData = con.getMetaData
val colResultSet = metaData.getColumns(null, "%", tableName, "%")
//如果没有该表,创建数据表
if (!colResultSet.next()) {
//构建建表字符串
val sb = new StringBuilder(s"CREATE TABLE `$tableName` (")
df.schema.fields.foreach(x =>
if (x.name.equalsIgnoreCase("id")) {
sb.append(s"`${x.name}` int(255) NOT NULL AUTO_INCREMENT PRIMARY KEY,") //如果是字段名为id,设置主键,整形,自增
} else {
x.dataType match {
case _: ByteType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
case _: ShortType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
case _: IntegerType => sb.append(s"`${x.name}` int(100) DEFAULT NULL,")
case _: LongType => sb.append(s"`${x.name}` bigint(100) DEFAULT NULL,")
case _: BooleanType => sb.append(s"`${x.name}` tinyint DEFAULT NULL,")
case _: FloatType => sb.append(s"`${x.name}` float(50) DEFAULT NULL,")
case _: DoubleType => sb.append(s"`${x.name}` double(50) DEFAULT NULL,")
case _: StringType => sb.append(s"`${x.name}` varchar(50) DEFAULT NULL,")
case _: TimestampType => sb.append(s"`${x.name}` timestamp DEFAULT current_timestamp,")
case _: DateType => sb.append(s"`${x.name}` date DEFAULT NULL,")
case _ => throw new RuntimeException(s"nonsupport ${x.dataType} !!!")
}
}
)
sb.append(") ENGINE=InnoDB DEFAULT CHARSET=utf8")
val sql_createTable = sb.deleteCharAt(sb.lastIndexOf(',')).toString()
println(sql_createTable)
val statement = con.createStatement()
statement.execute(sql_createTable)
}
}
  /**
* 验证数据表和dataFrame字段个数,名称,顺序是否一致
*
* @param tableName 表名
* @param df dataFrame
*/
def verifyFieldConsistency(tableName: String, df: DataFrame): Unit = {
val con = MySQLPoolManager.getMysqlManager.getConnection
val metaData = con.getMetaData
val colResultSet = metaData.getColumns(null, "%", tableName, "%")
colResultSet.last()
val tableFiledNum = colResultSet.getRow
val dfFiledNum = df.columns.length
if (tableFiledNum != dfFiledNum) {
throw new Exception(s"数据表和DataFrame字段个数不一致!!table--$tableFiledNum but dataFrame--$dfFiledNum")
}
for (i <- 1 to tableFiledNum) {
colResultSet.absolute(i)
val tableFileName = colResultSet.getString("COLUMN_NAME")
val dfFiledName = df.columns.apply(i - 1)
if (!tableFileName.equals(dfFiledName)) {
throw new Exception(s"数据表和DataFrame字段名不一致!!table--'$tableFileName' but dataFrame--'$dfFiledName'")
}
}
colResultSet.beforeFirst()
}
/**
* 将DataFrame所有类型(除id外)转换为String后,通过c3p0的连接池方法,向mysql写入数据
*
* @param tableName 表名
* @param resultDateFrame DataFrame
*/
def saveDFtoDBUsePool(tableName: String, resultDateFrame: DataFrame) {
val colNumbers = resultDateFrame.columns.length
val sql = getInsertSql(tableName, colNumbers)
val columnDataTypes = resultDateFrame.schema.fields.map(_.dataType)
resultDateFrame.foreachPartition(partitionRecords => {
val conn = MySQLPoolManager.getMysqlManager.getConnection //从连接池中获取一个连接
val preparedStatement = conn.prepareStatement(sql)
val metaData = conn.getMetaData.getColumns(null, "%", tableName, "%") //通过连接获取表名对应数据表的元数据
try {
conn.setAutoCommit(false)
partitionRecords.foreach(record => {
//注意:setString方法从1开始,record.getString()方法从0开始
for (i <- 1 to colNumbers) {
val value = record.get(i - 1)
val dateType = columnDataTypes(i - 1)
if (value != null) { //如何值不为空,将类型转换为String
preparedStatement.setString(i, value.toString)
dateType match {
case _: ByteType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
case _: ShortType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
case _: IntegerType => preparedStatement.setInt(i, record.getAs[Int](i - 1))
case _: LongType => preparedStatement.setLong(i, record.getAs[Long](i - 1))
case _: BooleanType => preparedStatement.setBoolean(i, record.getAs[Boolean](i - 1))
case _: FloatType => preparedStatement.setFloat(i, record.getAs[Float](i - 1))
case _: DoubleType => preparedStatement.setDouble(i, record.getAs[Double](i - 1))
case _: StringType => preparedStatement.setString(i, record.getAs[String](i - 1))
case _: TimestampType => preparedStatement.setTimestamp(i, record.getAs[Timestamp](i - 1))
case _: DateType => preparedStatement.setDate(i, record.getAs[Date](i - 1))
case _ => throw new RuntimeException(s"nonsupport ${dateType} !!!")
}
} else { //如果值为空,将值设为对应类型的空值
metaData.absolute(i)
preparedStatement.setNull(i, metaData.getInt("DATA_TYPE"))
}
}
preparedStatement.addBatch()
})
preparedStatement.executeBatch()
conn.commit()
} catch {
case e: Exception => println(s"@@ saveDFtoDBUsePool error: ${ExceptionUtil.getExceptionStack(e)}")
// do some log
} finally {
preparedStatement.close()
conn.close()
}
})
}

二、操作Spark

1. 切换Spark环境

定义环境Profile.scala

/**
* @descrption
* scf
* @author wangxuexing
* @date 2019/12/23
*/
object Profile extends Enumeration{
type Profile = Value
/**
* 生产环境
*/
val PROD = Value("prod")
/**
* 生产测试环境
*/
val PROD_TEST = Value("prod_test")
/**
* 开发环境
*/
val DEV = Value("dev") /**
* 设置当前环境
*/
val currentEvn = PROD
}

定义SparkUtil.scala

import com.dmall.scf.Profile
import com.dmall.scf.dto.{Env, MySqlConfig}
import org.apache.spark.sql.{DataFrame, Encoder, SparkSession} import scala.collection.JavaConversions._ /**
* @descrption Spark工具类
* scf
* @author wangxuexing
* @date 2019/12/23
*/
object SparkUtils {
//开发环境

val DEV_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"

val DEV_USER = "user"

val DEV_PASSWORD = "password"

//生产测试环境

val PROD_TEST_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&zeroDateTimeBehavior=convertToNull&useSSL=false"

val PROD_TEST_USER = "user"

val PROD_TEST_PASSWORD = "password"

//生产环境

val PROD_URL = "jdbc:mysql://IP:PORT/db_name?useUnicode=true&characterEncoding=UTF-8&autoReconnect=true&failOverReadOnly=false&useSSL=false"

val PROD_USER = "user"

val PROD_PASSWORD = "password"

  def env = Profile.currentEvn

  /**
* 获取环境设置
* @return
*/
def getEnv: Env ={
env match {
case Profile.DEV => Env(MySqlConfig(DEV_URL, DEV_USER, DEV_PASSWORD), SparkUtils.getDevSparkSession)
case Profile.PROD =>
Env(MySqlConfig(PROD_URL,PROD_USER,PROD_PASSWORD), SparkUtils.getProdSparkSession)
case Profile.PROD_TEST =>
Env(MySqlConfig(PROD_TEST_URL, PROD_TEST_USER, PROD_TEST_PASSWORD), SparkUtils.getProdSparkSession)
case _ => throw new Exception("无法获取环境")
}
} /**
* 获取生产SparkSession
* @return
*/
def getProdSparkSession: SparkSession = {
SparkSession
.builder()
.appName("scf")
.enableHiveSupport()//激活hive支持
.getOrCreate()
} /**
* 获取开发SparkSession
* @return
*/
def getDevSparkSession: SparkSession = {
SparkSession
.builder()
.master("local[*]")
.appName("local-1576939514234")
.config("spark.sql.warehouse.dir", "C:\\data\\spark-ware")//不指定,默认C:\data\projects\parquet2dbs\spark-warehouse
.enableHiveSupport()//激活hive支持
.getOrCreate();
} /**
* DataFrame 转 case class
* @param df DataFrame
* @tparam T case class
* @return
*/
def dataFrame2Bean[T: Encoder](df: DataFrame, clazz: Class[T]): List[T] = {
val fieldNames = clazz.getDeclaredFields.map(f => f.getName).toList
df.toDF(fieldNames: _*).as[T].collectAsList().toList
}
}

三、定义Spark操作流程

从MySQL或Hive读取数据->逻辑处理->写入MySQL

1. 定义处理流程

SparkAction.scala

import com.dmall.scf.utils.{MySQLUtils, SparkUtils}
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} /**
* @descrption 定义Spark处理流程
* @author wangxuexing
* @date 2019/12/23
*/
trait SparkAction[T] {
/**
* 定义流程
*/
def execute(args: Array[String], spark: SparkSession)={
//1. 前置处理
preAction
//2. 处理
val df = action(spark, args)
//3. 后置处理
postAction(df)
} /**
* 前置处理
* @return
*/
def preAction() = {
//无前置处理
} /**
* 处理
* @param spark
* @return
*/
def action(spark: SparkSession, args: Array[String]) : DataFrame /**
* 后置处理,比如保存结果到Mysql
* @param df
*/
def postAction(df: DataFrame)={
//结果追加到scfc_supplier_run_field_value表
MySQLUtils.writeIntoMySql(df, saveTable._1, saveTable._2)
} /**
* 保存mode及表名
* @return
*/
def saveTable: (SaveMode, String)
}

2. 实现流程

KanbanAction.scala

import com.dmall.scf.SparkAction
import com.dmall.scf.dto.KanbanFieldValue
import org.apache.spark.sql.types.{LongType, StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession} import scala.collection.JavaConverters._ /**
* @descrption
* scf-spark
* @author wangxuexing
* @date 2020/1/10
*/
trait KanbanAction extends SparkAction[KanbanFieldValue] {
/**
* 获取datafram
* @param resultList
* @param spark
* @return
*/
def getDataFrame(resultList: List[KanbanFieldValue], spark: SparkSession): DataFrame= {
//根据模式字符串生成模式schema
val fields = List(StructField("company_id", LongType, nullable = false),
StructField("statistics_date", StringType, nullable = false),
StructField("field_id", LongType, nullable = false),
StructField("field_type", StringType, nullable = false),
StructField("field_value", StringType, nullable = false),
StructField("other_value", StringType, nullable = false))
val schema = StructType(fields)
//将RDD的记录转换为行
val rowRDD = resultList.map(x=>Row(x.companyId, x.statisticsDate, x.fieldId, x.fieldType, x.fieldValue, x.otherValue)).asJava
//RDD转为DataFrame
spark.createDataFrame(rowRDD, schema)
}
/**
* 保存mode及表名
*
* @return
*/
override def saveTable: (SaveMode, String) = (SaveMode.Append, "scfc_kanban_field_value")
}

3. 实现具体业务逻辑

import com.dmall.scf.dto.{KanbanFieldValue, RegisteredMoney}
import com.dmall.scf.utils.{DateUtils, MySQLUtils}
import org.apache.spark.sql.{DataFrame, SparkSession} /**
* @descrption
* scf-spark 注册资本分布
* @author wangxuexing
* @date 2020/1/10
*/
object RegMoneyDistributionAction extends KanbanAction{
val CLASS_NAME = this.getClass.getSimpleName().filter(!_.equals('$')) val RANGE_50W = BigDecimal(50)
val RANGE_100W = BigDecimal(100)
val RANGE_500W = BigDecimal(500)
val RANGE_1000W = BigDecimal(1000) /**
* 处理
*
* @param spark
* @return
*/
override def action(spark: SparkSession, args: Array[String]): DataFrame = {
import spark.implicits._
if(args.length < 2){
throw new Exception("请指定是当前年(值为1)还是去年(值为2):1|2")
}
val lastDay = DateUtils.addSomeDays(-1)
val (starDate, endDate, filedId) = args(1) match {
case "1" =>
val startDate = DateUtils.isFirstDayOfYear match {
case true => DateUtils.getFirstDateOfLastYear
case false => DateUtils.getFirstDateOfCurrentYear
} (startDate, DateUtils.formatNormalDateStr(lastDay), 44)
case "2" =>
val startDate = DateUtils.isFirstDayOfYear match {
case true => DateUtils.getLast2YearFirstStr(DateUtils.YYYY_MM_DD)
case false => DateUtils.getLastYearFirstStr(DateUtils.YYYY_MM_DD)
}
val endDate = DateUtils.isFirstDayOfYear match {
case true => DateUtils.getLast2YearLastStr(DateUtils.YYYY_MM_DD)
case false => DateUtils.getLastYearLastStr(DateUtils.YYYY_MM_DD)
}
(startDate, endDate, 45)
case _ => throw new Exception("请传入正确的参数:是当前年(值为1)还是去年(值为2):1|2")
} val sql = s"""SELECT
id,
IFNULL(registered_money, 0) registered_money
FROM
scfc_supplier_info
WHERE
`status` = 3
AND yn = 1"""
val allDimension = MySQLUtils.getDFFromMysql(spark, sql)
val beanList = allDimension.map(x => RegisteredMoney(x.getLong(0), x.getDecimal(1)))
//val filterList = SparkUtils.dataFrame2Bean[RegisteredMoney](allDimension, classOf[RegisteredMoney])
val hiveSql = s"""
SELECT DISTINCT(a.company_id) supplier_ids
FROM wumart2dmall.wm_ods_cx_supplier_card_info a
JOIN wumart2dmall.wm_ods_jrbl_loan_dkzhxx b ON a.card_code = b.gshkahao
WHERE a.audit_status = '2'
AND b.jiluztai = '0'
AND to_date(b.gxinshij)>= '${starDate}'
AND to_date(b.gxinshij)<= '${endDate}'"""
println(hiveSql)
val supplierIds = spark.sql(hiveSql).collect().map(_.getLong(0))
val filterList = beanList.filter(x => supplierIds.contains(x.supplierId)) val range1 = spark.sparkContext.collectionAccumulator[Int]
val range2 = spark.sparkContext.collectionAccumulator[Int]
val range3 = spark.sparkContext.collectionAccumulator[Int]
val range4 = spark.sparkContext.collectionAccumulator[Int]
val range5 = spark.sparkContext.collectionAccumulator[Int]
filterList.foreach(x => {
if(RANGE_50W.compare(x.registeredMoney) >= 0){
range1.add(1)
} else if (RANGE_50W.compare(x.registeredMoney) < 0 && RANGE_100W.compare(x.registeredMoney) >= 0){
range1.add(1)
} else if (RANGE_100W.compare(x.registeredMoney) < 0 && RANGE_500W.compare(x.registeredMoney) >= 0){
range2.add(1)
} else if (RANGE_500W.compare(x.registeredMoney) < 0 && RANGE_1000W.compare(x.registeredMoney) >= 0){
range3.add(1)
} else if (RANGE_1000W.compare(x.registeredMoney) < 0){
range4.add(1)
}
})
val resultList = List(("50万元以下", range1.value.size()), ("50-100万元", range2.value.size()),
("100-500万元", range3.value.size()),("500-1000万元", range4.value.size()),
("1000万元以上", range5.value.size())).map(x => {
KanbanFieldValue(1, lastDay, filedId, x._1, x._2.toString, "")
}) getDataFrame(resultList, spark)
}
}

具体项目源码请参考:

https://github.com/barrywang88/spark-tool

https://gitee.com/barrywang/spark-tool