
#include "mex.h"
// calculate Integral Image on Square of Difference between two images
void calc_II_on_SoD(double *a, double *b, int win_h, int win_w, int a_x_l, int a_y_l, int b_x_l, int b_y_l,
					int aw, int ah, int bw, int bh, int C,
					/*output*/ double *int_win_pad);
void calc_min_patch_dist(double *int_win_pad, int patch_w, int win_h, int win_w, int a_x_l, int a_y_l, int b_x_l, 
						 int b_y_l, int ah, int bh,
						 /*output*/ int *nnf_XY, double *nnf_dist);

void calc_NN( double* a, double* b, int aw, int bw, int ah, int bh,double *int_win_pad, int patch_w, int num_of_colors,
			 /*output*/ int *nnf_XY, double *nnf_dist);

void UpdateSIFTpack(double* a,/*output*/double* b, int aw,int bw,int ah,int bh,int patch_w,int num_of_colors, int* nnf_XY,int* counts);



/*******************************************************************************/
/* mexFUNCTION - gateway routine for use with MATLAB                           */
/* Calculates SIFTpack of an image. Inputs: Dense SIFTpack of an image (A), initial SIFTpack (B), number of iterations (nIter) . Output: SIFTpack, asignments        */
/*******************************************************************************/
void mexFunction(int nout, mxArray *pout[], int nin, const mxArray *pin[]) {

	if (nin < 3) { mexErrMsgTxt("Mex error: nnf_gt called with < 3 input arguments!"); }

	const mxArray *A = pin[0];	
	const mxArray *B = pin[1];
	int nIter = int(mxGetScalar(pin[2]));	

	if ( !mxIsDouble(A) || !mxIsDouble(B) ){ 
		mexErrMsgTxt("Mex error: Input image A or B are not double!");
	}

	int ah = mxGetDimensions(A)[0];
	int aw = mxGetDimensions(A)[1];
	int bh = mxGetDimensions(B)[0];
	int bw = mxGetDimensions(B)[1];
	int patch_w = 4;
	int num_of_colors = 8;
	double *a =(double*) mxGetData(A);
	double *b = (double*)mxGetData(B);

	int* nnf_XY = new int [aw*ah];
	double* nnf_dist = new double [aw*ah];
	double* int_win_pad = new double [(aw+1)*(ah+1)];
	int* counts = new int [bw*bh];
	
	for (int i=0; i<nIter;++i){
		calc_NN( a, b, aw, bw, ah, bh, int_win_pad, patch_w, num_of_colors, nnf_XY, nnf_dist);
		UpdateSIFTpack(a, b, aw, bw, ah, bh, patch_w, num_of_colors, nnf_XY,counts);
	}
	mwSize dims[3] = { bh, bw, num_of_colors };
	mxArray *SIFTpack = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
	double *data = (double *) mxGetData(SIFTpack);

	for (int i=0;i<num_of_colors*bw*bh;++i){
		data[i] = b[i];
	}


	pout[0] = SIFTpack;
	if (nout>1){
		mwSize dims[3] = { ah, aw, 2 };
		mxArray *nnf = mxCreateNumericArray(3, dims, mxDOUBLE_CLASS, mxREAL);
		double *data = (double *) mxGetData(nnf);
		double *xchan = &data[0];
		double *ychan = &data[aw*ah];
		for (int x = 0; x < aw; x++) {

			int *nnf_XY_col   =  &nnf_XY[x*ah];
			for (int y = 0; y < ah; y++) {
				int map = nnf_XY_col[y];
				xchan[y+x*ah] = double(map/bh);
				ychan[y+x*ah] = double(map%bh);

			}
		}
		pout[1] = nnf;
	}




} // end mexFunction

	////////////////////////////////////////////////////////////////////create function!
	void calc_NN( double* a, double* b, int aw, int bw, int ah, int bh,double *int_win_pad, int patch_w, int num_of_colors,
		/*output*/ int *nnf_XY, double *nnf_dist) {
			
			for (int i = 0; i < (aw+1)*(ah+1); ++i)
				int_win_pad[i]=0;

			for (int i = 0; i < aw*ah; ++i)
				nnf_dist[i] = 100000000;

			int a_y_l=0;
			int a_y_r=patch_w-1;
			int b_y_l=bh-patch_w;
			int b_y_r=bh-1;

			for (int dy = 0; dy < ah+bh-2*patch_w+1; ++dy){

				int a_x_l = 0;
				int a_x_r = patch_w-1;
				int b_x_l = bw-patch_w;
				int b_x_r = bw-1;

				for (int dx = 0; dx < aw+bw-2*patch_w+1; ++dx){

					int win_w = a_x_r - a_x_l+1;
					int win_h = a_y_r - a_y_l+1;

					//compute difference and integral image
					calc_II_on_SoD(a, b, win_h, win_w, a_x_l, a_y_l, b_x_l, b_y_l, aw, ah, bw, bh, num_of_colors, int_win_pad);

					//compute patch distance and find minimum
					calc_min_patch_dist(int_win_pad, patch_w, win_h, win_w, a_x_l, a_y_l, b_x_l, b_y_l,ah,bh, nnf_XY, nnf_dist);


					//order of the if's can't be changed
					if (a_x_r+1 > aw-1){ b_x_r = b_x_r-1; }
					if (b_x_l == 0    ){ a_x_l = a_x_l+1; }

					if (a_x_r+1 <= aw-1){ a_x_r = a_x_r+1; }
					if (b_x_l   >     0){ b_x_l = b_x_l-1; }
				}

				if (a_y_r+1 > ah-1) { b_y_r = b_y_r-1; }
				if (b_y_l   ==  0 ) { a_y_l = a_y_l+1; }
				if (a_y_r+1 <= ah-1) { a_y_r = a_y_r+1; }
				if (b_y_l   >     0) { b_y_l = b_y_l-1; }

			}
	}


void calc_II_on_SoD(double *a, double *b, int win_h, int win_w, int a_x_l, int a_y_l, int b_x_l, int b_y_l,
					int aw, int ah, int bw, int bh, int C,
					/*output*/ double *int_win_pad){
						int a_layer = aw*ah;
						int b_layer = bw*bh;
						double diff;
						for (int x = 0; x < win_w; ++x){ 
							double *a_row_color =  &a[ (a_x_l+x)*ah + a_y_l ];
							double *b_row_color =  &b[ (b_x_l+x)*bh + b_y_l ];

							double *int_win_pad_row_prev = & int_win_pad[x*(ah+1) + 1];
							double *int_win_pad_row      = & int_win_pad[(1+x)*(ah+1) + 1];

							double prev_row_sum=0;
							for (int y = 0; y < win_h; ++y){
								double new_value=0;
								for(int c = 0; c < C; ++c){
									diff = a_row_color[y+c*a_layer] - b_row_color[y+c*b_layer];
									new_value += diff*diff;
								}		
								prev_row_sum = prev_row_sum + new_value;
								int_win_pad_row[y] = prev_row_sum+int_win_pad_row_prev[y];
							}
						}

}


void calc_min_patch_dist(double *int_win_pad, int patch_w, int win_h, int win_w, int a_x_l, int a_y_l, int b_x_l, 
						 int b_y_l, int ah, int bh,
						 /*output*/ int *nnf_XY, double *nnf_dist){

							 double cur_dist = 0;
							 for (int x = 0; x < win_w-patch_w+1; ++x){ 
								 double *win_pad_col_first = &int_win_pad[x*(ah+1)];
								 double *win_pad_col_last  = &int_win_pad[(patch_w+x)*(ah+1)];

								 double *nnf_dist_col = &nnf_dist[(a_x_l+x)*ah +a_y_l];
								 int *nnf_XY_col   = &nnf_XY[(a_x_l+x)*ah+a_y_l];

								 for (int y = 0; y < win_h-patch_w+1; ++y){
									 cur_dist = win_pad_col_first[y] - win_pad_col_first[y+patch_w]
									 -win_pad_col_last[y]  + win_pad_col_last[y+patch_w];

									 if(cur_dist < nnf_dist_col[y]){
										 nnf_dist_col[y] = cur_dist;
										 nnf_XY_col[y] = b_y_l+y + (b_x_l+x)*bh;
									 }
								 }
							 }

}

void UpdateSIFTpack(double* a,/*output*/double* b, int aw,int bw,int ah,int bh,int patch_w,int num_of_colors, int* nnf_XY, int* counts){
	int aIdx;
	int bIdx;
	int c;
	int stLayer_a; int stLayer_b;
	int bx; int by;
	int x; int y;
	double* b_new = new double [bw*bh*num_of_colors];

	for (int i=0;i<bw*bh;++i)
	{
		counts[i]=0;
	}
	for (int i=0;i<bw*bh*num_of_colors;++i)
	{
		b_new[i]=0;

	}

	for (int i=0;i<aw-patch_w+1;++i){		
		for (int j=0, aIdx = i*ah; j<ah-patch_w+1;++j,++aIdx){
			bIdx = nnf_XY[aIdx];
			bx = bIdx/bh;
			by = bIdx%bh;
			//c=0
			for (x=0;x<patch_w;++x){
				for (y=0;y<patch_w;++y){
					b_new[(bx+x)*bh + by +y] += a[(i+x)*ah+j+y];
					++counts[(bx+x)*bh + by +y];
				}
			}

			for (c=1;c<num_of_colors;++c){//c>0
				stLayer_b = c*bw*bh;
				stLayer_a = c*aw*ah;
				for (x=0;x<patch_w;++x){
					for (y=0;y<patch_w;++y){
						b_new[stLayer_b + (bx+x)*bh + by +y] += a[stLayer_a + (i+x)*ah+j+y];
					}
				}


			}

			
		}

	}

	for (int c=0; c<num_of_colors; ++c){
				stLayer_b = c*bw*bh;
				for (int i = 0; i < bw*bh; ++i){
					if (counts[i])					
					  b[stLayer_b+i] = b_new[stLayer_b+i]/counts[i];								

				}
	}
	delete[] b_new;

}