
function [xopt, iter, time, fv, nD, sparsity, avar, fs, neta] = test_SparsePCA(Xinitial, A, Dsq, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol)
% Min_{X \in Ob} \| X^T A^T A X - D^2\|_F^2 + lambda \|X\|_1
    fhandle = @(x)f(x, A, Dsq, lambda);
    gfhandle = @(x)gf(x, A, Dsq, lambda);
    
    [xopt, iter, time, fv, nD, fs, neta] = MPGWH_solver(fhandle, gfhandle, Xinitial, A, Dsq, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol);
    sparsity = sum(sum(xopt.main == 0)) / prod(size(xopt.main));
%     fprintf('sparsity:%1.3f\n',sum(sum(abs(xopt.main) < 1e-5)) / (n * p));

    % adjusted variance
    [Q, R] = qr(A * xopt.main, 0);
    avar = trace(R * R);
end

function [output, iter, time, fv, err, fs, neta] = MPGWH_solver(fhandle, gfhandle, x0, A, Dsq, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol)
    delta = 0.0001;
    gamma = 0.5;
    err = inf;
    A1 = A0;
    A2 = (2 * xi * A1 + xi + sqrt(4 * xi^2 * A1 + xi^2 + 4 * xi * A1^2 * q)) / (xi - q) / 2;
    gamma1 = (A2 - A1) / (xi + A2 * q);
    beta = (xi + A1 * q) / (xi + A2 * q);
    x1 = x0;
    y1 = x1;
    z1 = x1;
    [fx1, x1] = fhandle(x1);
    gfx1 = gfhandle(x1);
    L0 = Lu;
    L = Ll;
    iter = 0;
    fs(iter + 1) = fx1;
    neta(iter + 1) = 0;
    totalbt = 0;
    gfy1 = gfx1;
    fy1 = fx1;
    fx0 = fx1;
    gfx0 = gfx1;
    fprintf('iter:%d, f:%e, ngf:%e\n', iter, fx1, norm(gfx1, 'fro'));
    N = 5;
    maxN = 10;
    minN = 2;
    justrestart = 0;
    SafeGuardIter = N;
    Dx = inf;
    btiter = 0;
    tic;
    while(err > tol && fx1 > Ftol + 1e-6 && iter < maxiter)
        % safeguard
        if(iter == SafeGuardIter)
%             h1 = fhandle(x0)
            if(justrestart == 1)
                Dx = Dybackup;
                justrestart = 0;
            else
