Skip to content

Commit 0cae63a

Browse files
committed
modified: examples/pytorch/mnist.py
modified: examples/tensorops/mnist.py modified: pyproject.toml modified: readme.md modified: tensorops/__init__.py modified: tensorops/loss.py modified: tensorops/model.py modified: tensorops/optim.py modified: tensorops/src/kernel.rs modified: tensorops/src/main.rs modified: tensorops/src/runtime.rs modified: tensorops/tensor.py modified: tensorops/utils/onnx_exporter.py deleted: test_ce.py deleted: test_ce2.py deleted: test_log.py deleted: test_log2.py deleted: test_log3.py deleted: test_matmul.py new file: test_suite.py
1 parent 22e55c7 commit 0cae63a

20 files changed

Lines changed: 2373 additions & 357 deletions

examples/pytorch/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def forward(self, x):
125125
model = MNISTModel(2, 256).to(device)
126126

127127
model.train()
128-
BATCH_SIZE = 256
128+
BATCH_SIZE = 1024
129129
N_EPOCHS = 100
130130

131131
dataset = TensorDataset(X_train, y_train)

examples/tensorops/mnist.py

Lines changed: 112 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
import struct
88
from urllib.request import urlretrieve
99

10-
import tensorops
10+
# Device optimizer requires DirectInput caching for all params (including biases).
11+
os.environ.setdefault("TENSOROPS_DIRECT_INPUT_CACHE", "1")
12+
os.environ.setdefault("TENSOROPS_DIRECT_INPUT_CACHE_MIN_LEN", "1")
13+
1114
from tensorops.loss import CrossEntropyLoss
12-
from tensorops.optim import AdamW
13-
from tensorops.tensor import Tensor, TensorContext, LeakyReLU
15+
from tensorops.optim import AdamWDevice
16+
from tensorops.tensor import LeakyReLU, Tensor
1417
from tensorops.utils.models import SequentialModel
15-
from tensorops.utils.tensorutils import PlotterUtil
1618

