#version 450

#ifndef SPIRV_ENABLED
#extension GL_NV_gpu_shader5 : enable
#else
#extension GL_EXT_shader_explicit_arithmetic_types : enable
#extension GL_ARB_shader_ballot : enable
#extension GL_KHR_shader_subgroup_ballot : enable
#define HAS_16BIT_TYPES
#endif

#include <shaders/materials/commons.glsl>
#include <shaders/geometry_partitioning/raytrace_buffers.glsl>
#include <shaders/geometry_partitioning/raytrace_commons.glsl>

layout(triangles) in;
layout(triangle_strip, max_vertices = 3) out;

#ifndef SPIRV_VULKAN
layout(r32ui) uniform uimage3D imGridMarkers0;
layout(r32ui) uniform uimage3D imGridMarkers2;
#else
layout(r32ui) uniform uimage3D imGridMarkers0;
layout(r32ui) uniform uimage3D imGridMarkers2;
#endif

// for projection we use calculated bounding box converted to the grid coords

#ifndef SPIRV_VULKAN
in Vertex
{
	vec3 vLocalPos;
	f16vec4 vColor;
	f16vec2 vUV0;
	uint    vIdx;
} vtx_inputs[];
#endif

#ifdef SPIRV_VULKAN
layout(location = 1) in Vertex
{
	vec3    vLocalPos;
#ifdef PARTITION_VOXELIZE
	f16vec4 vColor;
	f16vec2 vUV0;
#endif
	uint    vIdx;
} vtx_inputs[];
#endif

// no interpolation, can use intergers
#ifdef SPIRV_VULKAN

layout(location = 0) flat out int vOrientationIndex;
#ifdef PARTITION_VOXELIZE
layout(location = 1) out vec3 vGridCoords;
layout(location = 2) out f16vec2 vUV0;
layout(location = 8) out vec4 vColor;
layout(location = 4) flat out f16vec3 vTriNormalForLighting;
#endif
layout(location = 3) flat out f16vec3 vTriNormal;  // this normal is for the scaled (nonuniform) coords, not real triangle normal

layout(location = 5) flat out uint vTriIndex[3];

#else

flat out int vOrientationIndex;
out vec3 vGridCoords;
out f16vec2 vUV0;
flat out uint vTriIndex[3];
flat out vec3 vTriCoords[3];
flat out f16vec3 vTriNormal;  // this normal is for the scaled (nonuniform) coords, not real triangle normal
flat out f16vec3 vTriNormalForLighting;
//flat out float vTriArea;

#endif

struct PartitionGeometryDrawParams
{
	uint instance_idx;
};


layout(std140, row_major) uniform PartitionGeometryDrawParamsBuffer{
	PartitionGeometryDrawParams partition_geometry_params;
};

// check if the triangle is contained in the single cell. this is a small primitive optimization
#ifdef PARTITION_RAYTRACE
bool is_contained_in_single_cell(vec3 v0, vec3 v1, vec3 v2, out ivec3 p)
{
	vec3 bbox_origin    = in_bbox_data.bbox_raytrace_min.xyz;
	vec3 bbox_grid_size = in_bbox_data.grid_size_raytrace.xyz;

	vec3 s = -bbox_origin;
	v0 += s;
	v1 += s;
	v2 += s;
	
	vec3 f = vec3(1.0) / bbox_grid_size;
	v0 *= f;
	v1 *= f;
	v2 *= f;

	// triangle bbox. but don't ceil the max side as we want to know if min and max
	// are in the same cell

	ivec3 mi = ivec3(floor(min(v0, min(v1, v2))));
	ivec3 ma = ivec3(floor(max(v0, max(v1, v2))));

	p = mi;
	if (mi.x == ma.x && mi.y == ma.y && mi.z == ma.z)
		return true;

	return false;
}
#endif

