Skip to content
Snippets Groups Projects
Commit 3c12dcb3 authored by crides's avatar crides
Browse files

perf(guard): pre-compile guard & assert ASTs before evaling

parent 2da8e420
No related branches found
No related tags found
2 merge requests!7Keyi tmp,!1Merge refactor to main
...@@ -922,14 +922,12 @@ class GuardExpressionAst: ...@@ -922,14 +922,12 @@ class GuardExpressionAst:
cont_var_updater = {} cont_var_updater = {}
for i in range(len(self.ast_list)): for i in range(len(self.ast_list)):
root = self.ast_list[i] root = self.ast_list[i]
j = 0 nodes = ast.walk(root)
while j < sum(1 for _ in ast.walk(root)): for node in nodes:
# TODO: Find a faster way to access nodes in the tree # TODO: Find a faster way to access nodes in the tree
node = list(ast.walk(root))[j]
if isinstance(node, Reduction): if isinstance(node, Reduction):
new_node = self.unroll_any_all_new(node, cont_var_dict, disc_var_dict, len_dict, cont_var_updater) new_node = self.unroll_any_all_new(node, cont_var_dict, disc_var_dict, len_dict, cont_var_updater)
root = NodeSubstituter(node, new_node).visit(root) root = NodeSubstituter(node, new_node).visit(root)
j += 1
self.ast_list[i] = root self.ast_list[i] = root
return cont_var_updater return cont_var_updater
......
...@@ -21,6 +21,17 @@ from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap ...@@ -21,6 +21,17 @@ from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
EGO, OTHERS = "ego", "others" EGO, OTHERS = "ego", "others"
def merge_conds(c):
if len(c) == 0:
return ast.Constant(True)
if len(c) == 1:
return c[0]
else:
return ast.BoolOp(ast.And(), c)
def compile_expr(e):
return compile(ast.fix_missing_locations(ast.Expression(e)), "", "eval")
class Scenario: class Scenario:
def __init__(self): def __init__(self):
self.agent_dict = {} self.agent_dict = {}
...@@ -219,6 +230,7 @@ class Scenario: ...@@ -219,6 +230,7 @@ class Scenario:
# For each agent # For each agent
agent_guard_dict = defaultdict(list) agent_guard_dict = defaultdict(list)
compiled_asserts = {}
for agent_id in node.agent: for agent_id in node.agent:
# Get guard # Get guard
...@@ -226,6 +238,12 @@ class Scenario: ...@@ -226,6 +238,12 @@ class Scenario:
agent_mode = node.mode[agent_id] agent_mode = node.mode[agent_id]
if agent.controller.controller == None: if agent.controller.controller == None:
continue continue
def compile_assert(i, a):
p = compile_expr(merge_conds(a.pre))
cond = compile_expr(a.cond)
label = a.label if a.label != None else f"<assert {i}>"
return p, cond, label
compiled_asserts[agent_id] = [compile_assert(i, a) for i, a in enumerate(agent.controller.controller.asserts)]
# TODO-PARSER: update how we get all next modes # TODO-PARSER: update how we get all next modes
# The getNextModes function will return # The getNextModes function will return
state_dict = {} state_dict = {}
...@@ -239,7 +257,8 @@ class Scenario: ...@@ -239,7 +257,8 @@ class Scenario:
# copy.deepcopy(guard_expression.ast_list[0].operand) # copy.deepcopy(guard_expression.ast_list[0].operand)
# can_satisfy = guard_expression.fast_pre_process(discrete_variable_dict) # can_satisfy = guard_expression.fast_pre_process(discrete_variable_dict)
continuous_variable_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, len_dict) continuous_variable_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, len_dict)
agent_guard_dict[agent_id].append((guard_expression, continuous_variable_updater, copy.deepcopy(discrete_variable_dict), reset)) guard_comp = compile_expr(merge_conds(guard_expression.ast_list))
agent_guard_dict[agent_id].append((guard_comp, continuous_variable_updater, discrete_variable_dict, reset))
transitions = defaultdict(list) transitions = defaultdict(list)
# TODO: We can probably rewrite how guard hit are detected and resets are handled for simulation # TODO: We can probably rewrite how guard hit are detected and resets are handled for simulation
...@@ -286,12 +305,10 @@ class Scenario: ...@@ -286,12 +305,10 @@ class Scenario:
return eval(compile(ast.fix_missing_locations(ast.Expression(expr)), "", "eval"), env) return eval(compile(ast.fix_missing_locations(ast.Expression(expr)), "", "eval"), env)
# Check safety conditions # Check safety conditions
for i, a in enumerate(agent.controller.controller.asserts): for pre, cond, label in compiled_asserts[agent_id]:
pre_sat = all(eval_expr(p, packed_env) for p in a.pre) if eval(pre, packed_env):
if pre_sat: del packed_env["__builtins__"]
cond_sat = eval_expr(a.cond, packed_env) if not eval(cond, packed_env):
if not cond_sat:
label = a.label if a.label != None else f"<assert {i}>"
del packed_env["__builtins__"] del packed_env["__builtins__"]
print(f"assert hit for {agent_id}: \"{label}\" @ {packed_env}") print(f"assert hit for {agent_id}: \"{label}\" @ {packed_env}")
asserts[agent_id].append(label) asserts[agent_id].append(label)
...@@ -299,18 +316,11 @@ class Scenario: ...@@ -299,18 +316,11 @@ class Scenario:
continue continue
all_resets = defaultdict(list) all_resets = defaultdict(list)
for guard_expression, continuous_variable_updater, discrete_variable_dict, reset in agent_guard_dict[agent_id]: for guard_comp, continuous_variable_updater, discrete_variable_dict, reset in agent_guard_dict[agent_id]:
new_cont_var_dict = copy.deepcopy(continuous_variable_dict) new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
one_step_guard = guard_expression.ast_list
self.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater) self.apply_cont_var_updater(new_cont_var_dict, continuous_variable_updater)
env = pack_env(agent, new_cont_var_dict, discrete_variable_dict, self.map) env = pack_env(agent, new_cont_var_dict, discrete_variable_dict, self.map)
if len(one_step_guard) == 0: guard_satisfied = eval(guard_comp, env)
raise ValueError("empty guard")
if len(one_step_guard) == 1:
one_step_guard = one_step_guard[0]
elif len(one_step_guard) > 1:
one_step_guard = ast.BoolOp(ast.And(), one_step_guard)
guard_satisfied = eval_expr(one_step_guard, env)
# Collect all the hit guards for this agent at this time step # Collect all the hit guards for this agent at this time step
if guard_satisfied: if guard_satisfied:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment