#include "SM_Engine3DPCH.h"
#include "SM_KeyFrameSequence.h"
#include "SM_D3DMesh.h"
#include "SM_SpikeTube.h"
#include "SM_Shader.h"
#include "SM_Timer.h"
#include "SM_FVF.h"

using namespace ShaderManager;

typedef Vector3D vec3_t;
#define LENGTHOF(x) (sizeof(x)/sizeof(x[0]))

inline void Cross(vec3_t& o, const vec3_t& a, const vec3_t& b) { o = Vector3D::Cross(a,b); }
inline void Normalize(vec3_t& v) { v.Normalize(); }

int SM_SpikeTube::Init(int iShader, IKFAnimable* pAnimable)
{
	m_shader = iShader;
	m_pAnimable = pAnimable;
  m_fStartTime = Timer::GetTime();

	// init mesh
	m_mesh.Init(sizeof(nvertex), nvertex_fvf, false);
	m_mesh.PrimType() = D3DPT_TRIANGLELIST;
	UpdateMesh(0, 0);
	BuildIndices2();

	return 0;
}


void SM_SpikeTube::Render(RenderContext* prc, int iOutcode, float time)
{
	UpdateMesh(prc, time-m_fStartTime);


	Shader* ps = GetShader(m_shader);
	if (ps)
	{
		for (int i=0; i<ps->m_uPasses; i++)
		{
			ps->SetShaderState(i);
			m_mesh.Render();
		}
	}
	else
	{
		m_mesh.Render();
	}
}


void LookAt(const vec3_t& pos, const vec3_t& at, const vec3_t& up)
{
	Matrix4X4 mat;
	mat.LookAt(pos, at, up);
	SM_D3d::Device()->SetTransform(D3DTS_VIEW, (D3DMATRIX*)&mat);
}


void SM_SpikeTube::UpdateMesh(RenderContext* prc, float time)
{
	m_mesh.LockVerts();

	// get current camera position
	D3DMATRIX mat;
	SM_D3d::Device()->GetTransform(D3DTS_VIEW, (D3DMATRIX*)&mat);

	m_camX = 0;
	if (prc)
		m_camX = prc->GetViewport()->m_v3dPosition.x;

	// build skeleton
	int oldNumSegs = m_numSegments;
  time *= 1.5f;
	BuildSkeleton(prc, time);

	// don't know how many segs we'll have (now variable based on camera)
	// so have to rebuild indices every frame
	if (m_numSegments != oldNumSegs)
		BuildIndices2();

#if 1
	// build normals from mesh
	for (int i=0; i<m_mesh.NumIndices(); )
	{
		WORD i0 = m_mesh.GetIndex(i++);
		WORD i1 = m_mesh.GetIndex(i++);
		WORD i2 = m_mesh.GetIndex(i++);
		MakeAddNormal(i0,i1,i2);
	}

	// normalize each vert's normal
	for (int n=0; n< m_mesh.NumVerts(); n++)
	{
		ntxvertex& v = *(ntxvertex*)m_mesh.GetVert(n);
		Normalize(v.n);
	}
#endif

#ifdef SHOW_NORMALS
	m_normals.clear();
	for (int i=0; i<m_mesh.NumVerts(); i++)
	{
		ntxvertex& v = *(ntxvertex*)m_mesh.GetVert(i);
		m_normals.push_back(cvertex(v.p, 0xff0000));
		m_normals.push_back(cvertex(v.p + v.n * 2.0f, 0xffffffff));
	}
#endif

	m_mesh.UnlockVerts();
}


