|
| 1 | +from dataclasses import dataclass, field |
| 2 | +from typing import Any, List, Mapping, Union, Any |
| 3 | +from collections import defaultdict |
| 4 | +from st_aggrid.shared import DataReturnMode |
| 5 | + |
| 6 | +import json |
| 7 | +import pandas as pd |
| 8 | +import numpy as np |
| 9 | +import inspect |
| 10 | + |
| 11 | + |
| 12 | +class AgGridReturn(Mapping): |
| 13 | + """Class to hold AgGrid call return""" |
| 14 | + |
| 15 | + # selected_rows: List[Mapping] = field(default_factory=list) |
| 16 | + # column_state = None |
| 17 | + # excel_blob = None |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + originalData, |
| 22 | + gridOptions=None, |
| 23 | + data_return_mode=DataReturnMode.AS_INPUT, |
| 24 | + try_to_convert_back_to_original_types=True, |
| 25 | + conversion_errors="corce" |
| 26 | + ) -> None: |
| 27 | + super().__init__() |
| 28 | + |
| 29 | + # def ddict(): |
| 30 | + # return defaultdict(ddict) |
| 31 | + |
| 32 | + # self.__dict__ = ddict() |
| 33 | + |
| 34 | + self.__component_value_set = False |
| 35 | + |
| 36 | + self.__original_data = originalData |
| 37 | + self.__try_to_convert_back_to_original_types = ( |
| 38 | + try_to_convert_back_to_original_types |
| 39 | + ) |
| 40 | + self.__conversion_errors = conversion_errors |
| 41 | + self.__data_return_mode = data_return_mode |
| 42 | + |
| 43 | + self.__dict__["grid_response"] = {"gridOptions": gridOptions} |
| 44 | + |
| 45 | + |
| 46 | + def _set_component_value(self, component_value): |
| 47 | + self.__component_value_set = True |
| 48 | + |
| 49 | + self.__dict__["grid_response"] = component_value |
| 50 | + self.__dict__["grid_response"]["gridOptions"] = json.loads( |
| 51 | + self.__dict__["grid_response"]["gridOptions"] |
| 52 | + ) |
| 53 | + |
| 54 | + @property |
| 55 | + def grid_response(self): |
| 56 | + """Raw response from component.""" |
| 57 | + return self.__dict__["grid_response"] |
| 58 | + |
| 59 | + @property |
| 60 | + def rows_id_after_sort_and_filter(self): |
| 61 | + """The row indexes after sort and filter is applied""" |
| 62 | + return self.grid_response.get("rowIdsAfterSortAndFilter") |
| 63 | + |
| 64 | + @property |
| 65 | + def rows_id_after_filter(self): |
| 66 | + """The filtered row indexes""" |
| 67 | + return self.grid_response.get("rowIdsAfterFilter") |
| 68 | + |
| 69 | + @property |
| 70 | + def grid_options(self): |
| 71 | + """GridOptions as applied on the grid.""" |
| 72 | + return self.grid_response.get("gridOptions", {}) |
| 73 | + |
| 74 | + @property |
| 75 | + def columns_state(self): |
| 76 | + """Gets the state of the columns. Typically used when saving column state.""" |
| 77 | + return self.grid_response.get("columnsState") |
| 78 | + |
| 79 | + @property |
| 80 | + def grid_state(self): |
| 81 | + """Gets the grid state. Tipically used on initialState option. (https://ag-grid.com/javascript-data-grid//grid-options/#reference-miscellaneous-initialState)""" |
| 82 | + return self.grid_response.get("gridState") |
| 83 | + |
| 84 | + @property |
| 85 | + def selected_rows_id(self): |
| 86 | + """Ids of selected rows""" |
| 87 | + return self.grid_state.get("rowSelection") |
| 88 | + |
| 89 | + def __process_vanilla_df_response( |
| 90 | + self, nodes, __try_to_convert_back_to_original_types |
| 91 | + ): |
| 92 | + data = pd.DataFrame([n.get("data", {}) for n in nodes if not n.get("group", False) == True]) |
| 93 | + |
| 94 | + if "__pandas_index" in data.columns: |
| 95 | + data.index = pd.Index(data["__pandas_index"], name="index") |
| 96 | + del data["__pandas_index"] |
| 97 | + |
| 98 | + if __try_to_convert_back_to_original_types: |
| 99 | + original_types = self.grid_response["originalDtypes"] |
| 100 | + try: |
| 101 | + original_types.pop("__pandas_index") |
| 102 | + except: |
| 103 | + pass |
| 104 | + |
| 105 | + numeric_columns = [ |
| 106 | + k for k, v in original_types.items() if v in ["i", "u", "f"] |
| 107 | + ] |
| 108 | + if numeric_columns: |
| 109 | + data.loc[:, numeric_columns] = data.loc[:, numeric_columns].apply( |
| 110 | + pd.to_numeric, errors=self.__conversion_errors |
| 111 | + ) |
| 112 | + |
| 113 | + text_columns = [ |
| 114 | + k for k, v in original_types.items() if v in ["O", "S", "U"] |
| 115 | + ] |
| 116 | + |
| 117 | + if text_columns: |
| 118 | + data.loc[:, text_columns] = data.loc[:, text_columns].applymap( |
| 119 | + lambda x: np.nan if x is None else str(x) |
| 120 | + ) |
| 121 | + |
| 122 | + date_columns = [k for k, v in original_types.items() if v == "M"] |
| 123 | + if date_columns: |
| 124 | + data.loc[:, date_columns] = data.loc[:, date_columns].apply( |
| 125 | + pd.to_datetime, errors=self.__conversion_errors |
| 126 | + ) |
| 127 | + |
| 128 | + timedelta_columns = [k for k, v in original_types.items() if v == "m"] |
| 129 | + if timedelta_columns: |
| 130 | + |
| 131 | + def cast_to_timedelta(s): |
| 132 | + try: |
| 133 | + return pd.Timedelta(s) |
| 134 | + except: |
| 135 | + return s |
| 136 | + |
| 137 | + data.loc[:, timedelta_columns] = data.loc[:, timedelta_columns].apply( |
| 138 | + cast_to_timedelta |
| 139 | + ) |
| 140 | + |
| 141 | + return data |
| 142 | + |
| 143 | + def __process_grouped_response( |
| 144 | + self, nodes, __try_to_convert_back_to_original_types, __data_return_mode |
| 145 | + ): |
| 146 | + def travel_parent(o): |
| 147 | + |
| 148 | + if o.get("parent", None) == None: |
| 149 | + return "" |
| 150 | + |
| 151 | + return rf"""{travel_parent(o.get("parent"))}.{o.get("parent").get('key')}""".lstrip( |
| 152 | + "." |
| 153 | + ) |
| 154 | + |
| 155 | + data = [ |
| 156 | + {**i.get("data"), **{"parent": travel_parent(i)}} |
| 157 | + for i in nodes |
| 158 | + if i.get("group", False) == False |
| 159 | + ] |
| 160 | + data = pd.DataFrame(data).set_index("__pandas_index") |
| 161 | + data.index.name = '' |
| 162 | + groups = [{tuple(v1.split(".")[1:]): v2.drop('parent', axis=1)} for v1, v2 in data.groupby("parent")] |
| 163 | + return groups |
| 164 | + |
| 165 | + def __get_data(self, onlySelected): |
| 166 | + data = self.__original_data if not onlySelected else None |
| 167 | + |
| 168 | + if self.__component_value_set: |
| 169 | + nodes = self.grid_response.get("nodes",[]) |
| 170 | + |
| 171 | + if onlySelected: |
| 172 | + nodes = list(filter(lambda n: n.get('isSelected', False) == True, nodes)) |
| 173 | + |
| 174 | + if not nodes: |
| 175 | + return None |
| 176 | + |
| 177 | + data = self.__process_vanilla_df_response( |
| 178 | + nodes, |
| 179 | + self.__try_to_convert_back_to_original_types and onlySelected |
| 180 | + ) |
| 181 | + |
| 182 | + |
| 183 | + reindex_ids_map = { |
| 184 | + DataReturnMode.FILTERED: self.rows_id_after_filter, |
| 185 | + DataReturnMode.FILTERED_AND_SORTED:self.rows_id_after_sort_and_filter |
| 186 | + } |
| 187 | + |
| 188 | + reindex_ids = reindex_ids_map.get(self.__data_return_mode, None) |
| 189 | + |
| 190 | + if reindex_ids: |
| 191 | + reindex_ids = pd.Index(reindex_ids) |
| 192 | + |
| 193 | + if onlySelected: |
| 194 | + reindex_ids = reindex_ids.intersection(data.index) |
| 195 | + |
| 196 | + data = data.reindex(index=reindex_ids) |
| 197 | + |
| 198 | + return data |
| 199 | + |
| 200 | + @property |
| 201 | + def data(self): |
| 202 | + "Data from the grid. If rows are grouped, return only the leaf rows" |
| 203 | + |
| 204 | + return self.__get_data(onlySelected=False) |
| 205 | + |
| 206 | + @property |
| 207 | + def selected_data(self): |
| 208 | + "Selected Data from the grid." |
| 209 | + |
| 210 | + return self.__get_data(onlySelected=True) |
| 211 | + |
| 212 | + def __get_dataGroups(self, onlySelected): |
| 213 | + if self.__component_value_set: |
| 214 | + nodes = self.grid_response.get("nodes",[]) |
| 215 | + |
| 216 | + if onlySelected: |
| 217 | + #n.get('isSelected', True). Default is true bc agGrid sets undefined for half selected groups |
| 218 | + nodes = list(filter(lambda n: n.get('isSelected', True) == True, nodes)) |
| 219 | + |
| 220 | + if not nodes: |
| 221 | + return [{(''):self.__get_data(onlySelected)}] |
| 222 | + |
| 223 | + response_has_groups = any((n.get("group", False) for n in nodes)) |
| 224 | + |
| 225 | + if response_has_groups: |
| 226 | + data = self.__process_grouped_response( |
| 227 | + nodes, |
| 228 | + self.__try_to_convert_back_to_original_types, |
| 229 | + self.__data_return_mode, |
| 230 | + ) |
| 231 | + return data |
| 232 | + |
| 233 | + return [{(''):self.__get_data(onlySelected)}] |
| 234 | + |
| 235 | + @property |
| 236 | + def dataGroups(self): |
| 237 | + "returns grouped rows as a dictionary where keys are tuples of groupby strings and values are pandas.DataFrame" |
| 238 | + |
| 239 | + return self.__get_dataGroups(onlySelected=False) |
| 240 | + |
| 241 | + @property |
| 242 | + def selected_dataGroups(self): |
| 243 | + "returns selected rows as a dictionary where keys are tuples of grouped column names and values are pandas.DataFrame" |
| 244 | + |
| 245 | + return self.__get_dataGroups(onlySelected=True) |
| 246 | + |
| 247 | + @property |
| 248 | + def selected_rows(self): |
| 249 | + """Returns with selected rows. If there are grouped rows return a dict of {key:pd.DataFrame}""" |
| 250 | + selected_items = pd.DataFrame(self.grid_response.get("selectedItems", {})) |
| 251 | + |
| 252 | + if selected_items.empty: |
| 253 | + return None |
| 254 | + |
| 255 | + if "__pandas_index" in selected_items.columns: |
| 256 | + selected_items.set_index("__pandas_index", inplace=True) |
| 257 | + selected_items.index.name = "index" |
| 258 | + |
| 259 | + return selected_items |
| 260 | + |
| 261 | + #TODO: implement event returns |
| 262 | + @property |
| 263 | + def event_data(self): |
| 264 | + """Returns information about the event that triggered AgGrid Response""" |
| 265 | + return self.grid_response.get("eventData",None) |
| 266 | + |
| 267 | + # Backwards compatibility with dict interface |
| 268 | + def __getitem__(self, __k): |
| 269 | + |
| 270 | + try: |
| 271 | + return getattr(self, __k) |
| 272 | + except AttributeError: |
| 273 | + return self.__dict__.__getitem__(__k) |
| 274 | + |
| 275 | + def __iter__(self): |
| 276 | + attrs = (x for x in inspect.getmembers(self) if not x[0].startswith('_')) |
| 277 | + return attrs.__iter__() |
| 278 | + |
| 279 | + def __len__(self): |
| 280 | + attrs = [x for x in inspect.getmembers(self) if not x[0].startswith('_')] |
| 281 | + return attrs.__len__() |
| 282 | + |
| 283 | + def keys(self): |
| 284 | + return [x[0] for x in inspect.getmembers(self) if not x[0].startswith('_')] |
| 285 | + |
| 286 | + def values(self): |
| 287 | + return [x[1] for x in inspect.getmembers(self) if not x[0].startswith('_')] |
0 commit comments