//===-- metal_atomic -------------------------------------------------------===//
// Copyright (c) 2014 Apple Inc. All rights reserved
//===----------------------------------------------------------------------===//

#ifndef __METAL_ATOMIC
#define __METAL_ATOMIC

#include <metal_compute>

#define _AIR_MEM_ORDER_RELAXED 0x0

namespace metal {
  // 2.3 Atomic Data Types
  struct atomic_int {
    private:
    __attribute__((unused)) int t;
  };

  struct atomic_uint {
    private:
    __attribute__((unused)) uint t;
  };

  // 5.12 Atomic Functions
  enum memory_order {
    memory_order_relaxed = _AIR_MEM_ORDER_RELAXED
  };

  // 5.12.1 Atomic Store Functions
  METAL_INTERNAL void _air_atomic_global_store_i32(device uint *object, uint desired, uint mem_order, uint mem_scope) __asm("air.atomic.global.store.i32");
  METAL_INTERNAL void _air_atomic_local_store_i32(threadgroup uint *object, uint desired, uint mem_order, uint mem_scope) __asm("air.atomic.local.store.i32");

  METAL_FUNC void atomic_store_explicit(volatile device atomic_int* obj, int desired, memory_order order) {
    return _air_atomic_global_store_i32((device uint *)obj, (uint)desired, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC void atomic_store_explicit(volatile device atomic_uint* obj, uint desired, memory_order order) {
    return _air_atomic_global_store_i32((device uint *)obj, (uint)desired, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC void atomic_store_explicit(volatile threadgroup atomic_int* obj, int desired, memory_order order) {
    return _air_atomic_local_store_i32((threadgroup uint *)obj, (uint)desired, order, _AIR_MEM_SCOPE_WORK_GROUP);
  }
  METAL_FUNC void atomic_store_explicit(volatile threadgroup atomic_uint* obj, uint desired, memory_order order) {
    return _air_atomic_local_store_i32((threadgroup uint *)obj, (uint)desired, order, _AIR_MEM_SCOPE_WORK_GROUP);
  }

  // 5.12.2 Atomic Load Functions
  METAL_INTERNAL uint _air_atomic_global_load_i32(device uint *object, uint mem_order, uint mem_scope) __asm("air.atomic.global.load.i32");
  METAL_INTERNAL uint _air_atomic_local_load_i32(threadgroup uint *object, uint mem_order, uint mem_scope) __asm("air.atomic.local.load.i32");

  METAL_FUNC int atomic_load_explicit(volatile device atomic_int* obj, memory_order order) {
    return _air_atomic_global_load_i32((device uint *)obj, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_load_explicit(volatile device atomic_uint* obj, memory_order order) {
    return _air_atomic_global_load_i32((device uint *)obj, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_load_explicit(volatile threadgroup atomic_int* obj, memory_order order) {
    return _air_atomic_local_load_i32((threadgroup uint *)obj, order, _AIR_MEM_SCOPE_WORK_GROUP);
  }
  METAL_FUNC uint atomic_load_explicit(volatile threadgroup atomic_uint* obj, memory_order order) {
    return _air_atomic_local_load_i32((threadgroup uint *)obj, order, _AIR_MEM_SCOPE_WORK_GROUP);
  }

  // 5.12.3 Atomic Exchange Functions
  METAL_INTERNAL uint _air_atomic_global_xchg_i32(device uint *object, uint desired, uint mem_order, uint mem_scope) __asm("air.atomic.global.xchg.i32");
  METAL_INTERNAL uint _air_atomic_local_xchg_i32(threadgroup uint *object, uint desired, uint mem_order, uint mem_scope) __asm("air.atomic.local.xchg.i32");

  METAL_FUNC int atomic_exchange_explicit(volatile device atomic_int *obj, int desired, memory_order order) {
    return _air_atomic_global_xchg_i32((device uint *)obj, (uint)desired, _AIR_MEM_ORDER_RELAXED, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_exchange_explicit(volatile device atomic_uint *obj, uint desired, memory_order order) {
    return _air_atomic_global_xchg_i32((device uint *)obj, (uint)desired, _AIR_MEM_ORDER_RELAXED, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_exchange_explicit(volatile threadgroup atomic_int *obj, int desired, memory_order order) {
    return _air_atomic_local_xchg_i32((threadgroup uint *)obj, (uint)desired, _AIR_MEM_ORDER_RELAXED, _AIR_MEM_SCOPE_WORK_GROUP);
  }
  METAL_FUNC uint atomic_exchange_explicit(volatile threadgroup atomic_uint *obj, uint desired, memory_order order) {
    return _air_atomic_local_xchg_i32((threadgroup uint *)obj, (uint)desired, _AIR_MEM_ORDER_RELAXED, _AIR_MEM_SCOPE_WORK_GROUP);
  }

  // 5.12.4 Atomic Compare and Exchange functions
  METAL_INTERNAL uint _air_atomic_global_cmpxchg_weak_i32(device uint *object, thread uint *expected, uint desired, uint mem_order_success, uint mem_order_failure, uint mem_scope) __asm("air.atomic.global.cmpxchg.weak.i32");
  METAL_INTERNAL uint _air_atomic_local_cmpxchg_weak_i32(threadgroup uint *object, thread uint *expected, uint desired, uint mem_order_success, uint mem_order_failure, uint mem_scope) __asm("air.atomic.local.cmpxchg.weak.i32");

  METAL_FUNC bool atomic_compare_exchange_weak_explicit(volatile device atomic_int *obj, thread int *expected, int desired, memory_order succ, memory_order fail) {
    return _air_atomic_global_cmpxchg_weak_i32((device uint *)obj, (thread uint *)expected, (uint)desired, succ, fail, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC bool atomic_compare_exchange_weak_explicit(volatile device atomic_uint *obj, thread uint *expected, uint desired, memory_order succ, memory_order fail) {
    return _air_atomic_global_cmpxchg_weak_i32((device uint *)obj, (thread uint *)expected, (uint)desired, succ, fail, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC bool atomic_compare_exchange_weak_explicit(volatile threadgroup atomic_int *obj, thread int *expected, int desired, memory_order succ, memory_order fail) {
    return _air_atomic_local_cmpxchg_weak_i32((threadgroup uint *)obj, (thread uint *)expected, (uint)desired, succ, fail, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC bool atomic_compare_exchange_weak_explicit(volatile threadgroup atomic_uint *obj, thread uint *expected, uint desired, memory_order succ, memory_order fail) {
    return _air_atomic_local_cmpxchg_weak_i32((threadgroup uint *)obj, (thread uint *)expected, (uint)desired, succ, fail, _AIR_MEM_SCOPE_DEVICE);
  }

  // 5.12.5 Atomic Fetch and Modify functions
  METAL_INTERNAL int _air_atomic_global_add_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.add.s.i32");
  METAL_INTERNAL uint _air_atomic_global_add_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.add.u.i32");
  METAL_INTERNAL int _air_atomic_local_add_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.add.s.i32");
  METAL_INTERNAL uint _air_atomic_local_add_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.add.u.i32");

  METAL_FUNC int atomic_fetch_add_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_add_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_add_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_add_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_add_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_add_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_add_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_add_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_sub_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.sub.s.i32");
  METAL_INTERNAL uint _air_atomic_global_sub_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.sub.u.i32");
  METAL_INTERNAL int _air_atomic_local_sub_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.sub.s.i32");
  METAL_INTERNAL uint _air_atomic_local_sub_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.sub.u.i32");

  METAL_FUNC int atomic_fetch_sub_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_sub_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_sub_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_sub_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_sub_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_sub_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_sub_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_sub_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_or_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.or.s.i32");
  METAL_INTERNAL uint _air_atomic_global_or_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.or.u.i32");
  METAL_INTERNAL int _air_atomic_local_or_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.or.s.i32");
  METAL_INTERNAL uint _air_atomic_local_or_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.or.u.i32");

  METAL_FUNC int atomic_fetch_or_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_or_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_or_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_or_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_or_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_or_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_or_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_or_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_xor_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.xor.s.i32");
  METAL_INTERNAL uint _air_atomic_global_xor_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.xor.u.i32");
  METAL_INTERNAL int _air_atomic_local_xor_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.xor.s.i32");
  METAL_INTERNAL uint _air_atomic_local_xor_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.xor.u.i32");

  METAL_FUNC int atomic_fetch_xor_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_xor_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_xor_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_xor_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_xor_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_xor_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_xor_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_xor_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_and_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.and.s.i32");
  METAL_INTERNAL uint _air_atomic_global_and_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.and.u.i32");
  METAL_INTERNAL int _air_atomic_local_and_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.and.s.i32");
  METAL_INTERNAL uint _air_atomic_local_and_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.and.u.i32");

  METAL_FUNC int atomic_fetch_and_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_and_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_and_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_and_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_and_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_and_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_and_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_and_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_min_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.min.s.i32");
  METAL_INTERNAL uint _air_atomic_global_min_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.min.u.i32");
  METAL_INTERNAL int _air_atomic_local_min_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.min.s.i32");
  METAL_INTERNAL uint _air_atomic_local_min_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.min.u.i32");

  METAL_FUNC int atomic_fetch_min_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_min_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_min_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_min_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_min_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_min_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_min_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_min_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_INTERNAL int _air_atomic_global_max_s_i32(device int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.max.s.i32");
  METAL_INTERNAL uint _air_atomic_global_max_u_i32(device uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.global.max.u.i32");
  METAL_INTERNAL int _air_atomic_local_max_s_i32(threadgroup int *object, int operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.max.s.i32");
  METAL_INTERNAL uint _air_atomic_local_max_u_i32(threadgroup uint *object, uint operand, uint mem_order, uint mem_scope) __asm("air.atomic.local.max.u.i32");

  METAL_FUNC int atomic_fetch_max_explicit(volatile device atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_global_max_s_i32((device int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_max_explicit(volatile device atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_global_max_u_i32((device uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC int atomic_fetch_max_explicit(volatile threadgroup atomic_int *obj, int arg, memory_order order) {
    return _air_atomic_local_max_s_i32((threadgroup int *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
  METAL_FUNC uint atomic_fetch_max_explicit(volatile threadgroup atomic_uint *obj, uint arg, memory_order order) {
    return _air_atomic_local_max_u_i32((threadgroup uint *)obj, arg, order, _AIR_MEM_SCOPE_DEVICE);
  }
} // namespace metal

#endif // __METAL_ATOMIC
