from enum import IntEnum, auto
import pprint
-from typing import Tuple, List
+from typing import Tuple, List, cast
DEBUG = True
return None
def lex(self) -> List[Token]:
- tokens: [Token] = []
- token = self.next_token()
+ tokens: List[Token] = []
+ token: Token | None = self.next_token()
while token != None:
- tokens.append(token)
+ tokens.append(cast(Token, token))
token = self.next_token()
return tokens
BINARY_OP = auto()
COUNT = auto()
-ast_node_type_as_str_map: {AstNodeType : str} = {
+ast_node_type_as_str_map: dict[AstNodeType, str] = {
AstNodeType.EXPR : "Expr",
AstNodeType.STMT : "Stmt",
AstNodeType.INT : "Int",
NAH = auto()
COUNT = auto()
+parse_error_as_str_map: dict[ParseError, str] = {
+ ParseError.EOF : "Eof",
+ ParseError.UNEXPECTED_TOKEN : "Unexpected Token",
+ ParseError.NAH : "Nah",
+}
+assert len(parse_error_as_str_map) == ParseError.COUNT-1, "Every ParseError is not handled in parse_error_as_str_map"
+
+class ParseException(Exception):
+ def __init__(self, typ: ParseError):
+ self.typ = typ
+
+ def __repr__(self):
+ return f"Parse Exception: {parse_error_as_str_map[self.typ]}"
+
class Parser:
def __init__(self, tokens):
self.tokens = tokens
def parse(self) -> AstNode:
stmt = self.parseStatement()
+ assert isinstance(stmt, AstNode)
return stmt
- def parseStatement(self) -> AstNode | ParseError | None:
+ def parseStatement(self) -> AstNode | None:
tokens = self.tokens
# Variable name
var_name_ast = self.parseIdentifier()
+ # TODO: Check for var_name_ast
var_type_ast = None
# Check if colon is there
if len(tokens) >= 1 and tokens[0].typ == TokenType.COLON:
# if var_type_ast != None:
# dlog(var_type_ast)
- equal = None
+ equal: Token | None = None
if len(tokens) >= 1 and tokens[0].typ == TokenType.EQUAL:
equal = tokens.pop(0)
+ if equal == None:
+ if len(tokens) <= 0: raise ParseException(ParseError.EOF)
+ semicolon = tokens.pop(0)
+ if semicolon.typ != TokenType.SEMICOLON:
+ fatal("We don't support Statements with more than one expressions yet!")
+
expr = self.parseExpression()
# dlog(f"EXPR: {expr}")
if isinstance(expr, ParseError):
if expr == ParseError.EOF:
+ assert isinstance(equal, Token)
self.syntax_error(f"Expected ; but reached EOF", equal)
semicolon = tokens.pop(0)
return AstNodeStatement(var_name_ast.token, var_name_ast, var_type_ast, expr)
- def parseIdentifier(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
- if self.tokens[0].typ != TokenType.IDENT: return ParseError.UNEXPECTED_TOKEN
+ def parseIdentifier(self) -> AstNode | None:
+ if len(self.tokens) <= 0: raise ParseException(ParseError.EOF)
+ if self.tokens[0].typ != TokenType.IDENT: return ParseException(ParseError.UNEXPECTED_TOKEN)
ident_token = self.tokens.pop(0)
return AstNodeIdentifier(ident_token, ident_token.lexeme)
- def parseExpression(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseExpression(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens[0]
lhs = self.parseName()
return current_node
- def parseName(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseName(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
name = self.parseLiteralValue()
if isinstance(name, ParseError):
return self.parseIdentifier()
return name
- def parseLiteralValue(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseLiteralValue(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens.pop(0)
if t.typ == TokenType.INT:
return AstNodeInt(t, int(t.lexeme))
return None
- def parseBinOp(self) -> AstNode | ParseError | None:
+ def parseBinOp(self) -> AstNode | None:
# TODO: Check if Operator predecence is correct
arithmeticOp = self.parseArithmeticOp()
assert False, "UNREACHABLE!"
- def parseComparisonOp(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseComparisonOp(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens.pop(0)
if t.typ in [ TokenType.GT, TokenType.GTE, TokenType.LT, TokenType.LTE, TokenType.EQUAL_EQUAL, TokenType.NOT_EQUAL ]:
return None
- def parseLogicalOp(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseLogicalOp(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens.pop(0)
if t.typ in [ TokenType.LOGICAL_AND, TokenType.LOGICAL_OR ]:
return None
- def parseArithmeticOp(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseArithmeticOp(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens.pop(0)
if t.typ in [ TokenType.PLUS, TokenType.MINUS, TokenType.DIVIDE, TokenType.MODULUS ]:
return None
- def parseBinaryArithmeticOp(self) -> AstNode | ParseError | None:
- if len(self.tokens) <= 0: return ParseError.EOF
+ def parseBinaryArithmeticOp(self) -> AstNode | None:
+ if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
t = self.tokens.pop(0)
if t.typ in [ TokenType.BINARY_AND, TokenType.BINARY_OR, TokenType.BINARY_NOT ]:
# TODO: Parse
parser = Parser(tokens)
- for t in tokens:
- pprint.pp(str(t))
+ # for t in tokens:
+ # pprint.pp(str(t))
- # print(parser.parse())
+ print(parser.parse())
if __name__ == '__main__':
main()