/*
This file is the driver for problem defined by Matlab functions.
The Matlab driver "DriverOPT.m" creates function handles of a cost function and
parameters for both solvers and manifolds, which are passing to the
binary file generated by this C++ driver.

---- WH
*/

#ifndef DRIVERMEXPROB_H
#define DRIVERMEXPROB_H


#include <iostream>
#include "Others/randgen.h"
#include "Manifolds/Manifold.h"
#include "Problems/Problem.h"
//#include "Problems/SphereTxRQ/SphereTxRQ.h"
#include "Solvers/SolversSMLS.h"
#include <ctime>

#include "Manifolds/Euclidean.h"
//#include "Manifolds/Euclidean/EucVariable.h"

//#include "Manifolds/Stiefel/StieVector.h"
//#include "Manifolds/Stiefel/StieVariable.h"
//#include "Manifolds/Stiefel/Stiefel.h"

/*Linesearch based solvers*/
#include "Solvers/RSD.h"
#include "Solvers/RNewton.h"
#include "Solvers/RCG.h"
#include "Solvers/RBroydenFamily.h"
#include "Solvers/RWRBFGS.h"
#include "Solvers/RBFGS.h"
#include "Solvers/LRBFGS.h"
#include "Solvers/RGS.h"
#include "Solvers/LRBFGSSub.h"
#include "Solvers/RBFGSSub.h"
#include "Solvers/IRPG.h"
#include "Solvers/IARPG.h"

/*Trust-region based solvers*/
#include "Solvers/SolversSMTR.h"
#include "Solvers/RTRSD.h"
#include "Solvers/RTRNewton.h"
#include "Solvers/RTRSR1.h"
#include "Solvers/LRTRSR1.h"
#include "Solvers/LRTRSR1woR.h"

#include "Others/def.h"

#include "Manifolds/CFixedRankQ2F.h"
#include "Manifolds/CStiefel.h"
#include "Manifolds/CSymFixedRankQ.h"
#include "Manifolds/Euclidean.h"
#include "Manifolds/FixedRankE.h"
#include "Manifolds/FixedRankQ2F.h"
#include "Manifolds/Grassmann.h"
#include "Manifolds/SPDManifold.h"
#include "Manifolds/Sphere.h"
#include "Manifolds/Stiefel.h"
#include "Manifolds/SymFixedRankQ.h"

#ifdef MATLAB_MEX_FILE

#include "Problems/mexProblem.h"

using namespace ROPTLIB;

/*This function checks the number and formats of input parameters.
nlhs: the number of output in mxArray format
plhs: the output objects in mxArray format
nrhs: the number of input in mxArray format
prhs: the input objects in mxArray format */
void DriverMexProb(int &nlhs, mxArray ** &plhs, int &nrhs, const mxArray ** &prhs);

/*This file creates a C++ solver based on the input parameters and run the solver to obtain a solution.*/
void ParseSolverParamsAndOptimizing(const mxArray *SolverParams, const Problem *Prob, Variable *initialX, mxArray ** &plhs,
	realdp(*LSInput)(integer iter, const Variable &x1, const Vector &eta1, realdp initialstepsize, realdp initialslope, const Problem *prob, const Solvers *solver) = nullptr, LinearOPE *initialLOPE = nullptr);

/*This file creates all components of the product manifolds. Note that the cost function is always
defined on a product manifolds.*/
bool ParseManiParams(const mxArray *ManiParams, Manifold **&manifolds, integer &numoftype, integer *&powsinterval);

/*Create a manifold based on the parameters.*/
Manifold *GetAManifold(const char *name, integer n, integer m, integer p = 1);

///*Create an element based on the parameters. Note that the element is a component of the initial iterate.*/
//Element *GetAnElement(const char *name, integer n, integer m, integer p = 1);

namespace RMEX{
	mxArray *isstopped = nullptr;
	/*This function defines the stopping criterion that may be used in the C++ solver*/
	bool mexInnerStop(const Variable &x, const Vector &funSeries, integer lengthSeries, realdp ngf, realdp ngf0, const Problem *prob, const Solvers *solver)
	{
		mxArray *Xmx, *funSeriesmx, *ngfmx, *ngf0mx; /* *fmx,*/
		mexProblem::ObtainMxArrayFromElement(Xmx, &x);
        Vector funSeriesTrun(lengthSeries);
        realdp *funSeriesTrunptr = funSeriesTrun.ObtainWriteEntireData();
        const realdp *funSeriesptr = funSeries.ObtainReadData();
        for(integer i = 0; i < lengthSeries; i++)
            funSeriesTrunptr[i] = funSeriesptr[i];
        
		mexProblem::ObtainMxArrayFromElement(funSeriesmx, &funSeriesTrun);
//		fmx = mxCreateDoubleScalar(f);
		ngfmx = mxCreateDoubleScalar(ngf);
		ngf0mx = mxCreateDoubleScalar(ngf0);

		mxArray *lhs[1], *rhs[6];
		rhs[0] = const_cast<mxArray *> (isstopped);
		rhs[1] = const_cast<mxArray *> (Xmx);
		rhs[2] = const_cast<mxArray *> (funSeriesmx);
//		rhs[3] = const_cast<mxArray *> (fmx);
		rhs[3] = const_cast<mxArray *> (ngfmx);
		rhs[4] = const_cast<mxArray *> (ngf0mx);
		mexCallMATLAB(1, lhs, 5, rhs, "feval");
		realdp result = mxGetScalar(lhs[0]);
		mxDestroyArray(Xmx);
		mxDestroyArray(funSeriesmx);
//		mxDestroyArray(fmx);
		mxDestroyArray(ngfmx);
		mxDestroyArray(ngf0mx);
		mxDestroyArray(lhs[0]);
		return (result != 0);
	};
	mxArray *LinesearchInput = nullptr;
	/*This function defines the line search algorithm that may be used in the C++ solver*/
	realdp mexLinesearchInput(integer iter, const Variable &x1, const Vector &eta1, realdp initialstepsize, realdp initialslope, const Problem *prob, const Solvers *solver)
	{
		mxArray *Xmx, *eta1mx, *tmx, *smx, *imx;
		mexProblem::ObtainMxArrayFromElement(Xmx, &x1);
		mexProblem::ObtainMxArrayFromElement(eta1mx, &eta1);
		tmx = mxCreateDoubleScalar(initialstepsize);
		smx = mxCreateDoubleScalar(initialslope);
		imx = mxCreateDoubleScalar(iter);

		mxArray *lhs[1], *rhs[6];
		rhs[0] = const_cast<mxArray *> (LinesearchInput);
		rhs[1] = const_cast<mxArray *> (Xmx);
		rhs[2] = const_cast<mxArray *> (eta1mx);
		rhs[3] = const_cast<mxArray *> (tmx);
		rhs[4] = const_cast<mxArray *> (smx);
		rhs[5] = const_cast<mxArray *> (imx);
		mexCallMATLAB(1, lhs, 6, rhs, "feval");
		realdp result = mxGetScalar(lhs[0]);
		mxDestroyArray(Xmx);
		mxDestroyArray(eta1mx);
		mxDestroyArray(tmx);
		mxDestroyArray(smx);
		mxDestroyArray(imx);
		mxDestroyArray(lhs[0]);
		return result;
	};
};
void ParseSolverParamsAndOptimizing(const mxArray *SolverParams, const Problem *Prob, Variable *initialX, mxArray ** &plhs, realdp(*LSInput)(integer iter, const Variable &x1, const Vector &eta1, realdp initialstepsize, realdp initialslope, const Problem *prob, const Solvers *solver), LinearOPE *initialLOPE)
//void ParseSolverParamsAndOptimizing(const mxArray *SolverParams, Problem *Prob, Variable *initialX, mxArray **&plhs,
//	realdp(*LSInput)(integer iter, const Variable &x1, const Vector &eta1, realdp initialstepsize, realdp initialslope, const Problem *prob, const Solvers *solver), LinearOPE *initialLOPE)
{
	integer nfields = mxGetNumberOfFields(SolverParams);
	const char *name;
	mxArray *tmp;
	realdp value;
	PARAMSMAP params;
	std::string key;
	for (integer i = 0; i < nfields; i++)
	{
		name = mxGetFieldNameByNumber(SolverParams, i);
		tmp = mxGetFieldByNumber(SolverParams, 0, i);
		value = mxGetScalar(tmp);
		key.assign(name);
		params.insert(std::pair<std::string, realdp>(key, value));
	}

	tmp = mexProblem::GetFieldbyName(SolverParams, 0, "method");
	if (tmp == nullptr)
	{
		mexErrMsgTxt("A method must be specified.");
	}
    
	char methodname[30] = "";
	mxGetString(tmp, methodname, 30);
	std::string stdmethodname = methodname;
	Solvers *solver;
	if (stdmethodname == "RSD")
	{
		solver = new RSD(Prob, initialX);
	}
	else
	if (stdmethodname == "RNewton")
	{
		solver = new RNewton(Prob, initialX);
	}
	else
	if (stdmethodname == "RCG")
	{
		solver = new RCG(Prob, initialX);
	}
	else
	if (stdmethodname == "RBroydenFamily")
	{
		solver = new RBroydenFamily(Prob, initialX, initialLOPE);
	}
	else
	if (stdmethodname == "RWRBFGS")
	{
		solver = new RWRBFGS(Prob, initialX, initialLOPE);
	}
	else
	if (stdmethodname == "RBFGS")
	{
		solver = new RBFGS(Prob, initialX, initialLOPE);
	}
	else
	if (stdmethodname == "RBFGSSub")
	{
		solver = new RBFGSSub(Prob, initialX, initialLOPE);
	}
	else
	if (stdmethodname == "LRBFGSSub")
	{
		solver = new LRBFGSSub(Prob, initialX);
	}
	else
	if (stdmethodname == "RGS")
	{
		solver = new RGS(Prob, initialX);
	}
	else
	if (stdmethodname == "LRBFGS")
	{
		solver = new LRBFGS(Prob, initialX);
	}
	else
	if (stdmethodname == "RTRSD")
	{
		solver = new RTRSD(Prob, initialX);
	}
	else
	if (stdmethodname == "RTRNewton")
	{
		solver = new RTRNewton(Prob, initialX);
	}
	else
	if (stdmethodname == "RTRSR1")
	{
		solver = new RTRSR1(Prob, initialX, initialLOPE);
	}
	else
	if (stdmethodname == "LRTRSR1")
	{
		solver = new LRTRSR1(Prob, initialX);
	}
    else
    if (stdmethodname == "LRTRSR1woR")
    {
        solver = new LRTRSR1woR(Prob, initialX);
    }
    else
    if (stdmethodname == "IRPG")
    {
        solver = new IRPG(Prob, initialX);
    }
    else
    if (stdmethodname == "IARPG")
    {
        solver = new IARPG(Prob, initialX);
    }
	else
	{
		printf("Warning: Unrecognized solver: %s. return!\n", stdmethodname.c_str());
        return;
//		solver = new LRBFGS(Prob, initialX);
	}
    
	solver->SetParams(params);

	RMEX::isstopped = mexProblem::GetFieldbyName(SolverParams, 0, "IsStopped");
	if (RMEX::isstopped != nullptr)
	{
		solver->StopPtr = &RMEX::mexInnerStop;
	}

	RMEX::LinesearchInput = mexProblem::GetFieldbyName(SolverParams, 0, "LinesearchInput");
	if (RMEX::LinesearchInput != nullptr)
	{
		SolversSMLS *solverLS = dynamic_cast<SolversSMLS *> (solver);
		if (solverLS != nullptr)
			solverLS->LinesearchInput = &RMEX::mexLinesearchInput;
	}
    
	SolversSMLS *solverLS = dynamic_cast<SolversSMLS *> (solver);
	if(solverLS != nullptr)
	{
		if (solverLS->LineSearch_LS == LSSM_INPUTFUN && LSInput != nullptr)
		{
			solverLS->LinesearchInput = LSInput;
		}
	}

	tmp = mexProblem::GetFieldbyName(SolverParams, 0, "IsCheckParams");
	if (tmp != nullptr)
	{
		if (fabs(mxGetScalar(tmp)) > std::numeric_limits<double>::epsilon()) // if the value is nonzero
		{
			solver->CheckParams();
		}
	}
	solver->Run();
    
    Vector Xopttmp = solver->GetXopt();
	mexProblem::ObtainMxArrayFromElement(plhs[0], &Xopttmp);
	plhs[1] = mxCreateDoubleScalar(static_cast<double> (solver->Getfinalfun()));
    
    SolversSM *solverSM = dynamic_cast<SolversSM *> (solver);
    if(solverSM != nullptr)
    {
        plhs[2] = mxCreateDoubleScalar(static_cast<double> (solverSM->Getnormgf()));
        plhs[3] = mxCreateDoubleScalar(static_cast<double> (solverSM->Getnormgfgf0()));
    }
    SolversNSM *solverNSM = dynamic_cast<SolversNSM *> (solver);
    if(solverNSM != nullptr)
    {
        plhs[2] = mxCreateDoubleScalar(static_cast<double> (solverNSM->Getnormnd()));
        plhs[3] = mxCreateDoubleScalar(static_cast<double> (solverNSM->Getnormndnd0()));
    }
    
	plhs[4] = mxCreateDoubleScalar(static_cast<double> (solver->GetIter()));
	plhs[5] = mxCreateDoubleScalar(static_cast<double> (solver->Getnf()));
	plhs[6] = mxCreateDoubleScalar(static_cast<double> (solver->Getng()));
	plhs[7] = mxCreateDoubleScalar(static_cast<double> (solver->GetnR()));
	plhs[8] = mxCreateDoubleScalar(static_cast<double> (solver->GetnV()));
	plhs[9] = mxCreateDoubleScalar(static_cast<double> (solver->GetnVp()));
	plhs[10] = mxCreateDoubleScalar(static_cast<double> (solver->GetnH()));
	plhs[11] = mxCreateDoubleScalar(static_cast<double> (solver->GetComTime()));
	integer lengthSeries = solver->GetlengthSeries();
	plhs[12] = mxCreateDoubleMatrix(lengthSeries, 1, mxREAL);
	plhs[13] = mxCreateDoubleMatrix(lengthSeries, 1, mxREAL);
	plhs[14] = mxCreateDoubleMatrix(lengthSeries, 1, mxREAL);
    
    double *plhsfun = mxGetPr(plhs[12]), *plhsgrad = mxGetPr(plhs[13]), *plhstime = mxGetPr(plhs[14]); //--, *plhsdist = mxGetPr(plhs[15]);
    const double *tmpSeries = nullptr;
    tmpSeries = (solverSM == nullptr) ? solverNSM->GetdirSeries().ObtainReadData() : solverSM->GetgradSeries().ObtainReadData();
	for (integer i = 0; i < lengthSeries; i++)
	{
		plhsfun[i] = solver->GetfunSeries().ObtainReadData()[i];
		plhstime[i] = solver->GettimeSeries().ObtainReadData()[i];
        plhsgrad[i] = tmpSeries[i];
	}
    
	tmp = mexProblem::GetFieldbyName(SolverParams, 0, "IsCheckGradHess");
	if (tmp != nullptr)
	{
		if (fabs(mxGetScalar(tmp)) > std::numeric_limits<realdp>::epsilon()) // if the value is nonzero
		{
			Prob->CheckGradHessian(*initialX);
            Prob->CheckGradHessian(solver->GetXopt());
		}
	}
    plhs[15] = mxCreateDoubleMatrix(4, 1, mxREAL);
    double *plhseigHess = mxGetPr(plhs[15]);
    for(integer i = 0; i < 4; i++)
        plhseigHess[i] = 0;
    
	if (solver->Verbose >= DETAILED)
	{
        Vector MinMaxEigVals1 = Prob->MinMaxEigValHess(*initialX);
        plhseigHess[0] = MinMaxEigVals1.ObtainReadData()[0];
        plhseigHess[1] = MinMaxEigVals1.ObtainReadData()[1];
        
        Vector MinMaxEigVals2 = Prob->MinMaxEigValHess(solver->GetXopt());
        plhseigHess[2] = MinMaxEigVals2.ObtainReadData()[0];
        plhseigHess[3] = MinMaxEigVals2.ObtainReadData()[1];
	}
    
	delete solver;
};

bool ParseManiParams(const mxArray *ManiParams, Manifold **&manifolds, integer &numoftype, integer *&powsinterval)
{
	// Parse ManiParams
	numoftype = mxGetNumberOfElements(ManiParams);
	powsinterval = new integer[numoftype + 1];
	char name[30] = "";
	manifolds = new Manifold *[numoftype];
	integer n, p, m, Params;
	powsinterval[0] = 0;

	for (integer i = 0; i < numoftype; i++)
		powsinterval[i + 1] = powsinterval[i] + mxGetScalar(mexProblem::GetFieldbyName(ManiParams, i, "numofmani"));

	PARAMSMAP params;
	for (integer i = 0; i < numoftype; i++)
	{
		if (mxGetString(mexProblem::GetFieldbyName(ManiParams, i, "name"), name, 30))
			mexErrMsgTxt("error in getting manifold name!");
		n = mxGetScalar(mexProblem::GetFieldbyName(ManiParams, i, "n"));
		m = mxGetScalar(mexProblem::GetFieldbyName(ManiParams, i, "m"));
		p = mxGetScalar(mexProblem::GetFieldbyName(ManiParams, i, "p"));
		Params = mxGetScalar(mexProblem::GetFieldbyName(ManiParams, i, "ParamSet"));
		params[static_cast<std::string> ("ParamSet")] = Params;
		//		params.insert(std::pair<std::string, realdp>("ParamSet", Params));

		manifolds[i] = GetAManifold(name, n, m, p);
		manifolds[i]->SetParams(params);

		if (manifolds[i] == nullptr)
		{
			return false;
		}
	}

	return true;
};

Manifold *GetAManifold(const char *name, integer n, integer m, integer p)
{
    if (strcmp(name, "CFixedRankQ2F") == 0)
    {
        return new CFixedRankQ2F(m, n, p);
    }
    else
    if (strcmp(name, "CStiefel") == 0)
    {
        return new CStiefel(n, p);
    }
    else
    if (strcmp(name, "CSymFixedRankQ") == 0)
    {
        return new CSymFixedRankQ(n, p);
    }
    else
	if (strcmp(name, "Euclidean") == 0)
	{
		return new Euclidean(m, n);
	}
    else
    if (strcmp(name, "FixedRankE") == 0)
    {
        return new FixedRankE(m, n, p);
    }
    else
    if (strcmp(name, "FixedRankQ2F") == 0)
    {
        return new FixedRankE(m, n, p);
    }
    else
    if (strcmp(name, "Grassmann") == 0)
    {
        return new Grassmann(n, p);
    }
    else
    if (strcmp(name, "SPDManifold") == 0)
    {
        return new SPDManifold(n);
    }
    else
    if (strcmp(name, "Sphere") == 0)
    {
        return new Sphere(n);
    }
	else
	if (strcmp(name, "Stiefel") == 0)
	{
		return new Stiefel(n, p);
	}
	else
    if (strcmp(name, "SymFixedRankQ") == 0)
    {
        return new SymFixedRankQ(n, p);
    }
    else
	{
		printf("Manifold: %s does not implemented in this library!\n", name);
		return nullptr;
	}
};

