Skip to content

Commit

Permalink
Support py3, don't support py2
Browse files Browse the repository at this point in the history
  • Loading branch information
filipsalo committed Jun 27, 2020
1 parent b251e99 commit 66525cc
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 113 deletions.
13 changes: 4 additions & 9 deletions README.markdown
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ attributes.
Usage example
-------------

>>> from cStringIO import StringIO
>>> output = StringIO()
>>> from io import BytesIO
>>> output = BytesIO()

>>> from streamxmlwriter import XMLWriter
>>> writer = XMLWriter(output, pretty_print=True)
Expand All @@ -30,13 +30,8 @@ Usage example
>>> writer.start("empty")
>>> writer.close()

>>> print output.getvalue()
<foo xmlns:a="http://example.org/ns" two="2" a:one="1">
<bar>something</bar>
<!--hello-->
<a:baz x="y">whatnot</a:baz>
<empty />
</foo>
>>> output.getvalue()
b'<foo xmlns:a="http://example.org/ns" two="2" a:one="1">\n <bar>something</bar>\n <!--hello-->\n <a:baz x="y">whatnot</a:baz>\n <empty />\n</foo>'


The API
Expand Down
52 changes: 29 additions & 23 deletions streamxmlwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def sorter_factory(attrib_order):
tag = _nssplitname(tag)
attrib_order[tag] = dict([(_nssplitname(name), n)
for (n, name) in enumerate(names)])
for tag, order in attrib_order.iteritems():
for tag, order in attrib_order.items():
order.setdefault(None, len(order))

def asort(pairs, tag):
Expand All @@ -127,8 +127,9 @@ def tostring(element, *args, **kwargs):
keyword arguments are passed on to the underlying `XMLWriter`.
"""
import cStringIO
out = cStringIO.StringIO()
import io

out = io.BytesIO()
writer = XMLWriter(out, *args, **kwargs)
writer.element(element)
writer.close()
Expand Down Expand Up @@ -169,7 +170,6 @@ def __init__(self, file, encoding="utf-8",
"""
self.file = file
self.write = file.write
self.encoding = encoding
self._pretty_print = pretty_print
self._sort = sort
Expand All @@ -187,6 +187,12 @@ def __init__(self, file, encoding="utf-8",
self.declaration()
self._wrote_data = False

def write(self, *data):
for datum in data:
if not isinstance(datum, bytes):
datum = bytes(datum, self.encoding)
self.file.write(datum)

def start(self, tag, attributes=None, nsmap=None, **kwargs):
"""Open a new `tag` element.
Expand All @@ -201,18 +207,18 @@ def start(self, tag, attributes=None, nsmap=None, **kwargs):
self.write(">")
self._start_tag_open = False
if self._pretty_print and self._tags and not self._wrote_data:
self.write("\n" + INDENT * len(self._tags))
self.write("\n", INDENT * len(self._tags))

# Copy old namespaces and cnames
if self._tags:
_, old_namespaces, _ = self._tags[-1]
else:
old_namespaces = {'': None}
old_namespaces = {'': ''}
namespaces = old_namespaces.copy()
if nsmap:
self._new_namespaces.update(reversed(item) for item in nsmap.iteritems())
self._new_namespaces.update(reversed(item) for item in nsmap.items())
values = self._new_namespaces.values()
for uri, prefix in namespaces.items():
for uri, prefix in list(namespaces.items()):
if prefix in values:
del namespaces[uri]

Expand All @@ -221,32 +227,32 @@ def start(self, tag, attributes=None, nsmap=None, **kwargs):

# Write tag name (cname)
tag = _nssplitname(tag)
self.write("<" + _cname(tag, namespaces, cnames))
self.write("<", _cname(tag, namespaces, cnames))

# Make cnames for the attributes
if attributes:
kwargs.update(attributes)
attributes = [(_nssplitname(name), value)
for (name, value) in kwargs.iteritems()]
for (name, value) in kwargs.items()]
attributes = [(name, _cname(name, namespaces, cnames), value)
for (name, value) in attributes]

# Write namespace declarations for all new mappings
for (uri, prefix) in sorted(namespaces.iteritems(),
for (uri, prefix) in sorted(namespaces.items(),
key=lambda x: x[1]):
if uri not in old_namespaces or old_namespaces.get(uri) != prefix:
value = escape_attribute(uri, self.encoding)
if prefix:
self.write(" xmlns:" + prefix + "=\"" + value + "\"")
self.write(" xmlns:", bytes(prefix, self.encoding), "=\"", value, "\"")
else:
self.write(" xmlns=\"" + value + "\"")
self.write(" xmlns=\"", value, "\"")

# Write the attributes
if self._sort:
self._sort(attributes, tag)
for (name, cname, value) in attributes:
value = escape_attribute(value, self.encoding)
self.write(" " + cname + "=\"" + value + "\"")
self.write(" ", cname, "=\"", value, "\"")

self._new_namespaces = {}
self._start_tag_open = True
Expand All @@ -270,12 +276,12 @@ def end(self, tag=None):
if self._abbrev_empty:
self.write(" />")
else:
self.write("></" + _cname(open_tag, namespaces, cnames) + ">")
self.write("></", _cname(open_tag, namespaces, cnames), ">")
self._start_tag_open = False
else:
if self._pretty_print and not self._wrote_data:
self.write("\n" + INDENT * len(self._tags))
self.write("</" + _cname(open_tag, namespaces, cnames) + ">")
self.write("\n", INDENT * len(self._tags))
self.write("</", _cname(open_tag, namespaces, cnames), ">")
self._wrote_data = False

def start_ns(self, prefix, uri):
Expand Down Expand Up @@ -337,25 +343,25 @@ def declaration(self):
self._wrote_declaration = True
xml = declaration

def _comment_or_pi(self, data):
def _comment_or_pi(self, *data):
"""Write a comment or PI, using special rules for
pretty-printing."""
self._close_start()
if self._pretty_print:
if ((self._tags and not self._wrote_data) or
(self._started and not self._tags)):
self.write("\n" + INDENT * len(self._tags))
self.write(data)
self.write("\n", INDENT * len(self._tags))
self.write(*data)
if self._pretty_print and not self._started:
self.write("\n")

def comment(self, data):
"""Add an XML comment."""
self._comment_or_pi("<!--" + escape_cdata(data, self.encoding) + "-->")
self._comment_or_pi("<!--", escape_cdata(data, self.encoding), "-->")

def pi(self, target, data):
"""Add an XML processing instruction."""
self._comment_or_pi("<?" + target + " " + data + "?>")
self._comment_or_pi("<?", target, " ", data, "?>")

def close(self):
"""Close all open elements."""
Expand Down Expand Up @@ -389,7 +395,7 @@ def iterwrite(self, events):

def delayed_iterator(iterable):
iterable = iter(iterable)
previous = iterable.next()
previous = next(iterable)
for item in iterable:
yield previous
previous = item
Expand Down
Loading

0 comments on commit 66525cc

Please sign in to comment.