Skip to content

Commit

Permalink
Use black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
filipsalo committed Jun 27, 2020
1 parent 8093d76 commit 68997ba
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 61 deletions.
60 changes: 37 additions & 23 deletions streamxmlwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def escape_attribute(value, encoding):
value = value.replace("&", "&")
if "<" in value:
value = value.replace("<", "&lt;")
if "\"" in value:
value = value.replace("\"", "&quot;")
if '"' in value:
value = value.replace('"', "&quot;")
return value.encode(encoding, "xmlcharrefreplace")


Expand All @@ -67,6 +67,7 @@ def _nssplitname(name):
return ("", name)
return tuple(name[1:].split("}", 1))


def _cname(name, nsmap, cnames):
"""Return a cname from its {ns}tag form."""
if not isinstance(name, tuple):
Expand All @@ -80,14 +81,15 @@ def _cname(name, nsmap, cnames):
break
else:
uri = ""
prefix = nsmap.setdefault(uri, "ns" + str(len(nsmap)+1))
prefix = nsmap.setdefault(uri, "ns" + str(len(nsmap) + 1))
if prefix:
cname = prefix + ":" + ncname
else:
cname = ncname
cnames[name] = cname
return cname


def sorter_factory(attrib_order):
"""Return a function that sorts a list of (key, value) pairs.
Expand All @@ -100,22 +102,26 @@ def sorter_factory(attrib_order):
attrib_order = {}
for tag, names in items:
tag = _nssplitname(tag)
attrib_order[tag] = dict([(_nssplitname(name), n)
for (n, name) in enumerate(names)])
attrib_order[tag] = dict(
[(_nssplitname(name), n) for (n, name) in enumerate(names)]
)
for tag, order in attrib_order.items():
order.setdefault(None, len(order))

def asort(pairs, tag):
"""Sort a list of ``(name, cname, value)`` tuples), using the
custom sort order for the given `tag` name."""

def key(item):
"""Return a sort key for a ``(name, cname, value)`` tuple."""
(ncname, cname, value) = item
if tag not in attrib_order:
return ncname
keys = attrib_order[tag]
return keys.get(ncname, keys[None]), ncname

pairs.sort(key=key)

return asort


Expand All @@ -142,8 +148,10 @@ class XMLSyntaxError(Exception):

class XMLWriter(object):
"""Stream XML writer"""
def __init__(self, file, encoding="utf-8",
pretty_print=False, sort=True, abbrev_empty=True):

def __init__(
self, file, encoding="utf-8", pretty_print=False, sort=True, abbrev_empty=True
):
"""
Create an `XMLWriter` that writes its output to `file`.
Expand Down Expand Up @@ -213,7 +221,7 @@ def start(self, tag, attributes=None, nsmap=None, **kwargs):
if self._tags:
_, old_namespaces, _ = self._tags[-1]
else:
old_namespaces = {'': ''}
old_namespaces = {"": ""}
namespaces = old_namespaces.copy()
if nsmap:
self._new_namespaces.update(reversed(item) for item in nsmap.items())
Expand All @@ -232,27 +240,29 @@ def start(self, tag, attributes=None, nsmap=None, **kwargs):
# Make cnames for the attributes
if attributes:
kwargs.update(attributes)
attributes = [(_nssplitname(name), value)
for (name, value) in kwargs.items()]
attributes = [(name, _cname(name, namespaces, cnames), value)
for (name, value) in attributes]
attributes = [(_nssplitname(name), value) for (name, value) in kwargs.items()]
attributes = [
(name, _cname(name, namespaces, cnames), value)
for (name, value) in attributes
]

# Write namespace declarations for all new mappings
for (uri, prefix) in sorted(namespaces.items(),
key=lambda x: x[1]):
for (uri, prefix) in sorted(namespaces.items(), key=lambda x: x[1]):
if uri not in old_namespaces or old_namespaces.get(uri) != prefix:
value = escape_attribute(uri, self.encoding)
if prefix:
self.write(" xmlns:", bytes(prefix, self.encoding), "=\"", value, "\"")
self.write(
" xmlns:", bytes(prefix, self.encoding), '="', value, '"'
)
else:
self.write(" xmlns=\"", value, "\"")
self.write(' xmlns="', value, '"')

# Write the attributes
if self._sort:
self._sort(attributes, tag)
for (name, cname, value) in attributes:
value = escape_attribute(value, self.encoding)
self.write(" ", cname, "=\"", value, "\"")
self.write(" ", cname, '="', value, '"')

