Skip to content

Commit 3e98794

Browse files
committed
fix windows compatibility for tempfile
1 parent a893974 commit 3e98794

File tree

1 file changed

+58
-16
lines changed

1 file changed

+58
-16
lines changed

tests/test_sklearn.py

Lines changed: 58 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
# See the License for the specific language governing permissions and
2020
# limitations under the License.
2121
#
22+
import contextlib
2223
import itertools
24+
import os
2325
import tempfile
2426
import warnings
2527

@@ -47,6 +49,34 @@
4749
_DEVICES = ("cpu",)
4850

4951

52+
@contextlib.contextmanager
53+
def _windows_compatible_tempfile(mode="w+b", delete=True, **kwargs):
54+
"""Context manager for creating temporary files compatible with Windows.
55+
56+
On Windows, files opened with delete=True cannot be accessed by other
57+
processes or reopened. This context manager creates a temporary file
58+
with delete=False, yields its path, and ensures cleanup in a finally block.
59+
60+
Args:
61+
mode: File mode (default: "w+b")
62+
**kwargs: Additional arguments passed to NamedTemporaryFile
63+
64+
Yields:
65+
str: Path to the temporary file
66+
"""
67+
if not delete:
68+
raise ValueError("'delete' must be True")
69+
70+
with tempfile.NamedTemporaryFile(mode=mode, delete=False, **kwargs) as f:
71+
tempname = f.name
72+
73+
try:
74+
yield tempname
75+
finally:
76+
if os.path.exists(tempname):
77+
os.remove(tempname)
78+
79+
5080
def test_imports():
5181
import cebra
5282

@@ -1037,24 +1067,23 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture,
10371067
device=device)
10381068

10391069
original_model = action(original_model)
1040-
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
1070+
with _windows_compatible_tempfile(mode="w+b") as tempname:
10411071
if not check_if_fit(original_model):
10421072
with pytest.raises(ValueError):
1043-
original_model.save(savefile.name, backend=backend_save)
1073+
original_model.save(tempname, backend=backend_save)
10441074
else:
10451075
if "parametrized" in original_model.model_architecture and backend_save == "torch":
10461076
with pytest.raises(AttributeError):
1047-
original_model.save(savefile.name, backend=backend_save)
1077+
original_model.save(tempname, backend=backend_save)
10481078
else:
1049-
original_model.save(savefile.name, backend=backend_save)
1079+
original_model.save(tempname, backend=backend_save)
10501080

10511081
if (backend_load != "auto") and (backend_save != backend_load):
10521082
with pytest.raises(RuntimeError):
1053-
cebra_sklearn_cebra.CEBRA.load(savefile.name,
1054-
backend_load)
1083+
cebra_sklearn_cebra.CEBRA.load(tempname, backend_load)
10551084
else:
10561085
loaded_model = cebra_sklearn_cebra.CEBRA.load(
1057-
savefile.name, backend_load)
1086+
tempname, backend_load)
10581087
_assert_equal(original_model, loaded_model)
10591088
action(loaded_model)
10601089

@@ -1130,9 +1159,9 @@ def test_move_cpu_to_cuda_device(device):
11301159
device_str = f'cuda:{device_model.index}'
11311160
assert device_str == new_device
11321161

1133-
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
1134-
cebra_model.save(savefile.name)
1135-
loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name)
1162+
with _windows_compatible_tempfile(mode="w+b") as tempname:
1163+
cebra_model.save(tempname)
1164+
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname)
11361165

11371166
assert cebra_model.device == loaded_model.device
11381167
assert next(cebra_model.solver_.model.parameters()).device == next(
@@ -1159,9 +1188,9 @@ def test_move_cpu_to_mps_device(device):
11591188
device_model = next(cebra_model.solver_.model.parameters()).device
11601189
assert device_model.type == new_device
11611190

1162-
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
1163-
cebra_model.save(savefile.name)
1164-
loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name)
1191+
with _windows_compatible_tempfile(mode="w+b") as tempname:
1192+
cebra_model.save(tempname)
1193+
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname)
11651194

11661195
assert cebra_model.device == loaded_model.device
11671196
assert next(cebra_model.solver_.model.parameters()).device == next(
@@ -1198,9 +1227,9 @@ def test_move_mps_to_cuda_device(device):
11981227
device_str = f'cuda:{device_model.index}'
11991228
assert device_str == new_device
12001229

1201-
with tempfile.NamedTemporaryFile(mode="w+b", delete=True) as savefile:
1202-
cebra_model.save(savefile.name)
1203-
loaded_model = cebra_sklearn_cebra.CEBRA.load(savefile.name)
1230+
with _windows_compatible_tempfile(mode="w+b") as tempname:
1231+
cebra_model.save(tempname)
1232+
loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname)
12041233

12051234
assert cebra_model.device == loaded_model.device
12061235
assert next(cebra_model.solver_.model.parameters()).device == next(
@@ -1561,3 +1590,16 @@ def test_non_writable_array():
15611590
embedding = cebra_model.transform(X)
15621591
assert isinstance(embedding, np.ndarray)
15631592
assert embedding.shape[0] == X.shape[0]
1593+
1594+
1595+
def test_read_write():
1596+
X = np.random.randn(100, 10)
1597+
y = np.random.randn(100, 2)
1598+
cebra_model = cebra.CEBRA(max_iterations=2, batch_size=32, device="cpu")
1599+
cebra_model.fit(X, y)
1600+
cebra_model.transform(X)
1601+
1602+
with _windows_compatible_tempfile(mode="w+b", delete=False) as tempname:
1603+
cebra_model.save(tempname)
1604+
loaded_model = cebra.CEBRA.load(tempname)
1605+
_assert_equal(cebra_model, loaded_model)

0 commit comments

Comments
 (0)