Python Data Source API#

概要#

Python Data Source APIは、Spark 4.0で導入された新機能で、開発者がPythonでカスタムデータソースからの読み込みやカスタムデータシンクへの書き込みを可能にします。このガイドでは、APIの包括的な概要と、Pythonデータソースの作成、使用、管理方法についての説明を提供します。

簡単な例#

ここでは、正確に2行の合成データを生成する簡単なPythonデータソースを示します。この例では、外部ライブラリを使用せずにカスタムデータソースをセットアップする方法を示し、すぐに利用を開始するために必要な基本に焦点を当てています。

ステップ 1: データソースの定義

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import IntegerType, StringType, StructField, StructType

class SimpleDataSource(DataSource):
    """
    A simple data source for PySpark that generates exactly two rows of synthetic data.
    """

    @classmethod
    def name(cls):
        return "simple"

    def schema(self):
        return StructType([
            StructField("name", StringType()),
            StructField("age", IntegerType())
        ])

    def reader(self, schema: StructType):
        return SimpleDataSourceReader()

class SimpleDataSourceReader(DataSourceReader):

    def read(self, partition):
        yield ("Alice", 20)
        yield ("Bob", 30)

ステップ 2: データソースの登録

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

spark.dataSource.register(SimpleDataSource)

ステップ 3: データソースからの読み込み

spark.read.format("simple").load().show()

# +-----+---+
# | name|age|
# +-----+---+
# |Alice| 20|
# |  Bob| 30|
# +-----+---+

Python データソースの作成#

カスタムPythonデータソースを作成するには、DataSource の基底クラスをサブクラス化し、データの読み書きに必要なメソッドを実装する必要があります。

この例では、faker ライブラリを使用して合成データを生成する簡単なデータソースの作成を示します。faker ライブラリがインストールされており、Python環境からアクセス可能であることを確認してください。

データソースの定義

ソース名とスキーマを持つ新しいDataSource のサブクラスを作成することから始めます。

バッチまたはストリーミングクエリでソースまたはシンクとして使用するには、DataSourceの対応するメソッドを実装する必要があります。

機能のために実装する必要のあるメソッド

ソース

シンク

バッチ

reader()

writer()

ストリーミング

streamReader() または simpleStreamReader()

streamWriter()

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StructType

class FakeDataSource(DataSource):
    """
    A fake data source for PySpark to generate synthetic data using the `faker` library.
    Options:
    - numRows: specify number of rows to generate. Default value is 3.
    """

    @classmethod
    def name(cls):
        return "fake"

    def schema(self):
        return "name string, date string, zipcode string, state string"

    def reader(self, schema: StructType):
        return FakeDataSourceReader(schema, self.options)

    def writer(self, schema: StructType, overwrite: bool):
        return FakeDataSourceWriter(self.options)

    def streamReader(self, schema: StructType):
        return FakeStreamReader(schema, self.options)

    # Please skip the implementation of this method if streamReader has been implemented.
    def simpleStreamReader(self, schema: StructType):
        return SimpleStreamReader()

    def streamWriter(self, schema: StructType, overwrite: bool):
        return FakeStreamWriter(self.options)

Python データソースのバッチリーダーとライターの実装#

リーダーの実装

合成データを生成するためのリーダーロジックを定義します。スキーマの各フィールドを埋めるためにfaker ライブラリを使用します。

class FakeDataSourceReader(DataSourceReader):

    def __init__(self, schema, options):
        self.schema: StructType = schema
        self.options = options

    def read(self, partition):
        from faker import Faker
        fake = Faker()
        # Note: every value in this `self.options` dictionary is a string.
        num_rows = int(self.options.get("numRows", 3))
        for _ in range(num_rows):
            row = []
            for field in self.schema.fields:
                value = getattr(fake, field.name)()
                row.append(value)
            yield tuple(row)

ライターの実装

各データパーティションを処理し、行数をカウントし、書き込みが成功した後に合計行数を表示するか、書き込みプロセスが失敗した場合は失敗したタスクの数を表示する、ダミーのデータソースライターを作成します。

from dataclasses import dataclass
from typing import Iterator, List

from pyspark.sql.types import Row
from pyspark.sql.datasource import DataSource, DataSourceWriter, WriterCommitMessage

@dataclass
class SimpleCommitMessage(WriterCommitMessage):
    partition_id: int
    count: int

class FakeDataSourceWriter(DataSourceWriter):

    def write(self, rows: Iterator[Row]) -> SimpleCommitMessage:
        from pyspark import TaskContext

        context = TaskContext.get()
        partition_id = context.partitionId()
        cnt = sum(1 for _ in rows)
        return SimpleCommitMessage(partition_id=partition_id, count=cnt)

    def commit(self, messages: List[SimpleCommitMessage]) -> None:
        total_count = sum(message.count for message in messages)
        print(f"Total number of rows: {total_count}")

    def abort(self, messages: List[SimpleCommitMessage]) -> None:
        failed_count = sum(message is None for message in messages)
        print(f"Number of failed tasks: {failed_count}")

Python データソースのストリーミングリーダーとライターの実装#

