From 79a94c69f2a5ffde151ba1fe0fc36fcd3ff3b61e Mon Sep 17 00:00:00 2001
From: keyis2 <keyis2@illinois.edu>
Date: Sat, 28 Jan 2023 22:00:32 +0800
Subject: [PATCH] cont

---
 demo/AEB/exp2_straight.py            | 14 +++++-----
 demo/tacas2023/exp11/inc-expr.py     |  9 ++++---
 verse/analysis/analysis_tree.py      | 40 ++++++++++++++++++++++++++++
 verse/analysis/simulator.py          | 12 ++++++---
 verse/analysis/verifier.py           |  9 +++----
 verse/map/example_map/simple_map2.py |  1 +
 verse/plotter/plotter2D.py           | 17 ++++++------
 7 files changed, 77 insertions(+), 25 deletions(-)

diff --git a/demo/AEB/exp2_straight.py b/demo/AEB/exp2_straight.py
index dbc93198..8167b546 100644
--- a/demo/AEB/exp2_straight.py
+++ b/demo/AEB/exp2_straight.py
@@ -5,6 +5,7 @@ from verse.scenario import ScenarioConfig
 # from noisy_sensor import NoisyVehicleSensor
 from verse.plotter.plotter2D import *
 import os
+import ray
 
 from enum import Enum, auto
 import time
@@ -75,10 +76,10 @@ if __name__ == "__main__":
 #            (AgentMode.Normal, TrackMode.T0),
         ]
     )
-
+    ray.init(include_dashboard=True)
     start_time = time.time()
-    # traces = scenario.verify(40, 0.1, params={"bloating_method": 'GLOBAL'})
-    traces = scenario.simulate(100,0.1)
+    traces = scenario.verify(20, 0.1, params={"bloating_method": 'GLOBAL'})
+    # traces = scenario.simulate(100,0.1)
     run_time = time.time()-start_time 
     traces.dump(parent_dir+'/sim_straight.json')
 
@@ -93,7 +94,8 @@ if __name__ == "__main__":
     })
 
     fig = go.Figure()
-    fig = simulation_tree(traces, tmp_map, fig, 1, 2, None, 'lines', 'trace')
+    # fig = simulation_tree(traces, tmp_map, fig, 1, 2, None, 'lines', 'trace')
     # fig = simulation_anime(traces, tmp_map, fig, 1, 2,None, 'lines', 'trace', time_step=0.1)
-    # fig = reachtube_anime(traces, tmp_map, fig, 1, 2, None,'lines', 'trace', combine_rect=1)
-    fig.show()
\ No newline at end of file
+    fig = reachtube_tree(traces, tmp_map, fig, 1, 2, None,'lines', 'trace', combine_rect=1)
+    fig.show()
+    ray.shutdown()
\ No newline at end of file
diff --git a/demo/tacas2023/exp11/inc-expr.py b/demo/tacas2023/exp11/inc-expr.py
index 43eb2d6b..cacbc6cc 100644
--- a/demo/tacas2023/exp11/inc-expr.py
+++ b/demo/tacas2023/exp11/inc-expr.py
@@ -12,6 +12,7 @@ from verse.scenario.scenario import ScenarioConfig
 import functools, pprint
 pp = functools.partial(pprint.pprint, compact=True, width=130)
 from typing import List
+import ray
 
 class AgentMode(Enum):
     Normal = auto()
@@ -61,9 +62,10 @@ if 'p' in arg:
 def run(sim, meas=False):
     time = timeit.default_timer()
     if sim:
-        traces = scenario.simulate(60, 0.1)
+        traces = scenario.simulate(60, 0.1, seed=4)
     else:
-        traces = scenario.verify(60, 0.1)
+        traces = scenario.verify(1, 0.1)
+    dur = timeit.default_timer() - time
 
     if 'd' in arg:
         traces.dump_tree()
@@ -84,7 +86,7 @@ def run(sim, meas=False):
         cache_size = asizeof.asizeof(scenario.verifier.cache) + asizeof.asizeof(scenario.verifier.trans_cache)
     if meas:
         pp({
-            "dur": timeit.default_timer() - time,
+            "dur": dur,
             "cache_size": cache_size,
             "node_count": ((0 if sim else scenario.verifier.num_transitions), len(traces.nodes)),
             "hits": scenario.simulator.cache_hits if sim else (scenario.verifier.tube_cache_hits, scenario.verifier.trans_cache_hits),
@@ -164,6 +166,7 @@ if __name__ == "__main__":
         cont_inits = jerks(cont_inits, _jerks)
     scenario.set_init(cont_inits, *mode_inits)
 
+    ray.init()
     if 'b' in arg:
         run(sim, True)
     elif 'r' in arg:
diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py
index 46a528a8..1da051c5 100644
--- a/verse/analysis/analysis_tree.py
+++ b/verse/analysis/analysis_tree.py
@@ -1,6 +1,7 @@
 from typing import List, Dict, Any
 import json
 from treelib import Tree
+import numpy as np
 
 class AnalysisTreeNode:
     """AnalysisTreeNode class
@@ -94,6 +95,35 @@ class AnalysisTreeNode:
             type = data['type'],
         )
 
+    def __eq__(self, __o: object) -> bool:
+        assert isinstance(__o, AnalysisTreeNode)  
+        if not (self.init==__o.init and 
+                self.mode==__o.mode and 
+                self.agent==__o.agent and 
+                self.start_time==__o.start_time and 
+                self.assert_hits==__o.assert_hits and 
+                self.type==__o.type and 
+                self.static==__o.static and 
+                self.uncertain_param==__o.uncertain_param and 
+                self.id==__o.id):
+            return False
+        if self.type=='simtrace':
+            for agent, trace in self.trace.items():
+                trace_other = __o.trace[agent]
+                for (step, step_other) in zip(trace, trace_other):
+                    if not np.allclose(step, step_other, equal_nan=True):
+                        print("diff in trace:", step, step_other)
+                        return False
+        elif self.type=='reachtube':
+            for agent, trace in self.trace.items():
+                trace_other = __o.trace[agent]
+                for (step, step_other) in zip(trace, trace_other):
+                    if not np.allclose(step, step_other, equal_nan=True):
+                        print("diff in trace:", step, step_other)
+                        return False
+        else:
+            raise ValueError
+
 class AnalysisTree:
     def __init__(self, root):
         self.root:AnalysisTreeNode = root
@@ -162,3 +192,13 @@ class AnalysisTree:
         for child in node.child:
             nid = AnalysisTree._dump_tree(child, tree, id, nid)
         return nid + 1
+
+
+    def __eq__(self, __o: object) -> bool:
+        assert isinstance(__o, AnalysisTree)
+        if len(self.nodes) != len(__o.nodes):
+            return False
+        for (node, node_other) in zip(self.nodes, __o.nodes):
+            if not (node == node_other):
+                return False
+        return True
diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py
index f3108400..9a838ec2 100644
--- a/verse/analysis/simulator.py
+++ b/verse/analysis/simulator.py
@@ -1,5 +1,5 @@
 from typing import Dict, List, Optional, Tuple
-import copy, itertools, functools, pprint, ray
+import copy, itertools, functools, pprint, ray, time
 
 from verse.agents.base_agent import BaseAgent
 from verse.analysis.incremental import SimTraceCache, convert_sim_trans, to_simulate
@@ -164,7 +164,7 @@ class Simulator:
             print(len(next_nodes))
             return (node.id, next_nodes, node.trace)
 
-    def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list,
+    def simulate_par(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list,
                  transition_graph, time_horizon, time_step, lane_map, run_num, past_runs):
         # Setup the root of the simulation tree
         root = AnalysisTreeNode(
@@ -177,6 +177,7 @@ class Simulator:
             child=[],
             start_time=0,
         )
+        start=time.perf_counter()
         for i, agent in enumerate(agent_list):
             root.init[agent.id] = init_list[i]
             init_mode = [elem.name for elem in init_mode_list[i]]
@@ -223,9 +224,11 @@ class Simulator:
                 result_refs = remaining
         
         self.simulation_tree = AnalysisTree(root)
+        end=time.perf_counter()
+        print("simulate time in (s):", end-start)
         return self.simulation_tree
 
-    def simulate_simple(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list,
+    def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list,
                  transition_graph, time_horizon, time_step, lane_map, run_num, past_runs):
         # Setup the root of the simulation tree
         root = AnalysisTreeNode(
@@ -238,6 +241,7 @@ class Simulator:
             child=[],
             start_time=0,
         )
+        # start=time.perf_counter()
         for i, agent in enumerate(agent_list):
             root.init[agent.id] = init_list[i]
             init_mode = [elem.name for elem in init_mode_list[i]]
@@ -355,5 +359,7 @@ class Simulator:
             # simulation_queue += node.child
         
         self.simulation_tree = AnalysisTree(root)
+        # end=time.perf_counter()
+        # print("simulate time in (s):", end-start)
         return self.simulation_tree
 
diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py
index 42b83b66..aab8480d 100644
--- a/verse/analysis/verifier.py
+++ b/verse/analysis/verifier.py
@@ -201,6 +201,7 @@ class Verifier:
         asserts, all_possible_transitions = transition_graph.get_transition_verify(new_cache, paths_to_sim, node)
         # pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions]))
         node.assert_hits = asserts
+        print(asserts)
         if asserts != None:
             asserts, idx = asserts
             for agent in node.agent:
@@ -325,7 +326,7 @@ class Verifier:
             child=[],
             start_time = 0,
             ndigits = 10,
-            type = 'simtrace',
+            type = 'reachtube',
             id = 0
         )
         for i, agent in enumerate(agent_list):
@@ -334,7 +335,6 @@ class Verifier:
             root.static[agent.id] = [elem.name for elem in static_list[i]]
             root.uncertain_param[agent.id] = uncertain_param_list[i]
             root.agent[agent.id] = agent
-            root.type = 'reachtube'
         verification_queue = [root]
         result_refs = []
         nodes = [root]
@@ -377,7 +377,7 @@ class Verifier:
 
         return self.reachtube_tree
 
-    def compute_full_reachtube_simple(
+    def compute_full_reachtube_ser(
         self,
         init_list: List[float],
         init_mode_list: List[str],
@@ -405,7 +405,7 @@ class Verifier:
             child=[],
             start_time = 0,
             ndigits = 10,
-            type = 'simtrace',
+            type = 'reachtube',
             id = 0
         )
         # root = AnalysisTreeNode()
@@ -417,7 +417,6 @@ class Verifier:
             root.static[agent.id] = init_static
             root.uncertain_param[agent.id] = uncertain_param_list[i]
             root.agent[agent.id] = agent
-            root.type = 'reachtube'
         verification_queue = []
         verification_queue.append(root)
         num_calls = 0
diff --git a/verse/map/example_map/simple_map2.py b/verse/map/example_map/simple_map2.py
index f7293dbb..f83df332 100644
--- a/verse/map/example_map/simple_map2.py
+++ b/verse/map/example_map/simple_map2.py
@@ -127,6 +127,7 @@ class SimpleMap4(LaneMap):
         }
 
     def left_lane(self, lane_mode):
+        print("left_lane", lane_mode)
         return self.left_dict[lane_mode]
 
     def right_lane(self,lane_mode):
diff --git a/verse/plotter/plotter2D.py b/verse/plotter/plotter2D.py
index ea77607d..19729b27 100644
--- a/verse/plotter/plotter2D.py
+++ b/verse/plotter/plotter2D.py
@@ -492,6 +492,7 @@ def reachtube_tree(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=go
                         showlegend=False,
                     ))
                     previous_mode[agent_id] = node.mode[agent_id]
+            print(node.assert_hits)
             if node.assert_hits != None and agent_id in node.assert_hits[0]:
                 fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]],
                                          mode='markers+text',
@@ -636,14 +637,14 @@ def reachtube_anime(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=g
                         showlegend=False,
                     ))
                     previous_mode[agent_id] = node.mode[agent_id]
-            if node.assert_hits != None and agent_id in node.assert_hits:
-                fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]],
-                                         mode='markers+text',
-                                         text=['HIT:\n' +
-                                               a for a in node.assert_hits[agent_id]],
-                                         textfont={'color': 'black'},
-                                         marker={'size': 4, 'color': 'black'},
-                                         showlegend=False))
+                if node.assert_hits != None and agent_id in node.assert_hits:
+                    fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]],
+                                            mode='markers+text',
+                                            text=['HIT:\n' +
+                                                a for a in node.assert_hits[agent_id]],
+                                            textfont={'color': 'black'},
+                                            marker={'size': 4, 'color': 'black'},
+                                            showlegend=False))
         queue += node.child
     if scale_type == 'trace':
         fig.update_xaxes(
-- 
GitLab