
function [xopt, iter, time, fv, nD, sparsity, avar, fs, neta] = test_SparsePCA(Xinitial, A, Dsq, lambda, L, tol, maxiter)
% Min_{X \in Ob} \| X^T A^T A X - D^2\|_F^2 + lambda \|X\|_1
    fhandle = @(x)f(x, A, Dsq, lambda);
    gfhandle = @(x)gf(x, A, Dsq, lambda);
    
    [xopt, iter, time, fv, nD, fs, neta] = MPGWH_solver(fhandle, gfhandle, Xinitial, A, Dsq, lambda, L, tol, maxiter);
    sparsity = sum(sum(xopt.main == 0)) / prod(size(xopt.main));

    % adjusted variance
    [Q, R] = qr(A * xopt.main, 0);
    avar = trace(R * R);
end

function [output, iter, time, fv, err, fs, neta] = MPGWH_solver(fhandle, gfhandle, x0, A, Dsq, lambda, L, tol, maxiter)
    x1 = x0;
    iter = 0;
    err = inf;
    [fv1, x1] = fhandle(x1);
    Rgf1 = gfhandle(x1);
    fs(iter + 1) = fv1;
    neta(iter + 1) = 0;
    fprintf('iter:%d, f:%e\n', iter, fv1);
    tic;
    while((iter < maxiter && err > tol) )
        % solve the Riemannian proximal mapping
        x2.main = proximalmap4(x1.main, Rgf1 / L, lambda / L);
        eta = Rinv(x1, x2);
        [fv2, x2] = fhandle(x2);

%         % linesearch
%         alpha = 1;
%         btiter = 0;
%         while(fv2 - fv1 > - 0.0001 *  alpha * norm(eta, 'fro')^2 && btiter < maxbtiter)
%             alpha = alpha * 0.5;
%             x2 = R(x1, alpha * eta);
%             [fv2, x2] = fhandle(x2);
%             btiter = btiter + 1;
%             totalbt = totalbt + 1;
%         end
        
        % update
        Rgf2 = gfhandle(x2);
        iter = iter + 1;
        fs(iter + 1) = fv2;
        err = norm((x2.main - x1.main) * L, 'fro')^2;
        neta(iter + 1) = norm(eta, 'fro');
        if(mod(iter, 500) == 0)
            fprintf('iter:%d, f:%e, err:%e\n', iter, fv2, err);
        end
        x1 = x2;
        Rgf1 = Rgf2;
        fv1 = fv2;
    end
    fprintf('iter:%d, f:%e, err:%e\n', iter, fv2, err);
    output = x1;
    fv = fv1;
    time = toc;
end

function output = Rinv(x, y)
%     output = zeros(size(x.main));
%     for i = 1 : size(output, 2)
%         tmp = x.main(:, i)' * y.main(:, i);
%         if(tmp == 1)
%             continue;
%         end
%         output(:, i) = (y.main(:, i) - x.main(:, i) * tmp) * acos(tmp) / sqrt(1 - tmp * tmp);
%     end
    [n, p] = size(x.main);
    xy = sum(x.main .* y.main);
    tmp = ones(1, p);
    tmp(xy ~= 1) = acos(xy(xy ~= 1)) ./ sqrt(1 - xy(xy ~= 1) .* xy(xy ~= 1));
    output = (y.main - x.main .* repmat(xy, n, 1)) .* repmat(tmp, n, 1);
end

function output = R(x, eta)
%     output.main = zeros(size(eta));
%     for i = 1 : size(output.main, 2)
%         tmp = norm(eta(:, i));
%         if(abs(tmp) < eps)
%             output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i);
%         else
%             output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i) / tmp * sin(tmp);
%         end
%         output.main(:, i) = output.main(:, i) / norm(output.main(:, i));
%     end
    output.main = zeros(size(eta));
    [n, p] = size(eta);
    colnorms = sqrt(sum(eta .* eta));

    sincol = sin(colnorms);
    sincoldivnorms = sincol ./ colnorms;
    sincoldivnorms(colnorms < eps) = 1;
    output.main = x.main .* repmat(cos(colnorms), n, 1) + eta .* repmat(sincoldivnorms, n, 1);

    outcolnorms = 1 ./ sqrt(sum(output.main .* output.main));
    output.main = output.main .* repmat(outcolnorms, n, 1); 
end

function output = VT(x, y, xix)
    output = zeros(size(x.main));
    for i = 1 : size(x.main, 2)
        output(:, i) = xix(:, i) - (2 * (xix(:, i)' * y.main(:, i)) / norm(x.main(:, i) + y.main(:, i))^2) * (x.main(:, i) + y.main(:, i));
    end
end

function [output, x] = f(x, A, Dsq, lambda)
    X = x.main;
    x.AX = A * X;
    x.XtAtAXmDsq = (x.AX' * x.AX - diag(Dsq));
    output = norm(X(:), 1) * lambda + norm(x.XtAtAXmDsq , 'fro')^2;
end

function output = gf(x, A, Dsq, lambda)
    X = x.main;
    Egf = 4 * A' * (x.AX * x.XtAtAXmDsq);
    output = Egf - X * diag(diag(X' * Egf));
end

% for Oblique manifold
function [output, aveiter] = proximalmap4(Y, etaY, lambda)
    output = zeros(size(Y));
    alliter = 0;
    for i = 1 : size(Y, 2)
        [output(:, i), iter] = proximalmap3(Y(:, i), etaY(:, i), lambda);
        alliter = alliter + iter;
    end
    aveiter = alliter / size(Y, 2);
end

% for sphere
%min_{x \in S^n} \|x\|_1 + \|Log_y(x) + etay\|^2 / 2 / lambda = \|x\|_1 + \| acos(x'y) * (I-y y')x / (sqrt(1 - (x'y)^2)) + etay \|_F^2 / 2 / lambda
function [output, iter] = proximalmap3(y, etay, lambda)
    err = inf;
    x = (y - etay) / norm(y - etay);
    xyold = x' * y;
    etayxold = 0;
    v = y - etay;
    %theta = min(pi-0.01, norm(etay));
    maxiter = 50;
    iter = 0;
    while(err > 1e-10 && iter < maxiter)
        x = proximalmap2(v, lambda);
        xynew = x' * y;
        acosxy = acos(xynew);
        sint = sqrt(1 - xynew * xynew);
        etayxnew = etay' * x;
        if(abs(sint) < 1e-10)
            v = (y - etay);
        else
            v = acosxy / sint * (y - etay) - (acosxy / sint * etayxnew * xynew / sint / sint - etayxnew / sint / sint) * y;
        end
        err1 = abs(xyold - xynew); err2 = abs(etayxold - etayxnew);
        err = max(err1, err2);     
        xyold = xynew;
        etayxold = etayxnew;
%         x = proximalmap2(v, lambda);
        iter = iter + 1;
%         fprintf('iter:%d, fv:%e, x''y:%e, etay''x:%e err1:%e, err2:%e\n', iter, fprox3(x, y, lambda, etay), xynew, etayxnew, err1, err2);
    end
    output = x;
end

%min_{x \in S^n} \|x\|_1 + \|x - y\|^2 / 2 / lambda
function output = proximalmap2(Y, lambda)
    
    output = max(abs(Y) - lambda, 0) .* sign(Y);
    
    [maxv, idx] = max(abs(Y));
    for i = 1 : size(Y, 2)
        nn = norm(output(:, i));
        if(nn ~= 0)
            output(:, i) = output(:, i) / nn;
        else
            output(idx(i), i) = sign(Y(idx(i), i));
        end
    end
end

