From cc268dc0cfa57003f48dcd80c8b929916bcd89fc Mon Sep 17 00:00:00 2001
From: crides <zhuhaoqing@live.cn>
Date: Fri, 26 Aug 2022 12:32:57 -0500
Subject: [PATCH] sim done?

---
 demo/vehicle/demo10.py          |  71 ++++++-----
 demo/vehicle/demo11.py          |  76 ++++++-----
 demo/vehicle/demo7.py           |  38 ++++--
 verse/analysis/analysis_tree.py |   2 +-
 verse/analysis/incremental.py   |  14 +-
 verse/analysis/simulator.py     | 104 +++++++++++----
 verse/parser/parser.py          |   5 +-
 verse/scenario/scenario.py      | 220 ++++++++++++++++++++------------
 8 files changed, 341 insertions(+), 189 deletions(-)

diff --git a/demo/vehicle/demo10.py b/demo/vehicle/demo10.py
index 4cae82fc..9cbcf778 100644
--- a/demo/vehicle/demo10.py
+++ b/demo/vehicle/demo10.py
@@ -2,13 +2,12 @@ from verse.agents.example_agent import CarAgent, NPCAgent
 from verse.map.example_map import SimpleMap2
 from verse import Scenario
 from verse.plotter.plotter2D import *
-# from verse.plotter.plotter2D_old import plot_reachtube_tree, plot_map
+from verse.plotter.plotter2D_old import plot_reachtube_tree, plot_map, plot_simulation_tree
 from noisy_sensor import NoisyVehicleSensor
 
 from enum import Enum, auto
 import plotly.graph_objects as go
-import matplotlib.pyplot as plt
-
+import matplotlib.pyplot as plt 
 
 class LaneObjectMode(Enum):
     Vehicle = auto()
@@ -54,7 +53,7 @@ if __name__ == "__main__":
     scenario.add_agent(car)
     tmp_map = SimpleMap2()
     scenario.set_map(tmp_map)
