% This script runs the numerical experiments of LADM algorithm in 
% subsection 5.3.

clear;
clc;

rng('default');
% seed = floor(rand() * 100000);
seed = 20;
rng(seed);

%% select missing types and rates

% select missing type
MissingType = 'R';  % random corruption
% MissingType = 'S';  % random small block
% MissingType = 'M';  % random medium blocks
% MissingType = 'L';  % large central blocks


% select the missing rate
if strcmp(MissingType,'R')
    MissingRate = 0.4:0.1:0.7;
else
    MissingRate = 0.1:0.1:0.2;
end

%% test the LADM algorithm
time_LADM = zeros(size(MissingRate)); % save the running time at each missing rate
for rate = MissingRate
    sum_time = 0;
    for i=1:112
        if i==14 % the 14th image is not in the database
            continue;
        end
        InputFileName = ['..\..\data\texture_database\D',num2str(i),'.png']; % known image file name
        if strcmp(MissingType,'L')
            MaskFileName = ['..\..\data\mask_',MissingType,'\M',num2str(rate),'.png']; % mask file name
        else
            MaskFileName = ['..\..\data\mask_',MissingType,'\',num2str(rate),'\M',num2str(i),'.png'];
        end
        OutputFileName = ['..\..\experiment5.3\texture_database\result\',MissingType,'\LADM\png\',num2str(rate),'\LADM',num2str(i),'.png']; % output file name (in png format)
        OutputFileName_mat = ['..\..\experiment5.3\texture_database\result\',MissingType,'\LADM\mat\',num2str(rate),'\LADM',num2str(i),'.mat']; % output file name (in mat format)
        
%         % for linux user
%         InputFileName = ['../.../data/texture_database/D',num2str(i),'.png']; % known image file name
%         if strcmp(MissingType,'L')
%             MaskFileName = ['../../data/mask_',MissingType,'/M',num2str(rate),'.png']; % mask file name
%         else
%             MaskFileName = ['../../data/mask_',MissingType,'/',num2str(rate),'/M',num2str(i),'.png'];
%         end
%         OutputFileName = ['../../experiment5.3/texture_database/result/',MissingType,'/LADM/png/',num2str(rate),'/LADM',num2str(i),'.png']; % output file name (in png format)
%         OutputFileName_mat = ['../../experiment5.3/texture_database/result/',MissingType,'/LADM/mat/',num2str(rate),'/LADM',num2str(i),'.mat']; % output file name (in mat format)


        A = double(imread(InputFileName));
        E = A;

        normA = norm(A);
        A = A / normA;
        [m, n] = size(A);


        Mask = logical(imread(MaskFileName));

        % determine r
        t1 = 0.2;
        t2 = 0.995;

        A(Mask) = sum(A(:))/sum(sum(~Mask));
        [U, D, V] = svd(A);

        diagD = diag(D);
        normD = sqrt(sum(diagD.^2));
        for r = 1:min(m,n)
            if sqrt(sum(diagD(1:r).^2))/normD >= t2
                break;
            end
        end
        r_max = floor(t1*min(m,n));
        r = min(r,r_max);


        A(Mask)=0;

        U = U(:, 1:r);
        D = D(1:r, 1:r);
        V = V(:, 1:r);

        %     U = orth(randn(m, r));
        %     D = randn(r, r);
        %     V = orth(randn(n, r));
        Xinitial.main = U * D * V';
        Xinitial.U = U;
        Xinitial.D = D;
        Xinitial.V = V;
        type = 2;
        LADMmu = 0.1;
        LADMrho = 1.1;
        LADMeta = 3;


        %method: 1: AManPG, 2: LADM
        %             method = 1;
        %             lambda = 0.0001;

        method = 2;
        lambda = 0.08;


        SolverParams.method = 'IARPG'; % IRPG IARPG
        SolverParams.IsCheckParams = 1;
        %     SolverParams.RPGVariant = 0; %0: RPG without adaptive stepsize, 1: RPG with adaptive stepsize
        SolverParams.LengthW = 1;
        SolverParams.OutputGap = 10;
        SolverParams.Max_Iteration = 500;
        % SolverParams.SMtol = 1e-2;
        SolverParams.Stop_Criterion = 3;
        SolverParams.Tolerance = 1e-3;
        %     SolverParams.Min_Iteration = 10;
        SolverParams.Verbose = 2;

        %     [xopt, f, gf, gfgf0, iter, nf, ng, nR, nV, nVp, nH, ComTime, funs, grads, times] = ...
        %         TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams);
        % m*n
        % nnz(sparse(A))
        % return;
        tic;
        [xopt1] = TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams, LADMmu, LADMrho, LADMeta);
        sum_time = sum_time + toc;

        xopt1.main = xopt1.main * normA;
        %     svd(xopt1.main)'
        rA = dct2(xopt1.main);
        save (OutputFileName_mat,'rA')
        E(Mask)=rA(Mask);
        imwrite(uint8(E),OutputFileName);
    end
    time_LADM(round(rate*10)) = sum_time;
end
time_LADM


%% Some other functions that may be used

function M = haarFWT_2d(M)
[n1, n2] = size(M);
tmp = M;
r2 = sqrt(2);
k = 1;
while(2 * k <= n1)
    k = k * 2;
end
while(1 < k)
    k = k / 2;
    for j = 1 : n2
        for i = 1 : k
            tmp(i, j) = (M(2 * i - 1, j) + M(2 * i, j)) / r2;
            tmp(k + i, j) = (M(2 * i - 1, j) - M(2 * i, j)) / r2;
        end
    end
    for j = 1 : n2
        for i = 1 : 2 * k
            M(i, j) = tmp(i, j);
        end
    end
end
k = 1;
while(2 * k <= n2)
    k = k * 2;
end
while(1 < k)
    k = k / 2;
    for j = 1 : k
        for i = 1 : n1
            tmp(i, j) = (M(i, 2 * j - 1) + M(i, 2 * j)) / r2;
            tmp(i, k + j) = (M(i, 2 * j - 1) - M(i, 2 * j)) / r2;
        end
    end
    for j = 1 : 2 * k
        for i = 1 : n1
            M(i, j) = tmp(i, j);
        end
    end
end
end

function M = haarFWT_2d_inverse(M)
[n1, n2] = size(M);
tmp = M;
r2 = sqrt(2);
k = 1;
while(k * 2 <= n2)
    for j = 1 : k
        for i = 1 : n1
            tmp(i, 2 * j - 1) = (M(i, j) + M(i, k + j)) / r2;
            tmp(i, 2 * j) = (M(i, j) - M(i, k + j)) / r2;
        end
    end
    for j = 1 : 2 * k
        for i = 1 : n1
            M(i, j) = tmp(i, j);
        end
    end
    k = k * 2;
end
k = 1;
while(k * 2 <= n1)
    for j = 1 : n2
        for i = 1 : k
            tmp(2 * i - 1, j) = (M(i, j) + M(k + i, j)) / r2;
            tmp(2 * i, j) = (M(i, j) - M(k + i, j)) / r2;
        end
    end
    for j = 1 : n2
        for i = 1 : 2 * k
            M(i, j) = tmp(i, j);
        end
    end
    k = k * 2;
end
end