ストリームリーダーの実装

これは、各マイクロバッチで2行を生成するダミーのストリーミングデータリーダーです。streamReaderインスタンスは、各マイクロバッチで2ずつ増加する整数オフセットを持っています。

class RangePartition(InputPartition):
    def __init__(self, start, end):
        self.start = start
        self.end = end

class FakeStreamReader(DataSourceStreamReader):
    def __init__(self, schema, options):
        self.current = 0

    def initialOffset(self) -> dict:
        """
        Return the initial start offset of the reader.
        """
        return {"offset": 0}

    def latestOffset(self) -> dict:
        """
        Return the current latest offset that the next microbatch will read to.
        """
        self.current += 2
        return {"offset": self.current}

    def partitions(self, start: dict, end: dict):
        """
        Plans the partitioning of the current microbatch defined by start and end offset,
        it needs to return a sequence of :class:`InputPartition` object.
        """
        return [RangePartition(start["offset"], end["offset"])]

    def commit(self, end: dict):
        """
        This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
        """
        pass

    def read(self, partition) -> Iterator[Tuple]:
        """
        Takes a partition as an input and read an iterator of tuples from the data source.
        """
        start, end = partition.start, partition.end
        for i in range(start, end):
            yield (i, str(i))

シンプルなストリームリーダーの実装

データソースのスループットが低く、パーティショニングを必要としない場合は、DataSourceStreamReaderの代わりにSimpleDataSourceStreamReaderを実装できます。

readable streaming data source については、simpleStreamReader() と streamReader() のいずれかを実装する必要があります。simpleStreamReader() は streamReader() が実装されていない場合にのみ呼び出されます。

これは、SimpleDataSourceStreamReaderインターフェースで実装された、各バッチで2行を生成する同じダミーのストリーミングリーダーです。

class SimpleStreamReader(SimpleDataSourceStreamReader):
    def initialOffset(self):
        """
        Return the initial start offset of the reader.
        """
        return {"offset": 0}

    def read(self, start: dict) -> (Iterator[Tuple], dict):
        """
        Takes start offset as an input, return an iterator of tuples and the start offset of next read.
        """
        start_idx = start["offset"]
        it = iter([(i,) for i in range(start_idx, start_idx + 2)])
        return (it, {"offset": start_idx + 2})

    def readBetweenOffsets(self, start: dict, end: dict) -> Iterator[Tuple]:
        """
        Takes start and end offset as input and read an iterator of data deterministically.
        This is called whe query replay batches during restart or after failure.
        """
        start_idx = start["offset"]
        end_idx = end["offset"]
        return iter([(i,) for i in range(start_idx, end_idx)])

    def commit(self, end):
        """
        This is invoked when the query has finished processing data before end offset, this can be used to clean up resource.
        """
        pass

ストリームライターの実装

これは、各マイクロバッチのメタデータ情報をローカルパスに書き込むストリーミングデータライターです。

class SimpleCommitMessage(WriterCommitMessage):
   partition_id: int
   count: int

