Skip to content

openaivec.spark

Asynchronous Spark UDFs for the OpenAI and Azure OpenAI APIs.

This module provides builder classes (ResponsesUDFBuilder, EmbeddingsUDFBuilder) for creating asynchronous Spark UDFs that communicate with either the public OpenAI API or Azure OpenAI using the openaivec.spark subpackage. It supports UDFs for generating responses and creating embeddings asynchronously. The UDFs operate on Spark DataFrames and leverage asyncio for potentially improved performance in I/O-bound operations.

Setup

First, obtain a Spark session:

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Next, instantiate UDF builders with your OpenAI API key (or Azure credentials) and model/deployment names, then register the desired UDFs:

import os
from openaivec.spark import ResponsesUDFBuilder, EmbeddingsUDFBuilder
from pydantic import BaseModel

# Option 1: Using OpenAI
resp_builder = ResponsesUDFBuilder.of_openai(
    api_key=os.getenv("OPENAI_API_KEY"),
    model_name="gpt-4o-mini", # Model for responses
)
emb_builder = EmbeddingsUDFBuilder.of_openai(
    api_key=os.getenv("OPENAI_API_KEY"),
    model_name="text-embedding-3-small", # Model for embeddings
)

# Option 2: Using Azure OpenAI
# resp_builder = ResponsesUDFBuilder.of_azure_openai(
#     api_key=os.getenv("AZURE_OPENAI_KEY"),
#     endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
#     api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
#     model_name="your-resp-deployment-name", # Deployment for responses
# )
# emb_builder = EmbeddingsUDFBuilder.of_azure_openai(
#     api_key=os.getenv("AZURE_OPENAI_KEY"),
#     endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
#     api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
#     model_name="your-emb-deployment-name", # Deployment for embeddings
# )

# Define a Pydantic model for structured responses (optional)
class Translation(BaseModel):
    en: str
    fr: str
    # ... other languages

# Register the asynchronous responses UDF
spark.udf.register(
    "translate_async",
    resp_builder.build(
        instructions="Translate the text to multiple languages.",
        response_format=Translation,
    ),
)

# Register the asynchronous embeddings UDF
spark.udf.register(
    "embed_async",
    emb_builder.build(),
)

You can now invoke the UDFs from Spark SQL:

SELECT
    text,
    translate_async(text) AS translation,
    embed_async(text) AS embedding
FROM your_table;

Note: This module relies on the openaivec.aio.pandas_ext extension for its core asynchronous logic.

EmbeddingsUDFBuilder dataclass

Builder for asynchronous Spark pandas UDFs for creating embeddings.

Configures and builds UDFs that leverage openaivec.aio.pandas_ext.embeddings to generate vector embeddings from OpenAI models asynchronously. An instance stores authentication parameters and the model name.

Attributes:

Name Type Description
api_key str

OpenAI or Azure API key.

endpoint Optional[str]

Azure endpoint base URL. None for public OpenAI.

api_version Optional[str]

Azure API version. Ignored for public OpenAI.

model_name str

Deployment name (Azure) or model name (OpenAI) for embeddings.

Source code in src/openaivec/spark.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
@dataclass(frozen=True)
class EmbeddingsUDFBuilder:
    """Builder for asynchronous Spark pandas UDFs for creating embeddings.

    Configures and builds UDFs that leverage `openaivec.aio.pandas_ext.embeddings`
    to generate vector embeddings from OpenAI models asynchronously.
    An instance stores authentication parameters and the model name.

    Attributes:
        api_key (str): OpenAI or Azure API key.
        endpoint (Optional[str]): Azure endpoint base URL. None for public OpenAI.
        api_version (Optional[str]): Azure API version. Ignored for public OpenAI.
        model_name (str): Deployment name (Azure) or model name (OpenAI) for embeddings.
    """

    # Params for OpenAI SDK
    api_key: str
    endpoint: str | None
    api_version: str | None

    # Params for Embeddings API
    model_name: str

    @classmethod
    def of_openai(cls, api_key: str, model_name: str) -> "EmbeddingsUDFBuilder":
        """Creates a builder configured for the public OpenAI API.

        Args:
            api_key (str): The OpenAI API key.
            model_name (str): The OpenAI model name for embeddings (e.g., "text-embedding-3-small").

        Returns:
            EmbeddingsUDFBuilder: A builder instance configured for OpenAI embeddings.
        """
        return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)

    @classmethod
    def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "EmbeddingsUDFBuilder":
        """Creates a builder configured for Azure OpenAI.

        Args:
            api_key (str): The Azure OpenAI API key.
            endpoint (str): The Azure OpenAI endpoint URL.
            api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
            model_name (str): The Azure OpenAI deployment name for embeddings.

        Returns:
            EmbeddingsUDFBuilder: A builder instance configured for Azure OpenAI embeddings.
        """
        return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)

    def build(self, batch_size: int = 128, max_concurrency: int = 8) -> UserDefinedFunction:
        """Builds the asynchronous pandas UDF for generating embeddings.

        Args:
            batch_size (int): Number of rows per async batch request passed to the underlying
                `pandas_ext` function. Defaults to 128.

        Returns:
            UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously,
                returning an `ArrayType(FloatType())` column.
        """

        @pandas_udf(returnType=ArrayType(FloatType()))
        def embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
            _initialize(self.api_key, self.endpoint, self.api_version)
            pandas_ext.embeddings_model(self.model_name)

            for part in col:
                embeddings: pd.Series = asyncio.run(
                    part.aio.embeddings(batch_size=batch_size, max_concurrency=max_concurrency)
                )
                yield embeddings.map(lambda x: x.tolist())

        return embeddings_udf

