Skip to content
Snippets Groups Projects
Commit ab5ca285 authored by li213's avatar li213
Browse files

working on integrating improved parser with verification

parent 1d524264
No related branches found
No related tags found
No related merge requests found
...@@ -42,35 +42,35 @@ class State: ...@@ -42,35 +42,35 @@ class State:
if __name__ == "__main__": if __name__ == "__main__":
input_code_name = './example_controller4.py' input_code_name = './demo/example_controller4.py'
scenario = Scenario() scenario = Scenario()
car = CarAgent('car1', file_name=input_code_name) car = CarAgent('car1', file_name=input_code_name)
scenario.add_agent(car) scenario.add_agent(car)
car = NPCAgent('car2') car = NPCAgent('car2')
scenario.add_agent(car) scenario.add_agent(car)
car = NPCAgent('car3') # car = NPCAgent('car3')
scenario.add_agent(car) # scenario.add_agent(car)
car = NPCAgent('car4') # car = NPCAgent('car4')
scenario.add_agent(car) # scenario.add_agent(car)
tmp_map = SimpleMap3() tmp_map = SimpleMap3()
scenario.set_map(tmp_map) scenario.set_map(tmp_map)
scenario.set_init( scenario.set_init(
[ [
[[0, -0.2, 0, 1.0],[0.01, 0.2, 0, 1.0]], [[0, -0.2, 0, 1.0],[0.01, 0.2, 0, 1.0]],
[[10, 0, 0, 0.5],[10, 0, 0, 0.5]], [[10, 0, 0, 0.5],[10, 0, 0, 0.5]],
[[20, 3, 0, 0.5],[20, 3, 0, 0.5]], # [[20, 3, 0, 0.5],[20, 3, 0, 0.5]],
[[30, 0, 0, 0.5],[30, 0, 0, 0.5]], # [[30, 0, 0, 0.5],[30, 0, 0, 0.5]],
], ],
[ [
(VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
(VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
(VehicleMode.Normal, LaneMode.Lane0, LaneObjectMode.Vehicle), # (VehicleMode.Normal, LaneMode.Lane0, LaneObjectMode.Vehicle),
(VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), # (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
] ]
) )
traces = scenario.simulate(70, 0.05) # traces = scenario.simulate(70, 0.05)
# traces = scenario.verify(70, 0.05) traces = scenario.verify(70, 0.05)
# fig = plt.figure(2) # fig = plt.figure(2)
# fig = plot_map(tmp_map, 'g', fig) # fig = plot_map(tmp_map, 'g', fig)
......
...@@ -18,6 +18,7 @@ class Verifier: ...@@ -18,6 +18,7 @@ class Verifier:
self, self,
init_list: List[float], init_list: List[float],
init_mode_list: List[str], init_mode_list: List[str],
static_list: List[str],
agent_list:List[BaseAgent], agent_list:List[BaseAgent],
transition_graph, transition_graph,
time_horizon, time_horizon,
...@@ -29,6 +30,8 @@ class Verifier: ...@@ -29,6 +30,8 @@ class Verifier:
root.init[agent.id] = init_list[i] root.init[agent.id] = init_list[i]
init_mode = [elem.name for elem in init_mode_list[i]] init_mode = [elem.name for elem in init_mode_list[i]]
root.mode[agent.id] = init_mode root.mode[agent.id] = init_mode
init_static = [elem.name for elem in static_list[i]]
root.static[agent.id] = init_static
root.agent[agent.id] = agent root.agent[agent.id] = agent
root.type = 'reachtube' root.type = 'reachtube'
self.reachtube_tree_root = root self.reachtube_tree_root = root
......
...@@ -13,6 +13,7 @@ from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap ...@@ -13,6 +13,7 @@ from dryvr_plus_plus.scene_verifier.map.lane_map import LaneMap
from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane from dryvr_plus_plus.scene_verifier.map.lane_segment import AbstractLane
from dryvr_plus_plus.scene_verifier.utils.utils import * from dryvr_plus_plus.scene_verifier.utils.utils import *
from dryvr_plus_plus.scene_verifier.code_parser.parser import Reduction, ReductionType from dryvr_plus_plus.scene_verifier.code_parser.parser import Reduction, ReductionType
from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
class LogicTreeNode: class LogicTreeNode:
def __init__(self, data, child = [], val = None, mode_guard = None): def __init__(self, data, child = [], val = None, mode_guard = None):
...@@ -108,7 +109,7 @@ class GuardExpressionAst: ...@@ -108,7 +109,7 @@ class GuardExpressionAst:
# Therefore we are not be able to get free symbols from it # Therefore we are not be able to get free symbols from it
# Thus we need to replace "==" to something else # Thus we need to replace "==" to something else
symbols_map = {v: k for k, v in self.cont_variables.items()} symbols_map = {v: k for k, v in self.cont_variables.items() if k in guard_str}
for vars in reversed(self.cont_variables): for vars in reversed(self.cont_variables):
guard_str = guard_str.replace(vars, self.cont_variables[vars]) guard_str = guard_str.replace(vars, self.cont_variables[vars])
...@@ -502,7 +503,7 @@ class GuardExpressionAst: ...@@ -502,7 +503,7 @@ class GuardExpressionAst:
res = res and tmp res = res and tmp
return res return res
def _evaluate_guard_disc(self, root, agent, disc_var_dict, cont_var_dict, lane_map): def _evaluate_guard_disc(self, root, agent:BaseAgent, disc_var_dict, cont_var_dict, lane_map):
""" """
Recursively called function to evaluate guard with only discrete variables Recursively called function to evaluate guard with only discrete variables
The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by The function will evaluate all guards with discrete variables and replace the nodes with discrete guards by
...@@ -578,9 +579,9 @@ class GuardExpressionAst: ...@@ -578,9 +579,9 @@ class GuardExpressionAst:
root = ast.parse('False').body[0].value root = ast.parse('False').body[0].value
else: else:
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
for mode_name in agent.controller.modes: for mode_name in agent.controller.mode_defs:
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
if res in agent.controller.modes[mode_name]: if res in agent.controller.mode_defs[mode_name].modes:
res = mode_name+'.'+res res = mode_name+'.'+res
break break
root = ast.parse(str(res)).body[0].value root = ast.parse(str(res)).body[0].value
...@@ -593,14 +594,14 @@ class GuardExpressionAst: ...@@ -593,14 +594,14 @@ class GuardExpressionAst:
if expr in disc_var_dict: if expr in disc_var_dict:
val = disc_var_dict[expr] val = disc_var_dict[expr]
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
for mode_name in agent.controller.modes: for mode_name in agent.controller.mode_defs:
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
if val in agent.controller.modes[mode_name]: if val in agent.controller.mode_defs[mode_name].modes:
val = mode_name+'.'+val val = mode_name+'.'+val
break break
return val, root return val, root
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
elif root.value.id in agent.controller.modes: elif root.value.id in agent.controller.mode_defs:
return expr, root return expr, root
else: else:
return True, root return True, root
...@@ -622,9 +623,9 @@ class GuardExpressionAst: ...@@ -622,9 +623,9 @@ class GuardExpressionAst:
if expr in disc_var_dict: if expr in disc_var_dict:
val = disc_var_dict[expr] val = disc_var_dict[expr]
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
for mode_name in agent.controller.modes: for mode_name in agent.controller.mode_defs:
# TODO-PARSER: Handle This # TODO-PARSER: Handle This
if val in agent.controller.modes[mode_name]: if val in agent.controller.mode_defs[mode_name].modes:
val = mode_name + '.' + val val = mode_name + '.' + val
break break
return val, root return val, root
......
...@@ -2,9 +2,9 @@ from typing import Tuple, List, Dict, Any ...@@ -2,9 +2,9 @@ from typing import Tuple, List, Dict, Any
import copy import copy
import itertools import itertools
import warnings import warnings
from collections import defaultdict
import numpy as np import numpy as np
from sympy import Q
from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent from dryvr_plus_plus.scene_verifier.agents.base_agent import BaseAgent
from dryvr_plus_plus.scene_verifier.automaton.guard import GuardExpressionAst from dryvr_plus_plus.scene_verifier.automaton.guard import GuardExpressionAst
...@@ -90,6 +90,7 @@ class Scenario: ...@@ -90,6 +90,7 @@ class Scenario:
def verify(self, time_horizon, time_step): def verify(self, time_horizon, time_step):
init_list = [] init_list = []
init_mode_list = [] init_mode_list = []
static_list = []
agent_list = [] agent_list = []
for agent_id in self.agent_dict: for agent_id in self.agent_dict:
init = self.init_dict[agent_id] init = self.init_dict[agent_id]
...@@ -98,14 +99,36 @@ class Scenario: ...@@ -98,14 +99,36 @@ class Scenario:
init = [init, init] init = [init, init]
init_list.append(init) init_list.append(init)
init_mode_list.append(self.init_mode_dict[agent_id]) init_mode_list.append(self.init_mode_dict[agent_id])
static_list.append(self.static_dict[agent_id])
agent_list.append(self.agent_dict[agent_id]) agent_list.append(self.agent_dict[agent_id])
return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, time_step, self.map) return self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, agent_list, self, time_horizon, time_step, self.map)
def apply_reset(self, agent: BaseAgent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]:
# reset_expr = ResetExpression(reset_list)
# continuous_variable_dict, discrete_variable_dict, _ = self.sensor.sense(self, agent, all_agent_state, self.map)
# dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map)
# rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map)
# return dest, rect
dest = []
rect = []
agent_state, agent_mode, agent_static = all_agent_state[agent.id]
def apply_reset(self, agent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]: # First get the transition destinations
reset_expr = ResetExpression(reset_list) dest = copy.deepcopy(agent_mode)
continuous_variable_dict, discrete_variable_dict, _ = self.sensor.sense(self, agent, all_agent_state, self.map) possible_dest = [[elem] for elem in dest]
dest = reset_expr.get_dest(agent, all_agent_state[agent.id], discrete_variable_dict, self.map) ego_type = agent.controller.ego_type
rect = reset_expr.apply_reset_continuous(agent, continuous_variable_dict, self.map) for reset in reset_list:
reset_variable = reset.var
expr = reset.expr
if "mode" in reset_variable:
for var_loc, discrete_variable_ego in enumerate(agent.controller.state_defs[ego_type].disc):
if discrete_variable_ego == reset_variable:
break
if 'map' in expr:
for var in discrete_variable_dict
# Then get the transition updated rect
return dest, rect return dest, rect
def apply_cont_var_updater(self,cont_var_dict, updater): def apply_cont_var_updater(self,cont_var_dict, updater):
...@@ -131,11 +154,11 @@ class Scenario: ...@@ -131,11 +154,11 @@ class Scenario:
agent_mode = node.mode[agent_id] agent_mode = node.mode[agent_id]
# TODO-PARSER: update how we get all next modes # TODO-PARSER: update how we get all next modes
# The getNextModes function will return # The getNextModes function will return
paths = agent.controller.getNextModes()
state_dict = {} state_dict = {}
for tmp in node.agent: for tmp in node.agent:
state_dict[tmp] = (node.trace[tmp][0], node.mode[tmp], node.static[tmp]) state_dict[tmp] = (node.trace[tmp][0], node.mode[tmp], node.static[tmp])
cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map) cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
paths = agent.controller.getNextModes()
for path in paths: for path in paths:
guard_list = path[0] guard_list = path[0]
reset = path[1] reset = path[1]
...@@ -251,16 +274,11 @@ class Scenario: ...@@ -251,16 +274,11 @@ class Scenario:
cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense(self, agent, state_dict, self.map) cont_var_dict_template, discrete_variable_dict, length_dict = self.sensor.sense(self, agent, state_dict, self.map)
# TODO-PARSER: Get equivalent for this function # TODO-PARSER: Get equivalent for this function
paths = agent.controller.getNextModes(agent_mode) paths = agent.controller.getNextModes()
for path in paths: for path in paths:
# Construct the guard expression # Construct the guard expression
guard_list = [] guard_list = path[0]
reset_list = [] reset = path[1]
for item in path:
if isinstance(item, Guard):
guard_list.append(item)
elif isinstance(item, Reset):
reset_list.append(item)
guard_expression = GuardExpressionAst(guard_list) guard_expression = GuardExpressionAst(guard_list)
cont_var_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, length_dict) cont_var_updater = guard_expression.parse_any_all_new(cont_var_dict_template, discrete_variable_dict, length_dict)
...@@ -269,9 +287,9 @@ class Scenario: ...@@ -269,9 +287,9 @@ class Scenario:
if not guard_can_satisfied: if not guard_can_satisfied:
continue continue
if agent_id not in agent_guard_dict: if agent_id not in agent_guard_dict:
agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list)] agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset)]
else: else:
agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset_list)) agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset))
trace_length = int(len(list(node.trace.values())[0])/2) trace_length = int(len(list(node.trace.values())[0])/2)
guard_hits = [] guard_hits = []
...@@ -285,10 +303,11 @@ class Scenario: ...@@ -285,10 +303,11 @@ class Scenario:
for agent_id in agent_guard_dict: for agent_id in agent_guard_dict:
agent:BaseAgent = self.agent_dict[agent_id] agent:BaseAgent = self.agent_dict[agent_id]
agent_state, agent_mode = state_dict[agent_id] agent_state, agent_mode, agent_static = state_dict[agent_id]
agent_state = agent_state[1:] agent_state = agent_state[1:]
continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map) continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map)
for guard_expression, continuous_variable_updater, discrete_variable_dict, reset_list in agent_guard_dict[agent_id]: resets = defaultdict(list)
for guard_expression, continuous_variable_updater, discrete_variable_dict, reset in agent_guard_dict[agent_id]:
new_cont_var_dict = copy.deepcopy(continuous_variable_dict) new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression) one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression)
...@@ -298,8 +317,17 @@ class Scenario: ...@@ -298,8 +317,17 @@ class Scenario:
continue continue
guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont(agent, new_cont_var_dict, self.map) guard_satisfied, is_contained = one_step_guard.evaluate_guard_cont(agent, new_cont_var_dict, self.map)
any_contained = any_contained or is_contained any_contained = any_contained or is_contained
# TODO: Can we also store the cont and disc var dict so we don't have to call sensor again?
if guard_satisfied: if guard_satisfied:
hits.append((agent_id, guard_list, reset_list)) reset_expr = ResetExpression(reset)
resets[reset_expr.var].append(reset_expr)
# Perform combination over all possible resets to generate all possible real resets
combined_reset_list = list(itertools.product(*resets.values()))
if len(combined_reset_list)==1 and combined_reset_list[0]==():
continue
for i in range(len(combined_reset_list)):
# a list of reset expression
hits.append((agent_id, guard_expression, combined_reset_list[i]))
if hits != []: if hits != []:
guard_hits.append((hits, state_dict, idx)) guard_hits.append((hits, state_dict, idx))
guard_hit_bool = True guard_hit_bool = True
...@@ -312,6 +340,7 @@ class Scenario: ...@@ -312,6 +340,7 @@ class Scenario:
reset_idx_dict = {} reset_idx_dict = {}
for hits, all_agent_state, hit_idx in guard_hits: for hits, all_agent_state, hit_idx in guard_hits:
for agent_id, guard_list, reset_list in hits: for agent_id, guard_list, reset_list in hits:
# TODO: Need to change this function to handle the new reset expression and then I am done
dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state) dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
if agent_id not in reset_dict: if agent_id not in reset_dict:
reset_dict[agent_id] = {} reset_dict[agent_id] = {}
......
numpy~=1.22.1 numpy~=1.22.1
scipy~=1.8.0 scipy~=1.6.1
matplotlib~=3.4.2 matplotlib~=3.4.2
polytope~=0.2.3 polytope~=0.2.3
pyvista~=0.32.1 pyvista~=0.32.1
......
...@@ -13,7 +13,7 @@ setup( ...@@ -13,7 +13,7 @@ setup(
python_requires='>=3.8', python_requires='>=3.8',
install_requires=[ install_requires=[
"numpy~=1.22.1", "numpy~=1.22.1",
"scipy~=1.8.0", "scipy~=1.6.1",
"matplotlib~=3.4.2", "matplotlib~=3.4.2",
"polytope~=0.2.3", "polytope~=0.2.3",
"pyvista~=0.32.1", "pyvista~=0.32.1",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment