diff --git a/verse/analysis/analysis_tree.py b/verse/analysis/analysis_tree.py index 147b98f44bbcb481ddbd0ad518e218e5f107aa92..309cd3800afa6aa1ea103c196d8cee0e7414c608 100644 --- a/verse/analysis/analysis_tree.py +++ b/verse/analysis/analysis_tree.py @@ -1,6 +1,6 @@ from functools import reduce import pickle -from typing import List, Dict, Any, Optional, Tuple +from typing import Iterable, List, Dict, Any, Optional, Tuple, TypeVar import json from treelib import Tree import numpy.typing as nptyp, numpy as np, portion @@ -12,6 +12,14 @@ import matplotlib.pyplot as plt TraceType = nptyp.NDArray[np.float_] +T = TypeVar("T") + +def index_of(l: Iterable[T], item: T) -> Optional[int]: + for i, v in enumerate(l): + if v == item: + return i + return None + class AnalysisTreeNode: """AnalysisTreeNode class A AnalysisTreeNode stores the continous execution of the system without transition happening""" @@ -72,26 +80,22 @@ class AnalysisTreeNode: return rst_dict def get_track(self, agent_id, D): - if 'TrackMode' not in self.agent[agent_id].decision_logic.mode_defs: - return "" - for d in D: - if d in self.agent[agent_id].decision_logic.mode_defs['TrackMode'].modes: - return d - return "" + mode_def_names = list(self.agent[agent_id].decision_logic.state_defs.values())[0].disc_type + track_mode_ind = index_of(mode_def_names, "TrackMode") + if track_mode_ind == None: + return None + return D[track_mode_ind] def get_mode(self, agent_id, D): - res = [] - if 'TrackMode' not in self.agent[agent_id].decision_logic.mode_defs: + mode_def_names = list(self.agent[agent_id].decision_logic.state_defs.values())[0].disc_type + track_mode_ind = index_of(mode_def_names, "TrackMode") + if track_mode_ind == None: if len(D)==1: return D[0] return D - for d in D: - if d not in self.agent[agent_id].decision_logic.mode_defs['TrackMode'].modes: - res.append(d) - if len(res) == 1: - return res[0] - else: - return tuple(res) + if len(mode_def_names) == 2: + return D[1 - track_mode_ind] + return tuple(v for i, v in enumerate(D) if i != track_mode_ind) @staticmethod def from_dict(data) -> "AnalysisTreeNode": diff --git a/verse/map/lane_map.py b/verse/map/lane_map.py index e91bf5bc2b08369c73604793a503361364b12f36..34a80506bc8a6ea41b1fe05e3469c0f819cb4c70 100644 --- a/verse/map/lane_map.py +++ b/verse/map/lane_map.py @@ -103,7 +103,7 @@ class LaneMap: if isinstance(agent_mode_dest, Enum): agent_mode_dest = agent_mode_dest.name if self.h_dict == {}: - return "" + return None return self.h_dict[(lane_idx, agent_mode_src, agent_mode_dest)] def h_exist(self, lane_idx, agent_mode_src, agent_mode_dest):