self._new_namespaces = {}
self._start_tag_open = True
Expand All @@ -270,8 +280,9 @@ def end(self, tag=None):
if tag is not None:
tag = _nssplitname(tag)
if open_tag != tag:
raise XMLSyntaxError("Start and end tag mismatch: %s and /%s."
% (open_tag, tag))
raise XMLSyntaxError(
"Start and end tag mismatch: %s and /%s." % (open_tag, tag)
)
if self._start_tag_open:
if self._abbrev_empty:
self.write(" />")
Expand Down Expand Up @@ -336,20 +347,23 @@ def _close_start(self):
def declaration(self):
"""Write an XML declaration."""
if self._started:
raise XMLSyntaxError("Can't write XML declaration after"
" root element has been started.")
raise XMLSyntaxError(
"Can't write XML declaration after root element has been started."
)
if not self._wrote_declaration:
self.pi("xml", "version='1.0' encoding='" + self.encoding + "'")
self._wrote_declaration = True

xml = declaration

def _comment_or_pi(self, *data):
"""Write a comment or PI, using special rules for
pretty-printing."""
self._close_start()
if self._pretty_print:
if ((self._tags and not self._wrote_data) or
(self._started and not self._tags)):
if (self._tags and not self._wrote_data) or (
self._started and not self._tags
):
self.write("\n", INDENT * len(self._tags))
self.write(*data)
if self._pretty_print and not self._started:
Expand Down
100 changes: 62 additions & 38 deletions tests/test_streamxmlwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ def test_single_element(self):
w = XMLWriter(BytesIO())
w.start("foo")
w.end()
self.assertOutput(w, b'<foo />')
self.assertOutput(w, b"<foo />")

def test_text_data(self):
w = XMLWriter(BytesIO())
w.start("foo")
w.data("bar")
w.end()
self.assertOutput(w, b'<foo>bar</foo>')
self.assertOutput(w, b"<foo>bar</foo>")

def test_single_attribute(self):
w = XMLWriter(BytesIO())
Expand All @@ -59,7 +59,7 @@ def test_sorted_attributes(self):

def test_escape_attributes(self):
w = XMLWriter(BytesIO())
w.start("foo", {"bar": "<>&\""})
w.start("foo", {"bar": '<>&"'})
w.end()
self.assertOutput(w, b'<foo bar="&lt;>&amp;&quot;" />')

Expand All @@ -68,18 +68,22 @@ def test_escape_character_data(self):
w.start("foo")
w.data("<>&")
w.end()
self.assertOutput(w, b'<foo>&lt;&gt;&amp;</foo>')
self.assertOutput(w, b"<foo>&lt;&gt;&amp;</foo>")

def test_file_encoding(self):
ts = [({},
b"<foo>\xc3\xa5\xc3\xa4\xc3\xb6\xe2\x98\x83\xe2\x9d\xa4</foo>"),
({"encoding": "us-ascii"},
b"<foo>&#229;&#228;&#246;&#9731;&#10084;</foo>"),
({"encoding": "iso-8859-1"},
b"<?xml version='1.0' encoding='iso-8859-1'?>" \
b"<foo>\xe5\xe4\xf6&#9731;&#10084;</foo>"),
({"encoding": "utf-8"},
b"<foo>\xc3\xa5\xc3\xa4\xc3\xb6\xe2\x98\x83\xe2\x9d\xa4</foo>")]
ts = [
({}, b"<foo>\xc3\xa5\xc3\xa4\xc3\xb6\xe2\x98\x83\xe2\x9d\xa4</foo>"),
({"encoding": "us-ascii"}, b"<foo>&#229;&#228;&#246;&#9731;&#10084;</foo>"),
(
{"encoding": "iso-8859-1"},
b"<?xml version='1.0' encoding='iso-8859-1'?>"
b"<foo>\xe5\xe4\xf6&#9731;&#10084;</foo>",
),
(
{"encoding": "utf-8"},
b"<foo>\xc3\xa5\xc3\xa4\xc3\xb6\xe2\x98\x83\xe2\x9d\xa4</foo>",
),
]
for (kwargs, output) in ts:
w = XMLWriter(BytesIO(), **kwargs)
w.start("foo")
Expand Down Expand Up @@ -133,14 +137,17 @@ def test_simple(self):
w.start("b")
w.start("c")
w.close()
self.assertOutput(w, b"""\
self.assertOutput(
w,
b"""\
<a>
<b>foo</b>
<b>bar</b>
<b>
<c />
</b>
</a>""")
</a>""",
)

