Skip to content
Snippets Groups Projects
Commit 0f9cfed7 authored by crides's avatar crides
Browse files

parse: fix if handling

parent 6b99b0d6
No related branches found
No related tags found
No related merge requests found
......@@ -3,8 +3,30 @@ from typing import List, Dict, Union, Optional, TypeAlias, Any
from dataclasses import dataclass
from enum import Enum, auto
class Argument():
pass
def dbg(msg, *rest):
print(f"\x1b[31m{msg}\x1b[m", end="")
for i, a in enumerate(rest[:5]):
print(f" \x1b[3{i+2}m{a}\x1b[m", end="")
if rest[5:]:
print("", rest[5:])
else:
print()
ScopeValue: TypeAlias = Union[ast.AST, "CondVal", "Lambda", Dict[str, "ScopeValue"]] # TODO
@dataclass
class CondValElem:
cond: List[ScopeValue]
val: ScopeValue
def __eq__(self, o) -> bool:
if o == None or len(self.cond) != len(o.cond):
return False
return all(ir_eq(sc, oc) for sc, oc in zip(self.cond, o.cond)) and ir_eq(self.val, o.val)
@dataclass
class CondVal:
elems: List[CondValElem]
class ReductionType(Enum):
Any = auto()
......@@ -30,14 +52,10 @@ class Reduction:
it: str
value: ast.AST
class If:
test: ast.expr
true: ast.expr
false: Optional[ast.expr]
def __init__(self, test, true, false=None):
self.test = test
self.true = true
self.false = false
def __eq__(self, o) -> bool:
if o == None:
return False
return self.op == o.op and self.it == o.it and ir_eq(self.expr, o.expr) and ir_eq(self.value, o.value)
@dataclass
class Lambda:
......@@ -54,35 +72,46 @@ class Lambda:
for a in args:
scope.set(a, ast.arg(a))
ret = None
for node in tree.body:
ret = proc(node, scope)
if isinstance(tree, ast.FunctionDef):
for node in tree.body:
ret = proc(node, scope)
else:
ret = proc(tree.body, scope)
scope.pop()
return Lambda(args, ret)
def apply(self, args: List[ast.expr]) -> ast.expr:
ret = copy.deepcopy(self.body)
args = {k: v for k, v in zip(self.args, args)}
return ArgSubstituter(args).visit(ret)
return ArgSubstituter({k: v for k, v in zip(self.args, args)}).visit(ret)
ast_dump = lambda node, dump=False: ast.dump(node, indent=2) if dump else ast.unparse(node)
def ir_dump(node, dump=False):
if node == None:
return "None"
if isinstance(node, Lambda):
return f"<Lambda args: {node.args} body: {ir_dump(node.body, dump)}>"
if isinstance(node, CondVal):
return f"<CondVal{''.join(f' [{ir_dump(e.val, dump)} if {ir_dump(e.cond, dump)}]' for e in node.elems)}>"
if isinstance(node, ast.If):
return f"<{{{ast_dump(node, dump)}}}>"
if isinstance(node, Reduction):
return f"<Reduction {node.op}({ast_dump(node.expr, dump)} for {node.it} in {ast_dump(node.value, dump)}>"
# if isinstance(node, If):
# if node.false == None:
# return f"<If test: {node.test} true: {ir_dump(node.true)}>"
# return f"<If test: {node.test} true: {ir_dump(node.true)} false: {ir_dump(node.false)}>"
return f"<Reduction {node.op} {ast_dump(node.expr, dump)} for {node.it} in {ast_dump(node.value, dump)}>"
elif isinstance(node, dict):
return "<Object " + " ".join(f"{k}: {ir_dump(v, dump)}" for k, v in node.items()) + ">"
elif isinstance(node, list):
return f"[{', '.join(ir_dump(n, dump) for n in node)}]"
else:
return ast_dump(node, dump)
ScopeValue: TypeAlias = Union[ast.AST, If, Lambda, Dict[str, "ScopeValue"]] # TODO
def ir_eq(a: Optional[ScopeValue], b: Optional[ScopeValue]) -> bool:
return ir_dump(a) == ir_dump(b)
def sv_eq(a: Optional[ScopeValue], b: Optional[ScopeValue]) -> bool:
if isinstance(a, ast.AST) and isinstance(b, ast.AST):
return ir_eq(a, b)
return a == b
class Scope:
scopes: List[Dict[str, ScopeValue]]
def __init__(self):
......@@ -115,7 +144,7 @@ class Scope:
class ArgSubstituter(ast.NodeTransformer):
args: Dict[str, ast.expr]
def __init__(self, args):
def __init__(self, args: Dict[str, ast.expr]):
super().__init__()
self.args = args
......@@ -125,25 +154,93 @@ class ArgSubstituter(ast.NodeTransformer):
self.generic_visit(node)
return node # XXX needed?
def merge_if(test, true, false, scope: Dict[str, ScopeValue]):
def merge_if(test: ast.expr, trues: Scope, falses: Scope, scope: Scope):
# `true`, `false` and `scope` should have the same level
for true, false in zip(trues.scopes, falses.scopes):
merge_if_single(test, true, false, scope)
def merge_if_single(test, true: Dict[str, ScopeValue], false: Dict[str, ScopeValue], scope: Union[Scope, Dict[str, ScopeValue]]):
dbg("merge if single", ir_dump(test), true.keys(), false.keys())
def lookup(s, k):
if isinstance(s, Scope):
return s.lookup(k)
return s.get(k)
def assign(s, k, v):
if isinstance(s, Scope):
s.set(k, v)
else:
s[k] = v
for var in set(true.keys()).union(set(false.keys())):
if true.get(var) != None and false.get(var) != None:
assert isinstance(true.get(var), dict) == isinstance(false.get(var), dict)
if isinstance(true.get(var), dict):
if not isinstance(scope.get(var), dict):
if var in scope:
print("???", var, scope[var])
scope[var] = {}
merge_if(test, true.get(var, {}), false.get(var, {}), scope[var])
var_true, var_false = true.get(var), false.get(var)
if sv_eq(var_true, var_false):
continue
if var_true != None and var_false != None:
assert isinstance(var_true, dict) == isinstance(var_false, dict)
dbg("merge", var, ir_dump(test), ir_dump(var_true), ir_dump(var_false))
if isinstance(var_true, dict):
if not isinstance(lookup(scope, var), dict):
if lookup(scope, var) != None:
dbg("???", var, lookup(scope, var))
dbg("if.merge.obj.init")
assign(scope, var, {})
var_true_emp, var_false_emp, var_scope = true.get(var, {}), false.get(var, {}), lookup(scope, var)
assert isinstance(var_true_emp, dict) and isinstance(var_false_emp, dict) and isinstance(var_scope, dict)
merge_if_single(test, var_true_emp, var_false_emp, var_scope)
else:
if true.get(var) == None:
scope[var] = ast.If(ast.UnaryOp(ast.Not(), test), [false.get(var)], [])
elif false.get(var) == None:
scope[var] = ast.If(test, [true.get(var)], [])
else:
scope[var] = ast.If(test, [true.get(var)], [false.get(var)])
if_val = merge_if_val(test, var_true, var_false, lookup(scope, var))
print(ir_dump(if_val))
assign(scope, var, if_val)
def merge_if_val(test, true: Optional[ScopeValue], false: Optional[ScopeValue], orig: Optional[ScopeValue]) -> CondVal:
dbg("merge val", ir_dump(test), ir_dump(true), ir_dump(false), ir_dump(orig), false == orig)
def merge_cond(test, val):
if isinstance(val, CondVal):
for elem in val.elems:
elem.cond.append(test)
return val
else:
return CondVal([CondValElem([test], val)])
def as_cv(a):
if a == None:
return None
if not isinstance(a, CondVal):
return CondVal([CondValElem([], a)])
return a
true, false, orig = as_cv(true), as_cv(false), as_cv(orig)
dbg("merge convert", ir_dump(true), ir_dump(false), ir_dump(orig))
if orig != None:
for orig_cve in orig.elems:
if true != None and orig_cve in true.elems:
true.elems.remove(orig_cve)
if false != None and orig_cve in false.elems:
false.elems.remove(orig_cve)
dbg("merge diff", ir_dump(true), ir_dump(false), ir_dump(orig))
if true != None and len(true.elems) == 0:
true = None
if false != None and len(false.elems) == 0:
false = None
if true == None and false == None:
raise Exception("no need for merge?")
elif true == None:
ret = merge_cond(ast.UnaryOp(ast.Not(), test), false)
elif false == None:
ret = merge_cond(test, true)
elif sv_eq(false, orig):
if isinstance(orig, CondVal):
ret = CondVal(merge_cond(test, true).elems + orig.elems)
else:
assert orig != None
ret = CondVal(merge_cond(test, true).elems + [CondValElem([], orig)])
else:
merge_true, merge_false = merge_cond(test, true), merge_cond(ast.UnaryOp(ast.Not(), test), false)
ret = CondVal(merge_true.elems + merge_false.elems)
if orig != None:
return CondVal(ret.elems + orig.elems)
return ret
def proc_assign(target: ast.AST, val, scope: Scope):
dbg("proc_assign", ast.unparse(target), val)
if isinstance(target, ast.Name):
if isinstance(val, ast.AST):
scope.set(target.id, proc(val, scope))
......@@ -151,6 +248,7 @@ def proc_assign(target: ast.AST, val, scope: Scope):
scope.set(target.id, val)
elif isinstance(target, ast.Attribute):
if proc(target.value, scope) == None:
dbg("proc.assign.obj.init")
proc_assign(target.value, {}, scope)
obj = proc(target.value, scope)
obj[target.attr] = val
......@@ -169,16 +267,21 @@ def proc(node: ast.AST, scope: Scope) -> Any:
elif isinstance(node, ast.If):
test = proc(node.test, scope)
true_scope = copy.deepcopy(scope)
true_scope.push()
for true in node.body:
dbg("true", true)
proc(true, true_scope)
false_scope = copy.deepcopy(scope)
false_scope.push()
for false in node.orelse:
proc(false, false_scope)
merge_if(test, true_scope.scopes[0], false_scope.scopes[0], scope.scopes[0])
merge_if(test, true_scope, false_scope, scope)
# Definition/Assignment
elif isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom):
for alias in node.names:
if alias.asname == None:
scope.set(alias.name, ast.arg(alias.name))
else:
scope.set(alias.asname, ast.arg(alias.asname))
elif isinstance(node, ast.Assign):
if len(node.targets) == 1:
proc_assign(node.targets[0], node.value, scope)
......@@ -196,7 +299,7 @@ def proc(node: ast.AST, scope: Scope) -> Any:
elif isinstance(node, ast.Lambda):
return Lambda.from_ast(node, scope)
elif isinstance(node, ast.ClassDef):
pass
scope.set(node.name, ast.arg(node.name))
# Expressions
elif isinstance(node, ast.UnaryOp):
......@@ -243,8 +346,6 @@ def proc(node: ast.AST, scope: Scope) -> Any:
scope.pop()
expr = cond_trans(expr, ast.BoolOp(ast.And(), ifs))
return Reduction(op, expr, target.id, proc(iter, scope))
print(ast.dump(node))
print(proc(node.func.value, scope))
elif isinstance(node, ast.Return):
return proc(node.value, scope) if node.value != None else None
elif isinstance(node, ast.IfExp):
......@@ -266,8 +367,8 @@ def parse(fn: str):
root = ast.parse(cont, fn)
scope = Scope()
proc(root, scope)
scope.dump(True)
print(ir_dump(scope.lookup("controller").body["mode"].test))
scope.dump()
# print(ir_dump(scope.lookup("controller").body["mode"].test))
if __name__ == "__main__":
import sys
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment