Skip to content

Commit 8a24c15

Browse files
committed
Alphabetize, link to docs for functions
1 parent 806eb62 commit 8a24c15

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

README.rst

+8-2
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,15 @@ All of these features are on by default but can be disabled with flags.
6767
.. |numpy.random| replace:: ``numpy.random``
6868
__ https://numpy.org/doc/stable/reference/random/index.html
6969

70-
* If `TensorFlow <https://www.tensorflow.org/>`_ is installed, its random seed in ``tensorflow.random`` is reset at the start of every test.
70+
* If `PyTorch <https://pytorch.org/>`_ is installed, its random seed is reset at the start of every test with |torch.manual_seed()|__.
7171

72-
* If `PyTorch <https://pytorch.org/>`_ is installed, its random seed is reset at the start of every test. The random seed of each test is recorded, and can play a role in detecting flaky tests.
72+
.. |torch.manual_seed()| replace:: ``torch.manual_seed()``
73+
__ https://pytorch.org/docs/stable/generated/torch.manual_seed.html
74+
75+
* If `TensorFlow <https://www.tensorflow.org/>`_ is installed, its random seed is reset at the start of every test with |tensorflow.random.set_seed()|__.
76+
77+
.. |tensorflow.random.set_seed()| replace:: ``tensorflow.random.set_seed()``
78+
__ https://www.tensorflow.org/api_docs/python/tf/random/set_seed
7379

7480
* If additional random generators are used, they can be registered under the
7581
``pytest_randomly.random_seeder``

src/pytest_randomly/__init__.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,6 @@
6767
except ImportError: # pragma: no cover
6868
have_numpy = False
6969

70-
# tensorflow
71-
try:
72-
import tensorflow as tf
73-
74-
have_tensorflow = True
75-
except ImportError: # pragma: no cover
76-
have_tensorflow = False
77-
7870
# pytorch
7971
try:
8072
import torch
@@ -83,6 +75,14 @@
8375
except ImportError: # pragma: no cover
8476
have_pytorch = False
8577

78+
# tensorflow
79+
try:
80+
import tensorflow as tf
81+
82+
have_tensorflow = True
83+
except ImportError: # pragma: no cover
84+
have_tensorflow = False
85+
8686

8787
default_seed = random.Random().getrandbits(32)
8888

@@ -196,17 +196,17 @@ def _reseed(config: Config, offset: int = 0) -> int:
196196
else:
197197
np_random.set_state(np_random_states[numpy_seed])
198198

199+
if have_pytorch: # pragma: no branch
200+
torch.manual_seed(seed)
201+
if torch.cuda.is_available(): # Also seed CUDA if available
202+
torch.cuda.manual_seed_all(seed)
203+
199204
if have_tensorflow: # pragma: no branch
200205
tf.random.set_seed(seed)
201206
# TensorFlow 1.x compatibility
202207
if hasattr(tf, "compat"):
203208
tf.compat.v1.set_random_seed(seed)
204209

205-
if have_pytorch: # pragma: no branch
206-
torch.manual_seed(seed)
207-
if torch.cuda.is_available(): # Also seed CUDA if available
208-
torch.cuda.manual_seed_all(seed)
209-
210210
if entrypoint_reseeds is None:
211211
eps = entry_points(group="pytest_randomly.random_seeder")
212212
entrypoint_reseeds = [e.load() for e in eps]

0 commit comments

Comments
 (0)