From 1d51865c0e4df2dd538543f7da0c95e48d3de091 Mon Sep 17 00:00:00 2001
From: crides <zhuhaoqing@live.cn>
Date: Thu, 6 Apr 2023 21:04:42 -0500
Subject: [PATCH] fix(example.intersection): make intermediate lanes

---
 verse/map/example_map/intersection.py | 21 ++++----
 verse/map/lane_map.py                 | 72 +++++++++++----------------
 2 files changed, 41 insertions(+), 52 deletions(-)

diff --git a/verse/map/example_map/intersection.py b/verse/map/example_map/intersection.py
index 0c876095..c216ec42 100644
--- a/verse/map/example_map/intersection.py
+++ b/verse/map/example_map/intersection.py
@@ -74,7 +74,8 @@ class Intersection(LaneMap):
 
     def h(self, lane_idx: str, mode_src: str, mode_dest: str) -> Optional[str]:
         # ret = self._h(lane_idx, mode_src, mode_dest)
-        # print("H", lane_idx, mode_src, mode_dest, "->", ret)
+        # if ret != None and ret != lane_idx:
+        #     print("H", lane_idx, mode_src, mode_dest, "->", ret)
         # return ret
 
     # def _h(self, lane_idx: str, mode_src: str, mode_dest: str) -> Optional[str]:
@@ -83,19 +84,19 @@ class Intersection(LaneMap):
             if dst_sw:
                 return None
             else:
-                lanes, ind = lane_idx.split("_")
+                lanes, _, ind = lane_idx.split("_", 2)
                 ind = int(ind)
-                if "Left" in mode_src and ind > 0:
-                    return f"{lanes}.{ind - 1}"
-                if "Right" in mode_src and ind < self.lanes - 1:
-                    return f"{lanes}.{ind + 1}"
-                return None
+                return f"{lanes}_{ind}"
         else:
             if dst_sw:
                 lanes, ind = lane_idx.split("_")
                 ind = int(ind)
-                if "Left" in mode_dest and ind > 0 or "Right" in mode_dest and ind < self.lanes - 1:
-                    return lane_idx
-                return None
+                if "Left" in mode_dest and ind > 0:
+                    new_ind = ind - 1
+                elif "Right" in mode_dest and ind < self.lanes - 1:
+                    new_ind = ind + 1
+                else:
+                    return None
+                return f"{lanes}_{ind}_{new_ind}"
             else:
                 return lane_idx
diff --git a/verse/map/lane_map.py b/verse/map/lane_map.py
index 7c2b3661..e91bf5bc 100644
--- a/verse/map/lane_map.py
+++ b/verse/map/lane_map.py
@@ -1,5 +1,5 @@
 from typing import Dict, List
-import copy
+import re
 from enum import Enum
 
 import numpy as np
@@ -7,6 +7,10 @@ import numpy as np
 from verse.map.lane_segment import AbstractLane
 from verse.map.lane import Lane
 
+VIRT_LANE = re.compile(r"M\d\d")
+PHYS_LANE = re.compile(r"T\d")
+INTERSECT_PHYS_LANE = re.compile(r"[A-Z]+_\d+")
+INTERSECT_VIRT_LANE = re.compile(r"[A-Z]+_\d+_\d+")
 
 class LaneMap:
     def __init__(self, lane_seg_list: List[Lane] = []):
@@ -25,70 +29,58 @@ class LaneMap:
             self.left_lane_dict[lane_seg.id] = []
             self.right_lane_dict[lane_seg.id] = []
 
+    @staticmethod
+    def get_phys_lane(lane):
+        # res = LaneMap._get_phys_lane(lane)
+        # print(f"phys({lane}) -> {res}")
+        # return res
+
+    # @staticmethod
+    # def _get_phys_lane(lane):
+        if isinstance(lane, Enum):
+            lane = lane.name
+        if VIRT_LANE.match(lane):
+            return f"T{lane[1]}"
+        if INTERSECT_VIRT_LANE.match(lane):
+            return lane.rsplit("_", 1)[0]
+        # if PHYS_LANE.match(lane) or INTERSECT_PHYS_LANE.match(lane):
+        #     return lane 
+        return lane
+
     def lane_geometry(self, lane_idx):
-        if isinstance(lane_idx, Enum):
-            lane_idx = lane_idx.name
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
-        return self.lane_dict[src_lane].get_geometry()
+        return self.lane_dict[LaneMap.get_phys_lane(lane_idx)].get_geometry()
 
     def get_longitudinal_position(self, lane_idx: str, position: np.ndarray) -> float:
-        if isinstance(lane_idx, Enum):
-            lane_idx = lane_idx.name
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
         if not isinstance(position, np.ndarray):
             position = np.array(position)
         # print(self.lane_dict)
-        lane = self.lane_dict[src_lane]
+        lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         return lane.get_longitudinal_position(position)
 
     def get_lateral_distance(self, lane_idx: str, position: np.ndarray) -> float:
-        if isinstance(lane_idx, Enum):
-            lane_idx = lane_idx.name
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
         if not isinstance(position, np.ndarray):
             position = np.array(position)
-        lane = self.lane_dict[src_lane]
+        lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         return lane.get_lateral_distance(position)
 
     def get_altitude(self, lane_idx, position: np.ndarray) -> float:
         raise NotImplementedError
 
     def get_lane_heading(self, lane_idx: str, position: np.ndarray) -> float:
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
         if not isinstance(position, np.ndarray):
             position = np.array(position)
-        lane = self.lane_dict[src_lane]
+        lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         return lane.get_heading(position)
 
     def get_lane_segment(self, lane_idx: str, position: np.ndarray) -> AbstractLane:
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
         if not isinstance(position, np.ndarray):
             position = np.array(position)
-        lane = self.lane_dict[src_lane]
+        lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         seg_idx, segment = lane.get_lane_segment(position)
         return segment
 
     def get_speed_limit(self, lane_idx: str) -> float:
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
-        lane: Lane = self.lane_dict[src_lane]
+        lane: Lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         # print(lane.get_speed_limit())
         return lane.get_speed_limit()
 
@@ -100,11 +92,7 @@ class LaneMap:
         return ret_dict
 
     def get_lane_width(self, lane_idx: str) -> float:
-        if len(lane_idx) == 3:
-            src_lane = f"T{lane_idx[1]}"
-        else:
-            src_lane = lane_idx 
-        lane: Lane = self.lane_dict[src_lane]
+        lane: Lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)]
         return lane.get_lane_width()
 
     def h(self, lane_idx, agent_mode_src, agent_mode_dest):
-- 
GitLab