build(batch_size=128, max_concurrency=8)

Builds the asynchronous pandas UDF for generating embeddings.

Parameters:

Name Type Description Default
batch_size int

Number of rows per async batch request passed to the underlying pandas_ext function. Defaults to 128.

128

Returns:

Name Type Description
UserDefinedFunction UserDefinedFunction

A Spark pandas UDF configured to generate embeddings asynchronously, returning an ArrayType(FloatType()) column.

Source code in src/openaivec/spark.py
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
def build(self, batch_size: int = 128, max_concurrency: int = 8) -> UserDefinedFunction:
    """Builds the asynchronous pandas UDF for generating embeddings.

    Args:
        batch_size (int): Number of rows per async batch request passed to the underlying
            `pandas_ext` function. Defaults to 128.

    Returns:
        UserDefinedFunction: A Spark pandas UDF configured to generate embeddings asynchronously,
            returning an `ArrayType(FloatType())` column.
    """

    @pandas_udf(returnType=ArrayType(FloatType()))
    def embeddings_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
        _initialize(self.api_key, self.endpoint, self.api_version)
        pandas_ext.embeddings_model(self.model_name)

        for part in col:
            embeddings: pd.Series = asyncio.run(
                part.aio.embeddings(batch_size=batch_size, max_concurrency=max_concurrency)
            )
            yield embeddings.map(lambda x: x.tolist())

    return embeddings_udf

of_azure_openai(api_key, endpoint, api_version, model_name) classmethod

Creates a builder configured for Azure OpenAI.

Parameters:

Name Type Description Default
api_key str

The Azure OpenAI API key.

required
endpoint str

The Azure OpenAI endpoint URL.

required
api_version str

The Azure OpenAI API version (e.g., "2024-02-01").

required
model_name str

The Azure OpenAI deployment name for embeddings.

required

Returns:

Name Type Description
EmbeddingsUDFBuilder EmbeddingsUDFBuilder

A builder instance configured for Azure OpenAI embeddings.

Source code in src/openaivec/spark.py
370
371
372
373
374
375
376
377
378
379
380
381
382
383
@classmethod
def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "EmbeddingsUDFBuilder":
    """Creates a builder configured for Azure OpenAI.

    Args:
        api_key (str): The Azure OpenAI API key.
        endpoint (str): The Azure OpenAI endpoint URL.
        api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
        model_name (str): The Azure OpenAI deployment name for embeddings.

    Returns:
        EmbeddingsUDFBuilder: A builder instance configured for Azure OpenAI embeddings.
    """
    return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)

of_openai(api_key, model_name) classmethod

Creates a builder configured for the public OpenAI API.

Parameters:

Name Type Description Default
api_key str

The OpenAI API key.

required
model_name str

The OpenAI model name for embeddings (e.g., "text-embedding-3-small").

required

Returns:

Name Type Description
EmbeddingsUDFBuilder EmbeddingsUDFBuilder

A builder instance configured for OpenAI embeddings.

Source code in src/openaivec/spark.py
357
358
359
360
361
362
363
364
365
366
367
368
@classmethod
def of_openai(cls, api_key: str, model_name: str) -> "EmbeddingsUDFBuilder":
    """Creates a builder configured for the public OpenAI API.

    Args:
        api_key (str): The OpenAI API key.
        model_name (str): The OpenAI model name for embeddings (e.g., "text-embedding-3-small").

    Returns:
        EmbeddingsUDFBuilder: A builder instance configured for OpenAI embeddings.
    """
    return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)

ResponsesUDFBuilder dataclass

Builder for asynchronous Spark pandas UDFs for generating responses.

Configures and builds UDFs that leverage openaivec.aio.pandas_ext.responses to generate text or structured responses from OpenAI models asynchronously. An instance stores authentication parameters and the model name.

Attributes:

Name Type Description
api_key str

OpenAI or Azure API key.

endpoint Optional[str]

Azure endpoint base URL. None for public OpenAI.

api_version Optional[str]

Azure API version. Ignored for public OpenAI.

model_name str

Deployment name (Azure) or model name (OpenAI) for responses.

Source code in src/openaivec/spark.py
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
@dataclass(frozen=True)
class ResponsesUDFBuilder:
    """Builder for asynchronous Spark pandas UDFs for generating responses.

    Configures and builds UDFs that leverage `openaivec.aio.pandas_ext.responses`
    to generate text or structured responses from OpenAI models asynchronously.
    An instance stores authentication parameters and the model name.

    Attributes:
        api_key (str): OpenAI or Azure API key.
        endpoint (Optional[str]): Azure endpoint base URL. None for public OpenAI.
        api_version (Optional[str]): Azure API version. Ignored for public OpenAI.
        model_name (str): Deployment name (Azure) or model name (OpenAI) for responses.
    """

    # Params for OpenAI SDK
    api_key: str
    endpoint: str | None
    api_version: str | None

    # Params for Responses API
    model_name: str

    @classmethod
    def of_openai(cls, api_key: str, model_name: str) -> "ResponsesUDFBuilder":
        """Creates a builder configured for the public OpenAI API.

        Args:
            api_key (str): The OpenAI API key.
            model_name (str): The OpenAI model name for responses (e.g., "gpt-4o-mini").

        Returns:
            ResponsesUDFBuilder: A builder instance configured for OpenAI responses.
        """
        return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)

    @classmethod
    def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "ResponsesUDFBuilder":
        """Creates a builder configured for Azure OpenAI.

        Args:
            api_key (str): The Azure OpenAI API key.
            endpoint (str): The Azure OpenAI endpoint URL.
            api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
            model_name (str): The Azure OpenAI deployment name for responses.

        Returns:
            ResponsesUDFBuilder: A builder instance configured for Azure OpenAI responses.
        """
        return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)

    def build(
        self,
        instructions: str,
        response_format: Type[T] = str,
        batch_size: int = 128,  # Default batch size for async might differ
        temperature: float = 0.0,
        top_p: float = 1.0,
        max_concurrency: int = 8,
    ) -> UserDefinedFunction:
        """Builds the asynchronous pandas UDF for generating responses.

        Args:
            instructions (str): The system prompt or instructions for the model.
            response_format (Type[T]): The desired output format. Either `str` for plain text
                or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
            batch_size (int): Number of rows per async batch request passed to the underlying
                `pandas_ext` function. Defaults to 128.
            temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
            top_p (float): Nucleus sampling parameter. Defaults to 1.0.

        Returns:
            UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
                Output schema is `StringType` or a struct derived from `response_format`.

        Raises:
            ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
        """
        if issubclass(response_format, BaseModel):
            spark_schema = _pydantic_to_spark_schema(response_format)
            json_schema_string = serialize_base_model(response_format)

            @pandas_udf(returnType=spark_schema)
            def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
                _initialize(self.api_key, self.endpoint, self.api_version)
                pandas_ext.responses_model(self.model_name)

                for part in col:
                    predictions: pd.Series = asyncio.run(
                        part.aio.responses(
                            instructions=instructions,
                            response_format=deserialize_base_model(json_schema_string),
                            batch_size=batch_size,
                            temperature=temperature,
                            top_p=top_p,
                            max_concurrency=max_concurrency,
                        )
                    )
                    yield pd.DataFrame(predictions.map(_safe_dump).tolist())

            return structure_udf

        elif issubclass(response_format, str):

            @pandas_udf(returnType=StringType())
            def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
                _initialize(self.api_key, self.endpoint, self.api_version)
                pandas_ext.responses_model(self.model_name)

                for part in col:
                    predictions: pd.Series = asyncio.run(
                        part.aio.responses(
                            instructions=instructions,
                            response_format=str,
                            batch_size=batch_size,
                            temperature=temperature,
                            top_p=top_p,
                            max_concurrency=max_concurrency,
                        )
                    )
                    yield predictions.map(_safe_cast_str)

            return string_udf

        else:
            raise ValueError(f"Unsupported response_format: {response_format}")