-    scenario.set_sensor(NoisyVehicleSensor((1, 1), (0, 0)))
+    scenario.set_sensor(NoisyVehicleSensor((1,1), (0,0)))
     scenario.set_init(
         [
             [[5, -0.1, 0, 1.0], [5.5, 0.1, 0, 1.1]],
@@ -65,34 +64,38 @@ if __name__ == "__main__":
             (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
         ]
     )
-    scenario.init_seg_length = 5
-    traces = scenario.verify(40, 0.05)
-
-    # fig = plt.figure(2)
-    # fig = plot_reachtube_tree(traces.root, 'car1', 0, [1], 'b', fig)
-    # fig = plot_reachtube_tree(traces.root, 'car2', 0, [1], 'r', fig)
-
-    scenario1 = Scenario()
-    car1 = CarAgent('car1', file_name=input_code_name)
-    scenario1.add_agent(car1)
-    car1 = NPCAgent('car2')
-    scenario1.add_agent(car1)
-    tmp_map1 = SimpleMap2()
-    scenario1.set_map(tmp_map1)
-    # scenario1.set_sensor(NoisyVehicleSensor((0,1), (0,0)))
-    scenario1.set_init(
-        [
-            [[5, -0.1, 0, 1.0], [6, 0.1, 0, 1.0]],
-            [[20, 0, 0, 0.5], [20, 0, 0, 0.5]],
-        ],
-        [
-            (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
-            (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
-        ]
-    )
-    scenario1.init_seg_length = 5
-    traces1 = scenario1.verify(40, 0.05, params={"bloating_method": 'GLOBAL'})
+    scenario.config.init_seg_length = 5
+    traces = scenario.simulate(40, 0.05)
+
+    fig = plt.figure(2)
+    fig = plot_simulation_tree(traces.root, 'car1', 0, [1], 'b', fig)
+    fig = plot_simulation_tree(traces.root, 'car2', 0, [1], 'r', fig)
+
+    # scenario1 = Scenario()
+    # car1 = CarAgent('car1', file_name=input_code_name)
+    # scenario1.add_agent(car1)
+    # car1 = NPCAgent('car2')
+    # scenario1.add_agent(car1)
+    # tmp_map1 = SimpleMap2()
+    # scenario1.set_map(tmp_map1)
+    # # scenario1.set_sensor(NoisyVehicleSensor((0,1), (0,0)))
+    # scenario1.set_init(
+    #     [
+    #         [[5, -0.1, 0, 1.0], [6, 0.1, 0, 1.0]],
+    #         [[20, 0, 0, 0.5], [20, 0, 0, 0.5]],
+    #     ],
+    #     [
+    #         (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
+    #         (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle),
+    #     ]
+    # )
+    # scenario1.init_seg_length = 5
+    # scenario.verify_method = 'GLOBAL'
+    # traces1 = scenario1.simulate(40, 0.05)
+
+    # fig = plot_reachtube_tree(traces1.root, 'car1', 0, [1], 'g', fig)
+    # fig = plot_reachtube_tree(traces1.root, 'car2', 0, [1], 'r', fig)
+
+
+    plt.show()
 
-    fig = go.Figure()
-    fig = reachtube_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace')
-    fig.show()
diff --git a/demo/vehicle/demo11.py b/demo/vehicle/demo11.py
index b76a9889..b32003ec 100644
--- a/demo/vehicle/demo11.py
+++ b/demo/vehicle/demo11.py
@@ -69,39 +69,55 @@ if __name__ == "__main__":
     )
     scenario.set_sensor(NoisyVehicleSensor((0.5,0.5), (0,0)))
 
-    scenario.init_seg_length = 5
-    traces = scenario.verify(40, 0.1, params={"bloating_method":'GLOBAL'})
+    import timeit
+    scenario.config.init_seg_length = 5
+    time = timeit.default_timer()
+    traces = scenario.verify(40, 0.1)
+    print("run1", timeit.default_timer() - time)
+    print(scenario.verifier.cache.cache)
 
     fig = plt.figure(2)
     fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'b', fig)
     fig = plot_reachtube_tree(traces.root, 'car2', 1, [2], 'r', fig)
     fig = plot_reachtube_tree(traces.root, 'car3', 1, [2], 'r', fig)
     fig = plot_map(tmp_map, 'g', fig)
-
-    scenario1 = Scenario()
-    car = CarAgent('car1', file_name=input_code_name)
-    scenario1.add_agent(car)
-    car = NPCAgent('car2')
-    scenario1.add_agent(car)
-    car = NPCAgent('car3')
-    scenario1.add_agent(car)
-    tmp_map = SimpleMap3()
-    scenario1.set_map(tmp_map)
-    scenario1.set_init(
-        [
-            [[5, -0.5, 0, 1.0], [5.5, 0.5, 0, 1.0]],
-            [[20, -0.2, 0, 0.5], [20, 0.2, 0, 0.5]],
-            [[4-2.5, 2.8, 0, 1.0], [4.5-2.5, 3.2, 0, 1.0]],
-        ],
-        [
-            (VehicleMode.Normal, LaneMode.Lane1,),
-            (VehicleMode.Normal, LaneMode.Lane1,),
-            (VehicleMode.Normal, LaneMode.Lane0,),
-        ]
-    )
-
-    scenario1.init_seg_length = 5
-    traces = scenario1.verify(40, 0.1, params={"bloating_method":'GLOBAL'})
-
-    fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'g', fig)
-    plt.show()
\ No newline at end of file
+    plt.show()
+
+    # time = timeit.default_timer()
+    # traces = scenario.verify(40, 0.1)
+    # print("run2", timeit.default_timer() - time)
+
+    # fig = plt.figure(2)
+    # fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'b', fig)
+    # fig = plot_reachtube_tree(traces.root, 'car2', 1, [2], 'r', fig)
+    # fig = plot_reachtube_tree(traces.root, 'car3', 1, [2], 'r', fig)
+    # fig = plot_map(tmp_map, 'g', fig)
+    # plt.show()
+
+    # scenario1 = Scenario()
+    # car = CarAgent('car1', file_name=input_code_name)
+    # scenario1.add_agent(car)
+    # car = NPCAgent('car2')
+    # scenario1.add_agent(car)
+    # car = NPCAgent('car3')
+    # scenario1.add_agent(car)
+    # tmp_map = SimpleMap3()
+    # scenario1.set_map(tmp_map)
+    # scenario1.set_init(
+    #     [
+    #         [[5, -0.5, 0, 1.0], [5.5, 0.5, 0, 1.0]],
+    #         [[20, -0.2, 0, 0.5], [20, 0.2, 0, 0.5]],
+    #         [[4-2.5, 2.8, 0, 1.0], [4.5-2.5, 3.2, 0, 1.0]],
+    #     ],
+    #     [
+    #         (VehicleMode.Normal, LaneMode.Lane1,),
+    #         (VehicleMode.Normal, LaneMode.Lane1,),
+    #         (VehicleMode.Normal, LaneMode.Lane0,),
+    #     ]
+    # )
+
+    # scenario1.init_seg_length = 5
+    # traces = scenario1.verify(40, 0.1)
+
+    # fig = plot_reachtube_tree(traces.root, 'car1', 1, [2], 'g', fig)
+    # plt.show()
diff --git a/demo/vehicle/demo7.py b/demo/vehicle/demo7.py
index 8f67d3ac..5e7cffc1 100644
--- a/demo/vehicle/demo7.py
+++ b/demo/vehicle/demo7.py
@@ -1,5 +1,6 @@
 # SM: Noting some things about the example
 
