Skip to content

Commit 80781ef

Browse files
aholyokeLinchin
andauthored
fix: Use except distinct and intersect distinct (#1094)
Co-authored-by: Lingqing Gan <[email protected]>
1 parent 9e0b117 commit 80781ef

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

sqlalchemy_bigquery/base.py

+2
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class BigQueryCompiler(_struct.SQLCompiler, vendored_postgresql.PGCompiler):
194194
compound_keywords = SQLCompiler.compound_keywords.copy()
195195
compound_keywords[selectable.CompoundSelect.UNION] = "UNION DISTINCT"
196196
compound_keywords[selectable.CompoundSelect.UNION_ALL] = "UNION ALL"
197+
compound_keywords[selectable.CompoundSelect.EXCEPT] = "EXCEPT DISTINCT"
198+
compound_keywords[selectable.CompoundSelect.INTERSECT] = "INTERSECT DISTINCT"
197199

198200
def __init__(self, dialect, statement, *args, **kwargs):
199201
if isinstance(statement, Column):

tests/unit/test_select.py

+88
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,94 @@ def test_typed_parameters(faux_conn, type_, val, btype, vrep):
168168
)
169169

170170

171+
def test_except(faux_conn):
172+
table = setup_table(
173+
faux_conn,
174+
"table",
175+
sqlalchemy.Column("id", sqlalchemy.Integer),
176+
sqlalchemy.Column("foo", sqlalchemy.Integer),
177+
)
178+
179+
s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
180+
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)
181+
182+
s3 = s1.except_(s2)
183+
184+
result = s3.compile(faux_conn).string
185+
186+
expected = (
187+
"SELECT `table`.`foo` \n"
188+
"FROM `table` \n"
189+
"WHERE `table`.`id` >= %(id_1:INT64)s EXCEPT DISTINCT SELECT `table`.`foo` \n"
190+
"FROM `table` \n"
191+
"WHERE `table`.`id` >= %(id_2:INT64)s"
192+
)
193+
assert result == expected
194+
195+
196+
def test_intersect(faux_conn):
197+
table = setup_table(
198+
faux_conn,
199+
"table",
200+
sqlalchemy.Column("id", sqlalchemy.Integer),
201+
sqlalchemy.Column("foo", sqlalchemy.Integer),
202+
)
203+
204+
s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
205+
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)
206+
207+
s3 = s1.intersect(s2)
208+
209+
result = s3.compile(faux_conn).string
210+
211+
expected = (
212+
"SELECT `table`.`foo` \n"
213+
"FROM `table` \n"
214+
"WHERE `table`.`id` >= %(id_1:INT64)s INTERSECT DISTINCT SELECT `table`.`foo` \n"
215+
"FROM `table` \n"
216+
"WHERE `table`.`id` >= %(id_2:INT64)s"
217+
)
218+
assert result == expected
219+
220+
221+
def test_union(faux_conn):
222+
table = setup_table(
223+
faux_conn,
224+
"table",
225+
sqlalchemy.Column("id", sqlalchemy.Integer),
226+
sqlalchemy.Column("foo", sqlalchemy.Integer),
227+
)
228+
229+
s1 = sqlalchemy.select(table.c.foo).where(table.c.id >= 2)
230+
s2 = sqlalchemy.select(table.c.foo).where(table.c.id >= 4)
231+
232+
s3 = s1.union(s2)
233+
234+
result = s3.compile(faux_conn).string
235+
236+
expected = (
237+
"SELECT `table`.`foo` \n"
238+
"FROM `table` \n"
239+
"WHERE `table`.`id` >= %(id_1:INT64)s UNION DISTINCT SELECT `table`.`foo` \n"
240+
"FROM `table` \n"
241+
"WHERE `table`.`id` >= %(id_2:INT64)s"
242+
)
243+
assert result == expected
244+
245+
s4 = s1.union_all(s2)
246+
247+
result = s4.compile(faux_conn).string
248+
249+
expected = (
250+
"SELECT `table`.`foo` \n"
251+
"FROM `table` \n"
252+
"WHERE `table`.`id` >= %(id_1:INT64)s UNION ALL SELECT `table`.`foo` \n"
253+
"FROM `table` \n"
254+
"WHERE `table`.`id` >= %(id_2:INT64)s"
255+
)
256+
assert result == expected
257+
258+
171259
def test_select_struct(faux_conn, metadata):
172260
from sqlalchemy_bigquery import STRUCT
173261

0 commit comments

Comments
 (0)