source: branches/pyyaml3000/tests/test_appliance.py @ 44

Revision 44, 11.3 KB checked in by xi, 9 years ago (diff)

All tests passed! Scanner and Parser seem to be correct.

Line 
1
2import unittest, os
3
4class TestAppliance(unittest.TestCase):
5
6    DATA = 'tests/data'
7
8    tests = {}
9    for filename in os.listdir(DATA):
10        if os.path.isfile(os.path.join(DATA, filename)):
11            root, ext = os.path.splitext(filename)
12            tests.setdefault(root, []).append(ext)
13
14    def add_tests(cls, method_name, *extensions):
15        for test in cls.tests:
16            available_extensions = cls.tests[test]
17            for ext in extensions:
18                if ext not in available_extensions:
19                    break
20            else:
21                filenames = [os.path.join(cls.DATA, test+ext) for ext in extensions]
22                def test_method(self, test=test, filenames=filenames):
23                    getattr(self, '_'+method_name)(test, *filenames)
24                test = test.replace('-', '_')
25                test_method.__name__ = '%s_%s' % (method_name, test)
26                setattr(cls, test_method.__name__, test_method)
27    add_tests = classmethod(add_tests)
28
29class Node:
30    def __repr__(self):
31        args = []
32        for attribute in ['anchor', 'tag', 'value']:
33            if hasattr(self, attribute):
34                args.append(repr(getattr(self, attribute)))
35        return "%s(%s)" % (self.__class__.__name__, ', '.join(args))
36
37class AliasNode(Node):
38    def __init__(self, anchor):
39        self.anchor = anchor
40
41class ScalarNode(Node):
42    def __init__(self, anchor, tag, value):
43        self.anchor = anchor
44        self.tag = tag
45        self.value = value
46
47class SequenceNode(Node):
48    def __init__(self, anchor, tag, value):
49        self.anchor = anchor
50        self.tag = tag
51        self.value = value
52
53class MappingNode(Node):
54    def __init__(self, anchor, tag, value):
55        self.anchor = anchor
56        self.tag = tag
57        self.value = value
58
59class Token:
60    def __repr__(self):
61        args = []
62        if hasattr(self, 'value'):
63            args.append(repr(self.value))
64        return "%s(%s)" % (self.__class__.__name__, ''.join(args))
65
66class EndToken(Token):
67    pass
68
69class DirectiveToken(Token):
70    pass
71
72class DocumentStartToken(Token):
73    pass
74
75class SequenceStartToken(Token):
76    pass
77
78class MappingStartToken(Token):
79    pass
80
81class SequenceEndToken(Token):
82    pass
83
84class MappingEndToken(Token):
85    pass
86
87class KeyToken(Token):
88    pass
89
90class ValueToken(Token):
91    pass
92
93class EntryToken(Token):
94    pass
95
96class AliasToken(Token):
97    def __init__(self, value):
98        self.value = value
99
100class AnchorToken(Token):
101    def __init__(self, value):
102        self.value = value
103
104class TagToken(Token):
105    def __init__(self, value):
106        self.value = value
107
108class ScalarToken(Token):
109    def __init__(self, value):
110        self.value = value
111
112class Error(Exception):
113    pass
114
115class CanonicalScanner:
116
117    def __init__(self, source, data):
118        self.source = source
119        self.data = unicode(data, 'utf-8')+u'\0'
120        self.index = 0
121
122    def scan(self):
123        #print self.data[self.index:]
124        tokens = []
125        while True:
126            self.find_token()
127            ch = self.data[self.index]
128            if ch == u'\0':
129                tokens.append(EndToken())
130                break
131            elif ch == u'%':
132                tokens.append(self.scan_directive())
133            elif ch == u'-' and self.data[self.index:self.index+3] == u'---':
134                self.index += 3
135                tokens.append(DocumentStartToken())
136            elif ch == u'[':
137                self.index += 1
138                tokens.append(SequenceStartToken())
139            elif ch == u'{':
140                self.index += 1
141                tokens.append(MappingStartToken())
142            elif ch == u']':
143                self.index += 1
144                tokens.append(SequenceEndToken())
145            elif ch == u'}':
146                self.index += 1
147                tokens.append(MappingEndToken())
148            elif ch == u'?':
149                self.index += 1
150                tokens.append(KeyToken())
151            elif ch == u':':
152                self.index += 1
153                tokens.append(ValueToken())
154            elif ch == u',':
155                self.index += 1
156                tokens.append(EntryToken())
157            elif ch == u'*' or ch == u'&':
158                tokens.append(self.scan_alias())
159            elif ch == u'!':
160                tokens.append(self.scan_tag())
161            elif ch == u'"':
162                tokens.append(self.scan_scalar())
163            else:
164                raise Error("invalid token")
165        return tokens
166
167    DIRECTIVE = u'%YAML 1.1'
168
169    def scan_directive(self):
170        if self.data[self.index:self.index+len(self.DIRECTIVE)] == self.DIRECTIVE and \
171                self.data[self.index+len(self.DIRECTIVE)] in u' \n\0':
172            self.index += len(self.DIRECTIVE)
173            return DirectiveToken()
174
175    def scan_alias(self):
176        if self.data[self.index] == u'*':
177            TokenClass = AliasToken
178        else:
179            TokenClass = AnchorToken
180        self.index += 1
181        start = self.index
182        while self.data[self.index] not in u', \n\0':
183            self.index += 1
184        value = self.data[start:self.index]
185        return TokenClass(value)
186
187    def scan_tag(self):
188        self.index += 1
189        start = self.index
190        while self.data[self.index] not in u' \n\0':
191            self.index += 1
192        value = self.data[start:self.index]
193        if value[0] == u'!':
194            value = 'tag:yaml.org,2002:'+value[1:]
195        else:
196            value = value[1:-1]
197        return TagToken(value)
198
199    QUOTE_CODES = {
200        'x': 2,
201        'u': 4,
202        'U': 8,
203    }
204
205    QUOTE_REPLACES = {
206        u'\\': u'\\',
207        u'\"': u'\"',
208        u' ': u' ',
209        u'a': u'\x07',
210        u'b': u'\x08',
211        u'e': u'\x1B',
212        u'f': u'\x0C',
213        u'n': u'\x0A',
214        u'r': u'\x0D',
215        u't': u'\x09',
216        u'v': u'\x0B',
217        u'N': u'\u0085',
218        u'L': u'\u2028',
219        u'P': u'\u2029',
220        u'_': u'_',
221        u'0': u'\x00',
222
223    }
224
225    def scan_scalar(self):
226        self.index += 1
227        chunks = []
228        start = self.index
229        ignore_spaces = False
230        while self.data[self.index] != u'"':
231            if self.data[self.index] == u'\\':
232                ignore_spaces = False
233                chunks.append(self.data[start:self.index])
234                self.index += 1
235                ch = self.data[self.index]
236                self.index += 1
237                if ch == u'\n':
238                    ignore_spaces = True
239                elif ch in self.QUOTE_CODES:
240                    length = self.QUOTE_CODES[ch]
241                    code = int(self.data[self.index:self.index+length], 16)
242                    chunks.append(unichr(code))
243                    self.index += length
244                else:
245                    chunks.append(self.QUOTE_REPLACES[ch])
246                start = self.index
247            elif self.data[self.index] == u'\n':
248                chunks.append(u' ')
249                self.index += 1
250                ignore_spaces = True
251            elif ignore_spaces and self.data[self.index] == u' ':
252                self.index += 1
253                start = self.index
254            else:
255                ignore_spaces = False
256                self.index += 1
257        chunks.append(self.data[start:self.index])
258        self.index += 1
259        return ScalarToken(u''.join(chunks))
260
261    def find_token(self):
262        found = False
263        while not found:
264            while self.data[self.index] in u' \t':
265                self.index += 1
266            if self.data[self.index] == u'#':
267                while self.data[self.index] != u'\n':
268                    self.index += 1
269            if self.data[self.index] == u'\n':
270                self.index += 1
271            else:
272                found = True
273
274class CanonicalParser:
275
276    def __init__(self, source, data):
277        self.scanner = CanonicalScanner(source, data)
278
279    # stream: document* END
280    def parse_stream(self):
281        documents = []
282        while not self.test_token(EndToken):
283            if self.test_token(DirectiveToken, DocumentStartToken):
284                documents.append(self.parse_document())
285            else:
286                raise Error("document is expected, got "+repr(self.tokens[self.index]))
287        return documents
288
289    # document: DIRECTIVE? DOCUMENT-START node?
290    def parse_document(self):
291        node = None
292        if self.test_token(DirectiveToken):
293            self.consume_token(DirectiveToken)
294        self.consume_token(DocumentStartToken)
295        if self.test_token(TagToken, AliasToken, AnchorToken, TagToken,
296                SequenceStartToken, MappingStartToken, ScalarToken):
297            node = self.parse_node()
298        return node
299
300    # node: ALIAS | ANCHOR? TAG? (SCALAR|sequence|mapping)
301    def parse_node(self):
302        if self.test_token(AliasToken):
303            return AliasNode(self.get_value())
304        else:
305            anchor = None
306            if self.test_token(AnchorToken):
307                anchor = self.get_value()
308            tag = None
309            if self.test_token(TagToken):
310                tag = self.get_value()
311            if self.test_token(ScalarToken):
312                return ScalarNode(anchor, tag, self.get_value())
313            elif self.test_token(SequenceStartToken):
314                return SequenceNode(anchor, tag, self.parse_sequence())
315            elif self.test_token(MappingStartToken):
316                return MappingNode(anchor, tag, self.parse_mapping())
317            else:
318                raise Error("SCALAR, '[', or '{' is expected, got "+repr(self.tokens[self.index]))
319
320    # sequence: SEQUENCE-START (node (ENTRY node)*)? ENTRY? SEQUENCE-END
321    def parse_sequence(self):
322        values = []
323        self.consume_token(SequenceStartToken)
324        if not self.test_token(SequenceEndToken):
325            values.append(self.parse_node())
326            while not self.test_token(SequenceEndToken):
327                self.consume_token(EntryToken)
328                if not self.test_token(SequenceEndToken):
329                    values.append(self.parse_node())
330        self.consume_token(SequenceEndToken)
331        return values
332
333    # mapping: MAPPING-START (map_entry (ENTRY map_entry)*)? ENTRY? MAPPING-END
334    def parse_mapping(self):
335        values = []
336        self.consume_token(MappingStartToken)
337        if not self.test_token(MappingEndToken):
338            values.append(self.parse_map_entry())
339            while not self.test_token(MappingEndToken):
340                self.consume_token(EntryToken)
341                if not self.test_token(MappingEndToken):
342                    values.append(self.parse_map_entry())
343        self.consume_token(MappingEndToken)
344        return values
345
346    # map_entry: KEY node VALUE node
347    def parse_map_entry(self):
348        self.consume_token(KeyToken)
349        key = self.parse_node()
350        self.consume_token(ValueToken)
351        value = self.parse_node()
352        return (key, value)
353
354    def test_token(self, *choices):
355        for choice in choices:
356            if isinstance(self.tokens[self.index], choice):
357                return True
358        return False
359
360    def consume_token(self, cls):
361        if not isinstance(self.tokens[self.index], cls):
362            raise Error("unexpected token "+repr(self.tokens[self.index]))
363        self.index += 1
364
365    def get_value(self):
366        value = self.tokens[self.index].value
367        self.index += 1
368        return value
369
370    def parse(self):
371        self.tokens = self.scanner.scan()
372        self.index = 0
373        return self.parse_stream()
374
Note: See TracBrowser for help on using the repository browser.