Skip to content

Commit d84953a

Browse files
committed
Bidi Preview first commit
1 parent 66d5ed1 commit d84953a

File tree

82 files changed

+32547
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+32547
-0
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
![](../shared_assets/sis-header.jpeg)
2+
3+
## Private Preview
4+
5+
**Note:** This is an experimental preview feature and is not fully functional. Support may be limited to specific accounts.
6+
7+
## Bidirectional Custom Components in Streamlit in Snowflake
8+
9+
This app demonstrates how to use bidirectional custom components in Streamlit in Snowflake.
10+
11+
12+
13+
## Limitations
14+
* Only components which don't trigger requests to external services are supported.
15+
* Supported components include:
16+
* AgGrid (streamlit-aggrid)
17+
* ECharts (streamlit-echarts)
18+
19+
20+
## Troubleshooting
21+
* When using git to add component files, git add may ignore files in `build`, or `frontend` directories.
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any, List, Mapping, Union, Any
3+
from collections import defaultdict
4+
from st_aggrid.shared import DataReturnMode
5+
6+
import json
7+
import pandas as pd
8+
import numpy as np
9+
import inspect
10+
11+
12+
class AgGridReturn(Mapping):
13+
"""Class to hold AgGrid call return"""
14+
15+
# selected_rows: List[Mapping] = field(default_factory=list)
16+
# column_state = None
17+
# excel_blob = None
18+
19+
def __init__(
20+
self,
21+
originalData,
22+
gridOptions=None,
23+
data_return_mode=DataReturnMode.AS_INPUT,
24+
try_to_convert_back_to_original_types=True,
25+
conversion_errors="corce"
26+
) -> None:
27+
super().__init__()
28+
29+
# def ddict():
30+
# return defaultdict(ddict)
31+
32+
# self.__dict__ = ddict()
33+
34+
self.__component_value_set = False
35+
36+
self.__original_data = originalData
37+
self.__try_to_convert_back_to_original_types = (
38+
try_to_convert_back_to_original_types
39+
)
40+
self.__conversion_errors = conversion_errors
41+
self.__data_return_mode = data_return_mode
42+
43+
self.__dict__["grid_response"] = {"gridOptions": gridOptions}
44+
45+
46+
def _set_component_value(self, component_value):
47+
self.__component_value_set = True
48+
49+
self.__dict__["grid_response"] = component_value
50+
self.__dict__["grid_response"]["gridOptions"] = json.loads(
51+
self.__dict__["grid_response"]["gridOptions"]
52+
)
53+
54+
@property
55+
def grid_response(self):
56+
"""Raw response from component."""
57+
return self.__dict__["grid_response"]
58+
59+
@property
60+
def rows_id_after_sort_and_filter(self):
61+
"""The row indexes after sort and filter is applied"""
62+
return self.grid_response.get("rowIdsAfterSortAndFilter")
63+
64+
@property
65+
def rows_id_after_filter(self):
66+
"""The filtered row indexes"""
67+
return self.grid_response.get("rowIdsAfterFilter")
68+
69+
@property
70+
def grid_options(self):
71+
"""GridOptions as applied on the grid."""
72+
return self.grid_response.get("gridOptions", {})
73+
74+
@property
75+
def columns_state(self):
76+
"""Gets the state of the columns. Typically used when saving column state."""
77+
return self.grid_response.get("columnsState")
78+
79+
@property
80+
def grid_state(self):
81+
"""Gets the grid state. Tipically used on initialState option. (https://ag-grid.com/javascript-data-grid//grid-options/#reference-miscellaneous-initialState)"""
82+
return self.grid_response.get("gridState")
83+
84+
@property
85+
def selected_rows_id(self):
86+
"""Ids of selected rows"""
87+
return self.grid_state.get("rowSelection")
88+
89+
def __process_vanilla_df_response(
90+
self, nodes, __try_to_convert_back_to_original_types
91+
):
92+
data = pd.DataFrame([n.get("data", {}) for n in nodes if not n.get("group", False) == True])
93+
94+
if "__pandas_index" in data.columns:
95+
data.index = pd.Index(data["__pandas_index"], name="index")
96+
del data["__pandas_index"]
97+
98+
if __try_to_convert_back_to_original_types:
99+
original_types = self.grid_response["originalDtypes"]
100+
try:
101+
original_types.pop("__pandas_index")
102+
except:
103+
pass
104+
105+
numeric_columns = [
106+
k for k, v in original_types.items() if v in ["i", "u", "f"]
107+
]
108+
if numeric_columns:
109+
data.loc[:, numeric_columns] = data.loc[:, numeric_columns].apply(
110+
pd.to_numeric, errors=self.__conversion_errors
111+
)
112+
113+
text_columns = [
114+
k for k, v in original_types.items() if v in ["O", "S", "U"]
115+
]
116+
117+
if text_columns:
118+
data.loc[:, text_columns] = data.loc[:, text_columns].applymap(
119+
lambda x: np.nan if x is None else str(x)
120+
)
121+
122+
date_columns = [k for k, v in original_types.items() if v == "M"]
123+
if date_columns:
124+
data.loc[:, date_columns] = data.loc[:, date_columns].apply(
125+
pd.to_datetime, errors=self.__conversion_errors
126+
)
127+
128+
timedelta_columns = [k for k, v in original_types.items() if v == "m"]
129+
if timedelta_columns:
130+
131+
def cast_to_timedelta(s):
132+
try:
133+
return pd.Timedelta(s)
134+
except:
135+
return s
136+
137+
data.loc[:, timedelta_columns] = data.loc[:, timedelta_columns].apply(
138+
cast_to_timedelta
139+
)
140+
141+
return data
142+
143+
def __process_grouped_response(
144+
self, nodes, __try_to_convert_back_to_original_types, __data_return_mode
145+
):
146+
def travel_parent(o):
147+
148+
if o.get("parent", None) == None:
149+
return ""
150+
151+
return rf"""{travel_parent(o.get("parent"))}.{o.get("parent").get('key')}""".lstrip(
152+
"."
153+
)
154+
155+
data = [
156+
{**i.get("data"), **{"parent": travel_parent(i)}}
157+
for i in nodes
158+
if i.get("group", False) == False
159+
]
160+
data = pd.DataFrame(data).set_index("__pandas_index")
161+
data.index.name = ''
162+
groups = [{tuple(v1.split(".")[1:]): v2.drop('parent', axis=1)} for v1, v2 in data.groupby("parent")]
163+
return groups
164+
165+
def __get_data(self, onlySelected):
166+
data = self.__original_data if not onlySelected else None
167+
168+
if self.__component_value_set:
169+
nodes = self.grid_response.get("nodes",[])
170+
171+
if onlySelected:
172+
nodes = list(filter(lambda n: n.get('isSelected', False) == True, nodes))
173+
174+
if not nodes:
175+
return None
176+
177+
data = self.__process_vanilla_df_response(
178+
nodes,
179+
self.__try_to_convert_back_to_original_types and onlySelected
180+
)
181+
182+
183+
reindex_ids_map = {
184+
DataReturnMode.FILTERED: self.rows_id_after_filter,
185+
DataReturnMode.FILTERED_AND_SORTED:self.rows_id_after_sort_and_filter
186+
}
187+
188+
reindex_ids = reindex_ids_map.get(self.__data_return_mode, None)
189+
190+
if reindex_ids:
191+
reindex_ids = pd.Index(reindex_ids)
192+
193+
if onlySelected:
194+
reindex_ids = reindex_ids.intersection(data.index)
195+
196+
data = data.reindex(index=reindex_ids)
197+
198+
return data
199+
200+
@property
201+
def data(self):
202+
"Data from the grid. If rows are grouped, return only the leaf rows"
203+
204+
return self.__get_data(onlySelected=False)
205+
206+
@property
207+
def selected_data(self):
208+
"Selected Data from the grid."
209+
210+
return self.__get_data(onlySelected=True)
211+
212+
def __get_dataGroups(self, onlySelected):
213+
if self.__component_value_set:
214+
nodes = self.grid_response.get("nodes",[])
215+
216+
if onlySelected:
217+
#n.get('isSelected', True). Default is true bc agGrid sets undefined for half selected groups
218+
nodes = list(filter(lambda n: n.get('isSelected', True) == True, nodes))
219+
220+
if not nodes:
221+
return [{(''):self.__get_data(onlySelected)}]
222+
223+
response_has_groups = any((n.get("group", False) for n in nodes))
224+
225+
if response_has_groups:
226+
data = self.__process_grouped_response(
227+
nodes,
228+
self.__try_to_convert_back_to_original_types,
229+
self.__data_return_mode,
230+
)
231+
return data
232+
233+
return [{(''):self.__get_data(onlySelected)}]
234+
235+
@property
236+
def dataGroups(self):
237+
"returns grouped rows as a dictionary where keys are tuples of groupby strings and values are pandas.DataFrame"
238+
239+
return self.__get_dataGroups(onlySelected=False)
240+
241+
@property
242+
def selected_dataGroups(self):
243+
"returns selected rows as a dictionary where keys are tuples of grouped column names and values are pandas.DataFrame"
244+
245+
return self.__get_dataGroups(onlySelected=True)
246+
247+
@property
248+
def selected_rows(self):
249+
"""Returns with selected rows. If there are grouped rows return a dict of {key:pd.DataFrame}"""
250+
selected_items = pd.DataFrame(self.grid_response.get("selectedItems", {}))
251+
252+
if selected_items.empty:
253+
return None
254+
255+
if "__pandas_index" in selected_items.columns:
256+
selected_items.set_index("__pandas_index", inplace=True)
257+
selected_items.index.name = "index"
258+
259+
return selected_items
260+
261+
#TODO: implement event returns
262+
@property
263+
def event_data(self):
264+
"""Returns information about the event that triggered AgGrid Response"""
265+
return self.grid_response.get("eventData",None)
266+
267+
# Backwards compatibility with dict interface
268+
def __getitem__(self, __k):
269+
270+
try:
271+
return getattr(self, __k)
272+
except AttributeError:
273+
return self.__dict__.__getitem__(__k)
274+
275+
def __iter__(self):
276+
attrs = (x for x in inspect.getmembers(self) if not x[0].startswith('_'))
277+
return attrs.__iter__()
278+
279+
def __len__(self):
280+
attrs = [x for x in inspect.getmembers(self) if not x[0].startswith('_')]
281+
return attrs.__len__()
282+
283+
def keys(self):
284+
return [x[0] for x in inspect.getmembers(self) if not x[0].startswith('_')]
285+
286+
def values(self):
287+
return [x[1] for x in inspect.getmembers(self) if not x[0].startswith('_')]

0 commit comments

Comments
 (0)