function [X, F, G, T, timecost, nf, ng, nR, nH, nV, nVp] = driver_soft_ICA(C, method, params)
% driver for Procrustes problem.
% X^* = arg min_{x in St(p, n)} - \sum_{i=1}^N ||diag(Y^T C_i Y)||_F^2, where C_i are positive-semidefinite symmetric matrix.

    fns.f = @(x)soft_ICA_f(x, C);
    if(params.retraction == 2 || params.retraction == 4)
        fprintf('gradient of test function: stiefel manifold is considered as embedded space\n')
        fns.gf = @(x)soft_ICA_gf_embedded(x, C);
        fns.hessianv = @(x, v)hessianv_embedded(x, v, C);
        if(params.vector_transport == 7 || params.vector_transport == 13 || params.vector_transport == 19 || params.vector_transport == 25)
            fns.f = @(x)soft_ICA_f_intr(x, C);
            fns.gf = @(x)soft_ICA_gf_embedded_intr(x, C);
            fns.hessianv = @(x, v)hessianv_embedded_intr(x, v, C);
        end
    else
        fprintf('gradient of quotient function: stiefel manifold is considered as quotient space\n')
        fns.gf = @(x)soft_ICA_gf_quotient(x, C);
        fns.hessianv = @(x, v)hessianv_quotient(x, v, C);
        if(params.vector_transport == 7 || params.vector_transport == 13 || params.vector_transport == 19 || params.vector_transport == 25)
            fns.f = @(x)soft_ICA_f_intr(x, C);
            fns.gf = @(x)soft_ICA_gf_quotient_intr(x, C);
            fns.hessianv = @(x, v)hessianv_quotient_intr(x, v, C);
        end
    end
    fns.gf_exists = @(x)gf_exists(x, B, N);
    
    nH = 0;
    nV = 0;
    nVp = 0;
    if(method == 1)
        [X, F, G, timecost, nf, ng, nH, T, nV, nR] = RTR_SR1(fns, params);
    elseif(method == 2)
        [X, F, G, timecost, nf, ng, T, nV, nVp, nR] = LRTR_SR1(fns, params);
    elseif(method == 3)
        [X, F, G, timecost, nf, ng, T, nR] = RTR_SD(fns, params);
    elseif(method == 4)
        [X, F, G, timecost, nf, ng, nH, T, nR] = rtr(fns, params);
    end
end

