INT = auto()
FLOAT = auto()
IDENT = auto()
+ STRING = auto()
+ ARITHMETIC_OP = auto()
COUNT = auto()
ast_node_type_as_str_map: {AstNodeType : str} = {
- AstNodeType.EXPR : "Expr",
- AstNodeType.STMT : "Stmt",
- AstNodeType.INT : "Int",
- AstNodeType.FLOAT : "Float",
- AstNodeType.IDENT : "Identifier",
+ AstNodeType.EXPR : "Expr",
+ AstNodeType.STMT : "Stmt",
+ AstNodeType.INT : "Int",
+ AstNodeType.FLOAT : "Float",
+ AstNodeType.IDENT : "Identifier",
+ AstNodeType.STRING : "String",
+ AstNodeType.ARITHMETIC_OP: "Arithmetic Op",
}
assert len(ast_node_type_as_str_map) == AstNodeType.COUNT-1, "Every AstNodeType is not handled in ast_node_type_as_str_map"
NOTE: ยข is a zero length string, meaning nothing will be substituted
Grammar:
Statement => Ident :? Ident? =? Expression* ;
- Expression => Name Binop Name
+ Expression => Name (Binop Name)*
Name => LitValue | Ident
LitValue => Int | Float | String
BinaryOperator => ArithmeticOp | ComparisionOp | LogicalOp
self.typ = typ
class AstNodeStatement(AstNode):
- def __init__(self, token: Token):
+ def __init__(self, token: Token, var_name, var_type, expr: AstNode):
super().__init__(token, AstNodeType.STMT)
+ self.var_name = var_name
+ self.var_type = var_type
+ self.expr = expr
+
+ def __repr__(self):
+ return f"{self.var_name.__repr__()} : {self.var_type.__repr__()} = {self.expr.__repr__()}"
class AstNodeExpression(AstNode):
- def __init__(self, token: Token):
+ def __init__(self, token: Token, lhs, binop, rhs):
super().__init__(token, AstNodeType.EXPR)
+ self.lhs = lhs
+ self.binop = binop
+ self.rhs = rhs
+
+ def __repr__(self):
+ return f"{self.lhs.__repr__()} {self.binop.__repr__()} {self.rhs.__repr__()}"
class AstNodeInt(AstNode):
def __init__(self, token: Token, value: int):
super().__init__(token, AstNodeType.INT)
self.value = value
+ def __repr__(self):
+ return f"{self.value}"
+
class AstNodeIdentifier(AstNode):
def __init__(self, token: Token, name: str):
super().__init__(token, AstNodeType.IDENT)
self.name = name
+ def __repr__(self):
+ return f"{self.name}"
+
+class AstNodeString(AstNode):
+ def __init__(self, token: Token, string: str):
+ super().__init__(token, AstNodeType.STRING)
+ self.string = string
+
+ def __repr__(self):
+ return f"'{self.string}'"
+
+class AstNodeArithmeticOp(AstNode):
+ def __init__(self, token: Token):
+ super().__init__(token, AstNodeType.ARITHMETIC_OP)
+ self.op = self.token.lexeme
+
+ def __repr__(self):
+ return f"{self.op}"
+
+class ParseError(IntEnum):
+ EOF = auto()
+ UNEXPECTED_TOKEN = auto()
+ COUNT = auto()
+
class Parser:
def __init__(self, tokens):
self.tokens = tokens
stmt = self.parseStatement()
return stmt
- def parseStatement(self) -> AstNodeStatement:
+ def parseStatement(self) -> AstNode | ParseError | None:
tokens = self.tokens
# Variable name
- ident_ast = self.parseIdentifier()
- var_typ = None
+ var_name_ast = self.parseIdentifier()
+ var_type_ast = None
# Check if colon is there
if len(tokens) >= 1 and tokens[0].typ == TokenType.COLON:
# TODO: Should i make an AstNode for the colon too?
colon = tokens.pop(0)
- var_typ = None if len(tokens) <= 0 else tokens.pop(0)
- if var_typ == None or var_typ.typ != TokenType.IDENT:
- if var_typ == None: error_msg = "Reached End of File"
- else: error_msg = f"Got {token_type_as_str_map[var_typ.typ]}"
- self.syntax_error("Expected type of variable after colon, but %s" % error_msg, colon)
+ var_type_ast = self.parseIdentifier()
+ if var_type_ast == ParseError.EOF:
+ self.syntax_error("Expected type of variable after colon, but reached end of file", colon)
+ elif var_type_ast == ParseError.UNEXPECTED_TOKEN:
+ unexpected_token = tokens[0]
+ self.syntax_error(f"Expected type of variable after colon, but got {token_type_as_str_map[unexpected_token.typ]}", colon)
- dlog(ident_ast.token)
- if var_typ != None:
- dlog(var_typ)
+ # dlog(var_name_ast)
+ # if var_type_ast != None:
+ # dlog(var_type_ast)
- print("TODO: Implement parseStatement()")
- exit(1)
- # expr = self.parseExpr()
+ equal = None
+ if len(tokens) >= 1 and tokens[0].typ == TokenType.EQUAL:
+ equal = tokens.pop(0)
- def parseIdentifier(self) -> AstNodeIdentifier:
- if len(self.tokens) <= 0: return None
- if self.tokens[0].typ != TokenType.IDENT: return None
+ expr = self.parseExpression()
+
+ # if expr != ParseError.EOF:
+ # dlog(expr)
+
+ if len(tokens) <= 0:
+ self.syntax_error(f"Expected ; but reached EOF")
+
+ semicolon = tokens.pop(0)
+ if semicolon.typ != TokenType.SEMICOLON:
+ fatal("We don't support Statements with more than one expressions yet!")
+
+ return AstNodeStatement(var_name_ast.token, var_name_ast, var_type_ast, expr)
+
+
+ def parseIdentifier(self) -> AstNode | ParseError | None:
+ if len(self.tokens) <= 0: return ParseError.EOF
+ if self.tokens[0].typ != TokenType.IDENT: return ParseError.UNEXPECTED_TOKEN
ident_token = self.tokens.pop(0)
return AstNodeIdentifier(ident_token, ident_token.lexeme)
- def parseExpr(self) -> AstNodeExpression:
- pass
+ def parseExpression(self) -> AstNode | ParseError | None:
+ if len(self.tokens) <= 0: return ParseError.EOF
+ t = self.tokens[0]
+ lhs = self.parseName()
+ binop = self.parseBinOp()
+ rhs = self.parseName()
+
+ return AstNodeExpression(t, lhs, binop, rhs)
+
+ def parseName(self) -> AstNode | ParseError | None:
+ if len(self.tokens) <= 0: return ParseError.EOF
+ name = self.parseLiteralValue()
+ if name == None:
+ return self.parseIdentifier()
+ else: return name
+
+ assert False, "UNREACHABLE!"
+
+ def parseLiteralValue(self) -> AstNode | ParseError | None:
+ if len(self.tokens) <= 0: return ParseError.EOF
+ t = self.tokens.pop(0)
+ if t.typ == TokenType.INT:
+ return AstNodeInt(t, int(t.lexeme))
+ elif t.typ == TokenType.FLOAT:
+ return AstNodeFloat(t, float(t.lexeme))
+ elif t.typ == TokenType.STRING:
+ return AstNodeString(t, t.lexeme)
+
+ return None
+
+ def parseBinOp(self) -> AstNode | ParseError | None:
+ arithmeticOp = self.parseArithmeticOp()
+ if arithmeticOp == None:
+ comparisionOp = self.parseComparisionOp()
+ if comparionOp == None:
+ return self.parseLogicalOp()
+ else:
+ return comparionOp
+ else: return arithmeticOp
+
+ assert False, "UNREACHABLE!"
+ def parseArithmeticOp(self) -> AstNode | ParseError | None:
+ if len(self.tokens) <= 0: return ParseError.EOF
+ t = self.tokens.pop(0)
+
+ if t.typ in [ TokenType.PLUS, TokenType.MINUS, TokenType.DIVIDE, TokenType.MODULUS ]:
+ return AstNodeArithmeticOp(t)
+
+ return None
def main():
program: str = sys.argv.pop(0)
# for t in tokens:
# pprint.pp(str(t))
- parser.parse()
+ print(parser.parse())
if __name__ == '__main__':
main()