%                 [Dx, Dxinitial, inneriterx] = finddir(x0, gfx0, t, fcalA, fcalAstar, fprox, fcalJ, mu, Dxinitial, innertol);
                xc.main = proximalmap4(x0.main, gfx0 / L, lambda / L);
                Dx = Rinv(x0, xc);
            end
            alpha = 1;
            xc = R(x0, alpha * Dx);
            [fxc, xc] = fhandle(xc);
            btiter = 0;
            while(fxc > fx0 - delta * alpha * norm(Dx, 'fro')^2 && btiter < 3)
                alpha = alpha * gamma;
                xc = R(x0, alpha * Dx);
                [fxc, xc] = fhandle(xc);
                btiter = btiter + 1;
                totalbt = totalbt + 1;
            end
            if(btiter == 3)
                L = min(L0, L * 1.1);
                continue;
            end
            % if safeguard takes effect, here must be a strict inequality,
            % otherwise, algorithm would stack at N = 1.
            if(iter ~= 0 && fxc < fx2)
                gfxc = gfhandle(xc);
                y1 = xc;gfy1 = gfxc; fy1 = fxc;
                x1 = xc; fx1 = fxc; gfx1 = gfxc;
                A1 = A0;
                A2 = (2 * xi * A1 + xi + sqrt(4 * xi^2 * A1 + xi^2 + 4 * xi * A1^2 * q)) / (xi - q) / 2;
                gamma1 = (A2 - A1) / (xi + A2 * q);
                beta = (xi + A1 * q) / (xi + A2 * q);
                z1 = x1;
                if(N ~= maxN)
                    L = L * 1.1;
                end
                N = max(N - 1, minN);
                justrestart = 1;
            else
                N = min(N + 1, maxN);
            end
            % update x0
            x0 = x1;
            fx0 = fx1;
            gfx0 = gfhandle(x0);
        end

        % solve the Riemannian proximal mapping
        x2.main = proximalmap4(y1.main, gfy1 / L, lambda / L);
        Dy = Rinv(y1, x2);
        if(justrestart == 1 && iter == SafeGuardIter) % this is to avoid a duplicate computation in the safeguard
            Dybackup = Dy;
        end
        if(iter == SafeGuardIter) % find the next safeguard iteration.
            SafeGuardIter = SafeGuardIter + N;
        end
        [fx2, x2] = fhandle(x2);

        % compute iteration point
        v = beta * Rinv(y1, z1) + gamma1 * Dy;
        tmp = VT(y1, x2, v - Dy);
        z2 = R(x2, tmp);
        A1 = A2;
        A2 = (2 * xi * A1 + xi + sqrt(4 * xi^2 * A1 + xi^2 + 4 * xi * A1^2 * q)) / (xi - q) / 2;
        gamma1 = (A2 - A1) / (xi + A2 * q);
        beta = (xi + A1 * q) / (xi + A2 * q );
        tau = beta * A2 / (gamma1 * A1 + beta * A2);
        zeta = tau * tmp;
        y2 = R(x2, zeta);
        
        % update
        iter = iter + 1;
        fs(iter + 1) = min(fx2, fhandle(x0));
        neta(iter + 1) = norm(Dy, 'fro');
        err = (norm(Dy, 'fro') * L0)^2;
        Ds(iter) = norm(Dx, 'fro');
%         fprintf('iter:%d, f:%e, nD^2:%e, inneriterx:%d, inneritery:%d, btiter:%d\n', iter, fx2, err, inneriterx, inneritery, btiter);
        if(iter==100)
        end
        [fy2, y2] = fhandle(y2);
        gfy2 = gfhandle(y2);
        fy1 = fy2; gfy1 = gfy2;
        y1 = y2; z1 = z2;
        x1 = x2; fx1 = fx2;
        
        if(mod(iter, 500) == 0)
            fprintf('iter:%d, fx:%e, N:%d, btiter:%d, err:%e\n', iter, min(fx1, fhandle(x0)), N, btiter, err);
        end
    end
    fprintf('iter:%d, fx:%e, err:%e, totalbt:%d\n', iter, min(fx1, fhandle(x0)), err, totalbt);
    output = x2;
    fv = fx2;
    time = toc;
end

function output = Rinv(x, y)
%     output = zeros(size(x.main));
%     for i = 1 : size(output, 2)
%         tmp = x.main(:, i)' * y.main(:, i);
%         if(tmp == 1)
%             continue;
%         end
%         output(:, i) = (y.main(:, i) - x.main(:, i) * tmp) * acos(tmp) / sqrt(1 - tmp * tmp);
%     end
    [n, p] = size(x.main);
    xy = sum(x.main .* y.main);
    tmp = ones(1, p);
    tmp(xy ~= 1) = acos(xy(xy ~= 1)) ./ sqrt(1 - xy(xy ~= 1) .* xy(xy ~= 1));
    output = (y.main - x.main .* repmat(xy, n, 1)) .* repmat(tmp, n, 1);
end

