clear all;

ns = [256, 256, 256, 256];%, 512, 512];
rs = [20, 40, 60, 80];%, 40, 80];
LADMlambdas = [0.01, 0.04, 0.16, 0.64];

numtest = 1;

tab = zeros(5, 8, numtest);

load('texture1.mat');
% grayflower = rgb2gray(cdata);
grayflower = double(cdata);

for k = 1 : 4
    n1 = ns(k);
    n2 = n1;
    r = rs(k);
    LADMlambda = LADMlambdas(k);
    
    for i = 1 : numtest
        [k, i]
        rng(i);
        
        startxidx = randi(size(grayflower, 1) - n1 + 1);
        startyidx = randi(size(grayflower, 2) - n2 + 1);
        A = double(grayflower(startxidx : startxidx + n1 - 1, startyidx : startyidx + n2 - 1))/255;
        normA = norm(A);
        A = A / normA;
        Atrue = A;
        [m, n] = size(A);
    
        B = rand(m, n);
        A(B>0.5) = 0;

        [U, D, V] = svd(A);
        U = U(:, 1:r);
        D = D(1:r, 1:r);
        V = V(:, 1:r);

        Xinitial.main = U * D * V';
        Xinitial.U = U;
        Xinitial.D = D;
        Xinitial.V = V;
        type = 1; % 1, haar, 2, dct
        LADMmu = 0.1;
        LADMrho = 1.1;
        LADMeta = 3;

        %method: 1: AManPG, 2: LADM
        method = 1; 
        lambda = 0.0002;
    
        SolverParams.method = 'IRPG'; % IRPG IARPG
        SolverParams.IsCheckParams = 1;
        SolverParams.Variant = 2; %0: RPG without adaptive stepsize, 1: RPG with adaptive stepsize, 2: BB
        SolverParams.LengthW = 1;
        SolverParams.OutputGap = 10;
        SolverParams.Max_Iteration = 500;
%         SolverParams.SMtol = 1e-2;
        SolverParams.Stop_Criterion = 3;
        SolverParams.Tolerance = 1e-3;
    %     SolverParams.Min_Iteration = 10;
        SolverParams.ProxMapType = 0;
        SolverParams.Verbose = 2;

        [xopt1, f, gf, gfgf0, iter1, nf, ng, nR, nV, nVp, nH, ComTime1, funs, grads, times] = ...
            TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams, LADMmu, LADMrho, LADMeta);
%         xopt1.main = xopt1.main * normA;
        [funvalue(xopt1, A, type), funerr(xopt1, Atrue, type), sparsity(xopt1, type), sum(svd(xopt1.main) > 1e-6)]
%         return
        
%         figure(1);
%         subplot(1, 3, 1);
%         imagesc(A);
%         title('Original image');
%         if(type == 2)
%             if(method == 1)
%                 subplot(1, 3, 2);
%                 rA = dct2(xopt1.main);
%                 imagesc(rA);
%                 title('Recovered image AManPG');
%             else
%                 subplot(1, 3, 3);
%                 rA = dct2(xopt1.main);
%                 imagesc(rA);
%                 title('Recovered image LADM');
%             end
%             max(max(A - rA))
%             max(max(A))
%             here = norm(full(A - rA))
%         else
%             if(method == 1)
%                 subplot(1, 3, 2);
%             else
%                 subplot(1, 3, 3);
%             end
%             rA = haarFWT_2d_inverse(xopt1.main);
%             imagesc(rA);
%         end
%         norm(Atrue / norm(Atrue, 'fro') - rA / norm(rA, 'fro'), 'fro')

        method = 2;
        lambda = LADMlambda;
    
        [xopt2, iter2, ComTime2] = TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams, LADMmu, LADMrho, LADMeta);
%         xopt2.main = xopt2.main * normA;
        [funvalue(xopt2, A, type), funerr(xopt2, Atrue, type), sparsity(xopt2, type), sum(svd(xopt2.main) > 1e-6)]
%         return;
%         if(type == 2)
%             if(method == 1)
%                 subplot(1, 3, 2);
%                 rA = dct2(xopt1.main);
%                 imagesc(rA);
%                 title('Recovered image AManPG');
%             else
%                 subplot(1, 3, 3);
%                 rA = dct2(xopt1.main);
%                 imagesc(rA);
%                 title('Recovered image LADM');
%             end
%             max(max(A - rA))
%             max(max(A))
%             here = norm(full(A - rA))
%         else
%             if(method == 1)
%                 subplot(1, 3, 2);
%             else
%                 subplot(1, 3, 3);
%             end
%             rA = haarFWT_2d_inverse(xopt1.main);
%             imagesc(rA);
%         end
%         norm(Atrue / norm(Atrue, 'fro') - rA / norm(rA, 'fro'), 'fro')
%         return;
        tab(1, k, i) = iter1;
        tab(2, k, i) = funvalue(xopt1, A, type);
        tab(3, k, i) = funerr(xopt1, Atrue, type);
        tab(4, k, i) = sparsity(xopt1, type);
        tab(5, k, i) = sum(svd(xopt1.main) > 1e-4);
        tab(6, k, i) = ComTime1;
        
        tab(1, k+4, i) = iter2;
        tab(2, k+4, i) = funvalue(xopt2, A, type);
        tab(3, k+4, i) = funerr(xopt2, Atrue, type);
        tab(4, k+4, i) = sparsity(xopt2, type);
        tab(5, k+4, i) = sum(svd(xopt2.main) > 1e-4);
        tab(6, k+4, i) = ComTime2;
    end
end

avetab = mean(tab, 3);

fout = fopen('IRPGvsLADMtab.txt','w');
for i = 1 : size(avetab, 1)
    for j = 1 : size(avetab, 2)
        if(i == 1 || i == 5)
            fprintf(fout, '%d', round(avetab(i, j)));
        else
            fprintf(fout, '$%s$ ', outputfloat(avetab(i, j)));
        end
        if(j < size(avetab, 2))
            fprintf(fout, ' & ', avetab(i, j));
        end
    end
    fprintf(fout,'\\\\\n');
end
fclose(fout);

function str = outputfloat(x)
    if(x <= 0)
        sn = '-';
        x = abs(x);
    else
        sn = '';
    end
    p = log(x)/log(10);
    p = - ceil(-p);
    x = round(x * 10^(-p) * 100);
    x = x / 100;
    strx = sprintf('%3.2f', x);
    if(p ~= 0)
        str = [sn strx '_{' num2str(p) '}'];
    else
        str = [sn strx];
    end
end

function output = funvalue(x, A, type)
    if(type == 2)
        rA = dct2(x.main);
    else
        rA = haarFWT_2d_inverse(x.main);
    end
    rA(A == 0) = 0;
    
    output = norm(A - rA, 'fro') / norm(A, 'fro');
end

function output = funerr(x, Atrue, type)
    
    if(type == 2)
        rA = dct2(x.main);
    else
        rA = haarFWT_2d_inverse(x.main);
    end
    
    output = norm(Atrue - rA, 'fro') / norm(Atrue, 'fro');
    
end

function output = sparsity(x, type)
    output = sum(sum(abs(x.main) < 1e-4)) / prod(size(x.main));
end


function M = haarFWT_2d(M)
    [n1, n2] = size(M);
    tmp = M;
    r2 = sqrt(2);
    k = 1;
    while(2 * k <= n1)
        k = k * 2;
    end
    while(1 < k)
        k = k / 2;
        for j = 1 : n2
            for i = 1 : k
                tmp(i, j) = (M(2 * i - 1, j) + M(2 * i, j)) / r2;
                tmp(k + i, j) = (M(2 * i - 1, j) - M(2 * i, j)) / r2;
            end
        end
        for j = 1 : n2
            for i = 1 : 2 * k
                M(i, j) = tmp(i, j);
            end
        end
    end
    k = 1;
    while(2 * k <= n2)
         k = k * 2;
    end
    while(1 < k)
        k = k / 2;
        for j = 1 : k
            for i = 1 : n1
                tmp(i, j) = (M(i, 2 * j - 1) + M(i, 2 * j)) / r2;
                tmp(i, k + j) = (M(i, 2 * j - 1) - M(i, 2 * j)) / r2;
            end
        end
        for j = 1 : 2 * k
            for i = 1 : n1
                M(i, j) = tmp(i, j);
            end
        end
    end
end

function M = haarFWT_2d_inverse(M)
    [n1, n2] = size(M);
    tmp = M;
    r2 = sqrt(2);
    k = 1;
    while(k * 2 <= n2)
        for j = 1 : k
            for i = 1 : n1
                tmp(i, 2 * j - 1) = (M(i, j) + M(i, k + j)) / r2;
                tmp(i, 2 * j) = (M(i, j) - M(i, k + j)) / r2;
            end
        end
        for j = 1 : 2 * k
            for i = 1 : n1
                M(i, j) = tmp(i, j);
            end
        end
        k = k * 2;
    end
    k = 1;
    while(k * 2 <= n1)
        for j = 1 : n2
            for i = 1 : k
                tmp(2 * i - 1, j) = (M(i, j) + M(k + i, j)) / r2;
                tmp(2 * i, j) = (M(i, j) - M(k + i, j)) / r2;
            end
        end
        for j = 1 : n2
            for i = 1 : 2 * k
                M(i, j) = tmp(i, j);
            end
        end
        k = k * 2;
    end
end