void SM_SpikeTube::BuildIndices2()
{
	m_mesh.ResetIndices();

	const int segRingSize = m_numCells*m_numPerSide-m_numCells;
	int numPerCell = m_numPerSide-2 + (m_numRings-1)*m_res; // column + rings
//		int numPerRing = segRingSize + m_numCells*numPerCell;
	int numPerSeg = m_numCells*numPerCell + segRingSize;

	ASSERT(m_numSegments*numPerSeg+segRingSize == m_mesh.NumVerts());

	std::vector<WORD> outerRing(m_res);
	int k=0;
	
	for (int seg=0; seg<m_numSegments; seg++)
	{
		const int ofs = GetOfs(seg);
		for (int c=0; c<m_numCells; c++)
		{
			int o = (ofs + c*(m_numPerSide-1))%segRingSize;

			// build outer ring
			k=0;

			// top
			int i;
			for (i=0; i<m_numPerSide; i++)
				outerRing[k++] = seg * numPerSeg + (o+i)%segRingSize;
			
			// right side
			int tro = seg * numPerSeg + segRingSize + ((c+1)%m_numCells)*numPerCell;
			for (i=0; i<m_numPerSide-2; i++)
				outerRing[k++] = tro+i;

			// bottom
			int tbo = (ofs + c*(m_numPerSide-1))%segRingSize;
			for (i=0; i<m_numPerSide; i++)
				outerRing[k+m_numPerSide-i-1] = (seg+1) * numPerSeg + (tbo+i)%segRingSize;
			k+=m_numPerSide;

			// left side
			int tlo = seg * numPerSeg + segRingSize + c*numPerCell;
			for (i=0; i<m_numPerSide-2; i++)
				outerRing[k+(m_numPerSide-2)-i-1] = tlo+i;
			k+=m_numPerSide-2;

			ASSERT(k <= m_res);

#if 0
			for (int t=0; t<m_res; t++)
			{
				char tc[128];
				sprintf(tc, "%d, ", outerRing[t]);
				OutputDebugString(tc);
			}
#endif

			int v0 = seg * numPerSeg + segRingSize + c*numPerCell + m_numPerSide-2;

			// stitch rings to skeloton
			for (int x=0; x<m_res; x++)
			{
				int i0 = outerRing[x];
				int i1 = outerRing[(x+1)%m_res];
				int i2 = v0 + x;
				int i3 = v0 + (x+1)%m_res;

				ASSERT(i0 < m_mesh.NumVerts());
				ASSERT(i1 < m_mesh.NumVerts());
				ASSERT(i2 < m_mesh.NumVerts());
				ASSERT(i3 < m_mesh.NumVerts());

				m_mesh.AddIndex(i0);
				m_mesh.AddIndex(i3);
				m_mesh.AddIndex(i1);

				m_mesh.AddIndex(i0);
				m_mesh.AddIndex(i2);
				m_mesh.AddIndex(i3);
			}

			for (int y=0; y<m_numRings-2; y++)
			{
				for (int x=0; x<m_res; x++)
				{
					const int i0 = v0 + y*m_res + x;
					const int i1 = v0 + y*m_res + (x+1)%m_res;
					const int i2 = v0 + (y+1)*m_res + x;
					const int i3 = v0 + (y+1)*m_res + (x+1)%m_res;

					ASSERT(i0 < m_mesh.NumVerts());
					ASSERT(i1 < m_mesh.NumVerts());
					ASSERT(i2 < m_mesh.NumVerts());
					ASSERT(i3 < m_mesh.NumVerts());

					m_mesh.AddIndex(i0);
					m_mesh.AddIndex(i3);
					m_mesh.AddIndex(i1);

					m_mesh.AddIndex(i0);
					m_mesh.AddIndex(i2);
					m_mesh.AddIndex(i3);
				}
			}
		}
	}
}

void SM_SpikeTube::BuildSkeleton(RenderContext* prc, float time)
{
	const float theta = c_2pi / m_numCells;
	const float h = 8;//sqrtf(2*m_segRadius*m_segRadius - 2*m_segRadius*m_segRadius*cos(theta));
	const float th = h * m_numSegments;

	//float startX = h*(int(m_camX/h)-int(m_camX/h)%2) - (m_numSegments/2)*h;
	//float startX = fmodf(m_camX, h*2)-(m_numSegments/2)*h;
	float startX = m_camX - fmodf(m_camX, h*2);

	if (prc)
	{
		// walk backwards to determine starting point
		int retcode = 0;
		float x;
		float ox = startX;
		int nPass=0;
		for (x=startX; ; x-=h*2, nPass++)
		{
			 retcode = prc->OutcodeSphere(&Vector3D(x, 0, 0), m_segRadius+m_spikeSize);
			 if (retcode == -1 || nPass > m_maxSegs*2)
				 break;
		}
		startX = x;

		// see how many segments we need
		int segs=0;
		bool foundValid=false;
		for (x=startX; ; x+=h, segs++)
		{
			 retcode = prc->OutcodeSphere(&Vector3D(x, 0, 0), m_segRadius+m_spikeSize);
			 if (retcode != -1)
				 foundValid = true;
			 if (foundValid && retcode == -1)
				 break;
			if (segs > m_maxSegs)
				break;
		}
	
		if (segs > m_maxSegs)
		{
			int nPass=0;
			for (x=m_camX - fmodf(m_camX, h*2); ; x+=h*2, nPass++)
			{
				 retcode = prc->OutcodeSphere(&Vector3D(x, 0, 0), m_segRadius+m_spikeSize);
				 if (retcode == -1 || nPass > m_maxSegs)
					 break;
			}
			startX = x - h*m_maxSegs;

			segs = m_maxSegs;
		}
		
		m_numSegments = segs;
	}

	// add all rings
	for (int seg=0; seg<m_numSegments+1; seg++)
	{
		//const float fseg = float(seg)/float(m_numSegments-1);
		float x = startX + seg * h;
		AddSegRing(x, h, time);

		const int ofs = GetOfs(seg);
		if (seg != m_numSegments)
			AddColumnsForRing(x, h, ofs, time);
	}
}