function output = R(x, eta)
%     output.main = zeros(size(eta));
%     for i = 1 : size(output.main, 2)
%         tmp = norm(eta(:, i));
%         if(abs(tmp) < eps)
%             output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i);
%         else
%             output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i) / tmp * sin(tmp);
%         end
%         output.main(:, i) = output.main(:, i) / norm(output.main(:, i));
%     end
    output.main = zeros(size(eta));
    [n, p] = size(eta);
    colnorms = sqrt(sum(eta .* eta));

    sincol = sin(colnorms);
    sincoldivnorms = sincol ./ colnorms;
    sincoldivnorms(colnorms < eps) = 1;
    output.main = x.main .* repmat(cos(colnorms), n, 1) + eta .* repmat(sincoldivnorms, n, 1);

    outcolnorms = 1 ./ sqrt(sum(output.main .* output.main));
    output.main = output.main .* repmat(outcolnorms, n, 1); 
end

function output = VT(x, y, xix)
    output = zeros(size(x.main));
    for i = 1 : size(x.main, 2)
        output(:, i) = xix(:, i) - (2 * (xix(:, i)' * y.main(:, i)) / norm(x.main(:, i) + y.main(:, i))^2) * (x.main(:, i) + y.main(:, i));
    end
end

function [output, x] = f(x, A, Dsq, lambda)
    X = x.main;
    x.AX = A * X;
    x.XtAtAXmDsq = (x.AX' * x.AX - diag(Dsq));
    output = norm(X(:), 1) * lambda + norm(x.XtAtAXmDsq , 'fro')^2;
end

function output = gf(x, A, Dsq, lambda)
    X = x.main;
    Egf = 4 * A' * (x.AX * x.XtAtAXmDsq);
    output = Egf - X * diag(diag(X' * Egf));
end

% for Oblique manifold
function [output, aveiter] = proximalmap4(Y, etaY, lambda)
    output = zeros(size(Y));
    alliter = 0;
    for i = 1 : size(Y, 2)
        [output(:, i), iter] = proximalmap3(Y(:, i), etaY(:, i), lambda);
        alliter = alliter + iter;
    end
    aveiter = alliter / size(Y, 2);
end

%min_{x \in S^n} \|x\|_1 + \|Log_y(x) + etay\|^2 / 2 / lambda = \|x\|_1 + \| acos(x'y) * (I-y y')x / (sqrt(1 - (x'y)^2)) + etay \|_F^2 / 2 / lambda
function [output, iter] = proximalmap3(y, etay, lambda)
    err = inf;
    x = (y - etay) / norm(y - etay);
    xyold = x' * y;
    etayxold = 0;
    v = y - etay;
    %theta = min(pi-0.01, norm(etay));
    maxiter = 50;
    iter = 0;
    while(err > 1e-10 && iter < maxiter)
        x = proximalmap2(v, lambda);
        xynew = x' * y;
        acosxy = acos(xynew);
        sint = sqrt(1 - xynew * xynew);
        etayxnew = etay' * x;
        if(abs(sint) < 1e-10)
            v = (y - etay);
        else
            v = acosxy / sint * (y - etay) - (acosxy / sint * etayxnew * xynew / sint / sint - etayxnew / sint / sint) * y;
        end
        err1 = abs(xyold - xynew); err2 = abs(etayxold - etayxnew);
        err = max(err1, err2);
        xyold = xynew;
        etayxold = etayxnew;
%         x = proximalmap2(v, lambda);
        iter = iter + 1;
%         fprintf('iter:%d, fv:%e, x''y:%e, etay''x:%e err1:%e, err2:%e\n', iter, fprox3(x, y, lambda, etay), xynew, etayxnew, err1, err2);
    end
    output = x;
end

%min_{x \in S^n} \|x\|_1 + \|x - y\|^2 / 2 / lambda
function output = proximalmap2(Y, lambda)
    output = max(abs(Y) - lambda, 0) .* sign(Y);
    
    [maxv, idx] = max(abs(Y));
    for i = 1 : size(Y, 2)
        nn = norm(output(:, i));
        if(nn ~= 0)
            output(:, i) = output(:, i) / nn;
        else
            output(idx(i), i) = sign(Y(idx(i), i));
        end
    end
end