build(instructions, response_format=str, batch_size=128, temperature=0.0, top_p=1.0, max_concurrency=8)

Builds the asynchronous pandas UDF for generating responses.

Parameters:

Name Type Description Default
instructions str

The system prompt or instructions for the model.

required
response_format Type[T]

The desired output format. Either str for plain text or a Pydantic BaseModel for structured JSON output. Defaults to str.

str
batch_size int

Number of rows per async batch request passed to the underlying pandas_ext function. Defaults to 128.

128
temperature float

Sampling temperature (0.0 to 2.0). Defaults to 0.0.

0.0
top_p float

Nucleus sampling parameter. Defaults to 1.0.

1.0

Returns:

Name Type Description
UserDefinedFunction UserDefinedFunction

A Spark pandas UDF configured to generate responses asynchronously. Output schema is StringType or a struct derived from response_format.

Raises:

Type Description
ValueError

If response_format is not str or a Pydantic BaseModel.

Source code in src/openaivec/spark.py
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def build(
    self,
    instructions: str,
    response_format: Type[T] = str,
    batch_size: int = 128,  # Default batch size for async might differ
    temperature: float = 0.0,
    top_p: float = 1.0,
    max_concurrency: int = 8,
) -> UserDefinedFunction:
    """Builds the asynchronous pandas UDF for generating responses.

    Args:
        instructions (str): The system prompt or instructions for the model.
        response_format (Type[T]): The desired output format. Either `str` for plain text
            or a Pydantic `BaseModel` for structured JSON output. Defaults to `str`.
        batch_size (int): Number of rows per async batch request passed to the underlying
            `pandas_ext` function. Defaults to 128.
        temperature (float): Sampling temperature (0.0 to 2.0). Defaults to 0.0.
        top_p (float): Nucleus sampling parameter. Defaults to 1.0.

    Returns:
        UserDefinedFunction: A Spark pandas UDF configured to generate responses asynchronously.
            Output schema is `StringType` or a struct derived from `response_format`.

    Raises:
        ValueError: If `response_format` is not `str` or a Pydantic `BaseModel`.
    """
    if issubclass(response_format, BaseModel):
        spark_schema = _pydantic_to_spark_schema(response_format)
        json_schema_string = serialize_base_model(response_format)

        @pandas_udf(returnType=spark_schema)
        def structure_udf(col: Iterator[pd.Series]) -> Iterator[pd.DataFrame]:
            _initialize(self.api_key, self.endpoint, self.api_version)
            pandas_ext.responses_model(self.model_name)

            for part in col:
                predictions: pd.Series = asyncio.run(
                    part.aio.responses(
                        instructions=instructions,
                        response_format=deserialize_base_model(json_schema_string),
                        batch_size=batch_size,
                        temperature=temperature,
                        top_p=top_p,
                        max_concurrency=max_concurrency,
                    )
                )
                yield pd.DataFrame(predictions.map(_safe_dump).tolist())

        return structure_udf

    elif issubclass(response_format, str):

        @pandas_udf(returnType=StringType())
        def string_udf(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
            _initialize(self.api_key, self.endpoint, self.api_version)
            pandas_ext.responses_model(self.model_name)

            for part in col:
                predictions: pd.Series = asyncio.run(
                    part.aio.responses(
                        instructions=instructions,
                        response_format=str,
                        batch_size=batch_size,
                        temperature=temperature,
                        top_p=top_p,
                        max_concurrency=max_concurrency,
                    )
                )
                yield predictions.map(_safe_cast_str)

        return string_udf

    else:
        raise ValueError(f"Unsupported response_format: {response_format}")

of_azure_openai(api_key, endpoint, api_version, model_name) classmethod

Creates a builder configured for Azure OpenAI.

Parameters:

Name Type Description Default
api_key str

The Azure OpenAI API key.

required
endpoint str

The Azure OpenAI endpoint URL.

required
api_version str

The Azure OpenAI API version (e.g., "2024-02-01").

required
model_name str

The Azure OpenAI deployment name for responses.

required

Returns:

Name Type Description
ResponsesUDFBuilder ResponsesUDFBuilder

A builder instance configured for Azure OpenAI responses.

Source code in src/openaivec/spark.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
@classmethod
def of_azure_openai(cls, api_key: str, endpoint: str, api_version: str, model_name: str) -> "ResponsesUDFBuilder":
    """Creates a builder configured for Azure OpenAI.

    Args:
        api_key (str): The Azure OpenAI API key.
        endpoint (str): The Azure OpenAI endpoint URL.
        api_version (str): The Azure OpenAI API version (e.g., "2024-02-01").
        model_name (str): The Azure OpenAI deployment name for responses.

    Returns:
        ResponsesUDFBuilder: A builder instance configured for Azure OpenAI responses.
    """
    return cls(api_key=api_key, endpoint=endpoint, api_version=api_version, model_name=model_name)

of_openai(api_key, model_name) classmethod

Creates a builder configured for the public OpenAI API.

Parameters:

Name Type Description Default
api_key str

The OpenAI API key.

required
model_name str

The OpenAI model name for responses (e.g., "gpt-4o-mini").

required

Returns:

Name Type Description
ResponsesUDFBuilder ResponsesUDFBuilder

A builder instance configured for OpenAI responses.

Source code in src/openaivec/spark.py
229
230
231
232
233
234
235
236
237
238
239
240
@classmethod
def of_openai(cls, api_key: str, model_name: str) -> "ResponsesUDFBuilder":
    """Creates a builder configured for the public OpenAI API.

    Args:
        api_key (str): The OpenAI API key.
        model_name (str): The OpenAI model name for responses (e.g., "gpt-4o-mini").

    Returns:
        ResponsesUDFBuilder: A builder instance configured for OpenAI responses.
    """
    return cls(api_key=api_key, endpoint=None, api_version=None, model_name=model_name)

count_tokens_udf(model_name='gpt-4o')

Create a pandas‑UDF that counts tokens for every string cell.

The UDF uses tiktoken to approximate tokenisation and caches the resulting Encoding object per executor.

Parameters:

Name Type Description Default
model_name str

Model identifier understood by tiktoken.

'gpt-4o'

Returns:

Type Description

A pandas UDF producing an IntegerType column with token counts.

Source code in src/openaivec/spark.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
def count_tokens_udf(model_name: str = "gpt-4o"):
    """Create a pandas‑UDF that counts tokens for every string cell.

    The UDF uses *tiktoken* to approximate tokenisation and caches the
    resulting ``Encoding`` object per executor.

    Args:
        model_name: Model identifier understood by ``tiktoken``.

    Returns:
        A pandas UDF producing an ``IntegerType`` column with token counts.
    """

    @pandas_udf(IntegerType())
    def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
        global _TIKTOKEN_ENC
        if _TIKTOKEN_ENC is None:
            _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)

        for part in col:
            yield part.map(lambda x: len(_TIKTOKEN_ENC.encode(x)) if isinstance(x, str) else 0)

    return fn

