% This script runs the numerical experiments of LADM and our IARPG algorithm
% in subsection 5.2.

clear;
clc;

rng('default');
% seed = floor(rand() * 100000);
seed = 20;
fprintf('TestFRankETextureInpainting seed:%d\n', seed);
rng(seed);


% select the algorithm
TestMethod = 'IARPG';
% TestMethod = 'LADM';


InputFileName = ['..\..\data\artificial_image\image.png'];         % known image file name
MaskFileName = ['..\..\data\artificial_image\mask.png'];           % mask file name
OutputFileName = ['..\..\experiment5.2\artificial_image\result\',TestMethod,'\',TestMethod,'.png'];     % output image file name (in png format)

% % for linux user
% InputFileName = ['../../data/artificial_image/image.png'];         % known image file name
% MaskFileName = ['../../data/artificial_image/mask.png'];           % mask file name
% OutputFileName = ['../../experiment5.2/artificial_image/result/',TestMethod,'/',TestMethod,'.png'];     % output image file name (in png format)

A = double(imread(InputFileName)); % read known image

E = A;  % save the source region pixel value
A_true = A;


normA = norm(A);
A = A / normA;
[m, n] = size(A);

Mask = logical(imread(MaskFileName)); % read mask


r = 2;  % the exact rank is 2

A(Mask)=0;


[U, D, V] = svd(A);
U = U(:, 1:r);
D = D(1:r, 1:r);
V = V(:, 1:r);

%     U = orth(randn(m, r));
%     D = randn(r, r);
%     V = orth(randn(n, r));

Xinitial.main = U * D * V';
Xinitial.U = U;
Xinitial.D = D;
Xinitial.V = V;
type = 2;
LADMmu = 0.1;
LADMrho = 1.1;
LADMeta = 3;


% method: 1: AManPG, 2: LADM
if strcmp(TestMethod,'IARPG')
    method = 1;
    lambda = 0.00004;
elseif strcmp(TestMethod,'LADM')
    method = 2;
    lambda = 0.00001;
end


SolverParams.method = 'IARPG'; % IRPG IARPG
SolverParams.IsCheckParams = 1;
%     SolverParams.RPGVariant = 0; %0: RPG without adaptive stepsize, 1: RPG with adaptive stepsize
SolverParams.LengthW = 1;
SolverParams.OutputGap = 10;
SolverParams.Max_Iteration = 500;
% SolverParams.SMtol = 1e-2;
SolverParams.Stop_Criterion = 3;
SolverParams.Tolerance = 1e-3;
%     SolverParams.Min_Iteration = 10;
SolverParams.Verbose = 2;

%     [xopt, f, gf, gfgf0, iter, nf, ng, nR, nV, nVp, nH, ComTime, funs, grads, times] = ...
%         TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams);
% m*n
% nnz(sparse(A))
% return;
tic;
[xopt1] = TestFRankETextureInpainting(sparse(A), lambda, Xinitial, type, method, SolverParams, LADMmu, LADMrho, LADMeta);
time = toc

xopt1.main = xopt1.main * normA;
%     svd(xopt1.main)'
rA = dct2(xopt1.main);
E(Mask)=rA(Mask);

imwrite(uint8(E),OutputFileName); % save the output

% print PSNR
PSNR = psnr(uint8(E),uint8(A_true));
fprintf('PSNR:%.2f\n',PSNR);


% show the output
figure(1);
subplot(1, 3, 1);
%imagesc(A);
imshow(uint8(normA*A));
title('Original image');
if(type == 2)
    if(method == 1)
        subplot(1, 3, 2);
        rA = dct2(xopt1.main);
        %imagesc(rA);
        imshow(uint8(E));
        title('Recovered image AManPG');
    else
        subplot(1, 3, 3);
        rA = dct2(xopt1.main);
        %imagesc(rA);
        imshow(uint8(E));
        title('Recovered image LADM');
    end
    max(max(normA*A - rA));
    max(max(normA*A));
    here = norm(full(normA*A - rA));
else
    rA = haarFWT_2d_inverse(xopt1.main);
    %imagesc(rA);
    imshow(uint8(rA));
end




function M = haarFWT_2d(M)
[n1, n2] = size(M);
tmp = M;
r2 = sqrt(2);
k = 1;
while(2 * k <= n1)
    k = k * 2;
end
while(1 < k)
    k = k / 2;
    for j = 1 : n2
        for i = 1 : k
            tmp(i, j) = (M(2 * i - 1, j) + M(2 * i, j)) / r2;
            tmp(k + i, j) = (M(2 * i - 1, j) - M(2 * i, j)) / r2;
        end
    end
    for j = 1 : n2
        for i = 1 : 2 * k
            M(i, j) = tmp(i, j);
        end
    end
end
k = 1;
while(2 * k <= n2)
    k = k * 2;
end
while(1 < k)
    k = k / 2;
    for j = 1 : k
        for i = 1 : n1
            tmp(i, j) = (M(i, 2 * j - 1) + M(i, 2 * j)) / r2;
            tmp(i, k + j) = (M(i, 2 * j - 1) - M(i, 2 * j)) / r2;
        end
    end
    for j = 1 : 2 * k
        for i = 1 : n1
            M(i, j) = tmp(i, j);
        end
    end
end
end

function M = haarFWT_2d_inverse(M)
[n1, n2] = size(M);
tmp = M;
r2 = sqrt(2);
k = 1;
while(k * 2 <= n2)
    for j = 1 : k
        for i = 1 : n1
            tmp(i, 2 * j - 1) = (M(i, j) + M(i, k + j)) / r2;
            tmp(i, 2 * j) = (M(i, j) - M(i, k + j)) / r2;
        end
    end
    for j = 1 : 2 * k
        for i = 1 : n1
            M(i, j) = tmp(i, j);
        end
    end
    k = k * 2;
end
k = 1;
while(k * 2 <= n1)
    for j = 1 : n2
        for i = 1 : k
            tmp(2 * i - 1, j) = (M(i, j) + M(k + i, j)) / r2;
            tmp(2 * i, j) = (M(i, j) - M(k + i, j)) / r2;
        end
    end
    for j = 1 : n2
        for i = 1 : 2 * k
            M(i, j) = tmp(i, j);
        end
    end
    k = k * 2;
end
end