
#ifndef INTERPOLATE_NORMALS
#define INTERPOLATE_NORMALS 1
#endif

#ifndef MAX_BOUNCES
#define MAX_BOUNCES 4
#endif

#ifndef MAX_TRACE_LENGTH
#define MAX_TRACE_LENGTH 1024
#endif

uniform usampler3D  s_grid_marker;

//--------------- DDA Intersection support ----------------------------------

struct ray_traversal_params
{
	float trace_range_primary;
	float trace_range_secondary;
};

#ifndef RT_TRAVERSAL_HAS_USER_RAY_STATE
struct ray_state_user_data
{
	uint dummy;
};
#endif

struct ray_state
{
	f16vec3             color;
	f16vec3             normal;
	vec3                dir;
	vec3                origin;
	float16_t           transparency;
	bool                running;
	int16_t             bounces;
	int16_t             material;
	int                 tests;
	int                 face_tests;
	bool                hit;
	bool                left;
	bool                inside_transparent;
	uint                active_threads_factor;
	uint                active_threads_samples;
	float               final_color_factor;
	ray_state_user_data user_data;
};

// user callbacks

void evaluate_material(in out ray_state state, in vec3 prev_state_origin, int hit_face, uint hit_material_flags, f16vec2 bc, bool flip_normal_on_glass);

//
#define GRID_SIZE in_bbox_data.grid_size_raytrace

//

void build_triangle(uint idx, out vec3 p0, out vec3 p1, out vec3 p2)
{
	RTFace rt_face = rt_get_face(idx);
	uint p0_idx = rt_face.v0;
	uint p1_idx = rt_face.v1;
	uint p2_idx = rt_face.v2;

	p0 = rt_get_vertex(p0_idx);
	p1 = rt_get_vertex(p1_idx);
	p2 = rt_get_vertex(p2_idx);
}

f16vec3 barycentric_for_face(int idx, vec3 p)
{
	RTFace rt_face = rt_get_face(idx);
	uint p0_idx = rt_face.v0;
	uint p1_idx = rt_face.v1;
	uint p2_idx = rt_face.v2;

	vec3 a = rt_get_vertex(p0_idx);
	vec3 b = rt_get_vertex(p1_idx);
	vec3 c = rt_get_vertex(p2_idx);

	return f16vec3(rt_barycentric_xyz(p, a, b, c));
}

f16vec2 barycentric_for_face_yz(int idx, vec3 p)
{
	RTFace rt_face = rt_get_face(idx);
	uint p0_idx = rt_face.v0;
	uint p1_idx = rt_face.v1;
	uint p2_idx = rt_face.v2;

	vec3 a = rt_get_vertex(p0_idx);
	vec3 b = rt_get_vertex(p1_idx);
	vec3 c = rt_get_vertex(p2_idx);

	return f16vec2(rt_barycentric_yz(p, a, b, c));
}

f16vec3 build_normal(int fi)
{
	vec3 p0, p1, p2;
	build_triangle(fi, p0, p1, p2);
		
	vec3 e1 = p1 - p0;
	vec3 e2 = p2 - p0;
	return f16vec3(normalize(cross(e1, e2)));
}

//---------------------------------------------------------------------------
// NOTE: We are doing dot produt thingie because this interpolator will 
// produce normal that is correct for face winding while during intersection
// we get proper normal always pointing towards the origin. This unifies this.
// NOTE2: This is also used to disable interpolation of the attributes on demand
f16vec3 interpolate_normal_from_bc_yz(int fi, f16vec2 bc_yz, f16vec3 ref_normal)
{
#if INTERPOLATE_NORMALS == 1
	f16vec3 n0, n1, n2;

	RTFace rt_face = rt_get_face(fi);
	uint i0 = rt_face.v0;
	uint i1 = rt_face.v1;
	uint i2 = rt_face.v2;
	
	n0 = f16vec3(rt_get_vertex_normal(i0));
	n1 = f16vec3(rt_get_vertex_normal(i1));
	n2 = f16vec3(rt_get_vertex_normal(i2));
	//return n0;

	f16vec3 smooth_normal = n0 * (float16_t(1.0) - bc_yz.x - bc_yz.y) + n1 * bc_yz.x + n2 * bc_yz.y;
	//return smooth_normal * (dot(smooth_normal, ref_normal) < 0.0 ? -1.0 : 1.0);
	//return f16vec3(normalize(smooth_normal * (dot(smooth_normal, ref_normal) < 0.0 ? -1.0 : 1.0)));
	return f16vec3(normalize(smooth_normal));
#else
	return ref_normal;
#endif
}