function output = soft_ICA_f(Y, C)
    N = size(C, 3);
    output = 0;
    for i = 1 : N
        output = output - norm(diag(Y' * C(:, :, i) * Y))^2;
    end
end

function output = soft_ICA_f_intr(Y, C)
    N = size(C, 3);
    Y = Y{1};
    output = 0;
    for i = 1 : N
        output = output - norm(diag(Y' * C(:, :, i) * Y))^2;
    end
end

function output = soft_ICA_gf_embedded(Y, C)
    N = size(C, 3);
    [n, p] = size(Y);
    Q = zeros(n, p);
    for i = 1 : N
        temp = C(:, :, i) * Y;
        Q = Q - 4 * temp * diag(diag(Y' * temp));
    end
    temp = Y' * Q;
    output = Q - Y * (temp + temp') / 2;
end

function output = soft_ICA_gf_embedded_intr(Y, C)
    x1 = Y{1};
    x2 = Y{2};
    N = size(C, 3);
    [n, p] = size(x1);
    Q = zeros(n, p);
    for i = 1 : N
        temp = C(:, :, i) * x1;
        Q = Q - 4 * temp * diag(diag(x1' * temp));
    end
    K = x2' * Q;
    omega = 0.5 * x1' * Q;
    omega = omega - omega';
    omega = - sqrt(2) * omega;
    output = zeros(n * p - 0.5 * p * (p + 1), 1);
    indx = find(tril(ones(p, p), -1));
    output(1 : 0.5 * p * (p - 1)) = omega(indx);%nonzeros(tril(omega, -1));
    output(0.5 * p * (p - 1) + 1 : end) = reshape(K, [], 1);
end

function output = soft_ICA_gf_quotient(Y, C)
    N = size(C, 3);
    [n, p] = size(Y);
    Q = zeros(n, p);
    for i = 1 : N
        Q = Q - 4 * C(:, :, i) * Y * diag(diag(Y' * C(:, :, i) * Y));
    end
    output = Q - Y * Q' * Y;
end

function output = gf_exists(x, C)
    output = 1;
end

function output = soft_ICA_gf_quotient_intr(Y, C)
    x1 = Y{1};
    x2 = Y{2};
    N = size(C, 3);
    [n, p] = size(x1);
    Q = zeros(n, p);
    for i = 1 : N
        temp = C(:, :, i) * x1;
        Q = Q - 4 * temp * diag(diag(x1' * temp));
    end
    
    omega = x1' * Q;
    omega = omega' - omega;
    K = x2' * Q;
    output = zeros(n * p - 0.5 * p * (p + 1), 1);
    indx = find(tril(ones(p, p), -1));
    output(1 : 0.5 * p * (p - 1)) = omega(indx);%%nonzeros(tril(omega, -1));
    output(0.5 * p * (p - 1) + 1 : end) = reshape(K, [], 1);
end

function output = hessianv_embedded(Y, v, C)
    N = size(C, 3);
    [n, p] = size(Y);
    Q1 = zeros(n, p);
    Q2 = zeros(n, p);
    for i = 1 : N
        temp1 = C(:, :, i) * Y;
        temp2 = Y' * temp1;
        
        Q1 = Q1 - 4 * C(:, :, i) * (v * diag(diag(temp2)) + 2 * Y * diag(diag(v' * temp1)));
        Q2 = Q2 - 4 * temp1 * diag(diag(temp2));
    end
    temp1 = Y' * Q1;
    temp2 = Y' * Q2;
    output = Q1 - Y * (temp1 + temp1') / 2 - v * (temp2 + temp2') / 2;
end

function output = hessianv_embedded_intr(Y, v, C)
    v = embedded_full_c(Y, v);
    Y1 = Y{1};
    
    N = size(C, 3);
    [n, p] = size(Y1);
    Q1 = zeros(n, p);
    Q2 = zeros(n, p);
    for i = 1 : N
        temp1 = C(:, :, i) * Y1;
        temp2 = Y1' * temp1;
        Q1 = Q1 - 4 * C(:, :, i) * (v * diag(diag(temp2)) + 2 * Y1 * diag(diag(v' * temp1)));
        Q2 = Q2 - 4 * temp1 * diag(diag(temp2));
    end
    temp1 = Y1' * Q1;
    temp2 = Y1' * Q2;
    output = Q1 - Y1 * (temp1 + temp1') / 2 - v * (temp2 + temp2') / 2;
    
    output = embedded_intr_c(Y, output);
end

function output = hessianv_quotient(Y, v, C)
    N = size(C, 3);
    [n, p] = size(Y);
    Q1 = zeros(n, p);
    Q2 = zeros(n, p);
    for i = 1 : N
        temp1 = C(:, :, i) * Y;
        temp2 = Y' * temp1;
        
        Q1 = Q1 - 4 * C(:, :, i) * (v * diag(diag(temp2)) + 2 * Y * diag(diag(v' * temp1)));
        Q2 = Q2 - 4 * temp1 * diag(diag(temp2));
    end
    % Q1 D gradf(Y) [xi]
    % Q2 gradf(Y)
    temp1 = Q2' * v;
    temp2 = v * Q2';
    temp3 = v * (Y' * Q2);
    output = Q1 - Y * (Q1)' * Y - Y * (temp1 - temp1') / 2 - (temp2 - temp2') / 2 * Y - 0.5 * (temp3 - Y * (Y' * temp3));
end

function output = hessianv_quotient_intr(Y, v, C)
    v = quotient_full_c(Y, v);
    Y1 = Y{1};
    
    N = size(C, 3);
    [n, p] = size(Y1);
    Q1 = zeros(n, p);
    Q2 = zeros(n, p);
    for i = 1 : N
        temp1 = C(:, :, i) * Y1;
        temp2 = Y1' * temp1;
        
        Q1 = Q1 - 4 * C(:, :, i) * (v * diag(diag(temp2)) + 2 * Y1 * diag(diag(v' * temp1)));
        Q2 = Q2 - 4 * temp1 * diag(diag(temp2));
    end
    % Q1 D gradf(Y) [xi]
    % Q2 gradf(Y)
    temp1 = Q2' * v;
    temp2 = v * Q2';
    temp3 = v * (Y1' * Q2);
    output = Q1 - Y1 * (Q1)' * Y1 - Y1 * (temp1 - temp1') / 2 - (temp2 - temp2') / 2 * Y1 - 0.5 * (temp3 - Y1 * (Y1' * temp3));
    
    output = quotient_intr_c(Y, output);
end

function output = M_diag(Y, D)
    [n, p] = size(Y);
    output = zeros(n, p);
    for i = 1 : p
        output(:, i) = Y(:, i) * D(i, i);
    end
end

function output = build_bases_embedded(x)
    [n, p] = size(x);
    x_perp = null(x');
    x = x / sqrt(2);
    d = n * p - p * (p + 1) / 2;
    output = zeros(n * p, d);
    col = 1;
    for i = 2 : p
        for j = i : p
            output((i - 2) * n + 1 : (i - 1) * n, col) = - x(:, j);
            output((j - 1) * n + 1 : j * n, col) = x(:, i - 1);
            col = col + 1;
        end
    end
    for i = 1 : p
        output((i - 1) * n + 1 : i * n, col : col + n - p - 1) = x_perp(1 : n, 1 : n - p);
        col = col + n - p;
    end
end

function output = embedded_intr_c(x, eta)
    x1 = x{1};
    x2 = x{2};
    r2 = sqrt(2);
    [n, p] = size(x1);
    omega = x1' * eta; % omega
    K = x2' * eta; % K
    omega = - sqrt(2) * omega;
    output = zeros(n * p - 0.5 * p * (p + 1), 1);
    indx = find(tril(ones(size(omega)), -1));
    output(1 : 0.5 * p * (p - 1)) = omega(indx);
    output(0.5 * p * (p - 1) + 1 : end) = reshape(K, [], 1);
end

function output = quotient_intr_c(x, eta)
    x1 = x{1};
    x2 = x{2};
    [n, p] = size(x1);
    omega = x1' * eta; % omega
    K = x2' * eta; % K
    omega = - omega;
    output = zeros(n * p - 0.5 * p * (p + 1), 1);
    indx = find(tril(ones(size(omega)), -1));
    output(1 : 0.5 * p * (p - 1)) = omega(indx);
    output(0.5 * p * (p - 1) + 1 : end) = reshape(K, [], 1);
end

function output = embedded_full_c(x, v)
    x1 = x{1};
    x2 = x{2};
    r2 = sqrt(2);
    [n, p] = size(x1);
    omega = tril(ones(p, p), -1);
    indx = find(omega);
    omega(indx) = v(1 : 0.5 * p * (p - 1)) / r2;
    omega = omega' - omega;
    K = reshape(v(0.5 * p * (p - 1) + 1 : end), (n - p), p);
    output = x1 * omega + x2 * K;
end

function output = quotient_full_c(x, v)
    x1 = x{1};
    x2 = x{2};
    [n, p] = size(x1);
    omega = tril(ones(p, p), -1);
    indx = find(omega);
    omega(indx) = v(1 : 0.5 * p * (p - 1));
    omega = omega' - omega;
    K = reshape(v(0.5 * p * (p - 1) + 1 : end), (n - p), p);
    output = x1 * omega + x2 * K;
end
