@@ -9,28 +9,60 @@ def __init__(self, label_entries, n5_prefix):
99 self .label_entries = label_entries
1010 self .n5_prefix = n5_prefix
1111 self .s3_filesystem = s3fs .S3FileSystem ()
12- self .default_block_size = 300000
1312
1413 def write_json_to_s3 (self , id_dataset_path , loc_dataset_path , attributes ):
1514 """
16- Write attributes file into both the ID and LOC dataset directories on S3
15+ Write attributes file into both the ID and LOC dataset directories on S3.
1716 """
1817 bucket , key = id_dataset_path .replace ("s3://" , "" , 1 ).split ("/" , 1 )
19- json_path = key + ' /attributes.json'
20- json_bytes = json .dumps (attributes ).encode (' utf-8' )
21- s3 = boto3 .client ('s3' )
18+ json_path = key + " /attributes.json"
19+ json_bytes = json .dumps (attributes ).encode (" utf-8" )
20+ s3 = boto3 .client ("s3" )
2221 s3 .put_object (Bucket = bucket , Key = json_path , Body = json_bytes )
2322
2423 bucket , key = loc_dataset_path .replace ("s3://" , "" , 1 ).split ("/" , 1 )
25- json_path = key + '/attributes.json'
26- json_bytes = json .dumps (attributes ).encode ('utf-8' )
27- s3 = boto3 .client ('s3' )
24+ json_path = key + "/attributes.json"
25+ json_bytes = json .dumps (attributes ).encode ("utf-8" )
2826 s3 .put_object (Bucket = bucket , Key = json_path , Body = json_bytes )
2927
28+ def write_one_block_dataset (self , root , name , data , dtype , attrs ):
29+ """
30+ Write a points as one block/chunk.
31+ """
32+ data = np .asarray (data , dtype = dtype )
33+ chunks = tuple (max (1 , dim ) for dim in data .shape )
34+
35+ if name in root :
36+ arr = zarr .creation .create (
37+ shape = data .shape ,
38+ chunks = chunks ,
39+ dtype = dtype ,
40+ compressor = zarr .GZip (),
41+ store = root .store ,
42+ path = f"{ root .path } /{ name } " if root .path else name ,
43+ overwrite = True ,
44+ )
45+
46+ if data .size > 0 :
47+ arr [...] = data
48+ else :
49+ arr = root .create_dataset (
50+ name = name ,
51+ data = data ,
52+ dtype = dtype ,
53+ chunks = chunks ,
54+ compressor = zarr .GZip (),
55+ )
56+
57+ for k , v in attrs .items ():
58+ arr .attrs [k ] = v
59+
60+ return arr
61+
3062 def save_interest_points_to_n5 (self ):
3163 for label_entry in self .label_entries :
32- n5_path = label_entry [' ip_list' ][ ' n5_path' ]
33-
64+ n5_path = label_entry [" ip_list" ][ " n5_path" ]
65+
3466 if self .n5_prefix .startswith ("s3://" ):
3567 output_path = self .n5_prefix + n5_path + "/interestpoints"
3668 store = s3fs .S3Map (root = output_path , s3 = self .s3_filesystem , check = False )
@@ -40,6 +72,10 @@ def save_interest_points_to_n5(self):
4072 store = zarr .N5Store (output_path )
4173 root = zarr .group (store , overwrite = False )
4274
75+ root .attrs ["pointcloud" ] = "1.0.0"
76+ root .attrs ["type" ] = "list"
77+ root .attrs ["list version" ] = "1.0.0"
78+
4379 id_dataset = "id"
4480 loc_dataset = "loc"
4581
@@ -49,53 +85,39 @@ def save_interest_points_to_n5(self):
4985 attrs_dict = dict (root .attrs )
5086 self .write_json_to_s3 (id_path , loc_path , attrs_dict )
5187
52- interest_points = [point [1 ] for point in label_entry ['ip_list' ]['interest_points' ]]
53- interest_point_ids = np .arange (len (interest_points ), dtype = np .uint64 ).reshape (- 1 , 1 )
54- n = 3
55-
56- if len (interest_points ) > 0 :
57- if id_dataset in root :
58- del root [id_dataset ]
59- root .create_dataset (
60- id_dataset ,
61- data = interest_point_ids ,
62- dtype = 'u8' ,
63- chunks = (self .default_block_size ,),
64- compressor = zarr .GZip ()
65- )
66-
67- if loc_dataset in root :
68- del root [loc_dataset ]
69- root .create_dataset (
70- loc_dataset ,
71- data = interest_points ,
72- dtype = 'f8' ,
73- chunks = (self .default_block_size , n ),
74- compressor = zarr .GZip ()
75- )
76-
77- # save as empty lists
78- else :
79- if id_dataset in root :
80- del root [id_dataset ]
81- root .create_dataset (
82- id_dataset ,
83- shape = (0 ,),
84- dtype = 'u8' ,
85- chunks = (1 ,),
86- compressor = zarr .GZip ()
87- )
88-
89- if loc_dataset in root :
90- del root [loc_dataset ]
91- root .create_dataset (
92- loc_dataset ,
93- shape = (0 ,),
94- dtype = 'f8' ,
95- chunks = (1 ,),
96- compressor = zarr .GZip ()
97- )
88+ interest_points = np .asarray (
89+ [point [1 ] for point in label_entry ["ip_list" ]["interest_points" ]],
90+ dtype = np .float64 ,
91+ ).reshape (- 1 , 3 )
92+
93+ num_points = interest_points .shape [0 ]
94+
95+ interest_point_ids = np .arange (
96+ num_points ,
97+ dtype = np .uint64 ,
98+ ).reshape (- 1 , 1 )
99+
100+ self .write_one_block_dataset (
101+ root = root ,
102+ name = id_dataset ,
103+ data = interest_point_ids ,
104+ dtype = "u8" ,
105+ attrs = {
106+ "dimensions" : [num_points , 1 ],
107+ "blockSize" : [max (num_points , 1 ), 1 ],
108+ },
109+ )
110+
111+ self .write_one_block_dataset (
112+ root = root ,
113+ name = loc_dataset ,
114+ data = interest_points ,
115+ dtype = "f8" ,
116+ attrs = {
117+ "dimensions" : [num_points , 3 ],
118+ "blockSize" : [max (num_points , 1 ), 3 ],
119+ },
120+ )
98121
99122 def run (self ):
100- self .save_interest_points_to_n5 ()
101- return 1
123+ self .save_interest_points_to_n5 ()
0 commit comments