function [OUT_ALIGN cur_score] = RefineAlignments(IN_ALIGN, varargin)
%   RefineAlignments
%       Uses an interative algorithm to improve the multiple alignments of
%       an input sequence.  It will iteratively remove a fraction of the
%       sequences, compress any gaps, and then re-align.  This will be done
%       for NUM_REPS.
%
%   Written by Will Dampier.  Contact at wnd22@drexel.edu
%   Version 1.0 on 10/23/09
%   
%   [OUT_ALIGN SCORE] = RefineAlignments(IN_ALIGN)
%
%   IN_ALIGN        A char-array representing a multiple alignment to be
%                   refined.
%
%   OUT_ALIGN       The refined alignment.
%   SCORE           The new alignment score.
%
%
%   Optional Arguements:
%
%       'Alphabet'      Either 'AA' or 'NT' to indicate whether this is a
%                       nucleotide or amino acid alignment.
%
%       'GapPenalty'    The penalty for gaps.  This is set at -8 by
%                       default.
%
%       'DistMat'       The distance matrix to use for the alignment 
%                       scoring.  Default is BLOSUM50 for 'AA' and NUC44
%                       for 'NT'
%
%       'NumReps'       The number of refining iterations to do.
%
%       'NumTry'        The number of tries to perform at each iteration.
%
%       'RefineType'    Which type of refinement to perform.  This can
%                       be either 'STOCASTIC', 'DETERMINISTIC' or 'MIXED'.
%
%       'Display'       A boolean indicating whether to display the output
%                       at each iteration.
%
%




RM_FRAC = 0.1;
NUM_TRIES = 10;
NUM_REPS = 100;
alpha = 'aa';
dist_mat = [];
gap_penalty = -8;
TYPE = 'mixed';
disp_flag = true;

for i = 1:2:length(varargin)
    switch lower(varargin{i})
        
        case 'alphabet'
            if strcmpi('nt', varargin{i+1}) || strcmpi('aa', varargin{i+1})
                alpha = varargin{i+1};
            else
                error('RefineAlignments:BADALPHA', ...
                    'Arguement to "alphabet" must be "NT" or "AA"')
            end
        case 'gappenalty'
            if isnumeric(varargin{i+1}) && varargin{i+1} < 0
                gap_penalty = varargin{i+1};
            else
                error('RefineAlignments:BADGAP', ...
                    'Arguement to "GapPenalty" must be negative numeric')
            end
        case 'distmat'
            dist_mat = varargin{i+1};
            
        case 'numreps'
            if isnumeric(varargin{i+1}) && varargin{i+1} > 1
                NUM_REPS = varargin{i+1};
            else
                error('RefineAlignments:BADREPS', ...
                        'Arguement to NumReps must be a positive numeric')
            end
            
        case 'refinetype'
            TYPE = varargin{i+1};
            
        case 'display'
            if islogical(varargin{i+1})
                disp_flag = varargin{i+1};
            else
                error('RefineAlignments:DISPLAY', ...
                        'Arguement to Display must be a boolean.')
            end
       
        otherwise
            error('RefineAlignments:BADARG', 'Unknown arguement: %s', ...
                    lower(varargin{i}))
        
    end
end

if isempty(dist_mat)
    if strcmpi(alpha, 'nt')
        dist_mat = nuc44;
        toint = @nt2int;
    else
        dist_mat = blosum50;
        toint = @aa2int;
    end
end

dist_mat(end+1,:) = gap_penalty;
dist_mat(:,end+1) = gap_penalty;
dist_mat(end,end) = 0;

switch upper(TYPE)
    case 'MIXED'
        is_stoch = true;
        is_mix = true;
    case 'STOCASTIC'
        is_stoch = true;
        is_mix = false;
    case 'DETERMINISTIC'
        is_stoch = false;
        is_mix = false;
end


OUT_ALIGN = IN_ALIGN;
cur_score = CalculateScore(OUT_ALIGN, dist_mat, toint);

for i = 1:NUM_REPS
    
    %determine whether to switch based on the "mix" factor
    if is_mix
        is_stoch = ~is_stoch;
    end
    
    
    rm_rows = GetRows(OUT_ALIGN, is_stoch, RM_FRAC, NUM_TRIES);
    counter = 1;
    while counter < NUM_TRIES
        new_mat = SplitAndReAlign(OUT_ALIGN, rm_rows(:,counter), alpha);
        
        new_score = CalculateScore(new_mat, dist_mat, toint);
        if new_score > cur_score
            if disp_flag
                disp_cell = {'On Iter: ', num2str(i), ' Improved by: ', ...
                    num2str(abs(new_score-cur_score)), ' and by ', ...
                    num2str(abs(size(new_mat,2)-size(OUT_ALIGN,2))), ...
                    ' columns', ' Using stoch:', num2str(is_stoch)};
                display([disp_cell{:}])
            end
            cur_score = new_score;
            OUT_ALIGN = new_mat;
            break
        end
        counter = counter + 1;
    end
