diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index c1d20a576671c7a71f76e4cfc3bd49e3e4dcb409..ddb7532f9adf9b192cbbae126c238d1ea3d49e90 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -110,8 +110,6 @@ class AnalysisTree: node_id = 0 while queue: node = queue.pop(0) - print("NODE: ", node) - node.id = node_id res.append(node) node_id += 1 queue += node.child diff --git a/verse/analysis/simulator.py b/verse/analysis/simulator.py index 26091880065f6ebe137b11d066b0c368bf4af714..74e8c794e942e3fb1ff8fba96066fda6a595aa24 100644 --- a/verse/analysis/simulator.py +++ b/verse/analysis/simulator.py @@ -1,14 +1,12 @@ -from typing import List, Tuple -import copy -import itertools -import functools +from typing import Dict, List, Optional, Tuple +import copy, itertools, functools, pprint, ray -import pprint from verse.agents.base_agent import BaseAgent - from verse.analysis.incremental import SimTraceCache, convert_sim_trans, to_simulate from verse.analysis.utils import dedup +from verse.map.lane_map import LaneMap from verse.parser.parser import ModePath, find +# from verse.scenario.scenario import Scenario pp = functools.partial(pprint.pprint, compact=True, width=130) # from verse.agents.base_agent import BaseAgent @@ -20,7 +18,6 @@ PathDiffs = List[Tuple[BaseAgent, ModePath]] def red(s): return "\x1b[31m" + s + "\x1b[0m" #]] - class Simulator: def __init__(self, config): self.simulation_tree = None @@ -28,6 +25,146 @@ class Simulator: self.config = config self.cache_hits = (0, 0) + @ray.remote + def simulate_one(self, node: AnalysisTreeNode, remain_time: float, time_step: float, lane_map: LaneMap, run_num: int, past_runs: List[AnalysisTree], transition_graph: "Scenario") -> Tuple[int, List[AnalysisTreeNode], Dict[str, list]]: + print(f"node id: {node.id}") + cached_segments = {} + for agent_id in node.agent: + mode = node.mode[agent_id] + init = node.init[agent_id] + if self.config.incremental: + # pp(("check hit", agent_id, mode, init)) + cached = self.cache.check_hit(agent_id, mode, init, node.init) + if cached != None: + self.cache_hits = self.cache_hits[0] + 1, self.cache_hits[1] + else: + self.cache_hits = self.cache_hits[0], self.cache_hits[1] + 1 + # pp(("check hit res", agent_id, len(cached.transitions) if cached != None else None)) + else: + cached = None + if agent_id in node.trace: + if cached != None: + cached_segments[agent_id] = cached + else: + if cached != None: + node.trace[agent_id] = cached.trace + if len(cached.trace) < remain_time / time_step: + node.trace[agent_id] += node.agent[agent_id].TC_simulate(mode, cached.trace[-1], remain_time - time_step * len(cached.trace), lane_map) + cached_segments[agent_id] = cached + else: + # pp(("sim", agent_id, *mode, *init)) + # Simulate the trace starting from initial condition + trace = node.agent[agent_id].TC_simulate( + mode, init, remain_time, time_step, lane_map) + trace[:, 0] += node.start_time + trace = trace.tolist() + node.trace[agent_id] = trace + # pp(("cached_segments", cached_segments.keys())) + # TODO: for now, make sure all the segments comes from the same node; maybe we can do + # something to combine results from different nodes in the future + node_ids = list(set((s.run_num, s.node_id) for s in cached_segments.values())) + # assert len(node_ids) <= 1, f"{node_ids}" + new_cache, paths_to_sim = {}, [] + if len(node_ids) == 1 and len(cached_segments.keys()) == len(node.agent): + old_run_num, old_node_id = node_ids[0] + if old_run_num != run_num: + old_node = find(past_runs[old_run_num].nodes, lambda n: n.id == old_node_id) + assert old_node != None + new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_segments) + # pp(("to sim", new_cache.keys(), len(paths_to_sim))) + # else: + # print("!!!") + + asserts, transitions, transition_idx = transition_graph.get_transition_simulate(new_cache, paths_to_sim, node) + # pp(("transitions:", transition_idx, transitions)) + + node.assert_hits = asserts + # pp(("next init:", {a: trace[transition_idx] for a, trace in node.trace.items()})) + + # truncate the computed trajectories from idx and store the content after truncate + truncated_trace, full_traces = {}, {} + for agent_idx in node.agent: + full_traces[agent_idx] = node.trace[agent_idx] + if transitions: + truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:] + node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1] + + if asserts != None: # FIXME + return (node.id, [], node.trace) + # print(transition_idx) + # pp({a: len(t) for a, t in node.trace.items()}) + else: + # If there's no transitions (returned transitions is empty), continue + if not transitions: + if self.config.incremental: + for agent_id in node.agent: + if agent_id not in cached_segments: + self.cache.add_segment(agent_id, node, [], full_traces[agent_id], [], transition_idx, run_num) + # print(red("no trans")) + return (node.id, [], node.trace) + + transit_agents = transitions.keys() + # pp(("transit agents", transit_agents)) + if self.config.incremental: + for agent_id in node.agent: + transition = transitions[agent_id] if agent_id in transit_agents else [] + if agent_id in cached_segments: + cached_segments[agent_id].transitions.extend(convert_sim_trans(agent_id, transit_agents, node.init, transition, transition_idx)) + cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions, lambda i: (i.disc, i.cont, i.inits)) + # pre_len = len(cached_segments[agent_id].transitions) + # pp(("dedup!", pre_len, len(cached_segments[agent_id].transitions))) + else: + self.cache.add_segment(agent_id, node, transit_agents, full_traces[agent_id], transition, transition_idx, run_num) + # pp(("cached inits", self.cache.get_cached_inits(3))) + # Generate the transition combinations if multiple agents can transit at the same time step + transition_list = list(transitions.values()) + all_transition_combinations = itertools.product( + *transition_list) + + # For each possible transition, construct the new node. + # Obtain the new initial condition for agent having transition + # copy the traces that are not under transition + all_transition_paths = [] + next_nodes = [] + for transition_combination in all_transition_combinations: + transition_paths = [] + next_node_mode = copy.deepcopy(node.mode) + next_node_static = copy.deepcopy(node.static) + next_node_uncertain_param = copy.deepcopy(node.uncertain_param) + next_node_agent = node.agent + next_node_start_time = list( + truncated_trace.values())[0][0][0] + next_node_init = {} + next_node_trace = {} + for transition in transition_combination: + transit_agent_idx, dest_mode, next_init, paths = transition + if dest_mode is None: + continue + transition_paths.extend(paths) + # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0) + next_node_mode[transit_agent_idx] = dest_mode + next_node_init[transit_agent_idx] = next_init + for agent_idx in next_node_agent: + if agent_idx not in next_node_init: + next_node_trace[agent_idx] = truncated_trace[agent_idx] + next_node_init[agent_idx] = truncated_trace[agent_idx][0][1:] + + all_transition_paths.append(transition_paths) + tmp = AnalysisTreeNode( + trace=next_node_trace, + init=next_node_init, + mode=next_node_mode, + static=next_node_static, + uncertain_param=next_node_uncertain_param, + agent=next_node_agent, + child=[], + start_time=next_node_start_time, + type='simtrace' + ) + next_nodes.append(tmp) + print(len(next_nodes)) + return (node.id, next_nodes, node.trace) + def simulate(self, init_list, init_mode_list, static_list, uncertain_param_list, agent_list, transition_graph, time_horizon, time_step, max_height, lane_map, run_num, past_runs): # Setup the root of the simulation tree @@ -54,171 +191,41 @@ class Simulator: root.agent[agent.id] = agent root.type = 'simtrace' - simulation_queue = [] - simulation_queue.append(root) + root.id = 0 # FIXME + simulation_queue = [root] + result_refs = [] + nodes = [root] # Perform BFS through the simulation tree to loop through all possible transitions - while simulation_queue != []: - node: AnalysisTreeNode = simulation_queue.pop(0) - # Setup the root of the simulation tree - - # pp(("start sim", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode})) - remain_time = round(time_horizon - node.start_time, 10) - if remain_time <= 0: - continue - # For trace not already simulated - cached_segments = {} - for agent_id in node.agent: - mode = node.mode[agent_id] - init = node.init[agent_id] - if self.config.incremental: - # pp(("check hit", agent_id, mode, init)) - cached = self.cache.check_hit(agent_id, mode, init, node.init) - if cached != None: - self.cache_hits = self.cache_hits[0] + 1, self.cache_hits[1] - else: - self.cache_hits = self.cache_hits[0], self.cache_hits[1] + 1 - # pp(("check hit res", agent_id, len(cached.transitions) if cached != None else None)) - else: - cached = None - if agent_id in node.trace: - if cached != None: - cached_segments[agent_id] = cached - else: - if cached != None: - node.trace[agent_id] = cached.trace - if len(cached.trace) < remain_time / time_step: - node.trace[agent_id] += node.agent[agent_id].TC_simulate(mode, cached.trace[-1], remain_time - time_step * len(cached.trace), lane_map) - cached_segments[agent_id] = cached - else: - # pp(("sim", agent_id, *mode, *init)) - # Simulate the trace starting from initial condition - trace = node.agent[agent_id].TC_simulate( - mode, init, remain_time, time_step, lane_map) - trace[:, 0] += node.start_time - trace = trace.tolist() - node.trace[agent_id] = trace - # pp(("cached_segments", cached_segments.keys())) - # TODO: for now, make sure all the segments comes from the same node; maybe we can do - # something to combine results from different nodes in the future - node_ids = list(set((s.run_num, s.node_id) for s in cached_segments.values())) - # assert len(node_ids) <= 1, f"{node_ids}" - new_cache, paths_to_sim = {}, [] - if len(node_ids) == 1 and len(cached_segments.keys()) == len(node.agent): - old_run_num, old_node_id = node_ids[0] - if old_run_num != run_num: - old_node = find(past_runs[old_run_num].nodes, lambda n: n.id == old_node_id) - assert old_node != None - new_cache, paths_to_sim = to_simulate(old_node.agent, node.agent, cached_segments) - # pp(("to sim", new_cache.keys(), len(paths_to_sim))) - # else: - # print("!!!") - - asserts, transitions, transition_idx = transition_graph.get_transition_simulate(new_cache, paths_to_sim, node) - # pp(("transitions:", transition_idx, transitions)) - - node.assert_hits = asserts - # pp(("next init:", {a: trace[transition_idx] for a, trace in node.trace.items()})) - - # truncate the computed trajectories from idx and store the content after truncate - truncated_trace, full_traces = {}, {} - for agent_idx in node.agent: - full_traces[agent_idx] = node.trace[agent_idx] - if transitions: - truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:] - node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1] - if (node.height >= max_height): - print("max depth reached") - continue - - if asserts != None: - pass - # print(transition_idx) - # pp({a: len(t) for a, t in node.trace.items()}) - else: - # If there's no transitions (returned transitions is empty), continue - if not transitions: - if self.config.incremental: - for agent_id in node.agent: - if agent_id not in cached_segments: - self.cache.add_segment(agent_id, node, [], full_traces[agent_id], [], transition_idx, run_num) - # print(red("no trans")) + while True: + wait = False + if len(simulation_queue) > 0: + node: AnalysisTreeNode = simulation_queue.pop(0) + # pp(("start sim", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode})) + remain_time = round(time_horizon - node.start_time, 10) + if remain_time <= 0: continue - - - transit_agents = transitions.keys() - # pp(("transit agents", transit_agents)) - if self.config.incremental: - for agent_id in node.agent: - transition = transitions[agent_id] if agent_id in transit_agents else [] - if agent_id in cached_segments: - cached_segments[agent_id].transitions.extend(convert_sim_trans(agent_id, transit_agents, node.init, transition, transition_idx)) - pre_len = len(cached_segments[agent_id].transitions) - cached_segments[agent_id].transitions = dedup(cached_segments[agent_id].transitions, lambda i: (i.disc, i.cont, i.inits)) - # pp(("dedup!", pre_len, len(cached_segments[agent_id].transitions))) - else: - self.cache.add_segment(agent_id, node, transit_agents, full_traces[agent_id], transition, transition_idx, run_num) - # pp(("cached inits", self.cache.get_cached_inits(3))) - # Generate the transition combinations if multiple agents can transit at the same time step - transition_list = list(transitions.values()) - all_transition_combinations = itertools.product( - *transition_list) - - # For each possible transition, construct the new node. - # Obtain the new initial condition for agent having transition - # copy the traces that are not under transition - all_transition_paths = [] - for transition_combination in all_transition_combinations: - transition_paths = [] - next_node_mode = copy.deepcopy(node.mode) - next_node_static = copy.deepcopy(node.static) - next_node_uncertain_param = copy.deepcopy(node.uncertain_param) - next_node_agent = node.agent - next_node_start_time = list( - truncated_trace.values())[0][0][0] - next_node_init = {} - next_node_trace = {} - for transition in transition_combination: - transit_agent_idx, dest_mode, next_init, paths = transition - if dest_mode is None: - continue - transition_paths.extend(paths) - # next_node = AnalysisTreeNode(trace = {},init={},mode={},agent={}, child = [], start_time = 0) - next_node_mode[transit_agent_idx] = dest_mode - next_node_init[transit_agent_idx] = next_init - for agent_idx in next_node_agent: - if agent_idx not in next_node_init: - next_node_trace[agent_idx] = truncated_trace[agent_idx] - next_node_init[agent_idx] = truncated_trace[agent_idx][0][1:] - - all_transition_paths.append(transition_paths) - - tmp = AnalysisTreeNode( - trace=next_node_trace, - init=next_node_init, - mode=next_node_mode, - static=next_node_static, - uncertain_param=next_node_uncertain_param, - agent=next_node_agent, - height=node.height+1, - child=[], - start_time=next_node_start_time, - type='simtrace' - ) - node.child.append(tmp) - simulation_queue.append(tmp) - # print(red("end sim")) - # Put the node in the child of current node. Put the new node in the queue - # node.child.append(AnalysisTreeNode( - # trace = next_node_trace, - # init = next_node_init, - # mode = next_node_mode, - # agent = next_node_agent, - # child = [], - # start_time = next_node_start_time - # )) - # simulation_queue += node.child - #checkHeight(root, max_height) - + # For trace not already simulated + result_refs.append(self.simulate_one.remote(self, node, remain_time, time_step, lane_map, run_num, past_runs, transition_graph)) + if len(result_refs) >= self.config.parallel_sim_ahead: + wait = True + elif len(result_refs) > 0: + wait = True + else: + break + print(len(simulation_queue), len(result_refs)) + if wait: + [res], remaining = ray.wait(result_refs) + id, next_nodes, traces = ray.get(res) + print("got id:", id) + nodes[id].child = next_nodes + nodes[id].trace = traces + last_id = nodes[-1].id + for i, node in enumerate(next_nodes): + node.id = i + 1 + last_id + simulation_queue.extend(next_nodes) + nodes.extend(next_nodes) + result_refs = remaining + self.simulation_tree = AnalysisTree(root) return self.simulation_tree @@ -253,6 +260,9 @@ class Simulator: # Perform BFS through the simulation tree to loop through all possible transitions while simulation_queue != []: node: AnalysisTreeNode = simulation_queue.pop(0) + if (node.height >= max_height): + print("max depth reached") + continue #continue if we are at the depth limit pp(("start sim", node.start_time, {a: (*node.mode[a], *node.init[a]) for a in node.mode})) @@ -287,9 +297,6 @@ class Simulator: if transitions or asserts: truncated_trace[agent_idx] = node.trace[agent_idx][transition_idx:] node.trace[agent_idx] = node.trace[agent_idx][:transition_idx+1] - if (node.height >= max_height): - print("max depth reached") - continue if asserts != None: pass @@ -344,7 +351,7 @@ class Simulator: static=next_node_static, uncertain_param=next_node_uncertain_param, agent=next_node_agent, - height=node.height+1, + height=node.height + 1, child=[], start_time=next_node_start_time, type='simtrace' diff --git a/verse/scenario/scenario.py b/verse/scenario/scenario.py index 9637ebf171b18cb084e26027bba52043aa13f03c..1976fc7851ae7054551d8c9610bb7daacf05acbd 100644 --- a/verse/scenario/scenario.py +++ b/verse/scenario/scenario.py @@ -155,6 +155,7 @@ class ScenarioConfig: unsafe_continue: bool = False init_seg_length: int = 1000 reachability_method: str = 'DRYVR' + parallel_sim_ahead: int = 8 class Scenario: def __init__(self, config=ScenarioConfig()):