+import timeit
 from verse.agents.example_agent import CarAgent, NPCAgent
 from verse.map.example_map import SimpleMap4
 from verse import Scenario
@@ -8,7 +9,6 @@ from verse.plotter.plotter2D import *
 from enum import Enum, auto
 import plotly.graph_objects as go
 
-
 class LaneObjectMode(Enum):
     Vehicle = auto()
     Ped = auto()        # Pedestrians
@@ -68,7 +68,7 @@ if __name__ == "__main__":
     scenario.set_map(tmp_map)
     scenario.set_init(
         [
-            [[0, -0.0, 0, 1.0], [0.0, 0.0, 0, 1.0]],
+            [[0, 0.0, 0, 1.0], [0.0, 0.0, 0, 1.0]],
             [[10, 0, 0, 0.5], [10, 0, 0, 0.5]],
             [[14, 3, 0, 0.6], [14, 3, 0, 0.6]],
             [[20, 3, 0, 0.5], [20, 3, 0, 0.5]],
@@ -99,12 +99,34 @@ if __name__ == "__main__":
         ],
 
     )
-    traces = scenario.simulate(60, 0.1)
+    # time = timeit.default_timer()
+    # # import cProfile, pstats, io
+    # # from pstats import SortKey
+    # # pr = cProfile.Profile()
+    # # pr.enable()
+    # traces = scenario.verify(60, 0.1)
+    # # pr.disable()
+    # # s = io.StringIO()
+    # # sortby = SortKey.CUMULATIVE
+    # # ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
+    # # ps.print_stats()
+    # # print(s.getvalue())
+    # print("run1", timeit.default_timer() - time)
+    # fig = go.Figure()
+    # fig = reachtube_tree(traces, tmp_map, fig, 1,
+    #                       2, 'lines', 'trace', print_dim_list=[1, 2])
+    # fig.show()
+
+    time = timeit.default_timer()
+    traces = scenario.simulate(60, 0.05)
+    print("run2", timeit.default_timer() - time)
     fig = go.Figure()
-    fig = simulation_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace')
+    fig = simulation_tree(traces, tmp_map, fig, 1,
+                          2, 'lines', 'trace', print_dim_list=[1, 2])
     fig.show()
 
-    traces = scenario.verify(60, 0.1)
-    fig = go.Figure()
-    fig = reachtube_tree(traces, tmp_map, fig, 1, 2, [1, 2], 'lines', 'trace')
-    fig.show()
+    # traces = scenario.verify(60, 0.05)
+    # fig = go.Figure()
+    # fig = reachtube_tree(traces, tmp_map, fig, 1,
+    #                       2, 'lines', 'trace', print_dim_list=[1, 2])
+    # fig.show()
diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py
index 792c5b4f..57b26bb0 100644
--- a/verse/analysis/analysis_tree.py
+++ b/verse/analysis/analysis_tree.py
@@ -121,4 +121,4 @@ class AnalysisTree:
                 child_node = AnalysisTreeNode.from_dict(child_node_dict)
                 parent_node.child.append(child_node)
                 queue.append((child_node_dict, child_node))
-        return AnalysisTree(root)
\ No newline at end of file
+        return AnalysisTree(root)
diff --git a/verse/analysis/incremental.py b/verse/analysis/incremental.py
index 558e96d3..e4e6ab91 100644
--- a/verse/analysis/incremental.py
+++ b/verse/analysis/incremental.py
@@ -1,17 +1,18 @@
 from collections import defaultdict
 from dataclasses import dataclass
-from typing import DefaultDict, List, Tuple, Optional
+from typing import DefaultDict, Dict, List, Tuple, Optional
 from verse.analysis import AnalysisTreeNode
 from intervaltree import IntervalTree
 
 from verse.analysis.dryvr import _EPSILON
-from verse.parser.parser import ControllerIR
+from verse.parser.parser import ControllerIR, ModePath
 
 @dataclass
 class CachedTransition:
     transition: int
     disc: List[str]
     cont: List[float]
+    paths: List[ModePath]
 
 @dataclass
 class CachedSegment:
@@ -19,6 +20,8 @@ class CachedSegment:
     asserts: List[str]
     transitions: List[CachedTransition]
     controller: ControllerIR
+    run_num: int
+    node_id: int
 
 @dataclass
 class CachedTube:
@@ -33,14 +36,15 @@ class SimTraceCache:
     def __init__(self):
         self.cache: DefaultDict[tuple, IntervalTree] = defaultdict(IntervalTree)
 
