"""
circseq_cup -- circular RNA analysis toolkits.

Usage: pack_and_assemble.py [options]

Options:
    -h --help                      Show this screen.
    -f FUSION --fusion=FUSION      TopHat BAM file. 
    -g GENOME --genome=GENOME      Genome FASTA file.
    -o PREFIX --output=PREFIX      Output prefix [default: circ].
"""


from docopt import docopt
import sys
import pysam
from collections import defaultdict
from interval import Interval
from Bio import SeqIO 
import tempfile
import os

COMPLEMENT = {
    'a' : 't',
    't' : 'a',
    'c' : 'g',
    'g' : 'c',
    'k' : 'm',
    'm' : 'k',
    'r' : 'y',
    'y' : 'r',
    's' : 's',
    'w' : 'w',
    'b' : 'v',
    'v' : 'b',
    'h' : 'd',
    'd' : 'h',
    'n' : 'n',
    'A' : 'T',
    'T' : 'A',
    'C' : 'G',
    'G' : 'C',
    'K' : 'M',
    'M' : 'K',
    'R' : 'Y',
    'Y' : 'R',
    'S' : 'S',
    'W' : 'W',
    'B' : 'V',
    'V' : 'B',
    'H' : 'D',
    'D' : 'H',
    'N' : 'N',
}

def complement(s):
    return "".join([COMPLEMENT[x] for x in s])
def rev_comp(seq):
    return complement(seq)[::-1]

# Get the joint sequence of each circle
def get_circ_info(real_name,genome_fa,n):
    chrom,sta,end=real_name.split('_')[:3]
    sta=int(sta);end=int(end)
    circ_l=end-sta
    fa=genome_fa
    circ_seq = fa.fetch(chrom,end-n,end)+fa.fetch(chrom,sta, sta+n)
    return  (circ_l,circ_seq)

def convert_fusion(fusion_bam,genome_fa):
    """
    Extract reads from the BAM file
    """
    splice_read = []; fusion_reads_num = defaultdict(int)
    pre_name = "";real_name = ""
    for i, read in enumerate(parse_bam(fusion_bam)):
            qname,rname,sta1, end1,seq1,qual1,sta2, end2,seq2,qual2,strand1=read
            real_name = fusion_bam.getrname(rname)
            if real_name != pre_name:
               output_reads(pre_name,splice_read)
               pre_name = real_name
               splice_read = []
            circ_l,circ_seq = get_circ_info(real_name,genome_fa,10)
            # b:a flag to indicate whether the mapped region covering the joints
            # ok:a flag to indicate the consistent of mapped sequence covering the joints with the reference
            b1=0;b2=0;ok1=0;ok2=0    
            if circ_l <= end1-10 and circ_l >= sta1+10:  #read1 covering the joint
               b1=1;ok1=is_fusion(sta1,end1,seq1,circ_l,circ_seq)
            if circ_l <= end2-10 and circ_l >= sta2+10:  #read2 covering the joint
               b2=1;ok2=is_fusion(sta2,end2,seq2,circ_l,circ_seq)
            if (b1==1 and b2==0 and ok1==1 and part_ok(sta2,end2,seq2,circ_l,circ_seq)==1) \
                or (b2==1 and b1==0 and ok2==1 and part_ok(sta1,end1,seq1,circ_l,circ_seq)==1) \
                or (b2==1 and b1==1 and ok1==1 and ok2==1):
               splice_read.append([qname,seq1,qual1,seq2,qual2,strand1]) 
               fusion_reads_num[real_name]+=1
    output_reads(real_name,splice_read)
    output_reads_num(fusion_reads_num)

def output_reads(name,splice_read):
    if not splice_read:
        return 
    if len(splice_read)>10000:
        print (name+'\tis\twith more than 10000 supporting PE and is discarded\n')
        return
    with open(temp_dir + '/'+output_prefix+'_output/pack_reads/' + name, 'w') as outf:
        for read in splice_read:
            qname,seq1,qual1,seq2,qual2,strand1=read
            """
            seqs='@'+qname+'\n'+seq1+'\n+\n'+qual1+'\n@'+qname+'\n'+seq2+'\n+\n'+qual2
            """
            if strand1 == "-":
               seq1=rev_comp(seq1)
               qual1= qual1[::-1]
            elif strand1 == "+":
               seq2=rev_comp(seq2)
               qual2= qual2[::-1]
            outf.write('@%s\n%s\n+\n%s\n@%s\n%s\n+\n%s\n' % (qname,seq1,qual1,qname,seq2,qual2))

def output_reads_num(fusion_reads_num):
    with open(temp_dir + '/'+output_prefix+'_output/'+output_prefix+'_reads_num', 'w') as outf:
        for name,num in fusion_reads_num.items():
            outf.write('%s\t%d\n' % (name, num))

#The 20nt joint sequence(both 10nt of two ends) are further compared to check the backsplicing reads
def is_fusion(sta,end,seq,circ_l,circ_seq):
   seq=seq.upper()
   circ_seq=circ_seq.upper()
   seq_len=len(seq)
   if end-sta == seq_len:
      if if_same(seq[circ_l-sta-10:circ_l-sta+10], circ_seq) == 1:
         return 1
   else:
      for i in range(0,seq_len-20):
         if seq[i:i+20] == circ_seq:
            return 1
   return 0

#The joint sequence less than 20nt are further compared to check the backsplicing reads
def part_ok(sta,end,seq,circ_l,circ_seq):
   if circ_l <= sta or circ_l >= end:
      return 1
   seq_len=len(seq)
   if sta > circ_l-10 and sta < circ_l+10:
      if if_same(seq[:circ_l-sta+10], circ_seq[-(circ_l-sta+10):]) == 1:
         return 1
   elif end > circ_l-10 and end < circ_l+10:
      if if_same(seq[-(end-circ_l+10):], circ_seq[:end-circ_l+10]) == 1:
         return 1
   return 0

#ckeck the consistance of two sequences with allowing one mismatch
def if_same(read_seq,circ_seq):
   r_seq=read_seq.upper()
   c_seq=circ_seq.upper()
   read_len=len(r_seq)
   count=0
   """
   print read_seq+'\t'+circ_seq
   """
   for i in range(0,read_len):
      if r_seq[i] != c_seq[i]:
         count+=1
      if count > 1:
         return 0
   return 1


def parse_bam(bam):
    p_rname = 0
    fusions1 = {};fusions2 = {}
    for read in bam:
        if read.rname != p_rname:
           p_rname= read.rname
           fusions1={};fusions2={}
        if read.is_secondary:  # not the primary alignment
            continue
        strand = '+' if not read.is_reverse else '-' 
        if read.qname in fusions1 and read.qname in fusions2:  # the read has been collected
            continue
        elif read.qname not in fusions1: # first read
            fusions1[read.qname] = [read.qname,read.pos, read.aend,read.rname,strand,read.seq,read.qual]
        else:  # second read
            qname1,sta1, end1, rname1,strand1,seq1,qual1 = fusions1[read.qname]
            if strand1 !=strand:
               sta2, end2,seq2,qual2 = read.pos, read.aend,read.seq,read.qual
               fusions2[read.qname] = 1
               yield [read.qname,read.rname,sta1, end1,seq1,qual1,sta2,end2,seq2,qual2,strand1]

#To assemble the packed reads with cap3
def assemble_seq(output_prefix):
   temp_dir = os.getcwd()
   if not os.path.exists(temp_dir + '/'+output_prefix+'_output/cap3_output'):
       os.system("mkdir "+temp_dir+'/'+output_prefix+"_output/cap3_output")
   outf= open(temp_dir + '/'+output_prefix+'_output/cap3_circ_res', 'w')
   files = os.listdir(output_prefix+"_output/pack_reads")
   for name in files:
      splicing_reads(name,outf)
   outf.close()
   os.system('rm -rf '+temp_dir + '/'+output_prefix+'_output/cap3_output')


def splicing_reads(i,outf):
   SeqIO.convert(output_prefix+"_output/pack_reads/"+i, "fastq", output_prefix+"_output/cap3_output/"+i+".qual", "qual")
   SeqIO.convert(output_prefix+"_output/pack_reads/"+i, "fastq", output_prefix+"_output/cap3_output/"+i, "fasta")
   os.system("cap3 "+output_prefix+"_output/cap3_output/"+i+" -o 20 -s 300 -j 40 > cap3.out")

   contigs_file = open(output_prefix+"_output/cap3_output/"+i+".cap.contigs")
   all_text = contigs_file.read( )
   contigs_file.close()  
   arr=all_text.split('\n')
   s=""
   for line in arr:
      if line[:1] == '>':
         check_seq(i,s,outf);s=""          
      else:
         s+=line
   check_seq(i,s,outf)
   os.system("rm "+output_prefix+"_output/cap3_output/*")

#To check if the contig contain same sequece at two ends
def check_seq(i,s,outf):
   overlap,diff=circ_ok(s)
   if overlap < 10 and diff == 1:
      return
   if s != "" and overlap != 0:
      outf.write('%s_%s_%s\t%s\n' % (i,overlap,diff,s))

def circ_ok(s):
    if s == "" :
      return [0,1]
    bias_same=0;len_s=len(s);min_range=4
    if len<=100:
       min_range=19
    for i in range(int(len_s/2),min_range,-1):
       if(s[:i] == s[-i:]):
          return [i,0]
       elif if_same(s[:i],s[-i:]) == 1 and bias_same == 0:
          bias_same=i
    return [bias_same,1]


if __name__ == '__main__':
    if len(sys.argv) == 1:
        sys.exit(__doc__)
    options = docopt(__doc__)
    try:
        fusion_bam = pysam.Samfile(options['--fusion'], 'rb')
    except:
        sys.exit('Please make sure %s is BAM file!' % options['--fusion'])
    try:
        genome_fa = pysam.Fastafile(options['--genome'])
    except:
        sys.exit('Please make sure %s is a Fasta file and indexed!'
                 % options['--genome']) 
    print('pack_seq Start ..' )
    output_prefix = options['--output']
    temp_dir = os.getcwd();
    if os.path.exists(temp_dir + '/'+output_prefix+'_output/pack_reads'):
        os.system('rm -rf '+temp_dir + '/'+output_prefix+'_output/pack_reads')
    os.mkdir(temp_dir + '/'+output_prefix+'_output/pack_reads')
    convert_fusion(fusion_bam,genome_fa)
    print('pack_seq Done!' )
    print('assemble_seq Start ..' )
    assemble_seq(output_prefix)
    print('assemble_seq Done ..' )