From dcc18ca2f00042518387e12a7bc9afd2bf039bd0 Mon Sep 17 00:00:00 2001
From: Yangge Li <li213@illinois.edu>
Date: Thu, 21 Apr 2022 16:00:58 -0500
Subject: [PATCH] further refine how variables are stored in ControllerAst,
 deciding the interface for setting map and sensor for user

---
 example_two_car_lane_switch.py | 42 ++++++++++++++++++----------------
 pythonparser.py                | 33 ++++++++++++++++++++++----
 2 files changed, 50 insertions(+), 25 deletions(-)

diff --git a/example_two_car_lane_switch.py b/example_two_car_lane_switch.py
index 01571762..14b3d14d 100644
--- a/example_two_car_lane_switch.py
+++ b/example_two_car_lane_switch.py
@@ -21,30 +21,30 @@ class State:
     def __init__(self):
         self.data = []
 
-def controller(ego_x, ego_y, ego_theta, ego_v, ego_vehicle_mode, ego_lane_mode, others_x, others_y, others_theta, others_v, others_vehicle_mode, others_lane_mode, map):
-    # output = ego
-    output_vehicle_mode = ego_vehicle_mode
-    output_lane_mode = ego_lane_mode
-    if ego_vehicle_mode == VehicleMode.Normal:
-        if ego_lane_mode == LaneMode.Lane0:
-            if others_x - ego_x > 3 and others_x - ego_x < 5 and map.can_swtich_left(ego_lane_mode):
-                output_vehicle_mode = VehicleMode.SwitchLeft
-                output_lane_mode = map.switch_left(ego_lane_mode)
-            if others_x - ego_x > 3 and others_x - ego_x < 5:
-                output_vehicle_mode = VehicleMode.SwitchRight
-    if ego_vehicle_mode == VehicleMode.SwitchLeft:
-        if ego_lane_mode == LaneMode.Lane0:
-            if ego_x - others_x > 10:
-                output_vehicle_mode = VehicleMode.Normal
-    if ego_vehicle_mode == VehicleMode.SwitchRight:
-        if ego_lane_mode == LaneMode.Lane0:
-            if ego_x - others_x > 10:
-                output_vehicle_mode = VehicleMode.Normal
+def controller(ego:State, other:State, map):
+    output = ego
+    if ego.vehicle_mode == VehicleMode.Normal:
+        if ego.lane_mode == LaneMode.Lane0:
+            if other.x - ego.x > 3 and other.x - ego.x < 5 and map.has_left(ego.lane_mode):
+                output.vehicle_mode = VehicleMode.SwitchLeft
+                output.lane_mode = map.left_lane(ego.lane_mode)
+            if other.x - ego.x > 3 and other.x - ego.x < 5:
+                output.vehicle_mode = VehicleMode.SwitchRight
+    if ego.vehicle_mode == VehicleMode.SwitchLeft:
+        if ego.lane_mode == LaneMode.Lane0:
+            if ego.x - other.x > 10:
+                output.vehicle_mode = VehicleMode.Normal
+    if ego.vehicle_mode == VehicleMode.SwitchRight:
+        if ego.lane_mode == LaneMode.Lane0:
+            if ego.x - other.x > 10:
+                output.vehicle_mode = VehicleMode.Normal
 
-    return output_vehicle_mode, output_lane_mode
+    return output
     
 from ourtool.agents.car_agent import CarAgent
 from ourtool.scenario.scenario import Scenario
+from user.sensor import SimpleSensor
+from user.map import SimpleMap
 import matplotlib.pyplot as plt 
 import numpy as np
 
@@ -56,6 +56,8 @@ if __name__ == "__main__":
     scenario.add_agent(car)
     car = CarAgent('car2', file_name=input_code_name)
     scenario.add_agent(car)
+    scenario.add_map(SimpleMap())
+    # scenario.set_sensor(SimpleSensor())
     scenario.set_init(
         [[0,0,0,1.0], [10,0,0,0.5]],
         [
diff --git a/pythonparser.py b/pythonparser.py
index 12857f6e..02d940e6 100644
--- a/pythonparser.py
+++ b/pythonparser.py
@@ -239,7 +239,7 @@ class ControllerAst():
 
         self.code = code
         self.tree = ast.parse(code)
-        self.statementtree, self.variables, self.modes, self.discrete_variables = self.initalwalktree(code, self.tree)
+        self.statementtree, self.variables, self.modes, self.discrete_variables, self.state_object_dict, self.vars_dict = self.initalwalktree(code, self.tree)
         self.vertices = []
         self.vertexStrings = []
         for vertex in itertools.product(*self.modes.values()):
@@ -250,7 +250,6 @@ class ControllerAst():
             self.vertexStrings.append(vertexstring)
         self.paths = None
 
-
     '''
     Function to populate paths variable with all paths of the controller.
     '''
@@ -362,6 +361,8 @@ class ControllerAst():
         discrete_vars = []
         out = []
         mode_dict = {}
+        state_object_dict = {}
+        vars_dict = {}
         for node in ast.walk(tree): #don't think we want to walk the whole thing because lose ordering/depth
             # Get all the modes
             if isinstance(node, ast.ClassDef):
@@ -373,21 +374,43 @@ class ControllerAst():
                     mode_dict[modeType] = modes
             if isinstance(node, ast.ClassDef):
                 if "State" in node.name:
+                    state_object_dict[node.name] = {"cont":[],"disc":[]}
                     for item in node.body:
                         if isinstance(item, ast.FunctionDef):
                             if "init" in item.name:
                                 for arg in item.args.args:
                                     if "self" not in arg.arg:
                                         if "mode" not in arg.arg:
-                                            vars.append(arg.arg)
+                                            state_object_dict[node.name]['cont'].append(arg.arg)
+                                            # vars.append(arg.arg)
                                         else:
-                                            discrete_vars.append(arg.arg)
+                                            state_object_dict[node.name]['disc'].append(arg.arg)
+                                            # discrete_vars.append(arg.arg)
             if isinstance(node, ast.FunctionDef):
                 if node.name == 'controller':
                     #print(node.body)
                     statementtree = self.parsenodelist(code, node.body, False, Tree(), None)
                     #print(type(node.args))
-        return [statementtree, vars, mode_dict, discrete_vars]
+                    args = node.args.args
+                    for arg in args:
+                        if arg.annotation is None:
+                            continue
+                        arg_annotation = arg.annotation.id
+                        arg_name = arg.arg
+                        vars_dict[arg_name] = {'cont':[], 'disc':[]}
+                        for var in state_object_dict[arg_annotation]['cont']:
+                            vars.append(arg_name+"."+var)
+                            vars_dict[arg_name]['cont'].append(var)
+                        for var in state_object_dict[arg_annotation]['disc']:
+                            discrete_vars.append(arg_name+"."+var)
+                            vars_dict[arg_name]['disc'].append(var)
+
+                        # if "mode" not in arg.arg:
+                        #     vars.append(arg.arg)
+                        #     #todo: what to add for return values
+                        # else:
+                        #     discrete_vars.append(arg.arg)
+        return [statementtree, vars, mode_dict, discrete_vars, state_object_dict, vars_dict]
 
 
     '''
-- 
GitLab