// DEFINES
#define MAX_BONES				32
#define MAX_DIRECTIONAL_LIGHTS	16
#define MAX_POINT_LIGHTS		16

#define PI		(3.14159f)
#define TWOPI	(6.28318f)

// VERTEX FORMAT
struct AnimatedVertexIn
{
	float4 position					: POSITION;
	uint4 blendIndices				: BLEND_INDICES;
	float4 blendWeights				: BLEND_WEIGHTS;

	float4 color					: COLOR0;
	float2 uv						: TEXCOORD0;
	float3 normal					: NORMAL;
	float3 binormal					: BINORMAL;
	float3 tangent					: TANGENT;
};

struct DefaultVertexIn
{
	float4 position					: POSITION;
	float4 color					: COLOR0;
	float2 uv						: TEXCOORD0;
	float3 normal					: NORMAL;
	float3 binormal					: BINORMAL;
	float3 tangent					: TANGENT;
};

struct DefaultVertexOut
{
	float4 position					: SV_POSITION;
	float4 wpos						: WORLDPOS;
	float4 color					: COLOR0;
	float2 uv						: TEXCOORD0;
	float3 normal					: NORMAL;
	float3 binormal					: BINORMAL;
	float3 tangent					: TANGENT;
};

struct FullScreenVertexOut
{
	float4 position					: SV_POSITION;
	float2 uv						: TEXCOORD0;
};

// LIGHTING STRUCTS
struct DirectionalLight
{
	float4 direction;
	float4 colorIntensity;
};

struct PointLight
{
	float4 position;
	float4 colorIntensity;	// w is intensity
	float2 range;			// y is invRange
	float2 padding0;
	float3 attenuation;		// Follows attenuation equation (x + (y * d) + (z * d^2)) where d = distance between light and pixel
	float padding1;
};

// CONSTANT BUFFERS
cbuffer PerFrame : register(b0)
{
	float4x4 World;
	float4x4 View;
	float4x4 Projection;
	float4x4 ViewInverse;
	float4x4 ProjectionInverse;

	float4 CameraEye;
	float4 CameraDir;
	float4 CameraRight;
	float4 CameraUp;

	float4 cameraParams;		// x = near, y = far, z = near * far, w = far - near

	float4 GlobalAmbient;

	float4 SceneTime;			// x = dt, y = currentTime, z = prevTime
	float4 ScreenDimensions;	// z and w = inv x and inv y respectively

	float4 ExtraParams;			// x = alpha of geometry, y = diffuse coeff, z = specular coeff, w = specular power
}

cbuffer Lighting : register(b1)
{
	float4				ambLightColor;

	DirectionalLight	directionalLights[MAX_DIRECTIONAL_LIGHTS];
	PointLight			pointLights[MAX_POINT_LIGHTS];
};

// SAMPLERS ////////////////////////////////////////////////////////////////////////////////////
SamplerState PointWrapSampler;
SamplerState BilinearWrapSampler;
SamplerState TrilinearWrapSampler;
SamplerState AnisoWrapSampler;

SamplerState PointClampSampler;
SamplerState BilinearClampSampler;
SamplerState TrilinearClampSampler;
SamplerState AnisoClampSampler;

// GLOBAL RESOURCES ////////////////////////////////////////////////////////////////////////////
// Deferred buffers
//Texture2D<float4> ColorBuffer;
//Texture2D<float4> NormalDepthBuffer;
//Texture2D<float4> WorldPosBuffer;
Texture2D<float4> DepthBuffer;

// HELPER FUNCTIONS ////////////////////////////////////////////////////////////////////////////
// Returns view space depth from post view space depth value
float getLinearDepth(float depth)
{
	return -(cameraParams.z / (depth * cameraParams.w - cameraParams.y));
}

// Samples a 2D texture made up of slices as a 3D texture
float4 Sample3D(Texture2D<float4> tex, float3 uv, float sliceSize, float slicePixelSize, float textureWidth)
{
	float slice0 = min(floor(uv.z * textureWidth), textureWidth);
	float slice1 = min(slice0 + 1.0f, textureWidth);
	float sliceXSize = slicePixelSize * textureWidth;
	float x = uv.x * sliceXSize + (slicePixelSize * 0.5f);
	float s0 = slice0 * sliceSize + x;
	float s1 = slice1 * sliceSize + x;

	float4 c0 = tex.Sample(PointClampSampler, float2(s0, uv.y));
	float4 c1 = tex.Sample(PointClampSampler, float2(s1, uv.y));

	float zOffset = fmod(uv.z * textureWidth, 1.0f);
	return lerp(c0, c1, zOffset);
}

float impulse(float k, float x)
{
	const float h = k*x;
	return h * exp(1.0f - h);
}