//Element *GetAnElement(const char *name, integer n, integer m, integer p)
//{
//	if (strcmp(name, "Euclidean") == 0)
//	{
//		return new EucVariable(n, m);
//	}
//	else
//	if (strcmp(name, "Sphere") == 0)
//	{
//		return new SphereVariable(n);
//	}
//	else
//	if (strcmp(name, "Stiefel") == 0)
//	{
//		return new StieVariable(n, p);
//	}
//	else
//	if (strcmp(name, "Oblique") == 0)
//	{
//		return new ObliqueVariable(n, m);
//	}
//	else
//	if (strcmp(name, "ObliqueQ") == 0)
//	{
//		return new ObliqueQVariable(n, m);
//	}
//	else
//	if (strcmp(name, "LowRank") == 0)
//	{
//		return new LowRankVariable(n, m, p);
//	}
//	else
//	if (strcmp(name, "Multinomial") == 0)
//	{
//		return new MNVariable(n, m);
//	}
//	else
//	if (strcmp(name, "NStQOrth") == 0)
//	{
//		return new NSOVariable(n, p);
//	}
//	else
//	if (strcmp(name, "OrthGroup") == 0)
//	{
//		return new OrthGroupVariable(n);
//	}
//	else
//	if (strcmp(name, "L2Sphere") == 0)
//	{
//		return new L2SphereVariable(n);
//	}
//	else
//	if (strcmp(name, "SPDManifold") == 0)
//	{
//		return new SPDVariable(n);
//	}
//	else
//	if (strcmp(name, "CpxNStQOrth") == 0)
//	{
//		return new CSOVariable(n, p);
//	}
//	else
//	if (strcmp(name, "Grassmann") == 0)
//	{
//		return new GrassVariable(n, p);
//	}
//	else
//	if (strcmp(name, "EucPositive") == 0)
//	{
//		return new EucPosVariable(n, m);
//	}
//	else
//	if (strcmp(name, "SPDTensor") == 0)
//	{
//		return new SPDTVariable(n, m);
//	}
//	else
//	{
//		printf("Element: %s does not implemented in this library!\n", name);
//		return nullptr;
//	}
//};
#endif // end of MATLAB_MEX_FILE
#endif // end of DRIVERMEXPROB_H
