#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
circseq_cup -- circular RNA analysis toolkits.

Usage: circ_anotation.py [options]

Options:
    -h --help                      Show this screen.
    -c CIRCSEQ --circ=CIRCSEQ      Cap3 assembled contigs
    -p PAIRNUM --pair=PAIRNUM      PE number supporting the joints
    -g GENOME --genome=GENOME      Genome FASTA file.
    -r REF --ref=REF               Gene annotation.
    -o PREFIX --output=PREFIX      Output prefix [default: circ].
"""


from docopt import docopt
import sys
import pysam
from collections import defaultdict
from interval import Interval
import tempfile
import os

COMPLEMENT = {
    'a' : 't',
    't' : 'a',
    'c' : 'g',
    'g' : 'c',
    'k' : 'm',
    'm' : 'k',
    'r' : 'y',
    'y' : 'r',
    's' : 's',
    'w' : 'w',
    'b' : 'v',
    'v' : 'b',
    'h' : 'd',
    'd' : 'h',
    'n' : 'n',
    'A' : 'T',
    'T' : 'A',
    'C' : 'G',
    'G' : 'C',
    'K' : 'M',
    'M' : 'K',
    'R' : 'Y',
    'Y' : 'R',
    'S' : 'S',
    'W' : 'W',
    'B' : 'V',
    'V' : 'B',
    'H' : 'D',
    'D' : 'H',
    'N' : 'N',
}

def complement(s):
    return "".join([COMPLEMENT[x] for x in s])
def rev_comp(seq):
    return complement(seq)[::-1]

circ_hash = {};circ_boundary={}

def converse_index(index):
    chrom,sta,end=index.split('_')[:3]
    return chrom,int(sta),int(end)

def converse_index_all(index):
    chrom,sta,end,ovlap,ovlap_err=index.split('_')
    return chrom,int(sta),int(end),int(ovlap),int(ovlap_err)


def print_list(clist):
    for i in clist:
       print i
    print ""

def read_file(f):
    cfile = open(f)
    l = cfile.readline() 
    chash = {}
    while l:
       l=l.strip()
       index,seq=l.split()       
       chash[index] = seq 
       l = cfile.readline()  
    cfile.close()
    return  chash

def check_seq(circ_hash_prime,pair_hash_prime,genome_fa,all_stas,all_ends):
    global circ_hash
    pair_hash={}
    outlist_spl=[]
    outlist_rep=[]
    change_list=[]
    for index,seq in circ_hash_prime.items():
       bs,outlist_spl=check_splicing(seq,index,genome_fa,10,outlist_spl)
       br,outlist_rep=check_repeat(index,genome_fa,20,outlist_rep)
       if bs == 1 and br == 1:
          new_index,change_list = check_boundary(index,genome_fa,all_stas,all_ends,change_list)
          chrom,sta,end=index.split('_')[:3]
          num=pair_hash_prime[chrom+"_"+sta+"_"+end]
          if new_index == "Y":   # not realign
             circ_hash[index]=seq
             pair_hash[index]=num
          else:   # realign
             circ_hash[new_index]=seq
             pair_hash[new_index]=num

    print "Check_splicing Delete:"+str(len(outlist_spl))
    #print_list(outlist_spl)
    print "Check_repeat Delete:"+str(len(outlist_rep))
    #print_list(outlist_rep)
    print "Check_boundary Change:"+str(len(change_list))
    #print_list(change_list)

    fusions = defaultdict(list)
    for index,seq in circ_hash.items():
       chrom,sta,end=converse_index(index)
       fusions[chrom].append([sta,end,index])
    return  fusions,pair_hash

# Delete circle sites with linear sequence same with the joins
def check_repeat(index,fa,n,out):
    chrom,sta,end=converse_index(index)
    if 1 == if_same_ignore_strd(fa.fetch(chrom,end-n, end).upper(),fa.fetch(chrom,sta-n, sta).upper()) or \
      1==  if_same_ignore_strd(fa.fetch(chrom,end, end+n).upper(),fa.fetch(chrom,sta, sta+n).upper()):
       out.append(index)
       return 0,out
    return 1,out

# The contigs without backsplicing sites are excluded
def check_splicing(circ_seq,index,fa,n,out):
    chrom,sta,end=converse_index(index)
    spl_seq=fa.fetch(chrom,end-n, end).upper()+fa.fetch(chrom,sta, sta+n).upper()
    for i in range(0,len(circ_seq)-n*2):
       if 1 == if_same(spl_seq,circ_seq[i:i+n*2]):
          return 1,out
    out.append(index)
    return 0,out

def check_boundary(index,fa,all_stas,all_ends,out):
    global circ_boundary
    left_b,right_b=get_boundary_tag(index,all_stas,all_ends)  
    circ_boundary[index]=str(left_b)+" "+str(right_b) 
    if left_b != 0 or right_b != 0:
        return "Y",out
    new_index=if_adjust_boundary(index,fa,all_stas,all_ends)
    if new_index =="Y":
       return "Y",out
    nl_b,nr_b=get_boundary_tag(new_index,all_stas,all_ends)
    circ_boundary[new_index] = str(nl_b)+" "+str(nr_b)
    out.append(index+" to "+new_index)
    return new_index,out


def if_adjust_boundary(index,fa,all_stas,all_ends):
    chrom,sta,end=converse_index(index)
    #To make start and end positions to satisfy the splicing signals
    s_seq=fa.fetch(chrom,sta-2, sta+18).upper()
    e_seq=fa.fetch(chrom,end, end+20).upper()
    if 1==if_splicing_site(s_seq[:2],e_seq[:2]):
       return "Y"
    for i in range(0,17):
       if s_seq[i+2] != e_seq[i]:
          break
       new_index=chrom+"_"+str(sta+i+1)+"_"+str(end+i+1)+"_"+index.split('_',3)[3]
       if 1==if_splicing_site(s_seq[i+1:i+3],e_seq[i+1:i+3]):
          return new_index
    s_seq=fa.fetch(chrom,sta-20, sta).upper()
    e_seq=fa.fetch(chrom,end-18, end+2).upper()
    for i in range(17,0,-1):
       if s_seq[i+2] != e_seq[i]:
          break
       new_index=chrom+"_"+str(sta-(18-i))+"_"+str(end-(18-i))+"_"+index.split('_',3)[3]
       if 1==if_splicing_site(s_seq[i:i+2],e_seq[i:i+2]):
          return new_index

    #To make start and end positions to satisfy the annotated spliced sites
    s_seq=fa.fetch(chrom,sta-2, sta+18).upper()
    e_seq=fa.fetch(chrom,end, end+20).upper()
    for i in range(0,17):
       if s_seq[i+2] != e_seq[i]:
          break
       new_index=chrom+"_"+str(sta+i+1)+"_"+str(end+i+1)+"_"+index.split('_',3)[3]
       nl_b,nr_b=get_boundary_tag(new_index,all_stas,all_ends)
       if nl_b!=0 or nr_b!=0:
          return new_index
    s_seq=fa.fetch(chrom,sta-20, sta).upper()
    e_seq=fa.fetch(chrom,end-18, end+2).upper()
    for i in range(17,0,-1):
       if s_seq[i+2] != e_seq[i]:
          break
       new_index=chrom+"_"+str(sta-(18-i))+"_"+str(end-(18-i))+"_"+index.split('_',3)[3]
       nl_b,nr_b=get_boundary_tag(new_index,all_stas,all_ends)
       if nl_b!=0 or nr_b!=0:
          return new_index
    return "Y"
          
def if_splicing_site(s1,s2):
   if (s1=="AG" and s2=="GT") or (s1=="AC" and s2=="CT"):
      return 1
   return 0

def get_boundary_tag(index,all_stas,all_ends):
    chrom,sta,end = converse_index(index)
    left_b = 0; right_b = 0
    if sta in all_stas[chrom]:
        left_b = 1
    elif sta in all_ends[chrom]:
        left_b = -1
    if end in all_ends[chrom]:
        right_b = 1
    elif end in all_stas[chrom]:
        right_b = -1
    return left_b,right_b

#ckeck the consistance of two sequences with allowing one mismatch
def if_same(seq1,seq2):
    l = len(seq1);err1=0
    for i in range(0,l):
       if seq1[i] !=seq2[i]:
          err1+=1
       if err1>1:
          break
    seq2 = rev_comp(seq2);err2=0
    for i in range(0,l):
       if seq1[i] !=seq2[i]:
          err2+=1
       if err2>1:
          break
    if err1 <= 1 or err2 <= 1:
       return 1
    else:
       return 0

def if_same_ignore_strd(seq1,seq2):
    l = len(seq1);err=0
    for i in range(0,l):
       if seq1[i] != seq2[i]:
          err+= 1
       if err > 1:
          return 0
    return 1

def sort_keys(my_hash):
    K=[]
    K_u=[]
    for k in my_hash:
        if k.split('_')[0][3:].isdigit():
           K.append(k)
        else:
           K_u.append(k)
    K.sort(key=lambda x:(int(x.split('_')[0][3:]),int(x.split('_')[1]),int(x.split('_')[2])))
    K_u.sort(key=lambda x:(x.split('_')[0],int(x.split('_')[1]),int(x.split('_')[2])))
    for i in K_u:
        K.append(i)
    return K

#compared the contig with the union of all exons in references
def compare_annotation(index,isos,genome_fa,gene_info):
        block=[];gene_mark=defaultdict(list)
        chrom,sta,end,ovlap,ovlap_err=converse_index_all(index)
        for iso in isos:
            for i in iso:
               gene_sta=gene_info[i][0][0];gene_end=gene_info[i][-1][-1]
               if len(Interval.mapto([sta,end],[gene_sta,gene_end]))>0:
                   gene_mark[i.split()[1]]=1
                   for j in range(0,len(gene_info[i][0])):
                      block.append([gene_info[i][0][j],gene_info[i][-1][j]])
        block=Interval(block) 
        block=Interval.mapto([sta,end],block.interval)
        gene=""
        for i in gene_mark.keys():
            gene+=i if len(gene)==0 else ("/"+i)
        ref_seq="";ref_seqs=""
        for i in block:
            exon_seq=genome_fa.fetch(chrom, i[0], i[1]).upper()
            ref_seq+=exon_seq
            ref_seqs+=exon_seq+"\n"
        if ovlap_err == 0:
           ad_seq,mp_err,len_tag,strd=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
           return  ad_seq,ref_seqs,mp_err,len_tag,strd,gene,block
        else:
           ad_seq1,mp_err1,len_tag,strd1=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
           ad_seq2,mp_err2,len_tag,strd2=adjust_assemble_seq(ref_seq,circ_hash[index][:-ovlap])
           if mp_err1<=mp_err2 :
              return ad_seq1,ref_seqs,mp_err1,len_tag,strd1,gene,block
           else:
              return ad_seq2,ref_seqs,mp_err2,len_tag,strd2,gene,block

#compared the contig with every linear transcript
def compare_annotation_onebyone(index,isos,genome_fa,gene_info):
    chrom,sta,end,ovlap,ovlap_err=converse_index_all(index)
    ad_seq="";r_seq="";mp_err=5000;len_tag=0;strd="+";gene="";blk=[];temp_gene=""
    for iso in isos:
       for i in iso:
           gene_sta=gene_info[i][0][0];gene_end=gene_info[i][-1][-1]
           if 0==len(Interval.mapto([sta,end],[gene_sta,gene_end])):
              continue
           block=[]
           for j in range(0,len(gene_info[i][0])):
               block.append([gene_info[i][0][j],gene_info[i][-1][j]])
           temp_gene=i.split()[2] 
           block=Interval.mapto([sta,end],block)
           ref_seq="";ref_seqs=""
           for i in block:
               exon_seq=genome_fa.fetch(chrom, i[0], i[1]).upper()
               ref_seq+=exon_seq
               ref_seqs+=exon_seq+"\n"
           if not len(ref_seq) == len(circ_hash[index]) - ovlap:
               continue
           if ovlap_err == 0:
              a,e,l,s=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
              if e < mp_err:
                 ad_seq=a;r_seq=ref_seqs;mp_err=e;strd=s;blk=block;gene=temp_gene
           else:
              a1,e1,l,s1=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
              a2,e2,l,s2=adjust_assemble_seq(ref_seq,circ_hash[index][:-ovlap])
              if e1<=e2 and e1< mp_err:
                 ad_seq=a1;mp_err=e1;strd=s1;r_seq=ref_seqs;blk=block;gene=temp_gene
              elif e2<e1 and e2< mp_err:
                 ad_seq=a2;mp_err=e2;strd=s2;r_seq=ref_seqs;blk=block;gene=temp_gene
    return  ad_seq,r_seq,mp_err,len_tag,strd,gene,blk


#compared the contig with the genomic sequence between the backsplicing sites
def compare_annotation_circposition(index,genome_fa):
    chrom,sta,end,ovlap,ovlap_err=converse_index_all(index)
    ref_seq=genome_fa.fetch(chrom, sta, end).upper()
    blk=[];blk.append([sta, end])
    ad_seq="";mp_err=5000;strd="+" 
    if ovlap_err == "0":
       a,e,l,s=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
       ad_seq = a;mp_err=e;strd=s
    else:
        a1,e1,l,s1=adjust_assemble_seq(ref_seq,circ_hash[index][ovlap:])
        a2,e2,l,s2=adjust_assemble_seq(ref_seq,circ_hash[index][:-ovlap])
        if e1<=e2:
              ad_seq=a1;mp_err=e1;strd=s1;
        else:
              ad_seq=a2;mp_err=e2;strd=s2;
    return  ad_seq,ref_seq,mp_err,strd

#To adjust the circle sequence to reference and get the number of mismatch sites
def adjust_assemble_seq(ref_seq,circ_seq):
    lr=len(ref_seq);lc=len(circ_seq)
    len_tag=0;add_err=0
    if lr < lc:
        len_tag = -1;add_err=lc-lr
    elif lr > lc:
        len_tag = 1;add_err=lr-lc
    temp_seq=circ_seq
    lmin = lr if lr<lc else lc
    min_err=5000;min_seq="";min_strd="+"
    for i in range(0,lc):
        err=0
        for j in range(0,lmin):
           err+=0 if temp_seq[j] == ref_seq[j] else 1
        if err == 0: 
           return (temp_seq,add_err,len_tag,"+")
        if min_err > err:
           min_err = err;min_seq=temp_seq
        temp_seq=temp_seq[1:]+temp_seq[0]

    temp_seq=rev_comp(temp_seq)

    for i in range(0,lc):
        err=0
        for j in range(0,lmin):
           err+=0 if temp_seq[j] == ref_seq[j] else 1
        if err == 0: 
           return (temp_seq,add_err,len_tag,"-")
        if min_err > err:
           min_err = err;min_seq=temp_seq;min_strd="-"
        temp_seq=temp_seq[1:]+temp_seq[0]

    return (min_seq,min_err+add_err,len_tag,min_strd)

def converse_block_seq(sta,block):
    blk_seq="";sta=int(sta)
    for i in block:
        if len(blk_seq) == 0:
            blk_seq="("+str(i[0]-sta)+"-"+str(i[1]-sta)
        else:
            blk_seq+=" "+str(i[0]-sta)+"-"+str(i[1]-sta)
    if len(blk_seq) == 0:
        blk_seq="none"
    else:
        blk_seq+=")"
    return blk_seq

def annotate_seq(circ_f,pair_f,ref_f, genome_fa, output_f):
    """
    Align fusion juncrions to gene annotations
    """
    fa=genome_fa
    genes, gene_info,all_stas,all_ends = parse_ref(ref_f)  # gene annotations
    circ_hash_prime=read_file(circ_f)  # read the contigs assembled by cap3
    pair_hash_prime=read_file(pair_f)  # read supporting PE number info
    fusions,pair_hash=check_seq(circ_hash_prime,pair_hash_prime,genome_fa,all_stas,all_ends)

    outlist_ovlap_short=[]
    iso_seq = defaultdict(list)
    index_mark=defaultdict(list)
    for chrom in genes:
            # overlap gene annotations with circle sites
        result = Interval.overlapwith(genes[chrom].interval,
                                          fusions[chrom])
        for itl in result:
                # extract gene annotations
            iso = list(filter(lambda x: x.startswith('iso'), itl[2:]))
                # for each overlapped circle sites
            for fus in itl[(2 + len(iso)):]:
                index_mark[fus]=1
                iso_seq[fus].append(iso)

    f_res = open(output_f, 'w')
    f_res_dtails = open(output_f+'_dtails', 'w')

    # to make order of the output 
    iso_seq_keys=sort_keys(iso_seq)
    circ_hash_keys=sort_keys(circ_hash)

    # Exon and intron circRNAs
    for index in iso_seq_keys:
        pair_num=int(pair_hash[index])
        chrom,sta,end = converse_index(index)
        bound_tag=circ_boundary[index]
        spl_site=fa.fetch(chrom, sta-2, sta).upper()+" "+fa.fetch(chrom, end, end+2).upper()
        isos=iso_seq[index]
        adjust_seq,ref_seq,map_err,len_tag,strand,gene,block=compare_annotation(index,isos,genome_fa,gene_info)
        exon_tag="exon"
        if len_tag == 1 :  #the union of exons between the circle sites longer than the contig
            ad_seq,r_seq,mp_err,l_tag,strd,gn,blk=compare_annotation_onebyone(index,isos,genome_fa,gene_info)
            if mp_err < 2:   #one of the linear transcript matched
                adjust_seq=ad_seq;ref_seq=r_seq;map_err=mp_err;strand=strd;gene=gn;block=blk;len_tag=l_tag
        elif len_tag == -1 :   #the union of exons between the circle sites shortet than the contig
            ad_seq,r_seq,mp_err,strd=compare_annotation_circposition(index,genome_fa)
            intron_tag=0
            if mp_err < map_err or 0==len(ref_seq): #the genomic sequence between the backsplicing sites matched
                if 0==len(ref_seq):
                    exon_tag="intron"
                adjust_seq=ad_seq;ref_seq=r_seq+"\n";map_err=mp_err;strand=strd;len_tag=0
        chrom,sta,end,ovlap,ovlap_err=converse_index_all(index)
        blk_seq=converse_block_seq(sta,block)
        len_seq=len(adjust_seq)
        if ovlap < 10 and map_err > 1:
           outlist_ovlap_short.append(index)
           continue
        if pair_num <= 1:    # supporting PEs for each circle should >=2
           continue
        if map_err==0:
           f_res_dtails.write('%s_%d_%d\novlap:%d\tovlap_err:%d\tpair_num:%d\tcirc_len:%d\tmap_err:%d\tgene:%s\tbound:%s %s\texon_block:%s\t%s\ncirc_seq:\n%s\n\n' % \
            (chrom,sta,end,ovlap,ovlap_err,pair_num,len_seq,map_err,gene,spl_site,bound_tag,blk_seq,strand,adjust_seq))
        else:
           f_res_dtails.write('%s_%d_%d\novlap:%d\tovlap_err:%d\tpair_num:%d\tcirc_len:%d\tmap_err:%d\tgene:%s\tbound:%s %s\texon_block:%s\t%s\nref_seq:\n%scirc_seq:\n%s\n\n' % \
            (chrom,sta,end,ovlap,ovlap_err,pair_num,len_seq,map_err,gene,spl_site,bound_tag,blk_seq,strand,ref_seq,adjust_seq))
        f_res.write('%s_%d_%d\t%d\t%d\t%s\t%d\t%d\t%s\t%d\t%d\n%s\n'%(chrom,sta,end,ovlap,ovlap_err,exon_tag,len_seq,pair_num,spl_site,map_err,len_tag,adjust_seq))

    #Intergenic circRNAs       
    for index in circ_hash_keys:
        pair_num=int(pair_hash[index])
        chrom,sta,end = converse_index(index)
        spl_site=fa.fetch(chrom, sta-2, sta).upper()+" "+fa.fetch(chrom, end, end+2).upper()
        if not index in index_mark.keys():
            chrom,sta,end,ovlap,ovlap_err=converse_index_all(index)
            adjust_seq,ref_seq,map_err,strand=compare_annotation_circposition(index,genome_fa)
            len_seq=len(adjust_seq)
            if ovlap < 10 and map_err > 1:
               outlist_ovlap_short.append(index)
               continue
            if pair_num <= 1:  # supporting PEs for each circle should >=2
               continue
            if map_err==0:
               f_res_dtails.write('%s_%d_%d\novlap:%d\tovlap_err:%d\tpair_num:%d\tcirc_len:%d\tmap_err:%d\tgene:none\tbound:%s\t%s\ncirc_seq:\n%s\n\n' % \
            (chrom,sta,end,ovlap,ovlap_err,pair_num,len_seq,map_err,spl_site,strand,adjust_seq))
            else:
               f_res_dtails.write('%s_%d_%d\novlap:%d\tovlap_err:%d\tpair_num:%d\tcirc_len:%d\tmap_err:%d\tgene:none\tbound:%s\t%s\nref_seq:\n%s\ncirc_seq:\n%s\n\n' % \
            (chrom,sta,end,ovlap,ovlap_err,pair_num,len_seq,map_err,spl_site,strand,ref_seq,adjust_seq))
            f_res.write('%s_%d_%d\t%d\t%d\tinter\t%d\t%d\t%s\t%d\t0\n%s\n'%(chrom,sta,end,ovlap,ovlap_err,len_seq,pair_num,spl_site,map_err,adjust_seq))

    f_res.close();f_res_dtails.close()

    print "Overlap_short Delete:"+str(len(outlist_ovlap_short))
    #print_list(outlist_ovlap_short)



def parse_ref(ref_file):
    genes = defaultdict(list)
    gene_info = {}
    all_stas=defaultdict(list);all_ends=defaultdict(list)
    with open(ref_file, 'r') as f:
        for line in f:
            gene_id, iso_id, chrom, strand = line.split()[:4]
            total_id = '\t'.join(['iso', gene_id, iso_id, chrom, strand])
            starts = [int(x) for x in line.split()[9].split(',')[:-1]]
            ends = [int(x) for x in line.split()[10].split(',')[:-1]]
            start = starts[0]
            end = ends[-1]
            genes[chrom].append([start, end, total_id])
            gene_info[total_id] = [starts, ends]
            for i in range(0,len(starts)):
               all_stas[chrom].append(starts[i])
               all_ends[chrom].append(ends[i])
    for chrom in genes:
        genes[chrom] = Interval(genes[chrom])
    return (genes, gene_info,all_stas,all_ends)



if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.exit(__doc__)
    options = docopt(__doc__)
    try:
        genome_fa = pysam.Fastafile(options['--genome'])
    except:
        sys.exit('Please make sure %s is a Fasta file and indexed!'
                 % options['--genome'])
    print "circ_anotation.py Start .."
    ref_f = options['--ref']
    output_prefix = options['--output']
    output_f = os.getcwd()+'/'+output_prefix+'_output/'+output_prefix + '_res' 
    annotate_seq(options['--circ'],options['--pair'],ref_f, genome_fa,output_f)
    print "circ_anotation.py Done!"

