Skip to content

Commit 89af1ee

Browse files
Merge pull request #639 from dimitri-yatsenko/custom-attribute-type
Custom attribute type: implement #627
2 parents fc21b5c + bb2e5fe commit 89af1ee

File tree

10 files changed

+315
-68
lines changed

10 files changed

+315
-68
lines changed

datajoint/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
'Manual', 'Lookup', 'Imported', 'Computed', 'Part',
2424
'Not', 'AndList', 'U', 'Diagram', 'Di', 'ERD',
2525
'set_password', 'kill',
26-
'MatCell', 'MatStruct',
26+
'MatCell', 'MatStruct', 'AttributeAdapter',
2727
'errors', 'DataJointError', 'key']
2828

2929
from .version import __version__
@@ -38,6 +38,7 @@
3838
from .admin import set_password, kill
3939
from .blob import MatCell, MatStruct
4040
from .fetch import key
41+
from .attribute_adapter import AttributeAdapter
4142
from . import errors
4243
from .errors import DataJointError
4344

datajoint/attribute_adapter.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import re
2+
import os
3+
from .errors import DataJointError
4+
5+
6+
ADAPTED_TYPE_SWITCH = "DJ_SUPPORT_ADAPTED_TYPES"
7+
8+
9+
def _switch_adapated_types(on):
10+
"""
11+
Enable (on=True) or disable (on=False) support for AttributeAdapter
12+
"""
13+
if on:
14+
os.environ[ADAPTED_TYPE_SWITCH] = "TRUE"
15+
else:
16+
del os.environ[ADAPTED_TYPE_SWITCH]
17+
18+
19+
def _support_adapted_types():
20+
"""
21+
check if support for AttributeAdapter is enabled
22+
"""
23+
return os.getenv(ADAPTED_TYPE_SWITCH, "FALSE").upper() == "TRUE"
24+
25+
26+
class AttributeAdapter:
27+
"""
28+
Base class for adapter objects for user-defined attribute types.
29+
"""
30+
@property
31+
def attribute_type(self):
32+
"""
33+
:return: a supported DataJoint attribute type to use; e.g. "longblob", "blob@store"
34+
"""
35+
raise NotImplementedError('Undefined attribute adapter')
36+
37+
def get(self, value):
38+
"""
39+
convert value retrieved from the the attribute in a table into the adapted type
40+
:param value: value from the database
41+
:return: object of the adapted type
42+
"""
43+
raise NotImplementedError('Undefined attribute adapter')
44+
45+
def put(self, obj):
46+
"""
47+
convert an object of the adapted type into a value that DataJoint can store in a table attribute
48+
:param object: an object of the adapted type
49+
:return: value to store in the database
50+
"""
51+
raise NotImplementedError('Undefined attribute adapter')
52+
53+
54+
def get_adapter(context, adapter_name):
55+
"""
56+
Extract the AttributeAdapter object by its name from the context and validate.
57+
"""
58+
if not _support_adapted_types():
59+
raise DataJointError('Support for Adapted Attribute types is disabled.')
60+
adapter_name = adapter_name.lstrip('<').rstrip('>')
61+
try:
62+
adapter = context[adapter_name]
63+
except KeyError:
64+
raise DataJointError(
65+
"Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name)) from None
66+
if not isinstance(adapter, AttributeAdapter):
67+
raise DataJointError(
68+
"Attribute adapter '{adapter_name}' must be an instance of datajoint.AttributeAdapter".format(
69+
adapter_name=adapter_name))
70+
if not isinstance(adapter.attribute_type, str) or not re.match(r'^\w', adapter.attribute_type):
71+
raise DataJointError("Invalid attribute type {type} in attribute adapter '{adapter_name}'".format(
72+
type=adapter.attribute_type, adapter_name=adapter_name))
73+
return adapter

datajoint/declare.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pyparsing as pp
77
import logging
88
from .errors import DataJointError
9+
from .attribute_adapter import get_adapter
910

1011
from .utils import OrderedDict
1112

@@ -27,21 +28,24 @@
2728
INTERNAL_ATTACH=r'attach$',
2829
EXTERNAL_ATTACH=r'attach@(?P<store>[a-z]\w*)$',
2930
FILEPATH=r'filepath@(?P<store>[a-z]\w*)$',
30-
UUID=r'uuid$').items()}
31+
UUID=r'uuid$',
32+
ADAPTED=r'<.+>$'
33+
).items()}
3134

32-
CUSTOM_TYPES = {'UUID', 'INTERNAL_ATTACH', 'EXTERNAL_ATTACH', 'EXTERNAL_BLOB', 'FILEPATH'} # types stored in attribute comment
35+
# custom types are stored in attribute comment
36+
SPECIAL_TYPES = {'UUID', 'INTERNAL_ATTACH', 'EXTERNAL_ATTACH', 'EXTERNAL_BLOB', 'FILEPATH', 'ADAPTED'}
37+
NATIVE_TYPES = set(TYPE_PATTERN) - SPECIAL_TYPES
3338
EXTERNAL_TYPES = {'EXTERNAL_ATTACH', 'EXTERNAL_BLOB', 'FILEPATH'} # data referenced by a UUID in external tables
3439
SERIALIZED_TYPES = {'EXTERNAL_ATTACH', 'INTERNAL_ATTACH', 'EXTERNAL_BLOB', 'INTERNAL_BLOB'} # requires packing data
3540

36-
assert set().union(CUSTOM_TYPES, EXTERNAL_TYPES, SERIALIZED_TYPES) <= set(TYPE_PATTERN)
41+
assert set().union(SPECIAL_TYPES, EXTERNAL_TYPES, SERIALIZED_TYPES) <= set(TYPE_PATTERN)
3742

3843

39-
def match_type(datatype):
40-
for category, pattern in TYPE_PATTERN.items():
41-
match = pattern.match(datatype)
42-
if match:
43-
return category, match
44-
raise DataJointError('Unsupported data types "%s"' % datatype)
44+
def match_type(attribute_type):
45+
try:
46+
return next(category for category, pattern in TYPE_PATTERN.items() if pattern.match(attribute_type))
47+
except StopIteration:
48+
raise DataJointError("Unsupported attribute type {type}".format(type=attribute_type)) from None
4549

4650

4751
logger = logging.getLogger(__name__)
@@ -78,7 +82,8 @@ def build_attribute_parser():
7882
quoted = pp.QuotedString('"') ^ pp.QuotedString("'")
7983
colon = pp.Literal(':').suppress()
8084
attribute_name = pp.Word(pp.srange('[a-z]'), pp.srange('[a-z0-9_]')).setResultsName('name')
81-
data_type = pp.Combine(pp.Word(pp.alphas) + pp.SkipTo("#", ignore=quoted)).setResultsName('type')
85+
data_type = (pp.Combine(pp.Word(pp.alphas) + pp.SkipTo("#", ignore=quoted))
86+
^ pp.QuotedString('<', endQuoteChar='>', unquoteResults=False)).setResultsName('type')
8287
default = pp.Literal('=').suppress() + pp.SkipTo(colon, ignore=quoted).setResultsName('default')
8388
comment = pp.Literal('#').suppress() + pp.restOfLine.setResultsName('comment')
8489
return attribute_name + pp.Optional(default) + colon + data_type + comment
@@ -168,8 +173,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
168173
raise DataJointError('Invalid foreign key attributes in "%s"' % line)
169174
try:
170175
raise DataJointError('Duplicate attributes "{attr}" in "{line}"'.format(
171-
attr=next(attr for attr in result.new_attrs if attr in attributes),
172-
line=line))
176+
attr=next(attr for attr in result.new_attrs if attr in attributes), line=line))
173177
except StopIteration:
174178
pass # the normal outcome
175179

