Skip to content

Commit 104bbd0

Browse files
committed
load and update
1 parent 7d274f0 commit 104bbd0

File tree

1 file changed

+116
-1
lines changed

1 file changed

+116
-1
lines changed

CM_GAN_Jan5.ipynb

+116-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"colab": {
66
"provenance": [],
77
"gpuType": "T4",
8-
"authorship_tag": "ABX9TyNJ/r/NM/jYEo7auYrkb7pZ",
8+
"authorship_tag": "ABX9TyPcTRx4Ut5mkhdIsmwBpSfx",
99
"include_colab_link": true
1010
},
1111
"kernelspec": {
@@ -184,6 +184,121 @@
184184
"id": "lT70RwhUq8O-"
185185
}
186186
},
187+
{
188+
"cell_type": "code",
189+
"source": [
190+
"import os\n",
191+
"import numpy as np\n",
192+
"import tensorflow as tf\n",
193+
"from tensorflow.keras.preprocessing.image import load_img, img_to_array\n",
194+
"from skimage.draw import random_shapes\n",
195+
"from sklearn.model_selection import train_test_split\n",
196+
"import matplotlib.pyplot as plt\n",
197+
"\n",
198+
"# Parameters\n",
199+
"DATA_DIR = \"./places2\" # Path to your dataset\n",
200+
"IMG_SIZE = (128, 128) # Target image size\n",
201+
"BATCH_SIZE = 32\n",
202+
"MASK_TYPE = \"random\" # Options: \"random\", \"rectangle\"\n",
203+
"\n",
204+
"# Generate a binary mask\n",
205+
"def generate_mask(img_size, mask_type=\"random\"):\n",
206+
" if mask_type == \"random\":\n",
207+
" # Generate random irregular mask\n",
208+
" mask, _ = random_shapes(img_size, max_shapes=5, min_size=50, max_size=100, multichannel=False)\n",
209+
" mask = (mask == 255).astype(np.float32) # Convert to binary mask\n",
210+
" elif mask_type == \"rectangle\":\n",
211+
" # Generate rectangular mask\n",
212+
" mask = np.ones(img_size, dtype=np.float32)\n",
213+
" x1, y1 = np.random.randint(0, img_size[0] // 2), np.random.randint(0, img_size[1] // 2)\n",
214+
" x2, y2 = np.random.randint(x1, img_size[0]), np.random.randint(y1, img_size[1])\n",
215+
" mask[x1:x2, y1:y2] = 0\n",
216+
" return mask\n",
217+
"\n",
218+
"# Load and preprocess a single image\n",
219+
"def load_and_preprocess_image(img_path, img_size):\n",
220+
" img = load_img(img_path, target_size=img_size)\n",
221+
" img = img_to_array(img) / 255.0 # Normalize to [0, 1]\n",
222+
" return img\n",
223+
"\n",
224+
"# Create a TensorFlow Dataset\n",
225+
"def create_dataset(data_dir, img_size, batch_size, mask_type=\"random\"):\n",
226+
" # Get list of image paths\n",
227+
" image_paths = [os.path.join(data_dir, img_name) for img_name in os.listdir(data_dir)\n",
228+
" if img_name.endswith(('.jpg', '.png', '.jpeg'))]\n",
229+
"\n",
230+
" if not image_paths:\n",
231+
" raise ValueError(f\"No images found in {data_dir}. Please check the dataset path.\")\n",
232+
"\n",
233+
" # Split into training and validation sets\n",
234+
" train_paths, val_paths = train_test_split(image_paths, test_size=0.2, random_state=42)\n",
235+
"\n",
236+
" # Function to load and preprocess images and masks\n",
237+
" def process_image(img_path):\n",
238+
" # Load and preprocess image\n",
239+
" img = tf.numpy_function(load_and_preprocess_image, [img_path, img_size], tf.float32)\n",
240+
" img.set_shape(img_size + (3,)) # Set shape explicitly\n",
241+
"\n",
242+
" # Generate mask\n",
243+
" mask = tf.numpy_function(generate_mask, [img_size, mask_type], tf.float32)\n",
244+
" mask.set_shape(img_size) # Set shape explicitly\n",
245+
"\n",
246+
" # Apply mask to image\n",
247+
" masked_img = img * tf.expand_dims(mask, axis=-1)\n",
248+
"\n",
249+
" return masked_img, img, mask\n",
250+
"\n",
251+
" # Create TensorFlow Dataset\n",
252+
" train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)\n",
253+
" train_dataset = train_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)\n",
254+
" train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n",
255+
"\n",
256+
" val_dataset = tf.data.Dataset.from_tensor_slices(val_paths)\n",
257+
" val_dataset = val_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)\n",
258+
" val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n",
259+
"\n",
260+
" return train_dataset, val_dataset\n",
261+
"\n",
262+
"# Create the dataset\n",
263+
"try:\n",
264+
" train_dataset, val_dataset = create_dataset(DATA_DIR, IMG_SIZE, BATCH_SIZE, MASK_TYPE)\n",
265+
" print(\"Dataset created successfully.\")\n",
266+
"except Exception as e:\n",
267+
" print(f\"Error creating dataset: {e}\")\n",
268+
"\n",
269+
"# Visualize a sample batch\n",
270+
"def visualize_batch(dataset):\n",
271+
" for masked_images, original_images, masks in dataset.take(1):\n",
272+
" plt.figure(figsize=(10, 5))\n",
273+
" for i in range(3): # Display 3 samples\n",
274+
" plt.subplot(3, 3, i * 3 + 1)\n",
275+
" plt.title(\"Masked Image\")\n",
276+
" plt.imshow(masked_images[i])\n",
277+
" plt.axis(\"off\")\n",
278+
"\n",
279+
" plt.subplot(3, 3, i * 3 + 2)\n",
280+
" plt.title(\"Original Image\")\n",
281+
" plt.imshow(original_images[i])\n",
282+
" plt.axis(\"off\")\n",
283+
"\n",
284+
" plt.subplot(3, 3, i * 3 + 3)\n",
285+
" plt.title(\"Mask\")\n",
286+
" plt.imshow(masks[i], cmap='gray')\n",
287+
" plt.axis(\"off\")\n",
288+
" plt.show()\n",
289+
"\n",
290+
"# Visualize a batch from the training dataset\n",
291+
"if 'train_dataset' in locals():\n",
292+
" visualize_batch(train_dataset)\n",
293+
"else:\n",
294+
" print(\"Dataset not available for visualization.\")"
295+
],
296+
"metadata": {
297+
"id": "IB80Tsl8GEDY"
298+
},
299+
"execution_count": null,
300+
"outputs": []
301+
},
187302
{
188303
"cell_type": "markdown",
189304
"source": [

0 commit comments

Comments
 (0)