Pythonユーザー定義テーブル関数 (UDTF)#

Spark 3.5では、新しいタイプのユーザー定義関数であるPythonユーザー定義テーブル関数 (UDTF) が導入されました。各呼び出しから単一の結果値を返すスカラー関数とは異なり、各UDTFはクエリのFROM句で呼び出され、テーブル全体を出力として返します。各UDTF呼び出しは、ゼロ個以上の引数を受け入れることができます。これらの引数は、スカラー式または完全な入力テーブルを表すテーブル引数のいずれかです。

Python UDTFの実装#

Python UDTFを実装するには、まずメソッドを実装するクラスを定義する必要があります。

class PythonUDTF:

    def __init__(self) -> None:
        """
        Initializes the user-defined table function (UDTF). This is optional.

        This method serves as the default constructor and is called once when the
        UDTF is instantiated on the executor side.

        Any class fields assigned in this method will be available for subsequent
        calls to the `eval` and `terminate` methods. This class instance will remain
        alive until all rows in the current partition have been consumed by the `eval`
        method.

        Notes
        -----
        - You cannot create or reference the Spark session within the UDTF. Any
          attempt to do so will result in a serialization error.
        - If the below `analyze` method is implemented, it is also possible to define this
          method as: `__init__(self, analyze_result: AnalyzeResult)`. In this case, the result
          of the `analyze` method is passed into all future instantiations of this UDTF class.
          In this way, the UDTF may inspect the schema and metadata of the output table as
          needed during execution of other methods in this class. Note that it is possible to
          create a subclass of the `AnalyzeResult` class if desired for purposes of passing
          custom information generated just once during UDTF analysis to other method calls;
          this can be especially useful if this initialization is expensive.
        """
        ...

    @staticmethod
    def analyze(self, *args: AnalyzeArgument) -> AnalyzeResult:
        """
        Static method to compute the output schema of a particular call to this function in
        response to the arguments provided.

        This method is optional and only needed if the registration of the UDTF did not provide
        a static output schema to be use for all calls to the function. In this context,
        `output schema` refers to the ordered list of the names and types of the columns in the
        function's result table.

        This method accepts zero or more parameters mapping 1:1 with the arguments provided to
        the particular UDTF call under consideration. Each parameter is an instance of the
        `AnalyzeArgument` class.

        `AnalyzeArgument` fields
        ------------------------
        dataType: DataType
            Indicates the type of the provided input argument to this particular UDTF call.
            For input table arguments, this is a StructType representing the table's columns.
        value: Optional[Any]
            The value of the provided input argument to this particular UDTF call. This is
            `None` for table arguments, or for literal scalar arguments that are not constant.
        isTable: bool
            This is true if the provided input argument to this particular UDTF call is a
            table argument.
        isConstantExpression: bool
            This is true if the provided input argument to this particular UDTF call is either a
            literal or other constant-foldable scalar expression.

        This method returns an instance of the `AnalyzeResult` class which includes the result
        table's schema as a StructType. If the UDTF accepts an input table argument, then the
        `AnalyzeResult` can also include a requested way to partition and order the rows of
        the input table across several UDTF calls. See below for more information about UDTF
        table arguments and how to call them in SQL queries, including the WITH SINGLE
        PARTITION clause (corresponding to the `withSinglePartition` field here), PARTITION BY
        clause (corresponding to the `partitionBy` field here), ORDER BY clause (corresponding
        to the `orderBy` field here), and passing table subqueries as arguments (corresponding
        to the `select` field here).

        `AnalyzeResult` fields
        ----------------------
        schema: StructType
            The schema of the result table.
        withSinglePartition: bool = False
            If True, the query planner will arrange a repartitioning operation from the previous
            execution stage such that all rows of the input table are consumed by the `eval`
            method from exactly one instance of the UDTF class.
        partitionBy: Sequence[PartitioningColumn] = field(default_factory=tuple)
            If non-empty, the query planner will arrange a repartitioning such that all rows
            with each unique combination of values of the partitioning expressions are consumed
            by a separate unique instance of the UDTF class.
        orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
            If non-empty, this specifies the requested ordering of rows within each partition.
        select: Sequence[SelectedColumn] = field(default_factory=tuple)
            If non-empty, this is a sequence of expressions that the UDTF is specifying for
            Catalyst to evaluate against the columns in the input TABLE argument. The UDTF then
            receives one input attribute for each name in the list, in the order they are
            listed.

        Notes
        -----
        - It is possible for the `analyze` method to accept the exact arguments expected,
          mapping 1:1 with the arguments provided to the UDTF call.
        - The `analyze` method can instead choose to accept positional arguments if desired
          (using `*args`) or keyword arguments (using `**kwargs`).

        Examples
        --------
        This is an `analyze` implementation that returns one output column for each word in the
        input string argument.

        >>> @staticmethod
        ... def analyze(text: str) -> AnalyzeResult:
        ...     schema = StructType()
        ...     for index, word in enumerate(text.split(" ")):
        ...         schema = schema.add(f"word_{index}")
        ...     return AnalyzeResult(schema=schema)

        Same as above, but using *args to accept the arguments.

        >>> @staticmethod
        ... def analyze(*args) -> AnalyzeResult:
        ...     assert len(args) == 1, "This function accepts one argument only"
        ...     assert args[0].dataType == StringType(), "Only string arguments are supported"
        ...     text = args[0]
        ...     schema = StructType()
        ...     for index, word in enumerate(text.split(" ")):
        ...         schema = schema.add(f"word_{index}")
        ...     return AnalyzeResult(schema=schema)

        Same as above, but using **kwargs to accept the arguments.

        >>> @staticmethod
        ... def analyze(**kwargs) -> AnalyzeResult:
        ...     assert len(kwargs) == 1, "This function accepts one argument only"
        ...     assert "text" in kwargs, "An argument named 'text' is required"
        ...     assert kwargs["text"].dataType == StringType(), "Only strings are supported"
        ...     text = args["text"]
        ...     schema = StructType()
        ...     for index, word in enumerate(text.split(" ")):
        ...         schema = schema.add(f"word_{index}")
        ...     return AnalyzeResult(schema=schema)

        This is an `analyze` implementation that returns a constant output schema, but add
        custom information in the result metadata to be consumed by future __init__ method
        calls:

        >>> @staticmethod
        ... def analyze(text: str) -> AnalyzeResult:
        ...     @dataclass
        ...     class AnalyzeResultWithOtherMetadata(AnalyzeResult):
        ...         num_words: int
        ...         num_articles: int
        ...     words = text.split(" ")
        ...     return AnalyzeResultWithOtherMetadata(
        ...         schema=StructType()
        ...             .add("word", StringType())
        ...             .add('total", IntegerType()),
        ...         num_words=len(words),
        ...         num_articles=len((
        ...             word for word in words
        ...             if word == 'a' or word == 'an' or word == 'the')))

        This is an `analyze` implementation that returns a constant output schema, and also
        requests to select a subset of columns from the input table and for the input table to
        be partitioned across several UDTF calls based on the values of the `date` column.
        A SQL query may this UDTF passing a table argument like "SELECT * FROM udtf(TABLE(t))".
        Then this `analyze` method specifies additional constraints on the input table:
        (1) The input table must be partitioned across several UDTF calls based on the values of
            the month value of each `date` column.
        (2) The rows within each partition will arrive ordered by the `date` column.
        (3) The UDTF will only receive the `date` and `word` columns from the input table.

        >>> @staticmethod
        ... def analyze(*args) -> AnalyzeResult:
        ...     assert len(args) == 1, "This function accepts one argument only"
        ...     assert args[0].isTable, "Only table arguments are supported"
        ...     return AnalyzeResult(
        ...         schema=StructType()
        ...             .add("month", DateType())
        ...             .add('longest_word", IntegerType()),
        ...         partitionBy=[
        ...             PartitioningColumn("extract(month from date)")],
        ...         orderBy=[
        ...             OrderingColumn("date")],
        ...         select=[
        ...             SelectedColumn("date"),
        ...             SelectedColumn(
        ...               name="length(word),
        ...               alias="length_word")])
        """
        ...

    def eval(self, *args: Any) -> Iterator[Any]:
        """
        Evaluates the function using the given input arguments.

        This method is required and must be implemented.

        Argument Mapping:
        - Each provided scalar expression maps to exactly one value in the
          `*args` list.
        - Each provided table argument maps to a pyspark.sql.Row object containing
          the columns in the order they appear in the provided input table,
          and with the names computed by the query analyzer.

        This method is called on every input row, and can produce zero or more
        output rows. Each element in the output tuple corresponds to one column
        specified in the return type of the UDTF.

        Parameters
        ----------
        *args : Any
            Arbitrary positional arguments representing the input to the UDTF.

        Yields
        ------
        tuple
            A tuple, list, or pyspark.sql.Row object representing a single row in the UDTF
            result table. Yield as many times as needed to produce multiple rows.

        Notes
        -----
        - It is also possible for UDTFs to accept the exact arguments expected, along with
          their types.
        - UDTFs can instead accept keyword arguments during the function call if needed.
        - The `eval` method can raise a `SkipRestOfInputTableException` to indicate that the
          UDTF wants to skip consuming all remaining rows from the current partition of the
          input table. This will cause the UDTF to proceed directly to the `terminate` method.
        - The `eval` method can raise any other exception to indicate that the UDTF should be
          aborted entirely. This will cause the UDTF to skip the `terminate` method and proceed
          directly to the `cleanup` method, and then the exception will be propagated to the
          query processor causing the invoking query to fail.

        Examples
        --------
        This `eval` method returns one row and one column for each input.

        >>> def eval(self, x: int):
        ...     yield (x, )

        This `eval` method returns two rows and two columns for each input.

        >>> def eval(self, x: int, y: int):
        ...     yield (x + y, x - y)
        ...     yield (y + x, y - x)

        Same as above, but using *args to accept the arguments:

        >>> def eval(self, *args):
        ...     assert len(args) == 2, "This function accepts two integer arguments only"
        ...     x = args[0]
        ...     y = args[1]
        ...     yield (x + y, x - y)
        ...     yield (y + x, y - x)

        Same as above, but using **kwargs to accept the arguments:

        >>> def eval(self, **kwargs):
        ...     assert len(kwargs) == 2, "This function accepts two integer arguments only"
        ...     x = kwargs["x"]
        ...     y = kwargs["y"]
        ...     yield (x + y, x - y)
        ...     yield (y + x, y - x)
        """
        ...

    def terminate(self) -> Iterator[Any]:
        """
        Called when the UDTF has successfully processed all input rows.

        This method is optional to implement and is useful for performing any
        finalization operations after the UDTF has finished processing
        all rows. It can also be used to yield additional rows if needed.
        Table functions that consume all rows in the entire input partition
        and then compute and return the entire output table can do so from
        this method as well (please be mindful of memory usage when doing
        this).

        If any exceptions occur during input row processing, this method
        won't be called.

        Yields
        ------
        tuple
            A tuple representing a single row in the UDTF result table.
            Yield this if you want to return additional rows during termination.

        Examples
        --------
        >>> def terminate(self) -> Iterator[Any]:
        >>>     yield "done", None
        """
        ...

    def cleanup(self) -> None:
        """
        Invoked after the UDTF completes processing input rows.

        This method is optional to implement and is useful for final cleanup
        regardless of whether the UDTF processed all input rows successfully
        or was aborted due to exceptions.

        Examples
        --------
        >>> def cleanup(self) -> None:
        >>>     self.conn.close()
        """
        ...

