#version 450

layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;

layout (set = 0, binding = 0) uniform Context {
    float g_simulation_step_seconds;
    float g_app_time;

    #include "sdf_params.glsl"
};

#include "particles.glsl"

layout (set = 0, binding = 1) buffer ParticlesCurrent {
    Particle buf[];
} particles_current;

layout (set = 0, binding = 2) buffer writeonly ParticlesNext {
    Particle buf[];
} particles_next;

const vec3 GRAVITY = vec3(0, -10.0, 0);
const float BOUNCE_VELOCITY_LOSS = 0.01;
const float BOUNCE_HARDNESS = 0.6;
const float MAX_OUTWARD_MOVEMENT = 1000;
const float MIN_OUTWARD_MOVEMENT = 1;
const float PARTICLE_GROWTH = 1;
const float MAX_VELOCITY = 50.5;

#include "sdf_function.glsl"
#include "sdf_utils.glsl"

void sim_step(inout vec3 pos, inout vec3 vel) {
    float sd_val = sdf(pos);
    vec3 new_pos = pos;

    if (sd_val < 0.0) {
        for (int i = 0; i < 100; i++) {
            vec3 normal = calculate_sdf_normal(new_pos);
            float movement = min(-sd_val + MIN_OUTWARD_MOVEMENT * g_simulation_step_seconds, MAX_OUTWARD_MOVEMENT);
            new_pos += normal * movement;
            sd_val = sdf(new_pos);
            if (sd_val >= 0.0) {
                break;
            }
        }
    }


    new_pos += vel * g_simulation_step_seconds;
    vec3 new_vel = vel;
    float new_sd_val = sdf(new_pos);

    if (new_sd_val < 0.0) {
        // Collision.
        if (sd_val >= 0.0) {
            // No collision before, just reflect.
            float ratio = sd_val / (sd_val - new_sd_val);

            vec3 intersection = pos + (new_pos - pos) * ratio;
            vec3 normal = calculate_sdf_normal(intersection);

            float vel_parallel_ratio = dot(vel, normal);
            vec3 vel_parallel = normal * vel_parallel_ratio;
            vec3 vel_perpendicular = vel - vel_parallel;
            vel_parallel_ratio = max(0, abs(vel_parallel_ratio) * BOUNCE_HARDNESS - BOUNCE_VELOCITY_LOSS);
            new_vel = normal * vel_parallel_ratio + vel_perpendicular;

            new_pos = intersection + new_vel * g_simulation_step_seconds * (1-ratio);
        } else {
            // Collision before, move particle outward, don't modify velocity.
            for (int i = 0; i < 100; i++) {
                vec3 normal = calculate_sdf_normal(new_pos);
//                vec3 move_normal = dot(normal, new_vel) * normal;
//                vec3 move_perpendicular = new_vel - move_normal;
                float movement = min(-new_sd_val*2, MAX_OUTWARD_MOVEMENT);
                new_pos += normal * movement;
                new_sd_val = sdf(new_pos);
                if (new_sd_val >= 0.0) {
                    break;
                }
            }
        }
    }

    float vel_mag = length(new_vel);
    if (vel_mag > MAX_VELOCITY) {
        new_vel = new_vel / vel_mag * MAX_VELOCITY;
    }

    pos = new_pos;
    vel = new_vel;
}


void main() {
    uint num_particles = particles_current.buf.length();

    // Calculate the new position and velocity of the particle.
    uint index = gl_GlobalInvocationID.x;
    if (index >= num_particles) {
        return;
    }

    vec3 pos = particles_current.buf[index].pos;
    vec3 vel = particles_current.buf[index].vel;
    float size = particles_current.buf[index].size;
    float pad = particles_current.buf[index]._pad;

    // Apply gravity.
    vec3 next_vel = vel + GRAVITY * g_simulation_step_seconds;
    vec3 next_pos = pos;

    // Simulate collisions.
    sim_step(next_pos, next_vel);

    // Grow the particle.
    size = clamp(size + PARTICLE_GROWTH * g_simulation_step_seconds, 0, 1);

    // Reset if particle falls too deep
    if (next_pos.y < -100) {
//        next_pos = vec3(sin(float(index+g_app_time*100+1) * 0.1122), sin(index+g_app_time*10) *1.2 + 2.25, sin(float(index) * 1.3122));
//        next_vel = vec3(sin(float(index*1.212+g_app_time*100+1) * 0.0004122), sin(index+g_app_time*130) *0.0002, sin(float(index*3.21) * 0.003122)) * 0.1;
//        size = 0;
//        particles_current.buf[index] = Particle(next_pos, next_vel, size, pad);
        Particle p = init_particle(index);
        particles_current.buf[index] = p;
        particles_next.buf[index] = p;
    }

    particles_next.buf[index] = Particle(next_pos, next_vel, size, pad);
}