split_to_chunks_udf(model_name, max_tokens, sep)

Create a pandas‑UDF that splits text into token‑bounded chunks.

Parameters:

Name Type Description Default
model_name str

Model identifier passed to tiktoken.

required
max_tokens int

Maximum tokens allowed per chunk.

required
sep List[str]

Ordered list of separator strings used by TextChunker.

required

Returns:

Type Description

A pandas UDF producing an ArrayType(StringType()) column whose values are lists of chunks respecting the max_tokens limit.

Source code in src/openaivec/spark.py
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
def split_to_chunks_udf(model_name: str, max_tokens: int, sep: List[str]):
    """Create a pandas‑UDF that splits text into token‑bounded chunks.

    Args:
        model_name: Model identifier passed to *tiktoken*.
        max_tokens: Maximum tokens allowed per chunk.
        sep: Ordered list of separator strings used by ``TextChunker``.

    Returns:
        A pandas UDF producing an ``ArrayType(StringType())`` column whose
            values are lists of chunks respecting the ``max_tokens`` limit.
    """

    @pandas_udf(ArrayType(StringType()))
    def fn(col: Iterator[pd.Series]) -> Iterator[pd.Series]:
        global _TIKTOKEN_ENC
        if _TIKTOKEN_ENC is None:
            _TIKTOKEN_ENC = tiktoken.encoding_for_model(model_name)

        chunker = TextChunker(_TIKTOKEN_ENC)

        for part in col:
            yield part.map(lambda x: chunker.split(x, max_tokens=max_tokens, sep=sep) if isinstance(x, str) else [])

    return fn