//===-- metal_matrix -------------------------------------------------------===//
// Copyright (c) 2014 Apple Inc. All rights reserved
//===----------------------------------------------------------------------===//

#ifndef __METAL_MATRIX_H
#define __METAL_MATRIX_H

#include <metal_geometric>

namespace metal {
  // 2.2 Matrices
  template <typename T, int numCols, int numRows = numCols> class matrix {
    matrix() { }
  };


#pragma mark matrix2 half
  template <int numRows> class matrix<half,2,numRows> {
    typedef half T;
    enum { numCols = 2 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r)  thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r)  device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r)  threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r)  constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(const thread  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const device  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const threadgroup  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const constant  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1)
    {
      cols[0] = c0;
      cols[1] = c1;
    }

    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  };


#pragma mark matrix3 half
  template <int numRows> class matrix<half,3,numRows> {
    typedef half T;
    enum { numCols = 3 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r) thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r) device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r) constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(thread  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(device  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(threadgroup  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(constant  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1, const vec<T,numRows> c2)
    {
      cols[0] = c0;
      cols[1] = c1;
      cols[2] = c2;
    }
    
    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  };

#pragma mark matrix4 half
  template <int numRows> class matrix<half,4,numRows> {
    typedef half T;
    enum { numCols = 4 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r) thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r) device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r) constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(const thread  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const device  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const threadgroup  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const constant  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<float,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1, const vec<T,numRows> c2, const vec<T,numRows> c3)
    {
      cols[0] = c0;
      cols[1] = c1;
      cols[2] = c2;
      cols[3] = c3;
    }

    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }

  };

#pragma mark matrix2 float
  template <int numRows> class matrix<float,2,numRows> {
    typedef float T;
    enum { numCols = 2 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r)  thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r)  device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r)  threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r)  constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(const thread  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const device  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const threadgroup  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const constant  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1)
    {
      cols[0] = c0;
      cols[1] = c1;
    }

    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  };


#pragma mark matrix3 float
  template <int numRows> class matrix<float,3,numRows> {
    typedef float T;
    enum { numCols = 3 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r) thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r) device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r) constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(thread  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(device  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(threadgroup  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(constant  const T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1, const vec<T,numRows> c2)
    {
      cols[0] = c0;
      cols[1] = c1;
      cols[2] = c2;
    }
    
    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  };

#pragma mark matrix4 float
  template <int numRows> class matrix<float,4,numRows> {
    typedef float T;
    enum { numCols = 4 };
    vec<T,numRows> cols[numCols];
  public:
    // 2.2.2 Accessing Matrix Components
    METAL_FUNC thread  vec<T,numRows>& operator[] (int r) thread  { return cols[r]; }
    METAL_FUNC const thread  vec<T,numRows>& operator[] (int r) thread  const { return cols[r]; }
    METAL_FUNC device  vec<T,numRows>& operator[] (int r) device  { return cols[r]; }
    METAL_FUNC const device  vec<T,numRows>& operator[] (int r) device  const { return cols[r]; }
    METAL_FUNC threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  { return cols[r]; }
    METAL_FUNC const threadgroup  vec<T,numRows>& operator[] (int r) threadgroup  const { return cols[r]; }
    METAL_FUNC constant  vec<T,numRows>& operator[] (int r) constant  { return cols[r]; }
    METAL_FUNC const constant  vec<T,numRows>& operator[] (int r) constant  const { return cols[r]; }

    // 2.2.4 Matrix Constructors
    METAL_FUNC matrix() { }

    METAL_FUNC explicit matrix(const thread  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const device  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const threadgroup  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }
    METAL_FUNC explicit matrix(const constant  T& val) {
      for (int r=0;r!=numCols;++r) {
        cols[r] = vec<T,numRows>(T(0));
        if (r < numRows)
          cols[r][r] = val;
      }
    }

    METAL_FUNC matrix(const thread  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const device  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const threadgroup  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC matrix(const constant  matrix<T,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC explicit matrix(const thread  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const device  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const threadgroup  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }
    METAL_FUNC explicit matrix(const constant  matrix<half,numCols,numRows>& m) {
      for (int r=0;r!=numCols;++r)
        cols[r] = static_cast< vec<T,numRows> >(m[r]);
    }

    METAL_FUNC matrix(const vec<T,numRows> c0, const vec<T,numRows> c1, const vec<T,numRows> c2, const vec<T,numRows> c3)
    {
      cols[0] = c0;
      cols[1] = c1;
      cols[2] = c2;
      cols[3] = c3;
    }

    // Matrix Assignments
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC thread  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) thread  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC device  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) device  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) threadgroup  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const thread  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const device  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const threadgroup  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }
    METAL_FUNC constant  matrix<T,numCols,numRows>& operator=(const constant  matrix<T,numCols,numRows>& m) constant  {
      for (int r=0;r!=numCols;++r)
        cols[r] = m[r];
      return *this;
    }

    // 3.2 Matrix Operators
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator*= (const T v) thread 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator*= (const T v) device 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator*= (const T v) threadgroup 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator*= (const T v) constant 
  {
    for (int r=0;r!=numCols;++r)
      cols[r] *= v;
    return *this;
  }
    
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator+= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] += m[i];
    return *this;
  }

  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (thread const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (device const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (threadgroup const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC thread  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) thread 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC device  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) device 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC threadgroup  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) threadgroup 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }
  METAL_FUNC constant  matrix<T,numCols,numRows>& operator-= (constant const matrix<T,numCols,numRows>& m) constant 
  {
    for (int i=0;i!=numCols;++i)
      cols[i] -= m[i];
    return *this;
  }

  };

  // 3.2 Matrix Operators (non-member)
#pragma mark operator* vector matrix
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numCols> operator* (const vec<T,numRows> v, thread const matrix<T,numCols,numRows>& m)
  {
    vec<T,numCols> r((T)0);
    for (int i=0;i!=numCols;++i)
      r[i] = dot(m[i],v);
     return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numCols> operator* (const vec<T,numRows> v, device const matrix<T,numCols,numRows>& m)
  {
    vec<T,numCols> r((T)0);
    for (int i=0;i!=numCols;++i)
      r[i] = dot(m[i],v);
     return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numCols> operator* (const vec<T,numRows> v, threadgroup const matrix<T,numCols,numRows>& m)
  {
    vec<T,numCols> r((T)0);
    for (int i=0;i!=numCols;++i)
      r[i] = dot(m[i],v);
     return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numCols> operator* (const vec<T,numRows> v, constant const matrix<T,numCols,numRows>& m)
  {
    vec<T,numCols> r((T)0);
    for (int i=0;i!=numCols;++i)
      r[i] = dot(m[i],v);
     return r;
  }

#pragma mark operator* matrix vector
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numRows> operator* (thread const matrix<T,numCols,numRows>& m, const vec<T,numCols> v)
  {
    vec<T,numRows> r = v[0] * m[0];
    for (int i=1;i!=numCols;++i)
      r += v[i] * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numRows> operator* (device const matrix<T,numCols,numRows>& m, const vec<T,numCols> v)
  {
    vec<T,numRows> r = v[0] * m[0];
    for (int i=1;i!=numCols;++i)
      r += v[i] * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m, const vec<T,numCols> v)
  {
    vec<T,numRows> r = v[0] * m[0];
    for (int i=1;i!=numCols;++i)
      r += v[i] * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC vec<T,numRows> operator* (constant const matrix<T,numCols,numRows>& m, const vec<T,numCols> v)
  {
    vec<T,numRows> r = v[0] * m[0];
    for (int i=1;i!=numCols;++i)
      r += v[i] * m[i];
    return r;
  }

#pragma mark operator* matrix matrix
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (thread const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (thread const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (thread const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (thread const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (device const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (device const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (device const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (device const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (constant const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (constant const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (constant const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }
  template <typename T, int numCols, int numRows, int numCols_>
  METAL_FUNC matrix<T,numCols_,numRows> operator* (constant const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols_,numCols>& m1)
  {
    matrix<T,numCols_,numRows> r;
    for (int i=0;i!=numCols_;++i)
      r[i] = m0 * m1[i];
    return r;
  }

#pragma mark operator* scalar matrix
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (const T v, thread const matrix<T,numCols,numRows>& m)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = v * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (const T v, device const matrix<T,numCols,numRows>& m)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = v * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (const T v, threadgroup const matrix<T,numCols,numRows>& m)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = v * m[i];
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (const T v, constant const matrix<T,numCols,numRows>& m)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = v * m[i];
    return r;
  }

#pragma mark operator* matrix scalar
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (thread const matrix<T,numCols,numRows>& m, const T v)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = m[i] * v;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (device const matrix<T,numCols,numRows>& m, const T v)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = m[i] * v;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (threadgroup const matrix<T,numCols,numRows>& m, const T v)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = m[i] * v;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator* (constant const matrix<T,numCols,numRows>& m, const T v)
  {
    matrix<T,numCols,numRows> r;
    for (int i=0;i!=numCols;++i)
      r[i] = m[i] * v;
    return r;
  }

#pragma mark operator+
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (thread const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (thread const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (thread const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (thread const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (device const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (device const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (device const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (device const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (threadgroup const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (threadgroup const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (threadgroup const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (threadgroup const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (constant const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (constant const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (constant const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator+ (constant const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r += m1;
    return r;
  }

#pragma mark operator-
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (thread const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (thread const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (thread const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (thread const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (device const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (device const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (device const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (device const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (threadgroup const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (threadgroup const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (threadgroup const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (threadgroup const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (constant const matrix<T,numCols,numRows>& m0, thread const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (constant const matrix<T,numCols,numRows>& m0, device const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (constant const matrix<T,numCols,numRows>& m0, threadgroup const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }
  template <typename T, int numCols, int numRows>
  METAL_FUNC matrix<T,numCols,numRows> operator- (constant const matrix<T,numCols,numRows>& m0, constant const matrix<T,numCols,numRows>& m1)
  {
    matrix<T,numCols,numRows> r= m0;
    r -= m1;
    return r;
  }

  // 5.7 Matrix Functions (non-member)

#pragma mark transpose
  template <typename T, int numCols, int numRows = numCols>
  METAL_FUNC matrix<T,numRows,numCols> transpose(thread  const matrix<T,numCols,numRows>& m) {
    matrix<T,numRows,numCols> r;
    for (int i=0;i!=numCols;++i)
      for (int j=0;j!=numRows;++j)
        r[j][i] = m[i][j];
    return r;
  }
  template <typename T, int numCols, int numRows = numCols>
  METAL_FUNC matrix<T,numRows,numCols> transpose(device  const matrix<T,numCols,numRows>& m) {
    matrix<T,numRows,numCols> r;
    for (int i=0;i!=numCols;++i)
      for (int j=0;j!=numRows;++j)
        r[j][i] = m[i][j];
    return r;
  }
  template <typename T, int numCols, int numRows = numCols>
  METAL_FUNC matrix<T,numRows,numCols> transpose(threadgroup  const matrix<T,numCols,numRows>& m) {
    matrix<T,numRows,numCols> r;
    for (int i=0;i!=numCols;++i)
      for (int j=0;j!=numRows;++j)
        r[j][i] = m[i][j];
    return r;
  }
  template <typename T, int numCols, int numRows = numCols>
  METAL_FUNC matrix<T,numRows,numCols> transpose(constant  const matrix<T,numCols,numRows>& m) {
    matrix<T,numRows,numCols> r;
    for (int i=0;i!=numCols;++i)
      for (int j=0;j!=numRows;++j)
        r[j][i] = m[i][j];
    return r;
  }
  
#pragma mark determinant
  template<typename T>
  METAL_FUNC T determinant(thread  const matrix<T,2,2>& m) {
    return m[0][0]*m[1][1] - m[0][1]*m[1][0];
  }
  template<typename T>
  METAL_FUNC T determinant(device  const matrix<T,2,2>& m) {
    return m[0][0]*m[1][1] - m[0][1]*m[1][0];
  }
  template<typename T>
  METAL_FUNC T determinant(threadgroup  const matrix<T,2,2>& m) {
    return m[0][0]*m[1][1] - m[0][1]*m[1][0];
  }
  template<typename T>
  METAL_FUNC T determinant(constant  const matrix<T,2,2>& m) {
    return m[0][0]*m[1][1] - m[0][1]*m[1][0];
  }

  template<typename T>
  METAL_FUNC T determinant(thread  const matrix<T,3,3>& m) {
    return m[0][0]*(m[1][1]*m[2][2] - m[1][2]*m[2][1]) -
    m[0][1]*(m[1][0]*m[2][2] - m[1][2]*m[2][0]) +
    m[0][2]*(m[1][0]*m[2][1] - m[1][1]*m[2][0]);
  }
  template<typename T>
  METAL_FUNC T determinant(device  const matrix<T,3,3>& m) {
    return m[0][0]*(m[1][1]*m[2][2] - m[1][2]*m[2][1]) -
    m[0][1]*(m[1][0]*m[2][2] - m[1][2]*m[2][0]) +
    m[0][2]*(m[1][0]*m[2][1] - m[1][1]*m[2][0]);
  }
  template<typename T>
  METAL_FUNC T determinant(threadgroup  const matrix<T,3,3>& m) {
    return m[0][0]*(m[1][1]*m[2][2] - m[1][2]*m[2][1]) -
    m[0][1]*(m[1][0]*m[2][2] - m[1][2]*m[2][0]) +
    m[0][2]*(m[1][0]*m[2][1] - m[1][1]*m[2][0]);
  }
  template<typename T>
  METAL_FUNC T determinant(constant  const matrix<T,3,3>& m) {
    return m[0][0]*(m[1][1]*m[2][2] - m[1][2]*m[2][1]) -
    m[0][1]*(m[1][0]*m[2][2] - m[1][2]*m[2][0]) +
    m[0][2]*(m[1][0]*m[2][1] - m[1][1]*m[2][0]);
  }
  
  template<typename T>
  METAL_FUNC T determinant(thread  const matrix<T,4,4>& m) {
    return
    m[0][0]*(
                m[1][1]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) +
                m[1][3]*(m[2][1]*m[3][2] - m[2][2]*m[3][1])
                ) -
    m[0][1]*(
                m[1][0]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][2] - m[2][2]*m[3][0])
                ) +
    m[0][2]*(
                m[1][0]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                ) -
    m[0][3]*(
                m[1][0]*(m[2][1]*m[3][2] - m[2][2]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][2] - m[2][2]*m[3][0]) +
                m[1][2]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                );
  }
  template<typename T>
  METAL_FUNC T determinant(device  const matrix<T,4,4>& m) {
    return
    m[0][0]*(
                m[1][1]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) +
                m[1][3]*(m[2][1]*m[3][2] - m[2][2]*m[3][1])
                ) -
    m[0][1]*(
                m[1][0]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][2] - m[2][2]*m[3][0])
                ) +
    m[0][2]*(
                m[1][0]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                ) -
    m[0][3]*(
                m[1][0]*(m[2][1]*m[3][2] - m[2][2]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][2] - m[2][2]*m[3][0]) +
                m[1][2]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                );
  }
  template<typename T>
  METAL_FUNC T determinant(threadgroup  const matrix<T,4,4>& m) {
    return
    m[0][0]*(
                m[1][1]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) +
                m[1][3]*(m[2][1]*m[3][2] - m[2][2]*m[3][1])
                ) -
    m[0][1]*(
                m[1][0]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][2] - m[2][2]*m[3][0])
                ) +
    m[0][2]*(
                m[1][0]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                ) -
    m[0][3]*(
                m[1][0]*(m[2][1]*m[3][2] - m[2][2]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][2] - m[2][2]*m[3][0]) +
                m[1][2]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                );
  }
  template<typename T>
  METAL_FUNC T determinant(constant  const matrix<T,4,4>& m) {
    return
    m[0][0]*(
                m[1][1]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) +
                m[1][3]*(m[2][1]*m[3][2] - m[2][2]*m[3][1])
                ) -
    m[0][1]*(
                m[1][0]*(m[2][2]*m[3][3] - m[2][3]*m[3][2]) -
                m[1][2]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][2] - m[2][2]*m[3][0])
                ) +
    m[0][2]*(
                m[1][0]*(m[2][1]*m[3][3] - m[2][3]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][3] - m[2][3]*m[3][0]) +
                m[1][3]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                ) -
    m[0][3]*(
                m[1][0]*(m[2][1]*m[3][2] - m[2][2]*m[3][1]) -
                m[1][1]*(m[2][0]*m[3][2] - m[2][2]*m[3][0]) +
                m[1][2]*(m[2][0]*m[3][1] - m[2][1]*m[3][0])
                );
  }


  using half2x2 = matrix<half,2,2>;
  using half2x3 = matrix<half,2,3>;
  using half2x4 = matrix<half,2,4>;
  using half3x2 = matrix<half,3,2>;
  using half3x3 = matrix<half,3,3>;
  using half3x4 = matrix<half,3,4>;
  using half4x2 = matrix<half,4,2>;
  using half4x3 = matrix<half,4,3>;
  using half4x4 = matrix<half,4,4>;


  using float2x2 = matrix<float,2,2>;
  using float2x3 = matrix<float,2,3>;
  using float2x4 = matrix<float,2,4>;
  using float3x2 = matrix<float,3,2>;
  using float3x3 = matrix<float,3,3>;
  using float3x4 = matrix<float,3,4>;
  using float4x2 = matrix<float,4,2>;
  using float4x3 = matrix<float,4,3>;
  using float4x4 = matrix<float,4,4>;

};  // namespace metal

#endif // __METAL_MATRIX_H