// we need to output orientation to properly place it into grid when running in PS
void main()
{
#ifdef PARTITION_RAYTRACE
	vec3 bbox_origin = in_bbox_data.bbox_raytrace_min.xyz;
	vec3 bbox_grid_size = in_bbox_data.grid_size_raytrace.xyz;
#endif
#ifdef PARTITION_VOXELIZE
	vec3 bbox_origin = in_bbox_data.bbox_voxelize_min.xyz;
	vec3 bbox_grid_size = in_bbox_data.grid_size_voxelize.xyz;
#endif

	// build normal to calculate desired orientation. absolute value
	vec3 tri_normal_for_lighting;
	vec3 n = cross((vtx_inputs[1].vLocalPos - vtx_inputs[0].vLocalPos), (vtx_inputs[2].vLocalPos - vtx_inputs[0].vLocalPos));
#ifdef PARTITION_VOXELIZE
	vTriNormal = f16vec3(normalize(n));
#endif
	{
		vec3 nn = cross((vtx_inputs[1].vLocalPos - vtx_inputs[0].vLocalPos), (vtx_inputs[2].vLocalPos - vtx_inputs[0].vLocalPos));
		tri_normal_for_lighting = f16vec3(normalize(nn));
	}

	n = abs(n);
	int orientation_index;

	vec3 v0, v1, v2;

	uint first_face_idx = 0;
	if (partition_geometry_params.instance_idx > 0)
		first_face_idx = transformed_data_location[partition_geometry_params.instance_idx - 1].last_face_idx + 1;

	// translate and scale gerometry. do not use bbox as we need grid-aligned offsets
	// TODO: we can actually convert bbox to be grid_size aligned (cell size aligned)

	v0 = vtx_inputs[0].vLocalPos.xyz;
	v1 = vtx_inputs[1].vLocalPos.xyz;
	v2 = vtx_inputs[2].vLocalPos.xyz;

	// small primitive optimization. NOTE: seems we can not write both from GS and FS
	// so only the counting pass is optimized for now:(
#if defined(PARTITION_RAYTRACE) // && defined(PARTITION_RAYTRACE_COUNTING_PASS)

	if (true)
	{
		ivec3 cell;
		bool is_single_cell_triangle = is_contained_in_single_cell(v0, v1, v2, cell);
		if (is_single_cell_triangle)
		{
			uint ccx = cell.x;
			uint ccy = cell.y;
			uint ccz = cell.z;

#ifdef PARTITION_RAYTRACE_COUNTING_PASS

			uint list_index = ccx + ccy * GRID_RES + ccz * GRID_RES * GRID_RES;
			rt_link_list_increase_count(list_index);
			rt_write_grid_marker_high_res(imGridMarkers0, ivec3(ccx, ccy, ccz));
#else
			uint list_index = ccx + ccy * GRID_RES + ccz * GRID_RES * GRID_RES;
			rt_link_list_push_value(list_index, gl_PrimitiveIDIn + int(first_face_idx));

			rt_write_grid_marker_low_res(imGridMarkers2, ivec3(ccx, ccy, ccz) >> 2);

#endif // PARTITION_RAYTRACE_COUNTING_PASS

			return;
		}
		else
		{
			//return;
		}
	}

#endif // PARTITION_RAYTRACE && PARTITION_RAYTRACE_COUNTING_PASS

	vec3 s = -bbox_origin;
	v0 += s;
	v1 += s;
	v2 += s;
	
	vec3 f = vec3(1.0) / bbox_grid_size;
	v0 *= f;
	v1 *= f;
	v2 *= f;

	float area = length(cross(v1 - v0, v2 - v0));
	if (area <= 0.001)
		return;

	vec3 grid_v0 = v0;
	vec3 grid_v1 = v1;
	vec3 grid_v2 = v2;
	vec3 tri_normal = f16vec3(cross(v1 - v0, v2 - v1) / area);

#if 1
	if (n.x > n.y && n.x > n.z)
	{
		// X dominant
		grid_v0 = grid_v0.yzx;
		grid_v1 = grid_v1.yzx;
		grid_v2 = grid_v2.yzx;
		orientation_index = 0;
	}
	else if (n.y > n.x && n.y > n.z)
	{
		// Y dominant
		grid_v0 = grid_v0.zxy;
		grid_v1 = grid_v1.zxy;
		grid_v2 = grid_v2.zxy;
		orientation_index = 1;
	}
	else// if (n.x > n.y && n.x > n.z)
	{
		// Z dominant
		grid_v0 = grid_v0.xyz;
		grid_v1 = grid_v1.xyz;
		grid_v2 = grid_v2.xyz;
		orientation_index = 2;
	}
#else
	// Z dominant
	grid_v0 = grid_v0.xyz;
	grid_v1 = grid_v1.xyz;
	grid_v2 = grid_v2.xyz;
	orientation_index = 2;

#endif

#if 1
	// NOTE: On Vulkan seems we have to output gl_PrimitiveID for each vtx? At least RenderDoc reports discrepancy here

	vTriIndex[0] = vtx_inputs[0].vIdx;
	vTriIndex[1] = vtx_inputs[1].vIdx;
	vTriIndex[2] = vtx_inputs[2].vIdx;

#ifdef PARTITION_VOXELIZE
	vTriNormalForLighting = f16vec3(tri_normal_for_lighting);
#endif
	vTriNormal = f16vec3(tri_normal); // match tri-box intersection code

	gl_PrimitiveID = gl_PrimitiveIDIn + int(first_face_idx);
	vOrientationIndex = orientation_index;
	gl_Position = vec4(grid_v0 * 2.0 - vec3(GRID_RES), GRID_RES);
	
#ifdef PARTITION_VOXELIZE
	vGridCoords = grid_v0;
	vUV0   = vtx_inputs[0].vUV0;
	vColor = vtx_inputs[2].vColor;
#endif

	//vTriArea = area;
	EmitVertex();

	vTriIndex[0] = vtx_inputs[0].vIdx;
	vTriIndex[1] = vtx_inputs[1].vIdx;
	vTriIndex[2] = vtx_inputs[2].vIdx;
#ifdef PARTITION_VOXELIZE
	vTriNormalForLighting = f16vec3(tri_normal_for_lighting);
#endif
	vTriNormal = f16vec3(tri_normal);

	gl_PrimitiveID = gl_PrimitiveIDIn + int(first_face_idx);
	vOrientationIndex = orientation_index;
	gl_Position = vec4(grid_v1 * 2.0 - vec3(GRID_RES), GRID_RES);
	
#ifdef PARTITION_VOXELIZE
	vGridCoords = grid_v1;
	vUV0   = vtx_inputs[1].vUV0;
	vColor = vtx_inputs[2].vColor;
#endif

	//vTriArea = area;
	EmitVertex();

	vTriIndex[0] = vtx_inputs[0].vIdx;
	vTriIndex[1] = vtx_inputs[1].vIdx;
	vTriIndex[2] = vtx_inputs[2].vIdx;
#ifdef PARTITION_VOXELIZE
	vTriNormalForLighting = f16vec3(tri_normal_for_lighting);
#endif
	vTriNormal = f16vec3(tri_normal);

	gl_PrimitiveID = gl_PrimitiveIDIn + int(first_face_idx);
	vOrientationIndex = orientation_index;
	gl_Position = vec4(grid_v2 * 2.0 - vec3(GRID_RES), GRID_RES);
	
#ifdef PARTITION_VOXELIZE
	vGridCoords = grid_v2;
	vUV0   = vtx_inputs[2].vUV0;
	vColor = vtx_inputs[2].vColor;
#endif

	//vTriArea = area;
	EmitVertex();
	
	EndPrimitive();
#else


	EndPrimitive();
#endif
}

