diff --git a/mrdna/segmentmodel.py b/mrdna/segmentmodel.py
index 7a563620a36850063107479f86ff9e4d4bb2979a..88949e744d2e90c85fa932d6b38ee685af58551e 100644
--- a/mrdna/segmentmodel.py
+++ b/mrdna/segmentmodel.py
@@ -3529,26 +3529,41 @@ proc calcforces {} {
             dx = dy = dz = max((dx,dy,dz))
         return np.array([dx,dy,dz])
 
-    def add_grid_potential(self, grid_file, scale=1, per_nucleotide=True):
+    def add_grid_potential(self, grid_file, scale=1, per_nucleotide=True, filter_fn=None):
         grid_file = Path(grid_file)
         if not grid_file.is_file():
             raise ValueError("Grid file {} does not exist".format(grid_file))
         if not grid_file.is_absolute():
             grid_file = Path.cwd() / grid_file
-        self.grid_potentials.append((grid_file,scale,per_nucleotide))
+        self.grid_potentials.append((grid_file,scale,per_nucleotide, filter_fn))
 
-    def _apply_grid_potentials_to_beads(self, bead_type_dict):
+    def _apply_grid_potentials_to_beads(self, bead_type_dict ):
         if len(self.grid_potentials) > 1:
             raise NotImplementedError("Multiple grid potentials are not yet supported")
 
-        for grid_file, scale, per_nucleotide in self.grid_potentials:
-            for key,particle_type in bead_type_dict.items():
-                if particle_type.name[0] == "O": continue
-                s = scale*particle_type.nts if per_nucleotide else scale
-                try:
-                    particle_type.grid = particle_type.grid + (grid_file, s)
-                except:
-                    particle_type.grid = tuple((grid_file, s))
+        def add_grid_to_type(particle_type):
+            if particle_type.name[0] == "O": return
+            s = scale*particle_type.nts if per_nucleotide else scale
+            try:
+                particle_type.grid = particle_type.grid + (grid_file, s)
+            except:
+                particle_type.grid = tuple((grid_file, s))
+
+        for grid_file, scale, per_nucleotide, filter_fn in self.grid_potentials:
+            if filter_fn is None:
+                for key,particle_type in bead_type_dict.items():
+                    add_grid_to_type(particle_type)
+            else:
+                grid_types = dict()
+                for b in filter(filter_fn, sum([seg.beads for seg in self.segments],[])):
+                    t = b.type_
+                    if t.name[0] == "O": continue
+                    if t not in grid_types:
+                        new_type = ParticleType(name=t.name+'G',charge=t.charge, parent=t)
+                        add_grid_to_type(new_type)
+                        grid_types[t] = new_type
+                    b.type_ = grid_types[t]
+
 
     def _generate_atomic_model(self, scale=1):
         ## TODO: deprecate