diff --git a/fn/iters.py b/fn/iters.py index 8dc5f06..c2f0891 100644 --- a/fn/iters.py +++ b/fn/iters.py @@ -219,9 +219,10 @@ def flatten(items): http://docs.python.org/3.4/library/itertools.html#itertools-recipes """ + str_type = basestring if version_info[0] < 3 else str for item in items: is_iterable = isinstance(item, Iterable) - is_string_or_bytes = isinstance(item, (str, bytes, bytearray)) + is_string_or_bytes = isinstance(item, (str_type, bytes, bytearray)) if is_iterable and not is_string_or_bytes: for i in flatten(item): yield i diff --git a/tests.py b/tests.py index a5454b4..0db2b98 100644 --- a/tests.py +++ b/tests.py @@ -621,11 +621,13 @@ def test_flatten(self): self.assertEqual([1,1,2,1,2,3], list(iters.flatten(generators))) # flat list should return itself self.assertEqual([1,2,3], list(iters.flatten([1,2,3]))) - # Don't flatten strings, bytes, or bytearrays + # Don't flatten strings/unicode, bytes, or bytearrays self.assertEqual([2,"abc",1], list(iters.flatten([2,"abc",1]))) self.assertEqual([2, b'abc', 1], list(iters.flatten([2, b'abc', 1]))) self.assertEqual([2, bytearray(b'abc'), 1], list(iters.flatten([2, bytearray(b'abc'), 1]))) + self.assertEqual([bytearray(b'abc'), b'\xd1\x8f'.decode('utf8'), b'y'], + list(iters.flatten([bytearray(b'abc'), b'\xd1\x8f'.decode('utf8'), b'y']))) def test_accumulate(self): self.assertEqual([1,3,6,10,15], list(iters.accumulate([1,2,3,4,5])))