Skip to content

Commit 4a006bf

Browse files
committed
fix!: Fix get_table_names and get_view_names without a default dataset
1 parent 155e5b0 commit 4a006bf

File tree

3 files changed

+65
-87
lines changed

3 files changed

+65
-87
lines changed

sqlalchemy_bigquery/base.py

+25-34
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@
2222
import datetime
2323
from decimal import Decimal
2424
import random
25-
import operator
2625
import uuid
2726

2827
from google import auth
29-
import google.api_core.exceptions
3028
from google.cloud.bigquery import dbapi
3129
from google.cloud.bigquery.table import (
3230
RangePartitioning,
@@ -1047,11 +1045,6 @@ def dbapi(cls):
10471045
def import_dbapi(cls):
10481046
return dbapi
10491047

1050-
@staticmethod
1051-
def _build_formatted_table_id(table):
1052-
"""Build '<dataset_id>.<table_id>' string using given table."""
1053-
return "{}.{}".format(table.reference.dataset_id, table.table_id)
1054-
10551048
@staticmethod
10561049
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
10571050
# If dataset_id is set, then we know the job_config isn't None
@@ -1100,36 +1093,34 @@ def create_connect_args(self, url):
11001093
)
11011094
return ([], {"client": client})
11021095

1103-
def _get_table_or_view_names(self, connection, item_types, schema=None):
1104-
current_schema = schema or self.dataset_id
1105-
get_table_name = (
1106-
self._build_formatted_table_id
1107-
if self.dataset_id is None
1108-
else operator.attrgetter("table_id")
1109-
)
1096+
def _get_default_schema_name(self, connection) -> str:
1097+
return connection.dialect.dataset_id
11101098

