From 79a94c69f2a5ffde151ba1fe0fc36fcd3ff3b61e Mon Sep 17 00:00:00 2001 From: keyis2 <keyis2@illinois.edu> Date: Sat, 28 Jan 2023 22:00:32 +0800 Subject: [PATCH] cont --- demo/AEB/exp2_straight.py | 14 +++++----- demo/tacas2023/exp11/inc-expr.py | 9 ++++--- verse/analysis/analysis_tree.py | 40 ++++++++++++++++++++++++++++ verse/analysis/simulator.py | 12 ++++++--- verse/analysis/verifier.py | 9 +++---- verse/map/example_map/simple_map2.py | 1 + verse/plotter/plotter2D.py | 17 ++++++------ 7 files changed, 77 insertions(+), 25 deletions(-) diff --git a/demo/AEB/exp2_straight.py b/demo/AEB/exp2_straight.py index dbc93198..8167b546 100644 --- a/demo/AEB/exp2_straight.py +++ b/demo/AEB/exp2_straight.py @@ -5,6 +5,7 @@ from verse.scenario import ScenarioConfig # from noisy_sensor import NoisyVehicleSensor from verse.plotter.plotter2D import * import os +import ray from enum import Enum, auto import time @@ -75,10 +76,10 @@ if __name__ == "__main__": # (AgentMode.Normal, TrackMode.T0), ] ) - + ray.init(include_dashboard=True) start_time = time.time() - # traces = scenario.verify(40, 0.1, params={"bloating_method": 'GLOBAL'}) - traces = scenario.simulate(100,0.1) + traces = scenario.verify(20, 0.1, params={"bloating_method": 'GLOBAL'}) + # traces = scenario.simulate(100,0.1) run_time = time.time()-start_time traces.dump(parent_dir+'/sim_straight.json') @@ -93,7 +94,8 @@ if __name__ == "__main__": }) fig = go.Figure() - fig = simulation_tree(traces, tmp_map, fig, 1, 2, None, 'lines', 'trace') + # fig = simulation_tree(traces, tmp_map, fig, 1, 2, None, 'lines', 'trace') # fig = simulation_anime(traces, tmp_map, fig, 1, 2,None, 'lines', 'trace', time_step=0.1) - # fig = reachtube_anime(traces, tmp_map, fig, 1, 2, None,'lines', 'trace', combine_rect=1) - fig.show() \ No newline at end of file + fig = reachtube_tree(traces, tmp_map, fig, 1, 2, None,'lines', 'trace', combine_rect=1) + fig.show() + ray.shutdown() \ No newline at end of file diff --git a/demo/tacas2023/exp11/inc-expr.py b/demo/tacas2023/exp11/inc-expr.py index 43eb2d6b..cacbc6cc 100644 --- a/demo/tacas2023/exp11/inc-expr.py +++ b/demo/tacas2023/exp11/inc-expr.py @@ -12,6 +12,7 @@ from verse.scenario.scenario import ScenarioConfig import functools, pprint pp = functools.partial(pprint.pprint, compact=True, width=130) from typing import List +import ray class AgentMode(Enum): Normal = auto() @@ -61,9 +62,10 @@ if 'p' in arg: def run(sim, meas=False): time = timeit.default_timer() if sim: - traces = scenario.simulate(60, 0.1) + traces = scenario.simulate(60, 0.1, seed=4) else: - traces = scenario.verify(60, 0.1) + traces = scenario.verify(1, 0.1) + dur = timeit.default_timer() - time if 'd' in arg: traces.dump_tree() @@ -84,7 +86,7 @@ def run(sim, meas=False): cache_size = asizeof.asizeof(scenario.verifier.cache) + asizeof.asizeof(scenario.verifier.trans_cache) if meas: pp({ - "dur": timeit.default_timer() - time, + "dur": dur, "cache_size": cache_size, "node_count": ((0 if sim else scenario.verifier.num_transitions), len(traces.nodes)), "hits": scenario.simulator.cache_hits if sim else (scenario.verifier.tube_cache_hits, scenario.verifier.trans_cache_hits), @@ -164,6 +166,7 @@ if __name__ == "__main__": cont_inits = jerks(cont_inits, _jerks) scenario.set_init(cont_inits, *mode_inits) + ray.init() if 'b' in arg: run(sim, True) elif 'r' in arg: diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 46a528a8..1da051c5 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -1,6 +1,7 @@ from typing import List, Dict, Any import json from treelib import Tree +import numpy as np class AnalysisTreeNode: """AnalysisTreeNode class @@ -94,6 +95,35 @@ class AnalysisTreeNode: type = data['type'], ) + def __eq__(self, __o: object) -> bool: + assert isinstance(__o, AnalysisTreeNode) + if not (self.init==__o.init and + self.mode==__o.mode and + self.agent==__o.agent and + self.start_time==__o.start_time and + self.assert_hits==__o.assert_hits and + self.type==__o.type and + self.static==__o.static and + self.uncertain_param==__o.uncertain_param and + self.id==__o.id): + return False + if self.type=='simtrace': + for agent, trace in self.trace.items(): + trace_other = __o.trace[agent] + for (step, step_other) in zip(trace, trace_other): + if not np.allclose(step, step_other, equal_nan=True): + print("diff in trace:", step, step_other) + return False + elif self.type=='reachtube': + for agent, trace in self.trace.items(): + trace_other = __o.trace[agent] + for (step, step_other) in zip(trace, trace_other): + if not np.allclose(step, step_other, equal_nan=True): + print("diff in trace:", step, step_other) + return False + else: + raise ValueError + class AnalysisTree: def __init__(self, root): self.root:AnalysisTreeNode = root @@ -162,3 +192,13 @@ class AnalysisTree: for child in node.child: nid = AnalysisTree._dump_tree(child, tree, id, nid) return nid + 1 + + + def __eq__(self, __o: object) -> bool: + assert isinstance(__o, AnalysisTree) + if len(self.nodes) != len(__o.nodes): + return False + for (node, node_other) in zip(self.nodes, __o.nodes): + if not (node == node_other): + return False + return True diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py index f3108400..9a838ec2 100644 --- a/verse/analysis/simulator.py +++ b/verse/analysis/simulator.py @@ -1,5 +1,5 @@ from typing import Dict, List, Optional, Tuple -import copy, itertools, functools, pprint, ray +import copy, itertools, functools, pprint, ray, time from verse.agents.base_agent import BaseAgent from verse.analysis.incremental import SimTraceCache, convert_sim_trans, to_simulate @@ -164,7 +164,7 @@ class Simulator: print(len(next_nodes)) return (node.id, next_nodes, node.trace) - def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, + def simulate_par(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, transition_graph, time_horizon, time_step, lane_map, run_num, past_runs): # Setup the root of the simulation tree root = AnalysisTreeNode( @@ -177,6 +177,7 @@ class Simulator: child=[], start_time=0, ) + start=time.perf_counter() for i, agent in enumerate(agent_list): root.init[agent.id] = init_list[i] init_mode = [elem.name for elem in init_mode_list[i]] @@ -223,9 +224,11 @@ class Simulator: result_refs = remaining self.simulation_tree = AnalysisTree(root) + end=time.perf_counter() + print("simulate time in (s):", end-start) return self.simulation_tree - def simulate_simple(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, + def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, transition_graph, time_horizon, time_step, lane_map, run_num, past_runs): # Setup the root of the simulation tree root = AnalysisTreeNode( @@ -238,6 +241,7 @@ class Simulator: child=[], start_time=0, ) + # start=time.perf_counter() for i, agent in enumerate(agent_list): root.init[agent.id] = init_list[i] init_mode = [elem.name for elem in init_mode_list[i]] @@ -355,5 +359,7 @@ class Simulator: # simulation_queue += node.child self.simulation_tree = AnalysisTree(root) + # end=time.perf_counter() + # print("simulate time in (s):", end-start) return self.simulation_tree diff --git a/verse/analysis/verifier.py b/verse/analysis/verifier.py index 42b83b66..aab8480d 100644 --- a/verse/analysis/verifier.py +++ b/verse/analysis/verifier.py @@ -201,6 +201,7 @@ class Verifier: asserts, all_possible_transitions = transition_graph.get_transition_verify(new_cache, paths_to_sim, node) # pp(("transitions:", [(t[0], t[2]) for t in all_possible_transitions])) node.assert_hits = asserts + print(asserts) if asserts != None: asserts, idx = asserts for agent in node.agent: @@ -325,7 +326,7 @@ class Verifier: child=[], start_time = 0, ndigits = 10, - type = 'simtrace', + type = 'reachtube', id = 0 ) for i, agent in enumerate(agent_list): @@ -334,7 +335,6 @@ class Verifier: root.static[agent.id] = [elem.name for elem in static_list[i]] root.uncertain_param[agent.id] = uncertain_param_list[i] root.agent[agent.id] = agent - root.type = 'reachtube' verification_queue = [root] result_refs = [] nodes = [root] @@ -377,7 +377,7 @@ class Verifier: return self.reachtube_tree - def compute_full_reachtube_simple( + def compute_full_reachtube_ser( self, init_list: List[float], init_mode_list: List[str], @@ -405,7 +405,7 @@ class Verifier: child=[], start_time = 0, ndigits = 10, - type = 'simtrace', + type = 'reachtube', id = 0 ) # root = AnalysisTreeNode() @@ -417,7 +417,6 @@ class Verifier: root.static[agent.id] = init_static root.uncertain_param[agent.id] = uncertain_param_list[i] root.agent[agent.id] = agent - root.type = 'reachtube' verification_queue = [] verification_queue.append(root) num_calls = 0 diff --git a/verse/map/example_map/simple_map2.py b/verse/map/example_map/simple_map2.py index f7293dbb..f83df332 100644 --- a/verse/map/example_map/simple_map2.py +++ b/verse/map/example_map/simple_map2.py @@ -127,6 +127,7 @@ class SimpleMap4(LaneMap): } def left_lane(self, lane_mode): + print("left_lane", lane_mode) return self.left_dict[lane_mode] def right_lane(self,lane_mode): diff --git a/verse/plotter/plotter2D.py b/verse/plotter/plotter2D.py index ea77607d..19729b27 100644 --- a/verse/plotter/plotter2D.py +++ b/verse/plotter/plotter2D.py @@ -492,6 +492,7 @@ def reachtube_tree(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=go showlegend=False, )) previous_mode[agent_id] = node.mode[agent_id] + print(node.assert_hits) if node.assert_hits != None and agent_id in node.assert_hits[0]: fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]], mode='markers+text', @@ -636,14 +637,14 @@ def reachtube_anime(root: Union[AnalysisTree, AnalysisTreeNode], map=None, fig=g showlegend=False, )) previous_mode[agent_id] = node.mode[agent_id] - if node.assert_hits != None and agent_id in node.assert_hits: - fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]], - mode='markers+text', - text=['HIT:\n' + - a for a in node.assert_hits[agent_id]], - textfont={'color': 'black'}, - marker={'size': 4, 'color': 'black'}, - showlegend=False)) + if node.assert_hits != None and agent_id in node.assert_hits: + fig.add_trace(go.Scatter(x=[trace[-1, x_dim]], y=[trace[-1, y_dim]], + mode='markers+text', + text=['HIT:\n' + + a for a in node.assert_hits[agent_id]], + textfont={'color': 'black'}, + marker={'size': 4, 'color': 'black'}, + showlegend=False)) queue += node.child if scale_type == 'trace': fig.update_xaxes( -- GitLab