source: branches/pyyaml3000/tests/test_appliance.py @ 45

Revision 45, 11.6 KB checked in by xi, 8 years ago (diff)

Stream and Marker are cleaned up.

Line 
1
2import unittest, os
3
4class TestAppliance(unittest.TestCase):
5
6    DATA = 'tests/data'
7
8    all_tests = {}
9    for filename in os.listdir(DATA):
10        if os.path.isfile(os.path.join(DATA, filename)):
11            root, ext = os.path.splitext(filename)
12            all_tests.setdefault(root, []).append(ext)
13
14    def add_tests(cls, method_name, *extensions):
15        for test in cls.all_tests:
16            available_extensions = cls.all_tests[test]
17            for ext in extensions:
18                if ext not in available_extensions:
19                    break
20            else:
21                filenames = [os.path.join(cls.DATA, test+ext) for ext in extensions]
22                def test_method(self, test=test, filenames=filenames):
23                    getattr(self, '_'+method_name)(test, *filenames)
24                test = test.replace('-', '_')
25                try:
26                    test_method.__name__ = '%s_%s' % (method_name, test)
27                except TypeError:
28                    import new
29                    test_method = new.function(test_method.func_code, test_method.func_globals,
30                            '%s_%s' % (method_name, test), test_method.func_defaults,
31                            test_method.func_closure)
32                setattr(cls, test_method.__name__, test_method)
33    add_tests = classmethod(add_tests)
34
35class Node:
36    def __repr__(self):
37        args = []
38        for attribute in ['anchor', 'tag', 'value']:
39            if hasattr(self, attribute):
40                args.append(repr(getattr(self, attribute)))
41        return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
42
43class AliasNode(Node):
44    def __init__(self, anchor):
45        self.anchor = anchor
46
47class ScalarNode(Node):
48    def __init__(self, anchor, tag, value):
49        self.anchor = anchor
50        self.tag = tag
51        self.value = value
52
53class SequenceNode(Node):
54    def __init__(self, anchor, tag, value):
55        self.anchor = anchor
56        self.tag = tag
57        self.value = value
58
59class MappingNode(Node):
60    def __init__(self, anchor, tag, value):
61        self.anchor = anchor
62        self.tag = tag
63        self.value = value
64
65class Token:
66    def __repr__(self):
67        args = []
68        if hasattr(self, 'value'):
69            args.append(repr(self.value))
70        return "%s(%s)" % (self.__class__.__name__, ''.join(args))
71
72class EndToken(Token):
73    pass
74
75class DirectiveToken(Token):
76    pass
77
78class DocumentStartToken(Token):
79    pass
80
81class SequenceStartToken(Token):
82    pass
83
84class MappingStartToken(Token):
85    pass
86
87class SequenceEndToken(Token):
88    pass
89
90class MappingEndToken(Token):
91    pass
92
93class KeyToken(Token):
94    pass
95
96class ValueToken(Token):
97    pass
98
99class EntryToken(Token):
100    pass
101
102class AliasToken(Token):
103    def __init__(self, value):
104        self.value = value
105
106class AnchorToken(Token):
107    def __init__(self, value):
108        self.value = value
109
110class TagToken(Token):
111    def __init__(self, value):
112        self.value = value
113
114class ScalarToken(Token):
115    def __init__(self, value):
116        self.value = value
117
118class Error(Exception):
119    pass
120
121class CanonicalScanner:
122
123    def __init__(self, source, data):
124        self.source = source
125        self.data = unicode(data, 'utf-8')+u'\0'
126        self.index = 0
127
128    def scan(self):
129        #print self.data[self.index:]
130        tokens = []
131        while True:
132            self.find_token()
133            ch = self.data[self.index]
134            if ch == u'\0':
135                tokens.append(EndToken())
136                break
137            elif ch == u'%':
138                tokens.append(self.scan_directive())
139            elif ch == u'-' and self.data[self.index:self.index+3] == u'---':
140                self.index += 3
141                tokens.append(DocumentStartToken())
142            elif ch == u'[':
143                self.index += 1
144                tokens.append(SequenceStartToken())
145            elif ch == u'{':
146                self.index += 1
147                tokens.append(MappingStartToken())
148            elif ch == u']':
149                self.index += 1
150                tokens.append(SequenceEndToken())
151            elif ch == u'}':
152                self.index += 1
153                tokens.append(MappingEndToken())
154            elif ch == u'?':
155                self.index += 1
156                tokens.append(KeyToken())
157            elif ch == u':':
158                self.index += 1
159                tokens.append(ValueToken())
160            elif ch == u',':
161                self.index += 1
162                tokens.append(EntryToken())
163            elif ch == u'*' or ch == u'&':
164                tokens.append(self.scan_alias())
165            elif ch == u'!':
166                tokens.append(self.scan_tag())
167            elif ch == u'"':
168                tokens.append(self.scan_scalar())
169            else:
170                raise Error("invalid token")
171        return tokens
172
173    DIRECTIVE = u'%YAML 1.1'
174
175    def scan_directive(self):
176        if self.data[self.index:self.index+len(self.DIRECTIVE)] == self.DIRECTIVE and \
177                self.data[self.index+len(self.DIRECTIVE)] in u' \n\0':
178            self.index += len(self.DIRECTIVE)
179            return DirectiveToken()
180
181    def scan_alias(self):
182        if self.data[self.index] == u'*':
183            TokenClass = AliasToken
184        else:
185            TokenClass = AnchorToken
186        self.index += 1
187        start = self.index
188        while self.data[self.index] not in u', \n\0':
189            self.index += 1
190        value = self.data[start:self.index]
191        return TokenClass(value)
192
193    def scan_tag(self):
194        self.index += 1
195        start = self.index
196        while self.data[self.index] not in u' \n\0':
197            self.index += 1
198        value = self.data[start:self.index]
199        if value[0] == u'!':
200            value = 'tag:yaml.org,2002:'+value[1:]
201        else:
202            value = value[1:-1]
203        return TagToken(value)
204
205    QUOTE_CODES = {
206        'x': 2,
207        'u': 4,
208        'U': 8,
209    }
210
211    QUOTE_REPLACES = {
212        u'\\': u'\\',
213        u'\"': u'\"',
214        u' ': u' ',
215        u'a': u'\x07',
216        u'b': u'\x08',
217        u'e': u'\x1B',
218        u'f': u'\x0C',
219        u'n': u'\x0A',
220        u'r': u'\x0D',
221        u't': u'\x09',
222        u'v': u'\x0B',
223        u'N': u'\u0085',
224        u'L': u'\u2028',
225        u'P': u'\u2029',
226        u'_': u'_',
227        u'0': u'\x00',
228
229    }
230
231    def scan_scalar(self):
232        self.index += 1
233        chunks = []
234        start = self.index
235        ignore_spaces = False
236        while self.data[self.index] != u'"':
237            if self.data[self.index] == u'\\':
238                ignore_spaces = False
239                chunks.append(self.data[start:self.index])
240                self.index += 1
241                ch = self.data[self.index]
242                self.index += 1
243                if ch == u'\n':
244                    ignore_spaces = True
245                elif ch in self.QUOTE_CODES:
246                    length = self.QUOTE_CODES[ch]
247                    code = int(self.data[self.index:self.index+length], 16)
248                    chunks.append(unichr(code))
249                    self.index += length
250                else:
251                    chunks.append(self.QUOTE_REPLACES[ch])
252                start = self.index
253            elif self.data[self.index] == u'\n':
254                chunks.append(u' ')
255                self.index += 1
256                ignore_spaces = True
257            elif ignore_spaces and self.data[self.index] == u' ':
258                self.index += 1
259                start = self.index
260            else:
261                ignore_spaces = False
262                self.index += 1
263        chunks.append(self.data[start:self.index])
264        self.index += 1
265        return ScalarToken(u''.join(chunks))
266
267    def find_token(self):
268        found = False
269        while not found:
270            while self.data[self.index] in u' \t':
271                self.index += 1
272            if self.data[self.index] == u'#':
273                while self.data[self.index] != u'\n':
274                    self.index += 1
275            if self.data[self.index] == u'\n':
276                self.index += 1
277            else:
278                found = True
279
280class CanonicalParser:
281
282    def __init__(self, source, data):
283        self.scanner = CanonicalScanner(source, data)
284
285    # stream: document* END
286    def parse_stream(self):
287        documents = []
288        while not self.test_token(EndToken):
289            if self.test_token(DirectiveToken, DocumentStartToken):
290                documents.append(self.parse_document())
291            else:
292                raise Error("document is expected, got "+repr(self.tokens[self.index]))
293        return documents
294
295    # document: DIRECTIVE? DOCUMENT-START node?
296    def parse_document(self):
297        node = None
298        if self.test_token(DirectiveToken):
299            self.consume_token(DirectiveToken)
300        self.consume_token(DocumentStartToken)
301        if self.test_token(TagToken, AliasToken, AnchorToken, TagToken,
302                SequenceStartToken, MappingStartToken, ScalarToken):
303            node = self.parse_node()
304        return node
305
306    # node: ALIAS | ANCHOR? TAG? (SCALAR|sequence|mapping)
307    def parse_node(self):
308        if self.test_token(AliasToken):
309            return AliasNode(self.get_value())
310        else:
311            anchor = None
312            if self.test_token(AnchorToken):
313                anchor = self.get_value()
314            tag = None
315            if self.test_token(TagToken):
316                tag = self.get_value()
317            if self.test_token(ScalarToken):
318                return ScalarNode(anchor, tag, self.get_value())
319            elif self.test_token(SequenceStartToken):
320                return SequenceNode(anchor, tag, self.parse_sequence())
321            elif self.test_token(MappingStartToken):
322                return MappingNode(anchor, tag, self.parse_mapping())
323            else:
324                raise Error("SCALAR, '[', or '{' is expected, got "+repr(self.tokens[self.index]))
325
326    # sequence: SEQUENCE-START (node (ENTRY node)*)? ENTRY? SEQUENCE-END
327    def parse_sequence(self):
328        values = []
329        self.consume_token(SequenceStartToken)
330        if not self.test_token(SequenceEndToken):
331            values.append(self.parse_node())
332            while not self.test_token(SequenceEndToken):
333                self.consume_token(EntryToken)
334                if not self.test_token(SequenceEndToken):
335                    values.append(self.parse_node())
336        self.consume_token(SequenceEndToken)
337        return values
338
339    # mapping: MAPPING-START (map_entry (ENTRY map_entry)*)? ENTRY? MAPPING-END
340    def parse_mapping(self):
341        values = []
342        self.consume_token(MappingStartToken)
343        if not self.test_token(MappingEndToken):
344            values.append(self.parse_map_entry())
345            while not self.test_token(MappingEndToken):
346                self.consume_token(EntryToken)
347                if not self.test_token(MappingEndToken):
348                    values.append(self.parse_map_entry())
349        self.consume_token(MappingEndToken)
350        return values
351
352    # map_entry: KEY node VALUE node
353    def parse_map_entry(self):
354        self.consume_token(KeyToken)
355        key = self.parse_node()
356        self.consume_token(ValueToken)
357        value = self.parse_node()
358        return (key, value)
359
360    def test_token(self, *choices):
361        for choice in choices:
362            if isinstance(self.tokens[self.index], choice):
363                return True
364        return False
365
366    def consume_token(self, cls):
367        if not isinstance(self.tokens[self.index], cls):
368            raise Error("unexpected token "+repr(self.tokens[self.index]))
369        self.index += 1
370
371    def get_value(self):
372        value = self.tokens[self.index].value
373        self.index += 1
374        return value
375
376    def parse(self):
377        self.tokens = self.scanner.scan()
378        self.index = 0
379        return self.parse_stream()
380
Note: See TracBrowser for help on using the repository browser.