1719
MNIST_URLS = {
1820
"train_images": "https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz",
@@ -48,8 +50,6 @@ def extract_images(filepath):
4850
images = []
4951
for _ in range(num_images):
5052
image = list(f.read(rows * cols))
51-
# Normalise to match PyTorch (divided by 255)
52-
image = [x / 255.0 for x in image]
5353
images.append(image)
5454
return images
5555

@@ -84,7 +84,28 @@ def load_mnist():
8484
return (train_images, train_labels), (test_images, test_labels)
8585

8686

87-
(train_images, train_labels), (test_images, test_labels) = load_mnist()
87+
def _normalise_images(images):
88+
return [[x / 255.0 for x in image] for image in images]
89+
90+
91+
def init_layer_params(layer, rng):
92+
"""Match the PyTorch init_network_params helper (uniform -1..1)."""
93+
weights_out_in = [
94+
[rng.uniform(-1.0, 1.0) for _ in range(layer.num_input_tensors)]
95+
for _ in range(layer.num_output_tensors)
96+
]
97+
# TensorOps expects (in, out) weights for x @ W.
98+
weights_in_out = [list(col) for col in zip(*weights_out_in)]
99+
layer.output_weights.values = weights_in_out
100+
layer.output_bias.values = [
101+
[rng.uniform(-1.0, 1.0) for _ in range(layer.num_output_tensors)]
102+
]
103+
104+
105+
def init_model_params(model, seed=42):
106+
rng = random.Random(seed)
107+
for layer in model.model_layers:
108+
init_layer_params(layer, rng)
88109

89110

90111
class MNISTModel(SequentialModel):
@@ -93,12 +114,11 @@ def __init__(
93114
num_hidden_layers: int,
94115
num_hidden_nodes: int,
95116
loss_criterion,
96-
seed: int | None = None,
97117
activation_function=LeakyReLU,
98118
*,
99119
batch_size: int = 1,
100120
) -> None:
101-
super().__init__(loss_criterion, seed, batch_size=batch_size)
121+
super().__init__(loss_criterion, None, batch_size=batch_size)
102122
self.activation_function = activation_function
103123
self.num_hidden_layers = num_hidden_layers
104124
with self.context:
@@ -107,125 +127,138 @@ def __init__(
107127
self.add_layer(
108128
num_hidden_nodes, num_hidden_nodes, self.activation_function
109129
)
110-
# Final layer emits logits; softmax is handled inside CrossEntropyLoss.
111-
self.add_layer(num_hidden_nodes, 10, None)
130+
# Apply activation on the output layer to match the PyTorch example.
131+
self.add_layer(num_hidden_nodes, 10, self.activation_function)
112132
# CrossEntropyLoss expects (logits, target).
113133
self.loss = self.loss_criterion(
114134
self.model_output_layer.layer_output, self.targets
115135
)
116136

117137
def forward(self, model_inputs: Tensor) -> Tensor: # type: ignore[override]
118138
with self.context:
119-
# Input must be (batch_size, 784) for this model.
120-
for layer in self.model_layers:
121-
layer.forward(model_inputs)
122-
model_inputs = layer.layer_output
123-
return model_inputs
139+
# Update only the input placeholder; the graph is already wired.
140+
if self.model_input_layer is None or self.model_output_layer is None:
141+
raise ValueError("Model layers are not initialised")
142+
if isinstance(model_inputs, Tensor):
143+
self.model_input_layer.layer_input_tensors.values = model_inputs.values
144+
else:
145+
self.model_input_layer.layer_input_tensors.values = model_inputs
146+
return self.model_output_layer.layer_output
124147

125148

126-
with TensorContext(device=tensorops.device.TensorOpsDevice.APPLE) as tc:
149+
if __name__ == "__main__":
150+
random.seed(42)
151+
(train_images, train_labels), (test_images, test_labels) = load_mnist()
152+
153+
train_images = _normalise_images(train_images)
154+
test_images = _normalise_images(test_images)
155+
127156
X_train, y_train, X_test, y_test = (
128-
Tensor(train_images, requires_grad=False),
129-
Tensor(train_labels, requires_grad=False),
130-
Tensor(test_images, requires_grad=False),
131-
Tensor(test_labels, requires_grad=False),
157+
train_images,
158+
train_labels,
159+
test_images,
160+
test_labels,
132161
)
133162

134-
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
135-
136-
# impl model
137163
BATCH_SIZE = 256
138-
N_EPOCHS = 5
164+
N_EPOCHS = 100
139165

140166
model = MNISTModel(
141167
2,
142168
256,
143169
CrossEntropyLoss(),
144-
seed=42,
145170
batch_size=BATCH_SIZE,
146171
activation_function=LeakyReLU,
147172
)
148-
optim = AdamW(model.get_weights(), lr=2e-4)
149-
# Enable gradient clipping to stabilise updates
150-
optim.grad_clip_norm = 0.5
151-
optim.grad_clip_value = 0.5
173+
init_model_params(model, seed=42)
174+
152175
model.train()
176+
optim = AdamWDevice(model.get_weights(), lr=2e-4)
153177

154178
# Helper to create fixed-size batches (graph uses a fixed batch_size)
155-
def _one_hot(label: int, num_classes: int = 10) -> list[float]:
156-
vec = [0.0] * num_classes
157-
vec[int(label)] = 1.0
158-
return vec
159-
160-
def _has_nonfinite_grad(params: list[Tensor]) -> bool:
161-
import numpy as np
162-
163-
for p in params:
164-
g = getattr(p, "grads", None)
165-
if g is None:
166-
continue
167-
src = g.flat if getattr(g, "flat", None) is not None else g.values
168-
if src is None:
169-
continue
170-
arr = np.array(src, dtype=float)
171-
if not np.isfinite(arr).all():
172-
return True
173-
return False
174-
175-
def get_batches(images: Tensor, labels: Tensor, batch_size: int):
179+
def _one_hot_labels(labels: list[int], num_classes: int = 10) -> list[list[float]]:
180+
one_hot = [[0.0] * num_classes for _ in labels]
181+
for i, lbl in enumerate(labels):
182+
one_hot[i][int(lbl)] = 1.0
183+
return one_hot
184+
185+
y_train_one_hot = _one_hot_labels(y_train)
186+
y_test_one_hot = _one_hot_labels(y_test)
187+
188+
def get_batches(images, labels_one_hot, batch_size: int, *, shuffle=True):
176189
"""Yield full (batch_size, 784) images and (batch_size, 10) one-hot labels."""
177-
assert images.values is not None and labels.values is not None
178-
n_samples = len(images.values)
190+
n_samples = len(images)
179191
indices = list(range(n_samples))
180-
random.shuffle(indices)
192+
if shuffle:
193+
random.shuffle(indices)
181194

182195
# Drop the last partial batch to keep shapes constant.
183196
for start_idx in range(0, n_samples - batch_size + 1, batch_size):
184197
batch_indices = indices[start_idx : start_idx + batch_size]
185-
batch_images = [images.values[i] for i in batch_indices]
186-
batch_labels = [_one_hot(int(labels.values[i])) for i in batch_indices]
187-
yield (
188-
Tensor(batch_images, requires_grad=False),
189-
Tensor(batch_labels, requires_grad=False),
198+
batch_images = [images[i] for i in batch_indices]
199+
batch_labels = [labels_one_hot[i] for i in batch_indices]
200+
yield batch_images, batch_labels
201+
202+
def get_eval_batches(images, labels_one_hot, batch_size: int):
203+
"""Yield eval batches, padding the last batch to keep shapes constant."""
204+
n_samples = len(images)
205+
206+
for start_idx in range(0, n_samples, batch_size):
207+
batch_indices = list(
208+
range(start_idx, min(start_idx + batch_size, n_samples))
190209
)
210+
valid_count = len(batch_indices)
211+
if valid_count < batch_size:
212+
batch_indices.extend([batch_indices[-1]] * (batch_size - valid_count))
213+
214+
batch_images = [images[i] for i in batch_indices]
215+
batch_labels = [labels_one_hot[i] for i in batch_indices]
216+
yield batch_images, batch_labels, valid_count
217+
218+
# Reuse model input/target tensors to avoid per-batch Tensor allocations.
219+
input_tensor = model.model_input_layer.layer_input_tensors
220+
assert model.targets is not None
221+
target_tensor = model.targets
191222

192223
for epoch in range(N_EPOCHS):
193224
if epoch % 10 == 0:
194225
print(f"Epoch {epoch + 1}")
195226

196-
for id_batch, (X_batch, y_batch) in enumerate(get_batches(X_train, y_train, BATCH_SIZE)):
227+
for id_batch, (batch_images, batch_labels) in enumerate(
228+
get_batches(X_train, y_train_one_hot, BATCH_SIZE, shuffle=True)
229+
):
197230
model.zero_grad()
198231

199-
logits = model(X_batch, execute=False)
200-
assert model.targets is not None
201-
model.targets.values = y_batch.values
232+
input_tensor.values = batch_images
233+
target_tensor.values = batch_labels
202234

203235
model.context.forward(recompute=True)
204236
loss = model.loss
205-
model.backward()
206-
optim.step()
237+
model.backward(device_optim=optim)
207238

208-
if id_batch % 100 == 0:
239+
if id_batch % 250 == 0:
209240
loss_value = loss.item()
210241
print(f"Loss: {loss_value:.4f}")
211242

243+
model.eval()
212244
correct = 0
213245
total = 0
214246
import numpy as np
215-
216-
# Disable dropout/etc if eval existed, here just forward pass
217-
for X_batch, y_batch in get_batches(X_test, y_test, BATCH_SIZE):
218-
logits = model(X_batch, execute=False)
247+
248+
for batch_images, batch_labels, valid_count in get_eval_batches(
249+
X_test, y_test_one_hot, BATCH_SIZE
250+
):
251+
input_tensor.values = batch_images
219252
model.context.forward(recompute=True)
220-
221-
vals = np.array(logits.flat)
222-
vals = vals.reshape((BATCH_SIZE, 10))
223-
predicted = np.argmax(vals, axis=1)
224-
225-
y_vals = np.array(y_batch.flat).reshape((BATCH_SIZE, 10))
226-
target = np.argmax(y_vals, axis=1)
227-
253+
logits = model.model_output_layer.layer_output
254+
255+
vals = np.array(logits.flat).reshape((BATCH_SIZE, 10))
256+
predicted = np.argmax(vals, axis=1)[:valid_count]
257+
258+
y_vals = np.array(batch_labels).reshape((BATCH_SIZE, 10))
259+
target = np.argmax(y_vals, axis=1)[:valid_count]
260+
228261
correct += np.sum(predicted == target)
229-
total += len(target)
262+
total += valid_count
230263

231264
print(f"Test Accuracy: {correct / total:.4f}")

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ dependencies = [
88
"matplotlib==3.10.5",
99
"matplotlib-inline==0.1.6",
1010
"networkx==3.0rc1",
11+
"onnx==1.14.1",
1112
"tqdm==4.66.5",
1213
"setuptools==75.8.0",
1314
"maturin==1.8.3",

0 commit comments

Comments
 (0)