From c680d60e7a6422918a9227f1a61e2623c5aeed15 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 1 Feb 2026 10:21:46 -0500 Subject: [PATCH] Implement all CSV options with a builder pattern --- docs/source/user-guide/io/csv.rst | 19 +++ examples/csv-read-options.py | 96 +++++++++++ python/datafusion/__init__.py | 3 + python/datafusion/context.py | 156 +++++++++-------- python/datafusion/io.py | 11 +- python/datafusion/options.py | 273 ++++++++++++++++++++++++++++++ python/tests/test_context.py | 65 +++++++ python/tests/test_sql.py | 2 +- src/context.rs | 74 ++------ src/lib.rs | 5 + src/options.rs | 142 ++++++++++++++++ 11 files changed, 712 insertions(+), 134 deletions(-) create mode 100644 examples/csv-read-options.py create mode 100644 python/datafusion/options.py create mode 100644 src/options.rs diff --git a/docs/source/user-guide/io/csv.rst b/docs/source/user-guide/io/csv.rst index 144b6615c..db31caa0f 100644 --- a/docs/source/user-guide/io/csv.rst +++ b/docs/source/user-guide/io/csv.rst @@ -36,3 +36,22 @@ An alternative is to use :py:func:`~datafusion.context.SessionContext.register_c ctx.register_csv("file", "file.csv") df = ctx.table("file") + +If you require additional control over how to read the CSV file, you can use +:py:class:`~datafusion.options.CsvReadOptions` to set a variety of options. + +.. code-block:: python + + from datafusion import CsvReadOptions + options = ( + CsvReadOptions() + .with_has_header(True) # File contains a header row + .with_delimiter(";") # Use ; as the delimiter instead of , + .with_comment("#") # Skip lines starting with # + .with_escape("\\") # Escape character + .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL + .with_truncated_rows(True) # Allow rows to have incomplete columns + .with_file_compression_type("gzip") # Read gzipped CSV + .with_file_extension(".gz") # File extension other than .csv + ) + df = ctx.read_csv("data.csv.gz", options=options) diff --git a/examples/csv-read-options.py b/examples/csv-read-options.py new file mode 100644 index 000000000..a5952d950 --- /dev/null +++ b/examples/csv-read-options.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example demonstrating CsvReadOptions usage.""" + +from datafusion import CsvReadOptions, SessionContext + +# Create a SessionContext +ctx = SessionContext() + +# Example 1: Using CsvReadOptions with default values +print("Example 1: Default CsvReadOptions") +options = CsvReadOptions() +df = ctx.read_csv("data.csv", options=options) + +# Example 2: Using CsvReadOptions with custom parameters +print("\nExample 2: Custom CsvReadOptions") +options = CsvReadOptions( + has_header=True, + delimiter=",", + quote='"', + schema_infer_max_records=1000, + file_extension=".csv", +) +df = ctx.read_csv("data.csv", options=options) + +# Example 3: Using the builder pattern (recommended for readability) +print("\nExample 3: Builder pattern") +options = ( + CsvReadOptions() + .with_has_header(True) # noqa: FBT003 + .with_delimiter("|") + .with_quote("'") + .with_schema_infer_max_records(500) + .with_truncated_rows(False) # noqa: FBT003 + .with_newlines_in_values(True) # noqa: FBT003 +) +df = ctx.read_csv("data.csv", options=options) + +# Example 4: Advanced options +print("\nExample 4: Advanced options") +options = ( + CsvReadOptions() + .with_has_header(True) # noqa: FBT003 + .with_delimiter(",") + .with_comment("#") # Skip lines starting with # + .with_escape("\\") # Escape character + .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL + .with_truncated_rows(True) # noqa: FBT003 + .with_file_compression_type("gzip") # Read gzipped CSV + .with_file_extension(".gz") +) +df = ctx.read_csv("data.csv.gz", options=options) + +# Example 5: Register CSV table with options +print("\nExample 5: Register CSV table") +options = CsvReadOptions().with_has_header(True).with_delimiter(",") # noqa: FBT003 +ctx.register_csv("my_table", "data.csv", options=options) +df = ctx.sql("SELECT * FROM my_table") + +# Example 6: Backward compatibility (without options) +print("\nExample 6: Backward compatibility") +# Still works the old way! +df = ctx.read_csv("data.csv", has_header=True, delimiter=",") + +print("\nAll examples completed!") +print("\nFor all available options, see the CsvReadOptions documentation:") +print(" - has_header: bool") +print(" - delimiter: str") +print(" - quote: str") +print(" - terminator: str | None") +print(" - escape: str | None") +print(" - comment: str | None") +print(" - newlines_in_values: bool") +print(" - schema: pa.Schema | None") +print(" - schema_infer_max_records: int") +print(" - file_extension: str") +print(" - table_partition_cols: list[tuple[str, pa.DataType]]") +print(" - file_compression_type: str") +print(" - file_sort_order: list[list[SortExpr]]") +print(" - null_regex: str | None") +print(" - truncated_rows: bool") diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 784d4ccc6..2e6f81166 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -54,6 +54,7 @@ from .dataframe_formatter import configure_formatter from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet +from .options import CsvReadOptions from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( @@ -75,6 +76,7 @@ "AggregateUDF", "Catalog", "Config", + "CsvReadOptions", "DFSchema", "DataFrame", "DataFrameWriteOptions", @@ -106,6 +108,7 @@ "lit", "literal", "object_store", + "options", "read_avro", "read_csv", "read_json", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 7dc06eb17..260e9805b 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -34,6 +34,11 @@ from datafusion.catalog import Catalog from datafusion.dataframe import DataFrame from datafusion.expr import sort_list_to_raw_sort_list +from datafusion.options import ( + DEFAULT_MAX_INFER_SCHEMA, + CsvReadOptions, + _convert_table_partition_cols, +) from datafusion.record_batch import RecordBatchStream from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal @@ -584,7 +589,7 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_listing_table( name, str(path), @@ -905,7 +910,7 @@ def register_parquet( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_parquet( name, str(path), @@ -924,9 +929,10 @@ def register_csv( schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", - schema_infer_max_records: int = 1000, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, file_extension: str = ".csv", file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> None: """Register a CSV file as a table. @@ -946,18 +952,46 @@ def register_csv( file_extension: File extension; only files with this extension are selected for data input. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. """ - path = [str(p) for p in path] if isinstance(path, list) else str(path) + path_arg = [str(p) for p in path] if isinstance(path, list) else str(path) + + if options is not None and ( + schema is not None + or not has_header + or delimiter != "," + or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA + or file_extension != ".csv" + or file_compression_type is not None + ): + message = ( + "Combining CsvReadOptions parameter with additional options " + "is not supported. Use CsvReadOptions to set parameters." + ) + warnings.warn( + message, + category=UserWarning, + stacklevel=2, + ) + + options = ( + options + if options is not None + else CsvReadOptions( + schema=schema, + has_header=has_header, + delimiter=delimiter, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + file_compression_type=file_compression_type, + ) + ) self.ctx.register_csv( name, - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - file_compression_type, + path_arg, + options.to_inner(), ) def register_json( @@ -988,7 +1022,7 @@ def register_json( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_json( name, str(path), @@ -1021,7 +1055,7 @@ def register_avro( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -1101,7 +1135,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) return DataFrame( self.ctx.read_json( str(path), @@ -1119,10 +1153,11 @@ def read_csv( schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", - schema_infer_max_records: int = 1000, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, file_extension: str = ".csv", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> DataFrame: """Read a CSV data source. @@ -1140,26 +1175,51 @@ def read_csv( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. Returns: DataFrame representation of the read CSV files """ - if table_partition_cols is None: - table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + path_arg = [str(p) for p in path] if isinstance(path, list) else str(path) + + if options is not None and ( + schema is not None + or not has_header + or delimiter != "," + or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA + or file_extension != ".csv" + or table_partition_cols is not None + or file_compression_type is not None + ): + message = ( + "Combining CsvReadOptions parameter with additional options " + "is not supported. Use CsvReadOptions to set parameters." + ) + warnings.warn( + message, + category=UserWarning, + stacklevel=2, + ) - path = [str(p) for p in path] if isinstance(path, list) else str(path) + options = ( + options + if options is not None + else CsvReadOptions( + schema=schema, + has_header=has_header, + delimiter=delimiter, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + table_partition_cols=table_partition_cols, + file_compression_type=file_compression_type, + ) + ) return DataFrame( self.ctx.read_csv( - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, + path_arg, + options.to_inner(), ) ) @@ -1197,7 +1257,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) file_sort_order = self._convert_file_sort_order(file_sort_order) return DataFrame( self.ctx.read_parquet( @@ -1231,7 +1291,7 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] - file_partition_cols = self._convert_table_partition_cols(file_partition_cols) + file_partition_cols = _convert_table_partition_cols(file_partition_cols) return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) @@ -1263,41 +1323,3 @@ def _convert_file_sort_order( if file_sort_order is not None else None ) - - @staticmethod - def _convert_table_partition_cols( - table_partition_cols: list[tuple[str, str | pa.DataType]], - ) -> list[tuple[str, pa.DataType]]: - warn = False - converted_table_partition_cols = [] - - for col, data_type in table_partition_cols: - if isinstance(data_type, str): - warn = True - if data_type == "string": - converted_data_type = pa.string() - elif data_type == "int": - converted_data_type = pa.int32() - else: - message = ( - f"Unsupported literal data type '{data_type}' for partition " - "column. Supported types are 'string' and 'int'" - ) - raise ValueError(message) - else: - converted_data_type = data_type - - converted_table_partition_cols.append((col, converted_data_type)) - - if warn: - message = ( - "using literals for table_partition_cols data types is deprecated," - "use pyarrow types instead" - ) - warnings.warn( - message, - category=DeprecationWarning, - stacklevel=2, - ) - - return converted_table_partition_cols diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 67dbc730f..4f9c3c516 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -31,6 +31,8 @@ from datafusion.dataframe import DataFrame from datafusion.expr import Expr + from .options import CsvReadOptions + def read_parquet( path: str | pathlib.Path, @@ -126,6 +128,7 @@ def read_csv( file_extension: str = ".csv", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> DataFrame: """Read a CSV data source. @@ -147,15 +150,12 @@ def read_csv( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. Returns: DataFrame representation of the read CSV files """ - if table_partition_cols is None: - table_partition_cols = [] - - path = [str(p) for p in path] if isinstance(path, list) else str(path) - return SessionContext.global_ctx().read_csv( path, schema, @@ -165,6 +165,7 @@ def read_csv( file_extension, table_partition_cols, file_compression_type, + options, ) diff --git a/python/datafusion/options.py b/python/datafusion/options.py new file mode 100644 index 000000000..648cffa76 --- /dev/null +++ b/python/datafusion/options.py @@ -0,0 +1,273 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Options for reading various file formats.""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import pyarrow as pa + +if TYPE_CHECKING: + from datafusion.expr import SortExpr + +from ._internal import options + +__all__ = ["CsvReadOptions"] + +DEFAULT_MAX_INFER_SCHEMA = 1000 + + +class CsvReadOptions: + """Options for reading CSV files. + + This class provides a builder pattern for configuring CSV reading options. + All methods starting with ``with_`` return ``self`` to allow method chaining. + """ + + def __init__( + self, + *, + has_header: bool = True, + delimiter: str = ",", + quote: str = '"', + terminator: str | None = None, + escape: str | None = None, + comment: str | None = None, + newlines_in_values: bool = False, + schema: pa.Schema | None = None, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, + file_extension: str = ".csv", + table_partition_cols: list[tuple[str, pa.DataType]] | None = None, + file_compression_type: str = "", + file_sort_order: list[list[SortExpr]] | None = None, + null_regex: str | None = None, + truncated_rows: bool = False, + ) -> None: + """Initialize CsvReadOptions. + + Args: + has_header: Does the CSV file have a header row? If schema inference + is run on a file with no headers, default column names are created. + delimiter: Column delimiter character. Must be a single ASCII character. + quote: Quote character for fields containing delimiters or newlines. + Must be a single ASCII character. + terminator: Optional line terminator character. If ``None``, uses CRLF. + Must be a single ASCII character. + escape: Optional escape character for quotes. Must be a single ASCII + character. + comment: If specified, lines beginning with this character are ignored. + Must be a single ASCII character. + newlines_in_values: Whether newlines in quoted values are supported. + Parsing newlines in quoted values may be affected by execution + behavior such as parallel file scanning. Setting this to ``True`` + ensures that newlines in values are parsed successfully, which may + reduce performance. + schema: Optional PyArrow schema representing the CSV files. If ``None``, + the CSV reader will try to infer it based on data in the file. + schema_infer_max_records: Maximum number of rows to read from CSV files + for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns as a list of tuples of + (column_name, data_type). + file_compression_type: File compression type. Supported values are + ``"gzip"``, ``"bz2"``, ``"xz"``, ``"zstd"``, or empty string for + uncompressed. + file_sort_order: Optional sort order of the files as a list of sort + expressions per file. + null_regex: Optional regex pattern to match null values in the CSV. + truncated_rows: Whether to allow truncated rows when parsing. By default + this is ``False`` and will error if the CSV rows have different + lengths. When set to ``True``, it will allow records with less than + the expected number of columns and fill the missing columns with + nulls. If the record's schema is not nullable, it will still return + an error. + """ + validate_single_character("delimiter", delimiter) + validate_single_character("quote", quote) + validate_single_character("terminator", terminator) + validate_single_character("escape", escape) + validate_single_character("comment", comment) + + self.has_header = has_header + self.delimiter = delimiter + self.quote = quote + self.terminator = terminator + self.escape = escape + self.comment = comment + self.newlines_in_values = newlines_in_values + self.schema = schema + self.schema_infer_max_records = schema_infer_max_records + self.file_extension = file_extension + self.table_partition_cols = table_partition_cols or [] + self.file_compression_type = file_compression_type + self.file_sort_order = file_sort_order or [] + self.null_regex = null_regex + self.truncated_rows = truncated_rows + + def with_has_header(self, has_header: bool) -> CsvReadOptions: + """Configure whether the CSV has a header row.""" + self.has_header = has_header + return self + + def with_delimiter(self, delimiter: str) -> CsvReadOptions: + """Configure the column delimiter.""" + self.delimiter = delimiter + return self + + def with_quote(self, quote: str) -> CsvReadOptions: + """Configure the quote character.""" + self.quote = quote + return self + + def with_terminator(self, terminator: str | None) -> CsvReadOptions: + """Configure the line terminator character.""" + self.terminator = terminator + return self + + def with_escape(self, escape: str | None) -> CsvReadOptions: + """Configure the escape character.""" + self.escape = escape + return self + + def with_comment(self, comment: str | None) -> CsvReadOptions: + """Configure the comment character.""" + self.comment = comment + return self + + def with_newlines_in_values(self, newlines_in_values: bool) -> CsvReadOptions: + """Configure whether newlines in values are supported.""" + self.newlines_in_values = newlines_in_values + return self + + def with_schema(self, schema: pa.Schema | None) -> CsvReadOptions: + """Configure the schema.""" + self.schema = schema + return self + + def with_schema_infer_max_records( + self, schema_infer_max_records: int + ) -> CsvReadOptions: + """Configure maximum records for schema inference.""" + self.schema_infer_max_records = schema_infer_max_records + return self + + def with_file_extension(self, file_extension: str) -> CsvReadOptions: + """Configure the file extension filter.""" + self.file_extension = file_extension + return self + + def with_table_partition_cols( + self, table_partition_cols: list[tuple[str, pa.DataType]] + ) -> CsvReadOptions: + """Configure table partition columns.""" + self.table_partition_cols = table_partition_cols + return self + + def with_file_compression_type(self, file_compression_type: str) -> CsvReadOptions: + """Configure file compression type.""" + self.file_compression_type = file_compression_type + return self + + def with_file_sort_order( + self, file_sort_order: list[list[SortExpr]] + ) -> CsvReadOptions: + """Configure file sort order.""" + self.file_sort_order = file_sort_order + return self + + def with_null_regex(self, null_regex: str | None) -> CsvReadOptions: + """Configure null value regex pattern.""" + self.null_regex = null_regex + return self + + def with_truncated_rows(self, truncated_rows: bool) -> CsvReadOptions: + """Configure whether to allow truncated rows.""" + self.truncated_rows = truncated_rows + return self + + def to_inner(self) -> options.CsvReadOptions: + """Convert this object into the underlying Rust structure. + + This is intended for internal use only. + """ + return options.CsvReadOptions( + has_header=self.has_header, + delimiter=ord(self.delimiter[0]) if self.delimiter else ord(","), + quote=ord(self.quote[0]) if self.quote else ord('"'), + terminator=ord(self.terminator[0]) if self.terminator else None, + escape=ord(self.escape[0]) if self.escape else None, + comment=ord(self.comment[0]) if self.comment else None, + newlines_in_values=self.newlines_in_values, + schema=self.schema, + schema_infer_max_records=self.schema_infer_max_records, + file_extension=self.file_extension, + table_partition_cols=_convert_table_partition_cols( + self.table_partition_cols + ), + file_compression_type=self.file_compression_type or "", + file_sort_order=self.file_sort_order or [], + null_regex=self.null_regex, + truncated_rows=self.truncated_rows, + ) + + +def validate_single_character(name: str, value: str | None) -> None: + if value is not None and len(value) != 1: + message = f"{name} must be a single character" + raise ValueError(message) + + +def _convert_table_partition_cols( + table_partition_cols: list[tuple[str, str | pa.DataType]], +) -> list[tuple[str, pa.DataType]]: + warn = False + converted_table_partition_cols = [] + + for col, data_type in table_partition_cols: + if isinstance(data_type, str): + warn = True + if data_type == "string": + converted_data_type = pa.string() + elif data_type == "int": + converted_data_type = pa.int32() + else: + message = ( + f"Unsupported literal data type '{data_type}' for partition " + "column. Supported types are 'string' and 'int'" + ) + raise ValueError(message) + else: + converted_data_type = data_type + + converted_table_partition_cols.append((col, converted_data_type)) + + if warn: + message = ( + "using literals for table_partition_cols data types is deprecated," + "use pyarrow types instead" + ) + warnings.warn( + message, + category=DeprecationWarning, + stacklevel=2, + ) + + return converted_table_partition_cols diff --git a/python/tests/test_context.py b/python/tests/test_context.py index bd65305ed..57e785c05 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -710,3 +710,68 @@ def test_create_dataframe_with_global_ctx(batch): result = df.collect()[0].column(0) assert result == pa.array([4, 5, 6]) + + +def test_csv_read_options_builder_pattern(): + """Test CsvReadOptions builder pattern.""" + from datafusion import CsvReadOptions + + options = ( + CsvReadOptions() + .with_has_header(False) # noqa: FBT003 + .with_delimiter("|") + .with_quote("'") + .with_schema_infer_max_records(2000) + .with_truncated_rows(True) # noqa: FBT003 + .with_newlines_in_values(True) # noqa: FBT003 + .with_file_extension(".tsv") + ) + assert options.has_header is False + assert options.delimiter == "|" + assert options.quote == "'" + assert options.schema_infer_max_records == 2000 + assert options.truncated_rows is True + assert options.newlines_in_values is True + assert options.file_extension == ".tsv" + + +@pytest.mark.parametrize( + ("as_read", "global_ctx"), + [ + (True, True), + (True, False), + (False, False), + ], +) +def test_read_csv_with_options(tmp_path, as_read, global_ctx): + """Test reading CSV with CsvReadOptions.""" + from datafusion import CsvReadOptions, SessionContext + + # Create a test CSV file + csv_path = tmp_path / "test.csv" + csv_content = "name;age;city\nAlice;30;New York\nBob;25\n#Charlie;35;Paris" + csv_path.write_text(csv_content) + + ctx = SessionContext() + + # Test with CsvReadOptions + options = CsvReadOptions( + has_header=True, delimiter=";", comment="#", truncated_rows=True + ) + + if as_read: + if global_ctx: + from datafusion.io import read_csv + + df = read_csv(str(csv_path), options=options) + else: + df = ctx.read_csv(str(csv_path), options=options) + else: + ctx.register_csv("test_table", str(csv_path), options=options) + df = ctx.sql("SELECT * FROM test_table") + + # Verify the data + result = df.collect() + assert len(result) == 1 + assert result[0].num_columns == 3 + assert result[0].column(0).to_pylist() == ["Alice", "Bob", None] diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 85afd021f..48c374660 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -92,7 +92,7 @@ def test_register_csv(ctx, tmp_path): result = pa.Table.from_batches(result) assert result.schema == alternative_schema - with pytest.raises(ValueError, match="Delimiter must be a single character"): + with pytest.raises(ValueError, match="delimiter must be a single character"): ctx.register_csv("csv4", path, delimiter="wrong") with pytest.raises( diff --git a/src/context.rs b/src/context.rs index ad4fc36b1..36d662903 100644 --- a/src/context.rs +++ b/src/context.rs @@ -60,9 +60,9 @@ use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; +use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; -use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; use crate::sql::util::replace_placeholders_with_strings; use crate::store::StorageContexts; @@ -710,38 +710,18 @@ impl PySessionContext { #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, - schema=None, - has_header=true, - delimiter=",", - schema_infer_max_records=1000, - file_extension=".csv", - file_compression_type=None))] + options=None))] pub fn register_csv( &self, name: &str, path: &Bound<'_, PyAny>, - schema: Option>, - has_header: bool, - delimiter: &str, - schema_infer_max_records: usize, - file_extension: &str, - file_compression_type: Option, + options: Option<&PyCsvReadOptions>, py: Python, ) -> PyDataFusionResult<()> { - let delimiter = delimiter.as_bytes(); - if delimiter.len() != 1 { - return Err(PyDataFusionError::PythonError(py_value_err( - "Delimiter must be a single character", - ))); - } - - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter[0]) - .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension) - .file_compression_type(parse_file_compression_type(file_compression_type)?); - options.schema = schema.as_ref().map(|x| &x.0); + let options = options + .map(|opts| opts.try_into()) + .transpose()? + .unwrap_or_default(); if path.is_instance_of::() { let paths = path.extract::>()?; @@ -963,45 +943,17 @@ impl PySessionContext { #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( path, - schema=None, - has_header=true, - delimiter=",", - schema_infer_max_records=1000, - file_extension=".csv", - table_partition_cols=vec![], - file_compression_type=None))] + options=None))] pub fn read_csv( &self, path: &Bound<'_, PyAny>, - schema: Option>, - has_header: bool, - delimiter: &str, - schema_infer_max_records: usize, - file_extension: &str, - table_partition_cols: Vec<(String, PyArrowType)>, - file_compression_type: Option, + options: Option<&PyCsvReadOptions>, py: Python, ) -> PyDataFusionResult { - let delimiter = delimiter.as_bytes(); - if delimiter.len() != 1 { - return Err(PyDataFusionError::PythonError(py_value_err( - "Delimiter must be a single character", - ))); - }; - - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter[0]) - .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension) - .table_partition_cols( - table_partition_cols - .into_iter() - .map(|(name, ty)| (name, ty.0)) - .collect::>(), - ) - .file_compression_type(parse_file_compression_type(file_compression_type)?); - options.schema = schema.as_ref().map(|x| &x.0); + let options = options + .map(|opts| opts.try_into()) + .transpose()? + .unwrap_or_default(); if path.is_instance_of::() { let paths = path.extract::>()?; diff --git a/src/lib.rs b/src/lib.rs index 9483a5252..47805f5a8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +mod options; pub mod physical_plan; mod pyarrow_filter_expression; pub mod pyarrow_util; @@ -125,6 +126,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { store::init_module(&store)?; m.add_submodule(&store)?; + let options = PyModule::new(py, "options")?; + options::init_module(&options)?; + m.add_submodule(&options)?; + // Register substrait as a submodule #[cfg(feature = "substrait")] setup_substrait_module(py, &m)?; diff --git a/src/options.rs b/src/options.rs new file mode 100644 index 000000000..a37664b2e --- /dev/null +++ b/src/options.rs @@ -0,0 +1,142 @@ +use arrow::datatypes::{DataType, Schema}; +use arrow::pyarrow::PyArrowType; +use datafusion::prelude::CsvReadOptions; +use pyo3::prelude::{PyModule, PyModuleMethods}; +use pyo3::{pyclass, pymethods, Bound, PyResult}; + +use crate::context::parse_file_compression_type; +use crate::errors::PyDataFusionError; +use crate::expr::sort_expr::PySortExpr; + +/// Options for reading CSV files +#[pyclass(name = "CsvReadOptions", module = "datafusion.options", frozen)] +pub struct PyCsvReadOptions { + pub has_header: bool, + pub delimiter: u8, + pub quote: u8, + pub terminator: Option, + pub escape: Option, + pub comment: Option, + pub newlines_in_values: bool, + pub schema: Option>, + pub schema_infer_max_records: usize, + pub file_extension: String, + pub table_partition_cols: Vec<(String, PyArrowType)>, + pub file_compression_type: String, + pub file_sort_order: Vec>, + pub null_regex: Option, + pub truncated_rows: bool, +} + +#[pymethods] +impl PyCsvReadOptions { + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + has_header=true, + delimiter=b',', + quote=b'"', + terminator=None, + escape=None, + comment=None, + newlines_in_values=false, + schema=None, + schema_infer_max_records=1000, + file_extension=".csv".to_string(), + table_partition_cols=vec![], + file_compression_type="".to_string(), + file_sort_order=vec![], + null_regex=None, + truncated_rows=false + ))] + #[new] + fn new( + has_header: bool, + delimiter: u8, + quote: u8, + terminator: Option, + escape: Option, + comment: Option, + newlines_in_values: bool, + schema: Option>, + schema_infer_max_records: usize, + file_extension: String, + table_partition_cols: Vec<(String, PyArrowType)>, + file_compression_type: String, + file_sort_order: Vec>, + null_regex: Option, + truncated_rows: bool, + ) -> Self { + Self { + has_header, + delimiter, + quote, + terminator, + escape, + comment, + newlines_in_values, + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + file_sort_order, + null_regex, + truncated_rows, + } + } +} + +impl<'a> TryFrom<&'a PyCsvReadOptions> for CsvReadOptions<'a> { + type Error = PyDataFusionError; + + fn try_from(value: &'a PyCsvReadOptions) -> Result, Self::Error> { + let partition_cols: Vec<(String, DataType)> = value + .table_partition_cols + .iter() + .map(|(name, dtype)| (name.clone(), dtype.0.clone())) + .collect(); + + let compression = parse_file_compression_type(Some(value.file_compression_type.clone()))?; + + let sort_order: Vec> = value + .file_sort_order + .iter() + .map(|inner| { + inner + .iter() + .map(|sort_expr| sort_expr.sort.clone()) + .collect() + }) + .collect(); + + // Explicit struct initialization to catch upstream changes + let mut options = CsvReadOptions { + has_header: value.has_header, + delimiter: value.delimiter, + quote: value.quote, + terminator: value.terminator, + escape: value.escape, + comment: value.comment, + newlines_in_values: value.newlines_in_values, + schema: None, // Will be set separately due to lifetime constraints + schema_infer_max_records: value.schema_infer_max_records, + file_extension: value.file_extension.as_str(), + table_partition_cols: partition_cols, + file_compression_type: compression, + file_sort_order: sort_order, + null_regex: value.null_regex.clone(), + truncated_rows: value.truncated_rows, + }; + + // Set schema separately to handle the lifetime + options.schema = value.schema.as_ref().map(|s| &s.0); + + Ok(options) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +}