Skip to content

Commit 10f7393

Browse files
simplify context management, move test initialization into setup_class from __init__ in nosetests
1 parent 06d5cc3 commit 10f7393

15 files changed

+82
-76
lines changed

datajoint/declare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def declare(full_table_name, definition, context):
194194
195195
:param full_table_name: full name of the table
196196
:param definition: DataJoint table definition
197-
:param context: dictionary of objects that might be referred to in the table. Usually this will be locals()
197+
:param context: dictionary of objects that might be referred to in the table.
198198
"""
199199
# split definition into lines
200200
definition = re.split(r'\s*\n\s*', definition.strip())

datajoint/schema.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,15 +218,14 @@ def process_relation_class(self, relation_class, context, assert_declared=False)
218218
relation_class.database = self.database
219219
relation_class._connection = self.connection
220220
relation_class._heading = Heading()
221-
relation_class._context = context
222221
# instantiate the class, declare the table if not already
223222
instance = relation_class()
224223
is_declared = instance.is_declared
225224
if not is_declared:
226225
if not self.create_tables or assert_declared:
227226
raise DataJointError('Table not declared %s' % instance.table_name)
228227
else:
229-
instance.declare()
228+
instance.declare(context)
230229
is_declared = is_declared or instance.is_declared
231230

232231
# fill values in Lookup tables from their contents property

datajoint/table.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,10 @@ class Table(QueryExpression):
2828
"""
2929
Table is an abstract class that represents a base relation, i.e. a table in the schema.
3030
To make it a concrete class, override the abstract properties specifying the connection,
31-
table name, database, context, and definition.
31+
table name, database, and definition.
3232
A Relation implements insert and delete methods in addition to inherited relational operators.
3333
"""
3434
_heading = None
35-
_context = None
3635
database = None
3736
_log_ = None
3837
_external_table = None
@@ -55,16 +54,12 @@ def heading(self):
5554
self._heading.init_from_database(self.connection, self.database, self.table_name)
5655
return self._heading
5756

58-
@property
59-
def context(self):
60-
return self._context
61-
62-
def declare(self):
57+
def declare(self, context=None):
6358
"""
6459
Use self.definition to declare the table in the schema.
6560
"""
6661
try:
67-
sql, uses_external = declare(self.full_table_name, self.definition, self._context)
62+
sql, uses_external = declare(self.full_table_name, self.definition, context)
6863
if uses_external:
6964
# trigger the creation of the external hash lookup for the current schema
7065
external_table = self.connection.schemas[self.database].external_table

datajoint/user_tables.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ class UserTable(Table, metaclass=OrderedClass):
6565
UserTable is initialized by the decorator generated by schema().
6666
"""
6767
_connection = None
68-
_context = None
6968
_heading = None
7069
tier_regexp = None
7170
_prefix = None
@@ -143,7 +142,6 @@ class Part(UserTable):
143142
"""
144143

145144
_connection = None
146-
_context = None
147145
_heading = None
148146
_master = None
149147

tests/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"""
77

88
import logging
9-
from os import environ
9+
from os import environ, remove
1010
import datajoint as dj
1111

1212
__author__ = 'Edgar Walker, Fabian Sinz, Dimitri Yatsenko'
@@ -47,3 +47,4 @@ def teardown_package():
4747
for db in cur.fetchall():
4848
conn.query('DROP DATABASE `{}`'.format(db[0]))
4949
conn.query('SET FOREIGN_KEY_CHECKS=1')
50+
remove("dj_local_conf.json")

tests/schema.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@schema
15-
class Test(dj.Lookup):
15+
class TTest(dj.Lookup):
1616
definition = """
1717
key : int # key
1818
---
@@ -22,7 +22,7 @@ class Test(dj.Lookup):
2222

2323

2424
@schema
25-
class Test2(dj.Manual):
25+
class TTest2(dj.Manual):
2626
definition = """
2727
key : int # key
2828
---
@@ -31,7 +31,7 @@ class Test2(dj.Manual):
3131

3232

3333
@schema
34-
class Test3(dj.Manual):
34+
class TTest3(dj.Manual):
3535
definition = """
3636
key : int
3737
---
@@ -40,19 +40,19 @@ class Test3(dj.Manual):
4040

4141

4242
@schema
43-
class TestExtra(dj.Manual):
43+
class TTestExtra(dj.Manual):
4444
"""
4545
clone of Test but with an extra field
4646
"""
47-
definition = Test.definition + "\nextra : int # extra int\n"
47+
definition = TTest.definition + "\nextra : int # extra int\n"
4848

4949

5050
@schema
51-
class TestNoExtra(dj.Manual):
51+
class TTestNoExtra(dj.Manual):
5252
"""
5353
clone of Test but with no extra fields
5454
"""
55-
definition = Test.definition
55+
definition = TTest.definition
5656

5757

5858
@schema

tests/schema_simple.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class DataB(dj.Lookup):
143143

144144

145145
@schema
146-
class TestUpdate(dj.Lookup):
146+
class TTestUpdate(dj.Lookup):
147147
definition = """
148148
primary_key : int
149149
---

tests/test_blob2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
2+
import os
23
import datajoint as dj
3-
from nose.tools import assert_equal, assert_true, assert_list_equal, assert_tuple_equal
4+
from nose.tools import assert_equal, assert_true, assert_list_equal, assert_tuple_equal, assert_false
45

56
from . import PREFIX, CONN_INFO
67

@@ -49,7 +50,10 @@ def insert_blobs():
4950

5051

5152
class TestFetch:
52-
def __init__(self):
53+
54+
@classmethod
55+
def setup_class(cls):
56+
assert_false(dj.config['safemode'], 'safemode must be disabled')
5357
Blob().delete()
5458
insert_blobs()
5559

tests/test_connection.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,10 @@ class Subjects(dj.Manual):
5555
species = "mouse" : enum('mouse', 'monkey', 'human') # species
5656
"""
5757

58-
def __init__(self):
59-
self.relation = self.Subjects()
60-
self.conn = dj.conn(**CONN_INFO)
58+
@classmethod
59+
def setup_class(cls):
60+
cls.relation = cls.Subjects()
61+
cls.conn = dj.conn(**CONN_INFO)
6162

6263
def teardown(self):
6364
self.relation.delete_quick()

tests/test_fetch.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010

1111
class TestFetch:
1212

13-
def __init__(self):
14-
self.subject = schema.Subject()
15-
self.lang = schema.Language()
13+
@classmethod
14+
def setup_class(cls):
15+
cls.subject = schema.Subject()
16+
cls.lang = schema.Language()
1617

1718
def test_getattribute(self):
1819
"""Testing Fetch.__call__ with attributes"""

tests/test_nan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@ class NanTest(dj.Manual):
1616

1717

1818
class TestNaNInsert:
19-
def __init__(self):
20-
self.rel = NanTest()
19+
@classmethod
20+
def setup_class(cls):
21+
cls.rel = NanTest()
2122
with dj.config(safemode=False):
22-
self.rel.delete()
23+
cls.rel.delete()
2324
a = np.array([0, 1/3, np.nan, np.pi, np.nan])
24-
self.rel.insert(((i, value) for i, value in enumerate(a)))
25-
self.a = a
25+
cls.rel.insert(((i, value) for i, value in enumerate(a)))
26+
cls.a = a
2627

2728
def test_insert_nan(self):
2829
"""Test fetching of null values"""

tests/test_privileges.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88

99
class TestUnprivileged:
1010

11-
def __init__(self):
11+
@classmethod
12+
def setup_class(cls):
1213
"""A connection with only SELECT privilege to djtest schemas"""
13-
self.connection = dj.Connection(host=environ.get('DJ_TEST_HOST', 'localhost'), user='djview', password='djview')
14+
cls.connection = dj.Connection(host=environ.get('DJ_TEST_HOST', 'localhost'), user='djview', password='djview')
1415

1516
@raises(dj.DataJointError)
1617
def test_fail_create_schema(self):

tests/test_relation.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,24 +24,24 @@ class TestRelation:
2424
Test base relations: insert, delete
2525
"""
2626

27-
def __init__(self):
28-
self.test = schema.Test()
29-
self.test_extra = schema.TestExtra()
30-
self.test_no_extra = schema.TestNoExtra()
31-
self.user = schema.User()
32-
self.subject = schema.Subject()
33-
self.experiment = schema.Experiment()
34-
self.trial = schema.Trial()
35-
self.ephys = schema.Ephys()
36-
self.channel = schema.Ephys.Channel()
37-
self.img = schema.Image()
38-
self.trash = schema.UberTrash()
27+
@classmethod
28+
def setup_class(cls):
29+
cls.test = schema.TTest()
30+
cls.test_extra = schema.TTestExtra()
31+
cls.test_no_extra = schema.TTestNoExtra()
32+
cls.user = schema.User()
33+
cls.subject = schema.Subject()
34+
cls.experiment = schema.Experiment()
35+
cls.trial = schema.Trial()
36+
cls.ephys = schema.Ephys()
37+
cls.channel = schema.Ephys.Channel()
38+
cls.img = schema.Image()
39+
cls.trash = schema.UberTrash()
3940

4041
def test_contents(self):
4142
"""
4243
test the ability of tables to self-populate using the contents property
4344
"""
44-
4545
# test contents
4646
assert_true(self.user)
4747
assert_true(len(self.user) == len(self.user.contents))
@@ -96,9 +96,9 @@ def test_wrong_insert_type(self):
9696
self.user.insert1(3)
9797

9898
def test_insert_select(self):
99-
schema.Test2.delete()
100-
schema.Test2.insert(schema.Test)
101-
assert_equal(len(schema.Test2()), len(schema.Test()))
99+
schema.TTest2.delete()
100+
schema.TTest2.insert(schema.TTest)
101+
assert_equal(len(schema.TTest2()), len(schema.TTest()))
102102
original_length = len(self.subject)
103103
self.subject.insert(self.subject.proj(
104104
'real_id', 'date_of_birth', 'subject_notes', subject_id='subject_id+1000', species='"human"'))
@@ -205,9 +205,13 @@ def test_blob_insert(self):
205205
def test_drop(self):
206206
"""Tests dropping tables"""
207207
dj.config['safemode'] = True
208-
with patch.object(utils, "input", create=True, return_value='yes'):
209-
self.trash.drop()
210-
dj.config['safemode'] = False
208+
try:
209+
with patch.object(utils, "input", create=True, return_value='yes'):
210+
self.trash.drop()
211+
except:
212+
pass
213+
finally:
214+
dj.config['safemode'] = False
211215
self.trash.fetch()
212216

213217
def test_table_regexp(self):

tests/test_relation_u.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,17 @@ class TestU:
88
Test base relations: insert, delete
99
"""
1010

11-
def __init__(self):
12-
self.user = schema.User()
13-
self.language = schema.Language()
14-
self.subject = schema.Subject()
15-
self.experiment = schema.Experiment()
16-
self.trial = schema.Trial()
17-
self.ephys = schema.Ephys()
18-
self.channel = schema.Ephys.Channel()
19-
self.img = schema.Image()
20-
self.trash = schema.UberTrash()
11+
@classmethod
12+
def setup_class(cls):
13+
cls.user = schema.User()
14+
cls.language = schema.Language()
15+
cls.subject = schema.Subject()
16+
cls.experiment = schema.Experiment()
17+
cls.trial = schema.Trial()
18+
cls.ephys = schema.Ephys()
19+
cls.channel = schema.Ephys.Channel()
20+
cls.img = schema.Image()
21+
cls.trash = schema.UberTrash()
2122

2223
def test_restriction(self):
2324
language_set = {s[1] for s in self.language.contents}
@@ -60,7 +61,7 @@ def test_aggregations(self):
6061
assert_equal((rel & 'language="English"').fetch1('number_of_speakers'), 3)
6162

6263
def test_argmax(self):
63-
rel = schema.Test()
64+
rel = schema.TTest()
6465
# get the tuples corresponding to maximum value
6566
mx = (rel * dj.U().aggr(rel, mx='max(value)')) & 'mx=value'
6667
assert_equal(mx.fetch('value')[0], max(rel.fetch('value')))

0 commit comments

Comments
 (0)