#include <iostream>
#include <fstream>
#include <cstdio>
#include <conio.h>
#include <cassert>

#include "aritcode.h"
#include "image.h"

const int IMAGE_WIDTH = 80;
const int IMAGE_HEIGHT = 25;
const int IMAGE_SIZE = IMAGE_WIDTH*IMAGE_HEIGHT;
const int IMAGE_BYTES = IMAGE_SIZE*2;
const int MAX_MODELS = 4;

static unsigned char tmp[65536];
static short modeldata[MAX_MODELS][IMAGE_SIZE*16][2];
int models[] = {-1, -80, 10000};

void calculate_modeldata(const unsigned short* ptr) {
	for(int m = 0; m < MAX_MODELS; m++) {
		for(int i = 0; i < IMAGE_SIZE; i++) {			//for each byte
			for(int b = 16; b > 0; b--) {				//for each bit
				int n[2] = {0, 0};
				int j = 0;
				
				do {									//for each context
					int k = 0;
					short bx = 0;
					int mask = m;
					int s = b;
					bool match = true;
					do {
						if(j+bx < 0) {
							match = false;
							break;
						}
						
						if((ptr[i+bx] ^ ptr[j+bx])>>s) {
							match = false;
							break;
						}
						s = 0;
skip:
						bx = models[k++];
	
						if(mask == 0) {
							match = true;
							break;
						}

						if((mask & 1) == 0) {
							mask >>= 1;
							goto skip;
						} 
						mask >>= 1;
					} while(1);

					if(match) {	//update counters?
						int pred_bit = (ptr[j] >> (b-1)) & 1;
						n[pred_bit]++;
						n[!pred_bit] = (n[!pred_bit]+1) / 2;	//cool new penalty function :)
					}

					j++;
				} while(j < i);

				modeldata[m][i*16+b-1][0] = n[0];
				modeldata[m][i*16+b-1][1] = n[1];
			}
		}
	}
}

int compress(const unsigned short* ptr, unsigned char* dst, int baseprob0, int baseprob1, int boost_factor) {
	AritState as;
	AritCodeInit(&as, dst);

	for(int i = 1; i < IMAGE_SIZE; i++) {
		for(int b = 16; b > 0; b--) {
			unsigned short nt[2] = {baseprob0, baseprob1};
			for(int m = 0; m < MAX_MODELS; m++) {
				int boost = (modeldata[m][i*16+b-1][0] && modeldata[m][i*16+b-1][1]) ? 1 : boost_factor;
				nt[0] += modeldata[m][i*16+b-1][0]*boost;
				nt[1] += modeldata[m][i*16+b-1][1]*boost;
			}


			if(AritCode(&as, nt[0], nt[1], (ptr[i] >> (b-1)) & 1))
				return INT_MAX;	//compression failed
		}
	}
	return AritCodeEnd(&as);
}

static int optimize_parameters(const unsigned short* image, int &best_bp0, int &best_bp1, int& best_boost_factor) {
	int best_size = INT_MAX;
	
	for(int bp0 = 3; bp0 < 4; bp0++) {
		for(int bp1 = 3; bp1 < 4; bp1++) {
			for(int boost_factor = 5; boost_factor < 25; boost_factor++) {
				int size = compress(image, tmp, bp0, bp1, boost_factor);
				if(size < best_size) {
					printf("bp0: %d bp1: %d boost_factor: %d  size: %.2f bytes\n", bp0, bp1, boost_factor, size / 8.0f);
					best_size = size;
					best_bp0=bp0;
					best_bp1=bp1;
					best_boost_factor=boost_factor;
				}
			}
		}
	}
	return best_size;
}

int main(int argc, const char* argv[]) {
	//dirty, heuristic and naive transform to improve compression
	int old_color = -1;
	for(int i = 0; i < IMAGE_SIZE; i++) {
		if(image[i*2] == 0x20) {
			image[i*2] = 0;
			image[i*2+1] &= 0xF0;		//clear foreground color
			if(image[i*2+1]) {
				image[i*2] = 0xDB;
				image[i*2+1]>>=4;
			}
		} else if(image[i*2] == 0xDB) {
			image[i*2+1] &= 0x0F;		//clear background color
		}
		
		if((image[i*2+1]>>4) > (image[i*2+1]&0x0F)) {
			if(image[i*2] == 0xDC) {
				image[i*2] = 0xDF;
				image[i*2+1] = (image[i*2+1]<<4)|(image[i*2+1]>>4);
			}
		}
		
		if(image[i*2] == 0x00 || image[i*2] == 0xDB) {
			if(old_color != 0) {
				if((image[i*2+1]<<4) == old_color) {
					image[i*2] = 0x00;
					image[i*2+1] <<= 4; 
				} else if((image[i*2+1]>>4) == old_color) {
					image[i*2] = 0xDB;
					image[i*2+1] >>= 4;
				}
			}
		}

		old_color = image[i*2+1];
	}

	//write file to disk
	std::ofstream imgfile("image_stripped", std::ios::binary);
	imgfile.write((char*)image, IMAGE_WIDTH*IMAGE_HEIGHT*2);
	imgfile.close();

	calculate_modeldata((unsigned short*)image);

	int best_bp0, best_bp1, best_boost_factor;
	optimize_parameters((unsigned short*)image, best_bp0, best_bp1, best_boost_factor);	

	memset(tmp, 0, sizeof(tmp));
	int size = compress((unsigned short*)image, tmp, best_bp0, best_bp1, best_boost_factor);
	printf("bp0: %d bp1: %d boost_factor: %d  size: %.2f bytes\n", best_bp0, best_bp1, best_boost_factor, size / 8.0f);
	printf("compressed size: %f\n", size / 8.0f);
	
	std::ofstream outfile("compressed", std::ios::binary);
	outfile.write((char*)tmp, (size+7)/8);
	outfile.close();
	
	return 0;
}