@@ -246,7 +250,7 @@ def prepare_declare(definition, context):
246250
elif re.match(r'^(unique\s+)?index[^:]*$', line, re.I): # index
247251
compile_index(line, index_sql)
248252
else:
249-
name, sql, store = compile_attribute(line, in_key, foreign_key_sql)
253+
name, sql, store = compile_attribute(line, in_key, foreign_key_sql, context)
250254
if store:
251255
external_stores.append(store)
252256
if in_key and name not in primary_key:
@@ -292,10 +296,9 @@ def _make_attribute_alter(new, old, primary_key):
292296
:param primary_key: primary key attributes
293297
:return: list of SQL ALTER commands
294298
"""
295-
296299
# parse attribute names
297300
name_regexp = re.compile(r"^`(?P<name>\w+)`")
298-
original_regexp = re.compile(r'COMMENT "\{\s*(?P<name>\w+)\s*\}')
301+
original_regexp = re.compile(r'COMMENT "{\s*(?P<name>\w+)\s*}')
299302
matched = ((name_regexp.match(d), original_regexp.search(d)) for d in new)
300303
new_names = OrderedDict((d.group('name'), n and n.group('name')) for d, n in matched)
301304
old_names = [name_regexp.search(d).group('name') for d in old]
@@ -380,13 +383,41 @@ def compile_index(line, index_sql):
380383
attrs=','.join('`%s`' % a for a in match.attr_list)))
381384

382385

383-
def compile_attribute(line, in_key, foreign_key_sql):
386+
def substitute_special_type(match, category, foreign_key_sql, context):
384387
"""
385-
Convert attribute definition from DataJoint format to SQL
388+
:param match: dict containing with keys "type" and "comment" -- will be modified in place
389+
:param category: attribute type category from TYPE_PATTERN
390+
:param foreign_key_sql: list of foreign key declarations to add to
391+
:param context: context for looking up user-defined attribute_type adapters
392+
"""
393+
if category == 'UUID':
394+
match['type'] = UUID_DATA_TYPE
395+
elif category == 'INTERNAL_ATTACH':
396+
match['type'] = 'LONGBLOB'
397+
elif category in EXTERNAL_TYPES:
398+
match['store'] = match['type'].split('@', 1)[1]
399+
match['type'] = UUID_DATA_TYPE
400+
foreign_key_sql.append(
401+
"FOREIGN KEY (`{name}`) REFERENCES `{{database}}`.`{external_table_root}_{store}` (`hash`) "
402+
"ON UPDATE RESTRICT ON DELETE RESTRICT".format(external_table_root=EXTERNAL_TABLE_ROOT, **match))
403+
elif category == 'ADAPTED':
404+
adapter = get_adapter(context, match['type'])
405+
match['type'] = adapter.attribute_type
406+
category = match_type(match['type'])
407+
if category in SPECIAL_TYPES:
408+
# recursive redefinition from user-defined datatypes.
409+
substitute_special_type(match, category, foreign_key_sql, context)
410+
else:
411+
assert False, 'Unknown special type'
412+
386413

414+
def compile_attribute(line, in_key, foreign_key_sql, context):
415+
"""
416+
Convert attribute definition from DataJoint format to SQL
387417
:param line: attribution line
388418
:param in_key: set to True if attribute is in primary key set
389-
:param foreign_key_sql:
419+
:param foreign_key_sql: the list of foreign key declarations to add to
420+
:param context: context in which to look up user-defined attribute type adapterss
390421
:returns: (name, sql, is_external) -- attribute name and sql code for its declaration
391422
"""
392423
try:
@@ -412,27 +443,18 @@ def compile_attribute(line, in_key, foreign_key_sql):
412443
match['default'] = 'NOT NULL'
413444

414445
match['comment'] = match['comment'].replace('"', '\\"') # escape double quotes in comment
415-
category, type_match = match_type(match['type'])
416446

417447
if match['comment'].startswith(':'):
418448
raise DataJointError('An attribute comment must not start with a colon in comment "{comment}"'.format(**match))
419449

420-
if category in CUSTOM_TYPES:
450+
category = match_type(match['type'])
451+
if category in SPECIAL_TYPES:
421452
match['comment'] = ':{type}:{comment}'.format(**match) # insert custom type into comment
422-
if category == 'UUID':
423-
match['type'] = UUID_DATA_TYPE
424-
elif category == 'INTERNAL_ATTACH':
425-
match['type'] = 'LONGBLOB'
426-
elif category in EXTERNAL_TYPES:
427-
match['store'] = match['type'].split('@', 1)[1]
428-
match['type'] = UUID_DATA_TYPE
429-
foreign_key_sql.append(
430-
"FOREIGN KEY (`{name}`) REFERENCES `{{database}}`.`{external_table_root}_{store}` (`hash`) "
431-
"ON UPDATE RESTRICT ON DELETE RESTRICT".format(external_table_root=EXTERNAL_TABLE_ROOT, **match))
453+
substitute_special_type(match, category, foreign_key_sql, context)
432454

433455
if category in SERIALIZED_TYPES and match['default'] not in {'DEFAULT NULL', 'NOT NULL'}:
434456
raise DataJointError(
435-
'The default value for a blob or attachment attributes can only be NULL in:\n%s' % line)
457+
'The default value for a blob or attachment attributes can only be NULL in:\n{line}'.format(line=line))
436458

437459
sql = ('`{name}` {type} {default}' + (' COMMENT "{comment}"' if match['comment'] else '')).format(**match)
438460
return match['name'], sql, match.get('store')

datajoint/fetch.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,35 +43,38 @@ def _get(connection, attr, data, squeeze, download_path):
4343
"""
4444
if data is None:
4545
return
46+
4647
extern = connection.schemas[attr.database].external[attr.store] if attr.is_external else None
4748

49+
adapt = attr.adapter.get if attr.adapter else lambda x: x
50+
4851
if attr.is_filepath:
49-
return extern.fget(uuid.UUID(bytes=data))[0]
52+
return adapt(extern.fget(uuid.UUID(bytes=data))[0])
5053

5154
if attr.is_attachment:
52-
# Steps:
55+
# Steps:
5356
# 1. peek the filename from the blob without downloading remote
5457
# 2. check if the file already exists at download_path, verify checksum
5558
# 3. if exists and checksum passes then return the local filepath
5659
# 4. Otherwise, download the remote file and return the new filepath
5760
peek, size = extern.peek(uuid.UUID(bytes=data)) if attr.is_external else (data, len(data))
58-
assert size is not None
61+
assert size is not None
5962
filename = peek.split(b"\0", 1)[0].decode()
6063
size -= len(filename) + 1
6164
filepath = os.path.join(download_path, filename)
6265
if os.path.isfile(filepath) and size == os.path.getsize(filepath):
6366
local_checksum = hash.uuid_from_file(filepath, filename + '\0')
6467
remote_checksum = uuid.UUID(bytes=data) if attr.is_external else hash.uuid_from_buffer(data)
65-
if local_checksum == remote_checksum:
66-
return filepath # the existing file is okay
68+
if local_checksum == remote_checksum:
69+
return adapt(filepath) # the existing file is okay
6770
# Download remote attachment
6871
if attr.is_external:
6972
data = extern.get(uuid.UUID(bytes=data))
70-
return attach.save(data, download_path) # download file from remote store
73+
return adapt(attach.save(data, download_path)) # download file from remote store
7174

72-
return uuid.UUID(bytes=data) if attr.uuid else (
75+
return adapt(uuid.UUID(bytes=data) if attr.uuid else (
7376
blob.unpack(extern.get(uuid.UUID(bytes=data)) if attr.is_external else data, squeeze=squeeze)
74-
if attr.is_blob else data)
77+
if attr.is_blob else data))
7578

7679

7780
def _flatten_attribute_list(primary_key, attrs):

datajoint/heading.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44
import re
55
import logging
66
from .errors import DataJointError
7-
from .declare import UUID_DATA_TYPE, CUSTOM_TYPES, TYPE_PATTERN, EXTERNAL_TYPES
7+
from .declare import UUID_DATA_TYPE, SPECIAL_TYPES, TYPE_PATTERN, EXTERNAL_TYPES, NATIVE_TYPES
88
from .utils import OrderedDict
9+
from .attribute_adapter import get_adapter, AttributeAdapter
910

1011

1112
logger = logging.getLogger(__name__)
1213

1314
default_attribute_properties = dict( # these default values are set in computed attributes
1415
name=None, type='expression', in_key=False, nullable=False, default=None, comment='calculated attribute',
1516
autoincrement=False, numeric=None, string=None, uuid=False, is_blob=False, is_attachment=False, is_filepath=False,
16-
is_external=False,
17+
is_external=False, adapter=None,
1718
store=None, unsupported=False, sql_expression=None, database=None, dtype=object)
1819

1920

@@ -146,7 +147,7 @@ def as_sql(self):
146147
def __iter__(self):
147148
return iter(self.attributes)
148149

