Skip to content

Commit

Permalink
Reworked attribute sorting a bit and made it namespace-aware (needs t…
Browse files Browse the repository at this point in the history
…ests)
  • Loading branch information
filipsalo committed Apr 11, 2010
1 parent 1133b6c commit 3d3b8f7
Showing 1 changed file with 56 additions and 55 deletions.
111 changes: 56 additions & 55 deletions streamxmlwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,34 @@ def escape_cdata(data, encoding):
return data.encode(encoding, "xmlcharrefreplace")


def _nssplitname(name):
if name is None:
return None
if not name[0] == "{":
return ("", name)
return tuple(name[1:].split("}", 1))

def _cname(name, nsmap, cnames):
"""Return a cname from its {ns}tag form."""
if not isinstance(name, tuple):
name = _nssplitname(name)
if name in cnames:
return cnames[name]
uri, ncname = name
if not uri:
for uri in nsmap:
if not nsmap[uri]:
break
else:
uri = ""
prefix = nsmap.setdefault(uri, "ns" + str(len(nsmap)+1))
if prefix:
cname = prefix + ":" + ncname
else:
cname = ncname
cnames[name] = cname
return cname

def sorter_factory(attrib_order):
"""Return a function that sorts a list of (key, value) pairs.
Expand All @@ -68,25 +96,24 @@ def sorter_factory(attrib_order):
class.
"""
attrib_order = attrib_order.copy()
for tag, names in attrib_order.iteritems():
attrib_order[tag] = dict((name, n) for (n, name) in enumerate(names))
attrib_order[tag] = dict([(_nssplitname(name), n)
for (n, name) in enumerate(names)])
for tag, order in attrib_order.iteritems():
order[None] = len(order)
order.setdefault(None, len(order))

def asort(pairs, tag):
"""Sort a list of ``(key, value)`` pairs), using the custom
sort order for the given `tag` name."""
"""Sort a list of ``(name, cname, value)`` tuples), using the
custom sort order for the given `tag` name."""
def key(item):
"""Return a sort key for a ``(key, value)`` pair."""
(_, (_, name)), _ = item
"""Return a sort key for a ``(name, cname, value)`` tuple."""
(ncname, cname, value) = item
if tag not in attrib_order:
return name
return ncname
keys = attrib_order[tag]
if name in keys:
return keys[name], name
else:
return keys[None], name
return sorted(pairs, key=key)
return keys.get(ncname, keys[None]), ncname
pairs.sort(key=key)
return asort


Expand Down Expand Up @@ -143,9 +170,11 @@ def __init__(self, file, encoding="utf-8",
self.encoding = encoding
self._pretty_print = pretty_print
self._sort = sort
self._abbrev_empty = abbrev_empty
if isinstance(sort, dict):
self._sort = sorter_factory(sort)
elif sort:
self._sort = lambda attributes, tag: attributes.sort()
self._abbrev_empty = abbrev_empty
self._tags = []
self._start_tag_open = False
self._new_namespaces = {}
Expand All @@ -155,30 +184,6 @@ def __init__(self, file, encoding="utf-8",
self.declaration()
self._wrote_data = False

def _cname(self, name, nsmap, cnames):
"""Return a cname from its {ns}tag form."""
if name in cnames:
return cnames[name]
if not name[0] == "{":
for uri in nsmap:
if not nsmap[uri]:
break
else:
uri = ""
name = "{" + uri + "}" + name
uri, ncname = name[1:].split("}", 1)
if uri not in nsmap:
prefix = "ns" + str(len(nsmap)+1)
nsmap[uri] = prefix
else:
prefix = nsmap[uri]
if prefix:
cname = prefix + ":" + ncname
else:
cname = ncname
cnames[name] = cname, (uri, ncname)
return cname, (uri, ncname)

def start(self, tag, attributes=None, nsmap=None, **kwargs):
"""Open a new `tag` element.
Expand Down Expand Up @@ -212,37 +217,33 @@ def start(self, tag, attributes=None, nsmap=None, **kwargs):
cnames = {}

# Write tag name (cname)
tag, _ = self._cname(tag, namespaces, cnames)
tag = _cname(tag, namespaces, cnames)
self.write("<" + tag)

# Make cnames for the attributes
if attributes:
kwargs.update(attributes)
attributes = sorted([(self._cname(name, namespaces, cnames), value)
for (name, value) in kwargs.iteritems()])
attributes = [(_nssplitname(name), value)
for (name, value) in kwargs.iteritems()]
attributes = [(name, _cname(name, namespaces, cnames), value)
for (name, value) in attributes]

# Write namespace declarations for all new mappings
for (uri, prefix) in sorted(namespaces.iteritems(),
key=lambda x: x[1]):
if uri not in old_namespaces or old_namespaces.get(uri) != prefix:
value = escape_attribute(uri, self.encoding)
if prefix:
self.write(" xmlns:" + prefix + "=\""
+ escape_attribute(uri, self.encoding)
+ "\"")
self.write(" xmlns:" + prefix + "=\"" + value + "\"")
else:
self.write(" xmlns=\""
+ escape_attribute(uri, self.encoding)
+ "\"")
self.write(" xmlns=\"" + value + "\"")

# Write the attributes
if callable(self._sort):
attributes = self._sort(attributes, tag)
elif self._sort:
attributes.sort(key=lambda x: x[0][1])
for ((cname, name), value) in attributes:
self.write(" " + cname + "=\""
+ escape_attribute(value, self.encoding)
+ "\"")
if self._sort:
self._sort(attributes, tag)
for (name, cname, value) in attributes:
value = escape_attribute(value, self.encoding)
self.write(" " + cname + "=\"" + value + "\"")

self._new_namespaces = {}
self._start_tag_open = True
Expand All @@ -258,7 +259,7 @@ def end(self, tag=None):
"""
open_tag, namespaces, cnames = self._tags.pop()
if tag is not None:
tag, _ = self._cname(tag, namespaces, cnames)
tag = _cname(tag, namespaces, cnames)
if open_tag != tag:
raise XMLSyntaxError("Start and end tag mismatch: %s and /%s."
% (open_tag, tag))
Expand Down

0 comments on commit 3d3b8f7

Please sign in to comment.