Skip to content

Commit 1cd5e23

Browse files
committed
Add support for copy.deepcopy on instances of RandObj
- Add tests that copying an object retains its determinism. - Fix resulting bugs from that. - Marginally improve MultiVar performance by changing a dict to a list.
1 parent 4f1e7e0 commit 1cd5e23

File tree

5 files changed

+87
-50
lines changed

5 files changed

+87
-50
lines changed

constrainedrandom/internal/multivar.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class MultiVarProblem:
3737
def __init__(
3838
self,
3939
parent: 'RandObj',
40-
vars: Dict[str, 'RandVar'],
40+
vars: List['RandVar'],
4141
constraints: Iterable[utils.ConstraintAndVars],
4242
max_iterations: int,
4343
max_domain_size: int,
@@ -71,15 +71,14 @@ def determine_order(self, with_values: Dict[str, Any]) -> List[List['RandVar']]:
7171
# to solve at the same time.
7272
# The best case is to simply solve them all at once, if possible, however it is
7373
# likely that the domain will be too large.
74-
vars = []
75-
7674
# If values are provided, simply don't add those variables to the ordering problem.
7775
if problem_changed:
78-
for name, var in self.vars.items():
79-
if name not in with_values:
76+
vars = []
77+
for var in self.vars:
78+
if var.name not in with_values:
8079
vars.append(var)
8180
else:
82-
vars = list(self.vars.values())
81+
vars = list(self.vars)
8382

8483
# Use order hints first, remaining variables can be placed anywhere the domain
8584
# isn't too large.

constrainedrandom/internal/randvar.py

+33-6
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,37 @@
44
import constraint
55
from functools import partial
66
from itertools import product
7-
from typing import Any, Callable, Iterable, Optional, Union
7+
from typing import Any, Callable, Iterable, List, Optional
88
import random
99

1010
from .. import utils
1111
from ..debug import RandomizationDebugInfo, RandomizationFail
1212
from ..random import dist
1313

1414

15+
def get_and_call(getter: Callable, member_fn: str, *args: List[Any]):
16+
'''
17+
This is a very strange workaround for a very strange issue.
18+
``copy.deepcopy`` can handle a ``partial`` for all other members
19+
of ``random.Random``, but not ``getrandbits``. I.e. it correctly
20+
copies the other functions and their instance of ``random.Random``,
21+
but not ``getrandbits``. The reason for this is unknown.
22+
23+
This function therefore exists to work around that issue
24+
by getting ``getrandbits`` and calling it. I tried many
25+
other approaches, but this was the only one that worked.
26+
27+
:param getter: Getter to call, returning an object that
28+
has a member function with name ``member_fn``.
29+
:param member_fn: Member function of the the object returned
30+
by ``getter``.
31+
:param args: Arguments to supply to ``member_fn``.
32+
'''
33+
callable_obj = getter()
34+
fn = getattr(callable_obj, member_fn)
35+
return fn(*args)
36+
37+
1538
class RandVar:
1639
'''
1740
Randomizable variable. For internal use with :class:`RandObj`.
@@ -99,6 +122,10 @@ def create_randomizer(self) -> Callable:
99122
We do this to create a more optimal randomizer than the user might
100123
have specified that is functionally equivalent.
101124
125+
We always return a ``partial`` because these work with
126+
``copy.deepcopy``, whereas locally-defined functions and
127+
lambdas can only ever have one instance.
128+
102129
:return: a function as described.
103130
:raises TypeError: if the domain is of a bad type.
104131
'''
@@ -111,7 +138,9 @@ def create_randomizer(self) -> Callable:
111138
return self.fn
112139
elif self.bits is not None:
113140
self.domain = range(0, 1 << self.bits)
114-
return partial(self._get_random().getrandbits, self.bits)
141+
# This is still faster than doing self._get_random().randrange(self.bits << 1),
142+
# it seems that getrandbits is 10x faster than randrange.
143+
return partial(get_and_call, self._get_random, 'getrandbits', self.bits)
115144
else:
116145
# Handle possible types of domain.
117146
is_range = isinstance(self.domain, range)
@@ -146,10 +175,8 @@ def create_randomizer(self) -> Callable:
146175
debug_info.add_failure(debug_fail)
147176
raise utils.RandomizationError("Variable was unsolvable. Check constraints.", debug_info)
148177
solution_list = [s[self.name] for s in solutions]
149-
def solution_picker(solns):
150-
return self._get_random().choice(solns)
151178
self.check_constraints = False
152-
return partial(solution_picker, solution_list)
179+
return partial(self._get_random().choice, solution_list)
153180
elif is_range:
154181
return partial(self._get_random().randrange, self.domain.start, self.domain.stop)
155182
elif is_list_or_tuple:
@@ -308,7 +335,7 @@ def randomize_list_csp(
308335
problem = constraint.Problem()
309336
possible_values = self.get_constraint_domain()
310337
# Prune possibilities according to scalar constraints.
311-
possible_values[:] = [x for x in possible_values \
338+
possible_values = [x for x in possible_values \
312339
if all(constr(val) for val in x for constr in constraints)]
313340
problem.addVariable(self.name, possible_values)
314341
for con in list_constraints:

constrainedrandom/randobj.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -318,18 +318,20 @@ def randomize(
318318
problem_changed = True
319319
# If a variable becomes constrained due to temporary multi-variable
320320
# constraints, we must respect single var temporary constraints too.
321-
for var, constrs in tmp_single_var_constraints.items():
321+
for var, constrs in sorted(tmp_single_var_constraints.items()):
322322
if var in constrained_vars:
323323
for constr in constrs:
324324
constraints.append((constr, (var,)))
325325

326326
# Don't allow non-determinism when iterating over a set
327327
constrained_vars = sorted(constrained_vars)
328+
# Don't allow non-determinism when iterating over a dict
329+
random_vars = sorted(self._random_vars.items())
328330

329331
# Process concrete values - use these preferentially
330332
with_values = with_values if with_values is not None else {}
331333

332-
for name, random_var in self._random_vars.items():
334+
for name, random_var in random_vars:
333335
if name in with_values:
334336
result[name] = with_values[name]
335337
else:
@@ -379,7 +381,7 @@ def randomize(
379381
if problem_changed or self._problem_changed or self._multi_var_problem is None:
380382
multi_var_problem = MultiVarProblem(
381383
self,
382-
{name: var for name, var in self._random_vars.items() if name in constrained_vars},
384+
[var for var_name, var in random_vars if var_name in constrained_vars],
383385
constraints,
384386
max_iterations=self._max_iterations,
385387
max_domain_size=self._max_domain_size,

constrainedrandom/random.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from . import utils
99

1010

11-
def weighted_choice(choices_dict: utils.Dist, _random: Optional[random.Random]=random) -> Any:
11+
def weighted_choice(choices_dict: utils.Dist, _random: Optional[random.Random]=None) -> Any:
1212
'''
1313
Wrapper around ``random.choices``, allowing the user to specify weights in a dictionary.
1414
@@ -24,9 +24,12 @@ def weighted_choice(choices_dict: utils.Dist, _random: Optional[random.Random]=r
2424
# 0 will be chosen 25% of the time, 1 25% of the time and 'foo' 50% of the time
2525
value = weighted_choice({0: 25, 1: 25, 'foo': 50})
2626
'''
27+
if _random is None:
28+
_random = random
2729
return _random.choices(tuple(choices_dict.keys()), weights=tuple(choices_dict.values()))
2830

29-
def dist(dist_dict: utils.Dist, _random: Optional[random.Random]=random) -> Any:
31+
32+
def dist(dist_dict: utils.Dist, _random: Optional[random.Random]=None) -> Any:
3033
'''
3134
Random distribution. As :func:`weighted_choice`, but allows ``range`` to be used as
3235
a key to the dictionary, which if chosen is then evaluated as a random range.
@@ -46,6 +49,8 @@ def dist(dist_dict: utils.Dist, _random: Optional[random.Random]=random) -> Any:
4649
# and 'foo' 50% of the time
4750
value = dist({0: 25, range(1, 10): 25, 'foo': 50})
4851
'''
52+
if _random is None:
53+
_random = random
4954
answer = weighted_choice(choices_dict=dist_dict, _random=_random)[0]
5055
if isinstance(answer, range):
5156
return _random.randrange(answer.start, answer.stop)

tests/testutils.py

+37-33
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import timeit
77
import unittest
8+
from copy import deepcopy
89
from typing import Any, Dict, List
910

1011
from constrainedrandom import RandObj, RandomizationError
@@ -72,7 +73,7 @@ def tmp_check(self, results) -> None:
7273
'''
7374
pass
7475

75-
def randomize_and_time(self, randobj, iterations, tmp_constraints=None, tmp_values=None) -> Dict[str, Any]:
76+
def randomize_and_time(self, randobj, iterations, tmp_constraints=None, tmp_values=None) -> List[Dict[str, Any]]:
7677
'''
7778
Call randobj.randomize() iterations times, time it, print performance stats,
7879
return the results.
@@ -93,7 +94,6 @@ def randomize_and_time(self, randobj, iterations, tmp_constraints=None, tmp_valu
9394
print(f'{self.get_full_test_name()} took {time_taken:.4g}s for {iterations} iterations ({hz:.1f}Hz)')
9495
return results
9596

96-
9797
def test_randobj(self) -> None:
9898
'''
9999
Reusable test function to randomize a RandObj for a number of iterations and perform checks.
@@ -109,6 +109,8 @@ def test_randobj(self) -> None:
109109

110110
# Test with seed 0
111111
randobj = self.get_randobj(0)
112+
# Take a copy of the randobj for use later
113+
randobj_copy = deepcopy(randobj)
112114
if self.EXPECT_FAILURE:
113115
self.assertRaises(RandomizationError, randobj.randomize)
114116
else:
@@ -128,41 +130,43 @@ def test_randobj(self) -> None:
128130
add_results = self.randomize_and_time(randobj, self.iterations, tmp_values=tmp_values)
129131
self.tmp_check(add_results)
130132

131-
# Test again with seed 0, ensuring results are the same
133+
# Test again with seed 0, ensuring results are the same.
134+
# Also test the copy we took earlier.
132135
randobj0 = self.get_randobj(0)
133-
if self.EXPECT_FAILURE:
134-
self.assertRaises(RandomizationError, randobj0.randomize)
135-
else:
136-
results0 = self.randomize_and_time(randobj0, self.iterations)
137-
assertListOfDictsEqual(self, results, results0, "Non-determinism detected, results were not equal")
138-
if do_tmp_checks:
139-
# Check applying temporary constraints is also deterministic
140-
tmp_results0 = self.randomize_and_time(randobj0, self.iterations, tmp_constraints, tmp_values)
141-
assertListOfDictsEqual(
142-
self,
143-
tmp_results,
144-
tmp_results0,
145-
"Non-determinism detected, results were not equal with temp constraints"
146-
)
147-
# Check temporary constraints don't break base randomization determinism
148-
post_tmp_results0 = self.randomize_and_time(randobj0, self.iterations)
149-
assertListOfDictsEqual(
150-
self,
151-
post_tmp_results,
152-
post_tmp_results0,
153-
"Non-determinism detected, results were not equal after temp constraints"
154-
)
155-
# Add temporary constraints permanently, see what happens
156-
if tmp_constraints is not None:
157-
for constr, vars in tmp_constraints:
158-
randobj0.add_constraint(constr, vars)
159-
add_results0 = self.randomize_and_time(randobj0, self.iterations, tmp_values=tmp_values)
136+
for tmp_randobj in [randobj0, randobj_copy]:
137+
if self.EXPECT_FAILURE:
138+
self.assertRaises(RandomizationError, tmp_randobj.randomize)
139+
else:
140+
results0 = self.randomize_and_time(tmp_randobj, self.iterations)
141+
assertListOfDictsEqual(self, results, results0, "Non-determinism detected, results were not equal")
142+
if do_tmp_checks:
143+
# Check applying temporary constraints is also deterministic
144+
tmp_results0 = self.randomize_and_time(tmp_randobj, self.iterations, tmp_constraints, tmp_values)
145+
assertListOfDictsEqual(
146+
self,
147+
tmp_results,
148+
tmp_results0,
149+
"Non-determinism detected, results were not equal with temp constraints"
150+
)
151+
# Check temporary constraints don't break base randomization determinism
152+
post_tmp_results0 = self.randomize_and_time(tmp_randobj, self.iterations)
160153
assertListOfDictsEqual(
161154
self,
162-
add_results,
163-
add_results0,
164-
"Non-determinism detected, results were not equal after constraints added"
155+
post_tmp_results,
156+
post_tmp_results0,
157+
"Non-determinism detected, results were not equal after temp constraints"
165158
)
159+
# Add temporary constraints permanently, see what happens
160+
if tmp_constraints is not None:
161+
for constr, vars in tmp_constraints:
162+
tmp_randobj.add_constraint(constr, vars)
163+
add_results0 = self.randomize_and_time(tmp_randobj, self.iterations, tmp_values=tmp_values)
164+
assertListOfDictsEqual(
165+
self,
166+
add_results,
167+
add_results0,
168+
"Non-determinism detected, results were not equal after constraints added"
169+
)
166170

167171
# Test with seed 1, ensuring results are different
168172
randobj1 = self.get_randobj(1)

0 commit comments

Comments
 (0)