Skip to content
Snippets Groups Projects
segmentmodel_from_lists.py 8.25 KiB
# -*- coding: utf-8 -*-

import pdb
import numpy as np
import os,sys

from ..segmentmodel import SegmentModel, SingleStrandedSegment, DoubleStrandedSegment
from .. import get_resource_path

def _three_prime_list_to_five_prime(three_prime):
    ids = np.arange(len(three_prime))
    five_prime = -np.ones(three_prime.shape)
    has_three_prime = np.where(three_prime >= 0)[0]
    five_prime[three_prime[has_three_prime]] = has_three_prime
    return five_prime  

def _three_prime_list_to_strands(three_prime):
    five_prime = _three_prime_list_to_five_prime(three_prime)
    five_prime_ends = np.where(five_prime < 0)[0]
    strands = []
    
    for nt_idx in five_prime_ends:
        strand = [nt_idx]
        while three_prime[nt_idx] >= 0:
            nt_idx = three_prime[nt_idx]
            strand.append(nt_idx)
        strands.append( np.array(strand, dtype=np.int) )
    return strands

def basepairs_and_stacks_to_helixmap(basepairs,stacks_above):

    helixmap = -np.ones(basepairs.shape, dtype=np.int)
    helixrank = -np.ones(basepairs.shape)
    is_fwd = np.ones(basepairs.shape, dtype=np.int)
    
    ## Remove stacks with nts lacking a basepairs
    nobp = np.where(basepairs < 0)[0]
    stacks_above[nobp] = -1
    stacks_with_nobp = np.in1d(stacks_above, nobp)
    stacks_above[stacks_with_nobp] = -1

    end_ids = np.where( (stacks_above < 0)*(basepairs >= 0) )[0]

    hid = 0
    for end in end_ids:
        if helixmap[end] >= 0:
            continue
        rank = 0
        nt = basepairs[end]
        bp = basepairs[nt]
        assert(helixmap[nt] == -1)
        assert(helixmap[bp] == -1)
        helixmap[nt] = helixmap[bp] = hid
        helixrank[nt] = helixrank[bp] = rank
        is_fwd[bp] = 0
        rank +=1

        while stacks_above[nt] >= 0:
            nt = stacks_above[nt]
            if basepairs[nt] < 0: break
            bp = basepairs[nt]
            assert(helixmap[nt] == -1)
            assert(helixmap[bp] == -1)
            helixmap[nt] = helixmap[bp] = hid
            helixrank[nt] = helixrank[bp] = rank
            is_fwd[bp] = 0
            rank +=1
        hid += 1
    return helixmap, helixrank, is_fwd


def SegmentModelFromPdb(*args,**kwargs):
    u = mda.Universe(*args,**kwargs)
    for a in u.select_atoms("resname DT* and name C7").atoms:
        a.name = "C5M"

    ## Find basepairs and base stacks
    centers,transforms = find_base_position_orientation(u)
    bps = find_basepairs(u, centers, transforms)
    stacks = find_stacks(u, centers, transforms)

    ## Build map from residue index to helix index

def set_splines(seg, coordinates, hid, hmap, hrank):
    maxrank = np.max( hrank[hmap==hid] )
    if maxrank == 0:
        ids = np.where((hmap == hid))[0]
        pos = np.mean( [coordinates[r,:] for r in ids ], axis=0 )
        coords = [pos,pos]
        contours = [0,1]
    else:
        coords,contours = [[],[]]
        for rank in range(int(maxrank)+1):
            ids = np.where((hmap == hid) * (hrank == rank))[0]
            coords.append(np.mean( [coordinates[r,:] for r in ids ], axis=0 ))
            contours.append( float(rank)/maxrank )
    coords = np.array(coords)
    seg.set_splines(coords,contours)