出力スキーマの定義#

UDTFの戻り型が出力するテーブルのスキーマを定義します。

@udtfデコレータの後、またはanalyzeメソッドからの結果として指定できます。

StructTypeである必要があります。

StructType().add("c1", StringType())

または、struct型を表すDDL文字列

c1: string

出力行の発行#

次に、evalメソッドとterminateメソッドは、タプル、リスト、またはpyspark.sql.Rowオブジェクトをyieldすることにより、このスキーマに準拠したゼロ個以上の出力行を発行します。

例として、ここでは3つの要素のタプルを提供することで行を返します。

def eval(self, x, y, z):
    yield (x, y, z)

括弧を省略することも許容されます。

def eval(self, x, y, z):
    yield x, y, z

1つの列を持つ行を返す場合は、末尾にカンマを追加することを忘れないでください!

def eval(self, x, y, z):
    yield x,

pyspark.sql.Rowオブジェクトをyieldすることも可能です。

def eval(self, x, y, z)
    from pyspark.sql.types import Row
    yield Row(x, y, z)

これは、Pythonリストを使用してterminateメソッドから出力行をyieldする例です。通常、UDTF評価の前のステップから状態をクラス内に保存することが目的となります。

def terminate(self):
    yield [self.x, self.y, self.z]

SQLでのPython UDTFの登録と使用#

Python UDTFは登録してSQLクエリで使用できます。

from pyspark.sql.functions import udtf

@udtf(returnType="word: string")
class WordSplitter:
    def eval(self, text: str):
        for word in text.split(" "):
            yield (word.strip(),)

# Register the UDTF for use in Spark SQL.
spark.udtf.register("split_words", WordSplitter)

# Example: Using the UDTF in SQL.
spark.sql("SELECT * FROM split_words('hello world')").show()
# +-----+
# | word|
# +-----+
# |hello|
# |world|
# +-----+

# Example: Using the UDTF with a lateral join in SQL.
# The lateral join allows us to reference the columns and aliases
# in the previous FROM clause items as inputs to the UDTF.
spark.sql(
    "SELECT * FROM VALUES ('Hello World'), ('Apache Spark') t(text), "
    "LATERAL split_words(text)"
).show()
# +------------+------+
# |        text|  word|
# +------------+------+
# | Hello World| Hello|
# | Hello World| World|
# |Apache Spark|Apache|
# |Apache Spark| Spark|
# +------------+------+

Arrow最適化#

Apache Arrowは、SparkでJavaとPythonプロセス間で効率的にデータを転送するために使用されるインメモリ列指向データ形式です。Apache Arrowは、Python UDTFではデフォルトで無効になっています。

