//===-- metal_geometric ----------------------------------------------------===//
// Copyright (c) 2014 Apple Inc. All rights reserved
//===----------------------------------------------------------------------===//

#ifndef __METAL_GEOMETRIC
#define __METAL_GEOMETRIC

#include <metal_math>

namespace metal {
// 5.7 Geometric Functions
  

  METAL_ASM half dot(vec<half,2> x, vec<half,2> y) __asm("air.dot.v2f16");
  
  METAL_FUNC half length_squared(vec<half,2> x) { return dot(x, x); }  

  METAL_FUNC half length(vec<half,2> x) {
  	return sqrt(length_squared(x));
  }
  
  METAL_FUNC vec<half,2> normalize(vec<half,2> x) {
  	return x * rsqrt(length_squared(x));
  }
  
  METAL_FUNC half distance(vec<half,2> x, vec<half,2> y) {
  	return length(x - y);
  }

  METAL_FUNC half distance_squared(vec<half,2> x, vec<half,2> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<half,2> faceforward(vec<half,2> N, vec<half,2> I, vec<half,2> Nref) {
    return (dot(Nref, I) < half(0.0)) ? N : -N;
  }

  METAL_FUNC vec<half,2> reflect(vec<half,2> I, vec<half,2> N) {
    return I - half(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<half,2> refract(vec<half,2> I, vec<half,2> N, half eta) {
    half k = half(1.0) - eta * eta * (half(1.0) - dot(N, I) * dot(N, I));
    vec<half,2> R;
    if(k < half(0.0)) {
      R = vec<half,2>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_ASM half dot(vec<half,3> x, vec<half,3> y) __asm("air.dot.v3f16");
  
  METAL_FUNC half length_squared(vec<half,3> x) { return dot(x, x); }  

  METAL_FUNC half length(vec<half,3> x) {
  	return sqrt(length_squared(x));
  }
  
  METAL_FUNC vec<half,3> normalize(vec<half,3> x) {
  	return x * rsqrt(length_squared(x));
  }
  
  METAL_FUNC half distance(vec<half,3> x, vec<half,3> y) {
  	return length(x - y);
  }

  METAL_FUNC half distance_squared(vec<half,3> x, vec<half,3> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<half,3> faceforward(vec<half,3> N, vec<half,3> I, vec<half,3> Nref) {
    return (dot(Nref, I) < half(0.0)) ? N : -N;
  }

  METAL_FUNC vec<half,3> reflect(vec<half,3> I, vec<half,3> N) {
    return I - half(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<half,3> refract(vec<half,3> I, vec<half,3> N, half eta) {
    half k = half(1.0) - eta * eta * (half(1.0) - dot(N, I) * dot(N, I));
    vec<half,3> R;
    if(k < half(0.0)) {
      R = vec<half,3>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_ASM half dot(vec<half,4> x, vec<half,4> y) __asm("air.dot.v4f16");
  
  METAL_FUNC half length_squared(vec<half,4> x) { return dot(x, x); }  

  METAL_FUNC half length(vec<half,4> x) {
  	return sqrt(length_squared(x));
  }
  
  METAL_FUNC vec<half,4> normalize(vec<half,4> x) {
  	return x * rsqrt(length_squared(x));
  }
  
  METAL_FUNC half distance(vec<half,4> x, vec<half,4> y) {
  	return length(x - y);
  }

  METAL_FUNC half distance_squared(vec<half,4> x, vec<half,4> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<half,4> faceforward(vec<half,4> N, vec<half,4> I, vec<half,4> Nref) {
    return (dot(Nref, I) < half(0.0)) ? N : -N;
  }

  METAL_FUNC vec<half,4> reflect(vec<half,4> I, vec<half,4> N) {
    return I - half(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<half,4> refract(vec<half,4> I, vec<half,4> N, half eta) {
    half k = half(1.0) - eta * eta * (half(1.0) - dot(N, I) * dot(N, I));
    vec<half,4> R;
    if(k < half(0.0)) {
      R = vec<half,4>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_FUNC vec<half,3> cross(vec<half,3> x, vec<half,3> y) {
    return vec<half,3>((x[1] * y[2]) - (y[1] * x[2]),
      (x[2] * y[0]) - (y[2] * x[0]),
      (x[0] * y[1]) - (y[0] * x[1]));
  }

  

  METAL_ASM float dot(vec<float,2> x, vec<float,2> y) __asm("air.dot.v2f32");
  
  METAL_FUNC float length_squared(vec<float,2> x) { return dot(x, x); }  

  namespace fast {
  	METAL_FUNC float length(vec<float,2> x) {
		return sqrt(length_squared(x));
  	}
  
  	METAL_FUNC vec<float,2> normalize(vec<float,2> x) {
		return x * rsqrt(length_squared(x));
  	}
  
  	METAL_FUNC float distance(vec<float,2> x, vec<float,2> y) {
  		return length(x - y);
  	}
  }
  
  namespace precise {
  	METAL_FUNC float length(vec<float,2> x) {
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			return float(0x1.0p+66) * sqrt(lenSq);
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			return float(0x1.0p-64) * sqrt(lenSq);
		}
		return sqrt(lenSq);
  	}
  
  	METAL_FUNC vec<float,2> normalize(vec<float,2> x) {
		if (any(isinf(x)))
			return vec<float,2>(as_type<float>((uint)(0x7fc00000)));
		
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			if (isinf(lenSq))
			{
				bool2 Ts = isinf(x);
				float2 r = select(vec<float,2>(0), vec<float,2>(1.0), Ts);
				return copysign(r, x);
			}
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			if (lenSq == float(0.0))
				return x;
		}
		return x * rsqrt(lenSq);
  	}
  
  	METAL_FUNC float distance(vec<float,2> x, vec<float,2> y) {
  		return length(x - y);
  	}
  }
  
  METAL_FUNC float length(vec<float,2> x) {
#if defined(__FAST_MATH__)
	return fast::length(x);
#else
	return precise::length(x);
#endif
  }
  
  METAL_FUNC vec<float,2> normalize(vec<float,2> x) {
#if defined(__FAST_MATH__)
	return fast::normalize(x);
#else
	return precise::normalize(x);
#endif
  }
  
  METAL_FUNC float distance(vec<float,2> x, vec<float,2> y) {
#if defined(__FAST_MATH__)
	return fast::distance(x, y);
#else
	return precise::distance(x, y);
#endif
  }

  METAL_FUNC float distance_squared(vec<float,2> x, vec<float,2> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<float,2> faceforward(vec<float,2> N, vec<float,2> I, vec<float,2> Nref) {
    return (dot(Nref, I) < float(0.0)) ? N : -N;
  }

  METAL_FUNC vec<float,2> reflect(vec<float,2> I, vec<float,2> N) {
    return I - float(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<float,2> refract(vec<float,2> I, vec<float,2> N, float eta) {
    float k = float(1.0) - eta * eta * (float(1.0) - dot(N, I) * dot(N, I));
    vec<float,2> R;
    if(k < float(0.0)) {
      R = vec<float,2>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_ASM float dot(vec<float,3> x, vec<float,3> y) __asm("air.dot.v3f32");
  
  METAL_FUNC float length_squared(vec<float,3> x) { return dot(x, x); }  

  namespace fast {
  	METAL_FUNC float length(vec<float,3> x) {
		return sqrt(length_squared(x));
  	}
  
  	METAL_FUNC vec<float,3> normalize(vec<float,3> x) {
		return x * rsqrt(length_squared(x));
  	}
  
  	METAL_FUNC float distance(vec<float,3> x, vec<float,3> y) {
  		return length(x - y);
  	}
  }
  
  namespace precise {
  	METAL_FUNC float length(vec<float,3> x) {
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			return float(0x1.0p+66) * sqrt(lenSq);
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			return float(0x1.0p-64) * sqrt(lenSq);
		}
		return sqrt(lenSq);
  	}
  
  	METAL_FUNC vec<float,3> normalize(vec<float,3> x) {
		if (any(isinf(x)))
			return vec<float,3>(as_type<float>((uint)(0x7fc00000)));
		
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			if (isinf(lenSq))
			{
				bool3 Ts = isinf(x);
				float3 r = select(vec<float,3>(0), vec<float,3>(1.0), Ts);
				return copysign(r, x);
			}
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			if (lenSq == float(0.0))
				return x;
		}
		return x * rsqrt(lenSq);
  	}
  
  	METAL_FUNC float distance(vec<float,3> x, vec<float,3> y) {
  		return length(x - y);
  	}
  }
  
  METAL_FUNC float length(vec<float,3> x) {
#if defined(__FAST_MATH__)
	return fast::length(x);
#else
	return precise::length(x);
#endif
  }
  
  METAL_FUNC vec<float,3> normalize(vec<float,3> x) {
#if defined(__FAST_MATH__)
	return fast::normalize(x);
#else
	return precise::normalize(x);
#endif
  }
  
  METAL_FUNC float distance(vec<float,3> x, vec<float,3> y) {
#if defined(__FAST_MATH__)
	return fast::distance(x, y);
#else
	return precise::distance(x, y);
#endif
  }

  METAL_FUNC float distance_squared(vec<float,3> x, vec<float,3> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<float,3> faceforward(vec<float,3> N, vec<float,3> I, vec<float,3> Nref) {
    return (dot(Nref, I) < float(0.0)) ? N : -N;
  }

  METAL_FUNC vec<float,3> reflect(vec<float,3> I, vec<float,3> N) {
    return I - float(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<float,3> refract(vec<float,3> I, vec<float,3> N, float eta) {
    float k = float(1.0) - eta * eta * (float(1.0) - dot(N, I) * dot(N, I));
    vec<float,3> R;
    if(k < float(0.0)) {
      R = vec<float,3>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_ASM float dot(vec<float,4> x, vec<float,4> y) __asm("air.dot.v4f32");
  
  METAL_FUNC float length_squared(vec<float,4> x) { return dot(x, x); }  

  namespace fast {
  	METAL_FUNC float length(vec<float,4> x) {
		return sqrt(length_squared(x));
  	}
  
  	METAL_FUNC vec<float,4> normalize(vec<float,4> x) {
		return x * rsqrt(length_squared(x));
  	}
  
  	METAL_FUNC float distance(vec<float,4> x, vec<float,4> y) {
  		return length(x - y);
  	}
  }
  
  namespace precise {
  	METAL_FUNC float length(vec<float,4> x) {
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			return float(0x1.0p+66) * sqrt(lenSq);
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			return float(0x1.0p-64) * sqrt(lenSq);
		}
		return sqrt(lenSq);
  	}
  
  	METAL_FUNC vec<float,4> normalize(vec<float,4> x) {
		if (any(isinf(x)))
			return vec<float,4>(as_type<float>((uint)(0x7fc00000)));
		
		float lenSq = length_squared(x);
		if (isinf(lenSq))
		{
			x *= float(0x1.0p-66);
			lenSq = length_squared(x);
			if (isinf(lenSq))
			{
				bool4 Ts = isinf(x);
				float4 r = select(vec<float,4>(0), vec<float,4>(1.0), Ts);
				return copysign(r, x);
			}
		}
		else if (lenSq < float(FLT_MIN)/float(FLT_EPSILON))
		{
			x *= float(0x1.0p+64);
			lenSq = length_squared(x);
			if (lenSq == float(0.0))
				return x;
		}
		return x * rsqrt(lenSq);
  	}
  
  	METAL_FUNC float distance(vec<float,4> x, vec<float,4> y) {
  		return length(x - y);
  	}
  }
  
  METAL_FUNC float length(vec<float,4> x) {
#if defined(__FAST_MATH__)
	return fast::length(x);
#else
	return precise::length(x);
#endif
  }
  
  METAL_FUNC vec<float,4> normalize(vec<float,4> x) {
#if defined(__FAST_MATH__)
	return fast::normalize(x);
#else
	return precise::normalize(x);
#endif
  }
  
  METAL_FUNC float distance(vec<float,4> x, vec<float,4> y) {
#if defined(__FAST_MATH__)
	return fast::distance(x, y);
#else
	return precise::distance(x, y);
#endif
  }

  METAL_FUNC float distance_squared(vec<float,4> x, vec<float,4> y) {
  	return length_squared(x - y);
  }

  METAL_FUNC vec<float,4> faceforward(vec<float,4> N, vec<float,4> I, vec<float,4> Nref) {
    return (dot(Nref, I) < float(0.0)) ? N : -N;
  }

  METAL_FUNC vec<float,4> reflect(vec<float,4> I, vec<float,4> N) {
    return I - float(2) * dot(N, I) * N;
  }
  
  METAL_FUNC vec<float,4> refract(vec<float,4> I, vec<float,4> N, float eta) {
    float k = float(1.0) - eta * eta * (float(1.0) - dot(N, I) * dot(N, I));
    vec<float,4> R;
    if(k < float(0.0)) {
      R = vec<float,4>(0.0);
    } else {
      R = eta * I - (eta * dot(N, I) + sqrt(k)) * N;
    }
    return R;
  }


  METAL_FUNC vec<float,3> cross(vec<float,3> x, vec<float,3> y) {
    return vec<float,3>((x[1] * y[2]) - (y[1] * x[2]),
      (x[2] * y[0]) - (y[2] * x[0]),
      (x[0] * y[1]) - (y[0] * x[1]));
  }



} // namespace metal

#endif // __METAL_GEOMETRIC
