source: pyyaml/trunk/tests3/test_constructor.py @ 328

Revision 328, 8.8 KB checked in by xi, 5 years ago (diff)

Added basic support for Python 3 (Thanks idadesub(at)users(dot)sourceforge(dot)net).

RevLine 
[58]1
[322]2import yaml
3import pprint
[225]4
5import datetime
[146]6import yaml.tokens
[144]7
[322]8def execute(code):
[328]9    exec(code)
[322]10    return value
[58]11
[322]12def _make_objects():
13    global MyLoader, MyDumper, MyTestClass1, MyTestClass2, MyTestClass3, YAMLObject1, YAMLObject2,  \
14            AnObject, AnInstance, AState, ACustomState, InitArgs, InitArgsWithState,    \
15            NewArgs, NewArgsWithState, Reduce, ReduceWithState, MyInt, MyList, MyDict,  \
16            FixedOffset, execute
[58]17
[322]18    class MyLoader(yaml.Loader):
19        pass
20    class MyDumper(yaml.Dumper):
21        pass
[58]22
[322]23    class MyTestClass1:
24        def __init__(self, x, y=0, z=0):
25            self.x = x
26            self.y = y
27            self.z = z
28        def __eq__(self, other):
29            if isinstance(other, MyTestClass1):
30                return self.__class__, self.__dict__ == other.__class__, other.__dict__
31            else:
32                return False
[58]33
[322]34    def construct1(constructor, node):
[58]35        mapping = constructor.construct_mapping(node)
[322]36        return MyTestClass1(**mapping)
37    def represent1(representer, native):
38        return representer.represent_mapping("!tag1", native.__dict__)
[58]39
[322]40    yaml.add_constructor("!tag1", construct1, Loader=MyLoader)
41    yaml.add_representer(MyTestClass1, represent1, Dumper=MyDumper)
[144]42
[322]43    class MyTestClass2(MyTestClass1, yaml.YAMLObject):
44        yaml_loader = MyLoader
45        yaml_dumper = MyDumper
46        yaml_tag = "!tag2"
47        def from_yaml(cls, constructor, node):
48            x = constructor.construct_yaml_int(node)
49            return cls(x=x)
50        from_yaml = classmethod(from_yaml)
51        def to_yaml(cls, representer, native):
52            return representer.represent_scalar(cls.yaml_tag, str(native.x))
53        to_yaml = classmethod(to_yaml)
[144]54
[322]55    class MyTestClass3(MyTestClass2):
56        yaml_tag = "!tag3"
57        def from_yaml(cls, constructor, node):
58            mapping = constructor.construct_mapping(node)
59            if '=' in mapping:
60                x = mapping['=']
61                del mapping['=']
62                mapping['x'] = x
63            return cls(**mapping)
64        from_yaml = classmethod(from_yaml)
65        def to_yaml(cls, representer, native):
66            return representer.represent_mapping(cls.yaml_tag, native.__dict__)
67        to_yaml = classmethod(to_yaml)
[136]68
[322]69    class YAMLObject1(yaml.YAMLObject):
70        yaml_loader = MyLoader
71        yaml_dumper = MyDumper
72        yaml_tag = '!foo'
73        def __init__(self, my_parameter=None, my_another_parameter=None):
74            self.my_parameter = my_parameter
75            self.my_another_parameter = my_another_parameter
76        def __eq__(self, other):
77            if isinstance(other, YAMLObject1):
78                return self.__class__, self.__dict__ == other.__class__, other.__dict__
79            else:
80                return False
[136]81
[322]82    class YAMLObject2(yaml.YAMLObject):
83        yaml_loader = MyLoader
84        yaml_dumper = MyDumper
85        yaml_tag = '!bar'
86        def __init__(self, foo=1, bar=2, baz=3):
87            self.foo = foo
88            self.bar = bar
89            self.baz = baz
90        def __getstate__(self):
91            return {1: self.foo, 2: self.bar, 3: self.baz}
92        def __setstate__(self, state):
93            self.foo = state[1]
94            self.bar = state[2]
95            self.baz = state[3]
96        def __eq__(self, other):
97            if isinstance(other, YAMLObject2):
98                return self.__class__, self.__dict__ == other.__class__, other.__dict__
99            else:
100                return False
[136]101
[322]102    class AnObject(object):
103        def __new__(cls, foo=None, bar=None, baz=None):
104            self = object.__new__(cls)
105            self.foo = foo
106            self.bar = bar
107            self.baz = baz
108            return self
109        def __cmp__(self, other):
110            return cmp((type(self), self.foo, self.bar, self.baz),
111                    (type(other), other.foo, other.bar, other.baz))
112        def __eq__(self, other):
113            return type(self) is type(other) and    \
114                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
[58]115
[322]116    class AnInstance:
117        def __init__(self, foo=None, bar=None, baz=None):
118            self.foo = foo
119            self.bar = bar
120            self.baz = baz
121        def __cmp__(self, other):
122            return cmp((type(self), self.foo, self.bar, self.baz),
123                    (type(other), other.foo, other.bar, other.baz))
124        def __eq__(self, other):
125            return type(self) is type(other) and    \
126                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
[144]127
[322]128    class AState(AnInstance):
129        def __getstate__(self):
130            return {
131                '_foo': self.foo,
132                '_bar': self.bar,
133                '_baz': self.baz,
134            }
135        def __setstate__(self, state):
136            self.foo = state['_foo']
137            self.bar = state['_bar']
138            self.baz = state['_baz']
[144]139
[322]140    class ACustomState(AnInstance):
141        def __getstate__(self):
142            return (self.foo, self.bar, self.baz)
143        def __setstate__(self, state):
144            self.foo, self.bar, self.baz = state
[144]145
[322]146    class InitArgs(AnInstance):
147        def __getinitargs__(self):
148            return (self.foo, self.bar, self.baz)
149        def __getstate__(self):
150            return {}
[144]151
[322]152    class InitArgsWithState(AnInstance):
153        def __getinitargs__(self):
154            return (self.foo, self.bar)
155        def __getstate__(self):
156            return self.baz
157        def __setstate__(self, state):
158            self.baz = state
[144]159
[322]160    class NewArgs(AnObject):
161        def __getnewargs__(self):
162            return (self.foo, self.bar, self.baz)
163        def __getstate__(self):
164            return {}
[147]165
[322]166    class NewArgsWithState(AnObject):
167        def __getnewargs__(self):
168            return (self.foo, self.bar)
169        def __getstate__(self):
170            return self.baz
171        def __setstate__(self, state):
172            self.baz = state
[147]173
[322]174    class Reduce(AnObject):
175        def __reduce__(self):
176            return self.__class__, (self.foo, self.bar, self.baz)
[147]177
[322]178    class ReduceWithState(AnObject):
179        def __reduce__(self):
180            return self.__class__, (self.foo, self.bar), self.baz
181        def __setstate__(self, state):
182            self.baz = state
[147]183
[322]184    class MyInt(int):
185        def __eq__(self, other):
186            return type(self) is type(other) and int(self) == int(other)
[147]187
[322]188    class MyList(list):
189        def __init__(self, n=1):
190            self.extend([None]*n)
191        def __eq__(self, other):
192            return type(self) is type(other) and list(self) == list(other)
[147]193
[322]194    class MyDict(dict):
195        def __init__(self, n=1):
196            for k in range(n):
197                self[k] = None
198        def __eq__(self, other):
199            return type(self) is type(other) and dict(self) == dict(other)
[147]200
[322]201    class FixedOffset(datetime.tzinfo):
202        def __init__(self, offset, name):
203            self.__offset = datetime.timedelta(minutes=offset)
204            self.__name = name
205        def utcoffset(self, dt):
206            return self.__offset
207        def tzname(self, dt):
208            return self.__name
209        def dst(self, dt):
210            return datetime.timedelta(0)
[147]211
[322]212def _load_code(expression):
213    return eval(expression)
[147]214
[322]215def _serialize_value(data):
216    if isinstance(data, list):
217        return '[%s]' % ', '.join(map(_serialize_value, data))
218    elif isinstance(data, dict):
219        items = []
220        for key, value in data.items():
221            key = _serialize_value(key)
222            value = _serialize_value(value)
223            items.append("%s: %s" % (key, value))
224        items.sort()
225        return '{%s}' % ', '.join(items)
226    elif isinstance(data, datetime.datetime):
227        return repr(data.utctimetuple())
[325]228    elif isinstance(data, float) and data != data:
229        return '?'
[322]230    else:
231        return str(data)
[147]232
[322]233def test_constructor_types(data_filename, code_filename, verbose=False):
234    _make_objects()
235    native1 = None
236    native2 = None
237    try:
238        native1 = list(yaml.load_all(open(data_filename, 'rb'), Loader=MyLoader))
239        if len(native1) == 1:
240            native1 = native1[0]
241        native2 = _load_code(open(code_filename, 'rb').read())
[58]242        try:
[322]243            if native1 == native2:
244                return
245        except TypeError:
246            pass
247        if verbose:
[328]248            print("SERIALIZED NATIVE1:")
249            print(_serialize_value(native1))
250            print("SERIALIZED NATIVE2:")
251            print(_serialize_value(native2))
[322]252        assert _serialize_value(native1) == _serialize_value(native2), (native1, native2)
253    finally:
254        if verbose:
[328]255            print("NATIVE1:")
[322]256            pprint.pprint(native1)
[328]257            print("NATIVE2:")
[322]258            pprint.pprint(native2)
[58]259
[322]260test_constructor_types.unittest = ['.data', '.code']
[58]261
[322]262if __name__ == '__main__':
263    import sys, test_constructor
264    sys.modules['test_constructor'] = sys.modules['__main__']
265    import test_appliance
266    test_appliance.run(globals())
267
Note: See TracBrowser for help on using the repository browser.