@@ -48,6 +48,7 @@ def __init__(self, index_dtype, target_text_dict_dtype):
4848 :param target_text_dict_dtype: type for the target text dictionary array
4949 """
5050 assert pa .types .is_integer (index_dtype )
51+ assert pa .types .is_dictionary (target_text_dict_dtype )
5152
5253 fields = [
5354 pa .field (self .BEGINS_NAME , index_dtype ),
@@ -70,52 +71,31 @@ class ArrowTokenSpanType(pa.PyExtensionType):
7071
7172 BEGINS_NAME = "token_begins"
7273 ENDS_NAME = "token_ends"
73- TARGET_TEXT_DICT_NAME = "token_spans "
74+ TOKENS_NAME = "tokens "
7475
75- def __init__ (self , index_dtype , target_text , num_char_span_splits ):
76+ def __init__ (self , index_dtype , token_dict_dtype ):
7677 """
7778 Create an instance of a TokenSpan data type with given index type and
7879 target text that will be stored in Field metadata.
7980
80- :param index_dtype:
81- :param target_text:
81+ :param index_dtype: type for the begin, end index arrays
82+ :param token_dict_dtype: type for the tokens dictionary array
8283 """
8384 assert pa .types .is_integer (index_dtype )
84- self .num_char_span_splits = num_char_span_splits
85-
86- # Store target text as field metadata
87- metadata = {self .TARGET_TEXT_KEY : target_text }
85+ assert pa .types .is_dictionary (token_dict_dtype )
8886
89- token_span_fields = [
90- pa .field (self .BEGINS_NAME , index_dtype , metadata = metadata ),
87+ fields = [
88+ pa .field (self .BEGINS_NAME , index_dtype ),
9189 pa .field (self .ENDS_NAME , index_dtype ),
90+ pa .field (self .TOKENS_NAME , token_dict_dtype ),
9291 ]
9392
94- # Span arrays fit into single fields
95- if num_char_span_splits == 0 :
96- char_span_fields = [
97- pa .field (ArrowSpanType .BEGINS_NAME , index_dtype ),
98- pa .field (ArrowSpanType .ENDS_NAME , index_dtype )
99- ]
100- # Store splits of Span as multiple fields
101- else :
102- char_span_fields = []
103- for i in range (num_char_span_splits ):
104- n = "_{}" .format (i )
105- begin_field = pa .field (ArrowSpanType .BEGINS_NAME + n , index_dtype )
106- end_field = pa .field (ArrowSpanType .ENDS_NAME + n , index_dtype )
107- char_span_fields .extend ([begin_field , end_field ])
108-
109- fields = token_span_fields + char_span_fields
110-
11193 pa .PyExtensionType .__init__ (self , pa .struct (fields ))
11294
11395 def __reduce__ (self ):
11496 index_dtype = self .storage_type [self .BEGINS_NAME ].type
115- metadata = self .storage_type [self .BEGINS_NAME ].metadata
116- target_text = metadata [self .TARGET_TEXT_KEY ].decode ()
117- num_char_span_splits = self .num_char_span_splits
118- return ArrowTokenSpanType , (index_dtype , target_text , num_char_span_splits )
97+ token_dict_dtype = self .storage_type [self .TOKENS_NAME ].type
98+ return ArrowTokenSpanType , (index_dtype , token_dict_dtype )
11999
120100
121101def span_to_arrow (char_span : SpanArray ) -> pa .ExtensionArray :
@@ -198,62 +178,33 @@ def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
198178 # Create arrays for begins/ends
199179 token_begins_array = pa .array (token_span .begin_token )
200180 token_ends_array = pa .array (token_span .end_token )
201- token_span_arrays = [token_begins_array , token_ends_array ]
202-
203- num_char_span_splits = 0
204-
205- # If TokenSpan arrays have greater length than Span arrays, pad Span
206- if len (token_span .begin_token ) > len (token_span .tokens .begin ):
207-
208- padding = np .zeros (len (token_span .begin_token ) - len (token_span .tokens .begin ),
209- token_span .tokens .begin .dtype )
210-
211- isnull = np .append (np .full (len (token_span .tokens .begin ), False ), np .full (len (padding ), True ))
212- char_begins_padded = np .append (token_span .tokens .begin , padding )
213- char_ends_padded = np .append (token_span .tokens .end , padding )
214- char_begins_array = pa .array (char_begins_padded , mask = isnull )
215- char_ends_array = pa .array (char_ends_padded , mask = isnull )
216- char_span_arrays = [char_begins_array , char_ends_array ]
217181
218- # If TokenSpan arrays have less length than Span arrays, split Span into multiple arrays
219- elif len (token_span .begin_token ) < len (token_span .tokens .begin ):
220-
221- char_begins_array = pa .array (token_span .tokens .begin )
222- char_ends_array = pa .array (token_span .tokens .end )
223-
224- char_span_arrays = []
225- while len (char_begins_array ) >= len (token_begins_array ):
226- char_begins_split = char_begins_array [:len (token_begins_array )]
227- char_ends_split = char_ends_array [:len (token_ends_array )]
228-
229- char_span_arrays .extend ([char_begins_split , char_ends_split ])
230- num_char_span_splits += 1
182+ # Get either single document as a list or use a list of all if multiple docs
183+ assert len (token_span .tokens ) > 0
184+ if all ([token is token_span .tokens [0 ] for token in token_span .tokens ]):
185+ tokens_arrays = [token_span .tokens [0 ]]
186+ tokens_indices = pa .array ([0 ] * len (token_span .tokens ))
187+ else :
188+ tokens_arrays = token_span .tokens
189+ tokens_indices = pa .array (range (len (tokens_arrays )))
231190
232- char_begins_array = char_begins_array [ len ( token_begins_array ):]
233- char_ends_array = char_ends_array [ len ( token_ends_array ): ]
191+ # Convert each token SpanArray to Arrow and get as raw storage
192+ arrow_tokens_arrays = [ span_to_arrow ( sa ). storage for sa in tokens_arrays ]
234193
235- # Pad the final split
236- if len (char_begins_array ) > 0 :
237- padding = np .zeros (len (token_begins_array ) - len (char_begins_array ),
238- token_span .tokens .begin .dtype )
239- isnull = np .append (np .full (len (char_begins_array ), False ), np .full (len (padding ), True ))
240- char_begins_padded = np .append (char_begins_array .to_numpy (), padding )
241- char_ends_padded = np .append (char_ends_array .to_numpy (), padding )
242- char_begins_split = pa .array (char_begins_padded , mask = isnull )
243- char_ends_split = pa .array (char_ends_padded , mask = isnull )
244- char_span_arrays .extend ([char_begins_split , char_ends_split ])
245- num_char_span_splits += 1
194+ # Create a list array with each element is an ArrowSpanArray
195+ # TODO: pyarrow.lib.ArrowNotImplementedError: ('Sequence converter for type dictionary<values=string, indices=int8, ordered=0> not implemented', 'Conversion failed for column ts1 with type TokenSpanDtype')
196+ #arrow_tokens_arrays_array = pa.array(arrow_tokens_arrays, type=pa.list_(arrow_tokens_arrays[0].type))
197+ offsets = [0 ] + [len (a ) for a in arrow_tokens_arrays ]
198+ values = pa .concat_arrays (arrow_tokens_arrays ) # TODO: can't concat extension arrays?
199+ arrow_tokens_arrays_array = pa .ListArray .from_arrays (offsets , values )
246200
247- # TokenSpan arrays are equal length to Span arrays
248- else :
249- char_begins_array = pa .array (token_span .tokens .begin )
250- char_ends_array = pa .array (token_span .tokens .end )
251- char_span_arrays = [char_begins_array , char_ends_array ]
201+ # Create a dictionary array mapping each token SpanArray index used to the list of ArrowSpanArrays
202+ tokens_dict_array = pa .DictionaryArray .from_arrays (tokens_indices , arrow_tokens_arrays_array )
252203
253- typ = ArrowTokenSpanType (token_begins_array .type , token_span . target_text , num_char_span_splits )
204+ typ = ArrowTokenSpanType (token_begins_array .type , tokens_dict_array . type )
254205 fields = list (typ .storage_type )
255206
256- storage = pa .StructArray .from_arrays (token_span_arrays + char_span_arrays , fields = fields )
207+ storage = pa .StructArray .from_arrays ([ token_begins_array , token_ends_array , tokens_dict_array ] , fields = fields )
257208
258209 return pa .ExtensionArray .from_storage (typ , storage )
259210
@@ -273,46 +224,41 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
273224
274225 assert pa .types .is_struct (extension_array .storage .type )
275226
276- # Get target text from the begins field metadata and decode string
277- metadata = extension_array .storage .type [ArrowTokenSpanType .BEGINS_NAME ].metadata
278- target_text = metadata [ArrowSpanType .TARGET_TEXT_KEY ]
279- if isinstance (target_text , bytes ):
280- target_text = target_text .decode ()
281-
282227 # Get the begins/ends pyarrow arrays
283228 token_begins_array = extension_array .storage .field (ArrowTokenSpanType .BEGINS_NAME )
284229 token_ends_array = extension_array .storage .field (ArrowTokenSpanType .ENDS_NAME )
285230
286- # Check if CharSpans have been split
287- num_char_span_splits = extension_array .type .num_char_span_splits
288- if num_char_span_splits > 0 :
289- char_begins_splits = []
290- char_ends_splits = []
291- for i in range (num_char_span_splits ):
292- char_begins_splits .append (
293- extension_array .storage .field (ArrowSpanType .BEGINS_NAME + "_{}" .format (i )))
294- char_ends_splits .append (
295- extension_array .storage .field (ArrowSpanType .ENDS_NAME + "_{}" .format (i )))
296- char_begins_array = pa .concat_arrays (char_begins_splits )
297- char_ends_array = pa .concat_arrays (char_ends_splits )
298- else :
299- char_begins_array = extension_array .storage .field (ArrowSpanType .BEGINS_NAME )
300- char_ends_array = extension_array .storage .field (ArrowSpanType .ENDS_NAME )
231+ # Get the tokens as a dictionary array where indices map to a list of ArrowSpanArrays
232+ tokens_dict_array = extension_array .storage .field (ArrowTokenSpanType .TOKENS_NAME )
233+ tokens_indices = tokens_dict_array .indices
234+ arrow_tokens_arrays_array = tokens_dict_array .dictionary
235+
236+ # Breakup the list of ArrowSpanArrays and convert back to individual SpanArrays
237+ tokens_arrays = []
238+ span_type = None
239+ for i in range (1 , len (arrow_tokens_arrays_array .offsets )):
240+ start = arrow_tokens_arrays_array .offsets [i - 1 ].as_py ()
241+ stop = arrow_tokens_arrays_array .offsets [i ].as_py ()
242+ arrow_tokens_array = arrow_tokens_arrays_array .values [start :stop ]
243+
244+ # Make an instance of ArrowSpanType
245+ if span_type is None :
246+ begins_array = arrow_tokens_array .field (ArrowSpanType .BEGINS_NAME )
247+ target_text_dict_array = arrow_tokens_array .field (ArrowSpanType .TARGET_TEXT_DICT_NAME )
248+ span_type = ArrowSpanType (begins_array .type , target_text_dict_array .type )
249+
250+ # Re-make the Arrow extension type to convert back to a SpanArray
251+ tokens_array = arrow_to_span (pa .ExtensionArray .from_storage (span_type , arrow_tokens_array ))
252+ tokens_arrays .append (tokens_array )
301253
302- # Remove any trailing padding
303- if char_begins_array .null_count > 0 :
304- char_begins_array = char_begins_array [:- char_begins_array .null_count ]
305- char_ends_array = char_ends_array [:- char_ends_array .null_count ]
254+ # Map the token indices to the actual token SpanArray for each element in the TokenSpanArray
255+ tokens = [tokens_arrays [i .as_py ()] for i in tokens_indices ]
306256
307257 # Zero-copy convert arrays to numpy
308258 token_begins = token_begins_array .to_numpy ()
309259 token_ends = token_ends_array .to_numpy ()
310- char_begins = char_begins_array .to_numpy ()
311- char_ends = char_ends_array .to_numpy ()
312260
313- # Create the SpanArray, then the TokenSpanArray
314- char_span = SpanArray (target_text , char_begins , char_ends )
315- return TokenSpanArray (char_span , token_begins , token_ends )
261+ return TokenSpanArray (tokens , token_begins , token_ends )
316262
317263
318264class ArrowTensorType (pa .PyExtensionType ):
0 commit comments