source: pyyaml/trunk/tests/test_constructor.py @ 222

Revision 222, 7.7 KB checked in by xi, 8 years ago (diff)

Subclass all base classes from object.

Hold references to the objects being represented (should fix #22).

The value of a mapping node is represented as a list of pairs (key, value)
now.

Sort dictionary items (fix #23).

Recursive structures are now loaded and dumped correctly, including complex
structures like recursive tuples (fix #5). Thanks Peter Murphy for the patches.
To make it possible, representer functions are allowed to be generators.
In this case, the first generated value is an object. Other values produced
by the representer are ignored.

Make Representer not try to guess !!pairs when a list is represented.
You need to construct a !!pairs node explicitly now.

Do not check for duplicate mapping keys as it didn't work correctly anyway.

Line 
1
2import test_appliance
3try:
4    import datetime
5except ImportError:
6    pass
7try:
8    set
9except NameError:
10    from sets import Set as set
11
12from yaml import *
13
14import yaml.tokens
15
16class MyLoader(Loader):
17    pass
18class MyDumper(Dumper):
19    pass
20
21class MyTestClass1:
22
23    def __init__(self, x, y=0, z=0):
24        self.x = x
25        self.y = y
26        self.z = z
27
28    def __eq__(self, other):
29        if isinstance(other, MyTestClass1):
30            return self.__class__, self.__dict__ == other.__class__, other.__dict__
31        else:
32            return False
33
34def construct1(constructor, node):
35    mapping = constructor.construct_mapping(node)
36    return MyTestClass1(**mapping)
37def represent1(representer, native):
38    return representer.represent_mapping("!tag1", native.__dict__)
39
40add_constructor("!tag1", construct1, Loader=MyLoader)
41add_representer(MyTestClass1, represent1, Dumper=MyDumper)
42
43class MyTestClass2(MyTestClass1, YAMLObject):
44
45    yaml_loader = MyLoader
46    yaml_dumper = MyDumper
47    yaml_tag = "!tag2"
48
49    def from_yaml(cls, constructor, node):
50        x = constructor.construct_yaml_int(node)
51        return cls(x=x)
52    from_yaml = classmethod(from_yaml)
53
54    def to_yaml(cls, representer, native):
55        return representer.represent_scalar(cls.yaml_tag, str(native.x))
56    to_yaml = classmethod(to_yaml)
57
58class MyTestClass3(MyTestClass2):
59
60    yaml_tag = "!tag3"
61
62    def from_yaml(cls, constructor, node):
63        mapping = constructor.construct_mapping(node)
64        if '=' in mapping:
65            x = mapping['=']
66            del mapping['=']
67            mapping['x'] = x
68        return cls(**mapping)
69    from_yaml = classmethod(from_yaml)
70
71    def to_yaml(cls, representer, native):
72        return representer.represent_mapping(cls.yaml_tag, native.__dict__)
73    to_yaml = classmethod(to_yaml)
74
75class YAMLObject1(YAMLObject):
76
77    yaml_loader = MyLoader
78    yaml_dumper = MyDumper
79    yaml_tag = '!foo'
80
81    def __init__(self, my_parameter=None, my_another_parameter=None):
82        self.my_parameter = my_parameter
83        self.my_another_parameter = my_another_parameter
84
85    def __eq__(self, other):
86        if isinstance(other, YAMLObject1):
87            return self.__class__, self.__dict__ == other.__class__, other.__dict__
88        else:
89            return False
90
91class YAMLObject2(YAMLObject):
92
93    yaml_loader = MyLoader
94    yaml_dumper = MyDumper
95    yaml_tag = '!bar'
96
97    def __init__(self, foo=1, bar=2, baz=3):
98        self.foo = foo
99        self.bar = bar
100        self.baz = baz
101
102    def __getstate__(self):
103        return {1: self.foo, 2: self.bar, 3: self.baz}
104
105    def __setstate__(self, state):
106        self.foo = state[1]
107        self.bar = state[2]
108        self.baz = state[3]
109
110    def __eq__(self, other):
111        if isinstance(other, YAMLObject2):
112            return self.__class__, self.__dict__ == other.__class__, other.__dict__
113        else:
114            return False
115
116class AnObject(object):
117
118    def __new__(cls, foo=None, bar=None, baz=None):
119        self = object.__new__(cls)
120        self.foo = foo
121        self.bar = bar
122        self.baz = baz
123        return self
124
125    def __cmp__(self, other):
126        return cmp((type(self), self.foo, self.bar, self.baz),
127                (type(other), other.foo, other.bar, other.baz))
128
129    def __eq__(self, other):
130        return type(self) is type(other) and    \
131                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
132
133class AnInstance:
134
135    def __init__(self, foo=None, bar=None, baz=None):
136        self.foo = foo
137        self.bar = bar
138        self.baz = baz
139
140    def __cmp__(self, other):
141        return cmp((type(self), self.foo, self.bar, self.baz),
142                (type(other), other.foo, other.bar, other.baz))
143
144    def __eq__(self, other):
145        return type(self) is type(other) and    \
146                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
147
148class AState(AnInstance):
149
150    def __getstate__(self):
151        return {
152            '_foo': self.foo,
153            '_bar': self.bar,
154            '_baz': self.baz,
155        }
156
157    def __setstate__(self, state):
158        self.foo = state['_foo']
159        self.bar = state['_bar']
160        self.baz = state['_baz']
161
162class ACustomState(AnInstance):
163
164    def __getstate__(self):
165        return (self.foo, self.bar, self.baz)
166
167    def __setstate__(self, state):
168        self.foo, self.bar, self.baz = state
169
170class InitArgs(AnInstance):
171
172    def __getinitargs__(self):
173        return (self.foo, self.bar, self.baz)
174
175    def __getstate__(self):
176        return {}
177
178class InitArgsWithState(AnInstance):
179
180    def __getinitargs__(self):
181        return (self.foo, self.bar)
182
183    def __getstate__(self):
184        return self.baz
185
186    def __setstate__(self, state):
187        self.baz = state
188
189class NewArgs(AnObject):
190
191    def __getnewargs__(self):
192        return (self.foo, self.bar, self.baz)
193
194    def __getstate__(self):
195        return {}
196
197class NewArgsWithState(AnObject):
198
199    def __getnewargs__(self):
200        return (self.foo, self.bar)
201
202    def __getstate__(self):
203        return self.baz
204
205    def __setstate__(self, state):
206        self.baz = state
207
208class Reduce(AnObject):
209
210    def __reduce__(self):
211        return self.__class__, (self.foo, self.bar, self.baz)
212
213class ReduceWithState(AnObject):
214
215    def __reduce__(self):
216        return self.__class__, (self.foo, self.bar), self.baz
217
218    def __setstate__(self, state):
219        self.baz = state
220
221class MyInt(int):
222
223    def __eq__(self, other):
224        return type(self) is type(other) and int(self) == int(other)
225
226class MyList(list):
227
228    def __init__(self, n=1):
229        self.extend([None]*n)
230
231    def __eq__(self, other):
232        return type(self) is type(other) and list(self) == list(other)
233
234class MyDict(dict):
235
236    def __init__(self, n=1):
237        for k in range(n):
238            self[k] = None
239
240    def __eq__(self, other):
241        return type(self) is type(other) and dict(self) == dict(other)
242
243def execute(code):
244    exec code
245    return value
246
247class TestConstructorTypes(test_appliance.TestAppliance):
248
249    def _testTypes(self, test_name, data_filename, code_filename):
250        data1 = None
251        data2 = None
252        try:
253            data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
254            if len(data1) == 1:
255                data1 = data1[0]
256            data2 = eval(file(code_filename, 'rb').read())
257            self.failUnlessEqual(type(data1), type(data2))
258            try:
259                self.failUnlessEqual(data1, data2)
260            except AssertionError:
261                if isinstance(data1, dict):
262                    data1 = [(repr(key), value) for key, value in data1.items()]
263                    data1.sort()
264                    data1 = repr(data1)
265                    data2 = [(repr(key), value) for key, value in data2.items()]
266                    data2.sort()
267                    data2 = repr(data2)
268                    if data1 != data2:
269                        raise
270                elif isinstance(data1, list):
271                    self.failUnlessEqual(type(data1), type(data2))
272                    self.failUnlessEqual(len(data1), len(data2))
273                    for item1, item2 in zip(data1, data2):
274                        if (item1 != item1 or (item1 == 0.0 and item1 == 1.0)) and  \
275                                (item2 != item2 or (item2 == 0.0 and item2 == 1.0)):
276                            continue
277                        self.failUnlessEqual(item1, item2)
278                else:
279                    raise
280        except:
281            print
282            print "DATA:"
283            print file(data_filename, 'rb').read()
284            print "CODE:"
285            print file(code_filename, 'rb').read()
286            print "NATIVES1:", data1
287            print "NATIVES2:", data2
288            raise
289
290TestConstructorTypes.add_tests('testTypes', '.data', '.code')
291
Note: See TracBrowser for help on using the repository browser.