diff --git a/example_two_car_sign_lane_switch.py b/example_two_car_sign_lane_switch.py index c0f5a2dfafbfd625d286a87acc192c363b1f8a7e..9f62fc213dcd69798a947b44b305a304b838da15 100644 --- a/example_two_car_sign_lane_switch.py +++ b/example_two_car_sign_lane_switch.py @@ -29,20 +29,21 @@ class State: lane_mode: LaneMode obj_mode: LaneObjectMode - def __init__(self, x: float = 0, y: float = 0, theta: float = 0, v: float = 0, vehicle_mode: VehicleMode = VehicleMode.Normal, lane_mode: LaneMode = LaneMode.Lane0, obj_mode: LaneObjectMode = LaneObjectMode.Vehicle): - self.data = [] - self.x = x - self.y = y - self.theta = theta - self.v = v - self.vehicle_mode = vehicle_mode - self.lane_mode = lane_mode - self.obj_mode = obj_mode + def __init__(self, x: float = 0, y: float = 0, theta: float = 0, v: float = 0, vehicle_mode: VehicleMode = VehicleMode.Normal, lane_mode: LaneMode = LaneMode.Lane0, type: LaneObjectMode = LaneObjectMode.Vehicle): + pass + # self.data = [] + # self.x = x + # self.y = y + # self.theta = theta + # self.v = v + # self.vehicle_mode = vehicle_mode + # self.lane_mode = lane_mode + # self.obj_mode = obj_mode def controller(ego: State, other: State, sign: State, lane_map): output = copy.deepcopy(ego) if ego.vehicle_mode == VehicleMode.Normal: - if sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode: + if sign.type == LaneObjectMode.Obstacle and sign.x - ego.x < 3 and sign.x - ego.x > 0 and ego.lane_mode == sign.lane_mode: output.vehicle_mode = VehicleMode.SwitchLeft return output if lane_map.get_longitudinal_position(other.lane_mode, [other.x,other.y]) - lane_map.get_longitudinal_position(ego.lane_mode, [ego.x,ego.y]) > 3 \ @@ -95,9 +96,9 @@ if __name__ == "__main__": [[20, 3, 0, 0],[20, 3, 0, 0]], ], [ - (VehicleMode.Normal, LaneMode.Lane2), - (VehicleMode.Normal, LaneMode.Lane1), - (VehicleMode.Normal, LaneMode.Lane1), + (VehicleMode.Normal, LaneMode.Lane2, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Vehicle), + (VehicleMode.Normal, LaneMode.Lane1, LaneObjectMode.Obstacle), ] ) # simulator = Simulator() diff --git a/src/example/example_sensor/fake_sensor.py b/src/example/example_sensor/fake_sensor.py index cbdd47da0734a4fe800ca27739efc0f82bc5e5e5..6f580dc9fe0e9a3fb612cd40d38a1cf0cf3466db 100644 --- a/src/example/example_sensor/fake_sensor.py +++ b/src/example/example_sensor/fake_sensor.py @@ -20,14 +20,14 @@ def sets(d, thing, attrs, vals): def set_states_2d(cnts, disc, thing, val): state, mode = val sets(cnts, thing, ["x", "y", "theta", "v"], state[1:5]) - sets(disc, thing, ["vehicle_mode", "lane_mode"], mode.split(",")) + sets(disc, thing, ["vehicle_mode", "lane_mode", "type"], mode.split(",")) def set_states_3d(cnts, disc, thing, val): state, mode = val transp = np.transpose(np.array(state)[:, 1:5]) assert len(transp) == 4 sets(cnts, thing, ["x", "y", "theta", "v"], transp) - sets(disc, thing, ["vehicle_mode", "lane_mode"], mode.split(",")) + sets(disc, thing, ["vehicle_mode", "lane_mode", "type"], mode.split(",")) class FakeSensor2: def sense(self, scenario, agent, state_dict, lane_map): diff --git a/src/scene_verifier/code_parser/pythonparser.py b/src/scene_verifier/code_parser/pythonparser.py index d42a19f22abbb9e620de2f3934f484484779a638..0a7e2a9dedfda709333d4da3839528c2256e8e0a 100644 --- a/src/scene_verifier/code_parser/pythonparser.py +++ b/src/scene_verifier/code_parser/pythonparser.py @@ -250,7 +250,7 @@ class ControllerAst(): self.code = code self.tree = ast.parse(code) - self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict = self.initalwalktree(code, self.tree) + self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict, self.type_vars = self.initalwalktree(code, self.tree) self.vertices = [] self.vertexStrings = [] for vertex in itertools.product(*self.modes.values()): @@ -274,7 +274,7 @@ class ControllerAst(): A path is a list of statements, all guards and resets along the path. They are in the order they are encountered in the code. TODO: should we not force all modes be listed? Or rerun for each unknown/don't care node? Or add them all to the list ''' - def getNextModes(self, currentModes, getAllPaths= False): + def getNextModes(self, currentModes: List[str], getAllPaths= False) -> List[str]: #walk the tree and capture all paths that have modes that are listed. Path is a list of statements paths = [] rootid = self.statementtree.root @@ -370,6 +370,7 @@ class ControllerAst(): def initalwalktree(self, code, tree): vars = [] discrete_vars = [] + type_vars = [] out = [] mode_dict = {} state_object_dict = {} @@ -385,13 +386,15 @@ class ControllerAst(): mode_dict[modeType] = modes if isinstance(node, ast.ClassDef): if "State" in node.name: - state_object_dict[node.name] = {"cont":[],"disc":[]} + state_object_dict[node.name] = {"cont":[],"disc":[], "type": []} for item in node.body: if isinstance(item, ast.FunctionDef): if "init" in item.name: for arg in item.args.args: if "self" not in arg.arg: - if "mode" not in arg.arg: + if "type" == arg.arg: + state_object_dict[node.name]["type"].append(arg.arg) + elif "mode" not in arg.arg: state_object_dict[node.name]['cont'].append(arg.arg) # vars.append(arg.arg) else: @@ -410,20 +413,23 @@ class ControllerAst(): continue arg_annotation = arg.annotation.id arg_name = arg.arg - vars_dict[arg_name] = {'cont':[], 'disc':[]} + vars_dict[arg_name] = {'cont':[], 'disc':[], "type": []} for var in state_object_dict[arg_annotation]['cont']: vars.append(arg_name+"."+var) vars_dict[arg_name]['cont'].append(var) for var in state_object_dict[arg_annotation]['disc']: discrete_vars.append(arg_name+"."+var) vars_dict[arg_name]['disc'].append(var) + for var in state_object_dict[arg_annotation]['type']: + type_vars.append(arg_name+"."+var) + vars_dict[arg_name]['type'].append(var) # if "mode" not in arg.arg: # vars.append(arg.arg) # #todo: what to add for return values # else: # discrete_vars.append(arg.arg) - return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict] + return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict, type_vars] '''