
function [xopt, iter, time, fv, nD, sparsity, avar, fs, neta] = Driver_ManPG(Xinitial, A, Dsq, lambda, L, tol, maxiter, Ftol)
    % Min_{X \in Ob} \| X^T A^T A X - D^2\|_F^2 + lambda \|X\|_1
    % functions for the optimization problem
    fhandle = @(x)f(x, A, Dsq, lambda);
    gfhandle = @(x)gf(x, A, Dsq, lambda);
    fprox = @prox;
    fcalJ = @calJ;
    
    % functions for the manifold
    fcalA = @calA;
    fcalAstar = @calAstar;
    
    [xopt, iter, time, fv, nD, fs, neta] = solver(fhandle, gfhandle, fcalA, fcalAstar, fprox, fcalJ, Xinitial, L, lambda, tol, maxiter, Ftol);
    sparsity = sum(sum(abs(xopt.main) < 1e-5)) / prod(size(xopt.main));

    % adjusted variance
    [Q, R] = qr(A * xopt.main, 0);
    avar = trace(R * R);
end

function output = R(x, eta)
    output.main = zeros(size(eta));
    for i = 1 : size(output.main, 2)
        tmp = norm(eta(:, i));
        if(abs(tmp) < eps)
            output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i);
        else
            output.main(:, i) = x.main(:, i) * cos(tmp) + eta(:, i) / tmp * sin(tmp);
        end
        output.main(:, i) = output.main(:, i) / norm(output.main(:, i));
    end
end

function [output, x] = f(x, A, Dsq, lambda)
    X = x.main;
    x.AX = A * X;
    x.XtAtAXmDsq = (x.AX' * x.AX - diag(Dsq));
    output = norm(X(:), 1) * lambda + norm(x.XtAtAXmDsq , 'fro')^2;
end

function output = gf(x, A, Dsq, lambda)
    X = x.main;
    Egf = 4 * A' * (x.AX * x.XtAtAXmDsq);
    output = Egf - X * diag(diag(X' * Egf));
end

function output = prox(X, t, mu)
    output = min(0, X + t * mu) + max(0, X - t * mu);
end

function output = calA(Z, U) % U \in St(p, n)
    tmp = Z' * U;
    output = tmp + tmp';
end

function output = calAstar(Lambda, U) % U \in St(p, n)
    output = U * (Lambda + Lambda');
end

function output = calJ(y, eta, t, mu)
    output = (abs(y) > mu * t) .* eta;
end

function [xopt, iter, time, fv, err, fs, neta] = solver(fhandle, gfhandle, fcalA, fcalAstar, fprox, fcalJ, x0, L, mu, tol, maxiter, Ftol)
    delta = 0.0001;
    gamma = 0.5;
    err = inf;
    x1 = x0;
    x2 = x1; fs = []; neta = [];
    [f1, x1] = fhandle(x1);
    gf1 = gfhandle(x1);
    t = 1 / L;
    t0 = t;
    iter = 0;
    fs(iter + 1) = f1;
    neta(iter + 1) = 0;
    [n, p] = size(x0.main);
    Dinitial = zeros(p, p);
    totalbt = 0;
    innertol = max(1e-13, min(1e-11,1e-3*sqrt(tol)*t^2));
    fprintf('iter:%d, f:%e, ngf:%e\n', iter, f1, norm(gf1, 'fro'));
    tic;
    while(err > tol && iter < maxiter && f1 > Ftol + 1e-7)
        [D, Dinitial, inneriter] = finddir(x1, gf1, t, fcalA, fcalAstar, fprox, fcalJ, mu, Dinitial, innertol);
        alpha = 1;
        x2 = R(x1, alpha * D);
        [f2, x2] = fhandle(x2);
        btiter = 0;
        while(f2 > f1 - delta * alpha * norm(D, 'fro')^2 && btiter < 3)
            alpha = alpha * gamma;
            x2 = R(x1, alpha * D);
            [f2, x2] = fhandle(x2);
            btiter = btiter + 1;
            totalbt = totalbt + 1;
        end
        if(btiter == 3)
            innertol = max(innertol * 1e-2, 1e-20);
            continue;
        end
        gf2 = gfhandle(x2);
        iter = iter + 1;
        fs(iter + 1) = f2;
        neta(iter +1) = norm(D, 'fro');
        err = (norm(x2.main - x1.main, 'fro') / t)^2;
        if(mod(iter, 500) == 0)
            fprintf('iter:%d, f:%e, err:%e, ngf:%e, btiter:%d\n', iter, f1, err, norm(gf1, 'fro'), btiter);
        end
        
        x1 = x2; f1 = f2; gf1 = gf2;
        if(btiter == 0)
            t = t*1.01;
        else
            t = max(t0,t/1.01);
        end
    end
    fprintf('iter:%d, f:%e, err:%e, ngf:%e, totalbt:%d\n', iter, f1, err, norm(gf1, 'fro'), totalbt);
    xopt = x2;
    fv = f2;
    time = toc;
end

% compute E(Lambda)
function ELambda = E(Lambda, BLambda, x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu)
    if(length(BLambda) == 0)
        BLambda = x - t * (gfx - fcalAstar(Lambda, x));
    end
    DLambda = fprox(BLambda, t, mmu) - x;
    ELambda = fcalA(DLambda, x);
end

% compute calG(Lambda)[d]
function GLambdad = GLd(Lambda, d, BLambda, Blocks, x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu)
        GLambdad = t * fcalA(fcalJ(BLambda, fcalAstar(d, x), t, mmu), x);
end

function [output, Lambda, inneriter] = finddir(x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu, x0, innertol)
    output = zeros(size(x.main));
    Lambda = zeros(size(x0));
    inneriterall = zeros(size(x.main, 2), 1);
    for i = 1 : size(x.main, 2)
        [output(:, i), Lambda(i, i), inneriterall(i)] = finddirSt(x.main(:, i), gfx(:, i), t, fcalA, fcalAstar, fprox, fcalJ, mmu, x0(i, i), innertol);
    end
    inneriter = mean(inneriterall);
end

% Use semi-Newton to solve the subproblem and find the search direction
function [output, Lambda, inneriter] = finddirSt(x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu, x0, innertol)
    lambda = 0.2;
    nu = 0.99;
    tau = 0.1;
    eta1 = 0.2; eta2 = 0.75;
    gamma1 = 3; gamma2 = 5;
    alpha = 0.1;
    beta = 1 / alpha / 100;
    [n, p] = size(x);
    
    z = x0;
    BLambda = x - t * (gfx - fcalAstar(z, x));
    Fz = E(z, BLambda, x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu);
    
    nFz = norm(Fz, 'fro');
    nnls = 5;
    xi = zeros(nnls, 1);% for non-monotonic linesearch
    xi(nnls) = nFz;
    maxiter = 1000;
    times = 0;
    Blocks = cell(p, 1);
    while(nFz * nFz > innertol && times < maxiter) % while not converge, find d and update z
        mu = lambda * max(min(nFz, 0.1), 1e-11);
        Axhandle = @(d)GLd(z, d, BLambda, Blocks, x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu) + mu * d;
        [d, CGiter] = myCG(Axhandle, -Fz, tau, lambda * nFz, 30); % update d
        u = z + d;
        Fu = E(u, [], x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu); 
        nFu = norm(Fu, 'fro');
        
        if(nFu < nu * max(xi))
            z = u;
            Fz = Fu;
            nFz = nFu;
            xi(mod(times, nnls) + 1) = nFz;
            status = 'success';
        else
            rho = - sum(Fu(:) .* d(:)) / norm(d, 'fro')^2;
            if(rho >= eta1)
                v = z - sum(sum(Fu .* (z - u))) / nFu^2 * Fu;
                Fv = E(v, [], x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu);
                nFv = norm(Fv, 'fro');
                if(nFv <= nFz)
                    z = v;
                    Fz = Fv;
                    nFz = nFv;
                    status = 'safegard success projection';
                else
                    z = z - beta * Fz;
                    Fz = E(z, [], x, gfx, t, fcalA, fcalAstar, fprox, fcalJ, mmu);
                    nFz = norm(Fz, 'fro');
                    status = 'safegard success fixed-point';
                end
            else
%                 fprintf('unsuccessful step\n');
                status = 'safegard unsuccess';
            end
            if(rho >= eta2)
                lambda = max(lambda / 4, 1e-5);
            elseif(rho >= eta1)
                lambda = (1 + gamma1) / 2 * lambda;
            else
                lambda = (gamma1 + gamma2) / 2 * lambda;
            end
        end
        BLambda = x - t * (gfx - fcalAstar(z, x));
%         fprintf(['iter:%d, nFz:%f, xi:%f, ' status '\n'], times, nFz, max(xi));
        times = times + 1;
    end
    Lambda = z;
    inneriter = times;
    output = fprox(BLambda, t, mmu) - x;
end

function [output, k] = myCG(Axhandle, b, tau, lambdanFz, maxiter)
    x = zeros(size(b));
    r = b;
    p = r;
    k = 0;
    while(norm(r, 'fro') > tau * min(lambdanFz * norm(x, 'fro'), 1) && k < maxiter)
        Ap = Axhandle(p);
        alpha = r(:)' * r(:) / (p(:)' * Ap(:));
        x = x + alpha * p;
        rr0 = r(:)' * r(:);
        r = r - alpha * Ap;
        beta = r(:)' * r(:) / rr0;
        p = r + beta * p;
        k = k + 1;
    end
    output = x;
end
