Spark数据输入–DataFrame/Dataset生成

之前文章Spark数据输入--RDD生成里讲过,Spark的数据输入,主要有三类,这里我们开始研究一下其中的第三类,即Spark如何加载数据生成DataFrame/Dataset。

先介绍一下什么是DataFrame/Dataset,从Spark源码里面找到权威介绍信息:

type DataFrame = Dataset[Row]

// 这里可以看到,DataFrame其实就是类型为Row的Dataset

因此我们主要看Dataset:

/**
 * A Dataset is a strongly typed collection of domain-specific objects that can be transformed
 * in parallel using functional or relational operations. Each Dataset also has an untyped view
 * called a `DataFrame`, which is a Dataset of [[Row]].
 *
 * Operations available on Datasets are divided into transformations and actions. Transformations
 * are the ones that produce new Datasets, and actions are the ones that trigger computation and
 * return results. Example transformations include map, filter, select, and aggregate (`groupBy`).
 * Example actions count, show, or writing data out to file systems.
 *
 * Datasets are "lazy", i.e. computations are only triggered when an action is invoked. Internally,
 * a Dataset represents a logical plan that describes the computation required to produce the data.
 * When an action is invoked, Spark's query optimizer optimizes the logical plan and generates a
 * physical plan for efficient execution in a parallel and distributed manner. To explore the
 * logical plan as well as optimized physical plan, use the `explain` function.
 *
 * To efficiently support domain-specific objects, an [[Encoder]] is required. The encoder maps
 * the domain specific type `T` to Spark's internal type system. For example, given a class `Person`
 * with two fields, `name` (string) and `age` (int), an encoder is used to tell Spark to generate
 * code at runtime to serialize the `Person` object into a binary structure. This binary structure
 * often has much lower memory footprint as well as are optimized for efficiency in data processing
 * (e.g. in a columnar format). To understand the internal binary representation for data, use the
 * `schema` function.
 *
 * There are typically two ways to create a Dataset. The most common way is by pointing Spark
 * to some files on storage systems, using the `read` function available on a `SparkSession`.
 * {{{
 *   val people = spark.read.parquet("...").as[Person]  // Scala
 *   Dataset<Person> people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java
 * }}}
 *
 * Datasets can also be created through transformations available on existing Datasets. For example,
 * the following creates a new Dataset by applying a filter on the existing one:
 * {{{
 *   val names = people.map(_.name)  // in Scala; names is a Dataset[String]
 *   Dataset<String> names = people.map((Person p) -> p.name, Encoders.STRING));
 * }}}
 *
 * Dataset operations can also be untyped, through various domain-specific-language (DSL)
 * functions defined in: Dataset (this class), [[Column]], and [[functions]]. These operations
 * are very similar to the operations available in the data frame abstraction in R or Python.
 *
 * To select a column from the Dataset, use `apply` method in Scala and `col` in Java.
 * {{{
 *   val ageCol = people("age")  // in Scala
 *   Column ageCol = people.col("age"); // in Java
 * }}}
 *
 * Note that the [[Column]] type can also be manipulated through its various functions.
 * {{{
 *   // The following creates a new column that increases everybody's age by 10.
 *   people("age") + 10  // in Scala
 *   people.col("age").plus(10);  // in Java
 * }}}
 *
 * A more concrete example in Scala:
 * {{{
 *   // To create Dataset[Row] using SparkSession
 *   val people = spark.read.parquet("...")
 *   val department = spark.read.parquet("...")
 *
 *   people.filter("age > 30")
 *     .join(department, people("deptId") === department("id"))
 *     .groupBy(department("name"), people("gender"))
 *     .agg(avg(people("salary")), max(people("age")))
 * }}}
 *
 * and in Java:
 * {{{
 *   // To create Dataset<Row> using SparkSession
 *   Dataset<Row> people = spark.read().parquet("...");
 *   Dataset<Row> department = spark.read().parquet("...");
 *
 *   people.filter(people.col("age").gt(30))
 *     .join(department, people.col("deptId").equalTo(department.col("id")))
 *     .groupBy(department.col("name"), people.col("gender"))
 *     .agg(avg(people.col("salary")), max(people.col("age")));
 * }}}
 *
 * @groupname basic Basic Dataset functions
 * @groupname action Actions
 * @groupname untypedrel Untyped transformations
 * @groupname typedrel Typed transformations
 *
 * @since 1.6.0
 */
