Skip to content
Snippets Groups Projects
Commit 83699f47 authored by crides's avatar crides
Browse files

fix(simulation): handle case when only one agent is in the scenario

parent 3c12dcb3
No related branches found
No related tags found
2 merge requests!7Keyi tmp,!1Merge refactor to main
......@@ -287,18 +287,28 @@ class Scenario:
for arg, arg_type, is_list in agent.controller.controller.args:
if arg != EGO and 'map' not in arg:
other = arg
others_keys = list(packed[other].keys())
if is_list:
packed[other] = [state_ty(**{k: packed[other][k][i] for k in others_keys}) for i in range(len(packed[other][others_keys[0]]))]
if other in packed:
others_keys = list(packed[other].keys())
if is_list:
packed[other] = [state_ty(**{k: packed[other][k][i] for k in others_keys}) for i in range(len(packed[other][others_keys[0]]))]
else:
other_val = packed[other][others_keys[0]]
if isinstance(other_val, list):
other_val_len = len(other_val)
raise ValueError(f"Expected one {ego_ty_name} for {other}, got {other_val_len}")
packed[other] = state_ty(**{k: packed[other][k] for k in others_keys})
else:
packed[other] = state_ty(**{k: packed[other][k] for k in others_keys})
if is_list:
packed[other] = []
else:
raise ValueError(f"Expected one {ego_ty_name} for {other}, got none")
packed[EGO] = state_ty(**packed[EGO])
map_var = find(agent.controller.controller.args, lambda a: "map" in a[0])
if map_var != None:
packed[map_var[0]] = map
packed: Dict[str, Any] = dict(packed.items())
packed.update(env)
# packed.update(env)
return packed
packed_env = pack_env(agent, continuous_variable_dict, orig_disc_vars, self.map)
def eval_expr(expr, env):
......@@ -368,9 +378,8 @@ class Scenario:
def get_transition_verify_new(self, node:AnalysisTreeNode):
lane_map = self.map
possible_transitions = []
agent_guard_dict = {}
agent_guard_dict = defaultdict(list)
for agent_id in node.agent:
agent:BaseAgent = self.agent_dict[agent_id]
if agent.controller.controller == None:
......@@ -394,27 +403,41 @@ class Scenario:
guard_can_satisfied = guard_expression.evaluate_guard_disc(agent, discrete_variable_dict, cont_var_dict_template, self.map)
if not guard_can_satisfied:
continue
if agent_id not in agent_guard_dict:
agent_guard_dict[agent_id] = [(guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset)]
else:
agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset))
agent_guard_dict[agent_id].append((guard_expression, cont_var_updater, copy.deepcopy(discrete_variable_dict), reset))
trace_length = int(len(list(node.trace.values())[0])/2)
guard_hits = []
guard_hit_bool = False
for idx in range(0,trace_length):
any_contained = False
hits = []
state_dict = {}
for tmp in node.agent:
state_dict[tmp] = (node.trace[tmp][idx*2:idx*2+2], node.mode[tmp], node.static[tmp])
asserts = defaultdict(list)
for agent_id in agent_guard_dict:
agent:BaseAgent = self.agent_dict[agent_id]
agent_state, agent_mode, agent_static = state_dict[agent_id]
agent_state = agent_state[1:]
continuous_variable_dict, _, _ = self.sensor.sense(self, agent, state_dict, self.map)
continuous_variable_dict, orig_disc_vars, _ = self.sensor.sense(self, agent, state_dict, self.map)
resets = defaultdict(list)
# Check safety conditions
for i, a in enumerate(agent.controller.controller.asserts):
pre_expr = merge_conds(a.pre)
def eval_expr(expr):
ge = GuardExpressionAst(expr)
sat = ge.evaluate_guard_hybrid(agent, orig_disc_vars, continuous_variable_dict, self.map)
if sat:
sat, _ = ge.evaluate_guard_cont(agent, continuous_variable_dict, self.map)
return sat
if eval_expr(pre_expr):
if not eval_expr(a.cond):
label = a.label if a.label != None else f"<assert {i}>"
print(f"assert hit for {agent_id}: \"{label}\"")
asserts[agent_id].append(label)
if agent_id in asserts:
continue
for guard_expression, continuous_variable_updater, discrete_variable_dict, reset in agent_guard_dict[agent_id]:
new_cont_var_dict = copy.deepcopy(continuous_variable_dict)
one_step_guard:GuardExpressionAst = copy.deepcopy(guard_expression)
......@@ -435,19 +458,20 @@ class Scenario:
continue
for i in range(len(combined_reset_list)):
# a list of reset expression
hits.append((agent_id, guard_expression, combined_reset_list[i]))
hits.append((agent_id, combined_reset_list[i]))
if len(asserts) > 0:
return (asserts, idx), None
if hits != []:
guard_hits.append((hits, state_dict, idx))
guard_hit_bool = True
if hits == [] and guard_hit_bool:
else:
break
if any_contained:
break
reset_dict = {}
reset_idx_dict = {}
reset_dict = {}#defaultdict(lambda: defaultdict(list))
reset_idx_dict = {}#defaultdict(lambda: defaultdict(list))
for hits, all_agent_state, hit_idx in guard_hits:
for agent_id, guard_list, reset_list in hits:
for agent_id, reset_list in hits:
# TODO: Need to change this function to handle the new reset expression and then I am done
dest_list,reset_rect = self.apply_reset(node.agent[agent_id], reset_list, all_agent_state)
if agent_id not in reset_dict:
......@@ -464,6 +488,7 @@ class Scenario:
reset_dict[agent_id][dest].append(reset_rect)
reset_idx_dict[agent_id][dest].append(hit_idx)
possible_transitions = []
# Combine reset rects and construct transitions
for agent in reset_dict:
for dest in reset_dict[agent]:
......@@ -480,8 +505,6 @@ class Scenario:
combined_rect = combined_rect.tolist()
min_idx = min(reset_idx_dict[agent][dest])
max_idx = max(reset_idx_dict[agent][dest])
transition = (
agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx))
possible_transitions.append(transition)
possible_transitions.append((agent, node.mode[agent], dest, combined_rect, (min_idx, max_idx)))
# Return result
return possible_transitions
return None, possible_transitions
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment