|
67 | 67 | except ImportError: # pragma: no cover
|
68 | 68 | have_numpy = False
|
69 | 69 |
|
70 |
| -# tensorflow |
71 |
| -try: |
72 |
| - import tensorflow as tf |
73 |
| - |
74 |
| - have_tensorflow = True |
75 |
| -except ImportError: # pragma: no cover |
76 |
| - have_tensorflow = False |
77 |
| - |
78 | 70 | # pytorch
|
79 | 71 | try:
|
80 | 72 | import torch
|
|
83 | 75 | except ImportError: # pragma: no cover
|
84 | 76 | have_pytorch = False
|
85 | 77 |
|
| 78 | +# tensorflow |
| 79 | +try: |
| 80 | + import tensorflow as tf |
| 81 | + |
| 82 | + have_tensorflow = True |
| 83 | +except ImportError: # pragma: no cover |
| 84 | + have_tensorflow = False |
| 85 | + |
86 | 86 |
|
87 | 87 | default_seed = random.Random().getrandbits(32)
|
88 | 88 |
|
@@ -196,17 +196,17 @@ def _reseed(config: Config, offset: int = 0) -> int:
|
196 | 196 | else:
|
197 | 197 | np_random.set_state(np_random_states[numpy_seed])
|
198 | 198 |
|
| 199 | + if have_pytorch: # pragma: no branch |
| 200 | + torch.manual_seed(seed) |
| 201 | + if torch.cuda.is_available(): # Also seed CUDA if available |
| 202 | + torch.cuda.manual_seed_all(seed) |
| 203 | + |
199 | 204 | if have_tensorflow: # pragma: no branch
|
200 | 205 | tf.random.set_seed(seed)
|
201 | 206 | # TensorFlow 1.x compatibility
|
202 | 207 | if hasattr(tf, "compat"):
|
203 | 208 | tf.compat.v1.set_random_seed(seed)
|
204 | 209 |
|
205 |
| - if have_pytorch: # pragma: no branch |
206 |
| - torch.manual_seed(seed) |
207 |
| - if torch.cuda.is_available(): # Also seed CUDA if available |
208 |
| - torch.cuda.manual_seed_all(seed) |
209 |
| - |
210 | 210 | if entrypoint_reseeds is None:
|
211 | 211 | eps = entry_points(group="pytest_randomly.random_seeder")
|
212 | 212 | entrypoint_reseeds = [e.load() for e in eps]
|
|
0 commit comments