end

function rows_mat = GetRows(MAT, STOCH, RM_FRAC, NUM_TRIES)
%   GetRows
%       A helper function which determines the rows to split the alignment
%       with.
%
%   MAT         A char-array of the multiple alignment.
%   STOCH       A boolean indicating whether to use a stochastic approach.
%   RM_FRAC     The fraction of rows to remove.
%   NUM_TRIES   The number of tries to return.
%
%   rows_mat    A matrix indicating which rows to remove.
%

if STOCH
    rows_mat = rand(size(MAT,1),NUM_TRIES) > RM_FRAC;
    if any(all(rows_mat,1)) || any(all(~rows_mat,1))
        cols = any(all(rows_mat,1)) || any(all(~rows_mat,1));
        rows_mat(1,cols) = ~rows_mat(1,cols);
    end
        
else
    gap_mask = MAT =='-';
    runs = FindRuns(gap_mask);
    score = mode(runs.*(runs~=0))-min(runs);
    [~, order] = sort(score, 'descend');
    rm_rows = order(1:NUM_TRIES);
    [~, norder] = sort(runs(:,rm_rows),'descend');
    rows_mat = false(size(MAT,1), NUM_TRIES);
    num_remove = ceil(RM_FRAC*size(MAT,1));
    for i = 1:NUM_TRIES
        rows_mat(norder(1:num_remove,i),i) = true;
    end
end

function NEW_MAT = SplitAndReAlign(MAT, ROWS, alpha)
%   SplitAndReAlign
%       This will split the alignment matix based on the provided rows and
%       then create two profiles and re-align them.
%
%   MAT         A char-array of the alignment.
%   ROWS        A boolean-array describing the switch between rows
%   alpha       Either 'NT' or 'AA' to indicate which sequence profile to
%               create.
%
%   NEW_MAT     The ReAlinged alignment
%

mat1 = MAT(ROWS, :);
mat2 = MAT(~ROWS, :);

mat1(:,all(mat1 == '-',1)) = [];
mat2(:,all(mat2 == '-',1)) = [];

pmat1 = seqprofile(mat1, 'alphabet', alpha);
pmat2 = seqprofile(mat2, 'alphabet', alpha);
try
    [~, ind1, ind2] = profalign(pmat1, pmat2);
catch
    ind1 = 1:size(pmat1,2);
    ind2 = 1:size(pmat2,2);
end
    
NEW_MAT = char(0);
NEW_MAT(ROWS, ind1) = mat1;
NEW_MAT(~ROWS, ind2) = mat2;
NEW_MAT(~isletter(NEW_MAT)) = '-';
NEW_MAT(all(NEW_MAT == '-',1)) = [];




function s = CalculateScore(MAT, DIST, FUN)
%   CalculateScore
%       A function which calculates the score of the alignment.  This score
%       is based on the agreement between the consensus and each individual
%       alignment.
%
%   MAT         A char-array of a multiple alignment.
%   DIST        A distance matrix which indicates the penalty for
%               mismatches.
%   FUN         A function which converts letter to numbers such at nt2int
%               or aa2int.
%   
%   s           The score for this alingment.
%
%

cons = seqconsensus(MAT);
num_cons = FUN(cons);
num_align = FUN(MAT);
ncons = repmat(num_cons(:), [size(num_align,1), 1]);
nalign = num_align(:);
lookupinds = sub2ind(size(DIST), ncons, nalign);
s = sum(DIST(lookupinds));

function reverse_looking = FindRuns(input)
%   FindRuns
%       Finds consecutive runs of 1's along the rows of a boolean array.
%
%   INPUT = [1 1 1 0 0 0 1 0 1 0 1 1;
%            0 1 1 1 0 1 1 0 0 1 1 1];
%   RL =    [1 2 3 0 0 0 1 0 1 0 1 2;
%           [0 3 2 1 0 1 2 0 0 1 2 3];
%

[m,n] = size(input);
reverse_looking = [zeros(1,m);input.'];
reverse_looking = reverse_looking(:);
p = find(~reverse_looking);
reverse_looking(p) = [0;1-diff(p)];
reverse_looking = reshape(cumsum(reverse_looking),[],m).';
reverse_looking(:,1) = [];