#ifndef _VECTOR3_H
#define _VECTOR3_H

#include <assert.h>
#include <math.h>

// 3-dimensional floating point vector.

class Vector3
{
public:
	Vector3() {}
	Vector3(const float* v) : x(v[0]), y(v[1]), z(v[2]) {}
	Vector3(float x, float y, float z) : x(x), y(y), z(z) {}

	Vector3 min(const Vector3& b) const
	{
		return Vector3(
			(x < b.x) ? x : b.x,
			(y < b.y) ? y : b.y,
			(z < b.z) ? z : b.z);
	}

	Vector3 max(const Vector3& b) const
	{
		return Vector3(
			(x > b.x) ? x : b.x,
			(y > b.y) ? y : b.y,
			(z > b.z) ? z : b.z);
	}

	float dot(const Vector3& b) const
	{
		return x * b.x + y * b.y + z * b.z;
	}

	float length() const
	{
		return sqrtf(dot(*this));
	}

	float length2() const
	{
		return dot(*this);
	}

	Vector3 normalize() const
	{
		float m = 1.0f / length();
		return Vector3(x * m, y * m, z * m);
	}

	Vector3 cross(const Vector3& b) const
	{
		return Vector3(
			y * b.z - z * b.y,
			z * b.x - x * b.z,
			x * b.y - y * b.x);
	}

	Vector3& self_scale(const Vector3& s)
	{
		x *= s.x;
		y *= s.y;
		z *= s.z;
		return *this;
	}

	Vector3& self_pow(float p)
	{
		x = powf(x, p);
		y = powf(y, p);
		z = powf(z, p);
		return *this;
	}

	float luminance() const
	{
		return 0.299 * x + 0.587 * y + 0.114 * z;
	}

	Vector3 operator+(const Vector3& b) const
	{
		return Vector3(
			x + b.x,
			y + b.y,
			z + b.z);
	}

	Vector3 operator-(const Vector3& b) const
	{
		return Vector3(
			x - b.x,
			y - b.y,
			z - b.z);
	}

	Vector3 operator-() const
	{
		return Vector3(-x, -y, -z);
	}

	Vector3 operator*(float s) const
	{
		return Vector3(x * s, y * s, z * s);
	}

	Vector3 operator/(float s) const
	{
		return Vector3(x / s, y / s, z / s);
	}

	Vector3& operator+=(const Vector3& b)
	{
		x += b.x;
		y += b.y;
		z += b.z;
		return *this;
	}

	Vector3& operator*=(float s)
	{
		x *= s;
		y *= s;
		z *= s;
		return *this;
	}

	Vector3& operator/=(float s)
	{
		x /= s;
		y /= s;
		z /= s;
		return *this;
	}

	const float& operator[](int n) const
	{
		assert(n >= 0 && n < 3);
		return (&x)[n];
	}

	float& operator[](int n)
	{
		assert(n >= 0 && n < 3);
		return (&x)[n];
	}

	static inline Vector3 zero()
	{
		return Vector3(0.0f, 0.0f, 0.0f);
	}

public:
	float x, y, z;
};

inline Vector3 scale(const Vector3& a, float s)
{
	return Vector3(a.x * s, a.y * s, a.z * s);
}

inline Vector3 scale(const Vector3& a, const Vector3& b)
{
	return Vector3(a.x * b.x, a.y * b.y, a.z * b.z);
}

inline Vector3 normalize(const Vector3& v)
{
	float m = 1.0f / v.length();
	return Vector3(v.x * m, v.y * m, v.z * m);
}

inline Vector3 cross(const Vector3& a, const Vector3& b)
{
	return Vector3(
		a.y * b.z - a.z * b.y,
		a.z * b.x - a.x * b.z,
		a.x * b.y - a.y * b.x);
}


inline Vector3 min(const Vector3& a, const Vector3& b)
{
	return Vector3((a.x < b.x) ? a.x : b.x,
	               (a.y < b.y) ? a.y : b.y,
	               (a.z < b.z) ? a.z : b.z);
}

inline Vector3 max(const Vector3& a, const Vector3& b)
{
	return Vector3((a.x >= b.x) ? a.x : b.x,
	               (a.y >= b.y) ? a.y : b.y,
	               (a.z >= b.z) ? a.z : b.z);
}

// Intersect axis-aligned bounding box.

