Skip to content

Commit 74bcfde

Browse files
committed
implement INSERT ALL statement
1 parent 7baaa1a commit 74bcfde

File tree

4 files changed

+270
-1
lines changed

4 files changed

+270
-1
lines changed

src/snowflake/sqlalchemy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
CreateStage,
3838
CSVFormatter,
3939
ExternalStage,
40+
InsertMulti,
4041
JSONFormatter,
4142
MergeInto,
4243
PARQUETFormatter,

src/snowflake/sqlalchemy/base.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,49 @@ def visit_merge_into_clause(self, merge_into_clause, **kw):
205205
" SET %s" % sets if merge_into_clause.set else "",
206206
)
207207

208+
def visit_insert_multi(self, insert_all, **kw):
209+
clauses = []
210+
for condition, table, columns, values in insert_all.clauses:
211+
clauses.append(
212+
(
213+
f"WHEN {condition._compiler_dispatch(self, include_table=False, **kw)} THEN "
214+
if condition is not None
215+
else ""
216+
)
217+
+ f"INTO {table._compiler_dispatch(self, asfrom=True, **kw)}"
218+
+ (
219+
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in columns)})"
220+
if columns
221+
else ""
222+
)
223+
+ (
224+
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, **kw) for v in values)})"
225+
if values
226+
else ""
227+
)
228+
)
229+
230+
source = insert_all.source._compiler_dispatch(self, **kw)
231+
if insert_all.else__:
232+
else_ = (
233+
f" ELSE {insert_all.else__[0]._compiler_dispatch(self, asfrom=True, **kw)}"
234+
+ (
235+
f" ({', '.join(c._compiler_dispatch(self, include_table=False, **kw) for c in insert_all.else__[1])})"
236+
if insert_all.else__[1]
237+
else ""
238+
)
239+
+ (
240+
f" VALUES ({', '.join(v._compiler_dispatch(self, include_table=False, **kw) for v in insert_all.else__[2])})"
241+
if insert_all.else__[2]
242+
else ""
243+
)
244+
)
245+
else:
246+
else_ = ""
247+
overwrite = " OVERWRITE" if insert_all.overwrite else ""
248+
condition = "FIRST" if insert_all.is_conditional and insert_all.first else "ALL"
249+
return f"INSERT{overwrite} {condition} {' '.join(clauses)}{else_} {source}"
250+
208251
def visit_copy_into(self, copy_into, **kw):
209252
if hasattr(copy_into, "formatter") and copy_into.formatter is not None:
210253
formatter = copy_into.formatter._compiler_dispatch(self, **kw)

src/snowflake/sqlalchemy/custom_commands.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,92 @@ def when_not_matched_then_insert(self):
9494
return clause
9595

9696

