You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
If Jax-ppl can help me on it, I will be very grateful on this! (I increased batchsize etc., but it didnt help us)
Any hint will be greatly appreciated.
Thanks in advance
k
import jax
import jax.numpy as jnp
from functools import partial
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from PIL import Image
import pickle
from jax.nn import swish, logsumexp
import time
# Initialize the slice
jax.distributed.initialize()
tpu_devices = jax.devices("tpu")
# Print information about each found TPU device
if tpu_devices:
print("Available TPU devices:")
for i, device in enumerate(tpu_devices):
print(f" Device {i}: {device}")
else:
print("No TPU devices found.")
#---------------------------------------------------------------------
from tensorflow.keras import layers,models
import keras
train_ds=keras.preprocessing.image_dataset_from_directory(
"/home/martin/mnist_extracted/train",
image_size=(28,28),
batch_size=None,
color_mode='grayscale'
)
test_ds=keras.preprocessing.image_dataset_from_directory(
"/home/martin/mnist_extracted/test",
image_size=(28,28),
batch_size=None,
color_mode='grayscale',
shuffle=False,
)
data_train=train_ds
data_test=test_ds
#---------------------------------------------------------------------
HEIGHT = 28
WIDTH = 28
CHANNELS = 1
NUM_PIXELS = HEIGHT * WIDTH * CHANNELS
NUM_LABELS = 10
NUM_DEVICES = jax.local_device_count() # jax.device_count()
BATCH_SIZE = 512 #128 #32
def preprocess(img, label):
"""Resize and preprocess images."""
return (tf.cast(img, tf.float32)/255.0), label
train_data = tfds.as_numpy(
data_train.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
test_data = tfds.as_numpy(
data_test.map(preprocess).batch(NUM_DEVICES*BATCH_SIZE).prefetch(1)
)
from jax import grad, jit, vmap, value_and_grad
from jax import random
from jax.nn import swish, logsumexp, one_hot
LAYER_SIZES = [28*28, 512, 10]
PARAM_SCALE = 0.01
def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):
def random_layer_params(m, n, key, scale=1e-2):
w_key, b_key = random.split(key)
return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))
keys = random.split(key, len(sizes))
return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]
init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)
def predict(params, image):
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = swish(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits
batched_predict = vmap(predict, in_axes=(None, 0))
# ML PARAMETERS
INIT_LR = 1E-3 # LATER IT WAS 1.0
DECAY_RATE = 0.95
DECAY_STEPS = 5
NUM_EPOCHS = 50
from functools import partial
def loss(params, images, targets):
logits = batched_predict(params, images)
# FIX: Add axis=1, keepdims=True so we normalize per-image, not per-batch
log_preds = logits - logsumexp(logits, axis=1, keepdims=True)
return -jnp.mean(targets * log_preds)
@partial(jax.pmap, axis_name='devices', in_axes=(None, 0, 0, None), out_axes=(None,0))
def update(params, x, y, epoch_number):
loss_value, grads = value_and_grad(loss)(params, x, y)
grads = [(jax.lax.pmean(dw, 'devices'), jax.lax.pmean(db, 'devices')) for dw, db in grads]
current_lr = 0.1 * DECAY_RATE ** (epoch_number / DECAY_STEPS)
return [(w - current_lr * dw, b - current_lr * db) for (w, b), (dw, db) in zip(params, grads)], loss_value
train_data_iter = iter(train_data)
x, y = next(train_data_iter)
x = jnp.reshape(x, (NUM_DEVICES, BATCH_SIZE, NUM_PIXELS))
y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, BATCH_SIZE, NUM_LABELS))
x.shape, y.shape
updated_params, loss_value = update(init_params, x, y, 0)
@jit
def batch_accuracy(params, images, targets):
images = jnp.reshape(images, (len(images), NUM_PIXELS))
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
return jnp.mean(predicted_class == targets)
def accuracy(params, data):
accs = []
for images, targets in data:
accs.append(batch_accuracy(params, images, targets))
return jnp.mean(jnp.array(accs))
import time
print("INITIAL TIME--------------------------:Curent Time",time.time())
btim=time.time()
params = init_params
for epoch in range(NUM_EPOCHS):
start_time = time.time()
losses = []
for x, y in train_data:
num_elements = len(y)
x = jnp.reshape(x, (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_PIXELS))
y = jnp.reshape(one_hot(y, NUM_LABELS), (NUM_DEVICES, num_elements//NUM_DEVICES, NUM_LABELS))
params, loss_value = update(params, x, y, epoch)
losses.append(jnp.sum(loss_value))
epoch_time = time.time() - start_time
train_acc = accuracy(params, train_data)
test_acc = accuracy(params, test_data)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set loss {}".format(jnp.mean(jnp.array(losses))))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
print("END -------------------------------- Elapsed Time:",time.time()-btim)
#===============
#SAVE TO MODEL
#===============
with open("mnist_params.pkl", "wb") as f:
pickle.dump(params, f)
time.sleep(5) # need this
# LOAD
with open("mnist_params.pkl", "rb") as f:
params = pickle.load(f)
def predict(params, image):
activations = image
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = swish(outputs)
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits
# ------------ Load your PNG image ------------
img = Image.open("/home/martin/mnist_extracted/test/7/12.png").convert("L")
img = img.resize((28, 28))
img_arr = np.array(img, dtype=np.float32) / 255.0
# Optional inversion if needed
if img_arr.mean() > 0.5:
img_arr = 1.0 - img_arr
img_flat = img_arr.reshape(-1)
img_jax = jnp.array(img_flat)
# ------------ Predict ------------
logits = predict(params, img_jax)
digit = int(jnp.argmax(logits))
print("\nPredicted digit:", digit)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
-
Hi list,
I have been testing TPU slices (v3-32, 4 workers) with a simple MNIST dataset ML. The Mnist data got from Kaggle as:
And it was put (in every worker) as:
/home/martin/mnist_extracted/test
/home/martin/mnist_extracted/train
/home/martin/mnist_extracted/val
My code (pls see below) ran as:
It seems it WORKS but slowly for a TPU slice!
If Jax-ppl can help me on it, I will be very grateful on this! (I increased batchsize etc., but it didnt help us)
Any hint will be greatly appreciated.
Thanks in advance
k
Beta Was this translation helpful? Give feedback.
All reactions