149-
def init_from_database(self, conn, database, table_name):
150+
def init_from_database(self, conn, database, table_name, context):
150151
"""
151152
initialize heading from a database table. The table must exist already.
152153
"""
@@ -211,32 +212,53 @@ def init_from_database(self, conn, database, table_name):
211212
numeric=any(TYPE_PATTERN[t].match(attr['type']) for t in ('DECIMAL', 'INTEGER', 'FLOAT')),
212213
string=any(TYPE_PATTERN[t].match(attr['type']) for t in ('ENUM', 'TEMPORAL', 'STRING')),
213214
is_blob=bool(TYPE_PATTERN['INTERNAL_BLOB'].match(attr['type'])),
214-
uuid=False, is_attachment=False, is_filepath=False, store=None, is_external=False, sql_expression=None)
215+
uuid=False, is_attachment=False, is_filepath=False, adapter=None,
216+
store=None, is_external=False, sql_expression=None)
215217

216218
if any(TYPE_PATTERN[t].match(attr['type']) for t in ('INTEGER', 'FLOAT')):
217219
attr['type'] = re.sub(r'\(\d+\)', '', attr['type'], count=1) # strip size off integers and floats
218220
attr['unsupported'] = not any((attr['is_blob'], attr['numeric'], attr['numeric']))
219221
attr.pop('Extra')
220222

221223
# process custom DataJoint types
222-
custom_type = re.match(r':(?P<type>[^:]+):(?P<comment>.*)', attr['comment'])
223-
if custom_type:
224-
attr.update(custom_type.groupdict(), unsupported=False)
224+
special = re.match(r':(?P<type>[^:]+):(?P<comment>.*)', attr['comment'])
225+
if special:
226+
special = special.groupdict()
227+
attr.update(special)
228+
# process adapted attribute types
229+
if special and TYPE_PATTERN['ADAPTED'].match(attr['type']):
230+
assert context is not None, 'Declaration context is not set'
231+
adapter_name = special['type']
225232
try:
226-
category = next(c for c in CUSTOM_TYPES if TYPE_PATTERN[c].match(attr['type']))
233+
attr.update(adapter=get_adapter(context, adapter_name))
234+
except DataJointError:
235+
# if no adapter, then delay the error until the first invocation
236+
attr.update(adapter=AttributeAdapter())
237+
else:
238+
attr.update(type=attr['adapter'].attribute_type)
239+
if not any(r.match(attr['type']) for r in TYPE_PATTERN.values()):
240+
raise DataJointError(
241+
"Invalid attribute type '{type}' in adapter object <{adapter_name}>.".format(
242+
adapter_name=adapter_name, **attr))
243+
special = not any(TYPE_PATTERN[c].match(attr['type']) for c in NATIVE_TYPES)
244+
245+
if special:
246+
try:
247+
category = next(c for c in SPECIAL_TYPES if TYPE_PATTERN[c].match(attr['type']))
227248
except StopIteration:
228249
if attr['type'].startswith('external'):
229250
raise DataJointError('Legacy datatype `{type}`.'.format(**attr)) from None
230251
raise DataJointError('Unknown attribute type `{type}`'.format(**attr)) from None
231252
attr.update(
253+
unsupported=False,
232254
is_attachment=category in ('INTERNAL_ATTACH', 'EXTERNAL_ATTACH'),
233255
is_filepath=category == 'FILEPATH',
234256
is_blob=category in ('INTERNAL_BLOB', 'EXTERNAL_BLOB'), # INTERNAL_BLOB is not a custom type but is included for completeness
235257
uuid=category == 'UUID',
236258
is_external=category in EXTERNAL_TYPES,
237259
store=attr['type'].split('@')[1] if category in EXTERNAL_TYPES else None)
238260

239-
if attr['in_key'] and (attr['is_blob'] or attr['is_attachment'] or attr['is_filepath']):
261+
if attr['in_key'] and any((attr['is_blob'], attr['is_attachment'], attr['is_filepath'])):
240262
raise DataJointError('Blob, attachment, or filepath attributes are not allowed in the primary key')
241263

242264
if attr['string'] and attr['default'] is not None and attr['default'] not in sql_literals:
@@ -256,6 +278,11 @@ def init_from_database(self, conn, database, table_name):
256278
t = re.sub(r' unsigned$', '', t) # remove unsigned
257279
assert (t, is_unsigned) in numeric_types, 'dtype not found for type %s' % t
258280
attr['dtype'] = numeric_types[(t, is_unsigned)]
281+
282+
if attr['adapter']:
283+
# restore adapted type name
284+
attr['type'] = adapter_name
285+
259286
self.attributes = OrderedDict(((q['name'], Attribute(**q)) for q in attributes))
260287

261288
# Read and tabulate secondary indexes

0 commit comments

Comments
 (0)