def test_comment(self):
w = XMLWriter(BytesIO(), pretty_print=True)
Expand Down Expand Up @@ -218,8 +225,9 @@ def test_default_unbinding(self):
w.start_ns("", "")
w.start("foo")
w.close()
self.assertOutput(w, b'<foo xmlns="http://example.org/ns">'
b'<foo xmlns="" /></foo>')
self.assertOutput(
w, b'<foo xmlns="http://example.org/ns">' b'<foo xmlns="" /></foo>'
)

def test_prefix_rebinding(self):
w = XMLWriter(BytesIO())
Expand All @@ -228,53 +236,68 @@ def test_prefix_rebinding(self):
w.start_ns("a", "http://example.org/ns2")
w.start("{http://example.org/ns2}foo")
w.close()
self.assertOutput(w,b'<a:foo xmlns:a="http://example.org/ns">'
b'<a:foo xmlns:a="http://example.org/ns2" />'
b'</a:foo>')
self.assertOutput(
w,
b'<a:foo xmlns:a="http://example.org/ns">'
b'<a:foo xmlns:a="http://example.org/ns2" />'
b"</a:foo>",
)

def test_attributes_same_local_name(self):
w = XMLWriter(BytesIO())
w.start_ns("a", "http://example.org/ns1")
w.start_ns("b", "http://example.org/ns2")
w.start("foo")
w.start("bar", {"{http://example.org/ns1}attr": "1",
"{http://example.org/ns2}attr": "2"})
w.start(
"bar",
{"{http://example.org/ns1}attr": "1", "{http://example.org/ns2}attr": "2"},
)
w.close()
self.assertOutput(w, b'<foo xmlns:a="http://example.org/ns1"'
b' xmlns:b="http://example.org/ns2">'
b'<bar a:attr="1" b:attr="2" />'
b'</foo>')
self.assertOutput(
w,
b'<foo xmlns:a="http://example.org/ns1"'
b' xmlns:b="http://example.org/ns2">'
b'<bar a:attr="1" b:attr="2" />'
b"</foo>",
)

def test_attributes_same_local_one_prefixed(self):
w = XMLWriter(BytesIO())
w.start_ns("a", "http://example.org/ns")
w.start("foo")
w.start("bar", {"{http://example.org/ns}attr": "1",
"attr": "2"})
w.start("bar", {"{http://example.org/ns}attr": "1", "attr": "2"})
w.close()
self.assertOutput(w, b'<foo xmlns:a="http://example.org/ns">'
b'<bar attr="2" a:attr="1" />'
b'</foo>')
self.assertOutput(
w,
b'<foo xmlns:a="http://example.org/ns">'
b'<bar attr="2" a:attr="1" />'
b"</foo>",
)

def test_attributes_same_local_one_prefixed_one_default(self):
w = XMLWriter(BytesIO())
w.start_ns("", "http://example.org/ns1")
w.start_ns("a", "http://example.org/ns2")
w.start("{http://example.org/ns1}foo")
w.start("{http://example.org/ns1}bar",
{"{http://example.org/ns1}attr": "1",
"{http://example.org/ns2}attr": "2"})
w.start(
"{http://example.org/ns1}bar",
{"{http://example.org/ns1}attr": "1", "{http://example.org/ns2}attr": "2"},
)
w.close()
self.assertOutput(w, b'<foo xmlns="http://example.org/ns1"'
b' xmlns:a="http://example.org/ns2">'
b'<bar attr="1" a:attr="2" />'
b'</foo>')
self.assertOutput(
w,
b'<foo xmlns="http://example.org/ns1"'
b' xmlns:a="http://example.org/ns2">'
b'<bar attr="1" a:attr="2" />'
b"</foo>",
)


class TestIterwrite(XMLWriterTestCase):
def test_basic(self):
from lxml import etree
from io import BytesIO

w = XMLWriter(BytesIO())
xml = b"""\
<!--comment before--><?pi before?><foo xmlns="http://example.org/ns1">
Expand All @@ -293,6 +316,7 @@ def test_basic(self):
def test_chunked_text(self):
from lxml import etree
from io import BytesIO

for padding in (16382, 32755):
padding = b" " * padding
w = XMLWriter(BytesIO())
Expand Down

0 comments on commit 68997ba

Please sign in to comment.