Skip to content

Commit 6dbbfe0

Browse files
committed
Initial mods
1 parent fdf40e0 commit 6dbbfe0

File tree

2 files changed

+64
-3
lines changed

2 files changed

+64
-3
lines changed

text_extensions_for_pandas/array/arrow_conversion.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,40 @@ class ArrowTokenSpanType(pa.PyExtensionType):
6868
PyArrow extension type definition for conversions to/from TokenSpan columns
6969
"""
7070

71+
BEGINS_NAME = "token_begins"
72+
ENDS_NAME = "token_ends"
73+
74+
def __init__(self, index_dtype, token_dict_dtype):
75+
"""
76+
Create an instance of a TokenSpan data type with given index type and
77+
target text that will be stored in Field metadata.
78+
79+
:param index_dtype:
80+
:param target_text:
81+
"""
82+
assert pa.types.is_integer(index_dtype)
83+
84+
token_span_fields = [
85+
pa.field(self.BEGINS_NAME, pa.dictionary(index_dtype, token_dict_dtype)),
86+
pa.field(self.ENDS_NAME, index_dtype),
87+
]
88+
89+
90+
fields = token_span_fields
91+
92+
pa.PyExtensionType.__init__(self, pa.struct(fields))
93+
94+
def __reduce__(self):
95+
index_dtype = self.storage_type[self.ENDS_NAME].type
96+
token_dict_dtype = self.storage_type[self.BEGINS_NAME].type.value_type
97+
return ArrowTokenSpanType, (index_dtype, token_dict_dtype)
98+
99+
100+
class ArrowTokenSpanTypeBAK(pa.PyExtensionType):
101+
"""
102+
PyArrow extension type definition for conversions to/from TokenSpan columns
103+
"""
104+
71105
BEGINS_NAME = "token_begins"
72106
ENDS_NAME = "token_ends"
73107
TARGET_TEXT_DICT_NAME = "token_spans"
@@ -192,6 +226,34 @@ def token_span_to_arrow(token_span: TokenSpanArray) -> pa.ExtensionArray:
192226
extension array can be serialized and transferred with standard
193227
Arrow protocols.
194228
229+
:param token_span: A TokenSpanArray to be converted
230+
:return: pyarrow.ExtensionArray containing TokenSpan data
231+
"""
232+
# Create arrays for begins/ends
233+
token_begins_array = pa.array(token_span.begin_token)
234+
token_ends_array = pa.array(token_span.end_token)
235+
#token_span_arrays = [token_begins_array, token_ends_array]
236+
237+
arrow_span_ext_array = span_to_arrow(token_span.tokens[0])
238+
arrow_span_array = arrow_span_ext_array.storage
239+
240+
token_begins_dict_array = pa.DictionaryArray.from_arrays(token_begins_array, arrow_span_array)
241+
242+
typ = ArrowTokenSpanType(token_begins_array.type, arrow_span_array.type)
243+
fields = list(typ.storage_type)
244+
245+
storage = pa.StructArray.from_arrays([token_begins_dict_array, token_ends_array], fields=fields)
246+
247+
return pa.ExtensionArray.from_storage(typ, storage)
248+
249+
250+
def token_span_to_arrow_BAK(token_span: TokenSpanArray) -> pa.ExtensionArray:
251+
"""
252+
Convert a TokenSpanArray to a pyarrow.ExtensionArray with a type
253+
of ArrowTokenSpanType and struct as the storage type. The resulting
254+
extension array can be serialized and transferred with standard
255+
Arrow protocols.
256+
195257
:param token_span: A TokenSpanArray to be converted
196258
:return: pyarrow.ExtensionArray containing TokenSpan data
197259
"""

text_extensions_for_pandas/array/test_token_span.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ def test_as_frame(self):
365365
self.assertEqual(len(df), len(arr))
366366

367367

368-
@pytest.mark.skip("Feather not yet reimplemented")
369368
class TokenSpanArrayIOTests(ArrayTestBase):
370369

371370
def do_roundtrip(self, df):
@@ -384,7 +383,7 @@ def test_feather(self):
384383
self.do_roundtrip(df1)
385384

386385
# More token spans than tokens
387-
ts2 = TokenSpanArray(toks, [0, 1, 2, 3, 0, 2, 0], [1, 2, 3, 4, 2, 4, 4])
386+
"""ts2 = TokenSpanArray(toks, [0, 1, 2, 3, 0, 2, 0], [1, 2, 3, 4, 2, 4, 4])
388387
df2 = pd.DataFrame({"ts2": ts2})
389388
self.do_roundtrip(df2)
390389
@@ -405,7 +404,7 @@ def test_feather(self):
405404
406405
# All columns together, TokenSpan arrays padded as needed
407406
df = pd.concat([df1, df2, df3, df4], axis=1)
408-
self.do_roundtrip(df)
407+
self.do_roundtrip(df)"""
409408

410409

411410
@pytest.fixture

0 commit comments

Comments
 (0)