|
5 | 5 | "colab": {
|
6 | 6 | "provenance": [],
|
7 | 7 | "gpuType": "T4",
|
8 |
| - "authorship_tag": "ABX9TyNJ/r/NM/jYEo7auYrkb7pZ", |
| 8 | + "authorship_tag": "ABX9TyPcTRx4Ut5mkhdIsmwBpSfx", |
9 | 9 | "include_colab_link": true
|
10 | 10 | },
|
11 | 11 | "kernelspec": {
|
|
184 | 184 | "id": "lT70RwhUq8O-"
|
185 | 185 | }
|
186 | 186 | },
|
| 187 | + { |
| 188 | + "cell_type": "code", |
| 189 | + "source": [ |
| 190 | + "import os\n", |
| 191 | + "import numpy as np\n", |
| 192 | + "import tensorflow as tf\n", |
| 193 | + "from tensorflow.keras.preprocessing.image import load_img, img_to_array\n", |
| 194 | + "from skimage.draw import random_shapes\n", |
| 195 | + "from sklearn.model_selection import train_test_split\n", |
| 196 | + "import matplotlib.pyplot as plt\n", |
| 197 | + "\n", |
| 198 | + "# Parameters\n", |
| 199 | + "DATA_DIR = \"./places2\" # Path to your dataset\n", |
| 200 | + "IMG_SIZE = (128, 128) # Target image size\n", |
| 201 | + "BATCH_SIZE = 32\n", |
| 202 | + "MASK_TYPE = \"random\" # Options: \"random\", \"rectangle\"\n", |
| 203 | + "\n", |
| 204 | + "# Generate a binary mask\n", |
| 205 | + "def generate_mask(img_size, mask_type=\"random\"):\n", |
| 206 | + " if mask_type == \"random\":\n", |
| 207 | + " # Generate random irregular mask\n", |
| 208 | + " mask, _ = random_shapes(img_size, max_shapes=5, min_size=50, max_size=100, multichannel=False)\n", |
| 209 | + " mask = (mask == 255).astype(np.float32) # Convert to binary mask\n", |
| 210 | + " elif mask_type == \"rectangle\":\n", |
| 211 | + " # Generate rectangular mask\n", |
| 212 | + " mask = np.ones(img_size, dtype=np.float32)\n", |
| 213 | + " x1, y1 = np.random.randint(0, img_size[0] // 2), np.random.randint(0, img_size[1] // 2)\n", |
| 214 | + " x2, y2 = np.random.randint(x1, img_size[0]), np.random.randint(y1, img_size[1])\n", |
| 215 | + " mask[x1:x2, y1:y2] = 0\n", |
| 216 | + " return mask\n", |
| 217 | + "\n", |
| 218 | + "# Load and preprocess a single image\n", |
| 219 | + "def load_and_preprocess_image(img_path, img_size):\n", |
| 220 | + " img = load_img(img_path, target_size=img_size)\n", |
| 221 | + " img = img_to_array(img) / 255.0 # Normalize to [0, 1]\n", |
| 222 | + " return img\n", |
| 223 | + "\n", |
| 224 | + "# Create a TensorFlow Dataset\n", |
| 225 | + "def create_dataset(data_dir, img_size, batch_size, mask_type=\"random\"):\n", |
| 226 | + " # Get list of image paths\n", |
| 227 | + " image_paths = [os.path.join(data_dir, img_name) for img_name in os.listdir(data_dir)\n", |
| 228 | + " if img_name.endswith(('.jpg', '.png', '.jpeg'))]\n", |
| 229 | + "\n", |
| 230 | + " if not image_paths:\n", |
| 231 | + " raise ValueError(f\"No images found in {data_dir}. Please check the dataset path.\")\n", |
| 232 | + "\n", |
| 233 | + " # Split into training and validation sets\n", |
| 234 | + " train_paths, val_paths = train_test_split(image_paths, test_size=0.2, random_state=42)\n", |
| 235 | + "\n", |
| 236 | + " # Function to load and preprocess images and masks\n", |
| 237 | + " def process_image(img_path):\n", |
| 238 | + " # Load and preprocess image\n", |
| 239 | + " img = tf.numpy_function(load_and_preprocess_image, [img_path, img_size], tf.float32)\n", |
| 240 | + " img.set_shape(img_size + (3,)) # Set shape explicitly\n", |
| 241 | + "\n", |
| 242 | + " # Generate mask\n", |
| 243 | + " mask = tf.numpy_function(generate_mask, [img_size, mask_type], tf.float32)\n", |
| 244 | + " mask.set_shape(img_size) # Set shape explicitly\n", |
| 245 | + "\n", |
| 246 | + " # Apply mask to image\n", |
| 247 | + " masked_img = img * tf.expand_dims(mask, axis=-1)\n", |
| 248 | + "\n", |
| 249 | + " return masked_img, img, mask\n", |
| 250 | + "\n", |
| 251 | + " # Create TensorFlow Dataset\n", |
| 252 | + " train_dataset = tf.data.Dataset.from_tensor_slices(train_paths)\n", |
| 253 | + " train_dataset = train_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)\n", |
| 254 | + " train_dataset = train_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n", |
| 255 | + "\n", |
| 256 | + " val_dataset = tf.data.Dataset.from_tensor_slices(val_paths)\n", |
| 257 | + " val_dataset = val_dataset.map(process_image, num_parallel_calls=tf.data.AUTOTUNE)\n", |
| 258 | + " val_dataset = val_dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)\n", |
| 259 | + "\n", |
| 260 | + " return train_dataset, val_dataset\n", |
| 261 | + "\n", |
| 262 | + "# Create the dataset\n", |
| 263 | + "try:\n", |
| 264 | + " train_dataset, val_dataset = create_dataset(DATA_DIR, IMG_SIZE, BATCH_SIZE, MASK_TYPE)\n", |
| 265 | + " print(\"Dataset created successfully.\")\n", |
| 266 | + "except Exception as e:\n", |
| 267 | + " print(f\"Error creating dataset: {e}\")\n", |
| 268 | + "\n", |
| 269 | + "# Visualize a sample batch\n", |
| 270 | + "def visualize_batch(dataset):\n", |
| 271 | + " for masked_images, original_images, masks in dataset.take(1):\n", |
| 272 | + " plt.figure(figsize=(10, 5))\n", |
| 273 | + " for i in range(3): # Display 3 samples\n", |
| 274 | + " plt.subplot(3, 3, i * 3 + 1)\n", |
| 275 | + " plt.title(\"Masked Image\")\n", |
| 276 | + " plt.imshow(masked_images[i])\n", |
| 277 | + " plt.axis(\"off\")\n", |
| 278 | + "\n", |
| 279 | + " plt.subplot(3, 3, i * 3 + 2)\n", |
| 280 | + " plt.title(\"Original Image\")\n", |
| 281 | + " plt.imshow(original_images[i])\n", |
| 282 | + " plt.axis(\"off\")\n", |
| 283 | + "\n", |
| 284 | + " plt.subplot(3, 3, i * 3 + 3)\n", |
| 285 | + " plt.title(\"Mask\")\n", |
| 286 | + " plt.imshow(masks[i], cmap='gray')\n", |
| 287 | + " plt.axis(\"off\")\n", |
| 288 | + " plt.show()\n", |
| 289 | + "\n", |
| 290 | + "# Visualize a batch from the training dataset\n", |
| 291 | + "if 'train_dataset' in locals():\n", |
| 292 | + " visualize_batch(train_dataset)\n", |
| 293 | + "else:\n", |
| 294 | + " print(\"Dataset not available for visualization.\")" |
| 295 | + ], |
| 296 | + "metadata": { |
| 297 | + "id": "IB80Tsl8GEDY" |
| 298 | + }, |
| 299 | + "execution_count": null, |
| 300 | + "outputs": [] |
| 301 | + }, |
187 | 302 | {
|
188 | 303 | "cell_type": "markdown",
|
189 | 304 | "source": [
|
|
0 commit comments