diff --git a/mrdna/segmentmodel.py b/mrdna/segmentmodel.py index 7a563620a36850063107479f86ff9e4d4bb2979a..88949e744d2e90c85fa932d6b38ee685af58551e 100644 --- a/mrdna/segmentmodel.py +++ b/mrdna/segmentmodel.py @@ -3529,26 +3529,41 @@ proc calcforces {} { dx = dy = dz = max((dx,dy,dz)) return np.array([dx,dy,dz]) - def add_grid_potential(self, grid_file, scale=1, per_nucleotide=True): + def add_grid_potential(self, grid_file, scale=1, per_nucleotide=True, filter_fn=None): grid_file = Path(grid_file) if not grid_file.is_file(): raise ValueError("Grid file {} does not exist".format(grid_file)) if not grid_file.is_absolute(): grid_file = Path.cwd() / grid_file - self.grid_potentials.append((grid_file,scale,per_nucleotide)) + self.grid_potentials.append((grid_file,scale,per_nucleotide, filter_fn)) - def _apply_grid_potentials_to_beads(self, bead_type_dict): + def _apply_grid_potentials_to_beads(self, bead_type_dict ): if len(self.grid_potentials) > 1: raise NotImplementedError("Multiple grid potentials are not yet supported") - for grid_file, scale, per_nucleotide in self.grid_potentials: - for key,particle_type in bead_type_dict.items(): - if particle_type.name[0] == "O": continue - s = scale*particle_type.nts if per_nucleotide else scale - try: - particle_type.grid = particle_type.grid + (grid_file, s) - except: - particle_type.grid = tuple((grid_file, s)) + def add_grid_to_type(particle_type): + if particle_type.name[0] == "O": return + s = scale*particle_type.nts if per_nucleotide else scale + try: + particle_type.grid = particle_type.grid + (grid_file, s) + except: + particle_type.grid = tuple((grid_file, s)) + + for grid_file, scale, per_nucleotide, filter_fn in self.grid_potentials: + if filter_fn is None: + for key,particle_type in bead_type_dict.items(): + add_grid_to_type(particle_type) + else: + grid_types = dict() + for b in filter(filter_fn, sum([seg.beads for seg in self.segments],[])): + t = b.type_ + if t.name[0] == "O": continue + if t not in grid_types: + new_type = ParticleType(name=t.name+'G',charge=t.charge, parent=t) + add_grid_to_type(new_type) + grid_types[t] = new_type + b.type_ = grid_types[t] + def _generate_atomic_model(self, scale=1): ## TODO: deprecate