--- /dev/null
+"""\\r
+A library of useful helper classes to the SAX classes, for the\r
+convenience of application and driver writers.\r
+"""\r
+\r
+import os, urlparse, urllib, types\r
+import io\r
+import sys\r
+import handler\r
+import xmlreader\r
+\r
+try:\r
+ _StringTypes = [types.StringType, types.UnicodeType]\r
+except AttributeError:\r
+ _StringTypes = [types.StringType]\r
+\r
+def __dict_replace(s, d):\r
+ """Replace substrings of a string using a dictionary."""\r
+ for key, value in d.items():\r
+ s = s.replace(key, value)\r
+ return s\r
+\r
+def escape(data, entities={}):\r
+ """Escape &, <, and > in a string of data.\r
+\r
+ You can escape other strings of data by passing a dictionary as\r
+ the optional entities parameter. The keys and values must all be\r
+ strings; each key will be replaced with its corresponding value.\r
+ """\r
+\r
+ # must do ampersand first\r
+ data = data.replace("&", "&")\r
+ data = data.replace(">", ">")\r
+ data = data.replace("<", "<")\r
+ if entities:\r
+ data = __dict_replace(data, entities)\r
+ return data\r
+\r
+def unescape(data, entities={}):\r
+ """Unescape &, <, and > in a string of data.\r
+\r
+ You can unescape other strings of data by passing a dictionary as\r
+ the optional entities parameter. The keys and values must all be\r
+ strings; each key will be replaced with its corresponding value.\r
+ """\r
+ data = data.replace("<", "<")\r
+ data = data.replace(">", ">")\r
+ if entities:\r
+ data = __dict_replace(data, entities)\r
+ # must do ampersand last\r
+ return data.replace("&", "&")\r
+\r
+def quoteattr(data, entities={}):\r
+ """Escape and quote an attribute value.\r
+\r
+ Escape &, <, and > in a string of data, then quote it for use as\r
+ an attribute value. The \" character will be escaped as well, if\r
+ necessary.\r
+\r
+ You can escape other strings of data by passing a dictionary as\r
+ the optional entities parameter. The keys and values must all be\r
+ strings; each key will be replaced with its corresponding value.\r
+ """\r
+ entities = entities.copy()\r
+ entities.update({'\n': ' ', '\r': ' ', '\t':'	'})\r
+ data = escape(data, entities)\r
+ if '"' in data:\r
+ if "'" in data:\r
+ data = '"%s"' % data.replace('"', """)\r
+ else:\r
+ data = "'%s'" % data\r
+ else:\r
+ data = '"%s"' % data\r
+ return data\r
+\r
+\r
+def _gettextwriter(out, encoding):\r
+ if out is None:\r
+ import sys\r
+ out = sys.stdout\r
+\r
+ if isinstance(out, io.RawIOBase):\r
+ buffer = io.BufferedIOBase(out)\r
+ # Keep the original file open when the TextIOWrapper is\r
+ # destroyed\r
+ buffer.close = lambda: None\r
+ else:\r
+ # This is to handle passed objects that aren't in the\r
+ # IOBase hierarchy, but just have a write method\r
+ buffer = io.BufferedIOBase()\r
+ buffer.writable = lambda: True\r
+ buffer.write = out.write\r
+ try:\r
+ # TextIOWrapper uses this methods to determine\r
+ # if BOM (for UTF-16, etc) should be added\r
+ buffer.seekable = out.seekable\r
+ buffer.tell = out.tell\r
+ except AttributeError:\r
+ pass\r
+ # wrap a binary writer with TextIOWrapper\r
+ return _UnbufferedTextIOWrapper(buffer, encoding=encoding,\r
+ errors='xmlcharrefreplace',\r
+ newline='\n')\r
+\r
+\r
+class _UnbufferedTextIOWrapper(io.TextIOWrapper):\r
+ def write(self, s):\r
+ super(_UnbufferedTextIOWrapper, self).write(s)\r
+ self.flush()\r
+\r
+\r
+class XMLGenerator(handler.ContentHandler):\r
+\r
+ def __init__(self, out=None, encoding="iso-8859-1"):\r
+ handler.ContentHandler.__init__(self)\r
+ out = _gettextwriter(out, encoding)\r
+ self._write = out.write\r
+ self._flush = out.flush\r
+ self._ns_contexts = [{}] # contains uri -> prefix dicts\r
+ self._current_context = self._ns_contexts[-1]\r
+ self._undeclared_ns_maps = []\r
+ self._encoding = encoding\r
+\r
+ def _qname(self, name):\r
+ """Builds a qualified name from a (ns_url, localname) pair"""\r
+ if name[0]:\r
+ # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is\r
+ # bound by definition to http://www.w3.org/XML/1998/namespace. It\r
+ # does not need to be declared and will not usually be found in\r
+ # self._current_context.\r
+ if 'http://www.w3.org/XML/1998/namespace' == name[0]:\r
+ return 'xml:' + name[1]\r
+ # The name is in a non-empty namespace\r
+ prefix = self._current_context[name[0]]\r
+ if prefix:\r
+ # If it is not the default namespace, prepend the prefix\r
+ return prefix + ":" + name[1]\r
+ # Return the unqualified name\r
+ return name[1]\r
+\r
+ # ContentHandler methods\r
+\r
+ def startDocument(self):\r
+ self._write(u'<?xml version="1.0" encoding="%s"?>\n' %\r
+ self._encoding)\r
+\r
+ def endDocument(self):\r
+ self._flush()\r
+\r
+ def startPrefixMapping(self, prefix, uri):\r
+ self._ns_contexts.append(self._current_context.copy())\r
+ self._current_context[uri] = prefix\r
+ self._undeclared_ns_maps.append((prefix, uri))\r
+\r
+ def endPrefixMapping(self, prefix):\r
+ self._current_context = self._ns_contexts[-1]\r
+ del self._ns_contexts[-1]\r
+\r
+ def startElement(self, name, attrs):\r
+ self._write(u'<' + name)\r
+ for (name, value) in attrs.items():\r
+ self._write(u' %s=%s' % (name, quoteattr(value)))\r
+ self._write(u'>')\r
+\r
+ def endElement(self, name):\r
+ self._write(u'</%s>' % name)\r
+\r
+ def startElementNS(self, name, qname, attrs):\r
+ self._write(u'<' + self._qname(name))\r
+\r
+ for prefix, uri in self._undeclared_ns_maps:\r
+ if prefix:\r
+ self._write(u' xmlns:%s="%s"' % (prefix, uri))\r
+ else:\r
+ self._write(u' xmlns="%s"' % uri)\r
+ self._undeclared_ns_maps = []\r
+\r
+ for (name, value) in attrs.items():\r
+ self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))\r
+ self._write(u'>')\r
+\r
+ def endElementNS(self, name, qname):\r
+ self._write(u'</%s>' % self._qname(name))\r
+\r
+ def characters(self, content):\r
+ if not isinstance(content, unicode):\r
+ content = unicode(content, self._encoding)\r
+ self._write(escape(content))\r
+\r
+ def ignorableWhitespace(self, content):\r
+ if not isinstance(content, unicode):\r
+ content = unicode(content, self._encoding)\r
+ self._write(content)\r
+\r
+ def processingInstruction(self, target, data):\r
+ self._write(u'<?%s %s?>' % (target, data))\r
+\r
+\r
+class XMLFilterBase(xmlreader.XMLReader):\r
+ """This class is designed to sit between an XMLReader and the\r
+ client application's event handlers. By default, it does nothing\r
+ but pass requests up to the reader and events on to the handlers\r
+ unmodified, but subclasses can override specific methods to modify\r
+ the event stream or the configuration requests as they pass\r
+ through."""\r
+\r
+ def __init__(self, parent = None):\r
+ xmlreader.XMLReader.__init__(self)\r
+ self._parent = parent\r
+\r
+ # ErrorHandler methods\r
+\r
+ def error(self, exception):\r
+ self._err_handler.error(exception)\r
+\r
+ def fatalError(self, exception):\r
+ self._err_handler.fatalError(exception)\r
+\r
+ def warning(self, exception):\r
+ self._err_handler.warning(exception)\r
+\r
+ # ContentHandler methods\r
+\r
+ def setDocumentLocator(self, locator):\r
+ self._cont_handler.setDocumentLocator(locator)\r
+\r
+ def startDocument(self):\r
+ self._cont_handler.startDocument()\r
+\r
+ def endDocument(self):\r
+ self._cont_handler.endDocument()\r
+\r
+ def startPrefixMapping(self, prefix, uri):\r
+ self._cont_handler.startPrefixMapping(prefix, uri)\r
+\r
+ def endPrefixMapping(self, prefix):\r
+ self._cont_handler.endPrefixMapping(prefix)\r
+\r
+ def startElement(self, name, attrs):\r
+ self._cont_handler.startElement(name, attrs)\r
+\r
+ def endElement(self, name):\r
+ self._cont_handler.endElement(name)\r
+\r
+ def startElementNS(self, name, qname, attrs):\r
+ self._cont_handler.startElementNS(name, qname, attrs)\r
+\r
+ def endElementNS(self, name, qname):\r
+ self._cont_handler.endElementNS(name, qname)\r
+\r
+ def characters(self, content):\r
+ self._cont_handler.characters(content)\r
+\r
+ def ignorableWhitespace(self, chars):\r
+ self._cont_handler.ignorableWhitespace(chars)\r
+\r
+ def processingInstruction(self, target, data):\r
+ self._cont_handler.processingInstruction(target, data)\r
+\r
+ def skippedEntity(self, name):\r
+ self._cont_handler.skippedEntity(name)\r
+\r
+ # DTDHandler methods\r
+\r
+ def notationDecl(self, name, publicId, systemId):\r
+ self._dtd_handler.notationDecl(name, publicId, systemId)\r
+\r
+ def unparsedEntityDecl(self, name, publicId, systemId, ndata):\r
+ self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)\r
+\r
+ # EntityResolver methods\r
+\r
+ def resolveEntity(self, publicId, systemId):\r
+ return self._ent_handler.resolveEntity(publicId, systemId)\r
+\r
+ # XMLReader methods\r
+\r
+ def parse(self, source):\r
+ self._parent.setContentHandler(self)\r
+ self._parent.setErrorHandler(self)\r
+ self._parent.setEntityResolver(self)\r
+ self._parent.setDTDHandler(self)\r
+ self._parent.parse(source)\r
+\r
+ def setLocale(self, locale):\r
+ self._parent.setLocale(locale)\r
+\r
+ def getFeature(self, name):\r
+ return self._parent.getFeature(name)\r
+\r
+ def setFeature(self, name, state):\r
+ self._parent.setFeature(name, state)\r
+\r
+ def getProperty(self, name):\r
+ return self._parent.getProperty(name)\r
+\r
+ def setProperty(self, name, value):\r
+ self._parent.setProperty(name, value)\r
+\r
+ # XMLFilter methods\r
+\r
+ def getParent(self):\r
+ return self._parent\r
+\r
+ def setParent(self, parent):\r
+ self._parent = parent\r
+\r
+# --- Utility functions\r
+\r
+def prepare_input_source(source, base = ""):\r
+ """This function takes an InputSource and an optional base URL and\r
+ returns a fully resolved InputSource object ready for reading."""\r
+\r
+ if type(source) in _StringTypes:\r
+ source = xmlreader.InputSource(source)\r
+ elif hasattr(source, "read"):\r
+ f = source\r
+ source = xmlreader.InputSource()\r
+ source.setByteStream(f)\r
+ if hasattr(f, "name"):\r
+ source.setSystemId(f.name)\r
+\r
+ if source.getByteStream() is None:\r
+ try:\r
+ sysid = source.getSystemId()\r
+ basehead = os.path.dirname(os.path.normpath(base))\r
+ encoding = sys.getfilesystemencoding()\r
+ if isinstance(sysid, unicode):\r
+ if not isinstance(basehead, unicode):\r
+ try:\r
+ basehead = basehead.decode(encoding)\r
+ except UnicodeDecodeError:\r
+ sysid = sysid.encode(encoding)\r
+ else:\r
+ if isinstance(basehead, unicode):\r
+ try:\r
+ sysid = sysid.decode(encoding)\r
+ except UnicodeDecodeError:\r
+ basehead = basehead.encode(encoding)\r
+ sysidfilename = os.path.join(basehead, sysid)\r
+ isfile = os.path.isfile(sysidfilename)\r
+ except UnicodeError:\r
+ isfile = False\r
+ if isfile:\r
+ source.setSystemId(sysidfilename)\r
+ f = open(sysidfilename, "rb")\r
+ else:\r
+ source.setSystemId(urlparse.urljoin(base, source.getSystemId()))\r
+ f = urllib.urlopen(source.getSystemId())\r
+\r
+ source.setByteStream(f)\r
+\r
+ return source\r