77import struct
88from urllib .request import urlretrieve
99
10- import tensorops
10+ # Device optimizer requires DirectInput caching for all params (including biases).
11+ os .environ .setdefault ("TENSOROPS_DIRECT_INPUT_CACHE" , "1" )
12+ os .environ .setdefault ("TENSOROPS_DIRECT_INPUT_CACHE_MIN_LEN" , "1" )
13+
1114from tensorops .loss import CrossEntropyLoss
12- from tensorops .optim import AdamW
13- from tensorops .tensor import Tensor , TensorContext , LeakyReLU
15+ from tensorops .optim import AdamWDevice
16+ from tensorops .tensor import LeakyReLU , Tensor
1417from tensorops .utils .models import SequentialModel
15- from tensorops .utils .tensorutils import PlotterUtil
1618
1719MNIST_URLS = {
1820 "train_images" : "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz" ,
@@ -48,8 +50,6 @@ def extract_images(filepath):
4850 images = []
4951 for _ in range (num_images ):
5052 image = list (f .read (rows * cols ))
51- # Normalise to match PyTorch (divided by 255)
52- image = [x / 255.0 for x in image ]
5353 images .append (image )
5454 return images
5555
@@ -84,7 +84,28 @@ def load_mnist():
8484 return (train_images , train_labels ), (test_images , test_labels )
8585
8686
87- (train_images , train_labels ), (test_images , test_labels ) = load_mnist ()
87+ def _normalise_images (images ):
88+ return [[x / 255.0 for x in image ] for image in images ]
89+
90+
91+ def init_layer_params (layer , rng ):
92+ """Match the PyTorch init_network_params helper (uniform -1..1)."""
93+ weights_out_in = [
94+ [rng .uniform (- 1.0 , 1.0 ) for _ in range (layer .num_input_tensors )]
95+ for _ in range (layer .num_output_tensors )
96+ ]
97+ # TensorOps expects (in, out) weights for x @ W.
98+ weights_in_out = [list (col ) for col in zip (* weights_out_in )]
99+ layer .output_weights .values = weights_in_out
100+ layer .output_bias .values = [
101+ [rng .uniform (- 1.0 , 1.0 ) for _ in range (layer .num_output_tensors )]
102+ ]
103+
104+
105+ def init_model_params (model , seed = 42 ):
106+ rng = random .Random (seed )
107+ for layer in model .model_layers :
108+ init_layer_params (layer , rng )
88109
89110
90111class MNISTModel (SequentialModel ):
@@ -93,12 +114,11 @@ def __init__(
93114 num_hidden_layers : int ,
94115 num_hidden_nodes : int ,
95116 loss_criterion ,
96- seed : int | None = None ,
97117 activation_function = LeakyReLU ,
98118 * ,
99119 batch_size : int = 1 ,
100120 ) -> None :
101- super ().__init__ (loss_criterion , seed , batch_size = batch_size )
121+ super ().__init__ (loss_criterion , None , batch_size = batch_size )
102122 self .activation_function = activation_function
103123 self .num_hidden_layers = num_hidden_layers
104124 with self .context :
@@ -107,125 +127,138 @@ def __init__(
107127 self .add_layer (
108128 num_hidden_nodes , num_hidden_nodes , self .activation_function
109129 )
110- # Final layer emits logits; softmax is handled inside CrossEntropyLoss .
111- self .add_layer (num_hidden_nodes , 10 , None )
130+ # Apply activation on the output layer to match the PyTorch example .
131+ self .add_layer (num_hidden_nodes , 10 , self . activation_function )
112132 # CrossEntropyLoss expects (logits, target).
113133 self .loss = self .loss_criterion (
114134 self .model_output_layer .layer_output , self .targets
115135 )
116136
117137 def forward (self , model_inputs : Tensor ) -> Tensor : # type: ignore[override]
118138 with self .context :
119- # Input must be (batch_size, 784) for this model.
120- for layer in self .model_layers :
121- layer .forward (model_inputs )
122- model_inputs = layer .layer_output
123- return model_inputs
139+ # Update only the input placeholder; the graph is already wired.
140+ if self .model_input_layer is None or self .model_output_layer is None :
141+ raise ValueError ("Model layers are not initialised" )
142+ if isinstance (model_inputs , Tensor ):
143+ self .model_input_layer .layer_input_tensors .values = model_inputs .values
144+ else :
145+ self .model_input_layer .layer_input_tensors .values = model_inputs
146+ return self .model_output_layer .layer_output
124147
125148
126- with TensorContext (device = tensorops .device .TensorOpsDevice .APPLE ) as tc :
149+ if __name__ == "__main__" :
150+ random .seed (42 )
151+ (train_images , train_labels ), (test_images , test_labels ) = load_mnist ()
152+
153+ train_images = _normalise_images (train_images )
154+ test_images = _normalise_images (test_images )
155+
127156 X_train , y_train , X_test , y_test = (
128- Tensor ( train_images , requires_grad = False ) ,
129- Tensor ( train_labels , requires_grad = False ) ,
130- Tensor ( test_images , requires_grad = False ) ,
131- Tensor ( test_labels , requires_grad = False ) ,
157+ train_images ,
158+ train_labels ,
159+ test_images ,
160+ test_labels ,
132161 )
133162
134- print (X_train .shape , y_train .shape , X_test .shape , y_test .shape )
135-
136- # impl model
137163 BATCH_SIZE = 256
138- N_EPOCHS = 5
164+ N_EPOCHS = 100
139165
140166 model = MNISTModel (
141167 2 ,
142168 256 ,
143169 CrossEntropyLoss (),
144- seed = 42 ,
145170 batch_size = BATCH_SIZE ,
146171 activation_function = LeakyReLU ,
147172 )
148- optim = AdamW (model .get_weights (), lr = 2e-4 )
149- # Enable gradient clipping to stabilise updates
150- optim .grad_clip_norm = 0.5
151- optim .grad_clip_value = 0.5
173+ init_model_params (model , seed = 42 )
174+
152175 model .train ()
176+ optim = AdamWDevice (model .get_weights (), lr = 2e-4 )
153177
154178 # Helper to create fixed-size batches (graph uses a fixed batch_size)
155- def _one_hot (label : int , num_classes : int = 10 ) -> list [float ]:
156- vec = [0.0 ] * num_classes
157- vec [int (label )] = 1.0
158- return vec
159-
160- def _has_nonfinite_grad (params : list [Tensor ]) -> bool :
161- import numpy as np
162-
163- for p in params :
164- g = getattr (p , "grads" , None )
165- if g is None :
166- continue
167- src = g .flat if getattr (g , "flat" , None ) is not None else g .values
168- if src is None :
169- continue
170- arr = np .array (src , dtype = float )
171- if not np .isfinite (arr ).all ():
172- return True
173- return False
174-
175- def get_batches (images : Tensor , labels : Tensor , batch_size : int ):
179+ def _one_hot_labels (labels : list [int ], num_classes : int = 10 ) -> list [list [float ]]:
180+ one_hot = [[0.0 ] * num_classes for _ in labels ]
181+ for i , lbl in enumerate (labels ):
182+ one_hot [i ][int (lbl )] = 1.0
183+ return one_hot
184+
185+ y_train_one_hot = _one_hot_labels (y_train )
186+ y_test_one_hot = _one_hot_labels (y_test )
187+
188+ def get_batches (images , labels_one_hot , batch_size : int , * , shuffle = True ):
176189 """Yield full (batch_size, 784) images and (batch_size, 10) one-hot labels."""
177- assert images .values is not None and labels .values is not None
178- n_samples = len (images .values )
190+ n_samples = len (images )
179191 indices = list (range (n_samples ))
180- random .shuffle (indices )
192+ if shuffle :
193+ random .shuffle (indices )
181194
182195 # Drop the last partial batch to keep shapes constant.
183196 for start_idx in range (0 , n_samples - batch_size + 1 , batch_size ):
184197 batch_indices = indices [start_idx : start_idx + batch_size ]
185- batch_images = [images .values [i ] for i in batch_indices ]
186- batch_labels = [_one_hot (int (labels .values [i ])) for i in batch_indices ]
187- yield (
188- Tensor (batch_images , requires_grad = False ),
189- Tensor (batch_labels , requires_grad = False ),
198+ batch_images = [images [i ] for i in batch_indices ]
199+ batch_labels = [labels_one_hot [i ] for i in batch_indices ]
200+ yield batch_images , batch_labels
201+
202+ def get_eval_batches (images , labels_one_hot , batch_size : int ):
203+ """Yield eval batches, padding the last batch to keep shapes constant."""
204+ n_samples = len (images )
205+
206+ for start_idx in range (0 , n_samples , batch_size ):
207+ batch_indices = list (
208+ range (start_idx , min (start_idx + batch_size , n_samples ))
190209 )
210+ valid_count = len (batch_indices )
211+ if valid_count < batch_size :
212+ batch_indices .extend ([batch_indices [- 1 ]] * (batch_size - valid_count ))
213+
214+ batch_images = [images [i ] for i in batch_indices ]
215+ batch_labels = [labels_one_hot [i ] for i in batch_indices ]
216+ yield batch_images , batch_labels , valid_count
217+
218+ # Reuse model input/target tensors to avoid per-batch Tensor allocations.
219+ input_tensor = model .model_input_layer .layer_input_tensors
220+ assert model .targets is not None
221+ target_tensor = model .targets
191222
192223 for epoch in range (N_EPOCHS ):
193224 if epoch % 10 == 0 :
194225 print (f"Epoch { epoch + 1 } " )
195226
196- for id_batch , (X_batch , y_batch ) in enumerate (get_batches (X_train , y_train , BATCH_SIZE )):
227+ for id_batch , (batch_images , batch_labels ) in enumerate (
228+ get_batches (X_train , y_train_one_hot , BATCH_SIZE , shuffle = True )
229+ ):
197230 model .zero_grad ()
198231
199- logits = model (X_batch , execute = False )
200- assert model .targets is not None
201- model .targets .values = y_batch .values
232+ input_tensor .values = batch_images
233+ target_tensor .values = batch_labels
202234
203235 model .context .forward (recompute = True )
204236 loss = model .loss
205- model .backward ()
206- optim .step ()
237+ model .backward (device_optim = optim )
207238
208- if id_batch % 100 == 0 :
239+ if id_batch % 250 == 0 :
209240 loss_value = loss .item ()
210241 print (f"Loss: { loss_value :.4f} " )
211242
243+ model .eval ()
212244 correct = 0
213245 total = 0
214246 import numpy as np
215-
216- # Disable dropout/etc if eval existed, here just forward pass
217- for X_batch , y_batch in get_batches (X_test , y_test , BATCH_SIZE ):
218- logits = model (X_batch , execute = False )
247+
248+ for batch_images , batch_labels , valid_count in get_eval_batches (
249+ X_test , y_test_one_hot , BATCH_SIZE
250+ ):
251+ input_tensor .values = batch_images
219252 model .context .forward (recompute = True )
220-
221- vals = np . array ( logits . flat )
222- vals = vals .reshape ((BATCH_SIZE , 10 ))
223- predicted = np .argmax (vals , axis = 1 )
224-
225- y_vals = np .array (y_batch . flat ).reshape ((BATCH_SIZE , 10 ))
226- target = np .argmax (y_vals , axis = 1 )
227-
253+ logits = model . model_output_layer . layer_output
254+
255+ vals = np . array ( logits . flat ) .reshape ((BATCH_SIZE , 10 ))
256+ predicted = np .argmax (vals , axis = 1 )[: valid_count ]
257+
258+ y_vals = np .array (batch_labels ).reshape ((BATCH_SIZE , 10 ))
259+ target = np .argmax (y_vals , axis = 1 )[: valid_count ]
260+
228261 correct += np .sum (predicted == target )
229- total += len ( target )
262+ total += valid_count
230263
231264 print (f"Test Accuracy: { correct / total :.4f} " )
0 commit comments