f16vec2 interpolate_uv_from_bc_yz(int fi, f16vec2 bc_yz)
{
	f16vec2 uv0, uv1, uv2;

	RTFace rt_face = rt_get_face(fi);
	uint i0 = rt_face.v0;
	uint i1 = rt_face.v1;
	uint i2 = rt_face.v2;

	uv0 = f16vec2(rt_get_vertex_uv0(i0));
	uv1 = f16vec2(rt_get_vertex_uv0(i1));
	uv2 = f16vec2(rt_get_vertex_uv0(i2));

	f16vec2 uv = uv0 * (float16_t(1.0) - bc_yz.x - bc_yz.y) + uv1 * bc_yz.x + uv2 * bc_yz.y;
	return uv;
}

f16vec4 interpolate_color_from_bc_yz(int fi, f16vec2 bc_yz)
{
	f16vec4 c0, c1, c2;

	RTFace rt_face = rt_get_face(fi);
	uint i0 = rt_face.v0;
	uint i1 = rt_face.v1;
	uint i2 = rt_face.v2;

	c0 = f16vec4(rt_get_vertex_color(i0));
	c1 = f16vec4(rt_get_vertex_color(i1));
	c2 = f16vec4(rt_get_vertex_color(i2));

	f16vec4 c = c0 * (float16_t(1.0) - bc_yz.x - bc_yz.y) + c1 * bc_yz.x + c2 * bc_yz.y;
	return c;
}

bool intersectTriangle(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2, out float intersection_t, out f16vec3 out_normal)
{
	intersection_t = 0.0;

	//out_normal = vec3(0.0);

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (det < 1e-8 && det > -1e-8)
	{
		//out_normal = vec3(1.0, 0.4, 1.0);
		return false;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		//out_normal = vec3(1.0, 1.0, 0.0);       // this one is somehow triggered now????
		return false;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		out_normal = f16vec3(1.0, 0.0, 0.0);
		return false;
	}

	intersection_t = dot(e2, qvec) * inv_det;
	if (intersection_t > 1e-8)
	{
		out_normal = f16vec3(normalize(cross(e1, e2)));	// TODO: remove normalization and reuse above calcs
		return true;
	}
	return false;
}

// NOTE: trying to workaround compiler issues here....

struct intersection
{
	float t;
	f16vec3 normal;
	float16_t denom;
	f16vec2 bc;			// barycentrics. only two included
};

intersection intersectTriangle2(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2)
{
	intersection it;
	it.t = -1.0;

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (abs(det) < 1e-7)
	{
		return it;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		return it;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		return it;
	}

	// this is some fucked up shit... miscompiles?
	float t = dot(e2, qvec) * inv_det;
	if (t > 1e-7)
	{
		// float denom = dot(normalize(cross(e1, e2)), dir);
		it.normal = f16vec3(normalize(cross(e2, e1)));	// TODO: remove normalization and reuse above calcs
		// NOTE: This fucks up for some reason:(
		//it.t = t;
		#if 1
		float denom = dot(vec3(it.normal), dir); 
		if (denom > 1e-7)
		{
			vec3 p0l0 = v0 - orig; 
			t = dot(p0l0, vec3(it.normal)) / denom; 
			it.t = t;
		}
		#endif
	}
	return it;
}

intersection intersectTriangle3(vec3 orig, vec3 dir, vec3 v0, vec3 v1, vec3 v2)
{
	intersection it;
	it.t = -1.0;

	vec3 e1 = v1 - v0;
	vec3 e2 = v2 - v0;
	// Calculate planes normal vector
	vec3 pvec = cross(dir, e2);
	float det = dot(e1, pvec);

	// Ray is parallel to plane
	if (abs(det) < 1e-7)
	{
		return it;
	}

	float inv_det = 1.0 / det;
	vec3 tvec = orig - v0;
	float u = dot(tvec, pvec) * inv_det;
	if (u < 0.0 || u > 1.0)
	{
		return it;
	}

	vec3 qvec = cross(tvec, e1);
	float v = dot(dir, qvec) * inv_det;
	if (v < 0.0 || u + v > 1.0)
	{
		return it;
	}

	// this is some fucked up shit... miscompiles?
	float t = dot(e2, qvec) * inv_det;
	if (t > 1e-7)
	{
		// float denom = dot(normalize(cross(e1, e2)), dir);
		it.normal = f16vec3(normalize(cross(e2, e1)));	// TODO: remove normalization and reuse above calcs
		//it.hit = true;
		// NOTE: This fucks up for some reason:(
		//it.t = t;
		#if 1
		float denom = dot(vec3(it.normal), dir); 
		
		//if (denom > 1e-7)
		{
			vec3 p0l0 = v0 - orig; 
			//t = dot(p0l0, vec3(it.normal)) / denom; 
			it.t = t;
			it.denom = float16_t(denom);

			it.bc.x = float16_t(u / denom);
			it.bc.y = float16_t(v / denom);
		}
		#endif
	}
	return it;
}

#if 0
int findClosestNaive(vec3 origin, vec3 dir, int skip_fi, out int closest_fi, out float closest_it, out f16vec3 closest_norm)
{
	closest_fi = -1;
	closest_it = 10000.0;

	for(int fi = 0; fi < numFaces; fi++)
	{
		if (fi == skip_fi)
			continue;

		intersection it1, it2;

		vec3 p0, p1, p2;
		build_triangle(fi, p0, p1, p2);
		
		#if 0

		#if 1 // backfaces
		it1 = intersectTriangle2(origin, dir, p0, p1, p2);
		if (it1.t > 1e-8)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				closest_norm = -it1.normal;
			}
		}
		#endif
		#if 1
		it2 = intersectTriangle2(origin, dir, p2, p1, p0);
		if (it2.t > 1e-8)
		{
			if (it2.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it2.t;
				closest_norm = it2.normal;
			}
		}
		#endif

		#else // optimized

		
		it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		if (it1.t > 1e-7)
		{
			if (it1.t < closest_it)
			{
				closest_fi = fi;
				closest_it = it1.t;
				closest_norm = it1.normal;

				#ifdef INNER_REFLECTION
				closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
				#endif
			}
		}

		#endif
	}

	return closest_fi;
}
#endif

#ifdef RT_TRAVERSAL_HAS_CUSTOM_INTERSECTION
// NOTE: This is a hack... provide proper interface...
void findClosestBucket2(in out ray_state_user_data user_data, uint list_index, bool bucket_full, uint max_tests, vec3 origin, vec3 dir, int skip_fi, float max_t, out int closest_fi, out uint material_flags, out float closest_it, out f16vec3 closest_norm, out f16vec2 closest_bc, out int face_tests);
#else
void findClosestBucket2(in out ray_state_user_data user_data, uint list_index, bool bucket_full, uint max_tests, vec3 origin, vec3 dir, int skip_fi, float max_t, out int closest_fi, out uint material_flags, out float closest_it, out f16vec3 closest_norm, out f16vec2 closest_bc, out int face_tests)
{
	face_tests = 0;

	closest_fi = -1;
	closest_it = max_t;

	uint head = in_faces_list_tails_data[list_index];
	uint cnt  = in_faces_list_data.node_buffer[head];
	
	if (max_tests != -1 && cnt >= max_tests)
	{
		// otherwise we would flicker
		closest_fi = -1;
		return;
	}

#if 0
	face_tests = int(cnt);
	ivec2 screen_pos = ivec2(gl_FragCoord.xy) & ivec2(127);
	float hash = fract(texelFetch(s_BlueNoise, ivec3(screen_pos, 0), 0).r);

	int fi_idx_start = int(hash * 2.0);
	for(int fi_idx = fi_idx_start; fi_idx < cnt; fi_idx += 2)
#else
	face_tests = int(cnt);
	for(int fi_idx = 0; fi_idx < cnt; fi_idx += 1)
#endif
	{
		int fi = int(in_faces_list_data.node_buffer[head + fi_idx + 1]);
		if (fi == skip_fi)
			continue;

		RTFace rt_face = rt_get_face(fi);
		vec3 p0 = rt_get_vertex(rt_face.v0);
		vec3 p1 = rt_get_vertex(rt_face.v1);
		vec3 p2 = rt_get_vertex(rt_face.v2);

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		{
			if (it1.t >= 0.0 && it1.t <= closest_it)
			//if (it1.t >= 0.0)
			{
				if ((rt_face.material_flags & MaterialFlag_Doublesided) == 0)
				{
					if (dot(dir, it1.normal) >= 0.0)
						continue;
				}

				closest_fi = fi;
				closest_it = it1.t;
				#if 0

				#ifdef INNER_REFLECTION
				closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
				#else
				closest_norm = it1.normal;
				#endif

				closest_bc     = it1.bc;
				material_flags = rt_face.material_flags;

				#endif
			}
			//break;
		}
	}

	if (closest_fi != -1)
	{
		// actually compute intersection params
		RTFace rt_face = rt_get_face(closest_fi);
		vec3 p0 = rt_get_vertex(rt_face.v0);
		vec3 p1 = rt_get_vertex(rt_face.v1);
		vec3 p2 = rt_get_vertex(rt_face.v2);

		intersection it1 = intersectTriangle3(origin, dir, p2, p1, p0);
		#ifdef INNER_REFLECTION
		closest_norm = it1.denom > 0.0 ? it1.normal : -it1.normal;
		#else
		closest_norm = it1.normal;
		#endif

		closest_bc     = it1.bc;
		material_flags = rt_face.material_flags;
	}
}
#endif

struct dda
{
	bool     is_high;
	vec3     res;
	vec3     ro;
	vec3     ird;
	vec3     delta;
	vec3     t_max;
	float    prev_next_t;
	vec3     prev_t_max;
	float    next_t;

	vec3     pf;
	float    tmin;
	float    tmax;
};

bool is_pos_inside_grid(i16vec3 icell)
{
	int16_t icell_mask = icell.x | icell.y | icell.z;   // test against <0 and >GRID_RES-1
	if ((icell_mask & (~(GRID_RES-1))) == 0)
		return true;
	else
		return false;
}

void build_dda(inout dda dda, vec3 ro, vec3 rd, float tmin, float tmax, bool restart)
{
	if (!restart)
	{
		dda.res = f16vec3(GRID_SIZE);
		dda.is_high = false;
	}
	dda.ro = ro - in_bbox_data.bbox_raytrace_min.xyz;

//	ivec3 p = ivec3(floor(dda.ro));
//	if (is_pos_inside_grid(p))
//		dda.res = is_high_level_cell_occupied(p.x, p.y, p.z) != 0 ? 4 : 1;

	dda.ird = vec3(1.0) / rd;
	vec3 s = step(vec3(0.0), rd);

	dda.delta = (s * 2.0 - 1.0) * dda.res * dda.ird;
	dda.t_max = ((floor(dda.ro / dda.res) + s) * dda.res - dda.ro) * dda.ird;

	dda.prev_next_t = 0.0f;
	dda.prev_t_max = dda.t_max;

	if (!restart)
		dda.pf = dda.ro;

	dda.tmin = tmin;
	dda.tmax = tmax;
}

i16vec3 grd_icell_dda(inout dda dda)
{
	//ivec3 p = ivec3(floor(dda.pf * (dda.res_rcp * (dda.is_high ? 4.0 : 1.0))));
	i16vec3 p = i16vec3(floor(dda.pf / vec3(GRID_SIZE)));
	return p;
}

bool is_inside_grid_dda(inout dda dda)
{
	i16vec3 p = grd_icell_dda(dda);
	return is_pos_inside_grid(p);
}

bool is_high_level_empty_dda(inout dda dda)
{
	#if 1
	ivec3 p = grd_icell_dda(dda);
	//return is_high_level_cell_occupied(p.x, p.y, p.z) == 0 ? true : false;
	return rt_read_grid_marker_low_res(s_grid_marker, p >> 2) ? false : true;
	#else
	ivec3 p = ivec3(floor(dda.pf * ((vec3(1.0) / dda.res) * (dda.is_high ? 1.0 : 0.25))));
	return rt_read_grid_marker_low_res(s_grid_marker, p) ? false : true;
	#endif
}


float get_current_intersection_dda(inout dda dda, in ray_state state)
{
	dda.next_t = min(dda.t_max.x, min(dda.t_max.y, dda.t_max.z));
	dda.pf = dda.ro + state.dir * (dda.prev_next_t + dda.next_t) * 0.5;
	return dda.next_t;
}

bool dda_is_abort(in dda dda)
{
	return dda.next_t >= dda.tmax;
}

void step_dda(inout dda dda)
{
	dda.prev_next_t = dda.next_t;
	vec3 cmp = step(dda.t_max.xyz, dda.t_max.yxy) * step(dda.t_max.xyz, dda.t_max.zzx);
	dda.prev_t_max = dda.t_max;
	dda.t_max += cmp * dda.delta;
}

