Skip to content
This repository was archived by the owner on Jan 3, 2024. It is now read-only.

Commit 140fe1e

Browse files
committed
Add fragmented code
1 parent f45365d commit 140fe1e

File tree

4 files changed

+276
-0
lines changed

4 files changed

+276
-0
lines changed

pytorch-superpixels/__init__.py

Whitespace-only changes.

pytorch-superpixels/metrics.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
'''For superpixel validation'''
2+
3+
4+
def mask_accuracy(target, mask):
5+
target_s = torch.zeros_like(target)
6+
superpixels = mask.unique().numel()
7+
for superpixel in range(superpixels):
8+
# Define mask for cluster idx
9+
segment_mask = mask == superpixel
10+
# Take slices to select image, apply mask, mode for majority class
11+
target_s[segment_mask] = target[segment_mask].view(-1).mode()[0]
12+
accuracy = torch.mean((target == target_s).float())
13+
return accuracy
14+
15+
16+
def dataset_accuracy(superpixels):
17+
# Generate image list
18+
if superpixels is not None:
19+
image_list = get_image_list('trainval_super')
20+
else:
21+
image_list = get_image_list()
22+
23+
mask_acc = 0
24+
mask_dir = "SegmentationClass/{}_sp".format(superpixels)
25+
target_dir = "SegmentationClass/pre_encoded"
26+
for image_number in tqdm(image_list):
27+
mask_path = join(root, mask_dir, image_number + ".pt")
28+
target_path = join(root, target_dir, image_number + ".png")
29+
mask = torch.load(mask_path)
30+
target = io.imread(target_path)
31+
target = torch.from_numpy(target)
32+
mask_acc += mask_accuracy(target, mask)
33+
dataset_acc = mask_acc / len(image_list)
34+
return dataset_acc
35+
36+
37+
def find_smallest_object():
38+
# Generate image list
39+
image_list = get_image_list()
40+
smallest_object = 1e6
41+
for image_number in tqdm(image_list):
42+
target_name = image_number + ".png"
43+
target_path = join(root, "SegmentationClass/pre_encoded", target_name)
44+
target = io.imread(target_path)
45+
target = torch.from_numpy(target)
46+
object_size = torch.ne(target, 0).sum()
47+
if object_size < smallest_object:
48+
smallest_object = object_size
49+
print(smallest_object, image_number)
50+
return smallest_object
51+
52+
53+
def find_usable_images(split, superpixels):
54+
# Generate image list
55+
image_list = get_image_list(split)
56+
usable = []
57+
target_dir = join(
58+
root,
59+
"SegmentationClass/pre_encoded_{}_sp".format(superpixels)
60+
)
61+
for image_number in image_list:
62+
target_name = image_number + ".pt"
63+
target_path = join(target_dir, target_name)
64+
target = torch.load(target_path)
65+
if target.nonzero().numel() > 0:
66+
usable.append(image_number)
67+
return usable
68+
69+
70+
def fix_broken_images(superpixels):
71+
for split in ["train", "val", "trainval"]:
72+
usable = find_usable_images(split=split, superpixels=superpixels)
73+
super_path = join(root, "ImageSets/Segmentation", split + "_super.txt")
74+
if exists(super_path):
75+
remove(super_path)
76+
with open(super_path, "w+") as file:
77+
for image_number in usable:
78+
file.write(image_number + "\n")
79+
80+
81+
def find_size_variance(superpixels):
82+
# Generate image list
83+
if superpixels is not None:
84+
image_list = get_image_list('trainval_super')
85+
else:
86+
image_list = get_image_list()
87+
mask_dir = "SegmentationClass/{}_sp".format(superpixels)
88+
dataset_variance = 0
89+
for image_number in tqdm(image_list):
90+
mask_path = join(root, mask_dir, image_number + ".pt")
91+
mask = torch.load(mask_path)
92+
# Initialise number of superpixels tensors
93+
Q = mask.unique().numel()
94+
size = torch.zeros(Q)
95+
counter = torch.ones_like(mask)
96+
# Calculate the size of each superpixel
97+
size.put_(mask, counter.float(), True)
98+
# Calculate the mean and standard deviation of the sizes
99+
std = size.std()
100+
mean = size.mean()
101+
# Add to the variance of the total datasets
102+
dataset_variance += std / mean
103+
dataset_variance /= len(image_list)
104+
return dataset_variance

pytorch-superpixels/preprocess.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
'''For pre-processing'''
2+
3+
4+
def create_masks(numSegments=100, limOverseg=None):
5+
# Generate image list
6+
image_list = get_image_list()
7+
for image_number in tqdm(image_list):
8+
# Load image/target pair
9+
image_name = image_number + ".jpg"
10+
target_name = image_number + ".png"
11+
image_path = join(root, "JPEGImages", image_name)
12+
target_path = join(root, "SegmentationClass/pre_encoded", target_name)
13+
image = img_as_float(io.imread(image_path))
14+
target = io.imread(target_path)
15+
target = torch.from_numpy(target)
16+
# Create mask for image/target pair
17+
mask, target_s = create_mask(
18+
image=image,
19+
target=target,
20+
numSegments=numSegments,
21+
limOverseg=limOverseg
22+
)
23+
24+
# Save for later
25+
image_save_dir = join(
26+
root,
27+
"SegmentationClass/{}_sp".format(numSegments)
28+
)
29+
target_s_save_dir = join(
30+
root,
31+
"SegmentationClass/pre_encoded_{}_sp".format(numSegments)
32+
)
33+
if not exists(image_save_dir):
34+
mkdir(image_save_dir)
35+
if not exists(target_s_save_dir):
36+
mkdir(target_s_save_dir)
37+
save_name = image_number + ".pt"
38+
image_save_path = join(image_save_dir, save_name)
39+
target_s_save_path = join(target_s_save_dir, save_name)
40+
torch.save(mask, image_save_path)
41+
torch.save(target_s, target_s_save_path)
42+
43+
44+
def create_mask(image, target, numSegments, limOverseg):
45+
# Perform SLIC segmentation
46+
mask = slic(image, n_segments=numSegments, slic_zero=True)
47+
mask = torch.from_numpy(mask)
48+
49+
if limOverseg is not None:
50+
# Oversegmentation step
51+
superpixels = mask.unique().numel()
52+
overseg = superpixels
53+
for superpixel in range(superpixels):
54+
overseg -= 1
55+
# Define mask for superpixel
56+
segment_mask = mask == superpixel
57+
# Classes in this superpixel
58+
classes = target[segment_mask].unique(sorted=True)
59+
# Check if superpixel is on target boundary
60+
on_boundary = classes.numel() > 1
61+
# If current superpixel is on a gt boundary
62+
if on_boundary:
63+
# Find how many of each class is in superpixel
64+
class_hist = torch.bincount(target[segment_mask])
65+
# Remove zero elements
66+
class_hist = class_hist[class_hist.nonzero()].float()
67+
# Find minority class in superpixel
68+
min_class = min(class_hist)
69+
# Is the minority class large enough for oversegmentation
70+
above_threshold = min_class > class_hist.sum() * limOverseg
71+
if above_threshold:
72+
# Leaving one class in supperpixel be
73+
for c in classes[1:]:
74+
# Adding to the oversegmentation offset
75+
overseg += 1
76+
# Add offset to class c in the mask
77+
mask[segment_mask] += (target[segment_mask]
78+
== c).long() * overseg
79+
80+
# (Re)define how many superpixels there are and create target_s
81+
superpixels = mask.unique().numel()
82+
target_s = torch.zeros(superpixels, dtype=torch.long)
83+
for superpixel in range(superpixels):
84+
# Define mask for superpixel
85+
segment_mask = mask == superpixel
86+
# Apply mask, the mode for majority class
87+
target_s[superpixel] = target[segment_mask].view(-1).mode()[0]
88+
return mask, target_s
89+
90+
91+
def get_image_list(split=None):
92+
if split is None:
93+
image_list_path = join(root, "ImageSets/Segmentation/trainval.txt")
94+
else:
95+
image_list_path = join(root, "ImageSets/Segmentation/", split + ".txt")
96+
image_list = tuple(open(image_list_path, "r"))
97+
image_list = [id_.rstrip() for id_ in image_list]
98+
return image_list

pytorch-superpixels/runtime.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
from os.path import join
3+
from os.path import exists
4+
from os.path import dirname
5+
from os.path import abspath
6+
from os import mkdir
7+
from os import remove
8+
from os import listdir
9+
from tqdm import tqdm
10+
from skimage import io
11+
from skimage.util import img_as_float
12+
from skimage.segmentation import slic
13+
14+
# Define absolute path for accessing dataset files
15+
package_dir = dirname(abspath(__file__))
16+
dataset_dir = "../../datasets/VOCdevkit/VOC2011"
17+
root = join(package_dir, dataset_dir)
18+
'''For use during runtime'''
19+
20+
21+
def convert_to_superpixels(input, target, mask):
22+
# Extract size data from input and target
23+
images, c, h, w = input.size()
24+
if images > 1:
25+
raise RuntimeError("Not implemented for batch sizes greater than 1")
26+
# Initialise vairables to use
27+
Q = mask.unique().numel()
28+
output = torch.zeros((Q, c), device=input.device)
29+
size = torch.zeros(Q, device=input.device)
30+
counter = torch.ones(mask.size(), device=input.device)
31+
# Calculate the size of each superpixel
32+
size.put_(mask, counter, True)
33+
# Calculate the mean value of each superpixel
34+
input = input.view(c, -1)
35+
mask = mask.view(1, -1).repeat(c, 1)
36+
arange = torch.arange(start=1, end=c, device=input.device)
37+
mask[arange, :] += Q * arange.view(-1, 1)
38+
output = output.put_(mask, input, True).view(c, Q).t()
39+
output = (output.t() / size).t()
40+
return output, target.view(-1), size
41+
42+
43+
def convert_to_pixels(input, output, mask):
44+
n, c, h, w = output.size()
45+
for k in range(c):
46+
output[0, k, :, :] = torch.gather(
47+
input[:, k], 0, mask.view(-1)).view(h, w)
48+
return output
49+
50+
51+
def to_super_to_pixels(input, mask):
52+
target = torch.tensor([])
53+
input_s, _, _ = convert_to_superpixels(input, target, mask)
54+
output = convert_to_pixels(input_s, input, mask)
55+
return output
56+
57+
58+
def setup_superpixels(superpixels):
59+
image_save_dir = join(
60+
root,
61+
"SegmentationClass/{}_sp".format(superpixels)
62+
)
63+
target_s_save_dir = join(
64+
root,
65+
"SegmentationClass/pre_encoded_{}_sp".format(superpixels)
66+
)
67+
dirs = [image_save_dir, target_s_save_dir]
68+
dataset_len = len(get_image_list())
69+
if not any(exists(x) and len(listdir(x)) == dataset_len for x in dirs):
70+
print("Superpixel dataset of scale {} superpixels either doesn't exist or is incomplete".format(superpixels))
71+
print("Generating superpixel dataset now...")
72+
create_masks(superpixels)
73+
74+
fix_broken_images(superpixels)

0 commit comments

Comments
 (0)