97+
class InsertMulti(UpdateBase):
98+
__visit_name__ = "insert_multi"
99+
_bind = None
100+
101+
def __init__(self, source, overwrite=False, first=False):
102+
self.source = source
103+
self.overwrite = overwrite
104+
self.first = first
105+
self.clauses = []
106+
self.else__ = None
107+
108+
@property
109+
def is_conditional(self):
110+
return any(condition is not None for condition, _, _, _ in self.clauses)
111+
112+
def __repr__(self):
113+
clauses = []
114+
for condition, table, columns, values in self.clauses:
115+
clauses.append(
116+
(f"WHEN {condition!r} THEN " if condition is not None else "")
117+
+ f" INTO {table!r}"
118+
+ (f"({', '.join(repr(c) for c in columns)})" if columns else "")
119+
+ (f" VALUES ({', '.join(str(v) for v in values)})" if values else "")
120+
)
121+
else_ = f" ELSE {self.else__!r}" if self.else__ else ""
122+
overwrite = " OVERWRITE" if self.overwrite else ""
123+
condition = "FIRST" if self.is_conditional and self.first else "ALL"
124+
return (
125+
f"INSERT{overwrite} {condition} {', '.join(clauses)}{else_} {self.source}"
126+
)
127+
128+
def _adapt_columns(self, columns, coll):
129+
"""Make sure all columns are column instances from the given table, not strings"""
130+
if columns is None:
131+
return None
132+
return [coll[c] if isinstance(c, str) else c for c in columns]
133+
134+
def into(self, table, columns=None, values=None):
135+
if self.is_conditional:
136+
raise ValueError(
137+
"Cannot add an unconditional clause to a Conditional multi-table insert"
138+
)
139+
if columns and values:
140+
assert len(columns) == len(
141+
values
142+
), "columns and values must be of the same length"
143+
self.clauses.append(
144+
(
145+
None,
146+
table,
147+
self._adapt_columns(columns, table.c),
148+
self._adapt_columns(values, self.source.selected_columns),
149+
)
150+
)
151+
return self
152+
153+
def when(self, condition, table, columns=None, values=None):
154+
if self.clauses and not self.is_conditional:
155+
raise ValueError(
156+
"Cannot add a conditional clause to an Unconditional multi-table insert"
157+
)
158+
if columns and values:
159+
assert len(columns) == len(
160+
values
161+
), "columns and values must be of the same length"
162+
self.clauses.append(
163+
(
164+
condition,
165+
table,
166+
self._adapt_columns(columns, table.c),
167+
self._adapt_columns(values, self.source.selected_columns),
168+
)
169+
)
170+
return self
171+
172+
def else_(self, table, columns=None, values=None):
173+
if self.clauses and not self.is_conditional:
174+
raise ValueError("Cannot set ELSE on an Unconditional multi-table insert")
175+
self.else__ = (
176+
table,
177+
self._adapt_columns(columns, table.c),
178+
self._adapt_columns(values, self.source.selected_columns),
179+
)
180+
return self
181+
182+
97183
class FilesOption:
98184
"""
99185
Class to represent FILES option for the snowflake COPY INTO statement

tests/test_core.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
Table,
3030
UniqueConstraint,
3131
dialects,
32+
func,
3233
inspect,
3334
text,
3435
)
@@ -39,7 +40,7 @@
3940
import snowflake.connector.errors
4041
import snowflake.sqlalchemy.snowdialect
4142
from snowflake.connector import Error, ProgrammingError, connect
42-
from snowflake.sqlalchemy import URL, MergeInto, dialect
43+
from snowflake.sqlalchemy import URL, InsertMulti, MergeInto, dialect
4344
from snowflake.sqlalchemy._constants import (
4445
APPLICATION_NAME,
4546
SNOWFLAKE_SQLALCHEMY_VERSION,
@@ -1290,6 +1291,144 @@ def test_deterministic_merge_into(sql_compiler):
12901291
)
12911292

12921293

1294+
def test_unconditional_insert_all(sql_compiler):
1295+
meta = MetaData()
1296+
users1 = Table(
1297+
"users1",
1298+
meta,
1299+
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
1300+
Column("name", String),
1301+
Column("fullname", String),
1302+
Column("created_at", DateTime),
1303+
)
1304+
users2 = Table(
1305+
"users2",
1306+
meta,
1307+
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
1308+
Column("name", String),
1309+
Column("full/name", String),
1310+
)
1311+
onboarding_users = Table(
1312+
"onboarding_users",
1313+
meta,
1314+
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
1315+
Column("name", String),
1316+
Column("fullname", String),
1317+
Column("delete", Boolean),
1318+
)
1319+
insert_all = (
1320+
InsertMulti(
1321+
select(
1322+
onboarding_users.c.id,
1323+
onboarding_users.c.name,
1324+
onboarding_users.c.fullname,
1325+
)
1326+
)
1327+
.into(users1)
1328+
.into(users2)
1329+
)
1330+
assert (
1331+
sql_compiler(insert_all) == "INSERT ALL INTO users1 INTO users2 "
1332+
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname "
1333+
"FROM onboarding_users"
1334+
)
1335+
1336+
stmt = select(
1337+
onboarding_users.c.id,
1338+
onboarding_users.c.name,
1339+
onboarding_users.c.fullname,
1340+
onboarding_users.c.delete,
1341+
)
1342+
insert_all = (
1343+
InsertMulti(stmt)
1344+
.into(
1345+
users1,
1346+
["id", "name", users1.c.fullname, users1.c.created_at],
1347+
[
1348+
"id",
1349+
"name",
1350+
stmt.selected_columns.fullname,
1351+
func.now(),
1352+
],
1353+
)
1354+
.into(
1355+
users2,
1356+
[users2.c.name, users2.c["full/name"]],
1357+
[stmt.selected_columns.fullname, stmt.selected_columns.name],
1358+
)
1359+
)
1360+
assert (
1361+
sql_compiler(insert_all) == "INSERT ALL "
1362+
"INTO users1 (id, name, fullname, created_at) VALUES (id, name, fullname, CURRENT_TIMESTAMP) "
1363+
'INTO users2 (name, "full/name") VALUES (fullname, name) '
1364+
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, "
1365+
'onboarding_users."delete" FROM onboarding_users'
1366+
)
1367+
1368+
1369+
def test_conditional_insert_multi(sql_compiler):
1370+
meta = MetaData()
1371+
users1 = Table(
1372+
"users1",
1373+
meta,
1374+
Column("id", Integer, Sequence("user_id_seq"), primary_key=True),
1375+
Column("name", String),
1376+
Column("fullname", String),
1377+
)
1378+
users2 = Table(
1379+
"users2",
1380+
meta,
1381+
Column("id", Integer, Sequence("user_id_seq2"), primary_key=True),
1382+
Column("name", String),
1383+
Column("full/name", String),
1384+
)
1385+
onboarding_users = Table(
1386+
"onboarding_users",
1387+
meta,
1388+
Column("id", Integer, Sequence("new_user_id_seq"), primary_key=True),
1389+
Column("name", String),
1390+
Column("fullname", String),
1391+
Column("delete", Boolean),
1392+
)
1393+
stmt = select(
1394+
onboarding_users.c.id,
1395+
onboarding_users.c.name,
1396+
onboarding_users.c.fullname,
1397+
onboarding_users.c.delete,
1398+
)
1399+
insert_all = (
1400+
InsertMulti(stmt)
1401+
.when(
1402+
stmt.selected_columns.delete,
1403+
users1,
1404+
values=[
1405+
stmt.selected_columns.id,
1406+
stmt.selected_columns.name,
1407+
stmt.selected_columns.fullname,
1408+
],
1409+
)
1410+
.when(
1411+
~stmt.selected_columns.delete,
1412+
users2,
1413+
[users2.c.id, users2.c.name, users2.c["full/name"]],
1414+
[
1415+
stmt.selected_columns.id,
1416+
stmt.selected_columns.name,
1417+
stmt.selected_columns.fullname,
1418+
],
1419+
)
1420+
.else_(users1)
1421+
)
1422+
assert (
1423+
sql_compiler(insert_all) == "INSERT ALL "
1424+
'WHEN "delete" THEN INTO users1 VALUES (id, name, fullname) '
1425+
'WHEN NOT "delete" THEN INTO users2 (id, name, "full/name") VALUES (id, name, fullname) '
1426+
"ELSE users1 "
1427+
"SELECT onboarding_users.id, onboarding_users.name, onboarding_users.fullname, "
1428+
'onboarding_users."delete" FROM onboarding_users'
1429+
)
1430+
1431+
12931432
def test_comments(engine_testaccount):
12941433
"""Tests strictly reading column comment through SQLAlchemy"""
12951434
table_name = random_string(5, choices=string.ascii_uppercase)

0 commit comments

Comments
 (0)