float dda_scale_grid_res(inout dda dda, in ray_state state, float scale)
{
	dda.res *= scale;
	dda.delta *= scale;

	// NOTE: This will still nuke sometimes. When we step exacly on both edges floor() might
	//       move us back one cell and this will nuke. Hope it is unlikely enough...
	// we calculate intersection with next block using new scales
					
	// NOTE: Try to merge these two conditions
	// NOTE: floor(float) in processing returns int!!!!

	float prev_t_min = min(dda.prev_t_max.x, min(dda.prev_t_max.y, dda.prev_t_max.z));  // this is checking where the grid intersection happened. can we just keep track of it?

	float blend_x = prev_t_min == dda.prev_t_max.x ? 1.0 : 0.0;
	float blend_y = prev_t_min == dda.prev_t_max.y ? 1.0 : 0.0;
	float blend_z = prev_t_min == dda.prev_t_max.z ? 1.0 : 0.0;

	vec3 next_ro = dda.ro + state.dir * prev_t_min;

	float dt_x_1 = dda.delta.x;
	float dt_y_1 = dda.delta.y;
	float dt_z_1 = dda.delta.z;
							
	vec3 s = step(vec3(0.0), state.dir);
	float dt_x_2 = ((floor(next_ro.x / dda.res.x) + s.x) * dda.res.x - next_ro.x) * dda.ird.x;
	float dt_y_2 = ((floor(next_ro.y / dda.res.y) + s.y) * dda.res.y - next_ro.y) * dda.ird.y;
	float dt_z_2 = ((floor(next_ro.z / dda.res.z) + s.z) * dda.res.z - next_ro.z) * dda.ird.z;
							
	dda.t_max.x = mix(dt_x_2, dt_x_1, blend_x);
	dda.t_max.y = mix(dt_y_2, dt_y_1, blend_y);
	dda.t_max.z = mix(dt_z_2, dt_z_1, blend_z);

	dda.t_max += prev_t_min;
	dda.next_t = min(dda.t_max.x, min(dda.t_max.y, dda.t_max.z));

	if (scale < 1.0) // no need to calculate sampling cube more precisely when going to high res
		dda.pf = dda.ro + state.dir * (dda.prev_next_t + dda.next_t) * 0.5;

	return dda.next_t;
}

bool fetch_grid_marker_for_cell(ivec3 icell, int mip)
{
	//int grid_marker = int(texelFetch(s_grid_marker, icell, 0).r);
	if (mip == 0)
		return rt_read_grid_marker_high_res(s_grid_marker, icell);
	else
		return rt_read_grid_marker_low_res(s_grid_marker, icell >> 2);

}


uint ballot_count(bool v)
{
	#ifndef SPIRV_VULKAN
	uint cnt = bitCount(ballotThreadNV(v));
	#else
	uint cnt;
	{
		uvec4 ballot = subgroupBallot(v);
		cnt  = bitCount(ballot.x);
		cnt += bitCount(ballot.y);
	}
	#endif

	return cnt;
}

ivec3 dda_step_sign_from_dir(vec3 dir)
{
	vec3 s = step(vec3(0.0), dir);   // 0.0 or 1.0
	vec3 sgn = s * 2.0 - 1.0;        // -1.0 or 1.0 
	return ivec3(sgn);
}

ivec3 dda_step_from_dir(vec3 dir, vec3 cell_step)
{
	vec3 s = step(vec3(0.0), dir);
	vec3 sgn = s * cell_step * 2.0 - cell_step;
	return ivec3(sgn);
}

