source: pyyaml/trunk/tests/test_constructor.py @ 144

Revision 144, 4.1 KB checked in by xi, 8 years ago (diff)

Add more unit tests.

Line 
1
2import test_appliance
3try:
4    import datetime
5except ImportError:
6    pass
7try:
8    set
9except NameError:
10    from sets import Set as set
11
12from yaml import *
13
14import xml.parsers
15
16class MyLoader(Loader):
17    pass
18class MyDumper(Dumper):
19    pass
20
21class MyTestClass1:
22
23    def __init__(self, x, y=0, z=0):
24        self.x = x
25        self.y = y
26        self.z = z
27
28    def __eq__(self, other):
29        if isinstance(other, MyTestClass1):
30            return self.__class__, self.__dict__ == other.__class__, other.__dict__
31        else:
32            return False
33
34def construct1(constructor, node):
35    mapping = constructor.construct_mapping(node)
36    return MyTestClass1(**mapping)
37def represent1(representer, native):
38    return representer.represent_mapping("!tag1", native.__dict__)
39
40add_constructor("!tag1", construct1, Loader=MyLoader)
41add_representer(MyTestClass1, represent1, Dumper=MyDumper)
42
43class MyTestClass2(MyTestClass1, YAMLObject):
44
45    yaml_loader = MyLoader
46    yaml_dumper = MyDumper
47    yaml_tag = "!tag2"
48
49    def from_yaml(cls, constructor, node):
50        x = constructor.construct_yaml_int(node)
51        return cls(x=x)
52    from_yaml = classmethod(from_yaml)
53
54    def to_yaml(cls, representer, native):
55        return representer.represent_scalar(cls.yaml_tag, str(native.x))
56    to_yaml = classmethod(to_yaml)
57
58class MyTestClass3(MyTestClass2):
59
60    yaml_tag = "!tag3"
61
62    def from_yaml(cls, constructor, node):
63        mapping = constructor.construct_mapping(node)
64        if '=' in mapping:
65            x = mapping['=']
66            del mapping['=']
67            mapping['x'] = x
68        return cls(**mapping)
69    from_yaml = classmethod(from_yaml)
70
71    def to_yaml(cls, representer, native):
72        return representer.represent_mapping(cls.yaml_tag, native.__dict__)
73    to_yaml = classmethod(to_yaml)
74
75class YAMLObject1(YAMLObject):
76
77    yaml_loader = MyLoader
78    yaml_dumper = MyDumper
79    yaml_tag = '!foo'
80
81    def __init__(self, my_parameter=None, my_another_parameter=None):
82        self.my_parameter = my_parameter
83        self.my_another_parameter = my_another_parameter
84
85    def __eq__(self, other):
86        if isinstance(other, YAMLObject1):
87            return self.__class__, self.__dict__ == other.__class__, other.__dict__
88        else:
89            return False
90
91class YAMLObject2(YAMLObject):
92
93    yaml_loader = MyLoader
94    yaml_dumper = MyDumper
95    yaml_tag = '!bar'
96
97    def __init__(self, foo=1, bar=2, baz=3):
98        self.foo = foo
99        self.bar = bar
100        self.baz = baz
101
102    def __getstate__(self):
103        return {1: self.foo, 2: self.bar, 3: self.baz}
104
105    def __setstate__(self, state):
106        self.foo = state[1]
107        self.bar = state[2]
108        self.baz = state[3]
109
110    def __eq__(self, other):
111        if isinstance(other, YAMLObject2):
112            return self.__class__, self.__dict__ == other.__class__, other.__dict__
113        else:
114            return False
115
116class TestConstructorTypes(test_appliance.TestAppliance):
117
118    def _testTypes(self, test_name, data_filename, code_filename):
119        data1 = None
120        data2 = None
121        try:
122            data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
123            if len(data1) == 1:
124                data1 = data1[0]
125            data2 = eval(file(code_filename, 'rb').read())
126            self.failUnlessEqual(type(data1), type(data2))
127            try:
128                self.failUnlessEqual(data1, data2)
129            except AssertionError:
130                if isinstance(data1, dict):
131                    data1 = data1.items()
132                    data1.sort()
133                    data1 = repr(data1)
134                    data2 = data2.items()
135                    data2.sort()
136                    data2 = repr(data2)
137                if data1 != data2:
138                    raise
139        except:
140            print
141            print "DATA:"
142            print file(data_filename, 'rb').read()
143            print "CODE:"
144            print file(code_filename, 'rb').read()
145            print "NATIVES1:", data1
146            print "NATIVES2:", data2
147            raise
148
149TestConstructorTypes.add_tests('testTypes', '.data', '.code')
150
Note: See TracBrowser for help on using the repository browser.