def model_from_basepair_stack_3prime(coordinates, basepair, stack, three_prime, sequence=None):
    """ 
    Creates a SegmentModel object from lists of each nucleotide's
    basepair, its stack (on 3' side) and its 3'-connected nucleotide

    The first argument should be an N-by-3 numpy array containing the
    coordinates of each nucleotide, where N is the number of
    nucleotides. The following three arguments should be integer lists
    where the i-th element corresponds to the i-th nucleotide; the
    list element should the integer index of the corresponding
    basepaired / stacked / phosphodiester-bonded nucleotide. If there
    is no such nucleotide, the value should be -1.

    Args:
        basepair:  List of each nucleotide's basepair's index
        stack:  List containing index of the nucleotide stacked on the 3' of each nucleotide
        three_prime:  List of each nucleotide's the 3' end of each nucleotide

    Returns:
        SegmentModel
    """

    """ Validate Input """
    try:
        inputs = (basepair,stack,three_prime)
        basepair,stack,three_prime = [np.array(a,dtype=np.int) 
                                      for a in inputs]
    except:
        raise TypeError("One or more of the input lists could not be converted into a numpy array")

    if np.any( [len(a.shape) > 1 for a in inputs] ):
        raise ValueError("One or more the input lists has the wrong dimensionality")

    inputs = (basepair,stack,three_prime)
    if not np.all(np.diff([len(a) for a in inputs]) == 0):
        raise ValueError("Inputs are not the same length")
        
    num_nt = len(basepair)
    if sequence is not None and len(sequence) != num_nt:
        raise ValueError("The 'sequence' parameter is the wrong length {} != {}".format(len(sequence),num_nt))

    bps = basepair              # alias

    """ Build map of dsDNA helices and strands """
    hmap,hrank,fwd = basepairs_and_stacks_to_helixmap(bps,stack)
    double_stranded_helices = np.unique(hmap[hmap >= 0])    
    strands = _three_prime_list_to_strands(three_prime)
    
    """ Add ssDNA to hmap """
    hid = double_stranded_helices[-1]+1
    ss_residues = hmap < 0
    assert( np.all(bps[ss_residues] == -1) )
    
    for s in strands:
        ss_residues = s[np.where(hmap[s] < 0)[0]]
        if len(ss_residues) == 0: continue
        resid_diff = np.diff(ss_residues)
        
        tmp = np.where( resid_diff != 1 )[0]
        first_res = ss_residues[0]
        for i in tmp:
            ids = np.arange(first_res, ss_residues[i]+1)
            assert( np.all(hmap[ids] == -1) )
            hmap[ids] = hid
            hrank[ids] = ids-first_res
            first_res = ss_residues[i+1]
            hid += 1
        ids = np.arange(first_res, ss_residues[-1]+1)
        assert( np.all(hmap[ids] == -1) )
        hmap[ids] = hid
        hrank[ids] = ids-first_res
        hid+=1

    single_stranded_helices = np.arange(double_stranded_helices[-1]+1,hid)

    ## Create double-stranded segments
    doubleSegments = []
    for hid in double_stranded_helices:
        seg = DoubleStrandedSegment(name=str(hid),
                                    num_bp = np.sum(hmap==hid)//2)
        set_splines(seg, coordinates, hid, hmap, hrank)

        assert(hid == len(doubleSegments))
        doubleSegments.append(seg)

    ## Create single-stranded segments
    singleSegments = []
    for hid in single_stranded_helices:
        seg = SingleStrandedSegment(name=str(hid),
                                    num_nt = np.sum(hmap==hid))
        set_splines(seg, coordinates, hid, hmap, hrank)

        assert(hid == len(doubleSegments) + len(singleSegments))
        singleSegments.append(seg)

    ## Find crossovers and 5prime/3prime ends
    crossovers,prime5,prime3 = [[],[],[]]
    for s in strands:
        tmp = np.where(np.diff(hmap[s]) != 0)[0]
        crossovers.extend( s[tmp] )
        prime5.append(s[0])
        prime3.append(s[-1])

    ## Add connections
    allSegments = doubleSegments+singleSegments
    for r in crossovers:
        seg1,seg2 = [allSegments[hmap[i]] for i in (r,r+1)]
        nt1,nt2 = [hrank[i] for i in (r,r+1)]
        f1,f2 = [fwd[i] for i in (r,r+1)]
        if nt1 in (0,seg1.num_nt) or nt2 in (0,seg2.num_nt):
            seg1.add_crossover(nt1,seg2,nt2,[f1,f2],type_="terminal_crossover")
        else:
            seg1.add_crossover(nt1,seg2,nt2,[f1,f2])

    ## Add 5prime/3prime ends
    for r in prime5:
        seg = allSegments[hmap[r]]
        seg.add_5prime(hrank[r],fwd[r])
    for r in prime3:
        seg = allSegments[hmap[r]]
        seg.add_3prime(hrank[r],fwd[r])

    ## Assign sequence
    if sequence is not None:
        for hid in range(len(allSegments)):
            resids = np.sort(np.where( (hmap==hid)*(fwd==1) )[0])
            s = allSegments[hid]
            s.sequence = [sequence[r] for r in resids]

    ## Build model
    model = SegmentModel( allSegments,
                          max_basepairs_per_bead = 5, 
                          max_nucleotides_per_bead = 5, 
                          local_twist = False,
                          dimensions=(5000,5000,5000))


    if sequence is None:
        model.randomize_unset_sequence()

    return model