diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py
index ecf5545bc..728b9c390 100644
--- a/python/datafusion/__init__.py
+++ b/python/datafusion/__init__.py
@@ -31,20 +31,16 @@
# The following imports are okay to remain as opaque to the user.
from ._internal import Config
from .catalog import Catalog, Database, Table
-from .common import (
- DFSchema,
-)
+from .common import DFSchema
from .context import (
+ DataframeDisplayConfig,
RuntimeEnvBuilder,
SessionConfig,
SessionContext,
SQLOptions,
)
from .dataframe import DataFrame
-from .expr import (
- Expr,
- WindowFrame,
-)
+from .expr import Expr, WindowFrame
from .io import read_avro, read_csv, read_json, read_parquet
from .plan import ExecutionPlan, LogicalPlan
from .record_batch import RecordBatch, RecordBatchStream
@@ -60,6 +56,7 @@
"DFSchema",
"DataFrame",
"Database",
+ "DataframeDisplayConfig",
"ExecutionPlan",
"Expr",
"LogicalPlan",
diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 1429a4975..c579a054b 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -19,7 +19,7 @@
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, Protocol
+from typing import TYPE_CHECKING, Any, Optional, Protocol
try:
from warnings import deprecated # Python 3.13+
@@ -32,6 +32,7 @@
from datafusion.record_batch import RecordBatchStream
from datafusion.udf import AggregateUDF, ScalarUDF, WindowUDF
+from ._internal import DataframeDisplayConfig as DataframeDisplayConfigInternal
from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
from ._internal import SessionConfig as SessionConfigInternal
from ._internal import SessionContext as SessionContextInternal
@@ -78,6 +79,106 @@ class TableProviderExportable(Protocol):
def __datafusion_table_provider__(self) -> object: ... # noqa: D105
+class DataframeDisplayConfig:
+ """Configuration for displaying DataFrame results.
+
+ This class allows you to control how DataFrames are displayed in Python.
+ """
+
+ def __init__(
+ self,
+ max_table_bytes: Optional[int] = None,
+ min_table_rows: Optional[int] = None,
+ max_cell_length: Optional[int] = None,
+ max_table_rows_in_repr: Optional[int] = None,
+ ) -> None:
+ """Create a new :py:class:`DataframeDisplayConfig` instance.
+
+ Args:
+ max_table_bytes: Maximum bytes to display for table presentation
+ (default: 2MB)
+ min_table_rows: Minimum number of table rows to display
+ (default: 20)
+ max_cell_length: Maximum length of a cell before it gets minimized
+ (default: 25)
+ max_table_rows_in_repr: Maximum number of rows to display in repr
+ string output (default: 10)
+ """
+ # Validate values if they are not None
+ if max_table_bytes is not None:
+ self._validate_positive(max_table_bytes, "max_table_bytes")
+ if min_table_rows is not None:
+ self._validate_positive(min_table_rows, "min_table_rows")
+ if max_cell_length is not None:
+ self._validate_positive(max_cell_length, "max_cell_length")
+ if max_table_rows_in_repr is not None:
+ self._validate_positive(max_table_rows_in_repr, "max_table_rows_in_repr")
+ self.config_internal = DataframeDisplayConfigInternal(
+ max_table_bytes=max_table_bytes,
+ min_table_rows=min_table_rows,
+ max_cell_length=max_cell_length,
+ max_table_rows_in_repr=max_table_rows_in_repr,
+ )
+
+ def _validate_positive(self, value: int, name: str) -> None:
+ """Validate that the given value is positive.
+
+ Args:
+ value: The value to validate
+ name: The name of the parameter for the error message
+
+ Raises:
+ ValueError: If the value is not positive
+ """
+ if value <= 0:
+ error_message = f"{name} must be greater than 0"
+ raise ValueError(error_message)
+
+ @property
+ def max_table_bytes(self) -> int:
+ """Get the maximum bytes to display for table presentation."""
+ return self.config_internal.max_table_bytes
+
+ @max_table_bytes.setter
+ def max_table_bytes(self, value: int) -> None:
+ """Set the maximum bytes to display for table presentation."""
+ self._validate_positive(value, "max_table_bytes")
+ self.config_internal.max_table_bytes = value
+
+ @property
+ def min_table_rows(self) -> int:
+ """Get the minimum number of table rows to display."""
+ return self.config_internal.min_table_rows
+
+ @min_table_rows.setter
+ def min_table_rows(self, value: int) -> None:
+ """Set the minimum number of table rows to display."""
+ self._validate_positive(value, "min_table_rows")
+ self.config_internal.min_table_rows = value
+
+ @property
+ def max_cell_length(self) -> int:
+ """Get the maximum length of a cell before it gets minimized."""
+ return self.config_internal.max_cell_length
+
+ @max_cell_length.setter
+ def max_cell_length(self, value: int) -> None:
+ """Set the maximum length of a cell before it gets minimized."""
+ self._validate_positive(value, "max_cell_length")
+ self.config_internal.max_cell_length = value
+
+ @property
+ def max_table_rows_in_repr(self) -> int:
+ """Get the maximum number of rows to display in repr string output."""
+ return self.config_internal.max_table_rows_in_repr
+
+ @max_table_rows_in_repr.setter
+ def max_table_rows_in_repr(self, value: int) -> None:
+ """Set the maximum number of rows to display in repr string output."""
+ self._validate_positive(value, "max_table_rows_in_repr")
+ self.config_internal.max_table_rows_in_repr = value
+
+
class SessionConfig:
"""Session configuration options."""
@@ -470,6 +571,7 @@ def __init__(
self,
config: SessionConfig | None = None,
runtime: RuntimeEnvBuilder | None = None,
+ display_config: DataframeDisplayConfig | None = None,
) -> None:
"""Main interface for executing queries with DataFusion.
@@ -480,7 +582,7 @@ def __init__(
Args:
config: Session configuration options.
runtime: Runtime configuration options.
-
+ display_config: DataFrame display configuration options.
Example usage:
The following example demonstrates how to use the context to execute
@@ -493,8 +595,10 @@ def __init__(
"""
config = config.config_internal if config is not None else None
runtime = runtime.config_internal if runtime is not None else None
-
- self.ctx = SessionContextInternal(config, runtime)
+ display_config = (
+ display_config.config_internal if display_config is not None else None
+ )
+ self.ctx = SessionContextInternal(config, runtime, display_config)
@classmethod
def global_ctx(cls) -> SessionContext:
@@ -508,6 +612,40 @@ def global_ctx(cls) -> SessionContext:
wrapper.ctx = internal_ctx
return wrapper
+ def with_display_config(
+ self,
+ max_table_bytes: Optional[int] = None,
+ min_table_rows: Optional[int] = None,
+ max_cell_length: Optional[int] = None,
+ max_table_rows_in_repr: Optional[int] = None,
+ ) -> SessionContext:
+ """Configure the display options for DataFrames.
+
+ Args:
+ max_table_bytes: Maximum bytes to display for table presentation
+ (default: 2MB)
+ min_table_rows: Minimum number of table rows to display
+ (default: 20)
+ max_cell_length: Maximum length of a cell before it gets minimized
+ (default: 25)
+ max_table_rows_in_repr: Maximum number of rows to display in repr
+ string output (default: 10)
+
+ Returns:
+ A new :py:class:`SessionContext` object with the updated display settings.
+ """
+ display_config = DataframeDisplayConfig(
+ max_table_bytes=max_table_bytes,
+ min_table_rows=min_table_rows,
+ max_cell_length=max_cell_length,
+ max_table_rows_in_repr=max_table_rows_in_repr,
+ )
+
+ klass = self.__class__
+ obj = klass.__new__(klass)
+ obj.ctx = self.ctx.with_display_config(display_config.config_internal)
+ return obj
+
def enable_url_table(self) -> SessionContext:
"""Control if local files can be queried as tables.
@@ -806,9 +944,11 @@ def register_parquet(
file_extension,
skip_metadata,
schema,
- [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
- if file_sort_order is not None
- else None,
+ (
+ [sort_list_to_raw_sort_list(exprs) for exprs in file_sort_order]
+ if file_sort_order is not None
+ else None
+ ),
)
def register_csv(
diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py
index eda13930d..e10613d9b 100644
--- a/python/tests/test_dataframe.py
+++ b/python/tests/test_dataframe.py
@@ -29,6 +29,7 @@
literal,
)
from datafusion import functions as f
+from datafusion.context import DataframeDisplayConfig
from datafusion.expr import Window
from pyarrow.csv import write_csv
@@ -51,6 +52,134 @@ def df():
return ctx.from_arrow(batch)
+@pytest.fixture
+def data():
+ return [{"a": 1, "b": "x" * 50, "c": 3}] * 100
+
+
+@pytest.fixture
+def span_expandable_class():
+ return '{('x' * 10)}" in html_repr2
+
+
+def test_display_config_in_init(data):
+ # Test display config directly in SessionContext constructor
+ display_config = DataframeDisplayConfig(
+ max_table_bytes=1024,
+ min_table_rows=5,
+ max_cell_length=10,
+ max_table_rows_in_repr=3,
+ )
+
+ ctx = SessionContext(display_config=display_config)
+ df1 = ctx.from_pylist(data)
+ html_repr1 = df1._repr_html_()
+
+ # Create a context with custom display config through the with_display_config method
+ ctx2 = ctx.with_display_config(
+ max_table_bytes=1024,
+ min_table_rows=5,
+ max_cell_length=10,
+ max_table_rows_in_repr=3,
+ )
+ df2 = ctx2.from_pylist(data)
+ html_repr2 = df2._repr_html_()
+
+ # Both methods should result in equivalent display configuration
+ assert normalize_uuid(html_repr1) == normalize_uuid(html_repr2)
+
+
@pytest.fixture
def struct_df():
ctx = SessionContext()
@@ -1261,3 +1390,130 @@ def test_dataframe_repr_html(df) -> None:
body_lines = [f"
{v} | " for inner in body_data for v in inner]
body_pattern = "(.*?)".join(body_lines)
assert len(re.findall(body_pattern, output, re.DOTALL)) == 1
+
+
+def test_display_config_affects_repr(data):
+ max_table_rows_in_repr = 3
+ # Create a context with custom display config
+ ctx = SessionContext().with_display_config(
+ max_table_rows_in_repr=max_table_rows_in_repr
+ )
+
+ # Create a DataFrame with more rows than the display limit
+ df = ctx.from_pylist(data)
+
+ repr_str = repr(df)
+
+ # The representation should show truncated data (3 rows as specified)
+ assert (
+ # 5 = 1 header row + 3 separator line + 1 truncation message
+ repr_str.count("\n") <= max_table_rows_in_repr + 5
+ )
+ assert "Data truncated" in repr_str
+
+ # Create a context with larger display limit
+ max_table_rows_in_repr = 100
+ ctx2 = SessionContext().with_display_config(
+ max_table_rows_in_repr=max_table_rows_in_repr
+ )
+
+ df2 = ctx2.from_pylist(data)
+ repr_str2 = repr(df2)
+
+ # Should show all data without truncation message
+ assert (
+ # 4 = 1 header row + 3 separator lines
+ repr_str2.count("\n") == max_table_rows_in_repr + 4
+ ) # All rows should be shown
+ assert "Data truncated" not in repr_str2
+
+
+def test_display_config_affects_html_repr(data, span_expandable_class):
+ # Create a context with custom display config to show only a small cell length
+ ctx = SessionContext().with_display_config(max_cell_length=5)
+
+ # Create a DataFrame with a column containing long strings
+ df = ctx.from_pylist(data)
+
+ # Get the HTML representation
+ html_str = df._repr_html_()
+
+ # The cell should be truncated to 5 characters and have expansion button
+ assert ">xxxxx" in html_str # 5 character limit
+ assert span_expandable_class in html_str
+
+ # Create a context with larger cell length limit
+ ctx2 = SessionContext().with_display_config(max_cell_length=60)
+
+ df2 = ctx2.from_pylist(data)
+ html_str2 = df2._repr_html_()
+
+ # String shouldn't be truncated (or at least not in the same way)
+ assert span_expandable_class not in html_str2
+
+
+def test_display_config_rows_limit_in_html(data):
+ max_table_rows = 5
+ # Create a context with custom display config to limit rows
+ ctx = SessionContext().with_display_config(
+ max_table_rows_in_repr=max_table_rows,
+ )
+
+ # Create a DataFrame with 10 rows
+ df = ctx.from_pylist(data)
+
+ # Get the HTML representation
+ html_str = df._repr_html_()
+
+ # Only a few rows should be shown and there should be a truncation message
+ row_count = html_str.count("") - 1 # Subtract 1 for header row
+ assert row_count <= max_table_rows
+ assert "Data truncated" in html_str
+
+ # Create a context with larger row limit
+ max_table_rows = 100
+ ctx2 = SessionContext().with_display_config(
+ max_table_rows_in_repr=max_table_rows
+ ) # Show more rows
+
+ df2 = ctx2.from_pylist(data)
+ html_str2 = df2._repr_html_()
+
+ # Should show all rows
+ row_count2 = html_str2.count("
") - 1 # Subtract 1 for header row
+ assert row_count2 == max_table_rows
+ assert "Data truncated" not in html_str2
+
+
+def test_display_config_max_bytes_limit(data):
+ min_table_rows = 10
+ max_table_rows = 20
+ # Create a context with custom display config with very small byte limit
+ ctx = SessionContext().with_display_config(
+ min_table_rows=min_table_rows,
+ max_table_rows_in_repr=max_table_rows,
+ max_table_bytes=100,
+ ) # Very small limit
+
+ # Create a DataFrame with large content
+ df = ctx.from_pylist(data)
+
+ # Get the HTML representation
+ html_str = df._repr_html_()
+
+ # Due to small byte limit, we should see truncation
+ row_count = html_str.count("
") - 1 # Subtract 1 for header row
+ assert row_count <= min_table_rows # Should not show all 10 rows
+ assert "Data truncated" in html_str
+
+ # With a larger byte limit
+ ctx2 = SessionContext().with_display_config(
+ max_table_bytes=10 * 1024 * 1024 # 10 MB, much more than needed
+ )
+
+ df2 = ctx2.from_pylist(data)
+ html_str2 = df2._repr_html_()
+
+ # Should show all rows
+ row_count2 = html_str2.count("
") - 1 # Subtract 1 for header row
+ assert row_count2 >= min_table_rows # Should show more than min_table_rows
diff --git a/src/context.rs b/src/context.rs
index 0db0f4d7e..6147cceff 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -72,6 +72,56 @@ use datafusion_ffi::table_provider::{FFI_TableProvider, ForeignTableProvider};
use pyo3::types::{PyCapsule, PyDict, PyList, PyTuple, PyType};
use tokio::task::JoinHandle;
+/// Configuration for displaying DataFrames
+#[pyclass(name = "DataframeDisplayConfig", module = "datafusion", subclass)]
+#[derive(Clone)]
+pub struct PyDataframeDisplayConfig {
+ /// Maximum bytes to display for table presentation (default: 2MB)
+ #[pyo3(get, set)]
+ pub max_table_bytes: usize,
+ /// Minimum number of table rows to display (default: 20)
+ #[pyo3(get, set)]
+ pub min_table_rows: usize,
+ /// Maximum length of a cell before it gets minimized (default: 25)
+ #[pyo3(get, set)]
+ pub max_cell_length: usize,
+ /// Maximum number of rows to display in repr string output (default: 10)
+ #[pyo3(get, set)]
+ pub max_table_rows_in_repr: usize,
+}
+
+#[pymethods]
+impl PyDataframeDisplayConfig {
+ #[new]
+ #[pyo3(signature = (max_table_bytes=None, min_table_rows=None, max_cell_length=None, max_table_rows_in_repr=None))]
+ fn new(
+ max_table_bytes: Option,
+ min_table_rows: Option,
+ max_cell_length: Option,
+ max_table_rows_in_repr: Option,
+ ) -> Self {
+ let default = Self::default();
+ Self {
+ max_table_bytes: max_table_bytes.unwrap_or(default.max_table_bytes),
+ min_table_rows: min_table_rows.unwrap_or(default.min_table_rows),
+ max_cell_length: max_cell_length.unwrap_or(default.max_cell_length),
+ max_table_rows_in_repr: max_table_rows_in_repr
+ .unwrap_or(default.max_table_rows_in_repr),
+ }
+ }
+}
+
+impl Default for PyDataframeDisplayConfig {
+ fn default() -> Self {
+ Self {
+ max_table_bytes: 2 * 1024 * 1024, // 2 MB
+ min_table_rows: 20,
+ max_cell_length: 25,
+ max_table_rows_in_repr: 10,
+ }
+ }
+}
+
/// Configuration options for a SessionContext
#[pyclass(name = "SessionConfig", module = "datafusion", subclass)]
#[derive(Clone, Default)]
@@ -269,15 +319,17 @@ impl PySQLOptions {
#[derive(Clone)]
pub struct PySessionContext {
pub ctx: SessionContext,
+ pub display_config: PyDataframeDisplayConfig,
}
#[pymethods]
impl PySessionContext {
- #[pyo3(signature = (config=None, runtime=None))]
+ #[pyo3(signature = (config=None, runtime=None, display_config=None))]
#[new]
pub fn new(
config: Option,
runtime: Option,
+ display_config: Option,
) -> PyDataFusionResult {
let config = if let Some(c) = config {
c.config
@@ -295,22 +347,33 @@ impl PySessionContext {
.with_runtime_env(runtime)
.with_default_features()
.build();
+
Ok(PySessionContext {
ctx: SessionContext::new_with_state(session_state),
+ display_config: display_config.unwrap_or_default(),
})
}
pub fn enable_url_table(&self) -> PyResult {
Ok(PySessionContext {
ctx: self.ctx.clone().enable_url_table(),
+ display_config: self.display_config.clone(),
})
}
+ pub fn with_display_config(&self, display_config: PyDataframeDisplayConfig) -> Self {
+ Self {
+ ctx: self.ctx.clone(),
+ display_config,
+ }
+ }
+
#[classmethod]
#[pyo3(signature = ())]
fn global_ctx(_cls: &Bound<'_, PyType>) -> PyResult {
Ok(Self {
ctx: get_global_ctx().clone(),
+ display_config: PyDataframeDisplayConfig::default(),
})
}
@@ -394,7 +457,7 @@ impl PySessionContext {
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult {
let result = self.ctx.sql(query);
let df = wait_for_future(py, result)?;
- Ok(PyDataFrame::new(df))
+ Ok(PyDataFrame::new(df, self.display_config.clone()))
}
#[pyo3(signature = (query, options=None))]
@@ -411,7 +474,7 @@ impl PySessionContext {
};
let result = self.ctx.sql_with_options(query, options);
let df = wait_for_future(py, result)?;
- Ok(PyDataFrame::new(df))
+ Ok(PyDataFrame::new(df, self.display_config.clone()))
}
#[pyo3(signature = (partitions, name=None, schema=None))]
@@ -446,13 +509,16 @@ impl PySessionContext {
let table = wait_for_future(py, self._table(&table_name))?;
- let df = PyDataFrame::new(table);
+ let df = PyDataFrame::new(table, self.display_config.clone());
Ok(df)
}
/// Create a DataFrame from an existing logical plan
pub fn create_dataframe_from_logical_plan(&mut self, plan: PyLogicalPlan) -> PyDataFrame {
- PyDataFrame::new(DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()))
+ PyDataFrame::new(
+ DataFrame::new(self.ctx.state(), plan.plan.as_ref().clone()),
+ self.display_config.clone(),
+ )
}
/// Construct datafusion dataframe from Python list
@@ -820,7 +886,7 @@ impl PySessionContext {
pub fn table(&self, name: &str, py: Python) -> PyResult {
let x = wait_for_future(py, self.ctx.table(name))
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
- Ok(PyDataFrame::new(x))
+ Ok(PyDataFrame::new(x, self.display_config.clone()))
}
pub fn table_exist(&self, name: &str) -> PyDataFusionResult {
@@ -828,7 +894,10 @@ impl PySessionContext {
}
pub fn empty_table(&self) -> PyDataFusionResult {
- Ok(PyDataFrame::new(self.ctx.read_empty()?))
+ Ok(PyDataFrame::new(
+ self.ctx.read_empty()?,
+ self.display_config.clone(),
+ ))
}
pub fn session_id(&self) -> String {
@@ -863,7 +932,7 @@ impl PySessionContext {
let result = self.ctx.read_json(path, options);
wait_for_future(py, result)?
};
- Ok(PyDataFrame::new(df))
+ Ok(PyDataFrame::new(df, self.display_config.clone()))
}
#[allow(clippy::too_many_arguments)]
@@ -908,12 +977,12 @@ impl PySessionContext {
let paths = path.extract::>()?;
let paths = paths.iter().map(|p| p as &str).collect::>();
let result = self.ctx.read_csv(paths, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone());
Ok(df)
} else {
let path = path.extract::()?;
let result = self.ctx.read_csv(path, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone());
Ok(df)
}
}
@@ -951,7 +1020,7 @@ impl PySessionContext {
.collect();
let result = self.ctx.read_parquet(path, options);
- let df = PyDataFrame::new(wait_for_future(py, result)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?, self.display_config.clone());
Ok(df)
}
@@ -976,12 +1045,12 @@ impl PySessionContext {
let read_future = self.ctx.read_avro(path, options);
wait_for_future(py, read_future)?
};
- Ok(PyDataFrame::new(df))
+ Ok(PyDataFrame::new(df, self.display_config.clone()))
}
pub fn read_table(&self, table: &PyTable) -> PyDataFusionResult {
let df = self.ctx.read_table(table.table())?;
- Ok(PyDataFrame::new(df))
+ Ok(PyDataFrame::new(df, self.display_config.clone()))
}
fn __repr__(&self) -> PyResult {
@@ -1097,6 +1166,9 @@ impl From for SessionContext {
impl From for PySessionContext {
fn from(ctx: SessionContext) -> PySessionContext {
- PySessionContext { ctx }
+ PySessionContext {
+ ctx,
+ display_config: PyDataframeDisplayConfig::default(),
+ }
}
}
diff --git a/src/dataframe.rs b/src/dataframe.rs
index be10b8c28..1fb925347 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -43,6 +43,7 @@ use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
use tokio::task::JoinHandle;
use crate::catalog::PyTable;
+use crate::context::PyDataframeDisplayConfig;
use crate::errors::{py_datafusion_err, PyDataFusionError};
use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
@@ -72,9 +73,6 @@ impl PyTableProvider {
PyTable::new(table_provider)
}
}
-const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
-const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
-const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
@@ -83,12 +81,16 @@ const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;
#[derive(Clone)]
pub struct PyDataFrame {
df: Arc,
+ display_config: Arc,
}
impl PyDataFrame {
/// creates a new PyDataFrame
- pub fn new(df: DataFrame) -> Self {
- Self { df: Arc::new(df) }
+ pub fn new(df: DataFrame, display_config: PyDataframeDisplayConfig) -> Self {
+ Self {
+ df: Arc::new(df),
+ display_config: Arc::new(display_config),
+ }
}
}
@@ -116,10 +118,17 @@ impl PyDataFrame {
}
fn __repr__(&self, py: Python) -> PyDataFusionResult {
+ // Collect record batches for display
let (batches, has_more) = wait_for_future(
py,
- collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
+ collect_record_batches_to_display(
+ self.df.as_ref().clone(),
+ self.display_config.min_table_rows,
+ self.display_config.max_table_rows_in_repr,
+ self.display_config.max_table_bytes,
+ ),
)?;
+
if batches.is_empty() {
// This should not be reached, but do it for safety since we index into the vector below
return Ok("No data to display".to_string());
@@ -141,8 +150,9 @@ impl PyDataFrame {
py,
collect_record_batches_to_display(
self.df.as_ref().clone(),
- MIN_TABLE_ROWS_TO_DISPLAY,
- usize::MAX,
+ self.display_config.min_table_rows,
+ self.display_config.max_table_rows_in_repr,
+ self.display_config.max_table_bytes,
),
)?;
if batches.is_empty() {
@@ -218,8 +228,8 @@ impl PyDataFrame {
for (col, formatter) in batch_formatter.iter().enumerate() {
let cell_data = formatter.value(batch_row).to_string();
// From testing, primitive data types do not typically get larger than 21 characters
- if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
- let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
+ if cell_data.len() > self.display_config.max_cell_length {
+ let short_cell_data = &cell_data[0..self.display_config.max_cell_length];
cells.push(format!("
@@ -269,7 +279,7 @@ impl PyDataFrame {
fn describe(&self, py: Python) -> PyDataFusionResult {
let df = self.df.as_ref().clone();
let stat_df = wait_for_future(py, df.describe())?;
- Ok(Self::new(stat_df))
+ Ok(Self::new(stat_df, self.display_config.as_ref().clone()))
}
/// Returns the schema from the logical plan
@@ -299,31 +309,31 @@ impl PyDataFrame {
fn select_columns(&self, args: Vec) -> PyDataFusionResult {
let args = args.iter().map(|s| s.as_ref()).collect::>();
let df = self.df.as_ref().clone().select_columns(&args)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (*args))]
fn select(&self, args: Vec) -> PyDataFusionResult {
let expr = args.into_iter().map(|e| e.into()).collect();
let df = self.df.as_ref().clone().select(expr)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (*args))]
fn drop(&self, args: Vec) -> PyDataFusionResult {
let cols = args.iter().map(|s| s.as_ref()).collect::>();
let df = self.df.as_ref().clone().drop_columns(&cols)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn filter(&self, predicate: PyExpr) -> PyDataFusionResult {
let df = self.df.as_ref().clone().filter(predicate.into())?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn with_column(&self, name: &str, expr: PyExpr) -> PyDataFusionResult {
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn with_columns(&self, exprs: Vec) -> PyDataFusionResult {
@@ -333,7 +343,7 @@ impl PyDataFrame {
let name = format!("{}", expr.schema_name());
df = df.with_column(name.as_str(), expr)?
}
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
/// Rename one column by applying a new projection. This is a no-op if the column to be
@@ -344,27 +354,27 @@ impl PyDataFrame {
.as_ref()
.clone()
.with_column_renamed(old_name, new_name)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn aggregate(&self, group_by: Vec, aggs: Vec) -> PyDataFusionResult {
let group_by = group_by.into_iter().map(|e| e.into()).collect();
let aggs = aggs.into_iter().map(|e| e.into()).collect();
let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (*exprs))]
fn sort(&self, exprs: Vec) -> PyDataFusionResult {
let exprs = to_sort_expressions(exprs);
let df = self.df.as_ref().clone().sort(exprs)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (count, offset=0))]
fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult {
let df = self.df.as_ref().clone().limit(offset, Some(count))?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
/// Executes the plan, returning a list of `RecordBatch`es.
@@ -381,7 +391,7 @@ impl PyDataFrame {
/// Cache DataFrame.
fn cache(&self, py: Python) -> PyDataFusionResult {
let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
/// Executes this DataFrame and collects all results into a vector of vector of RecordBatch
@@ -406,7 +416,7 @@ impl PyDataFrame {
/// Filter out duplicate rows
fn distinct(&self) -> PyDataFusionResult {
let df = self.df.as_ref().clone().distinct()?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn join(
@@ -440,7 +450,7 @@ impl PyDataFrame {
&right_keys,
None,
)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
fn join_on(
@@ -469,7 +479,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.join_on(right.df.as_ref().clone(), join_type, exprs)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
/// Print the query plan
@@ -502,7 +512,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.repartition(Partitioning::RoundRobinBatch(num))?;
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
/// Repartition a `DataFrame` based on a logical partitioning scheme.
@@ -514,7 +524,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.repartition(Partitioning::Hash(expr, num))?;
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
/// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
@@ -530,7 +540,7 @@ impl PyDataFrame {
self.df.as_ref().clone().union(py_df.df.as_ref().clone())?
};
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
/// Calculate the distinct union of two `DataFrame`s. The
@@ -541,7 +551,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.union_distinct(py_df.df.as_ref().clone())?;
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (column, preserve_nulls=true))]
@@ -554,7 +564,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.unnest_columns_with_options(&[column], unnest_options)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
#[pyo3(signature = (columns, preserve_nulls=true))]
@@ -572,7 +582,7 @@ impl PyDataFrame {
.as_ref()
.clone()
.unnest_columns_with_options(&cols, unnest_options)?;
- Ok(Self::new(df))
+ Ok(Self::new(df, self.display_config.as_ref().clone()))
}
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
@@ -582,13 +592,13 @@ impl PyDataFrame {
.as_ref()
.clone()
.intersect(py_df.df.as_ref().clone())?;
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema
fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult {
let new_df = self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
- Ok(Self::new(new_df))
+ Ok(Self::new(new_df, self.display_config.as_ref().clone()))
}
/// Write a `DataFrame` to a CSV file.
@@ -835,7 +845,7 @@ fn record_batch_into_schema(
) -> Result {
let schema = Arc::new(schema.clone());
let base_schema = record_batch.schema();
- if base_schema.fields().len() == 0 {
+ if base_schema.fields().is_empty() {
// Nothing to project
return Ok(RecordBatch::new_empty(schema));
}
@@ -886,6 +896,7 @@ async fn collect_record_batches_to_display(
df: DataFrame,
min_rows: usize,
max_rows: usize,
+ max_bytes: usize,
) -> Result<(Vec, bool), DataFusionError> {
let partitioned_stream = df.execute_stream_partitioned().await?;
let mut stream = futures::stream::iter(partitioned_stream).flatten();
@@ -894,9 +905,7 @@ async fn collect_record_batches_to_display(
let mut record_batches = Vec::default();
let mut has_more = false;
- while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
- || rows_so_far < min_rows
- {
+ while (size_estimate_so_far < max_bytes && rows_so_far < max_rows) || rows_so_far < min_rows {
let mut rb = match stream.next().await {
None => {
break;
@@ -909,8 +918,8 @@ async fn collect_record_batches_to_display(
if rows_in_rb > 0 {
size_estimate_so_far += rb.get_array_memory_size();
- if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
- let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
+ if size_estimate_so_far > max_bytes {
+ let ratio = max_bytes as f32 / size_estimate_so_far as f32;
let total_rows = rows_in_rb + rows_so_far;
let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
diff --git a/src/lib.rs b/src/lib.rs
index 6eeda0878..d31e650f2 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -83,6 +83,7 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
+ m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
m.add_class::()?;
|