@InterfaceStability.Stable
class Dataset[T] private[sql](
    @transient val sparkSession: SparkSession,
    @DeveloperApi @InterfaceStability.Unstable @transient val queryExecution: QueryExecution,
    encoder: Encoder[T])
  extends Serializable {
复制代码

将这段注释翻译一下:

Dataset 是领域特定对象的一种强类型集合,可通过函数或者关系操作进行并行转换。每个Dataset还有一个称为“DataFrame”的非类型化视图,它是Row类型的Dataset。

可作用于Dataset上的操作主要分为Transformation和Action两类。Transformation操作是产生新的Dataset,而Action是触发计算并返回结果集。常见的Transformation包括map、filter、select、以及聚合操作,比如groupBy;常见的Action操作包括count、show、或者写数据到外部文件系统等。

Dataset也是懒加载的,也就是说,计算仅仅当action被调用时触发。在Spark内部,Dataset表示描述生成数据所需计算的逻辑计划。当Action被调用时,Spark查询优化器将对该逻辑计划进行优化,并生成一个物理执行计划,用于分布式并行地高效执行。可以使用“explain”函数探索逻辑计划以及优化后的物理执行计划。

Encoder用于支持高效的领域特定对象。encoder用于将领域特定对象‘T’与Spark内部类型系统进行映射。例如:给定一个带有属性name(string)和age(int)的Person类型,encoder用于告诉Spark去生成运行时代码,将“Person”对象序列化为二进制结构。这种二进制结构,通常具有较低的内存占用,并针对数据处理的效率进行了优化,比如以列格式。使用‘schema’函数可以了解数据的二进制内部表示。

有两种常见的创建Dataset方式。其中最常用的是将Spark指向存储系统上的某些文件,使用‘SparkSession’上可用的‘read’函数。

val people = spark.read.parquet("...").as[Person]  // Scala
Dataset<Person> people = spark.read().parquet("...").as(Encoders.bean(Person.class)); // Java
复制代码

Dataset同样也可以通过已有的Dataset进行Transformation生成。比如下面通过在已有的Dataset上调用map函数生成。

val names = people.map(_.name)  // in Scala; names is a Dataset[String]
Dataset<String> names = people.map((Person p) -> p.name, Encoders.STRING));
复制代码

Dataset操作也可以是非类型化的,通过各种定义在Dataset、Column、functions中的DSL。这些操作与R或者Python语言中的DataFrame抽象中的可用操作非常相似。

从Dataset中选择一列,在scala中可用使用‘apply’方法,在java中使用‘col’方法:

val ageCol = people("age")  // in Scala
Column ageCol = people.col("age"); // in Java
复制代码

注意Column类型可以通过它的多种函数进行操作

// The following creates a new column that increases everybody's age by 10.
people("age") + 10  // in Scala
people.col("age").plus(10);  // in Java
复制代码

Scala中的一个更具体的例子:

// To create Dataset[Row] using SparkSession
val people = spark.read.parquet("...")
val department = spark.read.parquet("...")

people.filter("age > 30")
.join(department, people("deptId") === department("id"))
.groupBy(department("name"), people("gender"))
.agg(avg(people("salary")), max(people("age")))
复制代码

Java中:

// To create Dataset<Row> using SparkSession
Dataset<Row> people = spark.read().parquet("...");
Dataset<Row> department = spark.read().parquet("...");

people.filter(people.col("age").gt(30))
    .join(department, people.col("deptId").equalTo(department.col("id")))
    .groupBy(department.col("name"), people.col("gender"))
    .agg(avg(people.col("salary")), max(people.col("age")));
复制代码

从上面的解释来看,Spark加载数据生成Dataset,入口是SparkSession的read方法,其实还有一个readStream方法,从方法名称就可以看到read方法用于批量读取数据,readStream用于流式读取数据。

  /**
   * Returns a [[DataFrameReader]] that can be used to read non-streaming data in as a
   * `DataFrame`.
   * 返回一个DataFrameReader用于读取非流式数据生成DataFrame
   * {{{
   *   sparkSession.read.parquet("/path/to/file.parquet")
   *   sparkSession.read.schema(schema).json("/path/to/file.json")
   * }}}
   *
   * @since 2.0.0
   */
  def read: DataFrameReader = new DataFrameReader(self)

  /**
   * Returns a `DataStreamReader` that can be used to read streaming data in as a `DataFrame`.
   * 返回一个DataStreamReader用于读取流式数据生成DataFrame
   * {{{
   *   sparkSession.readStream.parquet("/path/to/directory/of/parquet/files")
   *   sparkSession.readStream.schema(schema).json("/path/to/directory/of/json/files")
   * }}}
   *
   * @since 2.0.0
   */
  @InterfaceStability.Evolving
  def readStream: DataStreamReader = new DataStreamReader(self)
