From cc268dc0cfa57003f48dcd80c8b929916bcd89fc Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Fri, 26 Aug 2022 12:32:57 -0500 Subject: [PATCH] sim done? --- demo/vehicle/demo10.py | 71 ++++++----- demo/vehicle/demo11.py | 76 ++++++----- demo/vehicle/demo7.py | 38 ++++-- verse/analysis/analysis_tree.py | 2 +- verse/analysis/incremental.py | 14 +- verse/analysis/simulator.py | 104 +++++++++++---- verse/parser/parser.py | 5 +- verse/scenario/scenario.py | 220 ++++++++++++++++++++------------ 8 files changed, 341 insertions(+), 189 deletions(-) diff --git a/demo/vehicle/demo10.py b/demo/vehicle/demo10.py index 4cae82fc..9cbcf778 100644 --- a/demo/vehicle/demo10.py +++ b/demo/vehicle/demo10.py @@ -2,13 +2,12 @@ from verse.agents.example_agent import CarAgent, NPCAgent from verse.map.example_map import SimpleMap2 from verse import Scenario from verse.plotter.plotter2D import * -# from verse.plotter.plotter2D_old import plot_reachtube_tree, plot_map +from verse.plotter.plotter2D_old import plot_reachtube_tree, plot_map, plot_simulation_tree from noisy_sensor import NoisyVehicleSensor from enum import Enum, auto import plotly.graph_objects as go -import matplotlib.pyplot as plt - +import matplotlib.pyplot as plt class LaneObjectMode(Enum): Vehicle = auto() @@ -54,7 +53,7 @@ if __name__ == "__main__": scenario.add_agent(car) tmp_map = SimpleMap2() scenario.set_map(tmp_map) - scenario.set_sensor(NoisyVehicleSensor((1, 1), (0, 0))) + scenario.set_sensor(NoisyVehicleSensor((1,1), (0,0))) scenario.set_init( [ [[5, -0.1, 0, 1.0], [5.5, 0.1, 0, 1.1]], @@ -65,34 +64,38 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), ] ) - scenario.init_seg_length = 5 - traces = scenario.verify(40, 0.05) - - # fig = plt.figure(2) - # fig = plot_reachtube_tree(traces.root, 'car1', 0, [1], 'b', fig) - # fig = plot_reachtube_tree(traces.root, 'car2', 0, [1], 'r', fig) - - scenario1 = Scenario() - car1 = CarAgent('car1', file_name=input_code_name) - scenario1.add_agent(car1) - car1 = NPCAgent('car2') - scenario1.add_agent(car1) - tmp_map1 = SimpleMap2() - scenario1.set_map(tmp_map1) - # scenario1.set_sensor(NoisyVehicleSensor((0,1), (0,0))) - scenario1.set_init( - [ - [[5, -0.1, 0, 1.0], [6, 0.1, 0, 1.0]], - [[20, 0, 0, 0.5], [20, 0, 0, 0.5]], - ], - [ - (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), - (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), - ] - ) - scenario1.init_seg_length = 5 - traces1 = scenario1.verify(40, 0.05, params={"bloating_method": 'GLOBAL'}) + scenario.config.init_seg_length = 5 + traces = scenario.simulate(40, 0.05) + + fig = plt.figure(2) + fig = plot_simulation_tree(traces.root, 'car1', 0, [1], 'b', fig) + fig = plot_simulation_tree(traces.root, 'car2', 0, [1], 'r', fig) + + # scenario1 = Scenario() + # car1 = CarAgent('car1', file_name=input_code_name) + # scenario1.add_agent(car1) + # car1 = NPCAgent('car2') + # scenario1.add_agent(car1) + # tmp_map1 = SimpleMap2() + # scenario1.set_map(tmp_map1) + # # scenario1.set_sensor(NoisyVehicleSensor((0,1), (0,0))) + # scenario1.set_init( + # [ + # [[5, -0.1, 0, 1.0], [6, 0.1, 0, 1.0]], + # [[20, 0, 0, 0.5], [20, 0, 0, 0.5]], + # ], + # [ + # (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + # (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + # ] + # ) + # scenario1.init_seg_length = 5 + # scenario.verify_method = 'GLOBAL' + # traces1 = scenario1.simulate(40, 0.05) + + # fig = plot_reachtube_tree(traces1.root, 'car1', 0, [1], 'g', fig) + # fig = plot_reachtube_tree(traces1.root, 'car2', 0, [1], 'r', fig) + + + plt.show() - fig = go.Figure() - fig = reachtube_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace') - fig.show() diff --git a/demo/vehicle/demo11.py b/demo/vehicle/demo11.py index b76a9889..b32003ec 100644 --- a/demo/vehicle/demo11.py +++ b/demo/vehicle/demo11.py @@ -69,39 +69,55 @@ if __name__ == "__main__": ) scenario.set_sensor(NoisyVehicleSensor((0.5,0.5), (0,0))) - scenario.init_seg_length = 5 - traces = scenario.verify(40, 0.1, params={"bloating_method":'GLOBAL'}) + import timeit + scenario.config.init_seg_length = 5 + time = timeit.default_timer() + traces = scenario.verify(40, 0.1) + print("run1", timeit.default_timer() - time) + print(scenario.verifier.cache.cache) fig = plt.figure(2) fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'b', fig) fig = plot_reachtube_tree(traces.root, 'car2', 1, [2], 'r', fig) fig = plot_reachtube_tree(traces.root, 'car3', 1, [2], 'r', fig) fig = plot_map(tmp_map, 'g', fig) - - scenario1 = Scenario() - car = CarAgent('car1', file_name=input_code_name) - scenario1.add_agent(car) - car = NPCAgent('car2') - scenario1.add_agent(car) - car = NPCAgent('car3') - scenario1.add_agent(car) - tmp_map = SimpleMap3() - scenario1.set_map(tmp_map) - scenario1.set_init( - [ - [[5, -0.5, 0, 1.0], [5.5, 0.5, 0, 1.0]], - [[20, -0.2, 0, 0.5], [20, 0.2, 0, 0.5]], - [[4-2.5, 2.8, 0, 1.0], [4.5-2.5, 3.2, 0, 1.0]], - ], - [ - (VehicleMode.Normal, LaneMode.Lane1,), - (VehicleMode.Normal, LaneMode.Lane1,), - (VehicleMode.Normal, LaneMode.Lane0,), - ] - ) - - scenario1.init_seg_length = 5 - traces = scenario1.verify(40, 0.1, params={"bloating_method":'GLOBAL'}) - - fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'g', fig) - plt.show() \ No newline at end of file + plt.show() + + # time = timeit.default_timer() + # traces = scenario.verify(40, 0.1) + # print("run2", timeit.default_timer() - time) + + # fig = plt.figure(2) + # fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'b', fig) + # fig = plot_reachtube_tree(traces.root, 'car2', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces.root, 'car3', 1, [2], 'r', fig) + # fig = plot_map(tmp_map, 'g', fig) + # plt.show() + + # scenario1 = Scenario() + # car = CarAgent('car1', file_name=input_code_name) + # scenario1.add_agent(car) + # car = NPCAgent('car2') + # scenario1.add_agent(car) + # car = NPCAgent('car3') + # scenario1.add_agent(car) + # tmp_map = SimpleMap3() + # scenario1.set_map(tmp_map) + # scenario1.set_init( + # [ + # [[5, -0.5, 0, 1.0], [5.5, 0.5, 0, 1.0]], + # [[20, -0.2, 0, 0.5], [20, 0.2, 0, 0.5]], + # [[4-2.5, 2.8, 0, 1.0], [4.5-2.5, 3.2, 0, 1.0]], + # ], + # [ + # (VehicleMode.Normal, LaneMode.Lane1,), + # (VehicleMode.Normal, LaneMode.Lane1,), + # (VehicleMode.Normal, LaneMode.Lane0,), + # ] + # ) + + # scenario1.init_seg_length = 5 + # traces = scenario1.verify(40, 0.1) + + # fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'g', fig) + # plt.show() diff --git a/demo/vehicle/demo7.py b/demo/vehicle/demo7.py index 8f67d3ac..5e7cffc1 100644 --- a/demo/vehicle/demo7.py +++ b/demo/vehicle/demo7.py @@ -1,5 +1,6 @@ # SM: Noting some things about the example +import timeit from verse.agents.example_agent import CarAgent, NPCAgent from verse.map.example_map import SimpleMap4 from verse import Scenario @@ -8,7 +9,6 @@ from verse.plotter.plotter2D import * from enum import Enum, auto import plotly.graph_objects as go - class LaneObjectMode(Enum): Vehicle = auto() Ped = auto() # Pedestrians @@ -68,7 +68,7 @@ if __name__ == "__main__": scenario.set_map(tmp_map) scenario.set_init( [ - [[0, -0.0, 0, 1.0], [0.0, 0.0, 0, 1.0]], + [[0, 0.0, 0, 1.0], [0.0, 0.0, 0, 1.0]], [[10, 0, 0, 0.5], [10, 0, 0, 0.5]], [[14, 3, 0, 0.6], [14, 3, 0, 0.6]], [[20, 3, 0, 0.5], [20, 3, 0, 0.5]], @@ -99,12 +99,34 @@ if __name__ == "__main__": ], ) - traces = scenario.simulate(60, 0.1) + # time = timeit.default_timer() + # # import cProfile, pstats, io + # # from pstats import SortKey + # # pr = cProfile.Profile() + # # pr.enable() + # traces = scenario.verify(60, 0.1) + # # pr.disable() + # # s = io.StringIO() + # # sortby = SortKey.CUMULATIVE + # # ps = pstats.Stats(pr, stream=s).sort_stats(sortby) + # # ps.print_stats() + # # print(s.getvalue()) + # print("run1", timeit.default_timer() - time) + # fig = go.Figure() + # fig = reachtube_tree(traces, tmp_map, fig, 1, + # 2, 'lines', 'trace', print_dim_list=[1, 2]) + # fig.show() + + time = timeit.default_timer() + traces = scenario.simulate(60, 0.05) + print("run2", timeit.default_timer() - time) fig = go.Figure() - fig = simulation_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace') + fig = simulation_tree(traces, tmp_map, fig, 1, + 2, 'lines', 'trace', print_dim_list=[1, 2]) fig.show() - traces = scenario.verify(60, 0.1) - fig = go.Figure() - fig = reachtube_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace') - fig.show() + # traces = scenario.verify(60, 0.05) + # fig = go.Figure() + # fig = reachtube_tree(traces, tmp_map, fig, 1, + # 2, 'lines', 'trace', print_dim_list=[1, 2]) + # fig.show() diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 792c5b4f..57b26bb0 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -121,4 +121,4 @@ class AnalysisTree: child_node = AnalysisTreeNode.from_dict(child_node_dict) parent_node.child.append(child_node) queue.append((child_node_dict, child_node)) - return AnalysisTree(root) \ No newline at end of file + return AnalysisTree(root) diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py index 558e96d3..e4e6ab91 100644 --- a/verse/analysis/incremental.py +++ b/verse/analysis/incremental.py @@ -1,17 +1,18 @@ from collections import defaultdict from dataclasses import dataclass -from typing import DefaultDict, List, Tuple, Optional +from typing import DefaultDict, Dict, List, Tuple, Optional from verse.analysis import AnalysisTreeNode from intervaltree import IntervalTree from verse.analysis.dryvr import _EPSILON -from verse.parser.parser import ControllerIR +from verse.parser.parser import ControllerIR, ModePath @dataclass class CachedTransition: transition: int disc: List[str] cont: List[float] + paths: List[ModePath] @dataclass class CachedSegment: @@ -19,6 +20,8 @@ class CachedSegment: asserts: List[str] transitions: List[CachedTransition] controller: ControllerIR + run_num: int + node_id: int @dataclass class CachedTube: @@ -33,14 +36,15 @@ class SimTraceCache: def __init__(self): self.cache: DefaultDict[tuple, IntervalTree] = defaultdict(IntervalTree) - def add_segment(self, agent_id: str, node: AnalysisTreeNode): + def add_segment(self, agent_id: str, node: AnalysisTreeNode, trace: List[List[float]], transition_paths: List[List[ModePath]], run_num: int): + assert len(transition_paths) == len(node.child) key = (agent_id,) + tuple(node.mode[agent_id]) init = node.init[agent_id] tree = self.cache[key] for i, val in enumerate(init): if i == len(init) - 1: - transitions = [CachedTransition(len(n.trace[agent_id]), n.mode[agent_id], n.init[agent_id]) for n in node.child] - entry = CachedSegment(node.trace[agent_id], node.assert_hits.get(agent_id), transitions, node.agent[agent_id].controller) + transitions = [CachedTransition(len(n.trace[agent_id]), n.mode[agent_id], n.init[agent_id], p) for n, p in zip(node.child, transition_paths)] + entry = CachedSegment(trace, node.assert_hits.get(agent_id), transitions, node.agent[agent_id].controller, run_num, node.id) tree[val - _EPSILON:val + _EPSILON] = entry return entry else: diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py index a1a831bd..7ee100e7 100644 --- a/verse/analysis/simulator.py +++ b/verse/analysis/simulator.py @@ -1,33 +1,69 @@ -from typing import List, Dict +from collections import defaultdict +from typing import List, Dict, Tuple import copy import itertools import functools import pprint +from verse.agents.base_agent import BaseAgent -from verse.analysis.incremental import SimTraceCache +from verse.analysis.incremental import CachedSegment, SimTraceCache +from verse.parser.parser import ControllerIR, ModePath, find pp = functools.partial(pprint.pprint, compact=True, width=100) # from verse.agents.base_agent import BaseAgent from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree +PathDiffs = List[Tuple[BaseAgent, ModePath]] + +def to_simulate(old_agents: Dict[str, BaseAgent], new_agents: Dict[str, BaseAgent], cached: Dict[str, CachedSegment]) -> Tuple[Dict[str, CachedSegment], PathDiffs]: + assert set(old_agents.keys()) == set(new_agents.keys()) + removed_paths, added_paths, reset_changed_paths = [], [], [] + for agent_id, old_agent in old_agents.items(): + new_agent = new_agents[agent_id] + old_ctlr, new_ctlr = old_agent.controller, new_agent.controller + assert old_ctlr.args == new_ctlr.args + def group_by_var(ctlr: ControllerIR) -> Dict[str, List[ModePath]]: + grouped = defaultdict(list) + for path in ctlr.paths: + grouped[path.var].append(path) + return dict(grouped) + old_grouped, new_grouped = group_by_var(old_ctlr), group_by_var(new_ctlr) + if set(old_grouped.keys()) != set(new_grouped.keys()): + raise NotImplementedError("different variable outputs") + for var, old_paths in old_grouped.items(): + new_paths = new_grouped[var] + for old, new in itertools.zip_longest(old_paths, new_paths): + if new == None: + removed_paths.append(old) + if old.cond != new.cond: + added_paths.append(new) + elif old.val != new.val: + reset_changed_paths.append(new) + new_cache = {} + for agent_id in cached: + segment = copy.deepcopy(cached[agent_id]) + new_transitions = [] + for trans in segment.transitions: + removed = False + for path in trans.paths: + if path in removed_paths: + removed = True + for rcp in reset_changed_paths: + if path.cond == rcp.cond: + path.val = rcp.val + if not removed: + new_transitions.append(trans) + new_cache[agent_id] = segment + return new_cache, added_paths + class Simulator: def __init__(self): self.simulation_tree = None self.cache = SimTraceCache() - def simulate( - self, - init_list, - init_mode_list, - static_list, - uncertain_param_list, - agent_list, - transition_graph, - time_horizon, - time_step, - lane_map - ): + def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, + transition_graph, time_horizon, time_step, lane_map, run_num, past_runs): # Setup the root of the simulation tree root = AnalysisTreeNode( trace={}, @@ -59,10 +95,11 @@ class Simulator: if remain_time <= 0: continue # For trace not already simulated - cache_entries = {} + cached_segments = {} for agent_id in node.agent: if agent_id in node.trace: - cache_entries[agent_id] = self.cache.add_segment(agent_id, node) + # cached_segments[agent_id] = self.cache.add_segment(agent_id, node, run_num) + pass else: mode = node.mode[agent_id] init = node.init[agent_id] @@ -71,7 +108,7 @@ class Simulator: node.trace[agent_id] = cached.trace if len(cached.trace) < remain_time / time_step: node.trace[agent_id] += node.agent[agent_id].TC_simulate(mode, cached.trace[-1], remain_time - time_step * len(cached.trace), lane_map) - cache_entries[agent_id] = cached + cached_segments[agent_id] = cached else: # Simulate the trace starting from initial condition trace = node.agent[agent_id].TC_simulate( @@ -79,28 +116,39 @@ class Simulator: trace[:, 0] += node.start_time trace = trace.tolist() node.trace[agent_id] = trace - cache_entries[agent_id] = self.cache.add_segment(agent_id, node) + # cached_segments[agent_id] = self.cache.add_segment(agent_id, node, run_num) + # TODO: for now, make sure all the segments comes from the same node; maybe we can do + # something to combine results from different nodes in the future + node_ids = list(set((s.run_num, s.node_id) for s in cached_segments.values())) + assert len(node_ids) <= 1, f"{node_ids}" + if len(node_ids) == 1: + run_num, node_id = node_ids[0] + old_node = find(past_runs[run_num], lambda n: n.id == node_id) + assert old_node != None + new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_segments) + else: + new_cache, paths_to_sim = {}, [] - asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new( - node, cache_entries) + asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new(new_cache, paths_to_sim) node.assert_hits = asserts pp({a: trace[transition_idx] for a, trace in node.trace.items()}) # truncate the computed trajectories from idx and store the content after truncate - truncated_trace = {} + truncated_trace, full_traces = {}, {} for agent_idx in node.agent: + full_traces[agent_idx] = node.trace[agent_idx] truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:] node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1] - # If there's no transitions (returned transitions is empty), continue - if not transitions: - continue - if asserts != None: print(transition_idx) pp({a: len(t) for a, t in node.trace.items()}) else: + # If there's no transitions (returned transitions is empty), continue + if not transitions: + continue + # Generate the transition combinations if multiple agents can transit at the same time step transition_list = list(transitions.values()) all_transition_combinations = itertools.product( @@ -109,6 +157,7 @@ class Simulator: # For each possible transition, construct the new node. # Obtain the new initial condition for agent having transition # copy the traces that are not under transition + transition_paths = [] for transition_combination in all_transition_combinations: next_node_mode = copy.deepcopy(node.mode) next_node_static = copy.deepcopy(node.static) @@ -119,7 +168,8 @@ class Simulator: next_node_init = {} next_node_trace = {} for transition in transition_combination: - transit_agent_idx, dest_mode, next_init = transition + transit_agent_idx, dest_mode, next_init, paths = transition + transition_paths.append(paths) if dest_mode is None: continue # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0) @@ -143,6 +193,8 @@ class Simulator: ) node.child.append(tmp) simulation_queue.append(tmp) + for agent_id in node.agent: + self.cache.add_segment(agent_id, node, full_traces[agent_id], transition_paths, run_num) # Put the node in the child of current node. Put the new node in the queue # node.child.append(AnalysisTreeNode( # trace = next_node_trace, diff --git a/verse/parser/parser.py b/verse/parser/parser.py index 88514ba1..e1c62de4 100644 --- a/verse/parser/parser.py +++ b/verse/parser/parser.py @@ -244,6 +244,10 @@ class ModePath: val: Any val_veri: ast.expr + def __eq__(self, other: object) -> bool: + # TODO: more general equivalence? + return self.cond == other.cond and self.val == other.val + @dataclass class ControllerIR: args: LambdaArgs @@ -322,7 +326,6 @@ class ControllerIR: cond = compile_expr(Env.trans_args(cond, False)) val = compile_expr(Env.trans_args(case.val, False)) paths.append(ModePath(cond, cond_veri, var, val, val_veri)) - return ControllerIR(controller.args, paths, asserts_sim, asserts_veri, env.state_defs, env.mode_defs) @dataclass diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index 7e523c13..071dfda0 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -1,4 +1,4 @@ -from lib2to3.pytree import Base +from pprint import pp from typing import DefaultDict, Optional, Tuple, List, Dict, Any import copy import itertools @@ -11,6 +11,7 @@ import numpy as np from verse.agents.base_agent import BaseAgent from verse.analysis.incremental import CachedSegment, CachedTransition +from verse.analysis.simulator import PathDiffs from verse.automaton import GuardExpressionAst, ResetExpression from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree from verse.analysis.utils import sample_rect @@ -48,6 +49,77 @@ def pack_env(agent: BaseAgent, ego_ty_name: str, cont, disc, lane_map): packed[EGO] = state_ty(**packed[EGO]) return dict(packed.items()) +def check_transitions(agent: BaseAgent, guards, cont, disc, map, state, mode): + asserts = [] + satisfied_guard = [] + agent_id = agent.id + # Unsafety checking + ego_ty_name = find(agent.controller.args, lambda a: a.name == EGO).typ + packed_env = pack_env(agent, ego_ty_name, cont, disc, map) + + # Check safety conditions + for assertion in agent.controller.asserts: + if eval(assertion.pre, packed_env): + if not eval(assertion.cond, packed_env): + del packed_env["__builtins__"] + print(f"assert hit for {agent_id}: \"{assertion.label}\" @ {packed_env}") + asserts.append(assertion.label) + if len(asserts) != 0: + return asserts, satisfied_guard + + all_resets = defaultdict(list) + for disc_vars, path in guards: + env = pack_env(agent, ego_ty_name, cont, disc_vars, map) # TODO: diff disc -> disc_vars? + + # Collect all the hit guards for this agent at this time step + if eval(path.cond, env): + # If the guard can be satisfied, handle resets + all_resets[path.var].append((path.val, path)) + + iter_list = [] + for vals in all_resets.values(): + paths = [p for _, p in vals] + iter_list.append(zip(range(len(vals)), paths)) + pos_list = list(itertools.product(*iter_list)) + if len(pos_list) == 1 and pos_list[0] == (): + raise NotImplementedError("??") + for pos in pos_list: + next_init = copy.deepcopy(state) + dest = copy.deepcopy(mode) + possible_dest = [[elem] for elem in dest] + for j, (reset_idx, path) in enumerate(pos): + reset_variable = list(all_resets.keys())[j] + res = eval(all_resets[reset_variable][reset_idx][0], packed_env) + ego_type = agent.controller.state_defs[ego_ty_name] + if "mode" in reset_variable: + var_loc = ego_type.disc.index(reset_variable) + assert not isinstance(res, list), res + possible_dest[var_loc] = [(res, path)] + else: + var_loc = ego_type.cont.index(reset_variable) + next_init[var_loc] = res + print("possible_dest") + pp(possible_dest) + all_dest = list(itertools.product(*possible_dest)) + print("all_dest") + pp(all_dest) + if not all_dest: + warnings.warn( + f"Guard hit for mode {mode} for agent {agent_id} without available next mode") + all_dest.append(None) + for dest in all_dest: + assert isinstance(dest, tuple) + paths = [] + pure_dest = [] + for d in dest: + if isinstance(d, tuple): + pure_dest.append(d[0]) + paths.append(d[1]) + else: + pure_dest.append(d) + satisfied_guard.append((agent_id, pure_dest, next_init, paths)) + return None, satisfied_guard + @dataclass class ScenarioConfig: incremental: bool = False @@ -66,6 +138,7 @@ class Scenario: self.uncertain_param_dict = {} self.map = LaneMap() self.sensor = BaseSensor() + self.past_runs = [] # Parameters self.config = config @@ -193,7 +266,9 @@ class Scenario: uncertain_param_list.append(self.uncertain_param_dict[agent_id]) agent_list.append(self.agent_dict[agent_id]) print(init_list) - return self.simulator.simulate(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, time_step, self.map) + tree = self.simulator.simulate(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, time_step, self.map, len(self.past_runs), self.past_runs) + self.past_runs.append(tree) + return tree def verify(self, time_horizon, time_step, reachability_method='DRYVR', params={}) -> AnalysisTree: self.check_init() @@ -212,10 +287,10 @@ class Scenario: static_list.append(self.static_dict[agent_id]) uncertain_param_list.append(self.uncertain_param_dict[agent_id]) agent_list.append(self.agent_dict[agent_id]) - - res = self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, - time_step, self.map, self.config.init_seg_length, self.config.reachability_method, params) - return res + tree = self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, + time_step, self.map, self.config.init_seg_length, self.config.reachability_method, params, len(self.past_runs)) + self.past_runs.append(tree) + return tree def apply_reset(self, agent: BaseAgent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]: lane_map = self.map @@ -326,106 +401,83 @@ class Scenario: # unrolled_variable, unrolled_variable_index = updater[variable] # disc_var_dict[unrolled_variable] = disc_var_dict[variable][unrolled_variable_index] - def get_transition_simulate_new(self, node: AnalysisTreeNode, cache: Dict[str, CachedSegment]) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]: + # def get_transition_simulate_new(self, diffed: Tuple[PathDiffs, PathDiffs, PathDiffs], cache: Dict[str, CachedSegment]) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]: + def get_transition_simulate_new(self, cache: Dict[str, CachedSegment], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]: lane_map = self.map trace_length = len(list(node.trace.values())[0]) # For each agent agent_guard_dict = defaultdict(list) + cached_guards = defaultdict(list) - for agent_id in node.agent: + if not cache: + paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths] + path_transitions = {} + else: + if len(paths) == 0: + transition = min(trans.transition for seg in cache.values() for trans in seg.transitions) + transitions = defaultdict(list) + for agent_id, seg in cache.items(): + # TODO: check for asserts + for tran in seg.transitions: + if tran.transition == transition: + for path in tran.paths: + transitions[agent_id].append((agent_id, tran.disc, tran.cont, path)) + return None, transitions, transition + + path_transitions = defaultdict(int) + for seg in cache.values(): + for tran in seg.transitions: + for p in tran.paths: + path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition) + for agent_id, segment in cache.items(): + agent = node.agent[agent_id] + agent_mode = node.mode[agent_id] + if len(agent.controller.args) == 0: + continue + state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} + agent_paths = {p for tran in segment.transitions for p in tran.paths} + cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map) + for path in agent_paths: + cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond])) + + for agent, path in paths: # Get guard - agent: BaseAgent = self.agent_dict[agent_id] + agent_id = agent.id agent_mode = node.mode[agent_id] if len(agent.controller.args) == 0: continue - state_dict = {} - for tmp in node.agent: - state_dict[tmp] = (node.trace[tmp][0], - node.mode[tmp], node.static[tmp]) - cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense( - self, agent, state_dict, self.map) - paths = agent.controller.paths - for path in paths: - agent_guard_dict[agent_id].append( - (path.cond, discrete_variable_dict, path.var, path.val)) + state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} + cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map) + agent_guard_dict[agent_id].append((path, discrete_variable_dict)) transitions = defaultdict(list) # TODO: We can probably rewrite how guard hit are detected and resets are handled for simulation for idx in range(trace_length): satisfied_guard = [] - asserts = defaultdict(list) + all_asserts = defaultdict(list) for agent_id in agent_guard_dict: agent: BaseAgent = self.agent_dict[agent_id] - state_dict = {} - for tmp in node.agent: - state_dict[tmp] = (node.trace[tmp][idx], - node.mode[tmp], node.static[tmp]) + state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent} agent_state, agent_mode, agent_static = state_dict[agent_id] agent_state = agent_state[1:] continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense( self, agent, state_dict, self.map) - # Unsafety checking - ego_ty_name = find(agent.controller.args, lambda a: a.name == EGO).typ - packed_env = pack_env(agent, ego_ty_name, continuous_variable_dict, orig_disc_vars, self.map) - - # Check safety conditions - for assertion in agent.controller.asserts: - if eval(assertion.pre, packed_env): - if not eval(assertion.cond, packed_env): - del packed_env["__builtins__"] - print( - f"assert hit for {agent_id}: \"{assertion.label}\" @ {packed_env}") - asserts[agent_id].append(assertion.label) - if agent_id in asserts: - continue - - all_resets = defaultdict(list) - for guard_comp, discrete_variable_dict, var, reset in agent_guard_dict[agent_id]: - new_cont_var_dict = copy.deepcopy(continuous_variable_dict) - env = pack_env(agent, ego_ty_name, new_cont_var_dict, discrete_variable_dict, self.map) - - # Collect all the hit guards for this agent at this time step - if eval(guard_comp, env): - # If the guard can be satisfied, handle resets - all_resets[var].append(reset) - - iter_list = [] - for reset_var in all_resets: - iter_list.append(range(len(all_resets[reset_var]))) - pos_list = list(itertools.product(*iter_list)) - if len(pos_list) == 1 and pos_list[0] == (): - continue - for i in range(len(pos_list)): - pos = pos_list[i] - next_init = copy.deepcopy(agent_state) - dest = copy.deepcopy(agent_mode) - possible_dest = [[elem] for elem in dest] - for j, reset_idx in enumerate(pos): - reset_variable = list(all_resets.keys())[j] - res = eval(all_resets[reset_variable] - [reset_idx], packed_env) - ego_type = agent.controller.state_defs[ego_ty_name] - if "mode" in reset_variable: - var_loc = ego_type.disc.index(reset_variable) - if not isinstance(res, list): - res = [res] - possible_dest[var_loc] = res - else: - var_loc = ego_type.cont.index(reset_variable) - next_init[var_loc] = res - all_dest = list(itertools.product(*possible_dest)) - if not all_dest: - warnings.warn( - f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode") - all_dest.append(None) - for dest in all_dest: - satisfied_guard.append((agent_id, dest, next_init)) - if len(asserts) > 0: - return asserts, None, idx + unchecked_cache_guards = [g[:2] for g in cached_guards[agent_id] if g[2] < idx] # FIXME: off by 1? + asserts, satisfied = check_transitions(agent, agent_guard_dict[agent_id] + unchecked_cache_guards, continuous_variable_dict, orig_disc_vars, self.map, agent_state, agent_mode) + if asserts != None: + all_asserts[agent_id] = asserts + return all_asserts, transitions, idx + if len(satisfied) != 0: + satisfied_guard.extend(satisfied) + if len(all_asserts) > 0: + return all_asserts, transitions, idx if len(satisfied_guard) > 0: - for agent_idx, dest_mode, next_init in satisfied_guard: - transitions[agent_idx].append((agent_idx, dest_mode, next_init)) + print("satisfied_guard") + pp(satisfied_guard) + for agent_idx, dest_mode, next_init, paths in satisfied_guard: + transitions[agent_idx].append((agent_idx, dest_mode, next_init, paths)) + print("transitions", transitions) break return None, transitions, idx -- GitLab