
function [xopt, iter, time, fv, nD, sparsity, avar, fs, neta] = test_SparsePCA(Xinitial, A, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol)
% Min_{x \in S^n} - x^T A^T A x + lambda \|x\|_1
    fhandle = @(x)f(x, A, lambda);
    gfhandle = @(x)gf(x, A, lambda);
    
    [xopt, iter, time, fv, nD, fs, neta] = MPGWH_solver(fhandle, gfhandle, Xinitial, A, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol);
    sparsity = sum(sum(xopt.main == 0)) / prod(size(xopt.main));
%     fprintf('sparsity:%1.3f\n',sum(sum(abs(xopt.main) < 1e-5)) / (n * p));

    % adjusted variance
    [Q, R] = qr(A * xopt.main, 0);
    avar = trace(R * R);
end

function [output, iter, time, fv, err, fs, neta] = MPGWH_solver(fhandle, gfhandle, x0, A, lambda, A0, xi, q, Ll, Lu, tol, maxiter, Ftol)
    err = inf;
    A1 = A0;
    A2 = (2 * xi * A1 + xi + sqrt(4 * xi^2 * A1 + xi^2 + 4 * xi * A1^2 * q)) / (xi - q) / 2;
    gamma1 = (A2 - A1) / (xi + A2 * q);
    beta = (xi + A1 * q) / (xi + A2 * q);
    x1 = x0;
    y1 = x1;
    z1 = x1;
    [fx1, x1] = fhandle(x1);
    gfx1 = gfhandle(x1);
    L0 = Lu;
    L = Ll;
    iter = 0;
    fs(iter + 1) = fx1;
    neta(iter + 1) = 0;
    gfy1 = gfx1;
    fy1 = fx1;
    fx0 = fx1;
    gfx0 = gfx1;
    fprintf('iter:%d, f:%e, ngf:%e\n', iter, fx1, norm(gfx1, 'fro'));
    Dx = inf; 
    x2 = x1; fx2 = fx1;
    tic;
    while(err > tol && iter < maxiter) %&& fx1 > Ftol + 1e-6 
        % solve the Riemannian proximal mapping
        x2.main = proximalmap3(y1.main, gfy1 / L, lambda / L);
        Dy = Rinv(y1, x2);
        [fx2, x2] = fhandle(x2);

        % compute iteration point
        v = beta * Rinv(y1, z1) + gamma1 * Dy;
        tmp = VT(y1, x2, v - Dy);
        z2 = R(x2, tmp);
        A1 = A2;
        A2 = (2 * xi * A1 + xi + sqrt(4 * xi^2 * A1 + xi^2 + 4 * xi * A1^2 * q)) / (xi - q) / 2;
        gamma1 = (A2 - A1) / (xi + A2 * q);
        beta = (xi + A1 * q) / (xi + A2 * q);
        tau = (xi + A1 * q) * A2 / ((A2 - A1) * A1 + (xi + A1 * q) * A2);
        zeta = tau * tmp;
        y2 = R(x2, zeta);
        
        % update
        iter = iter + 1;
        fs(iter + 1) = min(fx2, fhandle(x0));
        neta(iter + 1) = norm(Dy, 'fro');
        err = (norm(Dy, 'fro') * L0)^2;
        Ds(iter) = norm(Dx, 'fro');
%         fprintf('iter:%d, f:%e, nD^2:%e, inneriterx:%d, inneritery:%d, btiter:%d\n', iter, fx2, err, inneriterx, inneritery, btiter);
        
        [fy2, y2] = fhandle(y2);
        gfy2 = gfhandle(y2);
        fy1 = fy2; gfy1 = gfy2;
        y1 = y2; z1 = z2;
        x1 = x2; fx1 = fx2;
        
        if(mod(iter, 500) == 0)
            fprintf('iter:%d, fx:%e, err:%e\n', iter, min(fx1, fhandle(x0)), err);
        end
    end
    fprintf('iter:%d, fx:%e, err:%e\n', iter, min(fx1, fhandle(x0)), err);
    output = x2;
    fv = fx2;
    time = toc;
end

function output = Rinv(x, y)
    output = zeros(size(x.main));
    tmp = x.main' * y.main;
    if(tmp ~= 1)
        output = (y.main - x.main * tmp) * acos(tmp) / sqrt(1 - tmp * tmp);
    end
end

function output = R(x, eta)
    tmp = norm(eta);
    if(abs(tmp) < eps)
        output.main = x.main * cos(tmp) + eta;
    else
        output.main = x.main * cos(tmp) + eta / tmp * sin(tmp);
    end
    output.main = output.main / norm(output.main);
end

function output = VT(x, y, xix)
   output = xix - (2 * (xix' * y.main) / norm(x.main + y.main)^2) * (x.main + y.main);
end

function [output, x] = f(x, A, lambda)
    x.Ax = A * x.main;
    tmp = norm(x.Ax, 'fro');
    output = - tmp * tmp + lambda * sum(abs(x.main(:)));
end

function output = gf(x, A, lambda)
    gfx = -2 * (A' * x.Ax);
    tmp = gfx' * x.main;
    output = gfx - x.main * ((tmp + tmp') / 2);
end

% for Oblique manifold
function [output, aveiter] = proximalmap4(Y, etaY, lambda)
    output = zeros(size(Y));
    alliter = 0;
    for i = 1 : size(Y, 2)
        [output(:, i), iter] = proximalmap3(Y(:, i), etaY(:, i), lambda);
        alliter = alliter + iter;
    end
    aveiter = alliter / size(Y, 2);
end

%min_{x \in S^n} \|x\|_1 + \|Log_y(x) + etay\|^2 / 2 / lambda = \|x\|_1 + \| acos(x'y) * (I-y y')x / (sqrt(1 - (x'y)^2)) + etay \|_F^2 / 2 / lambda
function [output, iter] = proximalmap3(y, etay, lambda)
    err = inf;
    x = (y - etay) / norm(y - etay);
    xyold = x' * y;
    etayxold = 0;
    v = y - etay;
    %theta = min(pi-0.01, norm(etay));
    maxiter = 50;
    iter = 0;
    while(err > 1e-10 && iter < maxiter)
        x = proximalmap2(v, lambda);
        xynew = x' * y;
        acosxy = acos(xynew);
        sint = sqrt(1 - xynew * xynew);
        etayxnew = etay' * x;
        if(abs(sint) < 1e-10)
            v = (y - etay);
        else
            v = acosxy / sint * (y - etay) - (acosxy / sint * etayxnew * xynew / sint / sint - etayxnew / sint / sint) * y;
        end
        err1 = abs(xyold - xynew); err2 = abs(etayxold - etayxnew);
        err = max(err1, err2);
        xyold = xynew;
        etayxold = etayxnew;
%         x = proximalmap2(v, lambda);
        iter = iter + 1;
%         fprintf('iter:%d, fv:%e, x''y:%e, etay''x:%e err1:%e, err2:%e\n', iter, fprox3(x, y, lambda, etay), xynew, etayxnew, err1, err2);
    end
    output = x;
end

%min_{x \in S^n} \|x\|_1 + \|x - y\|^2 / 2 / lambda
function output = proximalmap2(Y, lambda)
    output = max(abs(Y) - lambda, 0) .* sign(Y);
    
    [maxv, idx] = max(abs(Y));
    for i = 1 : size(Y, 2)
        nn = norm(output(:, i));
        if(nn ~= 0)
            output(:, i) = output(:, i) / nn;
        else
            output(idx(i), i) = sign(Y(idx(i), i));
        end
    end
end
