1919# See the License for the specific language governing permissions and
2020# limitations under the License.
2121#
22+ import contextlib
2223import itertools
24+ import os
2325import tempfile
2426import warnings
2527
4749 _DEVICES = ("cpu" ,)
4850
4951
52+ @contextlib .contextmanager
53+ def _windows_compatible_tempfile (mode = "w+b" , delete = True , ** kwargs ):
54+ """Context manager for creating temporary files compatible with Windows.
55+
56+ On Windows, files opened with delete=True cannot be accessed by other
57+ processes or reopened. This context manager creates a temporary file
58+ with delete=False, yields its path, and ensures cleanup in a finally block.
59+
60+ Args:
61+ mode: File mode (default: "w+b")
62+ **kwargs: Additional arguments passed to NamedTemporaryFile
63+
64+ Yields:
65+ str: Path to the temporary file
66+ """
67+ if not delete :
68+ raise ValueError ("'delete' must be True" )
69+
70+ with tempfile .NamedTemporaryFile (mode = mode , delete = False , ** kwargs ) as f :
71+ tempname = f .name
72+
73+ try :
74+ yield tempname
75+ finally :
76+ if os .path .exists (tempname ):
77+ os .remove (tempname )
78+
79+
5080def test_imports ():
5181 import cebra
5282
@@ -1037,24 +1067,23 @@ def test_save_and_load(action, backend_save, backend_load, model_architecture,
10371067 device = device )
10381068
10391069 original_model = action (original_model )
1040- with tempfile . NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
1070+ with _windows_compatible_tempfile (mode = "w+b" ) as tempname :
10411071 if not check_if_fit (original_model ):
10421072 with pytest .raises (ValueError ):
1043- original_model .save (savefile . name , backend = backend_save )
1073+ original_model .save (tempname , backend = backend_save )
10441074 else :
10451075 if "parametrized" in original_model .model_architecture and backend_save == "torch" :
10461076 with pytest .raises (AttributeError ):
1047- original_model .save (savefile . name , backend = backend_save )
1077+ original_model .save (tempname , backend = backend_save )
10481078 else :
1049- original_model .save (savefile . name , backend = backend_save )
1079+ original_model .save (tempname , backend = backend_save )
10501080
10511081 if (backend_load != "auto" ) and (backend_save != backend_load ):
10521082 with pytest .raises (RuntimeError ):
1053- cebra_sklearn_cebra .CEBRA .load (savefile .name ,
1054- backend_load )
1083+ cebra_sklearn_cebra .CEBRA .load (tempname , backend_load )
10551084 else :
10561085 loaded_model = cebra_sklearn_cebra .CEBRA .load (
1057- savefile . name , backend_load )
1086+ tempname , backend_load )
10581087 _assert_equal (original_model , loaded_model )
10591088 action (loaded_model )
10601089
@@ -1130,9 +1159,9 @@ def test_move_cpu_to_cuda_device(device):
11301159 device_str = f'cuda:{ device_model .index } '
11311160 assert device_str == new_device
11321161
1133- with tempfile . NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
1134- cebra_model .save (savefile . name )
1135- loaded_model = cebra_sklearn_cebra .CEBRA .load (savefile . name )
1162+ with _windows_compatible_tempfile (mode = "w+b" ) as tempname :
1163+ cebra_model .save (tempname )
1164+ loaded_model = cebra_sklearn_cebra .CEBRA .load (tempname )
11361165
11371166 assert cebra_model .device == loaded_model .device
11381167 assert next (cebra_model .solver_ .model .parameters ()).device == next (
@@ -1159,9 +1188,9 @@ def test_move_cpu_to_mps_device(device):
11591188 device_model = next (cebra_model .solver_ .model .parameters ()).device
11601189 assert device_model .type == new_device
11611190
1162- with tempfile . NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
1163- cebra_model .save (savefile . name )
1164- loaded_model = cebra_sklearn_cebra .CEBRA .load (savefile . name )
1191+ with _windows_compatible_tempfile (mode = "w+b" ) as tempname :
1192+ cebra_model .save (tempname )
1193+ loaded_model = cebra_sklearn_cebra .CEBRA .load (tempname )
11651194
11661195 assert cebra_model .device == loaded_model .device
11671196 assert next (cebra_model .solver_ .model .parameters ()).device == next (
@@ -1198,9 +1227,9 @@ def test_move_mps_to_cuda_device(device):
11981227 device_str = f'cuda:{ device_model .index } '
11991228 assert device_str == new_device
12001229
1201- with tempfile . NamedTemporaryFile (mode = "w+b" , delete = True ) as savefile :
1202- cebra_model .save (savefile . name )
1203- loaded_model = cebra_sklearn_cebra .CEBRA .load (savefile . name )
1230+ with _windows_compatible_tempfile (mode = "w+b" ) as tempname :
1231+ cebra_model .save (tempname )
1232+ loaded_model = cebra_sklearn_cebra .CEBRA .load (tempname )
12041233
12051234 assert cebra_model .device == loaded_model .device
12061235 assert next (cebra_model .solver_ .model .parameters ()).device == next (
@@ -1561,3 +1590,16 @@ def test_non_writable_array():
15611590 embedding = cebra_model .transform (X )
15621591 assert isinstance (embedding , np .ndarray )
15631592 assert embedding .shape [0 ] == X .shape [0 ]
1593+
1594+
1595+ def test_read_write ():
1596+ X = np .random .randn (100 , 10 )
1597+ y = np .random .randn (100 , 2 )
1598+ cebra_model = cebra .CEBRA (max_iterations = 2 , batch_size = 32 , device = "cpu" )
1599+ cebra_model .fit (X , y )
1600+ cebra_model .transform (X )
1601+
1602+ with _windows_compatible_tempfile (mode = "w+b" , delete = False ) as tempname :
1603+ cebra_model .save (tempname )
1604+ loaded_model = cebra .CEBRA .load (tempname )
1605+ _assert_equal (cebra_model , loaded_model )
0 commit comments