Skip to content

Commit 3cfe0f9

Browse files
Merge pull request #224 from AllenNeuralDynamics/seanf
Fixed split dataset save points implementation to overwrite existing …
2 parents 78f4ba4 + 8775dc8 commit 3cfe0f9

2 files changed

Lines changed: 87 additions & 65 deletions

File tree

Rhapso/pipelines/ray/local/alignment_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,10 @@
161161

162162
# -- ALIGNMENT PIPELINE --
163163
interest_point_detection.run()
164-
# interest_point_matching_rigid.run()
165-
# solver_rigid.run()
166-
# interest_point_matching_affine.run()
167-
# solver_affine.run()
168-
# split_dataset.run()
169-
# interest_point_matching_split_affine.run()
170-
# solver_split_affine.run()
164+
interest_point_matching_rigid.run()
165+
solver_rigid.run()
166+
interest_point_matching_affine.run()
167+
solver_affine.run()
168+
split_dataset.run()
169+
interest_point_matching_split_affine.run()
170+
solver_split_affine.run()

Rhapso/split_dataset/save_points.py

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,60 @@ def __init__(self, label_entries, n5_prefix):
99
self.label_entries = label_entries
1010
self.n5_prefix = n5_prefix
1111
self.s3_filesystem = s3fs.S3FileSystem()
12-
self.default_block_size = 300000
1312

1413
def write_json_to_s3(self, id_dataset_path, loc_dataset_path, attributes):
1514
"""
16-
Write attributes file into both the ID and LOC dataset directories on S3
15+
Write attributes file into both the ID and LOC dataset directories on S3.
1716
"""
1817
bucket, key = id_dataset_path.replace("s3://", "", 1).split("/", 1)
19-
json_path = key + '/attributes.json'
20-
json_bytes = json.dumps(attributes).encode('utf-8')
21-
s3 = boto3.client('s3')
18+
json_path = key + "/attributes.json"
19+
json_bytes = json.dumps(attributes).encode("utf-8")
20+
s3 = boto3.client("s3")
2221
s3.put_object(Bucket=bucket, Key=json_path, Body=json_bytes)
2322

2423
bucket, key = loc_dataset_path.replace("s3://", "", 1).split("/", 1)
25-
json_path = key + '/attributes.json'
26-
json_bytes = json.dumps(attributes).encode('utf-8')
27-
s3 = boto3.client('s3')
24+
json_path = key + "/attributes.json"
25+
json_bytes = json.dumps(attributes).encode("utf-8")
2826
s3.put_object(Bucket=bucket, Key=json_path, Body=json_bytes)
2927

28+
def write_one_block_dataset(self, root, name, data, dtype, attrs):
29+
"""
30+
Write a points as one block/chunk.
31+
"""
32+
data = np.asarray(data, dtype=dtype)
33+
chunks = tuple(max(1, dim) for dim in data.shape)
34+
35+
if name in root:
36+
arr = zarr.creation.create(
37+
shape=data.shape,
38+
chunks=chunks,
39+
dtype=dtype,
40+
compressor=zarr.GZip(),
41+
store=root.store,
42+
path=f"{root.path}/{name}" if root.path else name,
43+
overwrite=True,
44+
)
45+
46+
if data.size > 0:
47+
arr[...] = data
48+
else:
49+
arr = root.create_dataset(
50+
name=name,
51+
data=data,
52+
dtype=dtype,
53+
chunks=chunks,
54+
compressor=zarr.GZip(),
55+
)
56+
57+
for k, v in attrs.items():
58+
arr.attrs[k] = v
59+
60+
return arr
61+
3062
def save_interest_points_to_n5(self):
3163
for label_entry in self.label_entries:
32-
n5_path = label_entry['ip_list']['n5_path']
33-
64+
n5_path = label_entry["ip_list"]["n5_path"]
65+
3466
if self.n5_prefix.startswith("s3://"):
3567
output_path = self.n5_prefix + n5_path + "/interestpoints"
3668
store = s3fs.S3Map(root=output_path, s3=self.s3_filesystem, check=False)
@@ -40,6 +72,10 @@ def save_interest_points_to_n5(self):
4072
store = zarr.N5Store(output_path)
4173
root = zarr.group(store, overwrite=False)
4274

75+
root.attrs["pointcloud"] = "1.0.0"
76+
root.attrs["type"] = "list"
77+
root.attrs["list version"] = "1.0.0"
78+
4379
id_dataset = "id"
4480
loc_dataset = "loc"
4581

@@ -49,53 +85,39 @@ def save_interest_points_to_n5(self):
4985
attrs_dict = dict(root.attrs)
5086
self.write_json_to_s3(id_path, loc_path, attrs_dict)
5187

52-
interest_points = [point[1] for point in label_entry['ip_list']['interest_points']]
53-
interest_point_ids = np.arange(len(interest_points), dtype=np.uint64).reshape(-1, 1)
54-
n = 3
55-
56-
if len(interest_points) > 0:
57-
if id_dataset in root:
58-
del root[id_dataset]
59-
root.create_dataset(
60-
id_dataset,
61-
data=interest_point_ids,
62-
dtype='u8',
63-
chunks=(self.default_block_size,),
64-
compressor=zarr.GZip()
65-
)
66-
67-
if loc_dataset in root:
68-
del root[loc_dataset]
69-
root.create_dataset(
70-
loc_dataset,
71-
data=interest_points,
72-
dtype='f8',
73-
chunks=(self.default_block_size, n),
74-
compressor=zarr.GZip()
75-
)
76-
77-
# save as empty lists
78-
else:
79-
if id_dataset in root:
80-
del root[id_dataset]
81-
root.create_dataset(
82-
id_dataset,
83-
shape=(0,),
84-
dtype='u8',
85-
chunks=(1,),
86-
compressor=zarr.GZip()
87-
)
88-
89-
if loc_dataset in root:
90-
del root[loc_dataset]
91-
root.create_dataset(
92-
loc_dataset,
93-
shape=(0,),
94-
dtype='f8',
95-
chunks=(1,),
96-
compressor=zarr.GZip()
97-
)
88+
interest_points = np.asarray(
89+
[point[1] for point in label_entry["ip_list"]["interest_points"]],
90+
dtype=np.float64,
91+
).reshape(-1, 3)
92+
93+
num_points = interest_points.shape[0]
94+
95+
interest_point_ids = np.arange(
96+
num_points,
97+
dtype=np.uint64,
98+
).reshape(-1, 1)
99+
100+
self.write_one_block_dataset(
101+
root=root,
102+
name=id_dataset,
103+
data=interest_point_ids,
104+
dtype="u8",
105+
attrs={
106+
"dimensions": [num_points, 1],
107+
"blockSize": [max(num_points, 1), 1],
108+
},
109+
)
110+
111+
self.write_one_block_dataset(
112+
root=root,
113+
name=loc_dataset,
114+
data=interest_points,
115+
dtype="f8",
116+
attrs={
117+
"dimensions": [num_points, 3],
118+
"blockSize": [max(num_points, 1), 3],
119+
},
120+
)
98121

99122
def run(self):
100-
self.save_interest_points_to_n5()
101-
return 1
123+
self.save_interest_points_to_n5()

0 commit comments

Comments
 (0)