diff --git a/demo/demo1.py b/demo/demo1.py index 21487dba13aa5034b0967437568223bd1ec35b65..b4c57e0c6047e4ab6d0feaec9a072a2697989757 100644 --- a/demo/demo1.py +++ b/demo/demo1.py @@ -69,7 +69,7 @@ if __name__ == "__main__": # plt.show() - traces = scenario.simulate(10) + traces = scenario.simulate(10, 0.01) fig = go.Figure() # fig = plotly_map(tmp_map, 'g', fig) # fig = plotly_simulation_tree( diff --git a/demo/demo3.py b/demo/demo3.py index 4500d5a003208005853b4d228217455a12374397..a459d2ecf1799151f86ef91b974ddcc3dbc5eb1a 100644 --- a/demo/demo3.py +++ b/demo/demo3.py @@ -35,14 +35,14 @@ class State: v = 0.0 vehicle_mode: VehicleMode = VehicleMode.Normal lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle + type_mode: LaneObjectMode = LaneObjectMode.Vehicle - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass if __name__ == "__main__": - input_code_name = './demo/example_controller4.py' + input_code_name = './example_controller4.py' scenario = Scenario() car = CarAgent('car1', file_name=input_code_name) @@ -63,14 +63,14 @@ if __name__ == "__main__": [[30, 0, 0, 0.5],[30, 0, 0, 0.5]], ], [ - (VehicleMode.Normal, LaneMode.Lane1), - (VehicleMode.Normal, LaneMode.Lane1), - (VehicleMode.Normal, LaneMode.Lane0), - (VehicleMode.Normal, LaneMode.Lane1), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane0, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), ] ) - traces = scenario.simulate(70) - # traces = scenario.verify(70) + traces = scenario.simulate(70, 0.05) + # traces = scenario.verify(70, 0.05) # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) diff --git a/demo/demo4.py b/demo/demo4.py index db61a07ce61085bf52b8be47483e8b5be84b0d08..a0025e590cdc268c8ae6d27accc5bed50bd614d0 100644 --- a/demo/demo4.py +++ b/demo/demo4.py @@ -79,8 +79,8 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane3), ] ) - traces = scenario.simulate(80) - # traces = scenario.verify(80) + traces = scenario.simulate(80, 0.05) + # traces = scenario.verify(80, 0.05) # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) diff --git a/demo/demo5.py b/demo/demo5.py index b4412151744fe597af87660b19a93227c10b7c7c..8c9ffdd990c1db4a8bd528114f2ba5f66204eac0 100644 --- a/demo/demo5.py +++ b/demo/demo5.py @@ -29,19 +29,6 @@ class LaneMode(Enum): Lane1 = auto() Lane2 = auto() -class State: - x = 0.0 - y = 0.0 - theta = 0.0 - v = 0.0 - vehicle_mode: VehicleMode = VehicleMode.Normal - lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle - - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): - pass - - if __name__ == "__main__": input_code_name = './example_controller7.py' scenario = Scenario() @@ -72,8 +59,8 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane0), ] ) - # traces = scenario.simulate(70) - traces = scenario.verify(60) + # traces = scenario.simulate(70, 0.05) + traces = scenario.verify(60, 0.05) fig = plt.figure(2) fig = plot_map(tmp_map, 'g', fig) diff --git a/demo/demo6.py b/demo/demo6.py index 4030efdfb6a143907cda881758a824f4c210244a..b2e6563b9a16357921aeab17038cbce4dcdcbdb0 100644 --- a/demo/demo6.py +++ b/demo/demo6.py @@ -30,17 +30,17 @@ class LaneMode(Enum): Lane2 = auto() Lane3 = auto() -class State: - x = 0.0 - y = 0.0 - theta = 0.0 - v = 0.0 - vehicle_mode: VehicleMode = VehicleMode.Normal - lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle +# class State: +# x = 0.0 +# y = 0.0 +# theta = 0.0 +# v = 0.0 +# vehicle_mode: VehicleMode = VehicleMode.Normal +# lane_mode: LaneMode = LaneMode.Lane0 +# type: LaneObjectMode = LaneObjectMode.Vehicle - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): - pass +# def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): +# pass if __name__ == "__main__": @@ -83,8 +83,8 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane3), ] ) - # traces = scenario.simulate(80) - traces = scenario.verify(50) + # traces = scenario.simulate(80, 0.05) + traces = scenario.verify(50, 0.05) fig = plt.figure(2) fig = plot_map(tmp_map, 'g', fig) diff --git a/demo/demo7.py b/demo/demo7.py index b3b5c0c43f928138d6260386fbd83b2c96d64642..3d6b1de9b250575aeec0c2b059a4edf83b6246fb 100644 --- a/demo/demo7.py +++ b/demo/demo7.py @@ -89,8 +89,8 @@ if __name__ == "__main__": (VehicleMode.Normal, LaneMode.Lane3), ] ) - traces = scenario.simulate(80) - # traces = scenario.verify(15) + traces = scenario.simulate(80, 0.05) + # traces = scenario.verify(15, 0.05) # fig = plt.figure(2) # fig = plot_map(tmp_map, 'g', fig) diff --git a/demo/demo8.py b/demo/demo8.py new file mode 100644 index 0000000000000000000000000000000000000000..99b4db8f2ab3251a48ec93afc2a9a169f12cde16 --- /dev/null +++ b/demo/demo8.py @@ -0,0 +1,78 @@ +from dryvr_plus_plus.example.example_agent.car_agent import CarAgent, NPCAgent +from dryvr_plus_plus.example.example_agent.sign_agent import SignAgent +from dryvr_plus_plus.scene_verifier.scenario.scenario import Scenario +from dryvr_plus_plus.example.example_map.simple_map2 import SimpleMap2, SimpleMap3, SimpleMap5, SimpleMap6 +from dryvr_plus_plus.plotter.plotter2D import * +from dryvr_plus_plus.example.example_sensor.fake_sensor import FakeSensor3 + +import matplotlib.pyplot as plt +import plotly.graph_objects as go +import numpy as np +from enum import Enum, auto + +class LaneObjectMode(Enum): + Vehicle = auto() + Ped = auto() # Pedestrians + Sign = auto() # Signs, stop signs, merge, yield etc. + Signal = auto() # Traffic lights + Obstacle = auto() # Static (to road/lane) obstacles + +class VehicleMode(Enum): + Normal = auto() + SwitchLeft = auto() + SwitchRight = auto() + Brake = auto() + +class LaneMode(Enum): + Lane0 = auto() + Lane1 = auto() + Lane2 = auto() + +class State: + x:float + y:float + theta:float + v:float + vehicle_mode: VehicleMode + lane_mode: LaneMode + type_mode: LaneObjectMode + + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): + pass + + +if __name__ == "__main__": + input_code_name = './example_controller9.py' + scenario = Scenario() + + car = CarAgent('car1', file_name=input_code_name) + scenario.add_agent(car) + sign = SignAgent('stop_sign') + scenario.add_agent(sign) + tmp_map = SimpleMap3() + scenario.set_map(tmp_map) + scenario.set_init( + [ + [[0, 0, 0, 1.0],[0, 0, 0, 1.0]], + [[15, 0, 0, 0.5],[15, 0, 0, 0.5]], + ], + [ + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Sign), + ] + ) + traces = scenario.simulate(13, 0.05) + # traces = scenario.verify(70, 0.05) + + # fig = plt.figure(2) + # fig = plot_map(tmp_map, 'g', fig) + # fig = plot_reachtube_tree(traces, 'car1', 1, [2], 'b', fig) + # fig = plot_reachtube_tree(traces, 'car2', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car3', 1, [2], 'r', fig) + # fig = plot_reachtube_tree(traces, 'car4', 1, [2], 'r', fig) + # plt.show() + + fig = go.Figure() + fig = plotly_simulation_anime(traces, tmp_map, fig) + fig.show() + diff --git a/demo/example_controller3.py b/demo/example_controller3.py index 1f606e0c0463ed5e60950d9b481d5036f761f27f..95209827ff4bd10619d39a3a734cc525d85c81a7 100644 --- a/demo/example_controller3.py +++ b/demo/example_controller3.py @@ -27,9 +27,9 @@ class State: v = 0.0 vehicle_mode: VehicleMode = VehicleMode.Normal lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode + type_mode: LaneObjectMode - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): self.data = [] def controller(ego:State, other:State, sign:State, lane_map:LaneMap): diff --git a/demo/example_controller4.py b/demo/example_controller4.py index 5f7f8f5a1885648d2e7a100b19073961cb6c9450..8645d7d812ffefe3237801a8aeb563d5deb01deb 100644 --- a/demo/example_controller4.py +++ b/demo/example_controller4.py @@ -27,15 +27,15 @@ class State: v = 0.0 vehicle_mode: VehicleMode = VehicleMode.Normal lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle + type_mode: LaneObjectMode = LaneObjectMode.Vehicle - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: - if any((other.x-ego.x > 3 and other.x-ego.x < 5 and ego.lane_mode == other.lane_mode) for other in others): + if any((other.x-ego.x > 3 and other.x-ego.x < 5 and ego.lane_mode == other.lane_mode and other.type_mode==LaneObjectMode.Vehicle) for other in others): if lane_map.has_left(ego.lane_mode): output.vehicle_mode = VehicleMode.SwitchLeft if any((other.x-ego.x > 3 and other.x-ego.x < 5 and ego.lane_mode == other.lane_mode) for other in others): diff --git a/demo/example_controller5.py b/demo/example_controller5.py index 4294de0e7c0ef59c338a58d637b67898d39055c5..1e8f409efacacc066da76784e93bebc498b7b6d2 100644 --- a/demo/example_controller5.py +++ b/demo/example_controller5.py @@ -27,9 +27,9 @@ class State: v = 0.0 vehicle_mode: VehicleMode = VehicleMode.Normal lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle + type_mode: LaneObjectMode = LaneObjectMode.Vehicle - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass def controller(ego:State, others:List[State], lane_map): diff --git a/demo/example_controller7.py b/demo/example_controller7.py index 0c30737a351192832788586ef60b5c345df29494..655278de8f2a845d20f559881080547187120bb0 100644 --- a/demo/example_controller7.py +++ b/demo/example_controller7.py @@ -27,9 +27,9 @@ class State: v = 0.0 vehicle_mode: VehicleMode = VehicleMode.Normal lane_mode: LaneMode = LaneMode.Lane0 - type: LaneObjectMode = LaneObjectMode.Vehicle + type_mode: LaneObjectMode = LaneObjectMode.Vehicle - def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass def controller(ego:State, others:List[State], lane_map): diff --git a/demo/example_controller9.py b/demo/example_controller9.py new file mode 100644 index 0000000000000000000000000000000000000000..1f7b38c95885533fc2c96f8fc249a15b18713f76 --- /dev/null +++ b/demo/example_controller9.py @@ -0,0 +1,47 @@ +from enum import Enum, auto +import copy +from typing import List + +class LaneObjectMode(Enum): + Vehicle = auto() + Ped = auto() # Pedestrians + Sign = auto() # Signs, stop signs, merge, yield etc. + Signal = auto() # Traffic lights + Obstacle = auto() # Static (to road/lane) obstacles + +class VehicleMode(Enum): + Normal = auto() + SwitchLeft = auto() + SwitchRight = auto() + Brake = auto() + Stop = auto() + +class LaneMode(Enum): + Lane0 = auto() + Lane1 = auto() + Lane2 = auto() + +class State: + x = 0.0 + y = 0.0 + theta = 0.0 + v = 0.0 + vehicle_mode: VehicleMode = VehicleMode.Normal + lane_mode: LaneMode = LaneMode.Lane0 + type_mode: LaneObjectMode = LaneObjectMode.Vehicle + + def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): + pass + +def controller(ego:State, others:List[State], lane_map): + output = copy.deepcopy(ego) + # Detect the stop sign + if ego.vehicle_mode == VehicleMode.Normal: + if any(other.x - ego.x < 5 and other.x - ego.x > -1 and other.type_mode==LaneObjectMode.Sign for other in others): + output.vehicle_mode = VehicleMode.Brake + if ego.vehicle_mode == VehicleMode.Brake: + if ego.v <= 0: + output.vehicle_mode = VehicleMode.Stop + output.v = 0 + return output + diff --git a/dryvr_plus_plus/example/example_agent/ball_agent.py b/dryvr_plus_plus/example/example_agent/ball_agent.py index 9fd550e84538fa409b3cbf8220ce194d9819a57e..e4658ba8452e10dd67f02acf79b2d5322bb4f54d 100644 --- a/dryvr_plus_plus/example/example_agent/ball_agent.py +++ b/dryvr_plus_plus/example/example_agent/ball_agent.py @@ -31,10 +31,9 @@ class BallAgent(BaseAgent): vy_dot = 0 return [x_dot, y_dot, vx_dot, vy_dot] - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: - # P1. Should TC_simulate really be part of the agent definition or should it be something more generic? - time_step = 0.05 - # P2. Looks like this should be a global parameter; some config file should be setting this. + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: + # TODO: P1. Should TC_simulate really be part of the agent definition or should it be something more generic? + # TODO: P2. Looks like this should be a global parameter; some config file should be setting this. time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) t = [round(i*time_step,10) for i in range(0,number_points)] diff --git a/dryvr_plus_plus/example/example_agent/car_agent.py b/dryvr_plus_plus/example/example_agent/car_agent.py index 3bee15fcc786d414fb599877dd039df23ba48e34..dde521e454e420d0e9a541a579c0b4d40592f101 100644 --- a/dryvr_plus_plus/example/example_agent/car_agent.py +++ b/dryvr_plus_plus/example/example_agent/car_agent.py @@ -37,8 +37,7 @@ class NPCAgent(BaseAgent): a = 0 return steering, a - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: - time_step = 0.05 + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) t = [i*time_step for i in range(0,number_points)] @@ -95,8 +94,7 @@ class CarAgent(BaseAgent): steering = np.clip(steering, -0.61, 0.61) return steering, a - def TC_simulate(self, mode: List[str], initialCondition, time_bound, lane_map:LaneMap=None)->np.ndarray: - time_step = 0.05 + def TC_simulate(self, mode: List[str], initialCondition, time_bound, time_step, lane_map:LaneMap=None)->np.ndarray: time_bound = float(time_bound) number_points = int(np.ceil(time_bound/time_step)) t = [round(i*time_step,10) for i in range(0,number_points)] diff --git a/dryvr_plus_plus/example/example_agent/sign_agent.py b/dryvr_plus_plus/example/example_agent/sign_agent.py index b6eb7fe383eb62f64a053312b5ef063f9cf40dab..2b97fa273ab00d7df9576fdd17a9fd21cbf2f0aa 100644 --- a/dryvr_plus_plus/example/example_agent/sign_agent.py +++ b/dryvr_plus_plus/example/example_agent/sign_agent.py @@ -7,8 +7,7 @@ class SignAgent(BaseAgent): self.id = id self.controller = EmptyAst() - def TC_simulate(self, mode, init, time_horizon, map=None): - time_step = 0.01 + def TC_simulate(self, mode, init, time_horizon, time_step, map=None): number_points = int(np.ceil(float(time_horizon)/time_step)) t = [i*time_step for i in range(0,number_points)] trace = [[0] + init] + [[i + time_step] + init for i in t] diff --git a/dryvr_plus_plus/scene_verifier/agents/base_agent.py b/dryvr_plus_plus/scene_verifier/agents/base_agent.py index 62d92d82e1a214554b24f6b1d45469d6e9d4cc68..96d912e91b5611cdda1785dd9c171b262c52b4df 100644 --- a/dryvr_plus_plus/scene_verifier/agents/base_agent.py +++ b/dryvr_plus_plus/scene_verifier/agents/base_agent.py @@ -5,5 +5,5 @@ class BaseAgent: self.controller = ControllerAst(code, file_name) self.id = id - def TC_simulate(self, mode, initialSet, time_horizon, map=None): + def TC_simulate(self, mode, initialSet, time_horizon, time_step, map=None): raise NotImplementedError \ No newline at end of file diff --git a/dryvr_plus_plus/scene_verifier/analysis/simulator.py b/dryvr_plus_plus/scene_verifier/analysis/simulator.py index c1099ecb985d9ec7f8d3471bfbe6406236917ca1..853122e70a76c806346dda383fa2fa90e47a8d98 100644 --- a/dryvr_plus_plus/scene_verifier/analysis/simulator.py +++ b/dryvr_plus_plus/scene_verifier/analysis/simulator.py @@ -11,7 +11,7 @@ class Simulator: def __init__(self): self.simulation_tree_root = None - def simulate(self, init_list, init_mode_list, agent_list:List[BaseAgent], transition_graph, time_horizon, lane_map): + def simulate(self, init_list, init_mode_list, agent_list:List[BaseAgent], transition_graph, time_horizon, time_step, lane_map): # Setup the root of the simulation tree root = AnalysisTreeNode( trace={}, @@ -43,7 +43,7 @@ class Simulator: # Simulate the trace starting from initial condition mode = node.mode[agent_id] init = node.init[agent_id] - trace = node.agent[agent_id].TC_simulate(mode, init, remain_time,lane_map) + trace = node.agent[agent_id].TC_simulate(mode, init, remain_time, time_step, lane_map) trace[:,0] += node.start_time node.trace[agent_id] = trace.tolist() diff --git a/dryvr_plus_plus/scene_verifier/analysis/verifier.py b/dryvr_plus_plus/scene_verifier/analysis/verifier.py index 123729a66b30cd660001a831f1f095f456c74adf..56a131b72dd8dd2883c0bd3ea146012c610afd10 100644 --- a/dryvr_plus_plus/scene_verifier/analysis/verifier.py +++ b/dryvr_plus_plus/scene_verifier/analysis/verifier.py @@ -21,6 +21,7 @@ class Verifier: agent_list:List[BaseAgent], transition_graph, time_horizon, + time_step, lane_map ): root = AnalysisTreeNode() @@ -53,6 +54,7 @@ class Verifier: cur_bloated_tube = calc_bloated_tube(mode, init, remain_time, + time_step, node.agent[agent_id].TC_simulate, 'PW', 100, diff --git a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py index 3e5370ffbdb5922c880c29c6b4c009c722856e17..780c4bd2a2a06f5a730d6774de9ac82b3080e91f 100644 --- a/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py +++ b/dryvr_plus_plus/scene_verifier/code_parser/pythonparser.py @@ -249,7 +249,7 @@ class ControllerAst(): self.code = code self.tree = ast.parse(code) - self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict, self.type_vars = self.initalwalktree(code, self.tree) + self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict = self.initalwalktree(code, self.tree) self.vertices = [] self.vertexStrings = [] for vertex in itertools.product(*self.modes.values()): @@ -367,7 +367,6 @@ class ControllerAst(): def initalwalktree(self, code, tree): vars = [] discrete_vars = [] - type_vars = [] out = [] mode_dict = {} state_object_dict = {} @@ -390,9 +389,7 @@ class ControllerAst(): if "init" in item.name: for arg in item.args.args: if "self" not in arg.arg: - if "type" == arg.arg: - state_object_dict[node.name]["type"].append(arg.arg) - elif "mode" not in arg.arg: + if "mode" not in arg.arg: state_object_dict[node.name]['cont'].append(arg.arg) # vars.append(arg.arg) else: @@ -431,11 +428,8 @@ class ControllerAst(): for var in state_object_dict[arg_annotation]['disc']: discrete_vars.append(arg_name+"."+var) vars_dict[arg_name]['disc'].append(var) - for var in state_object_dict[arg_annotation]['type']: - type_vars.append(arg_name+"."+var) - vars_dict[arg_name]['type'].append(var) - - return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict, type_vars] + + return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict] ''' @@ -495,7 +489,6 @@ class EmptyAst(ControllerAst): 'type':[] } } - self.type_vars = [] self.variables = [] self.vars_dict = [] self.vertexStrings = ['Null,Normal'] diff --git a/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py b/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py index a52a0f324c7a2c0c9678ddba3f682f850e9d4205..bbb1ccfe00b07b393d6ae00e26a31d6fe7bd7530 100644 --- a/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py +++ b/dryvr_plus_plus/scene_verifier/dryvr/core/dryvrcore.py @@ -245,6 +245,7 @@ def calc_bloated_tube( mode_label, initial_set, time_horizon, + time_step, sim_func, bloating_method, kvalue, @@ -275,11 +276,11 @@ def calc_bloated_tube( random.seed(4) cur_center = calcCenterPoint(initial_set[0], initial_set[1]) cur_delta = calcDelta(initial_set[0], initial_set[1]) - traces = [sim_func(mode_label, cur_center, time_horizon, lane_map)] + traces = [sim_func(mode_label, cur_center, time_horizon, time_step, lane_map)] # Simulate SIMTRACENUM times to learn the sensitivity for _ in range(sim_trace_num): new_init_point = randomPoint(initial_set[0], initial_set[1]) - traces.append(sim_func(mode_label, new_init_point, time_horizon, lane_map)) + traces.append(sim_func(mode_label, new_init_point, time_horizon, time_step, lane_map)) # Trim the trace to the same length traces = trimTraces(traces) diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/dryvr_plus_plus/scene_verifier/scenario/scenario.py index 1698eecc2fc2ee82341d935e38c98e101c8293f8..3b003429d3c5323a2b31dc0cd97b8898cd26f1df 100644 --- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py +++ b/dryvr_plus_plus/scene_verifier/scenario/scenario.py @@ -65,8 +65,8 @@ class Scenario: trace = self.simulate(time_horizon) res_list.append(trace) return res_list - - def simulate(self, time_horizon): + + def simulate(self, time_horizon, time_step): init_list = [] init_mode_list = [] agent_list = [] @@ -75,9 +75,9 @@ class Scenario: init_mode_list.append(self.init_mode_dict[agent_id]) agent_list.append(self.agent_dict[agent_id]) print(init_list) - return self.simulator.simulate(init_list, init_mode_list, agent_list, self, time_horizon, self.map) + return self.simulator.simulate(init_list, init_mode_list, agent_list, self, time_horizon, time_step, self.map) - def verify(self, time_horizon): + def verify(self, time_horizon, time_step): init_list = [] init_mode_list = [] agent_list = [] @@ -89,7 +89,7 @@ class Scenario: init_list.append(init) init_mode_list.append(self.init_mode_dict[agent_id]) agent_list.append(self.agent_dict[agent_id]) - return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, self.map) + return self.verifier.compute_full_reachtube(init_list, init_mode_list, agent_list, self, time_horizon, time_step, self.map) def check_guard_hit(self, state_dict): lane_map = self.map