/* Fit_Polynomial.bsh
 * IJ BAR: https://github.com/tferr/Scripts#scripts
 *
 * Implements polynomial fitting in ImageJ using the apache commons math library.
 * 2014.06, Tiago Ferreira
 *
 * N.B.: CurveFitter will be deprecated after commons-math3.3. It should be replaced by
 *       AbstractCurveFitter once Fiji adopts it.
 * TODO: Plot confidence bands (perhaps using stat.correlation.Covariance to determine the
 *       covariance matrix?)
 */

import ij.IJ;
import ij.ImagePlus;
import ij.WindowManager;
import ij.gui.GenericDialog;
import ij.gui.ImageWindow;
import ij.gui.Overlay;
import ij.gui.Plot;
import ij.gui.PlotWindow;
import ij.gui.Roi;
import ij.gui.TextRoi;
import ij.process.ImageProcessor;
import ij.util.Tools;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.fitting.CurveFitter;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;

boolean verbose = true;	// print details to the log window?
boolean guess = true;	// guess a 'polynomial of best fit'?
int minDegree = 2;		// the smallest polynomial order to consider when guessing POBF
int maxDegree = 40;		// the highest polynomial order to consider when guessing POBF
int defDegree = 12;		// the default polynomial degree


// Check upfront if required dependencies exist and alert non-Fiji users about them
try {
	Class.forName("org.apache.commons.math3.analysis.polynomials.PolynomialFunction");
} catch( ClassNotFoundException e ) {
	String URL = "http://jenkins.imagej.net/job/Stable-Fiji/ws/Fiji.app/jars/commons-math3-3.2.jar";
	IJ.log("\n**** Fit Polynomial Error: Required library not found:\n"+ e +"\n \n"
		+ "The apache commons math library must be installed in the plugins folder.\n"
		+ "Please install it in plugins/jars/ by double-clicking on the link below:\n"+ URL);
	IJ.error("Dependencies missing! Check the Log window for details.");
	return;
}

/* Returns an array with the sum of y and y^2 */
double[] calculateSumYandY2(double[] yArray) {
	double sumY = 0.0; double sumY2 = 0.0;
	for (int i=0; i<nPoints; i++) {
		double y = yArray[i];
		sumY += y;
		sumY2 += y*y;
	}
	return new double[] {sumY, sumY2};
}

/* Returns a readable string of the chosen polynomial */
String getDegreeString(int degree) {
	String polynName;
	if (degree==1) polynName = "Straight line";
	else if (degree==2) polynName = "2nd degree polynomial";
	else if (degree==3) polynName = "3rd degree polynomial";
	else polynName = degree +"th degree polynomial";
	return polynName;
}

/* Returns details of the fit */
String getLogMessage(PolynomialFunction fit, int degree) {
	double[] coefficients = fit.getCoefficients();
	String formula = "\ny = ";
	String parameters = "";
	char cChar = 'a';
	for (int i=0; i<degree+1; i++) {
		parameters += "\n"+ cChar +" = "+ IJ.d2s(coefficients[i], 5, 9);
		formula += cChar +"x^"+ i;
		if (i<degree) formula += " + ";
		cChar++;
		if (i==25) cChar = 'A'; // this won't work for more than 52 parameters
	}
	return "\nPolynomial fit ("+ this.polynomialName +"):"+ formula + parameters +"\n";
}

/* Returns RSquared, http://en.wikipedia.org/wiki/Coefficient_of_determination */
double getRSquared(double[] sampledData, double[] fittedData) {
	double[] sSumYY2 = calculateSumYandY2(sampledData);
	double[] fSumYY2 = calculateSumYandY2(fittedData);
	double sSumY = sSumYY2[0];
	double sSumY2 = sSumYY2[1];
	double fSumY2 = fSumYY2[1];
	double sumMeanDiffSqr = sSumY2 - sSumY*sSumY/nPoints;
	double sumResidualsSqr = sSumY2 - fSumY2;
	double rSquared = 0.0;
	if (sumMeanDiffSqr > 0.0)
		rSquared = 1.0 - sumResidualsSqr/sumMeanDiffSqr;
	return rSquared;
}

/* Returns Adjusted RSquared, http://en.wikipedia.org/wiki/Coefficient_of_determination */
double getAdjustedRSquared(double rSquared, int p) {
	return ( rSquared - ((1-rSquared) * p) / (this.nPoints-p-1) );
}

/* Ensures xsampled and ysampled have the same size and returns the n. of points to be fitted */
int getNumPoints() {
	int size1 = this.xsampled.length;
	int size2 = this.ysampled.length;
	int size = Math.min(size1, size2);
	if (size1!=size2) {
		this.xsampled = Arrays.copyOf(this.xsampled, size);
		this.ysampled = Arrays.copyOf(this.ysampled, size);
	}
	return size;
}

/* Tries to guess the polynomial of best fit while providing some visual feedback */
int guessDegree(Plot plot) {
	ImagePlus imp = this.pw.getImagePlus();
	imp.setHideOverlay(false);
	yfitted = new double[nPoints];
	double maxRSquared = 0; //double maxARSquared = 0;
	for (int i=this.minDegree; i<=this.maxDegree; i++) {
		curDegree = i; color = getColor(i);
		CurveFitter fitter = new CurveFitter(new LevenbergMarquardtOptimizer());
		for (int j=0; j<nPoints; j++)
			fitter.addObservedPoint(xsampled[j], ysampled[j]);
		double[] best = fitter.fit(new PolynomialFunction.Parametric(), new double[i+1]);
		PolynomialFunction fitted = new PolynomialFunction(best);
		for (int k=0; k<nPoints; k++)
			yfitted[k] = fitted.value(xsampled[k]);
		double RSquared = getRSquared(ysampled, yfitted);
		double ARSquared = getAdjustedRSquared(RSquared, i-1);
		if (RSquared>maxRSquared) { //(ARSquared>maxARSquared)
			bestDegree = i; maxRSquared = RSquared; //maxARSquared = ARSquared;
		}
		if (verbose) {
			IJ.log(getDegreeString(i) +": R^2= "+ RSquared +" Adj. R^2= "+ ARSquared);
			Roi roi = new TextRoi(0.0, 0.0, "Degree: "+ i);
			Overlay overlay = new Overlay();
			overlay.add(roi);
			overlay.setFillColor(Color.GRAY);
			overlay.setStrokeColor(colors[color]);
			imp.setOverlay(overlay);
			plot.setColor(colors[color]);
			plot.addPoints(xsampled, yfitted, Plot.LINE);
			pw.drawPlot(plot);
			if (i==this.maxDegree) overlay.clear();
		} else
			IJ.showProgress(i + this.minDegree, this.maxDegree);
	}
	return bestDegree;
}

/* Returns an array of 256 'heatmap' colors */
Color[] parabolicColors() {
	int[] r = new int[256]; int[] g = new int[256]; int[] b = new int[256];
	Color[] colors = new Color[256];
	for (int i=0; i<256; i++) { // create parabolas of vertices (0,255; 255/2,255; 255,255)
		r[i] = (int)Math.max(0, -255/(127.5*127.5) * (i-255)*(i-255) + 255);
		g[i] = (int)(-255/(127.5*127.5) * (i-127.5)*(i-127.5) + 255);
		b[i] = (int)Math.max(0, -255/(127.5*127.5) * i*i + 255);
		colors[i] = new Color(r[i], g[i], b[i]);
	}
	return colors;
}

/* Returns the 8-bit position of index relatively to [minDegree maxDegree] */
int getColor(int index) {
	return ((255*(index-this.minDegree))/(this.maxDegree-this.minDegree));
}

/* Adds color ramp to plot */
void makeLegend(Plot plot) {
	ImageProcessor ip = plot.getProcessor();
	int x = (int)(0.5 * (Plot.LEFT_MARGIN + ip.getWidth() - 256));
	int y = ip.getHeight()-4;
	for (int i=0; i<256; i++) {
		ip.setColor( (i!=color) ? colors[i]:Color.BLACK );
		ip.drawLine(x+i, y, x+i, y-5);
	}
	ip.setColor(Color.BLACK);
	ip.setJustification(ip.CENTER_JUSTIFY);
	ip.drawString("Degree", x+128, y-5);
	ip.drawString(""+this.minDegree, x, y-5);
	ip.drawString(""+this.maxDegree, x+256, y-5);
}

/* Prompts user for settings. Returns false if dialog is dismissed */
boolean getUserInput() {
	GenericDialog gd = new GenericDialog("Fit Polynomial to Curve...");
	gd.addSlider("Polynomial degree:", 1, maxDegree, defDegree);
	gd.addCheckbox("Guess \"best fit\" (ignores specified degree)", guess);
	gd.addStringField("          Degree range:", minDegree +"-"+ maxDegree, 12);
	gd.addCheckbox("Log details (animates retrieval of \"best fit\")", verbose);
	gd.addHelp("http://fiji.sc/Sholl_Analysis#Complementary_Tools");
	gd.showDialog();
	boolean proceed = (gd.wasCanceled()) ? false : true;

	defDegree = (int)Math.max(1, gd.getNextNumber());
	String[] minAndMax = Tools.split(gd.getNextString(), " -");
	int min = minAndMax.length>=1 ? (int)gd.parseDouble(minAndMax[0]) : minDegree;
	int max = minAndMax.length==2 ? (int)gd.parseDouble(minAndMax[1]) : maxDegree;
	min = Double.isNaN(min) ? minDegree : Math.max(1, min);
	max = Double.isNaN(max) ? maxDegree : max;
	minDegree = min; maxDegree = max;
	guess = gd.getNextBoolean();
	verbose = gd.getNextBoolean();
	return proceed;
}

double[] xsampled, ysampled, xfitted, yfitted;
int color, curDegree, nPoints, bestDegree = 1;
String polynomialName;
Color[] colors;
PlotWindow pw;

ImagePlus imp = WindowManager.getCurrentImage();
if (imp==null)
	{ IJ.error("There are no plots open."); return; }
ImageWindow win = imp.getWindow();
if (win!=null && (win instanceof PlotWindow)) {
	pw = (PlotWindow)win;
	xsampled = Tools.toDouble(pw.getXValues());
	ysampled = Tools.toDouble(pw.getYValues());
} else {
	IJ.error(imp.getTitle() +" is not a plot window.");
	return;
}

if (!getUserInput()) return;
long startTime = System.currentTimeMillis();
nPoints = getNumPoints();
if (nPoints==0)
	{ IJ.error("Could not retrieve values from plot"); return; }

IJ.showStatus("Retrieving fits. Please wait...");
colors = parabolicColors();
String plotTitle = imp.getTitle();
boolean plotExists = plotTitle.startsWith("Polynomial fit");
Plot plot = new Plot("Polynomial fit "+ plotTitle, "", "", xsampled, ysampled);

try {

	if (guess) { // Run fitting loop if degree is to be guessed
		if (verbose) {
			IJ.log("\nGuessing polynomial of 'best fit' for "+ plotTitle +":");
			if (!plotExists) // Display plot for visual feedback
				{ pw = plot.show(); plotExists = true; }
		}
		defDegree = guessDegree(plot);
	}

	// Perform final fit
	CurveFitter fitter = new CurveFitter(new LevenbergMarquardtOptimizer());
	for (int i=0; i<nPoints; i++)
		fitter.addObservedPoint(xsampled[i], ysampled[i]);

	// Compute optimal coefficients. The polynomial degree is deduced from the length of
	// the array containing the initial guess for the coefficients (all set to zero here)
	double[] best = fitter.fit(new PolynomialFunction.Parametric(), new double[defDegree+1]);

	// Construct the polynomial that best fits the data
	PolynomialFunction fitted = new PolynomialFunction(best);

	xfitted = new double[nPoints];
	yfitted = new double[nPoints];
	for (int i=0; i<nPoints; i++) {
		xfitted[i] = xsampled[i];
		yfitted[i] = fitted.value(xsampled[i]);
	}

	polynomialName = getDegreeString(defDegree);
	if (verbose) IJ.log("\nFitting "+ plotTitle + ":"+ getLogMessage(fitted, defDegree));

} catch (Exception e) {

	IJ.showStatus("Done. "+ (System.currentTimeMillis()-startTime)/1000.0 +" seconds");
	int errD = (guess) ? curDegree : defDegree;
	IJ.error("Optimizer failed when dealing with a polynomial of degree\n"
		+ errD +" ["+ nPoints +" data points] and will now terminate (see console\n"
		+ "for details)... Perhaps settings should be revised?" );
	IJ.log("\n*** Exception when parsing data from "+ plotTitle +":\n"+ e);
	if (guess)
		IJ.log(">>> Polynomial of 'best fit' until exception: "+ getDegreeString(bestDegree));
	else
		IJ.log(">>> Specified polynomial: "+ getDegreeString(defDegree));
	return;

}

double RSquared = getRSquared(ysampled, yfitted);
color = getColor(Math.min(defDegree,maxDegree));
plot.setColor(colors[color]);
plot.setLineWidth(2);
plot.addPoints(xfitted, yfitted, Plot.LINE);
plot.setColor(Color.BLACK);
plot.setLineWidth(1);
String plotLabel = (guess) ? polynomialName +" (guessed)": polynomialName;
plot.addLabel(0.00, 0, plotLabel +"  R^2= "+ IJ.d2s(RSquared, 5, 9) +"  "+ nPoints +" data points");
if (guess) makeLegend(plot);
if (plotExists)
	pw.drawPlot(plot);
else
	plot.show();
IJ.showStatus("Done. "+ (System.currentTimeMillis()-startTime)/1000.0 +" seconds");
IJ.showProgress(0,0);