diff --git a/fn/iters.py b/fn/iters.py
index 8dc5f06..c2f0891 100644
--- a/fn/iters.py
+++ b/fn/iters.py
@@ -219,9 +219,10 @@ def flatten(items):
 
     http://docs.python.org/3.4/library/itertools.html#itertools-recipes
     """
+    str_type = basestring if version_info[0] < 3 else str
     for item in items:
         is_iterable = isinstance(item, Iterable)
-        is_string_or_bytes = isinstance(item, (str, bytes, bytearray))
+        is_string_or_bytes = isinstance(item, (str_type, bytes, bytearray))
         if is_iterable and not is_string_or_bytes:
             for i in flatten(item):
                 yield i
diff --git a/tests.py b/tests.py
index a5454b4..0db2b98 100644
--- a/tests.py
+++ b/tests.py
@@ -621,11 +621,13 @@ def test_flatten(self):
         self.assertEqual([1,1,2,1,2,3], list(iters.flatten(generators)))
         # flat list should return itself
         self.assertEqual([1,2,3], list(iters.flatten([1,2,3])))
-        # Don't flatten strings, bytes, or bytearrays
+        # Don't flatten strings/unicode, bytes, or bytearrays
         self.assertEqual([2,"abc",1], list(iters.flatten([2,"abc",1])))
         self.assertEqual([2, b'abc', 1], list(iters.flatten([2, b'abc', 1])))
         self.assertEqual([2, bytearray(b'abc'), 1],
                          list(iters.flatten([2, bytearray(b'abc'), 1])))
+        self.assertEqual([bytearray(b'abc'), b'\xd1\x8f'.decode('utf8'), b'y'],
+                         list(iters.flatten([bytearray(b'abc'), b'\xd1\x8f'.decode('utf8'), b'y'])))
 
     def test_accumulate(self):
         self.assertEqual([1,3,6,10,15], list(iters.accumulate([1,2,3,4,5])))