source: pyyaml/trunk/tests/test_constructor.py @ 147

Revision 147, 7.0 KB checked in by xi, 8 years ago (diff)

Add support for pickling/unpickling python objects.

Line 
1
2import test_appliance
3try:
4    import datetime
5except ImportError:
6    pass
7try:
8    set
9except NameError:
10    from sets import Set as set
11
12from yaml import *
13
14import yaml.tokens
15
16class MyLoader(Loader):
17    pass
18class MyDumper(Dumper):
19    pass
20
21class MyTestClass1:
22
23    def __init__(self, x, y=0, z=0):
24        self.x = x
25        self.y = y
26        self.z = z
27
28    def __eq__(self, other):
29        if isinstance(other, MyTestClass1):
30            return self.__class__, self.__dict__ == other.__class__, other.__dict__
31        else:
32            return False
33
34def construct1(constructor, node):
35    mapping = constructor.construct_mapping(node)
36    return MyTestClass1(**mapping)
37def represent1(representer, native):
38    return representer.represent_mapping("!tag1", native.__dict__)
39
40add_constructor("!tag1", construct1, Loader=MyLoader)
41add_representer(MyTestClass1, represent1, Dumper=MyDumper)
42
43class MyTestClass2(MyTestClass1, YAMLObject):
44
45    yaml_loader = MyLoader
46    yaml_dumper = MyDumper
47    yaml_tag = "!tag2"
48
49    def from_yaml(cls, constructor, node):
50        x = constructor.construct_yaml_int(node)
51        return cls(x=x)
52    from_yaml = classmethod(from_yaml)
53
54    def to_yaml(cls, representer, native):
55        return representer.represent_scalar(cls.yaml_tag, str(native.x))
56    to_yaml = classmethod(to_yaml)
57
58class MyTestClass3(MyTestClass2):
59
60    yaml_tag = "!tag3"
61
62    def from_yaml(cls, constructor, node):
63        mapping = constructor.construct_mapping(node)
64        if '=' in mapping:
65            x = mapping['=']
66            del mapping['=']
67            mapping['x'] = x
68        return cls(**mapping)
69    from_yaml = classmethod(from_yaml)
70
71    def to_yaml(cls, representer, native):
72        return representer.represent_mapping(cls.yaml_tag, native.__dict__)
73    to_yaml = classmethod(to_yaml)
74
75class YAMLObject1(YAMLObject):
76
77    yaml_loader = MyLoader
78    yaml_dumper = MyDumper
79    yaml_tag = '!foo'
80
81    def __init__(self, my_parameter=None, my_another_parameter=None):
82        self.my_parameter = my_parameter
83        self.my_another_parameter = my_another_parameter
84
85    def __eq__(self, other):
86        if isinstance(other, YAMLObject1):
87            return self.__class__, self.__dict__ == other.__class__, other.__dict__
88        else:
89            return False
90
91class YAMLObject2(YAMLObject):
92
93    yaml_loader = MyLoader
94    yaml_dumper = MyDumper
95    yaml_tag = '!bar'
96
97    def __init__(self, foo=1, bar=2, baz=3):
98        self.foo = foo
99        self.bar = bar
100        self.baz = baz
101
102    def __getstate__(self):
103        return {1: self.foo, 2: self.bar, 3: self.baz}
104
105    def __setstate__(self, state):
106        self.foo = state[1]
107        self.bar = state[2]
108        self.baz = state[3]
109
110    def __eq__(self, other):
111        if isinstance(other, YAMLObject2):
112            return self.__class__, self.__dict__ == other.__class__, other.__dict__
113        else:
114            return False
115
116class AnObject(object):
117
118    def __new__(cls, foo=None, bar=None, baz=None):
119        self = object.__new__(cls)
120        self.foo = foo
121        self.bar = bar
122        self.baz = baz
123        return self
124
125    def __cmp__(self, other):
126        return cmp((type(self), self.foo, self.bar, self.baz),
127                (type(other), other.foo, other.bar, other.baz))
128
129    def __eq__(self, other):
130        return type(self) is type(other) and    \
131                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
132
133class AnInstance:
134
135    def __init__(self, foo=None, bar=None, baz=None):
136        self.foo = foo
137        self.bar = bar
138        self.baz = baz
139
140    def __cmp__(self, other):
141        return cmp((type(self), self.foo, self.bar, self.baz),
142                (type(other), other.foo, other.bar, other.baz))
143
144    def __eq__(self, other):
145        return type(self) is type(other) and    \
146                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
147
148class AState(AnInstance):
149
150    def __getstate__(self):
151        return {
152            '_foo': self.foo,
153            '_bar': self.bar,
154            '_baz': self.baz,
155        }
156
157    def __setstate__(self, state):
158        self.foo = state['_foo']
159        self.bar = state['_bar']
160        self.baz = state['_baz']
161
162class ACustomState(AnInstance):
163
164    def __getstate__(self):
165        return (self.foo, self.bar, self.baz)
166
167    def __setstate__(self, state):
168        self.foo, self.bar, self.baz = state
169
170class InitArgs(AnInstance):
171
172    def __getinitargs__(self):
173        return (self.foo, self.bar, self.baz)
174
175    def __getstate__(self):
176        return {}
177
178class InitArgsWithState(AnInstance):
179
180    def __getinitargs__(self):
181        return (self.foo, self.bar)
182
183    def __getstate__(self):
184        return self.baz
185
186    def __setstate__(self, state):
187        self.baz = state
188
189class NewArgs(AnObject):
190
191    def __getnewargs__(self):
192        return (self.foo, self.bar, self.baz)
193
194    def __getstate__(self):
195        return {}
196
197class NewArgsWithState(AnObject):
198
199    def __getnewargs__(self):
200        return (self.foo, self.bar)
201
202    def __getstate__(self):
203        return self.baz
204
205    def __setstate__(self, state):
206        self.baz = state
207
208class Reduce(AnObject):
209
210    def __reduce__(self):
211        return self.__class__, (self.foo, self.bar, self.baz)
212
213class ReduceWithState(AnObject):
214
215    def __reduce__(self):
216        return self.__class__, (self.foo, self.bar), self.baz
217
218    def __setstate__(self, state):
219        self.baz = state
220
221class MyInt(int):
222
223    def __eq__(self, other):
224        return type(self) is type(other) and int(self) == int(other)
225
226class MyList(list):
227
228    def __init__(self, n=1):
229        self.extend([None]*n)
230
231    def __eq__(self, other):
232        return type(self) is type(other) and list(self) == list(other)
233
234class MyDict(dict):
235
236    def __init__(self, n=1):
237        for k in range(n):
238            self[k] = None
239
240    def __eq__(self, other):
241        return type(self) is type(other) and dict(self) == dict(other)
242
243class TestConstructorTypes(test_appliance.TestAppliance):
244
245    def _testTypes(self, test_name, data_filename, code_filename):
246        data1 = None
247        data2 = None
248        try:
249            data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
250            if len(data1) == 1:
251                data1 = data1[0]
252            data2 = eval(file(code_filename, 'rb').read())
253            self.failUnlessEqual(type(data1), type(data2))
254            try:
255                self.failUnlessEqual(data1, data2)
256            except AssertionError:
257                if isinstance(data1, dict):
258                    data1 = data1.items()
259                    data1.sort()
260                    data1 = repr(data1)
261                    data2 = data2.items()
262                    data2.sort()
263                    data2 = repr(data2)
264                if data1 != data2:
265                    raise
266        except:
267            print
268            print "DATA:"
269            print file(data_filename, 'rb').read()
270            print "CODE:"
271            print file(code_filename, 'rb').read()
272            print "NATIVES1:", data1
273            print "NATIVES2:", data2
274            raise
275
276TestConstructorTypes.add_tests('testTypes', '.data', '.code')
277
Note: See TracBrowser for help on using the repository browser.