//
//  DDERungeKutta23.m
//  DDE23-Test
//
//  Created by ashley on 13/04/2008.
//  Copyright 2008 __MyCompanyName__. All rights reserved.
//

#import "DDERungeKutta23.h"

#define ERRCON 1.89e-4

void error(char *str)
{
	printf("%s\n", str);
	exit(-1);
}

double zeropos(double x1, double x2, double x3, double s1, double s2, double s3)
//double x1,x2,x3,s1,s2,s3;
/* finds the root in [x1,x3] of a quadratic passing through the (xi,si)s
   it is assumed that s3<s1*/
{ 
	double z,y,zpy,a,b,c,d,a1,b1,c1,p;
	int ok=1;
//	static int first=1;
	double udge = 1.00000001;
/*	if (first) { 
		first=0;
		udge=1.00000001;
	} */
	
	z = x3 - x2;
	y = x2 - x1;
	zpy = z + y;
	if (z==0.0 || y==0.0) {
		error("Error in switching: zero switch interval");
	}
    a1 = a= s2;
    c1 = c = (z*s1 + y*s3 - zpy*s2)/(zpy*z*y);
    b1 = b = (s2 - s1)/y + c*y;
    d = b*b - 4.0*a*c;
	c *= 2.0;
/* linear only approximation - in case c numerically zero */
	p = -a/b; 
	if (c==0.0) {
		a = p;
	} else { 
		if (d>=0.0) { 
			d = sqrt(d);
			a = (-b + d)/c;
			b = (-b - d)/c;
			if ((b>=-y)&&(b<=z)) {
				a = b; 
			} else {
				if ((a<-y) || (a>z)) {
					ok = 0;
				}
			}
		}
		if ((d<0.0) || (!ok)) { 
			if (-s3<s1) a = z; else a = -y;
		}
		z = a1 + a*b1 + a*a*c1;
		d = a1 + p*b1 + p*p*c1;
// check that linear interpolation is not better 
		if (fabs(z)>fabs(d)) {
			a = p; 
		}
	}
	a += x2;
	if (a>x3) {
		a = x3;
	}
	if (a<=x1) { 
		if (a==0.0) {
			a = udge - 1.0; 
		} else if (a<0.0) {
			a /= udge; 
		} else {
			a *= udge;
		}
	}
	
	return a;
}

@implementation DDERungeKutta23

+ (BOOL) hasHistory
{
	return YES;
}

- (id) initWithRhs:(id)rhs
{
	self = [super init];
	
	if (self!=nil) {
//		_stepper = [stepper retain];
/*		if ([rhs nHistory]>0) {
			_history = [[DDEHistory alloc] initWithRhs:rhs];
		} else {
			_history = nil;
		} */
		_accepted = 0;
		_rejected = 0;

		[rhs setHistory:self];
		_rhs = [rhs retain];
		
		int nState = [rhs nVariables];
		_k2 = (double *)calloc(nState, sizeof(double));
		_k3 = (double *)calloc(nState, sizeof(double));
		_k4 = (double *)calloc(nState, sizeof(double));

		int nsw = [rhs nSwitches];
		int ns = [rhs nVariables];
	
		_sw1 = (double *)calloc(nsw, sizeof(double));
		_sw2 = (double *)calloc(nsw, sizeof(double));
		_s1 = (double *)calloc(ns, sizeof(double));
		_s2 = (double *)calloc(ns, sizeof(double));
		_err1 = (double *)calloc(ns, sizeof(double));
		_flicked = (int *)calloc(nsw, sizeof(int));

		_history = nil;
		int nhv = [rhs nHistory];
		if (nhv) {
			long histsize = [[rhs option:@"histsize"] intValue];
			int nlag = [rhs nLag];
	
			_no = (long)nhv;
			_size = histsize;
			int i;
			_lagmarker = (long **)calloc((size_t)nhv, sizeof(long *));
			for (i=0; i<nhv; i++) {
				_lagmarker[i] = (long *)calloc((size_t)nlag, sizeof(long));
			}
		
			_clock = (double *)calloc((size_t)_size, sizeof(double));
	
			_buff = (double **)calloc((size_t)nhv, sizeof(double *));
			for (i=0; i<nhv; i++) {
				_buff[i] = (double *)calloc((size_t)_size, sizeof(double));
			}
	
			_gbuff = (double **)calloc((size_t)nhv, sizeof(double *));
			for (i=0; i<nhv; i++) {
				_gbuff[i] = (double *)calloc((size_t)_size, sizeof(double));
			}
			if (!_gbuff[nhv-1]) { 
				printf("History buffer too long\n");
				return nil;
			}
	
			_his = (double *)calloc((size_t)_no, sizeof(double));
			_ghis = (double *)calloc((size_t)_no, sizeof(double));
			
			_history = self; 
		}
	
		_offset = -1L;
		_first = YES;
	}
	
	return self;
}

- (id) rhs
{
	return _rhs;
}

- (id) history
{
	return _history;
}

- (double) istep:(double *)sw0 newSwitches:(double *)newsws state:(double *)s0 newState:(double *)news 
				gradients:(double*)g newGradients:(double *)newg error:(double *)err 
				startTime:(double)t0 stopTime:(double)t1 flickedSwitch:(int *)flickedswitch success:(BOOL*)success
{
	id rhs = [self rhs];
	int nsw = [rhs nSwitches];
	int ns = [rhs nVariables];
	
	double dt = t1-t0;
//	rk23(s0,news,g,newg,err,c,ns,t0,dt);
	*success = [self step:s0 newState:news gradients:g newGradients:newg error:err time:t0 timeStep:dt];
	if (nsw!=0) {
		[rhs switchFunctions:newsws newState:news time:t1];
	}

// Are there any switches?
	int i;
	int switches=0;
	for (i=0;i<nsw;i++) {
		if ((sw0[i]>0.0) && (newsws[i]<=0.0)) { 
			_flicked[switches] = i;
			switches++;
		}
	}
// If no switches then its an ordinary step 
	if (!switches) { 
		*flickedswitch = -1;
		return t1;
	}

// Logic for stepping to first switch 
	double sp1 = t0 + dt*0.5;
// if k gets to 100 routine fails
	int k;
	for (k=0;k<100;k++) { 
// step to approx. 1st switch position
//		rk23(s0,s1,g,newg,err,c,ns,t0,sp1-t0); 
		[self step:s0 newState:_s1 gradients:g newGradients:newg error:err time:t0 timeStep:sp1 - t0];
		[rhs switchFunctions:_sw1 newState:_s1 time:sp1];

		switches = 0;
//are there any switches ? 
		for (i=0;i<nsw;i++) {  
			if ((sw0[i]>0.0) && (_sw1[i]<=0.0)) { 
				_flicked[switches] = i;
				switches++;
			}
		}

		if ((k) && (switches==1)) { 
			*flickedswitch = _flicked[0];
			for (i=0;i<ns;i++) {
				news[i] = _s1[i];
			}
			for (i=0;i<nsw;i++) {
				newsws[i] = _sw1[i];
			}
			return(sp1);
		}

// step to end of interval 
//    rk23(s1,s2,newg,newg,err1,c,ns,sp1,t1-sp1);
		[self step:_s1 newState:_s2 gradients:newg newGradients:newg error:_err1 time:sp1 timeStep:t1 - sp1];
//    switchfunctions(sw2,s2,c,t1);
		[rhs switchFunctions:_sw2 newState:_s2 time:t1];

// are there any switches?
		for (i=0;i<nsw;i++) {
			if ((_sw1[i]>0.0) && (_sw2[i]<=0.0)) { 
				_flicked[switches] = i;
				switches++;
			}
		}
		
		if (!switches) { 
			*flickedswitch=-1;
			for (i=0;i<ns;i++) { 
				news[i] = _s2[i]; 
				err[i] = sqrt(err[i]*err[i] + _err1[i]*_err1[i]);
			}
			for (i=0;i<nsw;i++) {
				newsws[i] = _sw2[i];
			}
			return(t1);
		}

// having got this far switch positions must be estimated 

// locate the first switch 
		double sp2 = t1;
		double minp=t1;
		for (i=0;i<switches;i++) { 
			double zp;
			if ((t0==t1)||(sp1==t1)||(t0==sp1)) {
				zp = t1; 
			} else {
				zp = zeropos(t0, sp1, t1, sw0[_flicked[i]], _sw1[_flicked[i]], _sw2[_flicked[i]]);
			}
			if (zp<minp) { 
				sp2 = minp;
				minp = zp;
			}
		}
		sp1 = minp;
		double udge = 1e-9;
		double ds = sp2 - sp1;
		if (ds>0.0) {
			do { 
				sp1+=udge*ds;
				udge*=10.0;
			}  while (sp1==minp);
		}
	}
	error("Problem with switch logic");
	return(t0);
}

- (BOOL) step:(double *)state newState:(double *)newState gradients:(double *)gradients newGradients:(double *)newGradients
			  error:(double *)error time:(double)t timeStep:(double)dt

/* Takes a single integration step from time to time+dt using a 3rd order
   embedded Runge-Kutta Fehlberg method:
   E.Hairer, S.P.Norsett & G.Wanner 1987, Solving Ordinary differential
   Equations I. springer-Verlag Berlin. p170 RKF2(3)B
   The routine returns an estimated error vector for adaptive timestepping.
   The gradient of the state variables is to be given in function grad().
   The routine uses the lower order scheme for updating,
   fortunately Fehlberg optimised the coefficients for the lower order
   scheme..... 4/10/95.
   NOTE: not yet optimised for minimum gradient evaluations - see original
   table of coeffs. Partially optimised 9/10/95 Only valid for ci=b4i!
   Takes gradient at time in g, puts gradient at time+dt in newg - these can
   be the same pointer/array */


{ 	
//	Embedded RKF table - coded this way to save addressing time.
	double	a2=0.25,  a3=27.0/40.0,
			b21= 0.25,
			b31=-189.0/800.0,  b32= 729.0/800.0,
			b41= 214.0/891.0,  b42= 1.0/33.0,     b43= 650.0/891.0,
			cc1= 533.0/2106.0, cc3= 800.0/1053.0, cc4=-1.0/78.0;
  
	int i;
	double *k1 = gradients;
  
	id rhs = [self rhs];
	int ns = [rhs nVariables];
  
	for (i=0; i<ns; i++) {
		newState[i] = state[i] + (k1[i]*b21)*dt;
	}
	BOOL success = [rhs gradients:_k2 forState:newState atTime:(t + dt*a2)];
	if (!success) {
		return NO;
	}
  
	for (i=0; i<ns; i++) {
		newState[i] = state[i] + (k1[i]*b31 + _k2[i]*b32)*dt;
	}
  
	 success = [rhs gradients:_k3 forState:newState atTime:(t + dt*a3)];
	if (!success) {
		return NO;
	}
	for (i=0; i<ns; i++) {
		newState[i] = state[i] + (k1[i]*b41 + _k2[i]*b42 + _k3[i]*b43)*dt;
	}

	 success = [rhs gradients:_k4 forState:newState atTime:(t + dt)];
	if (!success) {
		return NO;
	}
	for (i=0; i<ns; i++) { 
		newGradients[i] = _k4[i];
		error[i] = state[i] + (cc1*k1[i] + cc3*_k3[i] + cc4*_k4[i])*dt - newState[i];
	}
	
	return YES;
}

- (int) accepted
{
	return _accepted;
}

- (int) rejected
{
	return _rejected;
}

- (BOOL) solve
{
	int i;
	long iout = 1L;
	double *newsws;
	double *sws;
	double *news;
	double *newg;
	double *err;
	double *e0;
	double *scale;
	double *g;
	
	id rhs = [self rhs];
	int nsw = [rhs nSwitches];
	int ns = [rhs nVariables];

	double *state = (double *)calloc(ns, sizeof(double));
	[rhs initState:state];

	double t0 = [[rhs option:@"tstart"] doubleValue];
	double t1 = [[rhs option:@"tstop"] doubleValue];
	double dt = [[rhs option:@"tstep"] doubleValue];
	double eps = [[rhs option:@"epsilon"] doubleValue];
	double dout = [[rhs option:@"outstep"] doubleValue];
	
	newsws = (double *)calloc(nsw, sizeof(double));
	sws = (double *)calloc(nsw, sizeof(double));
	news = (double *)calloc(ns, sizeof(double));
	newg = (double *)calloc(ns, sizeof(double));
	err = (double *)calloc(ns, sizeof(double));
	e0 = (double *)calloc(ns, sizeof(double));
	scale = (double *)calloc((size_t)ns, sizeof(double));
	[rhs stateScale:scale];
	if (nsw>0) {
		[rhs switchFunctions:sws newState:state time:t0];
	} 

	g = (double *)calloc(ns, sizeof(double));
		
	_accepted = 0;
	_rejected = 0;

	double ti = t0;
	double t = t0;
	double *sp = state;
	double D = dt;
	double mindt = D*1e-9;
	double maxdt = D*100.0;
	
	double *nswp = newsws;
	double *swp = sws;
	double *nsp = news;
	
	int fixstep = 0;
	
	int nhv = [rhs nHistory];

	BOOL success;
	success = [rhs gradients:g forState:state atTime:t0];
	if (!success) {
		return NO;
	}
	
	id history = [self history];
	if (nhv>0) {
		[history updateHistory:g state:state time:t0];
	}

	if (dout!=0.0) {
		[rhs output:state gradients:g time:t];
	}
	
	while (t0<t1) { 
//		printf("%8.3f %8.3f %12.6g\n", t, t0, state[0]);
		double target;
		if (t0+D>t1) { 
			target = t1;
		} else { 
			target = t + D;
		}
		int swi = -1;
		t = [self istep:sws newSwitches:newsws state:state newState:news gradients:g newGradients:newg error:err 
				  startTime:t0 stopTime:target flickedSwitch:&swi success:&success];
		if (!success) {
			return NO;
		}
		double errmax = 0.0;
		if ((!fixstep) && (D>mindt)) {
			for (i=0;i<ns;i++) {
				e0[i] = eps*(fabs(state[i]) + fabs((t - t0)*(g[i] + newg[i])*0.5) + scale[i]);
			}
			for (i=0; i<ns; i++) { 
				double rerr;
				if ((err[i]<1e-150) && (err[i]>-1e-150)) {
					rerr = 0.0;
				} else {
					rerr = err[i]/e0[i];
				}
				rerr = fabs(rerr);
				if (rerr>errmax) {
					errmax = rerr;
				}
			}
		}

		if (errmax<1.0) { 
//			printf("Step accepted\n");
			_accepted++;
			double *dum;
			dum=state;state=news;news=dum;
			dum=sws;sws=newsws;newsws=dum;
			dum=g;g=newg;newg=dum;
			[history updateHistory:g state:state time:t];
			if (dout<0.0) {
				[rhs output:state gradients:g time:t];
			}
			if (dout>0.0) {
				if (t > ti + iout*dout ) { 
					[rhs output:state gradients:g time:t];
					while (ti + iout*dout<=t) {
						iout++;
					}
				}
			} 
			t0 = t;
			if ( swi > -1 ) { 
				if (dout>0.0) {
					[rhs output:state gradients:g time:t];
				}
				[rhs map:state time:t switchNo:swi];
				if (dout>0.0) { 
					[rhs output:state gradients:g time:t];
				} 
				success = [rhs gradients:g forState:state atTime:t];
				if (!success) {
					return NO;
				}
				[history updateHistory:g state:state time:t];
			} else { 
				if ((!fixstep)&&(t<t1)) {
					D = (errmax > ERRCON) ? 0.95*D*pow(errmax,-0.2) : 5.0*D;
				}
				if (D>maxdt) D = maxdt;
			}
		} else { 
//			printf("Step rejected %g %g %g\n", errmax, t, t0);
/* Step actually achieved */
			double Da = t - t0; 
			_rejected++;
/* Shrink D from Da */
			D = 0.95*Da*pow(errmax, -0.25);
			D = (D < 0.1*Da) ? 0.1*Da : D ;
			if (D < 1e-14*(dt)) {
				printf("Step size failure\n");
			}
			t = t0;
		}
		(dt) = D;
	}
	for (i=0;i<ns;i++) {
		sp[i]=state[i]; /* copying results to correct address */
	}
	free(swp);
	free(nswp);
	free(nsp);
	free(err);
	free(e0);
	free(newg);
	free(scale);
	
	return YES;
}

- (void) updateHistory:(double *)gradients state:(double *)state time:(double) t
{
	if (_first) {
		_firstTime = t;
		_first = NO;
	}
	
	id rhs = [self rhs];

	[rhs storeHistory:_his gradientsHistory:_ghis gradients:gradients state:state time:t];
	_lastTime = t;
	_offset++;
	if (_offset==_size) {
		_offset = 0L;
	}
	_clock[_offset] = t;
	
	int i;
	for (i=0; i<_no; i++) {
		_buff[i][_offset] = _his[i];
		_gbuff[i][_offset] = _ghis[i];
	}

}

- (double) pastValue:(int)i time:(double)t mark:(int)markno
{
	long k1;
	long k;
	long offset;
	long offsetplus;
	long size;
	double res;
	double *y;
	double *g;
	double *x;
	double x0;
	double x1;
  
//  for (i=0;i<data.no_c;i++) printf("%4d %15.6lf\n", i, data.c[i]); printf("\n");
	y =_buff[i];
	g = _gbuff[i];
	x = _clock; 
	offset = _offset;
	size = _size;
	if (x[offset]==t) {
		return(y[offset]);
	}
  
	offsetplus = offset + 1L; 
	if (offsetplus==size) {
		offsetplus = 0L;
	}
	k = _lagmarker[i][markno];
	k1 = k + 1L;
	if ((k1>=size)||(k1<0L)) {
		k1=0L;
	}
  
	while ((x[k1]<t) && (k1!=offset)) {
		k1++;
		if (k1==size) {
		k1=0L;
		}
	}
	if (k1==0L) {
		k=size-1L; 
	} else {
		k=k1-1L;
	}

	while ((x[k]>t) && (k!=offsetplus)) { 
		if (k==0L) {
			k=size-1L; 
		} else {
			k--;
		}
	}
	
	k1 = k + 1L;
	if (k1==size) {
		k1=0L;
	}
	if (t<x[k]) {
		printf("Lag too large for history buffer\n");
		exit(0);
	}
  
	x0 = x[k];
	x1 = x[k1];
	if ((t>x[k1])&&(x[k]==x[k1])) {
		res=y[k1]+(t-x[k1])*g[k1]; 
	} else { 
//		        HERMITE(res,x0,x1,y[k],y[k1],g[k],g[k1],t);
//		#define HERMITE(res,x0,x1,y0,  y1,   g0,   g1,   x)
        double xx0 = t - x0;
		double xx1 = t - x1;
		double xx12 = xx1*xx1;
        double xx02 = xx0*xx0;
		double h = x1 - x0;
        if (h!=0.0) {
			res = ((g[k]*xx0*xx12 + g[k1]*xx1*xx02 +
				   (y[k]*(2.0*xx0 + h)*xx12 - y[k1]*(2.0*xx1-h)*xx02)/h)/(h*h)); 
		} else  {
			res = y[k];
		}
	}
//  if (x[k1]==x[k]) return(y[k1]);
  /*printf("\nx0=%g   x1=%g  yk=%g  yk1=%g\n  gk=%g  gk1=%g  res=%g",
  x0,x1,y[k],y[k1],g[k],g[k1],res); */
	_lagmarker[i][markno] = k;
	return res;
}

- (void) dealloc
{
	free(_flicked);
	free(_err1);
	free(_s1);
	free(_s2);
	free(_sw1);
	free(_sw2);

	free(_k2);
	free(_k3);
	free(_k4);

	int nhv = [_rhs nHistory];
	int i;
	for (i=0; i<nhv; i++) {
		free(_lagmarker[i]);
		free(_buff[i]);
		free(_gbuff[i]);
	}
	free(_lagmarker);
	free(_buff);
	free(_gbuff);
	free(_clock);
	free(_his);
	free(_ghis);
	[_rhs release];

	[super dealloc];
}

@end