-    def add_segment(self, agent_id: str, node: AnalysisTreeNode):
+    def add_segment(self, agent_id: str, node: AnalysisTreeNode, trace: List[List[float]], transition_paths: List[List[ModePath]], run_num: int):
+        assert len(transition_paths) == len(node.child)
         key = (agent_id,) + tuple(node.mode[agent_id])
         init = node.init[agent_id]
         tree = self.cache[key]
         for i, val in enumerate(init):
             if i == len(init) - 1:
-                transitions = [CachedTransition(len(n.trace[agent_id]), n.mode[agent_id], n.init[agent_id]) for n in node.child]
-                entry = CachedSegment(node.trace[agent_id], node.assert_hits.get(agent_id), transitions, node.agent[agent_id].controller)
+                transitions = [CachedTransition(len(n.trace[agent_id]), n.mode[agent_id], n.init[agent_id], p) for n, p in zip(node.child, transition_paths)]
+                entry = CachedSegment(trace, node.assert_hits.get(agent_id), transitions, node.agent[agent_id].controller, run_num, node.id)
                 tree[val - _EPSILON:val + _EPSILON] = entry
                 return entry
             else:
diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py
index a1a831bd..7ee100e7 100644
--- a/verse/analysis/simulator.py
+++ b/verse/analysis/simulator.py
@@ -1,33 +1,69 @@
-from typing import List, Dict
+from collections import defaultdict
+from typing import List, Dict, Tuple
 import copy
 import itertools
 import functools
 
 import pprint
+from verse.agents.base_agent import BaseAgent
 
-from verse.analysis.incremental import SimTraceCache
+from verse.analysis.incremental import CachedSegment, SimTraceCache
+from verse.parser.parser import ControllerIR, ModePath, find
 pp = functools.partial(pprint.pprint, compact=True, width=100)
 
 # from verse.agents.base_agent import BaseAgent
 from verse.analysis.analysis_tree import AnalysisTreeNode, AnalysisTree
 
+PathDiffs = List[Tuple[BaseAgent, ModePath]]
+
+def to_simulate(old_agents: Dict[str, BaseAgent], new_agents: Dict[str, BaseAgent], cached: Dict[str, CachedSegment]) -> Tuple[Dict[str, CachedSegment], PathDiffs]:
+    assert set(old_agents.keys()) == set(new_agents.keys())
+    removed_paths, added_paths, reset_changed_paths = [], [], []
+    for agent_id, old_agent in old_agents.items():
+        new_agent = new_agents[agent_id]
+        old_ctlr, new_ctlr = old_agent.controller, new_agent.controller
+        assert old_ctlr.args == new_ctlr.args
+        def group_by_var(ctlr: ControllerIR) -> Dict[str, List[ModePath]]:
+            grouped = defaultdict(list)
+            for path in ctlr.paths:
+                grouped[path.var].append(path)
+            return dict(grouped)
+        old_grouped, new_grouped = group_by_var(old_ctlr), group_by_var(new_ctlr)
+        if set(old_grouped.keys()) != set(new_grouped.keys()):
+            raise NotImplementedError("different variable outputs")
+        for var, old_paths in old_grouped.items():
+            new_paths = new_grouped[var]
+            for old, new in itertools.zip_longest(old_paths, new_paths):
+                if new == None:
+                    removed_paths.append(old)
+                if old.cond != new.cond:
+                    added_paths.append(new)
+                elif old.val != new.val:
+                    reset_changed_paths.append(new)
+    new_cache = {}
+    for agent_id in cached:
+        segment = copy.deepcopy(cached[agent_id])
+        new_transitions = []
+        for trans in segment.transitions:
+            removed = False
+            for path in trans.paths:
+                if path in removed_paths:
+                    removed = True
+                for rcp in reset_changed_paths:
+                    if path.cond == rcp.cond:
+                        path.val = rcp.val
+            if not removed:
+                new_transitions.append(trans)
+        new_cache[agent_id] = segment
+    return new_cache, added_paths
+
 class Simulator:
     def __init__(self):
         self.simulation_tree = None
         self.cache = SimTraceCache()
 
