%% Test scale (different 'p' or 'n' or lambda")
clear; clc;

mode    = 'lambda';   % choose 'p' or 'n' or lambda"
num_i   = 4;     % number of experiments (4 values of p or n or lambda)
num_alg = 5;     % RPG, ManPG, ManPG-Ada, RAPG, AR-RAPG
num_run = 10;     % number of random runs

% Initialize storage
Iter_all     = zeros(num_i, num_alg, num_run);
Time_all     = zeros(num_i, num_alg, num_run);
Sparsity_all = zeros(num_i, num_alg, num_run);

for i = 1:num_i
    % ----- Problem scale -----
    if strcmp(mode,'p')
        % Vary p, fix n and lambda
        p = i; m = 20; n = 128; lambda = 2;     
        xval = i;
    elseif strcmp(mode,'n')
        % Vary n, fix p and lambda
        p = 4; m = 20; n = 16 * 2^i; lambda = 2;
        xval = n;
    else
        % Vary lambda, fix p and n
        p = 4; m = 20; n = 128; lambda = 0.25 * 2^i;
        xval = lambda;
    end
    

    for r = 1:num_run
        rng(10*r);  % different random seed

        % Generate problem data
        A = randn(m, n);
        A = A - repmat(mean(A, 1), m, 1);
        A = A ./ repmat(sqrt(sum(A .* A)), m, 1);

        [U, S, V] = svd(A, 'econ');
        PCAV = V(:, 1:p);
        D = diag(S(1:p, 1:p));
        Dsq = D.^2;

        Xinitial.main = PCAV;
        Lu = norm(Dsq)^2 * 8;
        Ll = Lu / 4;
        tol = 1e-10*n*p;
        maxiter = 10000;

        A0 = 0.001; xi = 1; theta = 1; mu = 1; L0 = 0.5;
        q = (mu - L0) / (theta * Ll - L0);

        % Run algorithms
        [xopt1, iter1, time1, fv1, nD, sparsity1, avar, fs1, neta1] = RPG(Xinitial, A, Dsq, lambda, Ll, tol, maxiter);

        [xopt2, iter2, time2, fv2, nD, sparsity2, avar, fs2, neta2] = ManPG(Xinitial, A, Dsq, lambda, Ll, tol, maxiter);

        [xopt3, iter3, time3, fv3, nD, sparsity3, avar, fs3, neta3] = ManPG_Ada(Xinitial, A, Dsq, lambda, Ll, tol, maxiter, fv1);
 
        [xopt4, iter4, time4, fv4, nD, sparsity4, avar, fs4, neta4] = RAPG(Xinitial, A, Dsq, lambda, A0, xi, q, Ll, Lu, tol, maxiter, fv1);

        [xopt5, iter5, time5, fv5, nD, sparsity5, avar, fs5, neta5] = AR_RAPG(Xinitial, A, Dsq, lambda, A0, xi, q, Ll, Lu, tol, maxiter, fv1);

        % Store results
        Iter_all(i, :, r)     = [iter1, iter2, iter3, iter4, iter5];
        Time_all(i, :, r)     = [time1, time2, time3, time4, time5];
        Sparsity_all(i, :, r) = [sparsity1, sparsity2, sparsity3, sparsity4, sparsity5];
    end
end

% Save data
if ~exist('data', 'dir')
    mkdir('data');
end

if strcmp(mode,'p')
    xvals = (1:num_i)'; 
    writematrix(xvals, 'data/p.csv');   
elseif strcmp(mode,'n')
    xvals = (16*2.^(1:num_i))'; 
    writematrix(xvals, 'data/n.csv');
else
    xvals = (0.25*2.^(1:num_i))';
    writematrix(xvals, 'data/lambda.csv');
end

writematrix(Iter_all(:,:),     ['data/Iter_all_' mode '.csv']);  
writematrix(Time_all(:,:),     ['data/Time_all_' mode '.csv']);  
writematrix(Sparsity_all(:,:), ['data/Sparsity_all_' mode '.csv']);  

%% Plot results with average
plot_boxplots(Iter_all, Time_all, Sparsity_all, mode);

%% ====== Plot Functions ======

function plot_boxplots(Iter_all, Time_all, Sparsity_all, mode)
    num_alg = size(Iter_all,2);
    num_i   = size(Iter_all,1);
    
    lw = 1.4; 
    ms = 8;
    
    labels = {'RPG','ManPG','ManPG-Ada','RAPG','AR-RAPG'};
    [colors, markers, styles] = get_alg_styles(labels);

    % ----- x-axis -----
    if strcmp(mode,'p')
        xvals  = 1:num_i;
        xlab   = '$p$';
    elseif strcmp(mode,'n')
        xvals  = 16*2.^(1:num_i);  % e.g. [32,64,128,256]
        xlab   = '$n$';
    else
        xvals  = 0.25*2.^(1:num_i);  % e.g. [0.5,1,2,4]
        xlab   = '$\lambda$';
    end
    
    % ===== Iteration =====
    fig1 = figure('Position',[300 300 400 300]); hold on;
    for alg = 1:num_alg
        data = squeeze(Iter_all(:, alg, :))';
        mean_vals = mean(data,1);
            
        plot(xvals, mean_vals, 'Color', colors(alg,:), 'Marker', markers{alg}, ...
            'LineStyle', styles{alg}, 'LineWidth',lw,'MarkerSize',ms);
    end
    xlabel(xlab, 'Interpreter', 'latex', 'FontSize', 25); ylabel('Iteration', 'FontSize', 15); 
    xticks(xvals); xticklabels(string(xvals)); % force correct labels
    ax = gca;  % Get current axes handle
    ax.XAxis.FontSize = 12;  % x-axis tick labels font size
    ax.YAxis.FontSize = 12;  % y-axis tick labels font size
    box on;

    if strcmp(mode,'p')
        xlim([0.5,4.5]);
        legend(labels,'Location','northwest');
    elseif strcmp(mode,'n')
        xlim([0,288]); 
        legend(labels,'Location','northwest');
    else
        xlim([0,4.5]);
        legend(labels,'Location','northeast');
    end
    
    saveas(fig1, fullfile('figure', ['Iteration_' mode '.eps']), 'epsc');
    saveas(fig1, fullfile('figure', ['Iteration_' mode '.pdf']));
%     legend(labels,'Location','northwest');
%     title('Iteration');
    
    % ===== Time =====
    fig2 = figure('Position',[700 300 400 300]); hold on;
    for alg = 1:num_alg
        data = squeeze(Time_all(:, alg, :))';
        mean_vals = mean(data,1);
        plot(xvals, mean_vals, 'Color', colors(alg,:), 'Marker', markers{alg}, ...
            'LineStyle', styles{alg}, 'LineWidth',lw,'MarkerSize',ms);
    end
    xlabel(xlab, 'Interpreter', 'latex', 'FontSize', 25); ylabel('Time (s)', 'FontSize', 15); 
    xticks(xvals); xticklabels(string(xvals));
    ax = gca;  % Get current axes handle
    ax.XAxis.FontSize = 12;  % x-axis tick labels font size
    ax.YAxis.FontSize = 12;  % y-axis tick labels font size
    box on;

    if strcmp(mode,'p')
        xlim([0.5,4.5]);
        legend(labels,'Location','northwest');
    elseif strcmp(mode,'n')
        xlim([0,288]); 
        legend(labels,'Location','northwest');
    else
        xlim([0,4.5]);
        legend(labels,'Location','northeast');
    end
    
    saveas(fig2, fullfile('figure', ['Time_' mode '.eps']), 'epsc');
    saveas(fig2, fullfile('figure', ['Time_' mode '.pdf']));
     
%     title('Time');

    % ===== Sparsity =====
    fig3 = figure('Position',[1100 300 400 300]); hold on;
    for alg = 1:num_alg
        data = squeeze(Sparsity_all(:, alg, :))';
        mean_vals = mean(data,1);
        plot(xvals, mean_vals, 'Color', colors(alg,:), 'Marker', markers{alg}, ...
            'LineStyle', styles{alg}, 'LineWidth',lw,'MarkerSize',ms);
    end
    xlabel(xlab, 'Interpreter', 'latex', 'FontSize', 25); ylabel('Sparsity', 'FontSize', 15); 
    xticks(xvals); xticklabels(string(xvals));
    ax = gca;  % Get current axes handle
    ax.XAxis.FontSize = 12;  % x-axis tick labels font size
    ax.YAxis.FontSize = 12;  % y-axis tick labels font size
    box on;

    if strcmp(mode,'p')
        xlim([0.5,4.5]); legend(labels,'Location','southeast');
    elseif strcmp(mode,'n')
        xlim([0,288]); legend(labels,'Location','northeast');
    else
        xlim([0,4.5]); legend(labels,'Location','northwest');
    end
 %     ylim([0.41,0.49]);
%     if strcmp(mode,'p')
%         ylim([0.41,0.49]);
%     else
%         ylim([0.3,0.8]);   
%     end
% ---- Auto adjust y-axis range ----
    all_data = squeeze(Sparsity_all(:,:,:));
    ymin = min(all_data(:));
    ymax = max(all_data(:));
    padding = 0.01 * (ymax - ymin);  %  margin
    ylim([ymin - padding, ymax + padding]);
    
    saveas(fig3, fullfile('figure', ['Sparsity_' mode '.eps']), 'epsc');
    saveas(fig3, fullfile('figure', ['Sparsity_' mode '.pdf']));
    
%     title('Sparsity');
end

function [colors, markers, styles] = get_alg_styles(labels)
    num_alg = length(labels);
    colors  = zeros(num_alg,3);
    markers = cell(num_alg,1);
    styles  = cell(num_alg,1);

    for k = 1:num_alg
        switch labels{k}
            case 'RPG'
                colors(k,:) = [0 0.4470 0.7410]; markers{k} = 'o'; styles{k} = '-';
            case 'ManPG'
                colors(k,:) = [0.8500 0.3250 0.0980]; markers{k} = '<'; styles{k} = '-';
            case 'ManPG-Ada'
                colors(k,:) = [0.9290 0.6940 0.1250]; markers{k} = 's'; styles{k} = '-';
            case 'RAPG'
                colors(k,:) = [0.4940 0.1840 0.5560]; markers{k} = '^'; styles{k} = '-';
            case 'AR-RAPG'
                colors(k,:) = [0.4660 0.6740 0.1880]; markers{k} = 'v'; styles{k} = '-';
            otherwise
                colors(k,:) = [0 0 0]; markers{k} = 'x'; styles{k} = '-';
        end
    end
end
