source: pyyaml/trunk/tests/lib3/test_constructor.py @ 330

Revision 330, 8.5 KB checked in by xi, 6 years ago (diff)

Share data files between Py2 and Py3 test suites.

Line 
1
2import yaml
3import pprint
4
5import datetime
6import yaml.tokens
7
8def execute(code):
9    global value
10    exec(code)
11    return value
12
13def _make_objects():
14    global MyLoader, MyDumper, MyTestClass1, MyTestClass2, MyTestClass3, YAMLObject1, YAMLObject2,  \
15            AnObject, AnInstance, AState, ACustomState, InitArgs, InitArgsWithState,    \
16            NewArgs, NewArgsWithState, Reduce, ReduceWithState, MyInt, MyList, MyDict,  \
17            FixedOffset, execute
18
19    class MyLoader(yaml.Loader):
20        pass
21    class MyDumper(yaml.Dumper):
22        pass
23
24    class MyTestClass1:
25        def __init__(self, x, y=0, z=0):
26            self.x = x
27            self.y = y
28            self.z = z
29        def __eq__(self, other):
30            if isinstance(other, MyTestClass1):
31                return self.__class__, self.__dict__ == other.__class__, other.__dict__
32            else:
33                return False
34
35    def construct1(constructor, node):
36        mapping = constructor.construct_mapping(node)
37        return MyTestClass1(**mapping)
38    def represent1(representer, native):
39        return representer.represent_mapping("!tag1", native.__dict__)
40
41    yaml.add_constructor("!tag1", construct1, Loader=MyLoader)
42    yaml.add_representer(MyTestClass1, represent1, Dumper=MyDumper)
43
44    class MyTestClass2(MyTestClass1, yaml.YAMLObject):
45        yaml_loader = MyLoader
46        yaml_dumper = MyDumper
47        yaml_tag = "!tag2"
48        def from_yaml(cls, constructor, node):
49            x = constructor.construct_yaml_int(node)
50            return cls(x=x)
51        from_yaml = classmethod(from_yaml)
52        def to_yaml(cls, representer, native):
53            return representer.represent_scalar(cls.yaml_tag, str(native.x))
54        to_yaml = classmethod(to_yaml)
55
56    class MyTestClass3(MyTestClass2):
57        yaml_tag = "!tag3"
58        def from_yaml(cls, constructor, node):
59            mapping = constructor.construct_mapping(node)
60            if '=' in mapping:
61                x = mapping['=']
62                del mapping['=']
63                mapping['x'] = x
64            return cls(**mapping)
65        from_yaml = classmethod(from_yaml)
66        def to_yaml(cls, representer, native):
67            return representer.represent_mapping(cls.yaml_tag, native.__dict__)
68        to_yaml = classmethod(to_yaml)
69
70    class YAMLObject1(yaml.YAMLObject):
71        yaml_loader = MyLoader
72        yaml_dumper = MyDumper
73        yaml_tag = '!foo'
74        def __init__(self, my_parameter=None, my_another_parameter=None):
75            self.my_parameter = my_parameter
76            self.my_another_parameter = my_another_parameter
77        def __eq__(self, other):
78            if isinstance(other, YAMLObject1):
79                return self.__class__, self.__dict__ == other.__class__, other.__dict__
80            else:
81                return False
82
83    class YAMLObject2(yaml.YAMLObject):
84        yaml_loader = MyLoader
85        yaml_dumper = MyDumper
86        yaml_tag = '!bar'
87        def __init__(self, foo=1, bar=2, baz=3):
88            self.foo = foo
89            self.bar = bar
90            self.baz = baz
91        def __getstate__(self):
92            return {1: self.foo, 2: self.bar, 3: self.baz}
93        def __setstate__(self, state):
94            self.foo = state[1]
95            self.bar = state[2]
96            self.baz = state[3]
97        def __eq__(self, other):
98            if isinstance(other, YAMLObject2):
99                return self.__class__, self.__dict__ == other.__class__, other.__dict__
100            else:
101                return False
102
103    class AnObject:
104        def __new__(cls, foo=None, bar=None, baz=None):
105            self = object.__new__(cls)
106            self.foo = foo
107            self.bar = bar
108            self.baz = baz
109            return self
110        def __cmp__(self, other):
111            return cmp((type(self), self.foo, self.bar, self.baz),
112                    (type(other), other.foo, other.bar, other.baz))
113        def __eq__(self, other):
114            return type(self) is type(other) and    \
115                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
116
117    class AnInstance:
118        def __init__(self, foo=None, bar=None, baz=None):
119            self.foo = foo
120            self.bar = bar
121            self.baz = baz
122        def __cmp__(self, other):
123            return cmp((type(self), self.foo, self.bar, self.baz),
124                    (type(other), other.foo, other.bar, other.baz))
125        def __eq__(self, other):
126            return type(self) is type(other) and    \
127                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
128
129    class AState(AnInstance):
130        def __getstate__(self):
131            return {
132                '_foo': self.foo,
133                '_bar': self.bar,
134                '_baz': self.baz,
135            }
136        def __setstate__(self, state):
137            self.foo = state['_foo']
138            self.bar = state['_bar']
139            self.baz = state['_baz']
140
141    class ACustomState(AnInstance):
142        def __getstate__(self):
143            return (self.foo, self.bar, self.baz)
144        def __setstate__(self, state):
145            self.foo, self.bar, self.baz = state
146
147    class NewArgs(AnObject):
148        def __getnewargs__(self):
149            return (self.foo, self.bar, self.baz)
150        def __getstate__(self):
151            return {}
152
153    class NewArgsWithState(AnObject):
154        def __getnewargs__(self):
155            return (self.foo, self.bar)
156        def __getstate__(self):
157            return self.baz
158        def __setstate__(self, state):
159            self.baz = state
160
161    InitArgs = NewArgs
162
163    InitArgsWithState = NewArgsWithState
164
165    class Reduce(AnObject):
166        def __reduce__(self):
167            return self.__class__, (self.foo, self.bar, self.baz)
168
169    class ReduceWithState(AnObject):
170        def __reduce__(self):
171            return self.__class__, (self.foo, self.bar), self.baz
172        def __setstate__(self, state):
173            self.baz = state
174
175    class MyInt(int):
176        def __eq__(self, other):
177            return type(self) is type(other) and int(self) == int(other)
178
179    class MyList(list):
180        def __init__(self, n=1):
181            self.extend([None]*n)
182        def __eq__(self, other):
183            return type(self) is type(other) and list(self) == list(other)
184
185    class MyDict(dict):
186        def __init__(self, n=1):
187            for k in range(n):
188                self[k] = None
189        def __eq__(self, other):
190            return type(self) is type(other) and dict(self) == dict(other)
191
192    class FixedOffset(datetime.tzinfo):
193        def __init__(self, offset, name):
194            self.__offset = datetime.timedelta(minutes=offset)
195            self.__name = name
196        def utcoffset(self, dt):
197            return self.__offset
198        def tzname(self, dt):
199            return self.__name
200        def dst(self, dt):
201            return datetime.timedelta(0)
202
203def _load_code(expression):
204    return eval(expression)
205
206def _serialize_value(data):
207    if isinstance(data, list):
208        return '[%s]' % ', '.join(map(_serialize_value, data))
209    elif isinstance(data, dict):
210        items = []
211        for key, value in data.items():
212            key = _serialize_value(key)
213            value = _serialize_value(value)
214            items.append("%s: %s" % (key, value))
215        items.sort()
216        return '{%s}' % ', '.join(items)
217    elif isinstance(data, datetime.datetime):
218        return repr(data.utctimetuple())
219    elif isinstance(data, float) and data != data:
220        return '?'
221    else:
222        return str(data)
223
224def test_constructor_types(data_filename, code_filename, verbose=False):
225    _make_objects()
226    native1 = None
227    native2 = None
228    try:
229        native1 = list(yaml.load_all(open(data_filename, 'rb'), Loader=MyLoader))
230        if len(native1) == 1:
231            native1 = native1[0]
232        native2 = _load_code(open(code_filename, 'rb').read())
233        try:
234            if native1 == native2:
235                return
236        except TypeError:
237            pass
238        if verbose:
239            print("SERIALIZED NATIVE1:")
240            print(_serialize_value(native1))
241            print("SERIALIZED NATIVE2:")
242            print(_serialize_value(native2))
243        assert _serialize_value(native1) == _serialize_value(native2), (native1, native2)
244    finally:
245        if verbose:
246            print("NATIVE1:")
247            pprint.pprint(native1)
248            print("NATIVE2:")
249            pprint.pprint(native2)
250
251test_constructor_types.unittest = ['.data', '.code']
252
253if __name__ == '__main__':
254    import sys, test_constructor
255    sys.modules['test_constructor'] = sys.modules['__main__']
256    import test_appliance
257    test_appliance.run(globals())
258
Note: See TracBrowser for help on using the repository browser.