function fs = sphere(retraction, vector_transport)
% this function return some necessary function handles of unit sphere
% Input:
% sphere : S^{n - 1}
%     params.retraction : 
%         '1', exponential mapping
%         '2', R_x(eta) = (x + eta) / norm(x + eta)
%     params.vector_transport : the isometric vector transport
%         '1', parallel transport
%         '2', isometric vector transport (canonical angles)
%         '3', differentiated retraction '2'
%         '4', vector transport by projection
% output:
%     fns.inpro(x, v1, v2) : return the inner product g_x(v1, v2) of two tangent vectors v1 and v2 on T_x M.
%     fns.Hv(x, H, v) : return tangent vector which is given by that operator H affect on tangent vector v on T_x M.
%     fns.invBv(x, B, v) : return a tangent vector vv on T_x M which satisfies B vv = v.
%     fns.phi() : return the coefficient of RBroyden Family update formula.
%     fns.proj(x, eta) : return P_x(eta) by projecting v to tangent space of x.
%     fns.SRTV(x, m) : return m random tangent vectors in T_x M.
%     fns.R(x, eta) : return R_x(eta), where R is a retraction, x is an element on the manifold and eta is a tangent vector of T_x M.
%     fns.Tranv(x1, d, x2, v) : return a tangent vector on T_x2 M, which is given by vector transport that transport v \in T_x M to T_{R_x(d)} M. Here x2 = R_x(d).
%     fns.invTranv(x1, d, x2, v) : return a tangent vector on T_x1 M, which is given by inverse vector transport that transport v \in T_x2 M to T_x1 M. Here x2 = R_x1(d).
%     fns.TranH(x1, d, x2, H1) : return a operator, H2, that affects on T_x2 M. H2 satisfies that for any v \in T_x2(M), H2(v) = Tran (H1 (Tran^{-1}(v)))
%                             where Tran is the vector transport fns.Tranv(x1, d, x2, v).
%     fns.rank1operator(x, v1, v2) : return a operator, H, on T_x M that satisfies that for any v \in T_x M, H(v) = g_x(v2, v) * v1.
%     fns.operadd(x, H1, H2) : return a operator, H, on T_x M that satisfies that for any v \in T_x M, H(v) = H1(v) + H2(v).

    fprintf('unit sphere\n');
    fs.Hv = @Hv;
    fs.invBv = @invBv;
    fs.rank1operator = @rank1operator;
    fs.operadd = @operadd;
    fs.inpro = @inpro;
    fs.phi = @phi;
    fs.proj = @proj;
    fs.SRTV = @SRTV;
    if(retraction == 1)
        fprintf('exponential mapping\n')
        fs.R = @R_exp;
    elseif(retraction == 2)
        fprintf('retraction: (x + eta) / norm(x + eta)\n')
        fs.R = @R_normalized;
    end
    if(vector_transport == 1)
        fprintf('parallel vector transport\n')
        fs.Tranv = @Tranv_parallel;
        fs.Tranv_Params = @Tranv_parallel;
        fs.invTranv = @invTranv_parallel;
        fs.TranH = @TranH_parallel;
    elseif(vector_transport == 2)
        fprintf('vector transport : reflection based on L space\n')
        fs.Tranv = @Tranv_iso;
        fs.invTranv = @invTranv_iso;
        fs.Tranv_Params = @Tranv_iso;
        fs.TranH = @TranH_iso;
    elseif(vector_transport == 3)
        fprintf('vector transport : differentiated retraction\n')
        fs.Tranv = @Tranv_diff;
        fs.invTranv = @invTranv_diff;
        fs.Tranv_Params = @Tranv_diff;
        fs.TranH = @TranH_diff;
    elseif(vector_transport == 4)
        fprintf('vector transport : projection\n')
        fs.Tranv = @Tranv_proj;
        fs.invTranv = @invTranv_proj;
        fs.Tranv_Params = @Tranv_proj;
        fs.TranH = @TranH_proj;
    end
end

function output = Hv(x, H, v)
    output = H * v;
end

function output = invBv(x, B, v)
    EB = B;
    EB(end + 1, :) = x';
    Ev = v;
    Ev(end + 1) = 0;
    output = linsolve(EB, Ev);
end

function output = phi()
    output = 0;
end

function output = proj(x, eta)
    output = eta - x * (x' * eta);
end

function output = SRTV(x, m, seed)
    rand('state', seed);
    n = length(x);
    u = rand(n - 1, m) - 0.5;
    basis_x = null(x');
    TV = basis_x * u;
    for i = 1 : m
        output{i} = TV(:, i);
    end
end

function output = rank1operator(x, v1, v2)
    output = v1 * v2';
end

function output = operadd(x, H1, H2)
    output = H1 + H2;
end

function output = inpro(x, v1, v2)
    output = v1' * v2;
end

%% exponential mapping
function output = R_exp(x, eta) % retraction
    neta = norm(eta);
    output = x * cos(neta) + (sin(neta) / neta) * eta;
end

%% (x + eta) / norm(x + eta)
function output = R_normalized(x, eta) % retraction
    output = (x + eta) / norm(x + eta);
end

%% parallel vector transport
function [output, TranvParams] = Tranv_parallel(x, d, x2, v, TranvParams)
    nn = norm(x + x2);
    output = v - 2 * v' * x2 / nn / nn * (x + x2);
    TranvParams = [];
end

function [output, TranvParams] = invTranv_parallel(x, d, x2, v, TranvParams)
    nn = norm(x + x2);
    output = v - 2 * v' * x / nn / nn * (x + x2);
    TranvParams = [];
end

function output = TranH_parallel(x1, d, x2, H)
    n = size(x1, 1);
    x12 = x1 + x2;
    nn = norm(x12);
    temp1 = x2' * H;
    temp2 = H * x12;
    output = H - 2 / nn / nn * x12 * (temp1) - 2 / nn / nn * (temp2 - 2 / nn / nn * (temp1 * x12) * x12) * x1';
end

%% isometric vector transport
function [output, TranvParams] = Tranv_iso(x, d, x2, v, TranvParams)
    r = x2 - x * (x' * x2);
    q = r / norm(r);
    tilde_r = x - x2 * (x2' * x);
    tilde_q = tilde_r / norm(tilde_r);
    output = (v + (- tilde_q - q) * (q' * v)); %% the sign of tilde_q is negative, because sign(q' * tilde_q) is negative.
    TranvParams = [];
end

function [output, TranvParams] = invTranv_iso(x, d, x2, v, TranvParams)
    r = x - x2 * (x2' * x);
    q = r / norm(r);
    tilde_r = x2 - x * (x' * x2);
    tilde_q = tilde_r / norm(tilde_r);
    output = (v + (- tilde_q - q) * (q' * v)); %% the sign of tilde_q is negative, because sign(q' * tilde_q) is negative.
    TranvParams = [];
end

function output = TranH_iso(x1, d, x2, H)
    r = x2 - x1 * (x1' * x2);
    q = r / norm(r);
    tilde_r = x1 - x2 * (x2' * x1);
    tilde_q = tilde_r / norm(tilde_r);
    output = H + (H * (- q - tilde_q)) * tilde_q' - (tilde_q + q) * (q' * H + (q' * (H * (- q - tilde_q))) * tilde_q');
end

%% differentiated retraction
function [output, TranvParams] = Tranv_diff(x, d, x2, v, TranvParams)
    nxd = norm(x + d);
    output = (v - (x + d) * ((x + d)' * v) / (nxd * nxd)) / nxd;
    TranvParams = [];
end

function [output, TranvParams] = invTranv_diff(x, d, x2, v, TranvParams)
    nxd = norm(x + d);
    output = nxd * (v - (x + d) * ((x' * v) / (x' * (x + d))));
    TranvParams = [];
end

function output = TranH_diff(x1, d, x2, H)
    nxd = norm(x1 + d);
    TempH = (H - (H * (x1 + d)) * (x1' / (x1' * (x1 + d))));
    output = (TempH - (x1 + d) * (((x1 + d)' * TempH) / (nxd * nxd)));
end

%% vector transport by projection
function [output, TranvParams] = Tranv_proj(x, d, x2, v, TranvParams)
    nxd = norm(x + d);
    output = (v - (x + d) * ((x + d)' * v) / (nxd * nxd));
    TranvParams = [];
end

function [output, TranvParams] = invTranv_proj(x, d, x2, v, TranvParams)
    nxd = norm(x + d);
    output = (v - (x + d) * ((x' * v) / (x' * (x + d))));
    TranvParams = [];
end

function output = TranH_proj(x1, d, x2, H)
    nxd = norm(x1 + d);
    TempH = (H - (H * (x1 + d)) * (x1' / (x1' * (x1 + d))));
    output = (TempH - (x1 + d) * (((x1 + d)' * TempH) / (nxd * nxd)));
end