各入力行がUDTFから大きな結果テーブルを生成する場合、Arrowはパフォーマンスを向上させることができます。

Arrow最適化を有効にするには、spark.sql.execution.pythonUDTF.arrow.enabled設定をtrueに設定します。UDTFを宣言するときにuseArrowパラメータを指定して有効にすることもできます。

from pyspark.sql.functions import udtf

@udtf(returnType="c1: int, c2: int", useArrow=True)
class PlusOne:
    def eval(self, x: int):
        yield x, x + 1

詳細については、PySparkにおけるApache Arrowを参照してください。

スカラー引数を持つUDTFの例#

ここでは、UDTFクラス実装の簡単な例を示します。

# Define the UDTF class and implement the required `eval` method.
class SquareNumbers:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)

UDTFを利用するには、まず@udtfデコレータを使用してインスタンス化する必要があります。

from pyspark.sql.functions import lit, udtf

# Create a UDTF using the class definition and the `udtf` function.
square_num = udtf(SquareNumbers, returnType="num: int, squared: int")

# Invoke the UDTF in PySpark.
square_num(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# |  1|      1|
# |  2|      4|
# |  3|      9|
# +---+-------+

UDTFを作成する別の方法は、pyspark.sql.functions.udtf関数を使用することです。udtf()

from pyspark.sql.functions import lit, udtf

# Define a UDTF using the `udtf` decorator directly on the class.
@udtf(returnType="num: int, squared: int")
class SquareNumbers:
    def eval(self, start: int, end: int):
        for num in range(start, end + 1):
            yield (num, num * num)

# Invoke the UDTF in PySpark using the SquareNumbers class directly.
SquareNumbers(lit(1), lit(3)).show()
# +---+-------+
# |num|squared|
# +---+-------+
# |  1|      1|
# |  2|      4|
# |  3|      9|
# +---+-------+

これは、日付範囲を個別の日付に展開するPython UDTFです。

from datetime import datetime, timedelta
from pyspark.sql.functions import lit, udtf

@udtf(returnType="date: string")
class DateExpander:
    def eval(self, start_date: str, end_date: str):
        current = datetime.strptime(start_date, '%Y-%m-%d')
        end = datetime.strptime(end_date, '%Y-%m-%d')
        while current <= end:
            yield (current.strftime('%Y-%m-%d'),)
            current += timedelta(days=1)

DateExpander(lit("2023-02-25"), lit("2023-03-01")).show()
# +----------+
# |      date|
# +----------+
# |2023-02-25|
# |2023-02-26|
# |2023-02-27|
# |2023-02-28|
# |2023-03-01|
# +----------+

これは、__init__terminateを持つPython UDTFです。

from pyspark.sql.functions import udtf

@udtf(returnType="cnt: int")
class CountUDTF:
    def __init__(self):
        # Initialize the counter to 0 when an instance of the class is created.
        self.count = 0

    def eval(self, x: int):
        # Increment the counter by 1 for each input value received.
        self.count += 1

    def terminate(self):
        # Yield the final count when the UDTF is done processing.
        yield self.count,

spark.udtf.register("count_udtf", CountUDTF)
spark.sql("SELECT * FROM range(0, 10, 1, 1), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# |  9| 10|
# +---+---+
spark.sql("SELECT * FROM range(0, 10, 1, 2), LATERAL count_udtf(id)").show()
# +---+---+
# | id|cnt|
# +---+---+
# |  4|  5|
# |  9|  5|
# +---+---+

入力テーブル引数の受け入れ#

上記のUDTFの例は、整数や文字列などのスカラー入力引数を受け入れる関数を示しています。

しかし、任意のPython UDTFは、入力テーブルを引数として受け入れることもでき、これは同じ関数定義のスカラー入力引数と連携して機能します。入力としてこのようなテーブル引数を1つだけ持つことができます。

次に、SQLクエリはTABLE(t)のように適切なテーブル識別子を括弧で囲んだTABLEキーワードを使用して入力テーブルを提供できます。あるいは、TABLE(SELECT a, b, c FROM t)TABLE(SELECT t1.a, t2.b FROM t1 INNER JOIN t2 USING (key))のようなテーブルサブクエリを渡すこともできます。

入力テーブル引数は、evalメソッドのpyspark.sql.Row引数として表され、入力テーブルの各行に対してevalメソッドが1回呼び出されます。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

@udtf(returnType="id: int")
class FilterUDTF:
    def eval(self, row: Row):
        if row["id"] > 5:
            yield row["id"],

spark.udtf.register("filter_udtf", FilterUDTF)

spark.sql("SELECT * FROM filter_udtf(TABLE(SELECT * FROM range(10)))").show()
# +---+
# | id|
# +---+
# |  6|
# |  7|
# |  8|
# |  9|
# +---+

テーブル引数を持つUDTFを呼び出す場合、SQLクエリは、入力テーブルの1つ以上の列の値に基づいて、入力テーブルを複数のUDTF呼び出しにパーティション分割するように要求できます。これを行うには、TABLE引数の後にPARTITION BY句を関数呼び出しで指定します。これにより、パーティション分割列の各一意の値の組み合わせを持つすべての入力行が、UDTFクラスの1つのインスタンスによって正確に消費されることが保証されます。

単純な列参照に加えて、PARTITION BY句は、入力テーブルの列に基づいた任意の式も受け入れることに注意してください。たとえば、文字列のLENGTHを指定したり、日付から月を抽出したり、2つの値を連結したりできます。

PARTITION BYの代わりにWITH SINGLE PARTITIONを指定して、すべての入力行がUDTFクラスの1つのインスタンスによって正確に消費される単一のパーティションのみを要求することも可能です。

各パーティション内で、UDTFのevalメソッドが入力行を消費する際の順序をオプションで指定できます。これを行うには、上記のPARTITION BYまたはWITH SINGLE PARTITION句の後にORDER BY句を指定します。

from pyspark.sql.functions import udtf
from pyspark.sql.types import Row

# Define and register a UDTF.
@udtf(returnType="a: string, b: int")
class FilterUDTF:
    def __init__(self):
        self.key = ""
        self.max = 0

    def eval(self, row: Row):
        self.key = row["a"]
        self.max = max(self.max, row["b"])

    def terminate(self):
        yield self.key, self.max

spark.udtf.register("filter_udtf", FilterUDTF)

# Create an input table with some example values.
spark.sql("DROP TABLE IF EXISTS values_table")
spark.sql("CREATE TABLE values_table (a STRING, b INT)")
spark.sql("INSERT INTO values_table VALUES ('abc', 2), ('abc', 4), ('def', 6), ('def', 8)")
spark.table("values_table").show()
# +-------+----+
# |     a |  b |
# +-------+----+
# | "abc" | 2  |
# | "abc" | 4  |
# | "def" | 6  |
# | "def" | 8  |
# +-------+----+

# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique value of the `a` column are processed by the same
# instance of the UDTF class. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
    SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY a ORDER BY b) ORDER BY 1
    """).show()
# +-------+----+
# |     a |  b |
# +-------+----+
# | "abc" | 4  |
# | "def" | 8  |
# +-------+----+

# Query the UDTF with the input table as an argument, and a directive to partition the input
# rows such that all rows with each unique result of evaluating the "LENGTH(a)" expression are
# processed by the same instance of the UDTF class. Within each partition, the rows are ordered
# by the `b` column.
spark.sql("""
    SELECT * FROM filter_udtf(TABLE(values_table) PARTITION BY LENGTH(a) ORDER BY b) ORDER BY 1
    """).show()
# +-------+---+
# |     a | b |
# +-------+---+
# | "def" | 8 |
# +-------+---+

# Query the UDTF with the input table as an argument, and a directive to consider all the input
# rows in one single partition such that exactly once instance of the UDTF class consumes all of
# the input rows. Within each partition, the rows are ordered by the `b` column.
spark.sql("""
    SELECT * FROM filter_udtf(TABLE(values_table) WITH SINGLE PARTITION ORDER BY b) ORDER BY 1
    """).show()
# +-------+----+
# |     a |  b |
# +-------+----+
# | "def" | 8 |
# +-------+----+

# Clean up.
spark.sql("DROP TABLE values_table")

SQLクエリでUDTFを呼び出す際にこれらの方法で入力テーブルをパーティション分割する方法にはそれぞれ、UDTFのanalyzeメソッドが同じパーティション分割方法を自動的に指定するための対応する方法があることに注意してください。

たとえば、UDTFをSELECT * FROM udtf(TABLE(t) PARTITION BY a)のように呼び出す代わりに、analyzeメソッドを更新してpartitionBy=[PartitioningColumn("a")]フィールドを設定し、関数をSELECT * FROM udtf(TABLE(t))のように単純に呼び出すことができます。

同様に、SQLクエリでTABLE(t) WITH SINGLE PARTITIONを指定する代わりに、analyzewithSinglePartition=trueフィールドを設定してからTABLE(t)を渡すだけです。

SQLクエリでTABLE(t) ORDER BY bを渡す代わりに、analyzeorderBy=[OrderingColumn("b")]を設定し、TABLE(t)を渡すだけです。

SQLクエリでTABLE(SELECT a FROM t)を渡す代わりに、analyzeselect=[SelectedColumn("a")]を設定し、TABLE(t)を渡すだけです。