PySpark のテスト#

このガイドは、堅牢な PySpark コードのテストを作成するためのリファレンスです。

PySpark テストユーティリティのドキュメントを表示するには、こちらをご覧ください。

PySpark アプリケーションの構築#

PySpark アプリケーションを開始する方法の例を以下に示します。すでにテストする準備ができているアプリケーションがある場合は、「PySpark アプリケーションのテスト」セクションにスキップしてください。

まず、Spark Session を開始します。

[3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

# Create a SparkSession
spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()

次に、DataFrame を作成します。

[5]:
sample_data = [{"name": "John    D.", "age": 30},
  {"name": "Alice   G.", "age": 25},
  {"name": "Bob  T.", "age": 35},
  {"name": "Eve   A.", "age": 28}]

df = spark.createDataFrame(sample_data)

次に、DataFrame に変換関数を定義して適用しましょう。

[7]:
from pyspark.sql.functions import col, regexp_replace

# Remove additional spaces in name
def remove_extra_spaces(df, column_name):
    # Remove extra spaces from the specified column
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))

    return df_transformed

transformed_df = remove_extra_spaces(df, "name")

transformed_df.show()
+---+--------+
|age|    name|
+---+--------+
| 30| John D.|
| 25|Alice G.|
| 35|  Bob T.|
| 28|  Eve A.|
+---+--------+

PySpark アプリケーションのテスト#

では、PySpark 変換関数をテストしましょう。

1 つの方法は、結果の DataFrame を目視で確認することです。ただし、これは大きな DataFrame や入力サイズの場合、実用的でない可能性があります。

より良い方法は、テストを作成することです。以下に、コードをテストする方法の例をいくつか示します。以下の例は、Spark 3.5 以降のバージョンに適用されます。

これらの例は網羅的ではないことに注意してください。なぜなら、unittestpytest の代わりに利用できる多くのテストフレームワークの代替手段があるからです。組み込みの PySpark テストユーティリティ関数はスタンドアロンであり、任意のテストフレームワークまたは CI テストパイプラインと互換性があります。

オプション 1: PySpark 組み込みテストユーティリティ関数のみを使用する#

簡単なアドホック検証ケースでは、assertDataFrameEqualassertSchemaEqual などの PySpark テストユーティリティをスタンドアロンコンテキストで使用できます。ノートブックセッションで PySpark コードを簡単にテストできます。たとえば、2 つの DataFrame の等価性をアサートしたい場合

[10]:
import pyspark.testing
from pyspark.testing.utils import assertDataFrameEqual

# Example 1
df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2)  # pass, DataFrames are identical
[11]:
# Example 2
df1 = spark.createDataFrame(data=[("1", 0.1), ("2", 3.23)], schema=["id", "amount"])
df2 = spark.createDataFrame(data=[("1", 0.109), ("2", 3.23)], schema=["id", "amount"])
assertDataFrameEqual(df1, df2, rtol=1e-1)  # pass, DataFrames are approx equal by rtol

2 つの DataFrame スキーマを比較することもできます。

[13]:
from pyspark.testing.utils import assertSchemaEqual
from pyspark.sql.types import StructType, StructField, ArrayType, DoubleType

s1 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), True), True)])

assertSchemaEqual(s1, s2)  # pass, schemas are identical

オプション 2: Unit Test の使用#

より複雑なテストシナリオでは、テストフレームワークの使用を検討したくなるかもしれません。

最も人気のあるテストフレームワークのオプションの 1 つが単体テストです。組み込みの Python unittest ライブラリを使用して PySpark テストを作成する方法を説明します。

まず、Spark セッションが必要です。Spark セッションのセットアップとティアダウンを処理するために、unittest パッケージの @classmethod デコレーターを使用できます。

[15]:
import unittest

class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()


    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

次に、unittest クラスを作成しましょう。

[17]:
from pyspark.testing.utils import assertDataFrameEqual

class TestTranformation(PySparkTestCase):
    def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30},
                       {"name": "Alice   G.", "age": 25},
                       {"name": "Bob  T.", "age": 35},
                       {"name": "Eve   A.", "age": 28}]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [{"name": "John D.", "age": 30},
        {"name": "Alice G.", "age": 25},
        {"name": "Bob T.", "age": 35},
        {"name": "Eve A.", "age": 28}]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)

実行されると、unittest は「test」で始まるすべての関数を検出します。

オプション 3: Pytest の使用#

最も人気のある Python テストフレームワークの 1 つである pytest を使用してテストを作成することもできます。

pytest フィクスチャを使用すると、テスト間で Spark セッションを共有し、テスト完了時にそれを破棄できます。

[20]:
import pytest

@pytest.fixture
def spark_fixture():
    spark = SparkSession.builder.appName("Testing PySpark Example").getOrCreate()
    yield spark

その後、テストを次のように定義できます。

[22]:
import pytest
from pyspark.testing.utils import assertDataFrameEqual

def test_single_space(spark_fixture):
    sample_data = [{"name": "John    D.", "age": 30},
                   {"name": "Alice   G.", "age": 25},
                   {"name": "Bob  T.", "age": 35},
                   {"name": "Eve   A.", "age": 28}]

    # Create a Spark DataFrame
    original_df = spark_fixture.createDataFrame(sample_data)

    # Apply the transformation function from before
    transformed_df = remove_extra_spaces(original_df, "name")

    expected_data = [{"name": "John D.", "age": 30},
    {"name": "Alice G.", "age": 25},
    {"name": "Bob T.", "age": 35},
    {"name": "Eve A.", "age": 28}]

    expected_df = spark_fixture.createDataFrame(expected_data)

    assertDataFrameEqual(transformed_df, expected_df)

pytest コマンドでテストファイルを実行すると、「test」で始まるすべての関数が検出されます。

すべてをまとめる!#

単体テストの例で、すべてのステップをまとめて見てみましょう。

[25]:
# pkg/etl.py
import unittest

from pyspark.sql import SparkSession
from pyspark.sql.functions import col
from pyspark.sql.functions import regexp_replace
from pyspark.testing.utils import assertDataFrameEqual

# Create a SparkSession
spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate()

sample_data = [{"name": "John    D.", "age": 30},
  {"name": "Alice   G.", "age": 25},
  {"name": "Bob  T.", "age": 35},
  {"name": "Eve   A.", "age": 28}]

df = spark.createDataFrame(sample_data)

# Define DataFrame transformation function
def remove_extra_spaces(df, column_name):
    # Remove extra spaces from the specified column using regexp_replace
    df_transformed = df.withColumn(column_name, regexp_replace(col(column_name), "\\s+", " "))

    return df_transformed
[26]:
# pkg/test_etl.py
import unittest

from pyspark.sql import SparkSession

# Define unit test base class
class PySparkTestCase(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.appName("Sample PySpark ETL").getOrCreate()

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

# Define unit test
class TestTranformation(PySparkTestCase):
    def test_single_space(self):
        sample_data = [{"name": "John    D.", "age": 30},
                        {"name": "Alice   G.", "age": 25},
                        {"name": "Bob  T.", "age": 35},
                        {"name": "Eve   A.", "age": 28}]

        # Create a Spark DataFrame
        original_df = spark.createDataFrame(sample_data)

        # Apply the transformation function from before
        transformed_df = remove_extra_spaces(original_df, "name")

        expected_data = [{"name": "John D.", "age": 30},
        {"name": "Alice G.", "age": 25},
        {"name": "Bob T.", "age": 35},
        {"name": "Eve A.", "age": 28}]

        expected_df = spark.createDataFrame(expected_data)

        assertDataFrameEqual(transformed_df, expected_df)
[27]:
unittest.main(argv=[''], verbosity=0, exit=False)
Ran 1 test in 1.734s

OK
[27]:
<unittest.main.TestProgram at 0x174539db0>