Index: pyyaml/trunk/tests/test_yaml_ext.py
===================================================================
--- pyyaml/trunk/tests/test_yaml_ext.py	(revision 312)
+++ pyyaml/trunk/tests/test_yaml_ext.py	(revision 322)
@@ -1,195 +1,273 @@
-
-import unittest, test_appliance
 
 import _yaml, yaml
-
-test_appliance.TestAppliance.SKIP_EXT = '.skip-ext'
-
-class TestCVersion(unittest.TestCase):
-
-    def testCVersion(self):
-        self.failUnlessEqual("%s.%s.%s" % _yaml.get_version(), _yaml.get_version_string())
-
-class TestCLoader(test_appliance.TestAppliance):
-
-    def _testCScannerFileInput(self, test_name, data_filename, canonical_filename):
-        self._testCScanner(test_name, data_filename, canonical_filename, True)
-
-    def _testCScanner(self, test_name, data_filename, canonical_filename, file_input=False, Loader=yaml.Loader):
-        if file_input:
-            data = file(data_filename, 'r')
-        else:
-            data = file(data_filename, 'r').read()
-        tokens = list(yaml.scan(data, Loader=Loader))
-        ext_tokens = []
+import types, pprint
+
+yaml.PyBaseLoader = yaml.BaseLoader
+yaml.PySafeLoader = yaml.SafeLoader
+yaml.PyLoader = yaml.Loader
+yaml.PyBaseDumper = yaml.BaseDumper
+yaml.PySafeDumper = yaml.SafeDumper
+yaml.PyDumper = yaml.Dumper
+
+old_scan = yaml.scan
+def new_scan(stream, Loader=yaml.CLoader):
+    return old_scan(stream, Loader)
+
+old_parse = yaml.parse
+def new_parse(stream, Loader=yaml.CLoader):
+    return old_parse(stream, Loader)
+
+old_compose = yaml.compose
+def new_compose(stream, Loader=yaml.CLoader):
+    return old_compose(stream, Loader)
+
+old_compose_all = yaml.compose_all
+def new_compose_all(stream, Loader=yaml.CLoader):
+    return old_compose_all(stream, Loader)
+
+old_load = yaml.load
+def new_load(stream, Loader=yaml.CLoader):
+    return old_load(stream, Loader)
+
+old_load_all = yaml.load_all
+def new_load_all(stream, Loader=yaml.CLoader):
+    return old_load_all(stream, Loader)
+
+old_safe_load = yaml.safe_load
+def new_safe_load(stream):
+    return old_load(stream, yaml.CSafeLoader)
+
+old_safe_load_all = yaml.safe_load_all
+def new_safe_load_all(stream):
+    return old_load_all(stream, yaml.CSafeLoader)
+
+old_emit = yaml.emit
+def new_emit(events, stream=None, Dumper=yaml.CDumper, **kwds):
+    return old_emit(events, stream, Dumper, **kwds)
+
+old_serialize = yaml.serialize
+def new_serialize(node, stream, Dumper=yaml.CDumper, **kwds):
+    return old_serialize(node, stream, Dumper, **kwds)
+
+old_serialize_all = yaml.serialize_all
+def new_serialize_all(nodes, stream=None, Dumper=yaml.CDumper, **kwds):
+    return old_serialize_all(nodes, stream, Dumper, **kwds)
+
+old_dump = yaml.dump
+def new_dump(data, stream=None, Dumper=yaml.CDumper, **kwds):
+    return old_dump(data, stream, Dumper, **kwds)
+
+old_dump_all = yaml.dump_all
+def new_dump_all(documents, stream=None, Dumper=yaml.CDumper, **kwds):
+    return old_dump_all(documents, stream, Dumper, **kwds)
+
+old_safe_dump = yaml.safe_dump
+def new_safe_dump(data, stream=None, **kwds):
+    return old_dump(data, stream, yaml.CSafeDumper, **kwds)
+
+old_safe_dump_all = yaml.safe_dump_all
+def new_safe_dump_all(documents, stream=None, **kwds):
+    return old_dump_all(documents, stream, yaml.CSafeDumper, **kwds)
+
+def _set_up():
+    yaml.BaseLoader = yaml.CBaseLoader
+    yaml.SafeLoader = yaml.CSafeLoader
+    yaml.Loader = yaml.CLoader
+    yaml.BaseDumper = yaml.CBaseDumper
+    yaml.SafeDumper = yaml.CSafeDumper
+    yaml.Dumper = yaml.CDumper
+    yaml.scan = new_scan
+    yaml.parse = new_parse
+    yaml.compose = new_compose
+    yaml.compose_all = new_compose_all
+    yaml.load = new_load
+    yaml.load_all = new_load_all
+    yaml.safe_load = new_safe_load
+    yaml.safe_load_all = new_safe_load_all
+    yaml.emit = new_emit
+    yaml.serialize = new_serialize
+    yaml.serialize_all = new_serialize_all
+    yaml.dump = new_dump
+    yaml.dump_all = new_dump_all
+    yaml.safe_dump = new_safe_dump
+    yaml.safe_dump_all = new_safe_dump_all
+
+def _tear_down():
+    yaml.BaseLoader = yaml.PyBaseLoader
+    yaml.SafeLoader = yaml.PySafeLoader
+    yaml.Loader = yaml.PyLoader
+    yaml.BaseDumper = yaml.PyBaseDumper
+    yaml.SafeDumper = yaml.PySafeDumper
+    yaml.Dumper = yaml.PyDumper
+    yaml.scan = old_scan
+    yaml.parse = old_parse
+    yaml.compose = old_compose
+    yaml.compose_all = old_compose_all
+    yaml.load = old_load
+    yaml.load_all = old_load_all
+    yaml.safe_load = old_safe_load
+    yaml.safe_load_all = old_safe_load_all
+    yaml.emit = old_emit
+    yaml.serialize = old_serialize
+    yaml.serialize_all = old_serialize_all
+    yaml.dump = old_dump
+    yaml.dump_all = old_dump_all
+    yaml.safe_dump = old_safe_dump
+    yaml.safe_dump_all = old_safe_dump_all
+
+def test_c_version(verbose=False):
+    if verbose:
+        print _yaml.get_version()
+        print _yaml.get_version_string()
+    assert ("%s.%s.%s" % _yaml.get_version()) == _yaml.get_version_string(),    \
+            (_yaml.get_version(), _yaml.get_version_string())
+
+def _compare_scanners(py_data, c_data, verbose):
+    py_tokens = list(yaml.scan(py_data, Loader=yaml.PyLoader))
+    c_tokens = []
+    try:
+        for token in yaml.scan(c_data, Loader=yaml.CLoader):
+            c_tokens.append(token)
+        assert len(py_tokens) == len(c_tokens), (len(py_tokens), len(c_tokens))
+        for py_token, c_token in zip(py_tokens, c_tokens):
+            assert py_token.__class__ == c_token.__class__, (py_token, c_token)
+            if hasattr(py_token, 'value'):
+                assert py_token.value == c_token.value, (py_token, c_token)
+            if isinstance(py_token, yaml.StreamEndToken):
+                continue
+            py_start = (py_token.start_mark.index, py_token.start_mark.line, py_token.start_mark.column)
+            py_end = (py_token.end_mark.index, py_token.end_mark.line, py_token.end_mark.column)
+            c_start = (c_token.start_mark.index, c_token.start_mark.line, c_token.start_mark.column)
+            c_end = (c_token.end_mark.index, c_token.end_mark.line, c_token.end_mark.column)
+            assert py_start == c_start, (py_start, c_start)
+            assert py_end == c_end, (py_end, c_end)
+    finally:
+        if verbose:
+            print "PY_TOKENS:"
+            pprint.pprint(py_tokens)
+            print "C_TOKENS:"
+            pprint.pprint(c_tokens)
+
+def test_c_scanner(data_filename, canonical_filename, verbose=False):
+    _compare_scanners(open(data_filename, 'rb'),
+            open(data_filename, 'rb'), verbose)
+    _compare_scanners(open(data_filename, 'rb').read(),
+            open(data_filename, 'rb').read(), verbose)
+    _compare_scanners(open(canonical_filename, 'rb'),
+            open(canonical_filename, 'rb'), verbose)
+    _compare_scanners(open(canonical_filename, 'rb').read(),
+            open(canonical_filename, 'rb').read(), verbose)
+
+test_c_scanner.unittest = ['.data', '.canonical']
+test_c_scanner.skip = ['.skip-ext']
+
+def _compare_parsers(py_data, c_data, verbose):
+    py_events = list(yaml.parse(py_data, Loader=yaml.PyLoader))
+    c_events = []
+    try:
+        for event in yaml.parse(c_data, Loader=yaml.CLoader):
+            c_events.append(event)
+        assert len(py_events) == len(c_events), (len(py_events), len(c_events))
+        for py_event, c_event in zip(py_events, c_events):
+            for attribute in ['__class__', 'anchor', 'tag', 'implicit',
+                                'value', 'explicit', 'version', 'tags']:
+                py_value = getattr(py_event, attribute, None)
+                c_value = getattr(c_event, attribute, None)
+                assert py_value == c_value, (py_event, c_event, attribute)
+    finally:
+        if verbose:
+            print "PY_EVENTS:"
+            pprint.pprint(py_events)
+            print "C_EVENTS:"
+            pprint.pprint(c_events)
+
+def test_c_parser(data_filename, canonical_filename, verbose=False):
+    _compare_parsers(open(data_filename, 'rb'),
+            open(data_filename, 'rb'), verbose)
+    _compare_parsers(open(data_filename, 'rb').read(),
+            open(data_filename, 'rb').read(), verbose)
+    _compare_parsers(open(canonical_filename, 'rb'),
+            open(canonical_filename, 'rb'), verbose)
+    _compare_parsers(open(canonical_filename, 'rb').read(),
+            open(canonical_filename, 'rb').read(), verbose)
+
+test_c_parser.unittest = ['.data', '.canonical']
+test_c_parser.skip = ['.skip-ext']
+
+def _compare_emitters(data, verbose):
+    events = list(yaml.parse(data, Loader=yaml.PyLoader))
+    c_data = yaml.emit(events, Dumper=yaml.CDumper)
+    if verbose:
+        print c_data
+    py_events = list(yaml.parse(c_data, Loader=yaml.PyLoader))
+    c_events = list(yaml.parse(c_data, Loader=yaml.CLoader))
+    try:
+        assert len(events) == len(py_events), (len(events), len(py_events))
+        assert len(events) == len(c_events), (len(events), len(c_events))
+        for event, py_event, c_event in zip(events, py_events, c_events):
+            for attribute in ['__class__', 'anchor', 'tag', 'implicit',
+                                'value', 'explicit', 'version', 'tags']:
+                value = getattr(event, attribute, None)
+                py_value = getattr(py_event, attribute, None)
+                c_value = getattr(c_event, attribute, None)
+                if attribute == 'tag' and value in [None, u'!'] \
+                        and py_value in [None, u'!'] and c_value in [None, u'!']:
+                    continue
+                if attribute == 'explicit' and (py_value or c_value):
+                    continue
+                assert value == py_value, (event, py_event, attribute)
+                assert value == c_value, (event, c_event, attribute)
+    finally:
+        if verbose:
+            print "EVENTS:"
+            pprint.pprint(events)
+            print "PY_EVENTS:"
+            pprint.pprint(py_events)
+            print "C_EVENTS:"
+            pprint.pprint(c_events)
+
+def test_c_emitter(data_filename, canonical_filename, verbose=False):
+    _compare_emitters(open(data_filename, 'rb').read(), verbose)
+    _compare_emitters(open(canonical_filename, 'rb').read(), verbose)
+
+test_c_emitter.unittest = ['.data', '.canonical']
+test_c_emitter.skip = ['.skip-ext']
+
+def wrap_ext_function(function):
+    def wrapper(*args, **kwds):
+        _set_up()
         try:
-            if file_input:
-                data = file(data_filename, 'r')
-            for token in yaml.scan(data, Loader=yaml.CLoader):
-                ext_tokens.append(token)
-            self.failUnlessEqual(len(tokens), len(ext_tokens))
-            for token, ext_token in zip(tokens, ext_tokens):
-                self.failUnlessEqual(token.__class__, ext_token.__class__)
-                if not isinstance(token, yaml.StreamEndToken):
-                    self.failUnlessEqual((token.start_mark.index, token.start_mark.line, token.start_mark.column),
-                            (ext_token.start_mark.index, ext_token.start_mark.line, ext_token.start_mark.column))
-                    self.failUnlessEqual((token.end_mark.index, token.end_mark.line, token.end_mark.column),
-                            (ext_token.end_mark.index, ext_token.end_mark.line, ext_token.end_mark.column))
-                if hasattr(token, 'value'):
-                    self.failUnlessEqual(token.value, ext_token.value)
-        except:
-            print
-            print "DATA:"
-            print file(data_filename, 'rb').read()
-            print "TOKENS:", tokens
-            print "EXT_TOKENS:", ext_tokens
-            raise
-
-    def _testCParser(self, test_name, data_filename, canonical_filename, Loader=yaml.Loader):
-        data = file(data_filename, 'r').read()
-        events = list(yaml.parse(data, Loader=Loader))
-        ext_events = []
-        try:
-            for event in yaml.parse(data, Loader=yaml.CLoader):
-                ext_events.append(event)
-                #print "EVENT:", event
-            self.failUnlessEqual(len(events), len(ext_events))
-            for event, ext_event in zip(events, ext_events):
-                self.failUnlessEqual(event.__class__, ext_event.__class__)
-                if hasattr(event, 'anchor'):
-                    self.failUnlessEqual(event.anchor, ext_event.anchor)
-                if hasattr(event, 'tag'):
-                    self.failUnlessEqual(event.tag, ext_event.tag)
-                if hasattr(event, 'implicit'):
-                    self.failUnlessEqual(event.implicit, ext_event.implicit)
-                if hasattr(event, 'value'):
-                    self.failUnlessEqual(event.value, ext_event.value)
-                if hasattr(event, 'explicit'):
-                    self.failUnlessEqual(event.explicit, ext_event.explicit)
-                if hasattr(event, 'version'):
-                    self.failUnlessEqual(event.version, ext_event.version)
-                if hasattr(event, 'tags'):
-                    self.failUnlessEqual(event.tags, ext_event.tags)
-        except:
-            print
-            print "DATA:"
-            print file(data_filename, 'rb').read()
-            print "EVENTS:", events
-            print "EXT_EVENTS:", ext_events
-            raise
-
-TestCLoader.add_tests('testCScanner', '.data', '.canonical')
-TestCLoader.add_tests('testCScannerFileInput', '.data', '.canonical')
-TestCLoader.add_tests('testCParser', '.data', '.canonical')
-
-class TestCEmitter(test_appliance.TestAppliance):
-
-    def _testCEmitter(self, test_name, data_filename, canonical_filename, Loader=yaml.Loader):
-        data1 = file(data_filename, 'r').read()
-        events = list(yaml.parse(data1, Loader=Loader))
-        data2 = yaml.emit(events, Dumper=yaml.CDumper)
-        ext_events = []
-        try:
-            for event in yaml.parse(data2):
-                ext_events.append(event)
-            self.failUnlessEqual(len(events), len(ext_events))
-            for event, ext_event in zip(events, ext_events):
-                self.failUnlessEqual(event.__class__, ext_event.__class__)
-                if hasattr(event, 'anchor'):
-                    self.failUnlessEqual(event.anchor, ext_event.anchor)
-                if hasattr(event, 'tag'):
-                    if not (event.tag in ['!', None] and ext_event.tag in ['!', None]):
-                        self.failUnlessEqual(event.tag, ext_event.tag)
-                if hasattr(event, 'implicit'):
-                    self.failUnlessEqual(event.implicit, ext_event.implicit)
-                if hasattr(event, 'value'):
-                    self.failUnlessEqual(event.value, ext_event.value)
-                if hasattr(event, 'explicit') and event.explicit:
-                    self.failUnlessEqual(event.explicit, ext_event.explicit)
-                if hasattr(event, 'version'):
-                    self.failUnlessEqual(event.version, ext_event.version)
-                if hasattr(event, 'tags'):
-                    self.failUnlessEqual(event.tags, ext_event.tags)
-        except:
-            print
-            print "DATA1:"
-            print data1
-            print "DATA2:"
-            print data2
-            print "EVENTS:", events
-            print "EXT_EVENTS:", ext_events
-            raise
-
-TestCEmitter.add_tests('testCEmitter', '.data', '.canonical')
-
-yaml.BaseLoader = yaml.CBaseLoader
-yaml.SafeLoader = yaml.CSafeLoader
-yaml.Loader = yaml.CLoader
-yaml.BaseDumper = yaml.CBaseDumper
-yaml.SafeDumper = yaml.CSafeDumper
-yaml.Dumper = yaml.CDumper
-old_scan = yaml.scan
-def scan(stream, Loader=yaml.CLoader):
-    return old_scan(stream, Loader)
-yaml.scan = scan
-old_parse = yaml.parse
-def parse(stream, Loader=yaml.CLoader):
-    return old_parse(stream, Loader)
-yaml.parse = parse
-old_compose = yaml.compose
-def compose(stream, Loader=yaml.CLoader):
-    return old_compose(stream, Loader)
-yaml.compose = compose
-old_compose_all = yaml.compose_all
-def compose_all(stream, Loader=yaml.CLoader):
-    return old_compose_all(stream, Loader)
-yaml.compose_all = compose_all
-old_load_all = yaml.load_all
-def load_all(stream, Loader=yaml.CLoader):
-    return old_load_all(stream, Loader)
-yaml.load_all = load_all
-old_load = yaml.load
-def load(stream, Loader=yaml.CLoader):
-    return old_load(stream, Loader)
-yaml.load = load
-def safe_load_all(stream):
-    return yaml.load_all(stream, yaml.CSafeLoader)
-yaml.safe_load_all = safe_load_all
-def safe_load(stream):
-    return yaml.load(stream, yaml.CSafeLoader)
-yaml.safe_load = safe_load
-old_emit = yaml.emit
-def emit(events, stream=None, Dumper=yaml.CDumper, **kwds):
-    return old_emit(events, stream, Dumper, **kwds)
-yaml.emit = emit
-old_serialize_all = yaml.serialize_all
-def serialize_all(nodes, stream=None, Dumper=yaml.CDumper, **kwds):
-    return old_serialize_all(nodes, stream, Dumper, **kwds)
-yaml.serialize_all = serialize_all
-old_serialize = yaml.serialize
-def serialize(node, stream, Dumper=yaml.CDumper, **kwds):
-    return old_serialize(node, stream, Dumper, **kwds)
-yaml.serialize = serialize
-old_dump_all = yaml.dump_all
-def dump_all(documents, stream=None, Dumper=yaml.CDumper, **kwds):
-    return old_dump_all(documents, stream, Dumper, **kwds)
-yaml.dump_all = dump_all
-old_dump = yaml.dump
-def dump(data, stream=None, Dumper=yaml.CDumper, **kwds):
-    return old_dump(data, stream, Dumper, **kwds)
-yaml.dump = dump
-def safe_dump_all(documents, stream=None, **kwds):
-    return yaml.dump_all(documents, stream, yaml.CSafeDumper, **kwds)
-yaml.safe_dump_all = safe_dump_all
-def safe_dump(data, stream=None, **kwds):
-    return yaml.dump(data, stream, yaml.CSafeDumper, **kwds)
-yaml.safe_dump = safe_dump
-
-from test_yaml import *
-
-def main(module='__main__'):
-    unittest.main(module)
+            function(*args, **kwds)
+        finally:
+            _tear_down()
+    wrapper.func_name = '%s_ext' % function.func_name
+    wrapper.unittest = function.unittest
+    wrapper.skip = getattr(function, 'skip', [])+['.skip-ext']
+    return wrapper
+
+def wrap_ext(collections):
+    functions = []
+    if not isinstance(collections, list):
+        collections = [collections]
+    for collection in collections:
+        if not isinstance(collection, dict):
+            collection = vars(collection)
+        keys = collection.keys()
+        keys.sort()
+        for key in keys:
+            value = collection[key]
+            if isinstance(value, types.FunctionType) and hasattr(value, 'unittest'):
+                functions.append(wrap_ext_function(value))
+    for function in functions:
+        assert function.func_name not in globals()
+        globals()[function.func_name] = function
+
+import test_tokens, test_structure, test_errors, test_resolver, test_constructor,   \
+        test_emitter, test_representer, test_recursive
+wrap_ext([test_tokens, test_structure, test_errors, test_resolver, test_constructor,
+        test_emitter, test_representer, test_recursive])
 
 if __name__ == '__main__':
-    main()
-
+    import test_appliance
+    test_appliance.run(globals())
+