-    def simulate(
-        self, 
-        init_list, 
-        init_mode_list, 
-        static_list, 
-        uncertain_param_list, 
-        agent_list, 
-        transition_graph, 
-        time_horizon, 
-        time_step, 
-        lane_map
-    ):
+    def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list,
+                 transition_graph, time_horizon, time_step, lane_map, run_num, past_runs):
         # Setup the root of the simulation tree
         root = AnalysisTreeNode(
             trace={},
@@ -59,10 +95,11 @@ class Simulator:
             if remain_time <= 0:
                 continue
             # For trace not already simulated
-            cache_entries = {}
+            cached_segments = {}
             for agent_id in node.agent:
                 if agent_id in node.trace:
-                    cache_entries[agent_id] = self.cache.add_segment(agent_id, node)
+                    # cached_segments[agent_id] = self.cache.add_segment(agent_id, node, run_num)
+                    pass
                 else:
                     mode = node.mode[agent_id]
                     init = node.init[agent_id]
@@ -71,7 +108,7 @@ class Simulator:
                         node.trace[agent_id] = cached.trace
                         if len(cached.trace) < remain_time / time_step:
                             node.trace[agent_id] += node.agent[agent_id].TC_simulate(mode, cached.trace[-1], remain_time - time_step * len(cached.trace), lane_map)
-                        cache_entries[agent_id] = cached
+                        cached_segments[agent_id] = cached
                     else:
                         # Simulate the trace starting from initial condition
                         trace = node.agent[agent_id].TC_simulate(
@@ -79,28 +116,39 @@ class Simulator:
                         trace[:, 0] += node.start_time
                         trace = trace.tolist()
                         node.trace[agent_id] = trace
-                        cache_entries[agent_id] = self.cache.add_segment(agent_id, node)
+                        # cached_segments[agent_id] = self.cache.add_segment(agent_id, node, run_num)
+            # TODO: for now, make sure all the segments comes from the same node; maybe we can do
+            # something to combine results from different nodes in the future
+            node_ids = list(set((s.run_num, s.node_id) for s in cached_segments.values()))
+            assert len(node_ids) <= 1, f"{node_ids}"
+            if len(node_ids) == 1:
+                run_num, node_id = node_ids[0]
+                old_node = find(past_runs[run_num], lambda n: n.id == node_id)
+                assert old_node != None
+                new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_segments)
+            else:
+                new_cache, paths_to_sim = {}, []
 
-            asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new(
-                node, cache_entries)
+            asserts, transitions, transition_idx = transition_graph.get_transition_simulate_new(new_cache, paths_to_sim)
 
             node.assert_hits = asserts
             pp({a: trace[transition_idx] for a, trace in node.trace.items()})
 
             # truncate the computed trajectories from idx and store the content after truncate
-            truncated_trace = {}
+            truncated_trace, full_traces = {}, {}
             for agent_idx in node.agent:
+                full_traces[agent_idx] = node.trace[agent_idx]
                 truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:]
                 node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1]
 
-            # If there's no transitions (returned transitions is empty), continue
-            if not transitions:
-                continue
-
             if asserts != None:
                 print(transition_idx)
                 pp({a: len(t) for a, t in node.trace.items()})
             else:
+                # If there's no transitions (returned transitions is empty), continue
+                if not transitions:
+                    continue
+
                 # Generate the transition combinations if multiple agents can transit at the same time step
                 transition_list = list(transitions.values())
                 all_transition_combinations = itertools.product(
@@ -109,6 +157,7 @@ class Simulator:
                 # For each possible transition, construct the new node.
                 # Obtain the new initial condition for agent having transition
                 # copy the traces that are not under transition
+                transition_paths = []
                 for transition_combination in all_transition_combinations:
                     next_node_mode = copy.deepcopy(node.mode)
                     next_node_static = copy.deepcopy(node.static)
@@ -119,7 +168,8 @@ class Simulator:
                     next_node_init = {}
                     next_node_trace = {}
                     for transition in transition_combination:
-                        transit_agent_idx, dest_mode, next_init = transition
+                        transit_agent_idx, dest_mode, next_init, paths = transition
+                        transition_paths.append(paths)
                         if dest_mode is None:
                             continue
                         # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0)
@@ -143,6 +193,8 @@ class Simulator:
                     )
                     node.child.append(tmp)
                     simulation_queue.append(tmp)
+                for agent_id in node.agent:
+                    self.cache.add_segment(agent_id, node, full_traces[agent_id], transition_paths, run_num)
                 # Put the node in the child of current node. Put the new node in the queue
             #     node.child.append(AnalysisTreeNode(
             #         trace = next_node_trace,
diff --git a/verse/parser/parser.py b/verse/parser/parser.py
index 88514ba1..e1c62de4 100644
--- a/verse/parser/parser.py
+++ b/verse/parser/parser.py
@@ -244,6 +244,10 @@ class ModePath:
     val: Any
     val_veri: ast.expr
 
+    def __eq__(self, other: object) -> bool:
+        # TODO: more general equivalence?
+        return self.cond == other.cond and self.val == other.val
+
 @dataclass
 class ControllerIR:
     args: LambdaArgs
