Skip to content

Commit 378c69c

Browse files
committed
Basic tests passing
1 parent 6dbbfe0 commit 378c69c

File tree

1 file changed

+55
-171
lines changed

1 file changed

+55
-171
lines changed

text_extensions_for_pandas/array/arrow_conversion.py

Lines changed: 55 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def __init__(self, index_dtype, target_text_dict_dtype):
4848
:param target_text_dict_dtype: type for the target text dictionary array
4949
"""
5050
assert pa.types.is_integer(index_dtype)
51+
assert pa.types.is_dictionary(target_text_dict_dtype)
5152

5253
fields = [
5354
pa.field(self.BEGINS_NAME, index_dtype),
@@ -70,86 +71,31 @@ class ArrowTokenSpanType(pa.PyExtensionType):
7071

7172
BEGINS_NAME = "token_begins"
7273
ENDS_NAME = "token_ends"
74+
TOKENS_NAME = "tokens"
7375

7476
def __init__(self, index_dtype, token_dict_dtype):
7577
"""
7678
Create an instance of a TokenSpan data type with given index type and
7779
target text that will be stored in Field metadata.
7880
79-
:param index_dtype:
80-
:param target_text:
81-
"""
82-
assert pa.types.is_integer(index_dtype)
83-
84-
token_span_fields = [
85-
pa.field(self.BEGINS_NAME, pa.dictionary(index_dtype, token_dict_dtype)),
86-
pa.field(self.ENDS_NAME, index_dtype),
87-
]
88-
89-
90-
fields = token_span_fields
91-
92-
pa.PyExtensionType.__init__(self, pa.struct(fields))
93-
94-
def __reduce__(self):
95-
index_dtype = self.storage_type[self.ENDS_NAME].type
96-
token_dict_dtype = self.storage_type[self.BEGINS_NAME].type.value_type
97-
return ArrowTokenSpanType, (index_dtype, token_dict_dtype)
98-
99-
100-
class ArrowTokenSpanTypeBAK(pa.PyExtensionType):
101-
"""
102-
PyArrow extension type definition for conversions to/from TokenSpan columns
103-
"""
104-
105-
BEGINS_NAME = "token_begins"
106-
ENDS_NAME = "token_ends"
107-
TARGET_TEXT_DICT_NAME = "token_spans"
108-
109-
def __init__(self, index_dtype, target_text, num_char_span_splits):
110-
"""
111-
Create an instance of a TokenSpan data type with given index type and
112-
target text that will be stored in Field metadata.
113-
114-
:param index_dtype:
115-
:param target_text:
81+
:param index_dtype: type for the begin, end index arrays
82+
:param token_dict_dtype: type for the tokens dictionary array
11683
"""
11784
assert pa.types.is_integer(index_dtype)
118-
self.num_char_span_splits = num_char_span_splits
85+
assert pa.types.is_dictionary(token_dict_dtype)
11986

120-
# Store target text as field metadata
121-
metadata = {self.TARGET_TEXT_KEY: target_text}
122-
123-
token_span_fields = [
124-
pa.field(self.BEGINS_NAME, index_dtype, metadata=metadata),
87+
fields = [
88+
pa.field(self.BEGINS_NAME, index_dtype),
12589
pa.field(self.ENDS_NAME, index_dtype),
90+
pa.field(self.TOKENS_NAME, token_dict_dtype),
12691
]
12792

128-
# Span arrays fit into single fields
129-
if num_char_span_splits == 0:
130-
char_span_fields = [
131-
pa.field(ArrowSpanType.BEGINS_NAME, index_dtype),
132-
pa.field(ArrowSpanType.ENDS_NAME, index_dtype)
133-
]
134-
# Store splits of Span as multiple fields
135-
else:
136-
char_span_fields = []
137-
for i in range(num_char_span_splits):
138-
n = "_{}".format(i)
139-
begin_field = pa.field(ArrowSpanType.BEGINS_NAME + n, index_dtype)
140-
end_field = pa.field(ArrowSpanType.ENDS_NAME + n, index_dtype)
141-
char_span_fields.extend([begin_field, end_field])
142-
143-
fields = token_span_fields + char_span_fields
144-
14593
pa.PyExtensionType.__init__(self, pa.struct(fields))
14694

14795
def __reduce__(self):
14896
index_dtype = self.storage_type[self.BEGINS_NAME].type
149-
metadata = self.storage_type[self.BEGINS_NAME].metadata
150-
target_text = metadata[self.TARGET_TEXT_KEY].decode()
151-
num_char_span_splits = self.num_char_span_splits
152-
return ArrowTokenSpanType, (index_dtype, target_text, num_char_span_splits)
97+
token_dict_dtype = self.storage_type[self.TOKENS_NAME].type
98+
return ArrowTokenSpanType, (index_dtype, token_dict_dtype)
15399

154100

155101
def span_to_arrow(char_span: SpanArray) -> pa.ExtensionArray:
@@ -232,90 +178,33 @@ def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
232178
# Create arrays for begins/ends
233179
token_begins_array = pa.array(token_span.begin_token)
234180
token_ends_array = pa.array(token_span.end_token)
235-
#token_span_arrays = [token_begins_array, token_ends_array]
236-
237-
arrow_span_ext_array = span_to_arrow(token_span.tokens[0])
238-
arrow_span_array = arrow_span_ext_array.storage
239-
240-
token_begins_dict_array = pa.DictionaryArray.from_arrays(token_begins_array, arrow_span_array)
241-
242-
typ = ArrowTokenSpanType(token_begins_array.type, arrow_span_array.type)
243-
fields = list(typ.storage_type)
244-
245-
storage = pa.StructArray.from_arrays([token_begins_dict_array, token_ends_array], fields=fields)
246-
247-
return pa.ExtensionArray.from_storage(typ, storage)
248-
249-
250-
def token_span_to_arrow_BAK(token_span: TokenSpanArray) -> pa.ExtensionArray:
251-
"""
252-
Convert a TokenSpanArray to a pyarrow.ExtensionArray with a type
253-
of ArrowTokenSpanType and struct as the storage type. The resulting
254-
extension array can be serialized and transferred with standard
255-
Arrow protocols.
256-
257-
:param token_span: A TokenSpanArray to be converted
258-
:return: pyarrow.ExtensionArray containing TokenSpan data
259-
"""
260-
# Create arrays for begins/ends
261-
token_begins_array = pa.array(token_span.begin_token)
262-
token_ends_array = pa.array(token_span.end_token)
263-
token_span_arrays = [token_begins_array, token_ends_array]
264-
265-
num_char_span_splits = 0
266-
267-
# If TokenSpan arrays have greater length than Span arrays, pad Span
268-
if len(token_span.begin_token) > len(token_span.tokens.begin):
269-
270-
padding = np.zeros(len(token_span.begin_token) - len(token_span.tokens.begin),
271-
token_span.tokens.begin.dtype)
272-
273-
isnull = np.append(np.full(len(token_span.tokens.begin), False), np.full(len(padding), True))
274-
char_begins_padded = np.append(token_span.tokens.begin, padding)
275-
char_ends_padded = np.append(token_span.tokens.end, padding)
276-
char_begins_array = pa.array(char_begins_padded, mask=isnull)
277-
char_ends_array = pa.array(char_ends_padded, mask=isnull)
278-
char_span_arrays = [char_begins_array, char_ends_array]
279-
280-
# If TokenSpan arrays have less length than Span arrays, split Span into multiple arrays
281-
elif len(token_span.begin_token) < len(token_span.tokens.begin):
282-
283-
char_begins_array = pa.array(token_span.tokens.begin)
284-
char_ends_array = pa.array(token_span.tokens.end)
285181

286-
char_span_arrays = []
287-
while len(char_begins_array) >= len(token_begins_array):
288-
char_begins_split = char_begins_array[:len(token_begins_array)]
289-
char_ends_split = char_ends_array[:len(token_ends_array)]
290-
291-
char_span_arrays.extend([char_begins_split, char_ends_split])
292-
num_char_span_splits += 1
182+
# Get either single document as a list or use a list of all if multiple docs
183+
assert len(token_span.tokens) > 0
184+
if all([token is token_span.tokens[0] for token in token_span.tokens]):
185+
tokens_arrays = [token_span.tokens[0]]
186+
tokens_indices = pa.array([0] * len(token_span.tokens))
187+
else:
188+
tokens_arrays = token_span.tokens
189+
tokens_indices = pa.array(range(len(tokens_arrays)))
293190

294-
char_begins_array = char_begins_array[len(token_begins_array):]
295-
char_ends_array = char_ends_array[len(token_ends_array):]
191+
# Convert each token SpanArray to Arrow and get as raw storage
192+
arrow_tokens_arrays = [span_to_arrow(sa).storage for sa in tokens_arrays]
296193

297-
# Pad the final split
298-
if len(char_begins_array) > 0:
299-
padding = np.zeros(len(token_begins_array) - len(char_begins_array),
300-
token_span.tokens.begin.dtype)
301-
isnull = np.append(np.full(len(char_begins_array), False), np.full(len(padding), True))
302-
char_begins_padded = np.append(char_begins_array.to_numpy(), padding)
303-
char_ends_padded = np.append(char_ends_array.to_numpy(), padding)
304-
char_begins_split = pa.array(char_begins_padded, mask=isnull)
305-
char_ends_split = pa.array(char_ends_padded, mask=isnull)
306-
char_span_arrays.extend([char_begins_split, char_ends_split])
307-
num_char_span_splits += 1
194+
# Create a list array with each element is an ArrowSpanArray
195+
# TODO: pyarrow.lib.ArrowNotImplementedError: ('Sequence converter for type dictionary<values=string, indices=int8, ordered=0> not implemented', 'Conversion failed for column ts1 with type TokenSpanDtype')
196+
#arrow_tokens_arrays_array = pa.array(arrow_tokens_arrays, type=pa.list_(arrow_tokens_arrays[0].type))
197+
offsets = [0] + [len(a) for a in arrow_tokens_arrays]
198+
values = pa.concat_arrays(arrow_tokens_arrays) # TODO: can't concat extension arrays?
199+
arrow_tokens_arrays_array = pa.ListArray.from_arrays(offsets, values)
308200

309-
# TokenSpan arrays are equal length to Span arrays
310-
else:
311-
char_begins_array = pa.array(token_span.tokens.begin)
312-
char_ends_array = pa.array(token_span.tokens.end)
313-
char_span_arrays = [char_begins_array, char_ends_array]
201+
# Create a dictionary array mapping each token SpanArray index used to the list of ArrowSpanArrays
202+
tokens_dict_array = pa.DictionaryArray.from_arrays(tokens_indices, arrow_tokens_arrays_array)
314203

315-
typ = ArrowTokenSpanType(token_begins_array.type, token_span.target_text, num_char_span_splits)
204+
typ = ArrowTokenSpanType(token_begins_array.type, tokens_dict_array.type)
316205
fields = list(typ.storage_type)
317206

318-
storage = pa.StructArray.from_arrays(token_span_arrays + char_span_arrays, fields=fields)
207+
storage = pa.StructArray.from_arrays([token_begins_array, token_ends_array, tokens_dict_array], fields=fields)
319208

320209
return pa.ExtensionArray.from_storage(typ, storage)
321210

@@ -335,46 +224,41 @@ def arrow_to_token_span(extension_array: pa.ExtensionArray) -> TokenSpanArray:
335224

336225
assert pa.types.is_struct(extension_array.storage.type)
337226

338-
# Get target text from the begins field metadata and decode string
339-
metadata = extension_array.storage.type[ArrowTokenSpanType.BEGINS_NAME].metadata
340-
target_text = metadata[ArrowSpanType.TARGET_TEXT_KEY]
341-
if isinstance(target_text, bytes):
342-
target_text = target_text.decode()
343-
344227
# Get the begins/ends pyarrow arrays
345228
token_begins_array = extension_array.storage.field(ArrowTokenSpanType.BEGINS_NAME)
346229
token_ends_array = extension_array.storage.field(ArrowTokenSpanType.ENDS_NAME)
347230

348-
# Check if CharSpans have been split
349-
num_char_span_splits = extension_array.type.num_char_span_splits
350-
if num_char_span_splits > 0:
351-
char_begins_splits = []
352-
char_ends_splits = []
353-
for i in range(num_char_span_splits):
354-
char_begins_splits.append(
355-
extension_array.storage.field(ArrowSpanType.BEGINS_NAME + "_{}".format(i)))
356-
char_ends_splits.append(
357-
extension_array.storage.field(ArrowSpanType.ENDS_NAME + "_{}".format(i)))
358-
char_begins_array = pa.concat_arrays(char_begins_splits)
359-
char_ends_array = pa.concat_arrays(char_ends_splits)
360-
else:
361-
char_begins_array = extension_array.storage.field(ArrowSpanType.BEGINS_NAME)
362-
char_ends_array = extension_array.storage.field(ArrowSpanType.ENDS_NAME)
231+
# Get the tokens as a dictionary array where indices map to a list of ArrowSpanArrays
232+
tokens_dict_array = extension_array.storage.field(ArrowTokenSpanType.TOKENS_NAME)
233+
tokens_indices = tokens_dict_array.indices
234+
arrow_tokens_arrays_array = tokens_dict_array.dictionary
235+
236+
# Breakup the list of ArrowSpanArrays and convert back to individual SpanArrays
237+
tokens_arrays = []
238+
span_type = None
239+
for i in range(1, len(arrow_tokens_arrays_array.offsets)):
240+
start = arrow_tokens_arrays_array.offsets[i - 1].as_py()
241+
stop = arrow_tokens_arrays_array.offsets[i].as_py()
242+
arrow_tokens_array = arrow_tokens_arrays_array.values[start:stop]
243+
244+
# Make an instance of ArrowSpanType
245+
if span_type is None:
246+
begins_array = arrow_tokens_array.field(ArrowSpanType.BEGINS_NAME)
247+
target_text_dict_array = arrow_tokens_array.field(ArrowSpanType.TARGET_TEXT_DICT_NAME)
248+
span_type = ArrowSpanType(begins_array.type, target_text_dict_array.type)
249+
250+
# Re-make the Arrow extension type to convert back to a SpanArray
251+
tokens_array = arrow_to_span(pa.ExtensionArray.from_storage(span_type, arrow_tokens_array))
252+
tokens_arrays.append(tokens_array)
363253

364-
# Remove any trailing padding
365-
if char_begins_array.null_count > 0:
366-
char_begins_array = char_begins_array[:-char_begins_array.null_count]
367-
char_ends_array = char_ends_array[:-char_ends_array.null_count]
254+
# Map the token indices to the actual token SpanArray for each element in the TokenSpanArray
255+
tokens = [tokens_arrays[i.as_py()] for i in tokens_indices]
368256

369257
# Zero-copy convert arrays to numpy
370258
token_begins = token_begins_array.to_numpy()
371259
token_ends = token_ends_array.to_numpy()
372-
char_begins = char_begins_array.to_numpy()
373-
char_ends = char_ends_array.to_numpy()
374260

375-
# Create the SpanArray, then the TokenSpanArray
376-
char_span = SpanArray(target_text, char_begins, char_ends)
377-
return TokenSpanArray(char_span, token_begins, token_ends)
261+
return TokenSpanArray(tokens, token_begins, token_ends)
378262

379263

380264
class ArrowTensorType(pa.PyExtensionType):

0 commit comments

Comments
 (0)