//float3 blinnPhong(float3 color, float3 viewDir, float3 lightDir, float3 lightColor, float lightIntensity, float3 normal, float shininess, float kd, float ks)
//{
//	float diffuseAmt = saturate(dot(lightDir, normal));
//	float3 halfDir = normalize(lightDir + viewDir);
//	float specularAmt = saturate(dot(halfDir, normal));
//
//	float3 diffuse = kd * diffuseAmt * color;
//	float3 specular = (diffuseAmt > 0.0f) ? pow(specularAmt, shininess) : float3(0.0f, 0.0f, 0.0f);
//	specular *= ks;
//
//	return (saturate(diffuse + specular) * lightColor * lightIntensity);
//}

float3 blinnPhong(float3 cameraDir, float3 lightDir, float3 lightColor, float lightIntensity, float3 normal, float shininess, float kd, float ks)
{
	float lambert = saturate(dot(normal, lightDir));

	float3 h = normalize(cameraDir + lightDir);
	float specularAmt = pow(saturate(dot(h, normal)), shininess);

	return saturate((lightColor * lightIntensity * lambert * kd) + (float3(1.0f, 1.0f, 1.0f) * specularAmt * ks));
}

float2 ggxFV(float dotLH, float roughness)
{
	float alpha = roughness * roughness;

	// F
	float dotLH5 = pow(1.0f - dotLH, 5.0f);
	float fa = 1.0f;
	float fb = dotLH5;

	// V
	float k = alpha * 0.5f;
	float k2 = k * k;
	float invk2 = 1.0f - k2;
	float vis = rcp(dotLH * dotLH * invk2 + k2);

	return float2(fa * vis, fb * vis);
}

float ggxD(float dotNH, float roughness)
{
	float alpha = roughness * roughness;
	float alphaSq = alpha * alpha;
	const float pi = 3.14159f;
	float denom = dotNH * dotNH * (alphaSq - 1.0f) + 1.0f;
	return alphaSq / (pi * denom * denom);
}

float ggx(float3 n, float3 v, float3 l, float roughness, float F0)
{
	float3 h = normalize(v+l);
	float dotNL = saturate(dot(n,l));
	float dotLH = saturate(dot(l,h));
	float dotNH = saturate(dot(n,h));

	float D = ggxD(dotNH, roughness);
	float2 FVHelper = ggxFV(dotLH, roughness);
	float FV = F0 * FVHelper.x + (1.0f - F0) * FVHelper.y;
	float specular = dotNL * D * FV;
	return specular;
}

float3 cookTorrance(float3 cameraDir, float3 lightDir, float3 lightColor, float lightIntensity, float3 normal, float roughness, float f0)
{
	float3 diffuse = lightColor * lightIntensity * saturate(dot(normal, lightDir));
	float3 specular = float3(1.0f, 1.0f, 1.0f) * ggx(normal, cameraDir, lightDir, roughness, f0);
	return saturate(diffuse + specular);
}

float3 renderDirectionalLight(DirectionalLight light, float3 position, float3 normal)
{
	//return blinnPhong(CameraEye.xyz - position,
	//				  -light.direction.xyz, 
	//				  light.colorIntensity.xyz, 
	//				  light.colorIntensity.w,
	//				  normal,
	//				  ExtraParams.w,
	//				  ExtraParams.y,
	//				  ExtraParams.z);

	return cookTorrance(normalize(CameraEye.xyz - position),
						-light.direction.xyz,
						light.colorIntensity.xyz,
						light.colorIntensity.w,
						normal,
						0.2f,
						2.0f);
}

float3 renderPointLight(PointLight light, float3 color, float3 wpos, float3 normal)
{
	//float3 lightDir = light.position.xyz - wpos;
	float3 lightDir = light.position.xyz - wpos;
	float d = distance(light.position.xyz, wpos);
	float rangeScalar = saturate(d * light.range.y);
	lightDir = normalize(lightDir);

	float ndotl = saturate(dot(normal, lightDir));
	float attenuation = light.attenuation.x + (light.attenuation.y * d) + (light.attenuation.z * d * d);
	float invAttenuation = (attenuation <= 0.0f) ? 1.0f : rcp(attenuation);

	float3 lightColor = light.colorIntensity.xyz * light.colorIntensity.w;
	float3 diffuse = color * lightColor * ndotl;
	float3 specular = float3(0.0f, 0.0f, 0.0f);

	float3 finalColor = float3(0.0f, 0.0f, 0.0f);
	
	if(light.colorIntensity.w > 0.0f)
	{
		finalColor = (diffuse + specular) * invAttenuation;
	}

	//return rangeScalar * diffuse;
	return finalColor;// / attenuation;
}

// FULL SCREEN VERTEX SHADER
FullScreenVertexOut FullScreenVS(uint id : SV_VertexID)
{
	FullScreenVertexOut output;
	output.uv = float2((id << 1) & 2, id & 2);
	output.position = float4(output.uv * float2(2.0f, -2.0f) + float2(-1.0f, 1.0f), 0.0f, 1.0f);
	return output;
}