复制代码

从这里可以看出,Dataset具有是否流式数据的属性:

  /**
   * Returns true if this Dataset contains one or more sources that continuously
   * return data as it arrives. A Dataset that reads data from a streaming source
   * must be executed as a `StreamingQuery` using the `start()` method in
   * `DataStreamWriter`. Methods that return a single answer, e.g. `count()` or
   * `collect()`, will throw an [[AnalysisException]] when there is a streaming
   * source present.
   *
   * @group streaming
   * @since 2.0.0
   */
  @InterfaceStability.Evolving
  def isStreaming: Boolean = logicalPlan.isStreaming
复制代码

如果是流式数据源,则调用readStream方法生成DataStreamReader对象,如果是批量数据,则调用read方法生成DataFrameReader对象。

本次我们先研究批量数据读取生成DataFrame/Dataset。

我们以读取parquet格式文件为例:

sparkSession.read.parquet("/path/to/file.parquet")
复制代码

找到DataFrameReader的parquet方法:

  /**
   * Loads a Parquet file, returning the result as a `DataFrame`. See the documentation
   * on the other overloaded `parquet()` method for more details.
   *
   * @since 2.0.0
   */
  def parquet(path: String): DataFrame = {
    // This method ensures that calls that explicit need single argument works, see SPARK-16009
    parquet(Seq(path): _*)
  }
复制代码

进一步调用:

  /**
   * Loads a Parquet file, returning the result as a `DataFrame`.
   *
   * You can set the following Parquet-specific option(s) for reading Parquet files:
   * <ul>
   * <li>`mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
   * whether we should merge schemas collected from all Parquet part-files. This will override
   * `spark.sql.parquet.mergeSchema`.</li>
   * </ul>
   * @since 1.4.0
   */
  @scala.annotation.varargs
  def parquet(paths: String*): DataFrame = {
    format("parquet").load(paths: _*)
  }
复制代码

里面调用了format方法指定数据源格式为parquet

/**
   * Specifies the input data source format.
   *
   * @since 1.4.0
   */
  def format(source: String): DataFrameReader = {
    this.source = source
    this
  }
复制代码

然后调用DataFrameReader的load方法:

  /**
   * Loads input in as a `DataFrame`, for data sources that support multiple paths.
   * Only works if the source is a HadoopFsRelationProvider.
   *
   * @since 1.6.0
   */
  @scala.annotation.varargs
  def load(paths: String*): DataFrame = {
    if (source.toLowerCase(Locale.ROOT) == DDLUtils.HIVE_PROVIDER) {
      throw new AnalysisException("Hive data source can only be used with tables, you can not " +
        "read files of Hive data source directly.")
    }

    val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
    if (classOf[DataSourceV2].isAssignableFrom(cls)) {
      val ds = cls.newInstance().asInstanceOf[DataSourceV2]
      if (ds.isInstanceOf[ReadSupport]) {
        val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
          ds = ds, conf = sparkSession.sessionState.conf)
        val pathsOption = {
          val objectMapper = new ObjectMapper()
          DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray)
        }
        Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
          ds, sessionOptions ++ extraOptions.toMap + pathsOption,
          userSpecifiedSchema = userSpecifiedSchema))
      } else {
        loadV1Source(paths: _*)
      }
    } else {
      loadV1Source(paths: _*)
    }
  }
复制代码

我们来重点分析这个load方法流程:

未命名文件 (1).png

先重点看下DataSource.lookupDataSource方法,该方法主要是通过SPI机制查找指定source的数据源处理类:

  /** Given a provider name, look up the data source class definition. */
  def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
    val provider1 = backwardCompatibilityMap.getOrElse(provider, provider) match {
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native" =>
        classOf[OrcFileFormat].getCanonicalName
      case name if name.equalsIgnoreCase("orc") &&
          conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "hive" =>
        "org.apache.spark.sql.hive.orc.OrcFileFormat"
      case "com.databricks.spark.avro" if conf.replaceDatabricksSparkAvroEnabled =>
        "org.apache.spark.sql.avro.AvroFileFormat"
      case name => name
    }
    val provider2 = s"$provider1.DefaultSource"
    val loader = Utils.getContextOrSparkClassLoader
    val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader)

    try {
      serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider1)).toList match {
        // the provider format did not match any given registered aliases
        case Nil =>
          try {
            Try(loader.loadClass(provider1)).orElse(Try(loader.loadClass(provider2))) match {
              case Success(dataSource) =>
                // Found the data source using fully qualified path
                dataSource
              case Failure(error) =>
                if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
                  throw new AnalysisException(
                    "Hive built-in ORC data source must be used with Hive support enabled. " +
                    "Please use the native ORC data source by setting 'spark.sql.orc.impl' to " +
                    "'native'")
                } else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
                  provider1 == "com.databricks.spark.avro" ||
                  provider1 == "org.apache.spark.sql.avro") {
                  throw new AnalysisException(
                    s"Failed to find data source: $provider1. Avro is built-in but external data " +
                    "source module since Spark 2.4. Please deploy the application as per " +
                    "the deployment section of \"Apache Avro Data Source Guide\".")
                } else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
                  throw new AnalysisException(
                    s"Failed to find data source: $provider1. Please deploy the application as " +
                    "per the deployment section of " +
                    "\"Structured Streaming + Kafka Integration Guide\".")
                } else {
                  throw new ClassNotFoundException(
                    s"Failed to find data source: $provider1. Please find packages at " +
                      "http://spark.apache.org/third-party-projects.html",
                    error)
                }
            }
          } catch {
            case e: NoClassDefFoundError => // This one won't be caught by Scala NonFatal
              // NoClassDefFoundError's class name uses "/" rather than "." for packages
              val className = e.getMessage.replaceAll("/", ".")
              if (spark2RemovedClasses.contains(className)) {
                throw new ClassNotFoundException(s"$className was removed in Spark 2.0. " +
                  "Please check if your library is compatible with Spark 2.0", e)
              } else {
                throw e
              }
          }
        case head :: Nil =>
          // there is exactly one registered alias
          head.getClass
        case sources =>
          // There are multiple registered aliases for the input. If there is single datasource
          // that has "org.apache.spark" package in the prefix, we use it considering it is an
          // internal datasource within Spark.
          val sourceNames = sources.map(_.getClass.getName)
          val internalSources = sources.filter(_.getClass.getName.startsWith("org.apache.spark"))
          if (internalSources.size == 1) {
            logWarning(s"Multiple sources found for $provider1 (${sourceNames.mkString(", ")}), " +
              s"defaulting to the internal datasource (${internalSources.head.getClass.getName}).")
            internalSources.head.getClass
          } else {
            throw new AnalysisException(s"Multiple sources found for $provider1 " +
              s"(${sourceNames.mkString(", ")}), please specify the fully qualified class name.")
          }
      }
    } catch {
      case e: ServiceConfigurationError if e.getCause.isInstanceOf[NoClassDefFoundError] =>
        // NoClassDefFoundError's class name uses "/" rather than "." for packages
        val className = e.getCause.getMessage.replaceAll("/", ".")
        if (spark2RemovedClasses.contains(className)) {
          throw new ClassNotFoundException(s"Detected an incompatible DataSourceRegister. " +
            "Please remove the incompatible library from classpath or upgrade it. " +
            s"Error: ${e.getMessage}", e)
        } else {
          throw e
        }
    }
  }
复制代码

对于parquet格式,Spark内部已经内置好了实现类,即org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat

image.png

ParquetFileFormat,不是DataSourceV2实现类,因此执行流程走的是loadV1Source分支:首先通过DataSource的apply方法构建DataSource对象,然后调用resolveRelation方法产生BaseRelation对象,实际上是org.apache.spark.sql.execution.datasources.HadoopFsRelation,最后调用SparkSession的baseRelationToDataFrame方法,将BaseRelation转换为DataFrame:baseRelationToDataFrame方法调用了Dataset的ofRows方法,然后new了一个Dataset对象。

  def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
    val qe = sparkSession.sessionState.executePlan(logicalPlan)
    qe.assertAnalyzed()
    new Dataset[Row](sparkSession, qe, RowEncoder(qe.analyzed.schema))
  }
复制代码

至此,清楚了Spark数据输入生成DataFrame/Dataset流程,当然到这里其实还没有真正开始加载数据,我们后面再探索一下任务真正执行时,如何加载数据源数据的。