class FakeStreamWriter(DataSourceStreamWriter):
   def __init__(self, options):
       self.options = options
       self.path = self.options.get("path")
       assert self.path is not None

   def write(self, iterator):
       """
       Write the data and return the commit message of that partition
       """
       from pyspark import TaskContext
       context = TaskContext.get()
       partition_id = context.partitionId()
       cnt = 0
       for row in iterator:
           cnt += 1
       return SimpleCommitMessage(partition_id=partition_id, count=cnt)

   def commit(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` when all write tasks succeed and decides what to do with it.
       In this FakeStreamWriter, we write the metadata of the microbatch(number of rows and partitions) into a json file inside commit().
       """
       status = dict(num_partitions=len(messages), rows=sum(m.count for m in messages))
       with open(os.path.join(self.path, f"{batchId}.json"), "a") as file:
           file.write(json.dumps(status) + "\n")

   def abort(self, messages, batchId) -> None:
       """
       Receives a sequence of :class:`WriterCommitMessage` from successful tasks when some tasks fail and decides what to do with it.
       In this FakeStreamWriter, we write a failure message into a txt file inside abort().
       """
       with open(os.path.join(self.path, f"{batchId}.txt"), "w") as file:
           file.write(f"failed in batch {batchId}")

シリアライズ要件#

ユーザー定義のDataSource、DataSourceReader、DataSourceWriter、DataSourceStreamReader、およびDataSourceStreamWriterとそのメソッドは、pickleによってシリアライズ可能である必要があります。

メソッド内で使用されるライブラリは、メソッド内でインポートする必要があります。たとえば、以下のコードの `read()` メソッド内では TaskContext をインポートする必要があります。

def read(self, partition):
    from pyspark import TaskContext
    context = TaskContext.get()

Python データソースの使用#

Python データソースをバッチクエリで使用する

データソースを定義した後、使用する前に登録する必要があります。

spark.dataSource.register(FakeDataSource)

Python データソースからの読み込み

デフォルトのスキーマとオプションで、fake datasource から読み込みます。

spark.read.format("fake").load().show()

# +-----------+----------+-------+-------+
# |       name|      date|zipcode|  state|
# +-----------+----------+-------+-------+
# |Carlos Cobb|2018-07-15|  73003|Indiana|
# | Eric Scott|1991-08-22|  10085|  Idaho|
# | Amy Martin|1988-10-28|  68076| Oregon|
# +-----------+----------+-------+-------+

カスタムスキーマで、fake datasource から読み込みます。

spark.read.format("fake").schema("name string, company string").load().show()

# +---------------------+--------------+
# |name                 |company       |
# +---------------------+--------------+
# |Tanner Brennan       |Adams Group   |
# |Leslie Maxwell       |Santiago Group|
# |Mrs. Jacqueline Brown|Maynard Inc   |
# +---------------------+--------------+

異なる行数で、fake datasource から読み込みます。

spark.read.format("fake").option("numRows", 5).load().show()

# +--------------+----------+-------+------------+
# |          name|      date|zipcode|       state|
# +--------------+----------+-------+------------+
# |  Pam Mitchell|1988-10-20|  23788|   Tennessee|
# |Melissa Turner|1996-06-14|  30851|      Nevada|
# |  Brian Ramsey|2021-08-21|  55277|  Washington|
# |  Caitlin Reed|1983-06-22|  89813|Pennsylvania|
# | Douglas James|2007-01-18|  46226|     Alabama|
# +--------------+----------+-------+------------+

Python データソースへの書き込み

カスタム場所にデータを書き込むには、mode() 句を指定してください。サポートされているモードは appendoverwrite です。

df = spark.range(0, 10, 1, 5)
df.write.format("fake").mode("append").save()

# You can check the Spark log (standard error) to see the output of the write operation.
# Total number of rows: 10

Python データソースをストリーミングクエリで使用する

Python データソースを登録したら、`format()` に短い名前または完全名を渡すことで、`readStream()` のソースまたは `writeStream()` のシンクとしてストリーミングクエリでも使用できます。

fake Python データソースから読み込み、コンソールに書き込むクエリを開始します。

query = spark.readStream.format("fake").load().writeStream.format("console").start()

# +---+
# | id|
# +---+
# |  0|
# |  1|
# +---+
# +---+
# | id|
# +---+
# |  2|
# |  3|
# +---+

同じデータソースをストリーミングリーダーとライターでも使用できます。

query = spark.readStream.format("fake").load().writeStream.format("fake").start("/output_path")

パフォーマンス向上のための直接 Arrow バッチサポートを備えた Python Data Source Reader#

Python Datasource Reader は、Arrow バッチの直接生成をサポートしており、データ処理パフォーマンスを大幅に向上させることができます。効率的な Arrow フォーマットを使用することで、この機能は従来の行ごとのデータ処理のオーバーヘッドを回避し、特に大規模なデータセットで、パフォーマンスを最大1桁向上させます。

Arrow バッチサポートの有効化: この機能を有効にするには、カスタム DataSource を、DataSourceReader (または DataSourceStreamReader) 実装の read メソッド内で pyarrow.RecordBatch オブジェクトを返すことで、Arrow バッチを生成するように設定します。このメソッドは、データ処理を簡素化し、I/O 操作の数を削減します。これは、大規模なデータ処理タスクに特に役立ちます。

Arrow バッチの例: 以下は、Arrow バッチサポートを使用した基本的な Data Source の実装方法を示す例です。

from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
from pyspark.sql import SparkSession
import pyarrow as pa

# Define the ArrowBatchDataSource
class ArrowBatchDataSource(DataSource):
    """
    A Data Source for testing Arrow Batch Serialization
    """

    @classmethod
    def name(cls):
        return "arrowbatch"

    def schema(self):
        return "key int, value string"

    def reader(self, schema: str):
        return ArrowBatchDataSourceReader(schema, self.options)

# Define the ArrowBatchDataSourceReader
class ArrowBatchDataSourceReader(DataSourceReader):
    def __init__(self, schema, options):
        self.schema: str = schema
        self.options = options

    def read(self, partition):
        # Create Arrow Record Batch
        keys = pa.array([1, 2, 3, 4, 5], type=pa.int32())
        values = pa.array(["one", "two", "three", "four", "five"], type=pa.string())
        schema = pa.schema([("key", pa.int32()), ("value", pa.string())])
        record_batch = pa.RecordBatch.from_arrays([keys, values], schema=schema)
        yield record_batch

    def partitions(self):
        # Define the number of partitions
        num_part = 1
        return [InputPartition(i) for i in range(num_part)]

# Initialize the Spark Session
spark = SparkSession.builder.appName("ArrowBatchExample").getOrCreate()

# Register the ArrowBatchDataSource
spark.dataSource.register(ArrowBatchDataSource)

# Load data using the custom data source
df = spark.read.format("arrowbatch").load()

df.show()

使用上の注意#

  • Data Source の解決中、同じ名前の組み込みおよび Scala/Java Data Source は Python Data Source よりも優先されます。Python Data Source を明示的に使用するには、その名前が他の Data Source と競合しないことを確認してください。