]> www.git.momoyon.org Git - lang.git/commitdiff
WIP: TODO: using exceptions for ParseError...
authorahmedsamyh <ahmedsamyh10@gmail.com>
Wed, 20 Nov 2024 18:25:10 +0000 (23:25 +0500)
committerahmedsamyh <ahmedsamyh10@gmail.com>
Wed, 20 Nov 2024 18:25:10 +0000 (23:25 +0500)
- More type checking stuff.
- ignore tags file.

.gitignore
main.py

index a6e9d26f625b90a4913b0082813d8336d0d73aa4..ad5da304c02b6d5bc13e154448885e31809f6777 100644 (file)
@@ -1,2 +1,3 @@
 test.sh
 lang.sh
+tags
diff --git a/main.py b/main.py
index 9261a7337f860ea57b1f8af716b5c0178091d689..26939a9c49f68a2d216d95fab2fd7b7e6092a32c 100644 (file)
--- a/main.py
+++ b/main.py
@@ -2,7 +2,7 @@ import sys
 from enum import IntEnum, auto
 import pprint
 
-from typing import Tuple, List
+from typing import Tuple, List, cast
 
 DEBUG = True
 
@@ -335,10 +335,10 @@ class Lexer:
         return None
 
     def lex(self) -> List[Token]:
-        tokens: [Token] = []
-        token = self.next_token()
+        tokens: List[Token] = []
+        token: Token | None = self.next_token()
         while token != None:
-            tokens.append(token)
+            tokens.append(cast(Token, token))
             token = self.next_token()
         return tokens
 
@@ -353,7 +353,7 @@ class AstNodeType(IntEnum):
     BINARY_OP = auto()
     COUNT = auto()
 
-ast_node_type_as_str_map: {AstNodeType : str} = {
+ast_node_type_as_str_map: dict[AstNodeType, str] = {
     AstNodeType.EXPR    : "Expr",
     AstNodeType.STMT    : "Stmt",
     AstNodeType.INT     : "Int",
@@ -444,6 +444,20 @@ class ParseError(IntEnum):
     NAH = auto()
     COUNT = auto()
 
+parse_error_as_str_map: dict[ParseError, str] = {
+    ParseError.EOF : "Eof",
+    ParseError.UNEXPECTED_TOKEN : "Unexpected Token",
+    ParseError.NAH : "Nah",
+}
+assert len(parse_error_as_str_map) == ParseError.COUNT-1, "Every ParseError is not handled in parse_error_as_str_map"
+
+class ParseException(Exception):
+    def __init__(self, typ: ParseError):
+        self.typ = typ
+
+    def __repr__(self):
+        return f"Parse Exception: {parse_error_as_str_map[self.typ]}"
+
 class Parser:
     def __init__(self, tokens):
         self.tokens = tokens
@@ -453,12 +467,14 @@ class Parser:
 
     def parse(self) -> AstNode:
         stmt = self.parseStatement()
+        assert isinstance(stmt, AstNode)
         return stmt
 
-    def parseStatement(self) -> AstNode | ParseError | None:
+    def parseStatement(self) -> AstNode | None:
         tokens = self.tokens
         # Variable name
         var_name_ast = self.parseIdentifier()
+        # TODO: Check for var_name_ast
         var_type_ast = None
         # Check if colon is there
         if len(tokens) >= 1 and tokens[0].typ == TokenType.COLON:
@@ -475,16 +491,23 @@ class Parser:
         # if var_type_ast != None:
         #     dlog(var_type_ast)
 
-        equal = None
+        equal: Token | None = None
         if len(tokens) >= 1 and tokens[0].typ == TokenType.EQUAL:
             equal = tokens.pop(0)
 
+        if equal == None:
+            if len(tokens) <= 0: raise ParseException(ParseError.EOF)
+            semicolon = tokens.pop(0)
+            if semicolon.typ != TokenType.SEMICOLON:
+                fatal("We don't support Statements with more than one expressions yet!")
+
         expr = self.parseExpression()
 
         # dlog(f"EXPR: {expr}")
 
         if isinstance(expr, ParseError):
             if expr == ParseError.EOF:
+                assert isinstance(equal, Token)
                 self.syntax_error(f"Expected ; but reached EOF", equal)
 
         semicolon = tokens.pop(0)
@@ -494,14 +517,14 @@ class Parser:
         return AstNodeStatement(var_name_ast.token, var_name_ast, var_type_ast, expr)
 
 
-    def parseIdentifier(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
-        if self.tokens[0].typ != TokenType.IDENT: return ParseError.UNEXPECTED_TOKEN
+    def parseIdentifier(self) -> AstNode | None:
+        if len(self.tokens) <= 0: raise ParseException(ParseError.EOF)
+        if self.tokens[0].typ != TokenType.IDENT: return ParseException(ParseError.UNEXPECTED_TOKEN)
         ident_token = self.tokens.pop(0)
         return AstNodeIdentifier(ident_token, ident_token.lexeme)
 
-    def parseExpression(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseExpression(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens[0]
 
         lhs = self.parseName()
@@ -530,15 +553,15 @@ class Parser:
 
         return current_node
 
-    def parseName(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseName(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         name = self.parseLiteralValue()
         if isinstance(name, ParseError):
             return self.parseIdentifier()
         return name
 
-    def parseLiteralValue(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseLiteralValue(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens.pop(0)
         if t.typ == TokenType.INT:
             return AstNodeInt(t, int(t.lexeme))
@@ -549,7 +572,7 @@ class Parser:
 
         return None
 
-    def parseBinOp(self) -> AstNode | ParseError | None:
+    def parseBinOp(self) -> AstNode | None:
         # TODO: Check if Operator predecence is correct
 
         arithmeticOp = self.parseArithmeticOp()
@@ -568,8 +591,8 @@ class Parser:
 
         assert False, "UNREACHABLE!"
 
-    def parseComparisonOp(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseComparisonOp(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens.pop(0)
 
         if t.typ in [ TokenType.GT, TokenType.GTE, TokenType.LT, TokenType.LTE, TokenType.EQUAL_EQUAL, TokenType.NOT_EQUAL ]:
@@ -577,8 +600,8 @@ class Parser:
 
         return None
 
-    def parseLogicalOp(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseLogicalOp(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens.pop(0)
 
         if t.typ in [ TokenType.LOGICAL_AND, TokenType.LOGICAL_OR ]:
@@ -586,8 +609,8 @@ class Parser:
 
         return None
 
-    def parseArithmeticOp(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseArithmeticOp(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens.pop(0)
 
         if t.typ in [ TokenType.PLUS, TokenType.MINUS, TokenType.DIVIDE, TokenType.MODULUS ]:
@@ -595,8 +618,8 @@ class Parser:
 
         return None
 
-    def parseBinaryArithmeticOp(self) -> AstNode | ParseError | None:
-        if len(self.tokens) <= 0: return ParseError.EOF
+    def parseBinaryArithmeticOp(self) -> AstNode | None:
+        if len(self.tokens) <= 0: return ParseException(ParseError.EOF)
         t = self.tokens.pop(0)
 
         if t.typ in [ TokenType.BINARY_AND, TokenType.BINARY_OR, TokenType.BINARY_NOT ]:
@@ -621,10 +644,10 @@ def main():
     # TODO: Parse
     parser = Parser(tokens)
 
-    for t in tokens:
-        pprint.pp(str(t))
+    for t in tokens:
+        pprint.pp(str(t))
 
-    print(parser.parse())
+    print(parser.parse())
 
 if __name__ == '__main__':
     main()