diff --git a/guided_mrmp/simulator.py b/guided_mrmp/simulator.py index 830b72bed055b66a2515f7fcef581204cfe5f3ff..f71c33fea35b0da3a8497d1052a07f4072a2131f 100644 --- a/guided_mrmp/simulator.py +++ b/guided_mrmp/simulator.py @@ -50,15 +50,15 @@ class Simulator: # Get the controls from the policy x_mpc, controls = self.policy.advance(self.state, show_plots=self.settings['simulator']['show_collision_resolution']) - # # Update the state of each robot - # for i in range(self.num_robots): - # new_state = self.dynamics_models[i].next_state(self.state[i], controls[i], dt) - # self.robots[i].current_position = new_state - # self.state[i] = new_state + # Update the state of each robot + next_states = [] + for i in range(self.num_robots): + next_states.append(self.policy.dynamics.next_state(self.state[i], controls[i], self.policy.DT)) + + self.state = next_states # Update the time self.time += dt - return x_mpc, controls def run(self, show_plots=False): """ @@ -90,13 +90,7 @@ class Simulator: if show_plots: self.plot_current_world_state() # get the next control for all robots - x_mpc, controls = self.advance(self.state, self.policy.DT) - - next_states = [] - for i in range(self.num_robots): - next_states.append(self.policy.dynamics.next_state(self.state[i], controls[i], self.policy.DT)) - - self.state = next_states + self.advance(self.state, self.policy.DT) self.state = np.array(self.state) for i in range(self.num_robots):