diff --git a/notion/block.py b/notion/block.py index f19784f9..5c72b30e 100644 --- a/notion/block.py +++ b/notion/block.py @@ -877,6 +877,30 @@ class CalloutBlock(BasicBlock): _type = "callout" +class TableRowBlock(BasicBlock): + _type = "table_row" + + +class TableBlock(BasicBlock): + _type = "table" + + table_block_column_order = field_map( + "format.table_block_column_order", + ) + + def set_columns(self, num_columns): + self._columns = [f"{i:04d}" for i in range(num_columns)] + self.table_block_column_order = self._columns + + def add_row(self, row): + row_block = self.children.add_new(TableRowBlock) + with self._client.as_atomic_transaction(): + for col_id, cell in zip(self._columns, row): + attr = property_map(f"{col_id}") + attr.fset(row_block, str(cell)) + return row_block + + BLOCK_TYPES = { cls._type: cls for cls in locals().values() diff --git a/notion/client.py b/notion/client.py index 3d748934..2ff67c21 100644 --- a/notion/client.py +++ b/notion/client.py @@ -1,5 +1,4 @@ import hashlib -import json import re import uuid @@ -53,6 +52,7 @@ def create_session(client_specified_retry=None): ) adapter = HTTPAdapter(max_retries=retry) session.mount("https://", adapter) + session.headers.update({"content-type": "application/json"}) return session @@ -130,6 +130,10 @@ def _update_user_info(self): self._store.store_recordmap(records) self.current_user = self.get_user(list(records["notion_user"].keys())[0]) self.current_space = self.get_space(list(records["space"].keys())[0]) + + self.session.headers.update({"x-notion-active-user-header": + self.session.cookies.get("notion_user_id")}) + return records def get_email_uid(self): @@ -140,7 +144,6 @@ def get_email_uid(self): } def set_user_by_uid(self, user_id): - self.session.headers.update({"x-notion-active-user-header": user_id}) self._update_user_info() def set_user_by_email(self, email): @@ -158,15 +161,15 @@ def get_top_level_pages(self): records = self._update_user_info() return [self.get_block(bid) for bid in records["block"].keys()] - def get_record_data(self, table, id, force_refresh=False): - return self._store.get(table, id, force_refresh=force_refresh) + def get_record_data(self, table, id, force_refresh=False, limit=100): + return self._store.get(table, id, force_refresh=force_refresh, limit=limit) - def get_block(self, url_or_id, force_refresh=False): + def get_block(self, url_or_id, force_refresh=False, limit=100): """ Retrieve an instance of a subclass of Block that maps to the block/page identified by the URL or ID passed in. """ block_id = extract_id(url_or_id) - block = self.get_record_data("block", block_id, force_refresh=force_refresh) + block = self.get_record_data("block", block_id, force_refresh=force_refresh, limit=limit) if not block: return None if block.get("parent_table") == "collection": @@ -306,11 +309,11 @@ def in_transaction(self): """ return hasattr(self, "_transaction_operations") - def search_pages_with_parent(self, parent_id, search=""): + def search_pages_with_parent(self, parent_id, search="", limit=100): data = { "query": search, "parentId": parent_id, - "limit": 10000, + "limit": limit, "spaceId": self.current_space.id, } response = self.post("searchPagesWithParent", data).json() diff --git a/notion/collection.py b/notion/collection.py index 748cc067..12eab046 100644 --- a/notion/collection.py +++ b/notion/collection.py @@ -360,6 +360,7 @@ def __init__( sort=[], calendar_by="", group_by="", + limit=100 ): assert not ( aggregate and aggregations @@ -374,25 +375,40 @@ def __init__( self.sort = _normalize_query_data(sort, collection) self.calendar_by = _normalize_property_name(calendar_by, collection) self.group_by = _normalize_property_name(group_by, collection) + self.limit = limit self._client = collection._client def execute(self): result_class = QUERY_RESULT_TYPES.get(self.type, QueryResult) + kwargs = { + 'collection_id':self.collection.id, + 'collection_view_id':self.collection_view.id, + 'search':self.search, + 'type':self.type, + 'aggregate':self.aggregate, + 'aggregations':self.aggregations, + 'filter':self.filter, + 'sort':self.sort, + 'calendar_by':self.calendar_by, + 'group_by':self.group_by, + 'limit':0 + } + + if self.limit == -1: + # fetch remote total + result = self._client.query_collection( + **kwargs + ) + self.limit = result.get("total",-1) + + kwargs['limit'] = self.limit + return result_class( self.collection, self._client.query_collection( - collection_id=self.collection.id, - collection_view_id=self.collection_view.id, - search=self.search, - type=self.type, - aggregate=self.aggregate, - aggregations=self.aggregations, - filter=self.filter, - sort=self.sort, - calendar_by=self.calendar_by, - group_by=self.group_by, + **kwargs ), self, ) @@ -704,6 +720,7 @@ def __init__(self, collection, result, query): self.collection = collection self._client = collection._client self._block_ids = self._get_block_ids(result) + self.total = result.get("total", -1) self.aggregates = result.get("aggregationResults", []) self.aggregate_ids = [ agg.get("id") for agg in (query.aggregate or query.aggregations) @@ -711,7 +728,7 @@ def __init__(self, collection, result, query): self.query = query def _get_block_ids(self, result): - return result["blockIds"] + return result['reducerResults']['collection_group_results']["blockIds"] def _get_block(self, id): block = CollectionRowBlock(self._client, id) @@ -754,7 +771,6 @@ def __contains__(self, item): return False return item_id in self._block_ids - class TableQueryResult(QueryResult): _type = "table" diff --git a/notion/store.py b/notion/store.py index 57620c96..59fe76fc 100644 --- a/notion/store.py +++ b/notion/store.py @@ -174,14 +174,14 @@ def get_role(self, table, id, force_refresh=False): self.get(table, id, force_refresh=force_refresh) return self._role[table].get(id, None) - def get(self, table, id, force_refresh=False): + def get(self, table, id, force_refresh=False, limit=100): id = extract_id(id) # look up the record in the current local dataset result = self._get(table, id) # if it's not found, try refreshing the record from the server if result is Missing or force_refresh: if table == "block": - self.call_load_page_chunk(id) + self.call_load_page_chunk(id,limit=limit) else: self.call_get_record_values(**{table: id}) result = self._get(table, id) @@ -269,15 +269,17 @@ def get_current_version(self, table, id): else: return -1 - def call_load_page_chunk(self, page_id): + def call_load_page_chunk(self, page_id, limit=100): if self._client.in_transaction(): self._pages_to_refresh.append(page_id) return data = { - "pageId": page_id, - "limit": 100000, + "page": { + "id": page_id, + }, + "limit": limit, "cursor": {"stack": []}, "chunkNumber": 0, "verticalColumns": False, @@ -310,6 +312,7 @@ def call_query_collection( sort=[], calendar_by="", group_by="", + limit=50 ): assert not ( @@ -323,21 +326,25 @@ def call_query_collection( sort = [sort] data = { - "collectionId": collection_id, - "collectionViewId": collection_view_id, + "collection": { + "id": collection_id, + "spaceId": self._client.current_space.id + }, + "collectionView": { + "id": collection_view_id, + "spaceId": self._client.current_space.id + }, "loader": { - "limit": 10000, - "loadContentCover": True, + 'reducers': { + 'collection_group_results': { + 'limit': limit, + 'type': 'results', + }, + }, "searchQuery": search, - "userLocale": "en", + 'sort': sort, "userTimeZone": str(get_localzone()), - "type": type, - }, - "query": { - "aggregate": aggregate, - "aggregations": aggregations, - "filter": filter, - "sort": sort, + "type": 'reducer', }, }