diff --git a/Rhapso/split_dataset/split_images.py b/Rhapso/split_dataset/split_images.py index 0126528..231a196 100644 --- a/Rhapso/split_dataset/split_images.py +++ b/Rhapso/split_dataset/split_images.py @@ -8,16 +8,16 @@ import math class SplitImages: - def __init__(self, target_image_size, target_overlap, min_step_size, data_gloabl, n5_path, point_density, min_points, max_points, + def __init__(self, target_image_size, target_overlap, min_step_size, data_global, n5_path, point_density, min_points, max_points, error, excludeRadius): self.target_image_size = target_image_size self.target_overlap = target_overlap self.min_step_size = min_step_size - self.data_global = data_gloabl - self.image_loader_df = data_gloabl['image_loader'] - self.view_setups_df = data_gloabl['view_setups'] - self.view_registrations_df = data_gloabl['view_registrations'] - self.view_interest_points_df = data_gloabl['view_interest_points'] + self.data_global = data_global + self.image_loader_df = data_global['image_loader'] + self.view_setups_df = data_global['view_setups'] + self.view_registrations_df = data_global['view_registrations'] + self.view_interest_points_df = data_global['view_interest_points'] self.n5_path = n5_path self.point_density = point_density self.min_points = min_points @@ -96,21 +96,24 @@ def split_dims(self, input, i, final_size, overlap): to_val = 0 from_val = input_min[i] - while to_val < input[i]: + while to_val < input[i]-1: to_val = min(input[i], from_val + final_size - 1) dim_intervals.append((from_val, to_val)) from_val = to_val - overlap + 1 return dim_intervals - def last_image_size(self, l, s, o): - num = l - 2 * (s - o) - o - den = s - o - rem = num % den if num >= 0 else -((-num) % den) - size = o + rem - if size < 0: - size = l + size - return size + + def last_image_size(self, L, S, O): + stride = S - O + if not (0 <= O < S): + raise ValueError("Require 0 <= O < S") + if L <= 0: + raise ValueError("Require L > 0") + + start_last = ((max(L - S, 0)) // stride) * stride + return L - start_last # will be S when it fits perfectly + def distribute_intervals_fixed_overlap(self, input): input = list(map(int, input.split())) @@ -127,8 +130,8 @@ def distribute_intervals_fixed_overlap(self, input): length = input[i] if length <= self.target_image_size[i]: - pass - + dim_intervals.append((0, length - 1)) + else: l = length s = self.target_image_size[i] @@ -340,137 +343,141 @@ def split_images(self, timepoints, interest_points, fake_label): new_registrations[(new_view_id_key)] = new_view_registration new_v_ip_l = [] - - old_v_ip_l = { - 'points': interest_points[old_view_id], - 'setup': old_id, - 'timepoint': t, - } - - id = 0 - new_ip1 = [] - old_ip_l1 = old_v_ip_l['points'] - old_ip_1 = deepcopy(old_ip_l1['points']) - - for ip in old_ip_1: - if self.contains(ip, interval): - l = deepcopy(ip) - for j in range(len(interval[0])): - l[j] -= interval[0][j] - - new_ip1.append((id, l)) - id += 1 - - new_ip_l1 = { - 'base_directory': old_ip_l1['base_path'], - 'corresponding_interest_points': None, - 'interest_points': new_ip1, - 'modified_corresponding_interest_points': None, - 'modified_interest_points': None, - 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split", - 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - "parameters": old_ip_l1['parameters_split'] - } - - new_v_ip_l.append({ - 'label': "beads_split", - 'ip_list': new_ip_l1 - }) - - new_ip = [] - id = 0 - - for j in range(i): - other_interval = intervals[j] - intersection = self.intersect(interval, other_interval) + if old_view_id in interest_points: + old_v_ip_l = { + 'points': interest_points[old_view_id], + 'setup': old_id, + 'timepoint': t, + } + + id = 0 + new_ip1 = [] + old_ip_l1 = old_v_ip_l['points'] + old_ip_1 = deepcopy(old_ip_l1['points']) - if not self.is_empty(intersection): - other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))] - other_view_id = f"timepoint: {t}, setup: {other_setup['id']}" - other_ip_list = new_interest_points[other_view_id] - - n = len(interval[0]) - num_pixels = 1 - - for k in range(n): - num_pixels *= (intersection[1][k] - intersection[0][k] + 1) - - num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0)))) - other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or []) - other_id = len(other_points) - - tree2 = None - search2 = None + for ip in old_ip_1: + if self.contains(ip, interval): + l = deepcopy(ip) + for j in range(len(interval[0])): + l[j] -= interval[0][j] + + new_ip1.append((id, l)) + id += 1 + + new_ip_l1 = { + 'base_directory': old_ip_l1['base_path'], + 'corresponding_interest_points': None, + 'interest_points': new_ip1, + 'modified_corresponding_interest_points': None, + 'modified_interest_points': None, + 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/beads_split", + 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + "parameters": old_ip_l1['parameters_split'] + } + + new_v_ip_l.append({ + 'label': "beads_split", + 'ip_list': new_ip_l1 + }) + + if self.max_points > 0: + new_ip = [] + id = 0 - if self.exclude_radius > 0: - other_ip_global = [] + for j in range(i): + other_interval = intervals[j] + intersection = self.intersect(interval, other_interval) - for k, ip in enumerate(other_points): - l = deepcopy(ip[1]) + if not self.is_empty(intersection): + other_setup = interval_to_view_setup[(tuple(other_interval[0]), tuple(other_interval[1]))] + other_view_id = f"timepoint: {t}, setup: {other_setup['id']}" + other_ip_list = new_interest_points[other_view_id] - for m in range(n): - l[m] = l[m] + other_interval[0][m] - - other_ip_global.append((k, l)) + n = len(interval[0]) + num_pixels = 1 - if len(other_ip_global) > 0: - coords = np.vstack([l for _, l in other_ip_global]) - tree2 = cKDTree(coords) + for k in range(n): + num_pixels *= (intersection[1][k] - intersection[0][k] + 1) + + num_points = min(self.max_points, max(self.min_points, math.ceil(self.point_density * num_pixels / (100.0*100.0*100.0)))) + other_points = (next((x for x in other_ip_list if x.get("label") == fake_label), {"ip_list": {}})["ip_list"].get("interest_points") or []) + other_id = len(other_points) - def search2(q_point_global, radius=self.exclude_radius): - idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius) - return [other_ip_global[k] for k in idxs] - else: tree2 = None search2 = None - - else: - tree2 = None - search2 = None - - tmp = [0.0] * n - for k in range(num_points): - p = [0.0] * n - op = [0.0] * n - - for d in range(n): - l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d] - p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d] - op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d] - tmp[d] = l - - num_neighbors = 0 - if self.exclude_radius > 0: - tmp_ip = (0, np.asarray(tmp, dtype=float)) + if self.exclude_radius > 0: + other_ip_global = [] + + for k, ip in enumerate(other_points): + l = deepcopy(ip[1]) + + for m in range(n): + l[m] = l[m] + other_interval[0][m] + + other_ip_global.append((k, l)) + + if len(other_ip_global) > 0: + coords = np.vstack([l for _, l in other_ip_global]) + tree2 = cKDTree(coords) + + def search2(q_point_global, radius=self.exclude_radius): + idxs = tree2.query_ball_point(np.asarray(q_point_global, float), radius) + return [other_ip_global[k] for k in idxs] + else: + tree2 = None + search2 = None - if search2 is not None: - neighbors = search2(tmp_ip[1], self.exclude_radius) - num_neighbors += len(neighbors) - - if num_neighbors == 0: - new_ip.append((id, p)) - other_points.append((other_id, op)) - id += 1 - other_id += 1 + else: + tree2 = None + search2 = None + + tmp = [0.0] * n + + for k in range(num_points): + p = [0.0] * n + op = [0.0] * n + + for d in range(n): + l = rnd.random() * (intersection[1][d] - intersection[0][d] + 1) + intersection[0][d] + p[d] = (l + (rnd.random() - 0.5) * self.error) - interval[0][d] + op[d] = (l + (rnd.random() - 0.5) * self.error) - other_interval[0][d] + tmp[d] = l + + num_neighbors = 0 + if self.exclude_radius > 0: + tmp_ip = (0, np.asarray(tmp, dtype=float)) + + if search2 is not None: + neighbors = search2(tmp_ip[1], self.exclude_radius) + num_neighbors += len(neighbors) + + if num_neighbors == 0: + new_ip.append((id, p)) + other_points.append((other_id, op)) + id += 1 + other_id += 1 + + next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points - next(x for x in other_ip_list if x.get("label") == fake_label)["ip_list"]["interest_points"] = other_points - - new_ip_l = { - 'base_directory': old_ip_l1['base_path'], - 'corresponding_interest_points': None, - 'interest_points': new_ip, - 'modified_corresponding_interest_points': None, - 'modified_interest_points': None, - 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", - "parameters": old_ip_l1['parameters_fake'] - } - - new_v_ip_l.append({ - 'label': fake_label, - 'ip_list': new_ip_l - }) + new_ip_l = { + 'base_directory': old_ip_l1['base_path'], + 'corresponding_interest_points': None, + 'interest_points': new_ip, + 'modified_corresponding_interest_points': None, + 'modified_interest_points': None, + 'n5_path': f"interestpoints.n5/tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + 'xml_n5_path': f"tpId_{t}_viewSetupId_{new_view_id['setup']}/{fake_label}", + "parameters": old_ip_l1['parameters_fake'] + } + + new_v_ip_l.append({ + 'label': fake_label, + 'ip_list': new_ip_l + }) + + if len(new_v_ip_l) > 0: + new_interest_points[new_view_id_key] = new_v_ip_l self.setup_definition.append({ 'interval': interval, @@ -484,7 +491,6 @@ def search2(q_point_global, radius=self.exclude_radius): 'old_models': transform_list }) - new_interest_points[new_view_id_key] = new_v_ip_l new_id += 1 return new_interest_points @@ -493,6 +499,10 @@ def load_interest_points(self, fake_label): full_path = self.n5_path + "interestpoints.n5" interest_points = {} + # Skip loading interest points if dataframe is empty + if self.view_interest_points_df.empty: + return {} + if full_path.startswith("s3://"): path = full_path.rstrip("/") s3 = s3fs.S3FileSystem(anon=False) @@ -503,6 +513,7 @@ def load_interest_points(self, fake_label): store = zarr.N5Store(full_path) root = zarr.open(store, mode="r") + for _, row in self.view_interest_points_df.iterrows(): view_id = f"timepoint: {row['timepoint']}, setup: {row['setup']}" interestpoints_prefix = f"{row['path']}/interestpoints/loc/"