source: pyyaml/trunk/tests/test_constructor.py @ 225

Revision 225, 8.2 KB checked in by xi, 8 years ago (diff)

Fix timestamp constructing and representing (close #25).

RevLine 
[58]1
2import test_appliance
[225]3
4import datetime
[58]5try:
[59]6    set
7except NameError:
8    from sets import Set as set
[58]9
10from yaml import *
11
[146]12import yaml.tokens
[144]13
[136]14class MyLoader(Loader):
[58]15    pass
[144]16class MyDumper(Dumper):
17    pass
[58]18
19class MyTestClass1:
20
21    def __init__(self, x, y=0, z=0):
22        self.x = x
23        self.y = y
24        self.z = z
25
26    def __eq__(self, other):
[144]27        if isinstance(other, MyTestClass1):
28            return self.__class__, self.__dict__ == other.__class__, other.__dict__
29        else:
30            return False
[58]31
32def construct1(constructor, node):
33    mapping = constructor.construct_mapping(node)
34    return MyTestClass1(**mapping)
[144]35def represent1(representer, native):
36    return representer.represent_mapping("!tag1", native.__dict__)
[58]37
[144]38add_constructor("!tag1", construct1, Loader=MyLoader)
39add_representer(MyTestClass1, represent1, Dumper=MyDumper)
[58]40
41class MyTestClass2(MyTestClass1, YAMLObject):
42
[136]43    yaml_loader = MyLoader
[144]44    yaml_dumper = MyDumper
[58]45    yaml_tag = "!tag2"
46
47    def from_yaml(cls, constructor, node):
48        x = constructor.construct_yaml_int(node)
49        return cls(x=x)
50    from_yaml = classmethod(from_yaml)
51
[144]52    def to_yaml(cls, representer, native):
53        return representer.represent_scalar(cls.yaml_tag, str(native.x))
54    to_yaml = classmethod(to_yaml)
55
[58]56class MyTestClass3(MyTestClass2):
57
58    yaml_tag = "!tag3"
59
60    def from_yaml(cls, constructor, node):
61        mapping = constructor.construct_mapping(node)
62        if '=' in mapping:
63            x = mapping['=']
64            del mapping['=']
65            mapping['x'] = x
66        return cls(**mapping)
67    from_yaml = classmethod(from_yaml)
68
[144]69    def to_yaml(cls, representer, native):
70        return representer.represent_mapping(cls.yaml_tag, native.__dict__)
71    to_yaml = classmethod(to_yaml)
72
[136]73class YAMLObject1(YAMLObject):
[144]74
[136]75    yaml_loader = MyLoader
[144]76    yaml_dumper = MyDumper
[136]77    yaml_tag = '!foo'
78
79    def __init__(self, my_parameter=None, my_another_parameter=None):
80        self.my_parameter = my_parameter
81        self.my_another_parameter = my_another_parameter
82
83    def __eq__(self, other):
84        if isinstance(other, YAMLObject1):
85            return self.__class__, self.__dict__ == other.__class__, other.__dict__
86        else:
87            return False
88
[144]89class YAMLObject2(YAMLObject):
[58]90
[144]91    yaml_loader = MyLoader
92    yaml_dumper = MyDumper
93    yaml_tag = '!bar'
94
95    def __init__(self, foo=1, bar=2, baz=3):
96        self.foo = foo
97        self.bar = bar
98        self.baz = baz
99
100    def __getstate__(self):
101        return {1: self.foo, 2: self.bar, 3: self.baz}
102
103    def __setstate__(self, state):
104        self.foo = state[1]
105        self.bar = state[2]
106        self.baz = state[3]
107
108    def __eq__(self, other):
109        if isinstance(other, YAMLObject2):
110            return self.__class__, self.__dict__ == other.__class__, other.__dict__
111        else:
112            return False
113
[147]114class AnObject(object):
115
116    def __new__(cls, foo=None, bar=None, baz=None):
117        self = object.__new__(cls)
118        self.foo = foo
119        self.bar = bar
120        self.baz = baz
121        return self
122
123    def __cmp__(self, other):
124        return cmp((type(self), self.foo, self.bar, self.baz),
125                (type(other), other.foo, other.bar, other.baz))
126
127    def __eq__(self, other):
128        return type(self) is type(other) and    \
129                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
130
131class AnInstance:
132
133    def __init__(self, foo=None, bar=None, baz=None):
134        self.foo = foo
135        self.bar = bar
136        self.baz = baz
137
138    def __cmp__(self, other):
139        return cmp((type(self), self.foo, self.bar, self.baz),
140                (type(other), other.foo, other.bar, other.baz))
141
142    def __eq__(self, other):
143        return type(self) is type(other) and    \
144                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
145
146class AState(AnInstance):
147
148    def __getstate__(self):
149        return {
150            '_foo': self.foo,
151            '_bar': self.bar,
152            '_baz': self.baz,
153        }
154
155    def __setstate__(self, state):
156        self.foo = state['_foo']
157        self.bar = state['_bar']
158        self.baz = state['_baz']
159
160class ACustomState(AnInstance):
161
162    def __getstate__(self):
163        return (self.foo, self.bar, self.baz)
164
165    def __setstate__(self, state):
166        self.foo, self.bar, self.baz = state
167
168class InitArgs(AnInstance):
169
170    def __getinitargs__(self):
171        return (self.foo, self.bar, self.baz)
172
173    def __getstate__(self):
174        return {}
175
176class InitArgsWithState(AnInstance):
177
178    def __getinitargs__(self):
179        return (self.foo, self.bar)
180
181    def __getstate__(self):
182        return self.baz
183
184    def __setstate__(self, state):
185        self.baz = state
186
187class NewArgs(AnObject):
188
189    def __getnewargs__(self):
190        return (self.foo, self.bar, self.baz)
191
192    def __getstate__(self):
193        return {}
194
195class NewArgsWithState(AnObject):
196
197    def __getnewargs__(self):
198        return (self.foo, self.bar)
199
200    def __getstate__(self):
201        return self.baz
202
203    def __setstate__(self, state):
204        self.baz = state
205
206class Reduce(AnObject):
207
208    def __reduce__(self):
209        return self.__class__, (self.foo, self.bar, self.baz)
210
211class ReduceWithState(AnObject):
212
213    def __reduce__(self):
214        return self.__class__, (self.foo, self.bar), self.baz
215
216    def __setstate__(self, state):
217        self.baz = state
218
219class MyInt(int):
220
221    def __eq__(self, other):
222        return type(self) is type(other) and int(self) == int(other)
223
224class MyList(list):
225
226    def __init__(self, n=1):
227        self.extend([None]*n)
228
229    def __eq__(self, other):
230        return type(self) is type(other) and list(self) == list(other)
231
232class MyDict(dict):
233
234    def __init__(self, n=1):
235        for k in range(n):
236            self[k] = None
237
238    def __eq__(self, other):
239        return type(self) is type(other) and dict(self) == dict(other)
240
[225]241class FixedOffset(datetime.tzinfo):
242
243    def __init__(self, offset, name):
244        self.__offset = datetime.timedelta(minutes=offset)
245        self.__name = name
246
247    def utcoffset(self, dt):
248        return self.__offset
249
250    def tzname(self, dt):
251        return self.__name
252
253    def dst(self, dt):
254        return datetime.timedelta(0)
255
256
[222]257def execute(code):
258    exec code
259    return value
260
[144]261class TestConstructorTypes(test_appliance.TestAppliance):
262
[58]263    def _testTypes(self, test_name, data_filename, code_filename):
[136]264        data1 = None
265        data2 = None
[58]266        try:
[136]267            data1 = list(load_all(file(data_filename, 'rb'), Loader=MyLoader))
268            if len(data1) == 1:
269                data1 = data1[0]
270            data2 = eval(file(code_filename, 'rb').read())
[144]271            self.failUnlessEqual(type(data1), type(data2))
[58]272            try:
[136]273                self.failUnlessEqual(data1, data2)
[225]274            except (AssertionError, TypeError):
[136]275                if isinstance(data1, dict):
[150]276                    data1 = [(repr(key), value) for key, value in data1.items()]
[136]277                    data1.sort()
278                    data1 = repr(data1)
[150]279                    data2 = [(repr(key), value) for key, value in data2.items()]
[136]280                    data2.sort()
281                    data2 = repr(data2)
[173]282                    if data1 != data2:
283                        raise
284                elif isinstance(data1, list):
285                    self.failUnlessEqual(type(data1), type(data2))
286                    self.failUnlessEqual(len(data1), len(data2))
287                    for item1, item2 in zip(data1, data2):
288                        if (item1 != item1 or (item1 == 0.0 and item1 == 1.0)) and  \
289                                (item2 != item2 or (item2 == 0.0 and item2 == 1.0)):
290                            continue
[225]291                        if isinstance(item1, datetime.datetime):
292                            item1 = item1.utctimetuple()
293                        if isinstance(item2, datetime.datetime):
294                            item2 = item2.utctimetuple()
[173]295                        self.failUnlessEqual(item1, item2)
296                else:
[58]297                    raise
298        except:
299            print
300            print "DATA:"
301            print file(data_filename, 'rb').read()
302            print "CODE:"
303            print file(code_filename, 'rb').read()
[136]304            print "NATIVES1:", data1
305            print "NATIVES2:", data2
[58]306            raise
307
[144]308TestConstructorTypes.add_tests('testTypes', '.data', '.code')
[58]309
Note: See TracBrowser for help on using the repository browser.