本文使用Spark SQL 2.1.0版本
2. 使用代码
例如有个对象,除了包含简单的基本数据String,int之外还包含一个Location对象,就是所说的嵌套对象:
import java.io.Serializable;
public class Person implements Serializable {
private static final long serialVersionUID = 1L;
private String name;
private int age;
private Location location;
public String getName() {
return name;
}
public void setName(String name) {
this.name = name;
}
public int getAge() {
return age;
}
public void setAge(int age) {
this.age = age;
}
public Location getLocation() {
return location;
}
public void setLocation(Location location) {
this.location = location;
}
}
public class Location implements Serializable {
private static final long serialVersionUID = 1L;
private String city;
private String country;
public String getCity() {
return city;
}
public void setCity(String city) {
this.city = city;
}
public String getCountry() {
return country;
}
public void setCountry(String country) {
this.country = country;
}
}
Spark SQL代码:
// 初始化 spark session
SparkSession spark = SparkSession
.builder()
.appName("Java Spark SQL Schema test")
.master("local[*]")
.getOrCreate();
// 创建Person的JavaRDD
JavaRDD<Person> peopleRDD = spark
.read()
.textFile("examples/src/main/resources/people.txt")
.javaRDD().map(line -> {
String[] parts = line.split(",");
Person person = new Person();
person.setName(parts[0]);
person.setAge(Integer.parseInt(parts[1].trim()));
Location location = new Location();
location.setCity(parts[2].trim());
location.setCountry(parts[3].trim());
person.setLocation(location);
return person;
});
这里的people.txt格式为:
Michael, 29, Guangzhou, China
Andy, 30, Shenzhen, China
Justin, 19, Shanghai, China
// 应用Person对象的schema创建一个DataFrame
Dataset<Row> peopleDF = spark.createDataFrame(peopleRDD, Person.class);
// 打印 schema
peopleDF.printSchema();
因为根据http://spark.apache.org/docs/latest/sql-programming-guide.html#inferring-the-schema-using-reflection部分,嵌套的对象是支持的,所以可以自动识别schema:
root
|-- age: integer (nullable = false)
|-- location: struct (nullable = true)
| |-- city: string (nullable = true)
| |-- country: string (nullable = true)
|-- name: string (nullable = true)
但是执行sql并输出结果时就会报scala.MatchError异常,例如:
// 创建临时视图并执行sql查询
peopleDF.createOrReplaceTempView("people");
spark.sql("SELECT * FROM people WHERE age BETWEEN 13 AND 19").show();
异常信息:
scala.MatchError: Location@7dc6f69a (of class Location)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:383)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1113)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1113)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33)
at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1113)
at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1111)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:409)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIterator.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$8$$anon$1.hasNext(WholeStageCodegenExec.scala:377)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:231)
at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:225)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:826)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:826)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:323)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:287)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
at org.apache.spark.scheduler.Task.run(Task.scala:99)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:282)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
3. 解决方法
Spark SQL支持通过编程指定Schema,然后通过应用该Schema创建DataFrame:
// 创建 JavaRDD
JavaRDD<String> peopleRDD = spark.sparkContext()
.textFile("examples/src/main/resources/people2.txt", 1)
.toJavaRDD();
// 创建 name和age的StructField
StructField nameField = DataTypes.createStructField("name", DataTypes.StringType, true);
StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);
// 创建 location的结构
StructField cityField = DataTypes.createStructField("city", DataTypes.StringType, true);
StructField countryField = DataTypes.createStructField("country", DataTypes.StringType, true);
StructType locationStruct = DataTypes.createStructType(new StructField[] { cityField, countryField });
StructField locationField = DataTypes.createStructField("location", locationStruct, true);
// 创建StructType
List<StructField> fields = new ArrayList<StructField>();
fields.add(nameField);
fields.add(ageField);
fields.add(locationField);
StructType schema = DataTypes.createStructType(fields);
// 转换JavaRDD为Rows
JavaRDD<Row> rowRDD = peopleRDD.map(record -> {
String[] attributes = record.split(",");
return RowFactory.create(attributes[0], Integer.parseInt(attributes[1].trim()),
RowFactory.create(attributes[2].trim(), attributes[3].trim()));
});
// 应用schema创建DataFrame
Dataset<Row> peopleDataFrame = spark.createDataFrame(rowRDD, schema);
// 打印schema
peopleDataFrame.printSchema();
输出的结构和上面自动识别的是一致的:
root
|-- name: string (nullable = true)
|-- age: integer (nullable = true)
|-- location: struct (nullable = true)
| |-- city: string (nullable = true)
| |-- country: string (nullable = true)
// 创建临时视图并执行sql查询
peopleDataFrame.createOrReplaceTempView("people");
spark.sql("SELECT * FROM people").show(false);
输出结果:
+-------+---+-----------------+
|name |age|location |
+-------+---+-----------------+
|Michael|29 |[Guangzhou,China]|
|Andy |30 |[Shenzhen,China] |
|Justin |19 |[Shanghai,China] |
+-------+---+-----------------+
// 使用WHERE语句
spark.sql("SELECT * FROM people WHERE location.city = 'Guangzhou'").show(false);
输出结果:
+-------+---+-----------------+
|name |age|location |
+-------+---+-----------------+
|Michael|29 |[Guangzhou,China]|
+-------+---+-----------------+
4. 其它解决方法
创建对应的UDT类然后继承UserDefinedType<UserType>,实现里面的方法,使用UDTRegistration.register注册创建的UDT,具体可以参考Spark已经实现的UserDefinedType子类,例如MatrixUDT。