void SM_SpikeTube::AddCell2(float a0, float a1, float midX, float time)
{
	const float aMid = .5f*(a0+a1);
	float segRadius = 10;

	// find basis for this cell
	const vec3_t N(0, cos(aMid), sin(aMid));
	const vec3_t B(1,0,0);
	const vec3_t T(0, -sin(aMid), cos(aMid));
	//const vec3_t origin(segRadius*N);

	const vec3_t p0 = vec3_t(0, m_segRadius*cos(a0), m_segRadius*sin(a0));
	const vec3_t p1 = vec3_t(0, m_segRadius*cos(a1), m_segRadius*sin(a1));

	const float tLength = (p1-p0).Length();
	const vec3_t origin = .5f*(p1+p0);

	// precalc sins and coss for inner loop
	static std::vector<float> sinus;//(m_res);
	static std::vector<float> cosinus;//(m_res);
	if (!sinus.size())
	{
		sinus.resize(m_res);
		cosinus.resize(m_res);
		for (int i=0; i<m_res; i++)
		{
			float a = -float(i)/float(m_res)*c_2pi + 135.0f*c_pi/180;
			cosinus[i] = cos(a);
			sinus[i] = -sin(a);
		}
	}

	// for each ring
	for (int r=1; r<m_numRings; r++)
	{
		float tr = float(r)/float(m_numRings-1);
		float ringRadius = (1.0f-tr) * tLength/2;
		if (r == 0)
			ringRadius = tLength / sqrtf(2.0f);

		const float dispScale = GetSpikeScale(midX, time);//midX/63.232);
		const float normDisp = tr*tr*dispScale;

		const vec3_t spikeFactor = N*normDisp*m_spikeSize + vec3_t(midX + normDisp*normDisp*normDisp*m_spikeSize,0,0);

		// for each point on ring
		vec3_t p;
		for (int n=0; n<m_res; n++)
		{
			//float a = -float(n)/float(m_res)*c_2pi + 135.0f*c_pi/180;
			//float px = ringRadius*cos(a);
			//float py = -ringRadius*sin(a);

			float px = ringRadius*cosinus[n];
			float py = ringRadius*sinus[n];

			// transform by basis
#if 1
			p = origin + px*T;
			p = p/p.Length()*GetRadius(py*B.x+midX, time);
			p += py*B;
			p += spikeFactor;
			//vec3_t p = vec3_t(m_segRadius*cos(ainc), midY, m_segRadius*sin(ainc)) + py*B;// + N*normDisp*m_spikeSize + vec3_t(0,normDisp*normDisp*normDisp*m_spikeSize,0);
#else
			vec3_t p = origin + vec3_t(midX,0,0) + px*T + py*B + N*normDisp*m_spikeSize + vec3_t(0,normDisp*normDisp*normDisp*m_spikeSize,0);
#endif
			//p.y += GetVertOfs(p.x,time);
			m_mesh.AddVert(&ntxvertex(p, vec3_t(0,0,0), 0,0));
		}
	}	
}

void SM_SpikeTube::AddColumnsForRing(float x, float w, int ofs, float time)
{
	for (int c=0; c<m_numCells; c++)
	{
		const int segRingSize = m_numCells*m_numPerSide-m_numCells;
		int n = (c*(m_numPerSide-1)+ofs);//%segRingSize;
		int n2 = ((c+1)*(m_numPerSide-1)+ofs);//%segRingSize;

		const float a = float(n)/float(segRingSize)*c_2pi;
		const float a2 = float(n2)/float(segRingSize)*c_2pi;

		vec3_t p(0, cos(a), sin(a));

		// add column 
		for (int i=1; i<m_numPerSide-1; i++)
		{
			float t = float(i)/float(m_numPerSide-1);
			const float segRadius = GetRadius(x+t*w, time);
			const float y = 0;//GetVertOfs(x+t*w,time);
			m_mesh.AddVert(&nvertex(vec3_t(x+t*w,y,0)+segRadius*p, vec3_t(0,0,0)));
		}

		// add cell
		AddCell2(a, a2, x + .5f*w, time);
	}
}

void SM_SpikeTube::AddSegRing(float x, float w, float time)
{
	const int segRingSize = m_numCells*m_numPerSide-m_numCells;
	const float segRadius = GetRadius(x, time);
	for (int i=0; i<segRingSize; i++)
	{
		const float t = float(i)/float(segRingSize);
		const float a = t*c_2pi;

		m_mesh.AddVert(&nvertex(vec3_t(x, segRadius*cos(a)/*+GetVertOfs(x,time)*/, segRadius*sin(a)), vec3_t(0,0,0)));
	}
}

void SM_SpikeTube::MakeAddNormal(int i0, int i1, int i2)
{
	ntxvertex& v0 = *(ntxvertex*)m_mesh.GetVert(i0);
	ntxvertex& v1 = *(ntxvertex*)m_mesh.GetVert(i1);
	ntxvertex& v2 = *(ntxvertex*)m_mesh.GetVert(i2);

	vec3_t u = v2.p-v0.p;
	vec3_t v = v1.p-v0.p;
	vec3_t n;
	
	Cross(n, u, v);
	//Normalize(n);

	v0.n += n;
	v1.n += n;
	v2.n += n;
}

