Skip to content

[SPARK-10289] [SQL] A direct write API for testing Parquet #8454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,15 @@

package org.apache.spark.sql.execution.datasources.parquet

import scala.collection.JavaConverters._
import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapAsJavaMapConverter, seqAsJavaListConverter}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the specific imports?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we should be explicit and avoid wildcard imports according to our style guide. But just realized it's OK to have them for implicit methods.


import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
import org.apache.parquet.hadoop.api.WriteSupport
import org.apache.parquet.hadoop.api.WriteSupport.WriteContext
import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter}
import org.apache.parquet.io.api.RecordConsumer
import org.apache.parquet.schema.{MessageType, MessageTypeParser}

import org.apache.spark.sql.QueryTest

Expand All @@ -38,11 +42,10 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
override def accept(path: Path): Boolean = pathFilter(path)
}).toSeq
}).toSeq.asJava

val footers =
ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles.asJava, true)
footers.iterator().next().getParquetMetadata.getFileMetaData.getSchema
val footers = ParquetFileReader.readAllFootersInParallel(configuration, parquetFiles, true)
footers.asScala.head.getParquetMetadata.getFileMetaData.getSchema
}

protected def logParquetSchema(path: String): Unit = {
Expand All @@ -53,8 +56,69 @@ private[sql] abstract class ParquetCompatibilityTest extends QueryTest with Parq
}
}

object ParquetCompatibilityTest {
def makeNullable[T <: AnyRef](i: Int)(f: => T): T = {
if (i % 3 == 0) null.asInstanceOf[T] else f
private[sql] object ParquetCompatibilityTest {
implicit class RecordConsumerDSL(consumer: RecordConsumer) {
def message(f: => Unit): Unit = {
consumer.startMessage()
f
consumer.endMessage()
}

def group(f: => Unit): Unit = {
consumer.startGroup()
f
consumer.endGroup()
}

def field(name: String, index: Int)(f: => Unit): Unit = {
consumer.startField(name, index)
f
consumer.endField(name, index)
}
}

/**
* A testing Parquet [[WriteSupport]] implementation used to write manually constructed Parquet
* records with arbitrary structures.
*/
private class DirectWriteSupport(schema: MessageType, metadata: Map[String, String])
extends WriteSupport[RecordConsumer => Unit] {

private var recordConsumer: RecordConsumer = _

override def init(configuration: Configuration): WriteContext = {
new WriteContext(schema, metadata.asJava)
}

override def write(recordWriter: RecordConsumer => Unit): Unit = {
recordWriter.apply(recordConsumer)
}

override def prepareForWrite(recordConsumer: RecordConsumer): Unit = {
this.recordConsumer = recordConsumer
}
}

/**
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`.
* Records are produced by `recordWriters`.
*/
def writeDirect(path: String, schema: String, recordWriters: (RecordConsumer => Unit)*): Unit = {
writeDirect(path, schema, Map.empty[String, String], recordWriters: _*)
}

/**
* Writes arbitrary messages conforming to a given `schema` to a Parquet file located by `path`
* with given user-defined key-value `metadata`. Records are produced by `recordWriters`.
*/
def writeDirect(
path: String,
schema: String,
metadata: Map[String, String],
recordWriters: (RecordConsumer => Unit)*): Unit = {
val messageType = MessageTypeParser.parseMessageType(schema)
val writeSupport = new DirectWriteSupport(messageType, metadata)
val parquetWriter = new ParquetWriter[RecordConsumer => Unit](new Path(path), writeSupport)
try recordWriters.foreach(parquetWriter.write) finally parquetWriter.close()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
""".stripMargin)

checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i =>
def nullable[T <: AnyRef]: ( => T) => T = makeNullable[T](i)

val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS")

Row(
val nonNullablePrimitiveValues = Seq(
i % 2 == 0,
i.toByte,
(i + 1).toShort,
Expand All @@ -50,18 +48,15 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
s"val_$i",
s"val_$i",
// Thrift ENUM values are converted to Parquet binaries containing UTF-8 strings
suits(i % 4),

nullable(i % 2 == 0: java.lang.Boolean),
nullable(i.toByte: java.lang.Byte),
nullable((i + 1).toShort: java.lang.Short),
nullable(i + 2: Integer),
nullable((i * 10).toLong: java.lang.Long),
nullable(i.toDouble + 0.2d: java.lang.Double),
nullable(s"val_$i"),
nullable(s"val_$i"),
nullable(suits(i % 4)),
suits(i % 4))

val nullablePrimitiveValues = if (i % 3 == 0) {
Seq.fill(nonNullablePrimitiveValues.length)(null)
} else {
nonNullablePrimitiveValues
}

val complexValues = Seq(
Seq.tabulate(3)(n => s"arr_${i + n}"),
// Thrift `SET`s are converted to Parquet `LIST`s
Seq(i),
Expand All @@ -71,6 +66,83 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar
Row(Seq.tabulate(3)(j => i + j + m), s"val_${i + m}")
}
}.toMap)

Row(nonNullablePrimitiveValues ++ nullablePrimitiveValues ++ complexValues: _*)
})
}

test("SPARK-10136 list of primitive list") {
withTempPath { dir =>
val path = dir.getCanonicalPath

// This Parquet schema is translated from the following Thrift schema:
//
// struct ListOfPrimitiveList {
// 1: list<list<i32>> f;
// }
val schema =
s"""message ListOfPrimitiveList {
| required group f (LIST) {
| repeated group f_tuple (LIST) {
| repeated int32 f_tuple_tuple;
| }
| }
|}
""".stripMargin

writeDirect(path, schema, { rc =>
rc.message {
rc.field("f", 0) {
rc.group {
rc.field("f_tuple", 0) {
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(0)
rc.addInteger(1)
}
}

rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(2)
rc.addInteger(3)
}
}
}
}
}
}
}, { rc =>
rc.message {
rc.field("f", 0) {
rc.group {
rc.field("f_tuple", 0) {
rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(4)
rc.addInteger(5)
}
}

rc.group {
rc.field("f_tuple_tuple", 0) {
rc.addInteger(6)
rc.addInteger(7)
}
}
}
}
}
}
})

logParquetSchema(path)

checkAnswer(
sqlContext.read.parquet(path),
Seq(
Row(Seq(Seq(0, 1), Seq(2, 3))),
Row(Seq(Seq(4, 5), Seq(6, 7)))))
}
}
}