function varargout = LRTR_SR1(fns,params)
% fns
% required:
%     fns.f(x) : return objective function value at x.
%     fns.gf(x) : return the gradient of objection function at x.
% not required if params.manifold, params.retraction and params.vector_transport is provided.
%     fns.inpro(x, v1, v2) : return the inner product g_x(v1, v2) of two tangent vectors v1 and v2 on T_x M.
%     fns.proj(x, eta) : return P_x(eta) by projecting v to tangent space of x.
%     fns.R(x, eta) : return R_x(eta), where R is a retraction, x is an element on the manifold and eta is a tangent vector of T_x M.
%     fns.Tranv(x1, d, x2, v) : return a tangent vector on T_x2 M, which is given by vector transport that transport v \in T_x1 M to T_{R_x1(d)} M. Here x2 = R_x1(d).
%     fns.invTranv(x1, d, x2, v) : return a tangent vector on T_x1 M, which is given by inverse vector transport that transport v \in T_x2 M to T_x1 M. Here x2 = R_x1(d).
%                                  and return some parameters which are useful in fns.Tranv_Params.
%     fns.Tranv_Params(x1, d, x2, v, TranvParams) : return the tangent vector which is the same as fns.Tranv(x1, d, x2, v), 
%                                                   TranvParams is some parameters such that the cost becomes smaller.
% for partly smooth function
%     fns.Tranv(x1, d, x2, v) : return a tangent vector on T_x2 M, which is given by vector transport that transport v \in T_x M to T_{R_x(d)} M. Here x2 = R_x(d).
% 
% params
% required:
%     params.x0 : initial approximation of minimizer.
%     params.B0 : initial approximation of Hessian.
% not required:
%     params.error [1e-5]       : tolerance of stopping criterion.
%     params.StopCriterion [2]  : stopping criterion, 1 means stop when relative error of objective function is less than tolerance,
%                                 2 means stop when norm of gradient is less than tolerance,
%                                 3 means stop when norm of gradient over initial norm of gradient is less than tolerance,
%     params.m [4]              : number of s and y for storage
%     params.max_t [300]        : the maximum number of iterations.
%     params.debug [1]          : '0' means almost silence, '1' means some output. '2' means a lot of output, '3' means more than need to know.
%     params.c [0.1]            : acceptance or rejection constant
%     params.tau1 [0.25]        : shrink radius constant
%     params.tau2 [2]           : expansion radius constant
%     params.max_Delta [200]    : max of Delta
%     params.min_Delta [1e-4]   : min of Delta
% The following give some examples of manifold. For each manifold, some retractions and vector transports are provided.
%     params.manifold : the manifold which objective function is on.
%         '1', sphere : S^{n - 1}
%         '2', Stiefel manifold : St(p, n)
%         '3', orthogonal group : O(n)
%         '4', grassmann manifold : Gr(p, n)

    fprintf('LRTR_SR1\n')
    if nargin < 2,
       error('Invalid arguments: the number of arguments is invalid.');
    end
    % set default parameters and check the legal range of parameters
    params = set_default_para(params, 'StopCriterion', 2, 'int', 1, 3);
    params = set_default_para(params, 'm', 4, 'int', 0, inf);
    params = set_default_para(params, 'error', 1e-5, 'float', 0, 1);
    params = set_default_para(params, 'err_x', 1e-5, 'float', 0, 1);
    params = set_default_para(params, 'num_Grads', 10, 'int', 0, inf);
    params = set_default_para(params, 'max_t', 300, 'int', 1, inf);
    params = set_default_para(params, 'debug', 1, 'int', 0, 3);
    params = set_default_para(params, 'c', 0.1, 'float', 0, 0.5);
    params = set_default_para(params, 'tau1', 0.25, 'float', 0.1, 0.9);
    params = set_default_para(params, 'tau2', 2, 'float', 1.5, 5);
    params = set_default_para(params, 'max_Delta', 200, 'float', 1, inf);
    params = set_default_para(params, 'min_Delta', 1e-4, 'float', 0, 1);
    params = set_default_para(params, 'theta', 0.1, 'float', 0, 1);
    params = set_default_para(params, 'kappa', 0.9, 'float', 0, 1);
    params = set_default_para(params, 'min_innit', 0, 'int', 0, inf);
    params = set_default_para(params, 'max_innit', 200, 'int', 0, inf);
    if(isfield(params, 'manifold') && isfield(params, 'retraction') && isfield(params, 'vector_transport'))
        fs = choose_manifold(params.manifold, params.retraction, params.vector_transport);
        fns.R = fs.R;
        fns.proj = fs.proj;
        fns.inpro = fs.inpro;
        fns.Tranv_Params = fs.Tranv_Params;
        fns.invTranv = fs.invTranv;
        fns.rank1operator = fs.rank1operator;
        fns.operadd = fs.operadd;
        if(params.StopCriterion == 4)
            fns.Tranv = fs.Tranv;
        end
    end
    % Check whether function handles are legal
    check_function_handle(fns, 'f');
    check_function_handle(fns, 'gf');
    check_function_handle(fns, 'R');
    check_function_handle(fns, 'proj');
    check_function_handle(fns, 'inpro');
    check_function_handle(fns, 'Tranv_Params');
    check_function_handle(fns, 'invTranv');
    check_function_handle(fns, 'rank1operator');
    check_function_handle(fns, 'operadd');
    if(params.StopCriterion == 4)
        check_function_handle(fns, 'Tranv');
    end
    % Check whether required parameters are legal
    if(~isfield(params, 'x0') || ~isfield(params, 'B0'))
        error('Invalid arguments: missing initial x0 or B0');
    end
    fns.haveprec = 0;
    fns.g = fns.inpro;
    % Initialization
    c = params.c;
    tau1 = params.tau1;
    tau2 = params.tau2;
    m = params.m;
    err = inf;
    times = 0;
    Delta = 1;
    x = params.x0;
    err_g = 1;
    err_r = 1;
    gradf = fns.gf(x);
    fx = fns.f(x);
    X{1} = x;
    F(1) = fx;
    G(1) = sqrt(fns.inpro(x, gradf, gradf));
    gamma = 1;
    start_index = 1;
    S = [];
    Y = [];
    SS = [];
    SY = [];
    nf = 1;
    ng = 1;
    nV = 0;
    nVp = 0;
    nR = 0;
    ind = 1;
    Grads = [];
    tic
    T(1) = 0;
    while (err > params.error && times < params.max_t)
        [eta,inner_it,stop_tCG] = tCG(fns, x, gradf, S, Y, SS, SY, start_index, gamma, 0, Delta, params.theta, params.kappa, params.min_innit, params.max_innit, 0, 0, times);
        xtemp = fns.R(x, eta);
        fxtemp = fns.f(xtemp);
        gradftemp = fns.gf(xtemp);
%         rho = (fx - fxtemp) / (- fns.inpro(x, gradf, eta) - 0.5 * fns.inpro(x, eta, Hv(fns, x, S, Y, SS, SY, v, gamma, start_index)));
        Hvtemp = Hv(fns, x, S, Y, SS, SY, eta, gamma, start_index);
        rho = (fx - fxtemp) / (- fns.inpro(x, gradf +  0.5 * Hvtemp, eta));%% - fns.inpro(x, eta,)
        %update B
        s = eta;
        [y, TranvParams] = fns.invTranv(x, eta, xtemp, gradftemp);
        nV = nV + 1;
        y = y - gradf;
        Temp = y - Hvtemp;
        denorm = fns.inpro(x, s, Temp);
        if(abs(denorm) > sqrt(eps) * sqrt(fns.inpro(x, s, s) * fns.inpro(x, Temp, Temp)))
            gamma = fns.inpro(x, y, y) / fns.inpro(x, s, y);
            if(length(S) < m)
                S{times + 1} = s;
                Y{times + 1} = y;
            else
                S{start_index} = s;
                Y{start_index} = y;
                start_index = start_index + 1;
                if(start_index > m)
                    start_index = 1;
                end
            end
            [SS, SY] = updateSS_SY(fns, x, S, Y, SS, SY, start_index);
        else
            fprintf('warnning! denormator is close to zero.\n');
        end
        if(rho > 0.75)
            if(stop_tCG == 2 || stop_tCG == 1)
                Delta = tau2 * Delta;
            end
            if(Delta > params.max_Delta)
                fprintf('reach max of Delta\n');
                break;
            end
        elseif(rho < 0.1)
            Delta = tau1 * Delta;
            if(Delta < params.min_Delta)
                fprintf('reach min of Delta\n');
                break;
            end
        end
        if(rho > c || (abs((fx - fxtemp)/(abs(fx) + 1)) < sqrt(eps) && fxtemp < fx))
            if(length(S) < m)
                for i = 1 : times + 1
                    S{i} = fns.Tranv_Params(x, eta, xtemp, S{i}, TranvParams);
                    Y{i} = fns.Tranv_Params(x, eta, xtemp, Y{i}, TranvParams);
                end
            else
                for i = 1 : m
                    S{i} = fns.Tranv_Params(x, eta, xtemp, S{i}, TranvParams);
                    Y{i} = fns.Tranv_Params(x, eta, xtemp, Y{i}, TranvParams);
                end
            end
            nVp = nVp + 2 * length(S);
            err_r = abs((fx - fxtemp) / fxtemp);
            x = xtemp;
            gradf = gradftemp;
            fx = fxtemp;
            err_g = sqrt(fns.inpro(x, gradf, gradf));
        else
            if(params.debug > 0)
                fprintf('without updating x_k\n');
            end
        end
        times = times + 1;
        nf = nf + 1;
        ng = ng + 1;
        nR = nR + 1;
        X{times + 1} = x;
        F(times + 1) = fx;
        G(times + 1) = err_g;
        T(times + 1) = toc;
        if(params.StopCriterion == 1)
            err = err_r;
        elseif(params.StopCriterion == 2)
            err = err_g;
        elseif(params.StopCriterion == 3)
            err = err_g / G(1);
        end
        if(params.debug > 0)
            fprintf('times : %d, f(x) : %e, relative error : %e, norm gradf : %e \n', times, fx, err_r, err_g);
            fprintf('\t\t\t rho : %e, delta : %e, stop_tCG : %d, inner_it : %d\n', rho, Delta, stop_tCG, inner_it);
        end
    end
    timecost = toc;
    fprintf('Num. of Iter.: %d, time cost: %f seconds, Num. of Fun, Gra, VT and R: %d, %d, (%d+%d), %d \n', times, timecost, nf, ng, nV, nVp, nR)
    fprintf('f: %e, |gradf|: %e, |gradf/gradf0|: %e \n', F(end), G(end), G(end)/G(1))
    if(params.debug > 1)
        figure(1);clf
        scatter(1 : length(F), F, '.');
        ylabel('f(x_i)');
        xlabel('iteration number');
        figure(2);clf
        scatter(1 : length(G), log(G), '.');
        ylabel('log(|grad(x_i)|)');
        xlabel('iteration number');
    end
    varargout{1} = X;
    varargout{2} = F;
    varargout{3} = G;
    varargout{4} = timecost;
    varargout{5} = nf;
    varargout{6} = ng;
    varargout{7} = T;
    varargout{8} = nV;
    varargout{9} = nVp;
    varargout{10} = nR;
end

function [SSnew, SYnew] = updateSS_SY(fns, x, S, Y, SS, SY, start_index)
    n1 = length(S);
    n2 = size(SS, 1);
    SSnew = zeros(n1, n1);
    SYnew = zeros(n1, n1);
    if(n2 < n1)
        SSnew(1 : n2, 1 : n2) = SS;
        SYnew(1 : n2, 1 : n2) = SY;
        SSnew(n1, n1) = fns.inpro(x, S{n1}, S{n1});
        SYnew(n1, n1) = fns.inpro(x, S{n1}, Y{n1});
        for i = 1 : n2
            SYnew(n1, i) = fns.inpro(x, S{n1}, Y{i});
            SYnew(i, n1) = SYnew(n1, i);
            SSnew(n1, i) = fns.inpro(x, S{n1}, S{i});
            SSnew(i, n1) = SSnew(n1, i);
        end
    else
        %n1 = m in this case
        SSnew(1 : n1 - 1, 1 : n1 - 1) = SS(2 : n1, 2 : n1);
        SYnew(1 : n1 - 1, 1 : n1 - 1) = SY(2 : n1, 2 : n1);
        end_ind = start_index - 1;
        if(end_ind < 1)
            end_ind = end_ind + n1;
        end
        SSnew(n1, n1) = fns.inpro(x, S{end_ind}, S{end_ind});
        SYnew(n1, n1) = fns.inpro(x, S{end_ind}, Y{end_ind});
        for i = 1 : n1 - 1
            ind = i - 1 + start_index;
            if(ind > n1)
                ind = ind - n1;
            end
            SYnew(n1, i) = fns.inpro(x, S{end_ind}, Y{ind});
            SYnew(i, n1) = SYnew(n1, i);
            SSnew(n1, i) = fns.inpro(x, S{end_ind}, S{ind});
            SSnew(i, n1) = SSnew(n1, i);
        end
    end
end

function output = Hv(fns, x, S, Y, SS, SY, v, gamma, start_index)
    m = length(S);
    for i = 1 : m
        ind = i - 1 + start_index;
        if(ind > m)
            ind = ind - m;
        end
        YGS{i} = Y{ind} - gamma * S{ind};
    end
    M = SY - gamma * SS;
    V = zeros(m, 1);
    for i = 1 : m
        V(i) = fns.inpro(x, YGS{i}, v);
    end
    V = linsolve(M, V);
    output = gamma * v;
    for i = 1 : m
        output = output + V(i) * YGS{i};
    end
end

%         [eta,inner_it,stop_tCG] = tCG(fns, x, gradf, S, Y, SS, SY, start_index, gamma, 0, Delta, 1.0, 0.1, 0, 200, 0, 0, times);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%% truncated CG %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [eta,inner_it,stop_tCG] = tCG(fns,x,grad,S,Y,SS,SY,start_index,gamma,eta,Delta,theta,kappa,min_inner,max_inner,useRand,debug, times);
% tCG - Truncated (Steihaug-Toint) Conjugate-Gradient
% minimize <eta,grad> + .5*<eta,Hess(eta)>
% subject to <eta,eta> <= Delta^2

   % all terms involving the trust-region radius will utilize an inner product
   % w.r.t. the preconditioner; this is because the iterates grow in
   % length w.r.t. the preconditioner, guaranteeing that we will not 
   % re-enter the trust-region
   % 
   % the following recurrences for Prec-based norms and inner 
   % products come from CGT2000, pg. 205, first edition
   % below, P is the preconditioner
   % 
   % <eta_k,P*delta_k> = beta_k-1 * ( <eta_k-1,P*delta_k-1> + alpha_k-1 |delta_k-1|^2_P )
   % |delta_k|^2_P = <r_k,z_k> + beta_k-1^2 |delta_k-1|^2_P
   % 
   % therefore, we need to keep track of 
   % 1)   |delta_k|^2_P 
   % 2)   <eta_k,P*delta_k> = <eta_k,delta_k>_P
   % 3)   |eta_k  |^2_P
   % 
   % initial values are given by:
   %    |delta_0|_P = <r,z>
   %    |eta_0|_P   = 0
   %    <eta_0,delta_0>_P = 0
   % because we take eta_0 = 0

   if useRand, % and therefore, fns.haveprec == 0
      % eta (presumably) ~= 0 was provided by the caller   
      r = grad+fns.fhess(x,eta);
      e_Pe = fns.g(x,eta,eta);
   else % and therefore, eta == 0
      % eta = 0*grad;
      r = grad;
      e_Pe = 0;
   end
   r_r = fns.g(x,r,r);
   norm_r = sqrt(r_r);
   norm_r0 = norm_r;

   % precondition the residual
   if fns.haveprec == 0,
      z = r;
   else
      z = fns.prec(x,r);
   end
   % compute z'*r
   z_r = fns.g(x,z,r);
   d_Pd = z_r;

   % initial search direction
   delta  = -z;
   if useRand, % and therefore, fns.haveprec == 0
      e_Pd = fns.g(x,eta,delta);
   else % and therefore, eta == 0
      e_Pd = 0;
   end

   % pre-assume termination b/c j == end
   stop_tCG = 5;

   % begin inner/tCG loop
   j = 0;
   for j = 1:max_inner,

%       Hxd = fns.fhess(x,delta);%%----------
%       Hxd = fns.hessianv(x, delta); %%---Netwon
%       Hxd = fns.Hv(x, B, delta); %%---SR1

    Hxd = Hv(fns, x, S, Y, SS, SY, delta, gamma, start_index);

      % compute curvature
      d_Hd = fns.g(x,delta,Hxd);

      % DEBUGGING: check that <d,Hd> = <Hd,d>
      if debug > 1,
         Hd_d = fns.g(x,Hxd,delta);
         fprintf('DBG: |d_Hd - Hd_d| (abs/rel): %e/%e\n',abs(d_Hd-Hd_d),abs((d_Hd-Hd_d)/d_Hd));
      end

      alpha = z_r/d_Hd;
      % <neweta,neweta>_P = <eta,eta>_P + 2*alpha*<eta,delta>_P + alpha*alpha*<delta,delta>_P
      e_Pe_new = e_Pe + 2.0*alpha*e_Pd + alpha*alpha*d_Pd;

      if debug > 2,
         fprintf('DBG:   (r,r)  : %e\n',r_r);
         fprintf('DBG:   (d,Hd) : %e\n',d_Hd);
         fprintf('DBG:   alpha  : %e\n',alpha);
      end

      % check curvature and trust-region radius
      if d_Hd <= 0 || e_Pe_new >= Delta^2,
         % want
         %  ee = <eta,eta>_prec,x
         %  ed = <eta,delta>_prec,x
         %  dd = <delta,delta>_prec,x
         tau = (-e_Pd + sqrt(e_Pd*e_Pd + d_Pd*(Delta^2-e_Pe))) / d_Pd;
         if debug > 2,
            fprintf('DBG:     tau  : %e\n',tau);
         end
         eta = eta + tau*delta;
         if d_Hd <= 0,
            stop_tCG = 1;     % negative curvature
         else
            stop_tCG = 2;     % exceeded trust region
         end
         break;
      end

      % no negative curvature and eta_prop inside TR: accept it
      e_Pe = e_Pe_new;
      eta = eta + alpha*delta;

      % update the residual
      r = r + alpha*Hxd;
      % re-tangentalize r
      r = fns.proj(x,r);

      % compute new norm of r
      r_r = fns.g(x,r,r);
      norm_r = sqrt(r_r);

      % check kappa/theta stopping criterion
      if j >= min_inner && norm_r <= norm_r0*min(norm_r0^theta,kappa)
         % residual is small enough to quit
         if kappa < norm_r0^theta,
             stop_tCG = 3;  % linear convergence
         else
             stop_tCG = 4;  % superlinear convergence
         end
         break;
      end

      % precondition the residual
      if fns.haveprec == 0,
         z = r;
      else
         z = fns.prec(x,r);
      end

      % save the old z'*r
      zold_rold = z_r;
      % compute new z'*r
      z_r = fns.g(x,z,r);

      % compute new search direction
      beta = z_r/zold_rold;
      delta = -z + beta*delta;
      % update new P-norms and P-dots
      e_Pd = beta*(e_Pd + alpha*d_Pd);
      d_Pd = z_r + beta*beta*d_Pd;

   end  % of tCG loop
   inner_it = j;
   return;
end
