#version 330 core

uniform mat4 ModelViewMatrix;
uniform mat4 ModelViewProjectionMatrix;

uniform vec4 AnimationParams; // tile_width, tile_height, width_num_tiles, height_num_tiles

in vec3 gs_params[]; // radius, energy, 1/mass

out vec2 uv;
out float alpha_fade;

layout(points) in;
layout(triangle_strip, max_vertices = 4) out;

#define EMIT_VERTEX(VTX, UV) \
	uv = UV; \
	alpha_fade = alpha; \
	gl_Position = ModelViewProjectionMatrix * vec4(VTX, 1.0); \
	EmitVertex()

vec2 get_tile(float index)
{
	float tile_index = floor((1.0 - index)*AnimationParams.z*AnimationParams.w);
	float y = floor(tile_index / AnimationParams.z);
	float x = tile_index - y*AnimationParams.z;	
	return vec2(x, y) / AnimationParams.zw;
}

void main()
{
	if ( gs_params[0].x > 0.0 )
	{
		vec3 XAxis = vec3(ModelViewMatrix[0][0], ModelViewMatrix[1][0], ModelViewMatrix[2][0]);
		vec3 YAxis = vec3(ModelViewMatrix[0][1], ModelViewMatrix[1][1], ModelViewMatrix[2][1]);

		vec3 pos = gl_in[0].gl_Position.xyz;

		vec2 tile_pos = get_tile(gs_params[0].z);

		vec3 eye_pos = vec3(ModelViewMatrix * vec4(pos, 1.0));
		float dist = dot(vec3(0.0, 0.0, -1.0), eye_pos ) - 1.0;
		float alpha = smoothstep(0.0, 1.0, dist) * gs_params[0].y;

		float radius = gs_params[0].x;

		vec3 a = pos + (XAxis + YAxis)*radius;
		vec3 b = pos + (YAxis - XAxis)*radius;
		vec3 c = pos + (-XAxis - YAxis)*radius;
		vec3 d = pos + (XAxis - YAxis)*radius;

		EMIT_VERTEX(a, tile_pos + vec2(AnimationParams.x, 0.0));
		EMIT_VERTEX(b, tile_pos);
		EMIT_VERTEX(d, tile_pos + AnimationParams.xy);
		EMIT_VERTEX(c, tile_pos + vec2(0.0, AnimationParams.y));

		EndPrimitive();
	}
}