inline bool aaboxIntersect(const Vector3& p, const Vector3& d,
                           const Vector3& min, const Vector3& max,
                           float &a, float &b)
{
	float t1;
	float t2;
	float t3;
	float t4;
	bool hit1 = false;
	bool hit2 = false;

	assert(d.x != 0.0f || d.y != 0.0f || d.z != 0.0f);

	/* TODO: branch probabilities? */

	if (d.x == 0.0f)
	{
		if (p.x < min.x || p.x > max.x)
			return false;
	}
	else if (d.x > 0.0f)
	{
		t1 = (min.x - p.x) / d.x;
		t2 = (max.x - p.x) / d.x;
		hit1 = true;
	}
	else
	{
		t1 = (max.x - p.x) / d.x;
		t2 = (min.x - p.x) / d.x;
		hit1 = true;
	}
	assert(!hit1 || t1 <= t2);

	if (d.y == 0.0f)
	{
		if (p.y < min.y || p.y > max.y)
			return false;
	}
	else if (d.y > 0.0f)
	{
		t3 = (min.y - p.y) / d.y;
		t4 = (max.y - p.y) / d.y;
		hit2 = true;
	}
	else
	{
		t3 = (max.y - p.y) / d.y;
		t4 = (min.y - p.y) / d.y;
		hit2 = true;
	}
	assert(!hit2 || t3 <= t4);

	if (hit1 && hit2)
	{
		if (t2 < t3 || t4 < t1)
			return false;

		if (t1 < t3)
			t1 = t3;
		if (t2 > t4)
			t2 = t4;

		assert(t1 <= t2);
	}

	if (!hit1 && hit2)
	{
		t1 = t3;
		t2 = t4;
		hit1 = true;
	}

	if (d.z == 0.0f)
	{
		if (p.z < min.z || p.z > max.z)
			return false;
		hit2 = false;
	}
	else if (d.z > 0.0f)
	{
		t3 = (min.z - p.z) / d.z;
		t4 = (max.z - p.z) / d.z;
		hit2 = true;
	}
	else
	{
		t3 = (max.z - p.z) / d.z;
		t4 = (min.z - p.z) / d.z;
		hit2 = true;
	}

	if (!hit1 && !hit2)
		return false;

	assert(!hit2 || t3 <= t4);

	if (hit1 && hit2)
	{
		if (t2 < t3 || t4 < t1)
			return false;

		if (t1 < t3)
			t1 = t3;
		if (t2 > t4)
			t2 = t4;

		assert(t1 <= t2);
	}

	if (!hit1 && hit2)
	{
		a = t3;
		b = t4;
	}
	else
	{
		a = t1;
		b = t2;
	}

	return true;
}

// Intersect a triangle.

inline bool triangleIntersect(const Vector3& O, const Vector3& D,
                              const Vector3& V0, const Vector3& V1,
                              const Vector3& V2,
                              float &_t, float &_u, float &_v)
{
	const Vector3 E1 = V1 - V0;
	const Vector3 E2 = V2 - V0;
	const Vector3 P = D.cross(E2);

	float det = E1.dot(P);

	if (fabs(det) < 0.00001f)
		return false;

	const Vector3 T = O - V0;

	float u = T.dot(P);

	if (u < 0.0f || u > det)
		return false;

	const Vector3 Q = T.cross(E1);

	float v = D.dot(Q);
	if (v < 0.0f || u + v > det)
		return false;

	float t = E2.dot(Q);
	if (t < 0.0f)
		return false;

	float invDet = 1.0f / det;

	_t = t * invDet;
	_u = u * invDet;
	_v = v * invDet;

	return true;
}

// Gives an uniformly mapped point on a sphere from two floats between 0..1.

inline Vector3 pointOnSphere(float sx, float sy)
{
#if 0
	// Rejection sampling
	Vector3 v;
	do {
		v.x = randomFloat() * 2.0f - 1.0f;
		v.y = randomFloat() * 2.0f - 1.0f;
		v.z = randomFloat() * 2.0f - 1.0f;
	} while (v.length2() > 1.0f);
	return v;
#else
	sx = sx * 3.14159265f * 2.0f;
	sy = sy * 2.0f - 1.0f;
	float r = sqrtf(1.0f - sy * sy);
	return Vector3(cosf(sx) * r, sy, sinf(sx) * r);
#endif
}

// Gives an uniformly mapped hemisphere.
// TODO: check this!

inline Vector3 pointOnHemisphere(const Vector3& n, float theta, float phi)
{
	Vector3 ref = (fabs(n.x) < 0.5f) ?
		Vector3(1.0f, 0.0f, 0.0f) : Vector3(0.0f, 1.0f, 0.0f);

	Vector3 u = n.cross(ref).normalize();
	Vector3 v = u.cross(n);

	return
		u * sinf(theta) * cosf(phi) +
		v * sinf(theta) * sinf(phi) +
		n * cosf(theta);
}

// Gives a cosine weighted point on a hemisphere.

inline Vector3 pointOnHemisphereCosine(const Vector3& n, float x, float y)
{
	Vector3 ref = (fabs(n.x) < 0.5f) ?
		Vector3(1.0f, 0.0f, 0.0f) : Vector3(0.0f, 1.0f, 0.0f);

	Vector3 u = n.cross(ref).normalize();
	Vector3 v = u.cross(n);

	x *= 2.0f * 3.14159265f;
	y = sqrt(y);

	float xx = cosf(x) * y;
	float yy = sinf(x) * y;
	float zz = xx * xx + yy * yy;
	zz = (zz >= 1.0f) ? 0.0f : sqrtf(1.0f - zz);

	return
		u * xx + v * yy + n * zz;
}

// Reflects a ray.

inline Vector3 reflectRay(const Vector3& d, const Vector3& n)
{
	return n - n * (2.0f * d.dot(n));
}

// Refracts a ray.
// TODO: give only eta1 / eta2

inline Vector3 refractRay(const Vector3& d_, const Vector3& n,
                         float eta1, float eta2)
{
	Vector3 d = -d_;
	float dot1 = d.dot(n);
	float eta = eta1 / eta2;

	Vector3 v1 = (d - n * dot1) * -eta;
	Vector3 v2 = n * -sqrtf(1 - eta*eta * (1 - dot1*dot1));

	return v1 + v2;
}

#endif
