diff --git a/demo/example_controller4.py b/demo/example_controller4.py index d57ab38473d3313fd0937a4b9355c2a1133e30b2..dccf72553e4cf2a4b0dee63e0915ba666a90dbb3 100644 --- a/demo/example_controller4.py +++ b/demo/example_controller4.py @@ -32,7 +32,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) test = lambda other: other.x-ego.x > 3 and other.x-ego.x < 5 and ego.lane_mode == other.lane_mode if ego.vehicle_mode == VehicleMode.Normal: diff --git a/demo/example_controller5.py b/demo/example_controller5.py index 9d7bb281fe2fa3c169915c007d970090969328a9..1e8f409efacacc066da76784e93bebc498b7b6d2 100644 --- a/demo/example_controller5.py +++ b/demo/example_controller5.py @@ -32,7 +32,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: if any((lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ diff --git a/demo/example_controller6.py b/demo/example_controller6.py index abf5cacbc69fd9d7db6bf804208000942948b788..abb73f48d2cec3213d01ea389743989dcccac733 100644 --- a/demo/example_controller6.py +++ b/demo/example_controller6.py @@ -32,7 +32,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type: LaneObjectMode): pass -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: if any((other.x-ego.x > 3 and other.x-ego.x < 5 and ego.lane_mode == other.lane_mode) for other in others): diff --git a/demo/example_controller7.py b/demo/example_controller7.py index 59249f704516f1aef2270cd0be1dcbde32a06248..a7f2a75b0e3957b41bf3e097eb15f3422ed9cc33 100644 --- a/demo/example_controller7.py +++ b/demo/example_controller7.py @@ -47,7 +47,7 @@ def car_right(ego, others, lane_map): lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) >-3 and \ other.lane_mode==lane_map.right_lane(ego.lane_mode)) for other in others) -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: if car_front(ego, others, lane_map): diff --git a/demo/example_controller8.py b/demo/example_controller8.py index 35085827d80263f9b6a73fe5c17a7487f385819e..ae6ba309419ecb7721ac1a63c70e8f80254a2a57 100644 --- a/demo/example_controller8.py +++ b/demo/example_controller8.py @@ -48,7 +48,7 @@ def car_right(ego, others, lane_map): lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) >-3 and \ other.lane_mode==lane_map.right_lane(ego.lane_mode)) for other in others) -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: # Switch left if left lane is empty diff --git a/demo/example_controller9.py b/demo/example_controller9.py index b1ce3f15630522d92b8e8c52917f20ea8d02b777..1f7b38c95885533fc2c96f8fc249a15b18713f76 100644 --- a/demo/example_controller9.py +++ b/demo/example_controller9.py @@ -33,7 +33,7 @@ class State: def __init__(self, x, y, theta, v, vehicle_mode: VehicleMode, lane_mode: LaneMode, type_mode: LaneObjectMode): pass -def controller(ego:State, others:State, lane_map): +def controller(ego:State, others:List[State], lane_map): output = copy.deepcopy(ego) # Detect the stop sign if ego.vehicle_mode == VehicleMode.Normal: diff --git a/dryvr_plus_plus/scene_verifier/code_parser/parser.py b/dryvr_plus_plus/scene_verifier/code_parser/parser.py index 0813b10b0fffef709f4f7e76643000b7bb697922..db50129afaa6f6e8fb36dacd36b4f616ad713022 100644 --- a/dryvr_plus_plus/scene_verifier/code_parser/parser.py +++ b/dryvr_plus_plus/scene_verifier/code_parser/parser.py @@ -143,13 +143,15 @@ class Lambda: and isinstance(a.annotation.value, ast.Name) \ and a.annotation.value.id == 'List': typ = handle_simple_ann(a.annotation.slice) + is_list = True else: typ = handle_simple_ann(a.annotation) - args.append((a.arg, typ)) + is_list = False + args.append((a.arg, typ, is_list)) else: - args.append((a.arg, None)) + args.append((a.arg, None, False)) env.push() - for a, typ in args: + for a, typ, is_list in args: env.add_hole(a, typ) ret = None if isinstance(tree, ast.FunctionDef): @@ -167,7 +169,7 @@ class Lambda: def apply(self, args: List[ast.expr]) -> Tuple[List[Assert], ast.expr]: ret = copy.deepcopy(self.body) - subst = ArgSubstituter({k: v for (k, _), v in zip(self.args, args)}) + subst = ArgSubstituter({k: v for (k, _, _), v in zip(self.args, args)}) ret = subst.visit(ret) def visit_assert(a: Assert): a = copy.deepcopy(a) diff --git a/dryvr_plus_plus/scene_verifier/scenario/scenario.py b/dryvr_plus_plus/scene_verifier/scenario/scenario.py index fc18057e91f63bf18dffc16193a73056f77d4a8e..3970e8034b7023ada715a0e18dde1da3d630b75d 100644 --- a/dryvr_plus_plus/scene_verifier/scenario/scenario.py +++ b/dryvr_plus_plus/scene_verifier/scenario/scenario.py @@ -256,7 +256,7 @@ class Scenario: continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense(self, agent, state_dict, self.map) # Unsafety checking ego_ty_name = find(agent.controller.controller.args, lambda a: a[0] == EGO)[1] - def pack_env(agent, cont, disc, map): + def pack_env(agent: BaseAgent, cont, disc, map): env = copy.deepcopy(cont) env.update(disc) @@ -265,8 +265,15 @@ class Scenario: for k, v in env.items(): k = k.split(".") packed[k[0]][k[1]] = v - others_keys = list(packed[OTHERS].keys()) - packed[OTHERS] = [state_ty(**{k: packed[OTHERS][k][i] for k in others_keys}) for i in range(len(packed[OTHERS][others_keys[0]]))] + for arg, arg_type, is_list in agent.controller.controller.args: + if arg != EGO and 'map' not in arg: + other = arg + others_keys = list(packed[other].keys()) + if is_list: + packed[other] = [state_ty(**{k: packed[other][k][i] for k in others_keys}) for i in range(len(packed[other][others_keys[0]]))] + else: + packed[other] = state_ty(**{k: packed[other][k] for k in others_keys}) + packed[EGO] = state_ty(**packed[EGO]) map_var = find(agent.controller.controller.args, lambda a: "map" in a[0]) if map_var != None: diff --git a/dryvr_plus_plus/scene_verifier/sensor/base_sensor.py b/dryvr_plus_plus/scene_verifier/sensor/base_sensor.py index 2847b58fe6dadce346ed35aed9e98ee9844df456..52f054c8852cfec086c3fce7b811e79340cdf323 100644 --- a/dryvr_plus_plus/scene_verifier/sensor/base_sensor.py +++ b/dryvr_plus_plus/scene_verifier/sensor/base_sensor.py @@ -54,7 +54,7 @@ class BaseSensor(): controller_args = agent.controller.controller.args arg_type = None for arg in controller_args: - if arg[0] == 'ego': + if arg[0] != 'ego': arg_type = arg[1] break if arg_type is None: @@ -66,8 +66,10 @@ class BaseSensor(): else: controller_args = agent.controller.controller.args arg_type = None + arg_name = None for arg in controller_args: - if arg[0] == 'others': + if arg[0] != 'ego' and 'map' not in arg[0]: + arg_name = arg[0] arg_type = arg[1] break if arg_type is None: @@ -75,7 +77,7 @@ class BaseSensor(): cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static - add_states_2d(cont, disc, 'others', state_dict[agent_id], cont_var, disc_var, stat_var) + add_states_2d(cont, disc, arg_name, state_dict[agent_id], cont_var, disc_var, stat_var) else: for agent_id in state_dict: @@ -96,8 +98,10 @@ class BaseSensor(): else: controller_args = agent.controller.controller.args arg_type = None + arg_name = None for arg in controller_args: - if arg[0] == 'others': + if arg[0] != 'ego' and 'map' not in arg[0]: + arg_name = arg[0] arg_type = arg[1] break if arg_type is None: @@ -105,6 +109,6 @@ class BaseSensor(): cont_var = agent.controller.state_defs[arg_type].cont disc_var = agent.controller.state_defs[arg_type].disc stat_var = agent.controller.state_defs[arg_type].static - add_states_3d(cont, disc, 'others', state_dict[agent_id], cont_var, disc_var, stat_var) + add_states_3d(cont, disc, arg_name, state_dict[agent_id], cont_var, disc_var, stat_var) return cont, disc, len_dict \ No newline at end of file