Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using SQL functions for default values. #20

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,75 @@ except: print("Delete succeeded!")
| sqlite3.dbapi2.OperationalError | apsw.Error | General error, OperationalError is now proxied to apsw.Error |
| sqlite3.dbapi2.OperationalError | apsw.SQLError | When an error is due to flawed SQL statements |
| sqlite3.ProgrammingError | apsw.ConnectionClosedError | Caused by an improperly closed database file |

## Handling of default values

Default values are handled as expected, including expression-based
default values:

``` python
db.execute("""
DROP TABLE IF EXISTS migrations;
CREATE TABLE IF NOT EXISTS migrations (
id INTEGER PRIMARY KEY,
name TEXT DEFAULT 'foo',
cexpr TEXT DEFAULT ('abra' || 'cadabra'),
rand INTEGER DEFAULT (random()),
unix_epoch FLOAT DEFAULT (unixepoch('subsec')),
json_array JSON DEFAULT (json_array(1,2,3,4)),
inserted_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL
);
""")
```

<apsw.Cursor>

``` python
migrations = Table(db, 'migrations')
migrations.default_values
```

{'name': 'foo',
'cexpr': SQLExpr: 'abra' || 'cadabra',
'rand': SQLExpr: random(),
'unix_epoch': SQLExpr: unixepoch('subsec'),
'json_array': SQLExpr: json_array(1,2,3,4),
'inserted_at': SQLExpr: CURRENT_TIMESTAMP}

``` python
assert all([type(x) is SQLExpr for x in list(migrations.default_values.values())[1:]])
```

``` python
migrations.insert(dict(id=0))
migrations.insert(dict(id=1))
```

<Table migrations (id, name, cexpr, rand, unix_epoch, json_array, inserted_at)>

Default expressions are executed independently for each row on row
insertion:

``` python
rows = list(migrations.rows)
rows
```

[{'id': 0,
'name': 'foo',
'cexpr': 'abracadabra',
'rand': 8201569685582150332,
'unix_epoch': 1741481111.188,
'json_array': '[1,2,3,4]',
'inserted_at': '2025-03-09 00:45:11'},
{'id': 1,
'name': 'foo',
'cexpr': 'abracadabra',
'rand': 1625289491289542947,
'unix_epoch': 1741481111.19,
'json_array': '[1,2,3,4]',
'inserted_at': '2025-03-09 00:45:11'}]

``` python
assert rows[0]['rand'] != rows[1]['rand']
```
19 changes: 16 additions & 3 deletions apswutils/db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file is from sqlite-utils and copyright and license is the same as that project
__all__ = ['Database', 'Queryable', 'Table', 'View']
__all__ = ['Database', 'Queryable', 'Table', 'View', 'SQLExpr']

from .utils import chunks, hash_record, suggest_column_types, types_for_column_types, column_affinity, find_spatialite, cursor_row2dict
from collections import namedtuple
Expand All @@ -8,6 +8,7 @@
from functools import cache
import contextlib, datetime, decimal, inspect, itertools, json, os, pathlib, re, secrets, textwrap, binascii, uuid, logging
import apsw, apsw.ext, apsw.bestpractice
from fastcore.all import asdict

logger = logging.getLogger('apsw')
logger.setLevel(logging.ERROR)
Expand Down Expand Up @@ -3121,6 +3122,7 @@ def insert_all(
num_records_processed = 0
# Fix up any records with square braces in the column names
records = fix_square_braces(records)
records = remove_default_sql_exprs(records)
# We can only handle a max of 999 variables in a SQL insert, so
# we need to adjust the batch_size down if we have too many cols
records = iter(records)
Expand Down Expand Up @@ -3715,9 +3717,20 @@ def fix_square_braces(records: Iterable[Dict[str, Any]]):
else:
yield record

def remove_default_sql_exprs(records: Iterable[Dict[str, Any]]):
for record in records:
yield {k: v for k, v in asdict(record).items() if type(v) is not SQLExpr or not v.default}

class SQLExpr():
def __init__(self, expr, default=False): self.expr, self.default = expr, default
def __str__(self): return f'SQLExpr: {self.expr}'
__repr__ = __str__

# Match anything that is not a single quote, then match anything that is an escaped single quote
# (any number of times), then repeat the whole process
_sql_string_datatype_matcher = re.compile(r"^'([^']*(\\')*)*'$")
def _decode_default_value(value):
if value.startswith("'") and value.endswith("'"):
if _sql_string_datatype_matcher.match(value):
# It's a string
return value[1:-1]
if value.isdigit():
Expand All @@ -3732,4 +3745,4 @@ def _decode_default_value(value):
return float(value)
except ValueError:
pass
return value
return SQLExpr(value, True)
153 changes: 152 additions & 1 deletion nbs/index.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,157 @@
"|sqlite3.dbapi2.OperationalError|apsw.SQLError|When an error is due to flawed SQL statements|\n",
"|sqlite3.ProgrammingError|apsw.ConnectionClosedError|Caused by an improperly closed database file|\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Handling of default values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Default values are handled as expected, including expression-based default values:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<apsw.Cursor>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db.execute(\"\"\"\n",
"DROP TABLE IF EXISTS migrations;\n",
"CREATE TABLE IF NOT EXISTS migrations (\n",
" id INTEGER PRIMARY KEY,\n",
" name TEXT DEFAULT 'foo',\n",
" cexpr TEXT DEFAULT ('abra' || 'cadabra'),\n",
" rand INTEGER DEFAULT (random()),\n",
" unix_epoch FLOAT DEFAULT (unixepoch('subsec')),\n",
" json_array JSON DEFAULT (json_array(1,2,3,4)),\n",
" inserted_at DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL\n",
");\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'name': 'foo',\n",
" 'cexpr': SQLExpr: 'abra' || 'cadabra',\n",
" 'rand': SQLExpr: random(),\n",
" 'unix_epoch': SQLExpr: unixepoch('subsec'),\n",
" 'json_array': SQLExpr: json_array(1,2,3,4),\n",
" 'inserted_at': SQLExpr: CURRENT_TIMESTAMP}"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"migrations = Table(db, 'migrations')\n",
"migrations.default_values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert all([type(x) is SQLExpr for x in list(migrations.default_values.values())[1:]])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<Table migrations (id, name, cexpr, rand, unix_epoch, json_array, inserted_at)>"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"migrations.insert(dict(id=0))\n",
"migrations.insert(dict(id=1))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Default expressions are executed independently for each row on row insertion:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[{'id': 0,\n",
" 'name': 'foo',\n",
" 'cexpr': 'abracadabra',\n",
" 'rand': 8201569685582150332,\n",
" 'unix_epoch': 1741481111.188,\n",
" 'json_array': '[1,2,3,4]',\n",
" 'inserted_at': '2025-03-09 00:45:11'},\n",
" {'id': 1,\n",
" 'name': 'foo',\n",
" 'cexpr': 'abracadabra',\n",
" 'rand': 1625289491289542947,\n",
" 'unix_epoch': 1741481111.19,\n",
" 'json_array': '[1,2,3,4]',\n",
" 'inserted_at': '2025-03-09 00:45:11'}]"
]
},
"execution_count": null,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"rows = list(migrations.rows)\n",
"rows"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"assert rows[0]['rand'] != rows[1]['rand']"
]
}
],
"metadata": {
Expand All @@ -587,5 +738,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}