diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index ecf5545bc..728b9c390 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -31,20 +31,16 @@ # The following imports are okay to remain as opaque to the user. from ._internal import Config from .catalog import Catalog, Database, Table -from .common import ( - DFSchema, -) +from .common import DFSchema from .context import ( + DataframeDisplayConfig, RuntimeEnvBuilder, SessionConfig, SessionContext, SQLOptions, ) from .dataframe import DataFrame -from .expr import ( - Expr, - WindowFrame, -) +from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream @@ -60,6 +56,7 @@ "DFSchema", "DataFrame", "Database", + "DataframeDisplayConfig", "ExecutionPlan", "Expr", "LogicalPlan", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 1429a4975..c579a054b 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -19,7 +19,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Protocol +from typing import TYPE_CHECKING, Any, Optional, Protocol try: from warnings import deprecated # Python 3.13+ @@ -32,6 +32,7 @@ from datafusion.record_batch import RecordBatchStream from datafusion.udf import AggregateUDF, ScalarUDF, WindowUDF +from ._internal import DataframeDisplayConfig as DataframeDisplayConfigInternal from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal from ._internal import SessionConfig as SessionConfigInternal from ._internal import SessionContext as SessionContextInternal @@ -78,6 +79,106 @@ class TableProviderExportable(Protocol): def __datafusion_table_provider__(self) -> object: ... # noqa: D105 +class DataframeDisplayConfig: + """Configuration for displaying DataFrame results. + + This class allows you to control how DataFrames are displayed in Python. + """ + + def __init__( + self, + max_table_bytes: Optional[int] = None, + min_table_rows: Optional[int] = None, + max_cell_length: Optional[int] = None, + max_table_rows_in_repr: Optional[int] = None, + ) -> None: + """Create a new :py:class:`DataframeDisplayConfig` instance. + + Args: + max_table_bytes: Maximum bytes to display for table presentation + (default: 2MB) + min_table_rows: Minimum number of table rows to display + (default: 20) + max_cell_length: Maximum length of a cell before it gets minimized + (default: 25) + max_table_rows_in_repr: Maximum number of rows to display in repr + string output (default: 10) + """ + # Validate values if they are not None + if max_table_bytes is not None: + self._validate_positive(max_table_bytes, "max_table_bytes") + if min_table_rows is not None: + self._validate_positive(min_table_rows, "min_table_rows") + if max_cell_length is not None: + self._validate_positive(max_cell_length, "max_cell_length") + if max_table_rows_in_repr is not None: + self._validate_positive(max_table_rows_in_repr, "max_table_rows_in_repr") + self.config_internal = DataframeDisplayConfigInternal( + max_table_bytes=max_table_bytes, + min_table_rows=min_table_rows, + max_cell_length=max_cell_length, + max_table_rows_in_repr=max_table_rows_in_repr, + ) + + def _validate_positive(self, value: int, name: str) -> None: + """Validate that the given value is positive. + + Args: + value: The value to validate + name: The name of the parameter for the error message + + Raises: + ValueError: If the value is not positive + """ + if value <= 0: + error_message = f"{name} must be greater than 0" + raise ValueError(error_message) + + @property + def max_table_bytes(self) -> int: + """Get the maximum bytes to display for table presentation.""" + return self.config_internal.max_table_bytes + + @max_table_bytes.setter + def max_table_bytes(self, value: int) -> None: + """Set the maximum bytes to display for table presentation.""" + self._validate_positive(value, "max_table_bytes") + self.config_internal.max_table_bytes = value + + @property + def min_table_rows(self) -> int: + """Get the minimum number of table rows to display.""" + return self.config_internal.min_table_rows + + @min_table_rows.setter + def min_table_rows(self, value: int) -> None: + """Set the minimum number of table rows to display.""" + self._validate_positive(value, "min_table_rows") + self.config_internal.min_table_rows = value + + @property + def max_cell_length(self) -> int: + """Get the maximum length of a cell before it gets minimized.""" + return self.config_internal.max_cell_length + + @max_cell_length.setter + def max_cell_length(self, value: int) -> None: + """Set the maximum length of a cell before it gets minimized.""" + self._validate_positive(value, "max_cell_length") + self.config_internal.max_cell_length = value + + @property + def max_table_rows_in_repr(self) -> int: + """Get the maximum number of rows to display in repr string output.""" + return self.config_internal.max_table_rows_in_repr + + @max_table_rows_in_repr.setter + def max_table_rows_in_repr(self, value: int) -> None: + """Set the maximum number of rows to display in repr string output.""" + self._validate_positive(value, "max_table_rows_in_repr") + self.config_internal.max_table_rows_in_repr = value + + class SessionConfig: """Session configuration options.""" @@ -470,6 +571,7 @@ def __init__( self, config: SessionConfig | None = None, runtime: RuntimeEnvBuilder | None = None, + display_config: DataframeDisplayConfig | None = None, ) -> None: """Main interface for executing queries with DataFusion. @@ -480,7 +582,7 @@ def __init__( Args: config: Session configuration options. runtime: Runtime configuration options. - + display_config: DataFrame display configuration options. Example usage: The following example demonstrates how to use the context to execute @@ -493,8 +595,10 @@ def __init__( """ config = config.config_internal if config is not None else None runtime = runtime.config_internal if runtime is not None else None - - self.ctx = SessionContextInternal(config, runtime) + display_config = ( + display_config.config_internal if display_config is not None else None + ) + self.ctx = SessionContextInternal(config, runtime, display_config) @classmethod def global_ctx(cls) -> SessionContext: @@ -508,6 +612,40 @@ def global_ctx(cls) -> SessionContext: wrapper.ctx = internal_ctx return wrapper + def with_display_config( + self, + max_table_bytes: Optional[int] = None, + min_table_rows: Optional[int] = None, + max_cell_length: Optional[int] = None, + max_table_rows_in_repr: Optional[int] = None, + ) -> SessionContext: + """Configure the display options for DataFrames. + + Args: + max_table_bytes: Maximum bytes to display for table presentation + (default: 2MB) + min_table_rows: Minimum number of table rows to display + (default: 20) + max_cell_length: Maximum length of a cell before it gets minimized + (default: 25) + max_table_rows_in_repr: Maximum number of rows to display in repr + string output (default: 10) + + Returns: + A new :py:class:`SessionContext` object with the updated display settings. + """ + display_config = DataframeDisplayConfig( + max_table_bytes=max_table_bytes, + min_table_rows=min_table_rows, + max_cell_length=max_cell_length, + max_table_rows_in_repr=max_table_rows_in_repr, + ) + + klass = self.__class__ + obj = klass.__new__(klass) + obj.ctx = self.ctx.with_display_config(display_config.config_internal) + return obj + def enable_url_table(self) -> SessionContext: """Control if local files can be queried as tables. @@ -806,9 +944,11 @@ def register_parquet( file_extension, skip_metadata, schema, - [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] - if file_sort_order is not None - else None, + ( + [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order] + if file_sort_order is not None + else None + ), ) def register_csv( diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index eda13930d..e10613d9b 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -29,6 +29,7 @@ literal, ) from datafusion import functions as f +from datafusion.context import DataframeDisplayConfig from datafusion.expr import Window from pyarrow.csv import write_csv @@ -51,6 +52,134 @@ def df(): return ctx.from_arrow(batch) +@pytest.fixture +def data(): + return [{"a": 1, "b": "x" * 50, "c": 3}] * 100 + + +@pytest.fixture +def span_expandable_class(): + return '" in html_repr2 + + +def test_display_config_in_init(data): + # Test display config directly in SessionContext constructor + display_config = DataframeDisplayConfig( + max_table_bytes=1024, + min_table_rows=5, + max_cell_length=10, + max_table_rows_in_repr=3, + ) + + ctx = SessionContext(display_config=display_config) + df1 = ctx.from_pylist(data) + html_repr1 = df1._repr_html_() + + # Create a context with custom display config through the with_display_config method + ctx2 = ctx.with_display_config( + max_table_bytes=1024, + min_table_rows=5, + max_cell_length=10, + max_table_rows_in_repr=3, + ) + df2 = ctx2.from_pylist(data) + html_repr2 = df2._repr_html_() + + # Both methods should result in equivalent display configuration + assert normalize_uuid(html_repr1) == normalize_uuid(html_repr2) + + @pytest.fixture def struct_df(): ctx = SessionContext() @@ -1261,3 +1390,130 @@ def test_dataframe_repr_html(df) -> None: body_lines = [f"{v}" for inner in body_data for v in inner] body_pattern = "(.*?)".join(body_lines) assert len(re.findall(body_pattern, output, re.DOTALL)) == 1 + + +def test_display_config_affects_repr(data): + max_table_rows_in_repr = 3 + # Create a context with custom display config + ctx = SessionContext().with_display_config( + max_table_rows_in_repr=max_table_rows_in_repr + ) + + # Create a DataFrame with more rows than the display limit + df = ctx.from_pylist(data) + + repr_str = repr(df) + + # The representation should show truncated data (3 rows as specified) + assert ( + # 5 = 1 header row + 3 separator line + 1 truncation message + repr_str.count("\n") <= max_table_rows_in_repr + 5 + ) + assert "Data truncated" in repr_str + + # Create a context with larger display limit + max_table_rows_in_repr = 100 + ctx2 = SessionContext().with_display_config( + max_table_rows_in_repr=max_table_rows_in_repr + ) + + df2 = ctx2.from_pylist(data) + repr_str2 = repr(df2) + + # Should show all data without truncation message + assert ( + # 4 = 1 header row + 3 separator lines + repr_str2.count("\n") == max_table_rows_in_repr + 4 + ) # All rows should be shown + assert "Data truncated" not in repr_str2 + + +def test_display_config_affects_html_repr(data, span_expandable_class): + # Create a context with custom display config to show only a small cell length + ctx = SessionContext().with_display_config(max_cell_length=5) + + # Create a DataFrame with a column containing long strings + df = ctx.from_pylist(data) + + # Get the HTML representation + html_str = df._repr_html_() + + # The cell should be truncated to 5 characters and have expansion button + assert ">xxxxx" in html_str # 5 character limit + assert span_expandable_class in html_str + + # Create a context with larger cell length limit + ctx2 = SessionContext().with_display_config(max_cell_length=60) + + df2 = ctx2.from_pylist(data) + html_str2 = df2._repr_html_() + + # String shouldn't be truncated (or at least not in the same way) + assert span_expandable_class not in html_str2 + + +def test_display_config_rows_limit_in_html(data): + max_table_rows = 5 + # Create a context with custom display config to limit rows + ctx = SessionContext().with_display_config( + max_table_rows_in_repr=max_table_rows, + ) + + # Create a DataFrame with 10 rows + df = ctx.from_pylist(data) + + # Get the HTML representation + html_str = df._repr_html_() + + # Only a few rows should be shown and there should be a truncation message + row_count = html_str.count("") - 1 # Subtract 1 for header row + assert row_count <= max_table_rows + assert "Data truncated" in html_str + + # Create a context with larger row limit + max_table_rows = 100 + ctx2 = SessionContext().with_display_config( + max_table_rows_in_repr=max_table_rows + ) # Show more rows + + df2 = ctx2.from_pylist(data) + html_str2 = df2._repr_html_() + + # Should show all rows + row_count2 = html_str2.count("") - 1 # Subtract 1 for header row + assert row_count2 == max_table_rows + assert "Data truncated" not in html_str2 + + +def test_display_config_max_bytes_limit(data): + min_table_rows = 10 + max_table_rows = 20 + # Create a context with custom display config with very small byte limit + ctx = SessionContext().with_display_config( + min_table_rows=min_table_rows, + max_table_rows_in_repr=max_table_rows, + max_table_bytes=100, + ) # Very small limit + + # Create a DataFrame with large content + df = ctx.from_pylist(data) + + # Get the HTML representation + html_str = df._repr_html_() + + # Due to small byte limit, we should see truncation + row_count = html_str.count("") - 1 # Subtract 1 for header row + assert row_count <= min_table_rows # Should not show all 10 rows + assert "Data truncated" in html_str + + # With a larger byte limit + ctx2 = SessionContext().with_display_config( + max_table_bytes=10 * 1024 * 1024 # 10 MB, much more than needed + ) + + df2 = ctx2.from_pylist(data) + html_str2 = df2._repr_html_() + + # Should show all rows + row_count2 = html_str2.count("") - 1 # Subtract 1 for header row + assert row_count2 >= min_table_rows # Should show more than min_table_rows diff --git a/src/context.rs b/src/context.rs index 0db0f4d7e..6147cceff 100644 --- a/src/context.rs +++ b/src/context.rs @@ -72,6 +72,56 @@ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider}; use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType}; use tokio::task::JoinHandle; +/// Configuration for displaying DataFrames +#[pyclass(name = "DataframeDisplayConfig", module = "datafusion", subclass)] +#[derive(Clone)] +pub struct PyDataframeDisplayConfig { + /// Maximum bytes to display for table presentation (default: 2MB) + #[pyo3(get, set)] + pub max_table_bytes: usize, + /// Minimum number of table rows to display (default: 20) + #[pyo3(get, set)] + pub min_table_rows: usize, + /// Maximum length of a cell before it gets minimized (default: 25) + #[pyo3(get, set)] + pub max_cell_length: usize, + /// Maximum number of rows to display in repr string output (default: 10) + #[pyo3(get, set)] + pub max_table_rows_in_repr: usize, +} + +#[pymethods] +impl PyDataframeDisplayConfig { + #[new] + #[pyo3(signature = (max_table_bytes=None, min_table_rows=None, max_cell_length=None, max_table_rows_in_repr=None))] + fn new( + max_table_bytes: Option, + min_table_rows: Option, + max_cell_length: Option, + max_table_rows_in_repr: Option, + ) -> Self { + let default = Self::default(); + Self { + max_table_bytes: max_table_bytes.unwrap_or(default.max_table_bytes), + min_table_rows: min_table_rows.unwrap_or(default.min_table_rows), + max_cell_length: max_cell_length.unwrap_or(default.max_cell_length), + max_table_rows_in_repr: max_table_rows_in_repr + .unwrap_or(default.max_table_rows_in_repr), + } + } +} + +impl Default for PyDataframeDisplayConfig { + fn default() -> Self { + Self { + max_table_bytes: 2 * 1024 * 1024, // 2 MB + min_table_rows: 20, + max_cell_length: 25, + max_table_rows_in_repr: 10, + } + } +} + /// Configuration options for a SessionContext #[pyclass(name = "SessionConfig", module = "datafusion", subclass)] #[derive(Clone, Default)] @@ -269,15 +319,17 @@ impl PySQLOptions { #[derive(Clone)] pub struct PySessionContext { pub ctx: SessionContext, + pub display_config: PyDataframeDisplayConfig, } #[pymethods] impl PySessionContext { - #[pyo3(signature = (config=None, runtime=None))] + #[pyo3(signature = (config=None, runtime=None, display_config=None))] #[new] pub fn new( config: Option, runtime: Option, + display_config: Option, ) -> PyDataFusionResult { let config = if let Some(c) = config { c.config @@ -295,22 +347,33 @@ impl PySessionContext { .with_runtime_env(runtime) .with_default_features() .build(); + Ok(PySessionContext { ctx: SessionContext::new_with_state(session_state), + display_config: display_config.unwrap_or_default(), }) } pub fn enable_url_table(&self) -> PyResult { Ok(PySessionContext { ctx: self.ctx.clone().enable_url_table(), + display_config: self.display_config.clone(), }) } + pub fn with_display_config(&self, display_config: PyDataframeDisplayConfig) -> Self { + Self { + ctx: self.ctx.clone(), + display_config, + } + } + #[classmethod] #[pyo3(signature = ())] fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult { Ok(Self { ctx: get_global_ctx().clone(), + display_config: PyDataframeDisplayConfig::default(), }) } @@ -394,7 +457,7 @@ impl PySessionContext { pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult { let result = self.ctx.sql(query); let df = wait_for_future(py, result)?; - Ok(PyDataFrame::new(df)) + Ok(PyDataFrame::new(df, self.display_config.clone())) } #[pyo3(signature = (query, options=None))] @@ -411,7 +474,7 @@ impl PySessionContext { }; let result = self.ctx.sql_with_options(query, options); let df = wait_for_future(py, result)?; - Ok(PyDataFrame::new(df)) + Ok(PyDataFrame::new(df, self.display_config.clone())) } #[pyo3(signature = (partitions, name=None, schema=None))] @@ -446,13 +509,16 @@ impl PySessionContext { let table = wait_for_future(py, self._table(&table_name))?; - let df = PyDataFrame::new(table); + let df = PyDataFrame::new(table, self.display_config.clone()); Ok(df) } /// Create a DataFrame from an existing logical plan pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame { - PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone())) + PyDataFrame::new( + DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()), + self.display_config.clone(), + ) } /// Construct datafusion dataframe from Python list @@ -820,7 +886,7 @@ impl PySessionContext { pub fn table(&self, name: &str, py: Python) -> PyResult { let x = wait_for_future(py, self.ctx.table(name)) .map_err(|e| PyKeyError::new_err(e.to_string()))?; - Ok(PyDataFrame::new(x)) + Ok(PyDataFrame::new(x, self.display_config.clone())) } pub fn table_exist(&self, name: &str) -> PyDataFusionResult { @@ -828,7 +894,10 @@ impl PySessionContext { } pub fn empty_table(&self) -> PyDataFusionResult { - Ok(PyDataFrame::new(self.ctx.read_empty()?)) + Ok(PyDataFrame::new( + self.ctx.read_empty()?, + self.display_config.clone(), + )) } pub fn session_id(&self) -> String { @@ -863,7 +932,7 @@ impl PySessionContext { let result = self.ctx.read_json(path, options); wait_for_future(py, result)? }; - Ok(PyDataFrame::new(df)) + Ok(PyDataFrame::new(df, self.display_config.clone())) } #[allow(clippy::too_many_arguments)] @@ -908,12 +977,12 @@ impl PySessionContext { let paths = path.extract::>()?; let paths = paths.iter().map(|p| p as &str).collect::>(); let result = self.ctx.read_csv(paths, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone()); Ok(df) } else { let path = path.extract::()?; let result = self.ctx.read_csv(path, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone()); Ok(df) } } @@ -951,7 +1020,7 @@ impl PySessionContext { .collect(); let result = self.ctx.read_parquet(path, options); - let df = PyDataFrame::new(wait_for_future(py, result)?); + let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone()); Ok(df) } @@ -976,12 +1045,12 @@ impl PySessionContext { let read_future = self.ctx.read_avro(path, options); wait_for_future(py, read_future)? }; - Ok(PyDataFrame::new(df)) + Ok(PyDataFrame::new(df, self.display_config.clone())) } pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult { let df = self.ctx.read_table(table.table())?; - Ok(PyDataFrame::new(df)) + Ok(PyDataFrame::new(df, self.display_config.clone())) } fn __repr__(&self) -> PyResult { @@ -1097,6 +1166,9 @@ impl From for SessionContext { impl From for PySessionContext { fn from(ctx: SessionContext) -> PySessionContext { - PySessionContext { ctx } + PySessionContext { + ctx, + display_config: PyDataframeDisplayConfig::default(), + } } } diff --git a/src/dataframe.rs b/src/dataframe.rs index be10b8c28..1fb925347 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -43,6 +43,7 @@ use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods}; use tokio::task::JoinHandle; use crate::catalog::PyTable; +use crate::context::PyDataframeDisplayConfig; use crate::errors::{py_datafusion_err, PyDataFusionError}; use crate::expr::sort_expr::to_sort_expressions; use crate::physical_plan::PyExecutionPlan; @@ -72,9 +73,6 @@ impl PyTableProvider { PyTable::new(table_provider) } } -const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB -const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20; -const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25; /// A PyDataFrame is a representation of a logical plan and an API to compose statements. /// Use it to build a plan and `.collect()` to execute the plan and collect the result. @@ -83,12 +81,16 @@ const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25; #[derive(Clone)] pub struct PyDataFrame { df: Arc, + display_config: Arc, } impl PyDataFrame { /// creates a new PyDataFrame - pub fn new(df: DataFrame) -> Self { - Self { df: Arc::new(df) } + pub fn new(df: DataFrame, display_config: PyDataframeDisplayConfig) -> Self { + Self { + df: Arc::new(df), + display_config: Arc::new(display_config), + } } } @@ -116,10 +118,17 @@ impl PyDataFrame { } fn __repr__(&self, py: Python) -> PyDataFusionResult { + // Collect record batches for display let (batches, has_more) = wait_for_future( py, - collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10), + collect_record_batches_to_display( + self.df.as_ref().clone(), + self.display_config.min_table_rows, + self.display_config.max_table_rows_in_repr, + self.display_config.max_table_bytes, + ), )?; + if batches.is_empty() { // This should not be reached, but do it for safety since we index into the vector below return Ok("No data to display".to_string()); @@ -141,8 +150,9 @@ impl PyDataFrame { py, collect_record_batches_to_display( self.df.as_ref().clone(), - MIN_TABLE_ROWS_TO_DISPLAY, - usize::MAX, + self.display_config.min_table_rows, + self.display_config.max_table_rows_in_repr, + self.display_config.max_table_bytes, ), )?; if batches.is_empty() { @@ -218,8 +228,8 @@ impl PyDataFrame { for (col, formatter) in batch_formatter.iter().enumerate() { let cell_data = formatter.value(batch_row).to_string(); // From testing, primitive data types do not typically get larger than 21 characters - if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE { - let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE]; + if cell_data.len() > self.display_config.max_cell_length { + let short_cell_data = &cell_data[0..self.display_config.max_cell_length]; cells.push(format!("
@@ -269,7 +279,7 @@ impl PyDataFrame { fn describe(&self, py: Python) -> PyDataFusionResult { let df = self.df.as_ref().clone(); let stat_df = wait_for_future(py, df.describe())?; - Ok(Self::new(stat_df)) + Ok(Self::new(stat_df, self.display_config.as_ref().clone())) } /// Returns the schema from the logical plan @@ -299,31 +309,31 @@ impl PyDataFrame { fn select_columns(&self, args: Vec) -> PyDataFusionResult { let args = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().select_columns(&args)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } #[pyo3(signature = (*args))] fn select(&self, args: Vec) -> PyDataFusionResult { let expr = args.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().select(expr)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } #[pyo3(signature = (*args))] fn drop(&self, args: Vec) -> PyDataFusionResult { let cols = args.iter().map(|s| s.as_ref()).collect::>(); let df = self.df.as_ref().clone().drop_columns(&cols)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn filter(&self, predicate: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().filter(predicate.into())?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult { let df = self.df.as_ref().clone().with_column(name, expr.into())?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn with_columns(&self, exprs: Vec) -> PyDataFusionResult { @@ -333,7 +343,7 @@ impl PyDataFrame { let name = format!("{}", expr.schema_name()); df = df.with_column(name.as_str(), expr)? } - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } /// Rename one column by applying a new projection. This is a no-op if the column to be @@ -344,27 +354,27 @@ impl PyDataFrame { .as_ref() .clone() .with_column_renamed(old_name, new_name)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyDataFusionResult { let group_by = group_by.into_iter().map(|e| e.into()).collect(); let aggs = aggs.into_iter().map(|e| e.into()).collect(); let df = self.df.as_ref().clone().aggregate(group_by, aggs)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } #[pyo3(signature = (*exprs))] fn sort(&self, exprs: Vec) -> PyDataFusionResult { let exprs = to_sort_expressions(exprs); let df = self.df.as_ref().clone().sort(exprs)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } #[pyo3(signature = (count, offset=0))] fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult { let df = self.df.as_ref().clone().limit(offset, Some(count))?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } /// Executes the plan, returning a list of `RecordBatch`es. @@ -381,7 +391,7 @@ impl PyDataFrame { /// Cache DataFrame. fn cache(&self, py: Python) -> PyDataFusionResult { let df = wait_for_future(py, self.df.as_ref().clone().cache())?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch @@ -406,7 +416,7 @@ impl PyDataFrame { /// Filter out duplicate rows fn distinct(&self) -> PyDataFusionResult { let df = self.df.as_ref().clone().distinct()?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn join( @@ -440,7 +450,7 @@ impl PyDataFrame { &right_keys, None, )?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } fn join_on( @@ -469,7 +479,7 @@ impl PyDataFrame { .as_ref() .clone() .join_on(right.df.as_ref().clone(), join_type, exprs)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } /// Print the query plan @@ -502,7 +512,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::RoundRobinBatch(num))?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } /// Repartition a `DataFrame` based on a logical partitioning scheme. @@ -514,7 +524,7 @@ impl PyDataFrame { .as_ref() .clone() .repartition(Partitioning::Hash(expr, num))?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The @@ -530,7 +540,7 @@ impl PyDataFrame { self.df.as_ref().clone().union(py_df.df.as_ref().clone())? }; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } /// Calculate the distinct union of two `DataFrame`s. The @@ -541,7 +551,7 @@ impl PyDataFrame { .as_ref() .clone() .union_distinct(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } #[pyo3(signature = (column, preserve_nulls=true))] @@ -554,7 +564,7 @@ impl PyDataFrame { .as_ref() .clone() .unnest_columns_with_options(&[column], unnest_options)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } #[pyo3(signature = (columns, preserve_nulls=true))] @@ -572,7 +582,7 @@ impl PyDataFrame { .as_ref() .clone() .unnest_columns_with_options(&cols, unnest_options)?; - Ok(Self::new(df)) + Ok(Self::new(df, self.display_config.as_ref().clone())) } /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema @@ -582,13 +592,13 @@ impl PyDataFrame { .as_ref() .clone() .intersect(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult { let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?; - Ok(Self::new(new_df)) + Ok(Self::new(new_df, self.display_config.as_ref().clone())) } /// Write a `DataFrame` to a CSV file. @@ -835,7 +845,7 @@ fn record_batch_into_schema( ) -> Result { let schema = Arc::new(schema.clone()); let base_schema = record_batch.schema(); - if base_schema.fields().len() == 0 { + if base_schema.fields().is_empty() { // Nothing to project return Ok(RecordBatch::new_empty(schema)); } @@ -886,6 +896,7 @@ async fn collect_record_batches_to_display( df: DataFrame, min_rows: usize, max_rows: usize, + max_bytes: usize, ) -> Result<(Vec, bool), DataFusionError> { let partitioned_stream = df.execute_stream_partitioned().await?; let mut stream = futures::stream::iter(partitioned_stream).flatten(); @@ -894,9 +905,7 @@ async fn collect_record_batches_to_display( let mut record_batches = Vec::default(); let mut has_more = false; - while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows) - || rows_so_far < min_rows - { + while (size_estimate_so_far < max_bytes && rows_so_far < max_rows) || rows_so_far < min_rows { let mut rb = match stream.next().await { None => { break; @@ -909,8 +918,8 @@ async fn collect_record_batches_to_display( if rows_in_rb > 0 { size_estimate_so_far += rb.get_array_memory_size(); - if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY { - let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32; + if size_estimate_so_far > max_bytes { + let ratio = max_bytes as f32 / size_estimate_so_far as f32; let total_rows = rows_in_rb + rows_so_far; let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize; diff --git a/src/lib.rs b/src/lib.rs index 6eeda0878..d31e650f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -83,6 +83,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;