@@ -322,7 +326,6 @@ class ControllerIR:
                     cond = compile_expr(Env.trans_args(cond, False))
                     val = compile_expr(Env.trans_args(case.val, False))
                     paths.append(ModePath(cond, cond_veri, var, val, val_veri))
-
         return ControllerIR(controller.args, paths, asserts_sim, asserts_veri, env.state_defs, env.mode_defs)
 
 @dataclass
diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py
index 7e523c13..071dfda0 100644
--- a/verse/scenario/scenario.py
+++ b/verse/scenario/scenario.py
@@ -1,4 +1,4 @@
-from lib2to3.pytree import Base
+from pprint import pp
 from typing import DefaultDict, Optional, Tuple, List, Dict, Any
 import copy
 import itertools
@@ -11,6 +11,7 @@ import numpy as np
 
 from verse.agents.base_agent import BaseAgent
 from verse.analysis.incremental import CachedSegment, CachedTransition
+from verse.analysis.simulator import PathDiffs
 from verse.automaton import GuardExpressionAst, ResetExpression
 from verse.analysis import Simulator, Verifier, AnalysisTreeNode, AnalysisTree
 from verse.analysis.utils import sample_rect
@@ -48,6 +49,77 @@ def pack_env(agent: BaseAgent, ego_ty_name: str, cont, disc, lane_map):
     packed[EGO] = state_ty(**packed[EGO])
     return dict(packed.items())
 
+def check_transitions(agent: BaseAgent, guards, cont, disc, map, state, mode):
+    asserts = []
+    satisfied_guard = []
+    agent_id = agent.id
+    # Unsafety checking
+    ego_ty_name = find(agent.controller.args, lambda a: a.name == EGO).typ
+    packed_env = pack_env(agent, ego_ty_name, cont, disc, map)
+
+    # Check safety conditions
+    for assertion in agent.controller.asserts:
+        if eval(assertion.pre, packed_env):
+            if not eval(assertion.cond, packed_env):
+                del packed_env["__builtins__"]
+                print(f"assert hit for {agent_id}: \"{assertion.label}\" @ {packed_env}")
+                asserts.append(assertion.label)
+    if len(asserts) != 0:
+        return asserts, satisfied_guard
+
+    all_resets = defaultdict(list)
+    for disc_vars, path in guards:
+        env = pack_env(agent, ego_ty_name, cont, disc_vars, map)    # TODO: diff disc -> disc_vars?
+
+        # Collect all the hit guards for this agent at this time step
+        if eval(path.cond, env):
+            # If the guard can be satisfied, handle resets
+            all_resets[path.var].append((path.val, path))
+
+    iter_list = []
+    for vals in all_resets.values():
+        paths = [p for _, p in vals]
+        iter_list.append(zip(range(len(vals)), paths))
+    pos_list = list(itertools.product(*iter_list))
+    if len(pos_list) == 1 and pos_list[0] == ():
+        raise NotImplementedError("??")
+    for pos in pos_list:
+        next_init = copy.deepcopy(state)
+        dest = copy.deepcopy(mode)
+        possible_dest = [[elem] for elem in dest]
+        for j, (reset_idx, path) in enumerate(pos):
+            reset_variable = list(all_resets.keys())[j]
+            res = eval(all_resets[reset_variable][reset_idx][0], packed_env)
+            ego_type = agent.controller.state_defs[ego_ty_name]
+            if "mode" in reset_variable:
+                var_loc = ego_type.disc.index(reset_variable)
+                assert not isinstance(res, list), res
+                possible_dest[var_loc] = [(res, path)]
+            else:
+                var_loc = ego_type.cont.index(reset_variable)
+                next_init[var_loc] = res
+        print("possible_dest")
+        pp(possible_dest)
+        all_dest = list(itertools.product(*possible_dest))
+        print("all_dest")
+        pp(all_dest)
+        if not all_dest:
+            warnings.warn(
+                f"Guard hit for mode {mode} for agent {agent_id} without available next mode")
+            all_dest.append(None)
+        for dest in all_dest:
+            assert isinstance(dest, tuple)
+            paths = []
+            pure_dest = []
+            for d in dest:
+                if isinstance(d, tuple):
+                    pure_dest.append(d[0])
+                    paths.append(d[1])
+                else:
+                    pure_dest.append(d)
+            satisfied_guard.append((agent_id, pure_dest, next_init, paths))
+    return None, satisfied_guard
+
 @dataclass
 class ScenarioConfig:
     incremental: bool = False
@@ -66,6 +138,7 @@ class Scenario:
         self.uncertain_param_dict = {}
         self.map = LaneMap()
         self.sensor = BaseSensor()
+        self.past_runs = []
 
         # Parameters
         self.config = config
@@ -193,7 +266,9 @@ class Scenario:
             uncertain_param_list.append(self.uncertain_param_dict[agent_id])
             agent_list.append(self.agent_dict[agent_id])
         print(init_list)
-        return self.simulator.simulate(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, time_step, self.map)
+        tree = self.simulator.simulate(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon, time_step, self.map, len(self.past_runs), self.past_runs)
+        self.past_runs.append(tree)
+        return tree
 
     def verify(self, time_horizon, time_step, reachability_method='DRYVR', params={}) -> AnalysisTree:
         self.check_init()
@@ -212,10 +287,10 @@ class Scenario:
             static_list.append(self.static_dict[agent_id])
             uncertain_param_list.append(self.uncertain_param_dict[agent_id])
             agent_list.append(self.agent_dict[agent_id])
-
-        res = self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon,
-                                                   time_step, self.map, self.config.init_seg_length, self.config.reachability_method, params)
-        return res
+        tree = self.verifier.compute_full_reachtube(init_list, init_mode_list, static_list, uncertain_param_list, agent_list, self, time_horizon,
+                                                    time_step, self.map, self.config.init_seg_length, self.config.reachability_method, params, len(self.past_runs))
+        self.past_runs.append(tree)
+        return tree
 
     def apply_reset(self, agent: BaseAgent, reset_list, all_agent_state) -> Tuple[str, np.ndarray]:
         lane_map = self.map
@@ -326,106 +401,83 @@ class Scenario:
     #         unrolled_variable, unrolled_variable_index = updater[variable]
     #         disc_var_dict[unrolled_variable] = disc_var_dict[variable][unrolled_variable_index]
 
-    def get_transition_simulate_new(self, node: AnalysisTreeNode, cache: Dict[str, CachedSegment]) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]:
+    # def get_transition_simulate_new(self, diffed: Tuple[PathDiffs, PathDiffs, PathDiffs], cache: Dict[str, CachedSegment]) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]:
+    def get_transition_simulate_new(self, cache: Dict[str, CachedSegment], paths: PathDiffs, node: AnalysisTreeNode) -> Tuple[Optional[Dict[str, List[str]]], Optional[Dict[str, List[Tuple[str, List[str], List[float]]]]], int]:
         lane_map = self.map
         trace_length = len(list(node.trace.values())[0])
 
         # For each agent
         agent_guard_dict = defaultdict(list)
+        cached_guards = defaultdict(list)
 
-        for agent_id in node.agent:
+        if not cache:
+            paths = [(agent, p) for agent in node.agent.values() for p in agent.controller.paths]
+            path_transitions = {}
+        else:
+            if len(paths) == 0:
+                transition = min(trans.transition for seg in cache.values() for trans in seg.transitions)
+                transitions = defaultdict(list)
+                for agent_id, seg in cache.items():
+                    # TODO: check for asserts
+                    for tran in seg.transitions:
+                        if tran.transition == transition:
+                            for path in tran.paths:
+                                transitions[agent_id].append((agent_id, tran.disc, tran.cont, path))
+                return None, transitions, transition
+
+            path_transitions = defaultdict(int)
+            for seg in cache.values():
+                for tran in seg.transitions:
+                    for p in tran.paths:
+                        path_transitions[p.cond] = max(path_transitions[p.cond], tran.transition)
+            for agent_id, segment in cache.items():
+                agent = node.agent[agent_id]
+                agent_mode = node.mode[agent_id]
+                if len(agent.controller.args) == 0:
+                    continue
+                state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent}
+                agent_paths = {p for tran in segment.transitions for p in tran.paths}
+                cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
+                for path in agent_paths:
+                    cached_guards[agent_id].append((path, discrete_variable_dict, path_transitions[path.cond]))
+
+        for agent, path in paths:
             # Get guard
-            agent: BaseAgent = self.agent_dict[agent_id]
+            agent_id = agent.id
             agent_mode = node.mode[agent_id]
             if len(agent.controller.args) == 0:
                 continue
-            state_dict = {}
-            for tmp in node.agent:
-                state_dict[tmp] = (node.trace[tmp][0],
-                                   node.mode[tmp], node.static[tmp])
-            cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(
-                self, agent, state_dict, self.map)
-            paths = agent.controller.paths
-            for path in paths:
-                agent_guard_dict[agent_id].append(
-                    (path.cond, discrete_variable_dict, path.var, path.val))
+            state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent}
+            cont_var_dict_template, discrete_variable_dict, len_dict = self.sensor.sense(self, agent, state_dict, self.map)
+            agent_guard_dict[agent_id].append((path, discrete_variable_dict))
 
         transitions = defaultdict(list)
         # TODO: We can probably rewrite how guard hit are detected and resets are handled for simulation
         for idx in range(trace_length):
             satisfied_guard = []