#if 1  // version without multiresolution grid
int findClosestDDAMultibounce(ray_traversal_params traversal_params, inout ray_state state, int skip_fi, out int closest_fi, out float closest_it, int max_bounces)
{ 
    //state.dir = normalize(state.dir);

	closest_fi = -1;
	closest_it = 1000000.0;

	vec3 cellDimension = vec3(GRID_SIZE);

	float tmin = 0.0;
	float tmax = traversal_params.trace_range_primary;

	vec3 deltaT, nextCrossingT; 

	vec3 ro_cell = state.origin - in_bbox_data.bbox_raytrace_min.xyz;
	ivec3 icell = ivec3(floor(ro_cell / cellDimension));

	{
		vec3 s = step(vec3(0.0), state.dir);       // 0.0 or 1.0
		vec3 sgn = s * 2.0 - 1.0;                  // -1.0 or 1.0 

		deltaT = sgn * cellDimension / state.dir;    // same as 'dir'
		nextCrossingT = tmin + ((floor(ro_cell / cellDimension) + s) * cellDimension - ro_cell) / state.dir;
	}
 
	// walk through each cell of the grid and test for an intersection if
	// current cell contains geometry
	float rt = tmin;
	bool inside      = false;
	bool prev_inside = inside;

	int max_iter = MAX_TRACE_LENGTH;                        // this includes all boundes we are now tracking
	int threads_running = int(ballot_count(true));

	while(rt < tmax && max_iter >= 0)
	{
		bool hit = false;
		max_iter--;
		state.tests++;

		if (state.running == false)
			break;

		//if (int(ballot_count(state.hit)) >= threads_running / 2)
		//	break;
		
		// t for next next crossing intersection
		rt          = tmin + min(nextCrossingT.x, min(nextCrossingT.y, nextCrossingT.z));
		prev_inside = prev_inside || inside;
		inside      = false;

		bool search_bucket = false;
		uint icell_idx     = 0;
		bool bucket_full    = false;

#if 0
		if (fetch_grid_marker_for_cell(icell, 0) > 0)
#else
		if (is_pos_inside_grid(i16vec3(icell)))
#endif
		{
			inside = true;
			icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
			bucket_full = fetch_grid_marker_for_cell(icell, 0);
		}

		// try to advance also other threads to increase coherency. this is mostly to skip empty space and sync threads
		if (true)
		{
			bool done_criteria = (bucket_full) || (prev_inside == true && inside == false);
			uint hit_threads = ballot_count(done_criteria);
			int ii = 0;
			while(hit_threads < 16 && ii < 4)
			{
				state.tests++;
				if (done_criteria == false)
				{
					max_iter--;

					// step dda
					vec3 mm = step(nextCrossingT.xyz, nextCrossingT.yxy) * step(nextCrossingT.xyz, nextCrossingT.zzx);
					icell += dda_step_from_dir(state.dir, mm);
					nextCrossingT += mm * deltaT;

					rt = tmin + min(nextCrossingT.x, min(nextCrossingT.y, nextCrossingT.z));

					inside = is_pos_inside_grid(i16vec3(icell));
					prev_inside = prev_inside || inside;

					if (inside)
					{
						icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
						bucket_full = fetch_grid_marker_for_cell(icell, 0);
					}

					done_criteria = (bucket_full) || (prev_inside == true && inside == false);
				}

				ii++;
				hit_threads = ballot_count(done_criteria);
			}
		}

		if (bucket_full)
		{
			f16vec2 closest_bc;
			f16vec3 closest_normal;
			uint    closest_material_flags;

			state.active_threads_factor  += ballot_count(true);
			state.active_threads_samples += 1;

			//state.tests += bucket_full;

			#ifndef USE_LINKED_LISTS
			int bucket_offset = int(in_buckets.offsets[icell_idx]);
			findClosestBucket2(bucket_offset, bucket_size, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
			//findClosestBucket(icell_idx, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
			#else
			int face_tests = 0;
			uint max_tests = -1;
			if (closest_fi == -1 && state.face_tests > 512)
				max_tests = 16;

			findClosestBucket2(state.user_data, icell_idx, bucket_full, max_tests, state.origin, state.dir, skip_fi, min(tmax, rt), closest_fi, closest_material_flags, closest_it, closest_normal, closest_bc, face_tests);
			state.face_tests += face_tests;
			#endif

			if (closest_fi != -1)
			{
				state.bounces += int16_t(1);
				state.normal   = closest_normal;

				hit     = true;
				skip_fi = closest_fi;

				//
				if (state.bounces >= max_bounces || (closest_material_flags & MaterialFlag_RaytraceTerminate) != 0)
				{
					state.running = false;
					//state.hit = true;
					//break;
				}

				f16vec2 bc = barycentric_for_face_yz(closest_fi, state.origin + state.dir * closest_it);
				if ((closest_material_flags & MaterialFlag_Flat) == 0)
					state.normal = interpolate_normal_from_bc_yz(closest_fi, bc, state.normal);

				bool flip_normal_on_glass = false;
				if ((closest_material_flags & MaterialFlag_Transparent) != 0)
				{
					// align normal along the current view dir
					flip_normal_on_glass = dot(state.normal, state.dir) > 0.0 ? true : false;
				}

				// calculate new origin and recalculate tracing parameters for the bounce
				vec3 prev_state_origin = state.origin; // for lighting calculation
				state.origin = state.origin + state.dir * closest_it;

				// we hit solid object which is not perfectly rough, reflect
				evaluate_material(state, prev_state_origin, closest_fi, closest_material_flags, bc, flip_normal_on_glass);
				//state.color = TurboColormap(fract(closest_fi * 13.11301));
				//state.color.rg = bc.xy;
				//state.color.b = 1.0 - bc.x - bc.y;

				// rebuild stepping parameters TODO: factor this out
				if (max_bounces > 1) // && state.running)
				{
					vec3 ird = 1.0 / state.dir;
					ro_cell = state.origin - in_bbox_data.bbox_raytrace_min.xyz;
					//cell = floor(ro_cell / cellDimension);

					vec3 s = step(vec3(0.0), state.dir);            // 0.0 or 1.0
					vec3 sgn = s * 2.0 - 1.0;                       // -1.0 or 1.0 

					deltaT = sgn * cellDimension * ird;         // same as 'dir'
					nextCrossingT = tmin + ((floor(ro_cell / cellDimension) + s) * cellDimension - ro_cell) * ird;

					// walk through each cell of the grid and test for an intersection if
					// current cell contains geometry
					tmax = traversal_params.trace_range_secondary;
					rt = tmin;
				}
			}
		}

		if (inside == false && prev_inside == true)
		{
			state.running = false;
			state.left    = true;
		}

		if (!hit)
		{
			// all components of minimum mask (i.e. x <= y && x <= z, y <= x && y <= z, z <= y && z <= x) 
			// are false except for the corresponding smallest component of dt (if no mask), which 
			// is the axis along which the ray should be incremented
			// stolen from https://github.com/guozhou/voxelizer/blob/master/raycasting_fs.glsl
			// NOTE: nextCrossingT == dt

			vec3 mm = step(nextCrossingT.xyz, nextCrossingT.yxy) * step(nextCrossingT.xyz, nextCrossingT.zzx);
			icell += dda_step_from_dir(state.dir, mm);
			nextCrossingT += mm * deltaT;
		}
	} 

	return 0;
} 
#else

int findClosestDDAMultibounce(ray_traversal_params traversal_params, inout ray_state state, int skip_fi, out int closest_fi, out float closest_it, int max_bounces)
{ 
	dda dda;
	build_dda(dda, state.origin, state.dir, 0.0, traversal_params.trace_range_primary, false);

	closest_fi = -1;
	closest_it = 1000000.0;

	float tmin = 0.0;
	float tmax = traversal_params.trace_range_primary;
	float rt   = tmin;

	// walk through each cell of the grid and test for an intersection if
	// current cell contains geometry
	bool inside      = false;
	bool prev_inside = inside;

	int max_iter = MAX_TRACE_LENGTH;                        // this includes all boundes we are now tracking

	while(rt < tmax && max_iter >= 0)
	{
		bool hit = false;
		max_iter--;
		state.tests++;

		if (state.running == false)
			break;
		
		float rt = get_current_intersection_dda(dda, state);
		prev_inside = prev_inside || inside;
		inside      = is_inside_grid_dda(dda);

		bool changed = false;
		if (false) // && max_bounces > 1)
		{
			bool change_to_high = (inside == false || (inside == true && dda.is_high == false && is_high_level_empty_dda(dda)));
			bool change_to_low  = (inside == true && (dda.is_high == true && is_high_level_empty_dda(dda) == false));

			uint change_to_high_ballot_cnt = ballot_count(change_to_high);

			// change to low is always executed!
			//if (change_to_high_ballot_cnt < 4)
			//	change_to_high = false;

			if (change_to_high)
			{
				// go to high res
				dda.is_high  = true;
				float change = 4.0;
				changed      = true;
				rt = dda_scale_grid_res(dda, state, 4.0);
			}
			else if (change_to_low)
			{
				// go to low res
				dda.is_high  = false;
				float change = 0.25;
				changed      = true;
				rt = dda_scale_grid_res(dda, state, 0.25);
			}
		}

		ivec3 icell        = ivec3(0);
		uint icell_idx     = 0;
		bool bucket_full   = false;

		// calculate cell index and check for intersections if valid (we don't do grid-box intersection yet)
		// check if we can skip whole high level cell. first the naive way
        if (inside)
		{
			if (dda.is_high == false)
			{
				icell = grd_icell_dda(dda);
				icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
				bucket_full = fetch_grid_marker_for_cell(icell, 0);
			}
		}

		// try to advance also other threads to increase coherency. this is mostly to skip empty space and sync threads
		// NOTE: This is (at curent state) not compatible with multires solution:( Enable only when tracing high res grid but this also has some issues:(
		//if (false)
		if (true)
		{
			bool done_criteria = (changed == true) || (dda.is_high == true) || (bucket_full == true);// || (prev_inside == true && inside == false);
			uint hit_threads = ballot_count(done_criteria);
			int ii = 0;
			while(hit_threads < 16 && ii < 4)
			{
				state.tests++;
				if (done_criteria == false)
				{
					max_iter--;
					step_dda(dda);
					rt = get_current_intersection_dda(dda, state); // this not only calculates 'rt' but updates prev_t and pf for call calculation

					inside = is_inside_grid_dda(dda);
					prev_inside = prev_inside || inside;

					if (inside)
					{
						icell = grd_icell_dda(dda);
						icell_idx = icell.z * GRID_RES * GRID_RES + icell.y * GRID_RES + icell.x;
						bucket_full = fetch_grid_marker_for_cell(icell, 0);
					}

					done_criteria = (changed == true) || (dda.is_high == true) || (bucket_full == true);// || (prev_inside == true && inside == false);
				}

				ii++;
				hit_threads = ballot_count(done_criteria);
			}
		}

		if (bucket_full)
		{
			f16vec2 closest_bc;
			f16vec3 closest_normal;

			state.active_threads_factor  += ballot_count(true);
			state.active_threads_samples += 1;

			#ifndef USE_LINKED_LISTS
			int bucket_offset = int(in_buckets.offsets[icell_idx]);
			findClosestBucket2(bucket_offset, bucket_size, state.origin, state.dir, rt, skip_fi, closest_fi, closest_it, closest_normal);
			//findClosestBucket(icell_idx, state.origin, state.dir, skip_fi, closest_fi, closest_it, state.normal);
			#else
			int face_tests = 0;
			findClosestBucket2(state.user_data, icell_idx, bucket_full, state.origin, state.dir, skip_fi, rt, closest_fi, closest_it, closest_normal, closest_bc, face_tests);
			#endif

			if (closest_fi != -1)
			{
				state.bounces += int16_t(1);
				state.normal   = closest_normal;

				hit     = true;
				skip_fi = closest_fi;

				//
				int hit_material = rt_get_triangle_material(closest_fi);
				if (state.bounces >= max_bounces || (materials.material_properties[hit_material].flags & MaterialFlag_RaytraceTerminate) != 0)
				{
					state.running = false;
					//state.hit = true;
					//break;
				}

				f16vec2 bc = barycentric_for_face_yz(closest_fi, state.origin + state.dir * closest_it);
				if ((materials.material_properties[hit_material].flags & MaterialFlag_Flat) == 0)
					state.normal = interpolate_normal_from_bc_yz(closest_fi, bc, state.normal);

				bool flip_normal_on_glass = false;
				if ((materials.material_properties[hit_material].flags & MaterialFlag_Transparent) != 0)
				{
					// align normal along the current view dir
					flip_normal_on_glass = dot(state.normal, state.dir) > 0.0 ? true : false;
				}

				// calculate new origin and recalculate tracing parameters for the bounce
				vec3 prev_state_origin = state.origin; // for lighting calculation
				state.origin = state.origin + state.dir * closest_it;

				// we hit solid object which is not perfectly rough, reflect
				evaluate_material(state, prev_state_origin, int16_t(rt_get_triangle_material(closest_fi)), closest_fi, bc, flip_normal_on_glass);

				//state.origin += state.dir * 50.2;

				if (max_bounces > 1 && state.running)
					build_dda(dda, state.origin, state.dir, 0.0, traversal_params.trace_range_secondary, true);

				tmax = traversal_params.trace_range_secondary;
			}
		}

		if (inside == false && prev_inside == true)
		{
			state.running = false;
			state.left    = true;
		}

		if (!hit)
		{
			step_dda(dda);
			if (dda_is_abort(dda))
				state.running = false;
		}
	} 

	return 0;
} 
#endif
