From 1d51865c0e4df2dd538543f7da0c95e48d3de091 Mon Sep 17 00:00:00 2001 From: crides <zhuhaoqing@live.cn> Date: Thu, 6 Apr 2023 21:04:42 -0500 Subject: [PATCH] fix(example.intersection): make intermediate lanes --- verse/map/example_map/intersection.py | 21 ++++---- verse/map/lane_map.py | 72 +++++++++++---------------- 2 files changed, 41 insertions(+), 52 deletions(-) diff --git a/verse/map/example_map/intersection.py b/verse/map/example_map/intersection.py index 0c876095..c216ec42 100644 --- a/verse/map/example_map/intersection.py +++ b/verse/map/example_map/intersection.py @@ -74,7 +74,8 @@ class Intersection(LaneMap): def h(self, lane_idx: str, mode_src: str, mode_dest: str) -> Optional[str]: # ret = self._h(lane_idx, mode_src, mode_dest) - # print("H", lane_idx, mode_src, mode_dest, "->", ret) + # if ret != None and ret != lane_idx: + # print("H", lane_idx, mode_src, mode_dest, "->", ret) # return ret # def _h(self, lane_idx: str, mode_src: str, mode_dest: str) -> Optional[str]: @@ -83,19 +84,19 @@ class Intersection(LaneMap): if dst_sw: return None else: - lanes, ind = lane_idx.split("_") + lanes, _, ind = lane_idx.split("_", 2) ind = int(ind) - if "Left" in mode_src and ind > 0: - return f"{lanes}.{ind - 1}" - if "Right" in mode_src and ind < self.lanes - 1: - return f"{lanes}.{ind + 1}" - return None + return f"{lanes}_{ind}" else: if dst_sw: lanes, ind = lane_idx.split("_") ind = int(ind) - if "Left" in mode_dest and ind > 0 or "Right" in mode_dest and ind < self.lanes - 1: - return lane_idx - return None + if "Left" in mode_dest and ind > 0: + new_ind = ind - 1 + elif "Right" in mode_dest and ind < self.lanes - 1: + new_ind = ind + 1 + else: + return None + return f"{lanes}_{ind}_{new_ind}" else: return lane_idx diff --git a/verse/map/lane_map.py b/verse/map/lane_map.py index 7c2b3661..e91bf5bc 100644 --- a/verse/map/lane_map.py +++ b/verse/map/lane_map.py @@ -1,5 +1,5 @@ from typing import Dict, List -import copy +import re from enum import Enum import numpy as np @@ -7,6 +7,10 @@ import numpy as np from verse.map.lane_segment import AbstractLane from verse.map.lane import Lane +VIRT_LANE = re.compile(r"M\d\d") +PHYS_LANE = re.compile(r"T\d") +INTERSECT_PHYS_LANE = re.compile(r"[A-Z]+_\d+") +INTERSECT_VIRT_LANE = re.compile(r"[A-Z]+_\d+_\d+") class LaneMap: def __init__(self, lane_seg_list: List[Lane] = []): @@ -25,70 +29,58 @@ class LaneMap: self.left_lane_dict[lane_seg.id] = [] self.right_lane_dict[lane_seg.id] = [] + @staticmethod + def get_phys_lane(lane): + # res = LaneMap._get_phys_lane(lane) + # print(f"phys({lane}) -> {res}") + # return res + + # @staticmethod + # def _get_phys_lane(lane): + if isinstance(lane, Enum): + lane = lane.name + if VIRT_LANE.match(lane): + return f"T{lane[1]}" + if INTERSECT_VIRT_LANE.match(lane): + return lane.rsplit("_", 1)[0] + # if PHYS_LANE.match(lane) or INTERSECT_PHYS_LANE.match(lane): + # return lane + return lane + def lane_geometry(self, lane_idx): - if isinstance(lane_idx, Enum): - lane_idx = lane_idx.name - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx - return self.lane_dict[src_lane].get_geometry() + return self.lane_dict[LaneMap.get_phys_lane(lane_idx)].get_geometry() def get_longitudinal_position(self, lane_idx: str, position: np.ndarray) -> float: - if isinstance(lane_idx, Enum): - lane_idx = lane_idx.name - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx if not isinstance(position, np.ndarray): position = np.array(position) # print(self.lane_dict) - lane = self.lane_dict[src_lane] + lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] return lane.get_longitudinal_position(position) def get_lateral_distance(self, lane_idx: str, position: np.ndarray) -> float: - if isinstance(lane_idx, Enum): - lane_idx = lane_idx.name - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx if not isinstance(position, np.ndarray): position = np.array(position) - lane = self.lane_dict[src_lane] + lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] return lane.get_lateral_distance(position) def get_altitude(self, lane_idx, position: np.ndarray) -> float: raise NotImplementedError def get_lane_heading(self, lane_idx: str, position: np.ndarray) -> float: - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx if not isinstance(position, np.ndarray): position = np.array(position) - lane = self.lane_dict[src_lane] + lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] return lane.get_heading(position) def get_lane_segment(self, lane_idx: str, position: np.ndarray) -> AbstractLane: - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx if not isinstance(position, np.ndarray): position = np.array(position) - lane = self.lane_dict[src_lane] + lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] seg_idx, segment = lane.get_lane_segment(position) return segment def get_speed_limit(self, lane_idx: str) -> float: - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx - lane: Lane = self.lane_dict[src_lane] + lane: Lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] # print(lane.get_speed_limit()) return lane.get_speed_limit() @@ -100,11 +92,7 @@ class LaneMap: return ret_dict def get_lane_width(self, lane_idx: str) -> float: - if len(lane_idx) == 3: - src_lane = f"T{lane_idx[1]}" - else: - src_lane = lane_idx - lane: Lane = self.lane_dict[src_lane] + lane: Lane = self.lane_dict[LaneMap.get_phys_lane(lane_idx)] return lane.get_lane_width() def h(self, lane_idx, agent_mode_src, agent_mode_dest): -- GitLab