-            asserts = defaultdict(list)
+            all_asserts = defaultdict(list)
             for agent_id in agent_guard_dict:
                 agent: BaseAgent = self.agent_dict[agent_id]
-                state_dict = {}
-                for tmp in node.agent:
-                    state_dict[tmp] = (node.trace[tmp][idx],
-                                       node.mode[tmp], node.static[tmp])
+                state_dict = {aid: (node.trace[aid][0], node.mode[aid], node.static[aid]) for aid in node.agent}
                 agent_state, agent_mode, agent_static = state_dict[agent_id]
                 agent_state = agent_state[1:]
                 continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense(
                     self, agent, state_dict, self.map)
-                # Unsafety checking
-                ego_ty_name = find(agent.controller.args, lambda a: a.name == EGO).typ
-                packed_env = pack_env(agent, ego_ty_name, continuous_variable_dict, orig_disc_vars, self.map)
-
-                # Check safety conditions
-                for assertion in agent.controller.asserts:
-                    if eval(assertion.pre, packed_env):
-                        if not eval(assertion.cond, packed_env):
-                            del packed_env["__builtins__"]
-                            print(
-                                f"assert hit for {agent_id}: \"{assertion.label}\" @ {packed_env}")
-                            asserts[agent_id].append(assertion.label)
-                if agent_id in asserts:
-                    continue
-
-                all_resets = defaultdict(list)
-                for guard_comp, discrete_variable_dict, var, reset in agent_guard_dict[agent_id]:
-                    new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
-                    env = pack_env(agent, ego_ty_name, new_cont_var_dict, discrete_variable_dict, self.map)
-
-                    # Collect all the hit guards for this agent at this time step
-                    if eval(guard_comp, env):
-                        # If the guard can be satisfied, handle resets
-                        all_resets[var].append(reset)
-
-                iter_list = []
-                for reset_var in all_resets:
-                    iter_list.append(range(len(all_resets[reset_var])))
-                pos_list = list(itertools.product(*iter_list))
-                if len(pos_list) == 1 and pos_list[0] == ():
-                    continue
-                for i in range(len(pos_list)):
-                    pos = pos_list[i]
-                    next_init = copy.deepcopy(agent_state)
-                    dest = copy.deepcopy(agent_mode)
-                    possible_dest = [[elem] for elem in dest]
-                    for j, reset_idx in enumerate(pos):
-                        reset_variable = list(all_resets.keys())[j]
-                        res = eval(all_resets[reset_variable]
-                                   [reset_idx], packed_env)
-                        ego_type = agent.controller.state_defs[ego_ty_name]
-                        if "mode" in reset_variable:
-                            var_loc = ego_type.disc.index(reset_variable)
-                            if not isinstance(res, list):
-                                res = [res]
-                            possible_dest[var_loc] = res
-                        else:
-                            var_loc = ego_type.cont.index(reset_variable)
-                            next_init[var_loc] = res
-                    all_dest = list(itertools.product(*possible_dest))
-                    if not all_dest:
-                        warnings.warn(
-                            f"Guard hit for mode {agent_mode} for agent {agent_id} without available next mode")
-                        all_dest.append(None)
-                    for dest in all_dest:
-                        satisfied_guard.append((agent_id, dest, next_init))
-            if len(asserts) > 0:
-                return asserts, None, idx
+                unchecked_cache_guards = [g[:2] for g in cached_guards[agent_id] if g[2] < idx]     # FIXME: off by 1?
+                asserts, satisfied = check_transitions(agent, agent_guard_dict[agent_id] + unchecked_cache_guards, continuous_variable_dict, orig_disc_vars, self.map, agent_state, agent_mode)
+                if asserts != None:
+                    all_asserts[agent_id] = asserts
+                    return all_asserts, transitions, idx
+                if len(satisfied) != 0:
+                    satisfied_guard.extend(satisfied)
+            if len(all_asserts) > 0:
+                return all_asserts, transitions, idx
             if len(satisfied_guard) > 0:
-                for agent_idx, dest_mode, next_init in satisfied_guard:
-                    transitions[agent_idx].append((agent_idx, dest_mode, next_init))
+                print("satisfied_guard")
+                pp(satisfied_guard)
+                for agent_idx, dest_mode, next_init, paths in satisfied_guard:
+                    transitions[agent_idx].append((agent_idx, dest_mode, next_init, paths))
+                print("transitions", transitions)
                 break
         return None, transitions, idx
 
-- 
GitLab