24
24
from test_all import set_up
25
25
from torch import nn , optim
26
26
from trainers import (
27
- TrainEvents ,
28
27
create_trainers ,
29
28
evaluate_function ,
30
- train_events_to_attr ,
31
29
train_function ,
32
30
)
33
31
from utils import (
@@ -92,88 +90,6 @@ def test_get_logger(tmp_path):
92
90
assert isinstance (logger_handler , types ), "Should be Ignite provided loggers or None"
93
91
94
92
95
- def test_train_fn ():
96
- model , optimizer , device , loss_fn , batch = set_up ()
97
- engine = Engine (lambda e , b : 1 )
98
- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
99
- backward = MagicMock ()
100
- optim = MagicMock ()
101
- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED , backward )
102
- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED , optim )
103
- config = Namespace (use_amp = False )
104
- output = train_function (config , engine , batch , model , loss_fn , optimizer , device )
105
- assert isinstance (output , dict )
106
- assert hasattr (engine .state , "backward_completed" )
107
- assert hasattr (engine .state , "optim_step_completed" )
108
- assert engine .state .backward_completed == 1
109
- assert engine .state .optim_step_completed == 1
110
- assert backward .call_count == 1
111
- assert optim .call_count == 1
112
- assert backward .called
113
- assert optim .called
114
-
115
-
116
- def test_train_fn_event_filter ():
117
- model , optimizer , device , loss_fn , batch = set_up ()
118
- config = Namespace (use_amp = False )
119
- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
120
- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
121
- backward = MagicMock ()
122
- optim = MagicMock ()
123
- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (event_filter = lambda _ , x : (x % 2 == 0 ) or x == 3 ), backward )
124
- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (event_filter = lambda _ , x : (x % 2 == 0 ) or x == 3 ), optim )
125
- engine .run ([batch ] * 5 )
126
- assert hasattr (engine .state , "backward_completed" )
127
- assert hasattr (engine .state , "optim_step_completed" )
128
- assert engine .state .backward_completed == 5
129
- assert engine .state .optim_step_completed == 5
130
- assert backward .call_count == 3
131
- assert optim .call_count == 3
132
- assert backward .called
133
- assert optim .called
134
-
135
-
136
- def test_train_fn_every ():
137
- model , optimizer , device , loss_fn , batch = set_up ()
138
-
139
- config = Namespace (use_amp = False )
140
- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
141
- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
142
- backward = MagicMock ()
143
- optim = MagicMock ()
144
- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (every = 2 ), backward )
145
- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (every = 2 ), optim )
146
- engine .run ([batch ] * 5 )
147
- assert hasattr (engine .state , "backward_completed" )
148
- assert hasattr (engine .state , "optim_step_completed" )
149
- assert engine .state .backward_completed == 5
150
- assert engine .state .optim_step_completed == 5
151
- assert backward .call_count == 2
152
- assert optim .call_count == 2
153
- assert backward .called
154
- assert optim .called
155
-
156
-
157
- def test_train_fn_once ():
158
- model , optimizer , device , loss_fn , batch = set_up ()
159
- config = Namespace (use_amp = False )
160
- engine = Engine (lambda e , b : train_function (config , e , b , model , loss_fn , optimizer , device ))
161
- engine .register_events (* TrainEvents , event_to_attr = train_events_to_attr )
162
- backward = MagicMock ()
163
- optim = MagicMock ()
164
- engine .add_event_handler (TrainEvents .BACKWARD_COMPLETED (once = 3 ), backward )
165
- engine .add_event_handler (TrainEvents .OPTIM_STEP_COMPLETED (once = 3 ), optim )
166
- engine .run ([batch ] * 5 )
167
- assert hasattr (engine .state , "backward_completed" )
168
- assert hasattr (engine .state , "optim_step_completed" )
169
- assert engine .state .backward_completed == 5
170
- assert engine .state .optim_step_completed == 5
171
- assert backward .call_count == 1
172
- assert optim .call_count == 1
173
- assert backward .called
174
- assert optim .called
175
-
176
-
177
93
def test_evaluate_fn ():
178
94
model , optimizer , device , loss_fn , batch = set_up ()
179
95
engine = Engine (lambda e , b : 1 )
@@ -193,8 +109,6 @@ def test_create_trainers():
193
109
)
194
110
assert isinstance (trainer , Engine )
195
111
assert isinstance (evaluator , Engine )
196
- assert hasattr (trainer .state , "backward_completed" )
197
- assert hasattr (trainer .state , "optim_step_completed" )
198
112
199
113
200
114
def test_get_default_parser ():
0 commit comments