#version 460

#ifndef SPIRV_ENABLED
#extension GL_NV_gpu_shader5 : enable
#endif

layout(points) in;
// NOTE: Aparently VK doesn't allow 0 to be output... We might reconsider actually patching VS instead of using GS...
layout(points, max_vertices = 1) out;

#include <shaders/materials/commons.glsl>
#include <shaders/commons_hlsl.glsl>
#include <shaders/materials/noise/noise3d.glsl>
#include <shaders/materials/commons_instancing_buffers.h>

#include <shaders/geometry_partitioning/raytrace_buffers.glsl>

layout(std430) buffer TransformedDataGeometryBuffer {
	float data[];
} out_vtx_data;

layout(std140, row_major) uniform TransformParamsBuffer{
	EntityTransformParams transform_params;
};

// Sync with raytrace_partition_geometry_add_surface.comp

struct TransformGeometryOutputVertexParams
{
	int coord_stride;
	int normal_stride;
	int uv0_stride;
	int coord_offset;
	int normal_offset;
	int uv0_offset;
	int _pad0;
	int _pad1;
};

struct TransformGeometryParams
{
	TransformGeometryOutputVertexParams output_vtx_params;

	uint instance_idx;	// which entry to setup in the TransformedDataLocation buffer
	uint surface_idx;
	uint voxelize;
	uint raytrace;

	uint material_idx;

	int _points_per_instance;
	int _faces_per_instance;
	int transform_normals;

	int calculate_bbox_raytrace;
	int calculate_bbox_voxelize;

	uint _instances_num;
	uint _pad1;

	vec3 default_bbox_raytrace_min;
	vec3 default_bbox_raytrace_max;
	vec3 default_bbox_voxelize_min;
	vec3 default_bbox_voxelize_max;
};


layout(std140, row_major) uniform TransformGeometryParamsBuffer {
	TransformGeometryParams transform_geometry_params;
};

#if 0
out Vertex
{
	vec3 vCoords;
	vec3 vNorm;
	vec3 vWorldNorm;
	vec3 vLocalPos;
	vec3 vWorldPos;
	vec4 vColor;
	vec2 vUV0;
	uint vIdx;
} vtx_output;
#endif

#ifndef SPIRV_ENABLED
in Vertex
{
	vec3    vCoords;
	vec3    vNorm;
	vec3    vWorldNorm;
	vec3    vLocalPos;
	vec3    vCameraRelativeWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_inputs[];

in uint instanceID[];
#else
layout(location = 1) in Vertex
{
	vec3    vCoords;
	vec3    vNorm;
	vec3    vWorldNorm;
	vec3    vLocalPos;
	vec3    vCameraRelativeWorldPos;
	f16vec4 vColor;
	f16vec2 vUV0;
} vtx_inputs[];

layout(location = 0) in uint instanceID[];
#endif

// raytracer export

void put_coords(uint idx, vec3 p)
{
	uint coord_offset = transform_geometry_params.output_vtx_params.coord_offset + idx * transform_geometry_params.output_vtx_params.coord_stride;
	out_vtx_data.data[coord_offset + 0] = p.x;
	out_vtx_data.data[coord_offset + 1] = p.y;
	out_vtx_data.data[coord_offset + 2] = p.z;
}

// NOTE: Needs to match rt_get_vertex_normal() in raytrace_commons.glsl 
void put_normal(uint idx, vec3 p)
{
	//uint normal_offset = out_vtx_data_normal_offset + idx * out_vtx_data_normal_stride;
	uint normal_offset = idx * VERTEX_NORMAL_STRIDE_FLOATS + VERTEX_NORMAL_OFFSET_FLOATS;
	//out_vtx_data.data[normal_offset + 0] = p.x;
	//out_vtx_data.data[normal_offset + 1] = p.y;
	//out_vtx_data.data[normal_offset + 2] = p.z;
	//vec3 p = vec3(1.0, 0.0, 0.0);
	uint n1 = packSnorm2x16(p.xy);
	uint n2 = packSnorm2x16(vec2(p.z, 0.0));
	out_vtx_data.data[normal_offset + 0] = asfloat(n1);
	out_vtx_data.data[normal_offset + 1] = asfloat(n2);
}

void put_uv0(uint idx, vec2 p)
{
	//uint uv0_offset = transform_geometry_params.output_vtx_params.uv0_offset + idx * transform_geometry_params.output_vtx_params.uv0_stride;
	uint uv0_offset = idx * VERTEX_UV0_STRIDE_FLOATS + VERTEX_UV0_OFFSET_FLOATS;
	out_vtx_data.data[uv0_offset + 0] = p.x;
	out_vtx_data.data[uv0_offset + 1] = p.y;
}

void put_color(uint idx, vec4 c)
{
	c = clamp(c, vec4(0.0), vec4(1.0));

	//uint color_offset = transform_geometry_params.output_vtx_params.color_offset + idx * transform_geometry_params.output_vtx_params.color_stride;
	uint color_offset = idx * VERTEX_COLOR_STRIDE_FLOATS + VERTEX_COLOR_OFFSET_FLOATS;
	uint packed_color = (uint(c.r * 255.0) << 24) | (uint(c.g * 255.0) << 16) | (uint(c.b * 255.0) << 8) | (uint(c.a * 255.0) << 0);
	out_vtx_data.data[color_offset + 0] = uintBitsToFloat(packed_color);
}


//in int gl_PrimitiveIDIn;

void main()
{
	// NOTE: Either there is a bug, i'm doing something wrong or this is how it should work...
	// PrimitiveID restarts per instance
	uint vtx_idx = uint(gl_PrimitiveIDIn) + instanceID[0] * geometry_information.vtx_num;

	// bookkeeping

	uint prim_idx = transform_geometry_params.instance_idx;
	uint out_vtx_idx = 0;

	if (gl_PrimitiveIDIn == 0)
	{
		transformed_data_location[prim_idx].surface_idx = 0;
		transformed_data_location[prim_idx].last_face_idx = 0;
	}

	// updating last vtx idx
	// OPTIMIZATION: We can keep track and update as we go or just assume that full
	// geomtry is submitted and update once? This should work but keep in mind that it
	// relies on the full geom submission!
#if 0
	if (prim_idx == 0)
	{
		if (vtx_idx > transformed_data_location[prim_idx].last_vtx_idx)
			atomicMax(transformed_data_location[prim_idx].last_vtx_idx, vtx_idx);
		out_vtx_idx = vtx_idx;
	}
	else // need to update even if we don't export. don't have to do it for the first face as buffer is zeroed on the start of frame
	{
		out_vtx_idx = transformed_data_location[prim_idx - 1].last_vtx_idx + vtx_idx + 1;
		if (out_vtx_idx > transformed_data_location[prim_idx].last_vtx_idx)
		{
			atomicMax(transformed_data_location[prim_idx].last_vtx_idx, out_vtx_idx);
		}
	}
#else
	if (prim_idx == 0)
	{
		if (vtx_idx == 0)
			transformed_data_location[prim_idx].last_vtx_idx = geometry_information.vtx_num * instance_params.instance_count - 1;
		out_vtx_idx = vtx_idx;
	}
	else // need to update even if we don't export. don't have to do it for the first face as buffer is zeroed on the start of frame
	{
		out_vtx_idx = transformed_data_location[prim_idx - 1].last_vtx_idx + 1 + vtx_idx;
		if (vtx_idx == 0)
		{
			transformed_data_location[prim_idx].last_vtx_idx =
				transformed_data_location[prim_idx - 1].last_vtx_idx
				+ (geometry_information.vtx_num * instance_params.instance_count - 1) + 1;
		}
	}
#endif

	// output data

	vec3 p0 = vtx_inputs[0].vCameraRelativeWorldPos.xyz + transform_params.vCameraPosition;

	put_coords(out_vtx_idx, p0);

	vec3 n0 = vec3(0.0);
	vec2 uv00 = vtx_inputs[0].vUV0;

	if (transform_geometry_params.transform_normals != 0)
		n0 = normalize(vtx_inputs[0].vWorldNorm.xyz);

	put_normal(out_vtx_idx, n0);
	put_uv0(out_vtx_idx, uv00);
	put_color(out_vtx_idx, vtx_inputs[0].vColor);
}