1099+
def _get_table_or_view_names(self, connection, item_types, schema=None):
11111100
client = connection.connection._client
1112-
datasets = client.list_datasets()
1113-
1114-
result = []
1115-
for dataset in datasets:
1116-
if current_schema is not None and current_schema != dataset.dataset_id:
1117-
continue
1118-
1119-
try:
1120-
tables = client.list_tables(
1121-
dataset.reference, page_size=self.list_tables_page_size
1101+
# `schema=None` means to search the default schema. If one isn't set in the
1102+
# connection string, then we have nothing to search so return an empty list.
1103+
#
1104+
# When using Alembic with `include_schemas=False`, it expects to compare to a
1105+
# single schema. If `include_schemas=True`, it will enumerate all schemas and
1106+
# then call `get_table_names`/`get_view_names` for each schema.
1107+
current_schema = schema or self.default_schema_name
1108+
if current_schema is None:
1109+
return []
1110+
try:
1111+
return [
1112+
table.table_id
1113+
for table in client.list_tables(
1114+
current_schema, page_size=self.list_tables_page_size
11221115
)
1123-
for table in tables:
1124-
if table.table_type in item_types:
1125-
result.append(get_table_name(table))
1126-
except google.api_core.exceptions.NotFound:
1127-
# It's possible that the dataset was deleted between when we
1128-
# fetched the list of datasets and when we try to list the
1129-
# tables from it. See:
1130-
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105
1131-
pass
1132-
return result
1116+
if table.table_type in item_types
1117+
]
1118+
except NotFound:
1119+
# It's possible that the dataset was deleted between when we
1120+
# fetched the list of datasets and when we try to list the
1121+
# tables from it. See:
1122+
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105
1123+
return []
11331124

11341125
@staticmethod
11351126
def _split_table_name(full_table_name):

tests/unit/fauxdbi.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -482,8 +482,8 @@ def list_tables(self, dataset, page_size):
482482
google.cloud.bigquery.table.TableListItem(
483483
dict(
484484
tableReference=dict(
485-
projectId=dataset.project,
486-
datasetId=dataset.dataset_id,
485+
projectId="myproject",
486+
datasetId=dataset,
487487
tableId=row["name"],
488488
),
489489
type=row["type"].upper(),

tests/unit/test_sqlalchemy_bigquery.py

+38-51
Original file line numberDiff line numberDiff line change
@@ -65,83 +65,70 @@ def table_item(dataset_id, table_id, type_="TABLE"):
6565

6666

6767
@pytest.mark.parametrize(
68-
["datasets_list", "tables_lists", "expected"],
68+
["dataset", "tables_list", "expected"],
6969
[
70-
([], [], []),
71-
([dataset_item("dataset_1")], [[]], []),
70+
(None, [], []),
71+
("dataset", [], []),
7272
(
73-
[dataset_item("dataset_1"), dataset_item("dataset_2")],
73+
"dataset",
7474
[
75-
[table_item("dataset_1", "d1t1"), table_item("dataset_1", "d1t2")],
76-
[
77-
table_item("dataset_2", "d2t1"),
78-
table_item("dataset_2", "d2view", type_="VIEW"),
79-
table_item("dataset_2", "d2ext", type_="EXTERNAL"),
80-
table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"),
81-
],
75+
table_item("dataset", "t1"),
76+
table_item("dataset", "view", type_="VIEW"),
77+
table_item("dataset", "ext", type_="EXTERNAL"),
78+
table_item("dataset", "mv", type_="MATERIALIZED_VIEW"),
8279
],
83-
["dataset_1.d1t1", "dataset_1.d1t2", "dataset_2.d2t1", "dataset_2.d2ext"],
80+
["t1", "ext"],
8481
),
8582
(
86-
[dataset_item("dataset_1"), dataset_item("dataset_deleted")],
87-
[
88-
[table_item("dataset_1", "d1t1")],
89-
google.api_core.exceptions.NotFound("dataset_deleted"),
90-
],
91-
["dataset_1.d1t1"],
83+
"dataset",
84+
google.api_core.exceptions.NotFound("dataset_deleted"),
85+
[],
9286
),
9387
],
9488
)
9589
def test_get_table_names(
96-
engine_under_test, mock_bigquery_client, datasets_list, tables_lists, expected
90+
engine_under_test, mock_bigquery_client, dataset, tables_list, expected
9791
):
98-
mock_bigquery_client.list_datasets.return_value = datasets_list
99-
mock_bigquery_client.list_tables.side_effect = tables_lists
100-
table_names = sqlalchemy.inspect(engine_under_test).get_table_names()
101-
mock_bigquery_client.list_datasets.assert_called_once()
102-
assert mock_bigquery_client.list_tables.call_count == len(datasets_list)
92+
mock_bigquery_client.list_tables.side_effect = [tables_list]
93+
table_names = sqlalchemy.inspect(engine_under_test).get_table_names(schema=dataset)
94+
if dataset:
95+
mock_bigquery_client.list_tables.assert_called_once()
96+
else:
97+
mock_bigquery_client.list_tables.assert_not_called()
10398
assert list(sorted(table_names)) == list(sorted(expected))
10499

105100

106101
@pytest.mark.parametrize(
107-
["datasets_list", "tables_lists", "expected"],
102+
["dataset", "tables_list", "expected"],
108103
[
109-
([], [], []),
110-
([dataset_item("dataset_1")], [[]], []),
104+
(None, [], []),
105+
("dataset", [], []),
111106
(
112-
[dataset_item("dataset_1"), dataset_item("dataset_2")],
107+
"dataset",
113108
[
114-
[
115-
table_item("dataset_1", "d1t1"),
116-
table_item("dataset_1", "d1view", type_="VIEW"),
117-
],
118-
[
119-
table_item("dataset_2", "d2t1"),
120-
table_item("dataset_2", "d2view", type_="VIEW"),
121-
table_item("dataset_2", "d2ext", type_="EXTERNAL"),
122-
table_item("dataset_2", "d2mv", type_="MATERIALIZED_VIEW"),
123-
],
109+
table_item("dataset", "t1"),
110+
table_item("dataset", "view", type_="VIEW"),
111+
table_item("dataset", "ext", type_="EXTERNAL"),
112+
table_item("dataset", "mv", type_="MATERIALIZED_VIEW"),
124113
],
125-
["dataset_1.d1view", "dataset_2.d2view", "dataset_2.d2mv"],
114+
["view", "mv"],
126115
),
127116
(
128-
[dataset_item("dataset_1"), dataset_item("dataset_deleted")],
129-
[
130-
[table_item("dataset_1", "d1view", type_="VIEW")],
131-
google.api_core.exceptions.NotFound("dataset_deleted"),
132-
],
133-
["dataset_1.d1view"],
117+
"dataset_deleted",
118+
google.api_core.exceptions.NotFound("dataset_deleted"),
119+
[],
134120
),
135121
],
136122
)
137123
def test_get_view_names(
138-
inspector_under_test, mock_bigquery_client, datasets_list, tables_lists, expected
124+
inspector_under_test, mock_bigquery_client, dataset, tables_list, expected
139125
):
140-
mock_bigquery_client.list_datasets.return_value = datasets_list
141-
mock_bigquery_client.list_tables.side_effect = tables_lists
142-
view_names = inspector_under_test.get_view_names()
143-
mock_bigquery_client.list_datasets.assert_called_once()
144-
assert mock_bigquery_client.list_tables.call_count == len(datasets_list)
126+
mock_bigquery_client.list_tables.side_effect = [tables_list]
127+
view_names = inspector_under_test.get_view_names(schema=dataset)
128+
if dataset:
129+
mock_bigquery_client.list_tables.assert_called_once()
130+
else:
131+
mock_bigquery_client.list_tables.assert_not_called()
145132
assert list(sorted(view_names)) == list(sorted(expected))
146133

147134

0 commit comments

Comments
 (0)