source: pyyaml/trunk/tests/test_constructor.py @ 325

Revision 325, 9.0 KB checked in by xi, 5 years ago (diff)

Minor 2.3 and win32 compatibility fixes; clarify the 'feature not found' message in setup.py.

RevLine 
[58]1
[322]2import yaml
3import pprint
[225]4
5import datetime
[58]6try:
[59]7    set
8except NameError:
9    from sets import Set as set
[146]10import yaml.tokens
[144]11
[322]12def execute(code):
13    exec code
14    return value
[58]15
[322]16def _make_objects():
17    global MyLoader, MyDumper, MyTestClass1, MyTestClass2, MyTestClass3, YAMLObject1, YAMLObject2,  \
18            AnObject, AnInstance, AState, ACustomState, InitArgs, InitArgsWithState,    \
19            NewArgs, NewArgsWithState, Reduce, ReduceWithState, MyInt, MyList, MyDict,  \
20            FixedOffset, execute
[58]21
[322]22    class MyLoader(yaml.Loader):
23        pass
24    class MyDumper(yaml.Dumper):
25        pass
[58]26
[322]27    class MyTestClass1:
28        def __init__(self, x, y=0, z=0):
29            self.x = x
30            self.y = y
31            self.z = z
32        def __eq__(self, other):
33            if isinstance(other, MyTestClass1):
34                return self.__class__, self.__dict__ == other.__class__, other.__dict__
35            else:
36                return False
[58]37
[322]38    def construct1(constructor, node):
[58]39        mapping = constructor.construct_mapping(node)
[322]40        return MyTestClass1(**mapping)
41    def represent1(representer, native):
42        return representer.represent_mapping("!tag1", native.__dict__)
[58]43
[322]44    yaml.add_constructor("!tag1", construct1, Loader=MyLoader)
45    yaml.add_representer(MyTestClass1, represent1, Dumper=MyDumper)
[144]46
[322]47    class MyTestClass2(MyTestClass1, yaml.YAMLObject):
48        yaml_loader = MyLoader
49        yaml_dumper = MyDumper
50        yaml_tag = "!tag2"
51        def from_yaml(cls, constructor, node):
52            x = constructor.construct_yaml_int(node)
53            return cls(x=x)
54        from_yaml = classmethod(from_yaml)
55        def to_yaml(cls, representer, native):
56            return representer.represent_scalar(cls.yaml_tag, str(native.x))
57        to_yaml = classmethod(to_yaml)
[144]58
[322]59    class MyTestClass3(MyTestClass2):
60        yaml_tag = "!tag3"
61        def from_yaml(cls, constructor, node):
62            mapping = constructor.construct_mapping(node)
63            if '=' in mapping:
64                x = mapping['=']
65                del mapping['=']
66                mapping['x'] = x
67            return cls(**mapping)
68        from_yaml = classmethod(from_yaml)
69        def to_yaml(cls, representer, native):
70            return representer.represent_mapping(cls.yaml_tag, native.__dict__)
71        to_yaml = classmethod(to_yaml)
[136]72
[322]73    class YAMLObject1(yaml.YAMLObject):
74        yaml_loader = MyLoader
75        yaml_dumper = MyDumper
76        yaml_tag = '!foo'
77        def __init__(self, my_parameter=None, my_another_parameter=None):
78            self.my_parameter = my_parameter
79            self.my_another_parameter = my_another_parameter
80        def __eq__(self, other):
81            if isinstance(other, YAMLObject1):
82                return self.__class__, self.__dict__ == other.__class__, other.__dict__
83            else:
84                return False
[136]85
[322]86    class YAMLObject2(yaml.YAMLObject):
87        yaml_loader = MyLoader
88        yaml_dumper = MyDumper
89        yaml_tag = '!bar'
90        def __init__(self, foo=1, bar=2, baz=3):
91            self.foo = foo
92            self.bar = bar
93            self.baz = baz
94        def __getstate__(self):
95            return {1: self.foo, 2: self.bar, 3: self.baz}
96        def __setstate__(self, state):
97            self.foo = state[1]
98            self.bar = state[2]
99            self.baz = state[3]
100        def __eq__(self, other):
101            if isinstance(other, YAMLObject2):
102                return self.__class__, self.__dict__ == other.__class__, other.__dict__
103            else:
104                return False
[136]105
[322]106    class AnObject(object):
107        def __new__(cls, foo=None, bar=None, baz=None):
108            self = object.__new__(cls)
109            self.foo = foo
110            self.bar = bar
111            self.baz = baz
112            return self
113        def __cmp__(self, other):
114            return cmp((type(self), self.foo, self.bar, self.baz),
115                    (type(other), other.foo, other.bar, other.baz))
116        def __eq__(self, other):
117            return type(self) is type(other) and    \
118                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
[58]119
[322]120    class AnInstance:
121        def __init__(self, foo=None, bar=None, baz=None):
122            self.foo = foo
123            self.bar = bar
124            self.baz = baz
125        def __cmp__(self, other):
126            return cmp((type(self), self.foo, self.bar, self.baz),
127                    (type(other), other.foo, other.bar, other.baz))
128        def __eq__(self, other):
129            return type(self) is type(other) and    \
130                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
[144]131
[322]132    class AState(AnInstance):
133        def __getstate__(self):
134            return {
135                '_foo': self.foo,
136                '_bar': self.bar,
137                '_baz': self.baz,
138            }
139        def __setstate__(self, state):
140            self.foo = state['_foo']
141            self.bar = state['_bar']
142            self.baz = state['_baz']
[144]143
[322]144    class ACustomState(AnInstance):
145        def __getstate__(self):
146            return (self.foo, self.bar, self.baz)
147        def __setstate__(self, state):
148            self.foo, self.bar, self.baz = state
[144]149
[322]150    class InitArgs(AnInstance):
151        def __getinitargs__(self):
152            return (self.foo, self.bar, self.baz)
153        def __getstate__(self):
154            return {}
[144]155
[322]156    class InitArgsWithState(AnInstance):
157        def __getinitargs__(self):
158            return (self.foo, self.bar)
159        def __getstate__(self):
160            return self.baz
161        def __setstate__(self, state):
162            self.baz = state
[144]163
[322]164    class NewArgs(AnObject):
165        def __getnewargs__(self):
166            return (self.foo, self.bar, self.baz)
167        def __getstate__(self):
168            return {}
[147]169
[322]170    class NewArgsWithState(AnObject):
171        def __getnewargs__(self):
172            return (self.foo, self.bar)
173        def __getstate__(self):
174            return self.baz
175        def __setstate__(self, state):
176            self.baz = state
[147]177
[322]178    class Reduce(AnObject):
179        def __reduce__(self):
180            return self.__class__, (self.foo, self.bar, self.baz)
[147]181
[322]182    class ReduceWithState(AnObject):
183        def __reduce__(self):
184            return self.__class__, (self.foo, self.bar), self.baz
185        def __setstate__(self, state):
186            self.baz = state
[147]187
[322]188    class MyInt(int):
189        def __eq__(self, other):
190            return type(self) is type(other) and int(self) == int(other)
[147]191
[322]192    class MyList(list):
193        def __init__(self, n=1):
194            self.extend([None]*n)
195        def __eq__(self, other):
196            return type(self) is type(other) and list(self) == list(other)
[147]197
[322]198    class MyDict(dict):
199        def __init__(self, n=1):
200            for k in range(n):
201                self[k] = None
202        def __eq__(self, other):
203            return type(self) is type(other) and dict(self) == dict(other)
[147]204
[322]205    class FixedOffset(datetime.tzinfo):
206        def __init__(self, offset, name):
207            self.__offset = datetime.timedelta(minutes=offset)
208            self.__name = name
209        def utcoffset(self, dt):
210            return self.__offset
211        def tzname(self, dt):
212            return self.__name
213        def dst(self, dt):
214            return datetime.timedelta(0)
[147]215
[322]216def _load_code(expression):
217    return eval(expression)
[147]218
[322]219def _serialize_value(data):
220    if isinstance(data, list):
221        return '[%s]' % ', '.join(map(_serialize_value, data))
222    elif isinstance(data, dict):
223        items = []
224        for key, value in data.items():
225            key = _serialize_value(key)
226            value = _serialize_value(value)
227            items.append("%s: %s" % (key, value))
228        items.sort()
229        return '{%s}' % ', '.join(items)
230    elif isinstance(data, datetime.datetime):
231        return repr(data.utctimetuple())
232    elif isinstance(data, unicode):
233        return data.encode('utf-8')
[325]234    elif isinstance(data, float) and data != data:
235        return '?'
[322]236    else:
237        return str(data)
[147]238
[322]239def test_constructor_types(data_filename, code_filename, verbose=False):
240    _make_objects()
241    native1 = None
242    native2 = None
243    try:
244        native1 = list(yaml.load_all(open(data_filename, 'rb'), Loader=MyLoader))
245        if len(native1) == 1:
246            native1 = native1[0]
247        native2 = _load_code(open(code_filename, 'rb').read())
[58]248        try:
[322]249            if native1 == native2:
250                return
251        except TypeError:
252            pass
253        if verbose:
254            print "SERIALIZED NATIVE1:"
255            print _serialize_value(native1)
256            print "SERIALIZED NATIVE2:"
257            print _serialize_value(native2)
258        assert _serialize_value(native1) == _serialize_value(native2), (native1, native2)
259    finally:
260        if verbose:
261            print "NATIVE1:"
262            pprint.pprint(native1)
263            print "NATIVE2:"
264            pprint.pprint(native2)
[58]265
[322]266test_constructor_types.unittest = ['.data', '.code']
[58]267
[322]268if __name__ == '__main__':
269    import sys, test_constructor
270    sys.modules['test_constructor'] = sys.modules['__main__']
271    import test_appliance
272    test_appliance.run(globals())
273
Note: See TracBrowser for help on using the repository browser.