diff --git a/src/textual/widgets/_data_table.py b/src/textual/widgets/_data_table.py index 951dc133c7..fb6d53ded5 100644 --- a/src/textual/widgets/_data_table.py +++ b/src/textual/widgets/_data_table.py @@ -1013,13 +1013,14 @@ def get_column(self, column_key: ColumnKey | str) -> Iterable[CellType]: Raises: ColumnDoesNotExist: If there is no column corresponding to the key. """ - if column_key not in self._column_locations: - raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") - data = self._data - for row_metadata in self.ordered_rows: - row_key = row_metadata.key - yield data[row_key][column_key] + try: + for row_metadata in self.ordered_rows: + row_key = row_metadata.key + row = data[row_key] + yield row[column_key] + except KeyError: + raise ColumnDoesNotExist(f"Column key {column_key!r} is not valid.") def get_column_at(self, column_index: int) -> Iterable[CellType]: """Get the values from the column at a given index. @@ -1051,9 +1052,10 @@ def get_column_index(self, column_key: ColumnKey | str) -> int: Raises: ColumnDoesNotExist: If the column key does not exist. """ - if column_key not in self._column_locations: + column_index = self._column_locations.get(column_key) + if column_index is None: raise ColumnDoesNotExist(f"No column exists for column_key={column_key!r}") - return self._column_locations.get(column_key) + return column_index def _clear_caches(self) -> None: self._row_render_cache.clear() @@ -1075,7 +1077,12 @@ def get_row_height(self, row_key: RowKey) -> int: """ if row_key is self._header_row_key: return self.header_height - return self.rows[row_key].height + + row = self.rows.get(row_key) + if row is None: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") + + return row.height def notify_style_update(self) -> None: self._row_render_cache.clear() @@ -1466,14 +1473,19 @@ def _update_dimensions(self, new_rows: Iterable[RowKey]) -> None: self._total_row_height + header_height, ) - def _get_cell_region(self, coordinate: Coordinate) -> Region: + def _get_cell_region(self, coordinate: Coordinate) -> Region | None: """Get the region of the cell at the given spatial coordinate.""" if not self.is_valid_coordinate(coordinate): - return Region(0, 0, 0, 0) + return None row_index, column_index = coordinate row_key = self._row_locations.get_key(row_index) - row = self.rows[row_key] + if row_key is None: + return None + + row = self.rows.get(row_key) + if row is None: + return None # The x-coordinate of a cell is the sum of widths of the data cells to the left # plus the width of the render width of the longest row label. @@ -1485,6 +1497,8 @@ def _get_cell_region(self, coordinate: Coordinate) -> Region: + self._row_label_column_width ) column_key = self._column_locations.get_key(column_index) + if column_key is None: + return None width = self.columns[column_key].get_render_width(self) height = row.height y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) @@ -1493,28 +1507,36 @@ def _get_cell_region(self, coordinate: Coordinate) -> Region: cell_region = Region(x, y, width, height) return cell_region - def _get_row_region(self, row_index: int) -> Region: + def _get_row_region(self, row_index: int) -> Region | None: """Get the region of the row at the given index.""" if not self.is_valid_row_index(row_index): - return Region(0, 0, 0, 0) + return None rows = self.rows row_key = self._row_locations.get_key(row_index) - row = rows[row_key] + if row_key is None: + return None + + row = rows.get(row_key) + if row is None: + return None + row_width = ( sum(column.get_render_width(self) for column in self.columns.values()) + self._row_label_column_width ) y = sum(ordered_row.height for ordered_row in self.ordered_rows[:row_index]) + if self.show_header: y += self.header_height + row_region = Region(0, y, row_width, row.height) return row_region - def _get_column_region(self, column_index: int) -> Region: + def _get_column_region(self, column_index: int) -> Region | None: """Get the region of the column at the given index.""" if not self.is_valid_column_index(column_index): - return Region(0, 0, 0, 0) + return None columns = self.columns x = ( @@ -1525,6 +1547,9 @@ def _get_column_region(self, column_index: int) -> Region: + self._row_label_column_width ) column_key = self._column_locations.get_key(column_index) + if column_key is None: + return None + width = columns[column_key].get_render_width(self) header_height = self.header_height if self.show_header else 0 height = self._total_row_height + header_height @@ -1737,7 +1762,10 @@ def remove_row(self, row_key: RowKey | str) -> None: self.check_idle() index_to_delete = self._row_locations.get(row_key) - new_row_locations = TwoWayDict({}) + if index_to_delete is None: + raise RowDoesNotExist(f"Row key {row_key!r} is not valid.") + + new_row_locations = TwoWayDict[RowKey, int]({}) for row_location_key in self._row_locations: row_index = self._row_locations.get(row_location_key) if row_index > index_to_delete: @@ -1833,6 +1861,8 @@ def refresh_coordinate(self, coordinate: Coordinate) -> Self: if not self.is_valid_coordinate(coordinate): return self region = self._get_cell_region(coordinate) + if region is None: + return self self._refresh_region(region) return self @@ -1849,6 +1879,8 @@ def refresh_row(self, row_index: int) -> Self: return self region = self._get_row_region(row_index) + if region is None: + return self self._refresh_region(region) return self @@ -1865,6 +1897,8 @@ def refresh_column(self, column_index: int) -> Self: return self region = self._get_column_region(column_index) + if region is None: + return self self._refresh_region(region) return self @@ -2535,13 +2569,21 @@ def _scroll_cursor_into_view(self, animate: bool = False) -> None: top, _, _, left = fixed_offset if self.cursor_type == "row": - x, y, width, height = self._get_row_region(self.cursor_row) + row_region = self._get_row_region(self.cursor_row) + if row_region is None: + return + x, y, width, height = row_region region = Region(int(self.scroll_x) + left, y, width - left, height) elif self.cursor_type == "column": - x, y, width, height = self._get_column_region(self.cursor_column) + column_region = self._get_column_region(self.cursor_column) + if column_region is None: + return + x, y, width, height = column_region region = Region(x, int(self.scroll_y) + top, width, height - top) else: region = self._get_cell_region(self.cursor_coordinate) + if region is None: + return self.scroll_to_region(region, animate=animate, spacing=fixed_offset, force=True)