view src/impl/arm/neon.c @ 27:d00b95f95dd1 default tip

impl/arm/neon: it compiles again, but is untested
author Paper <paper@tflc.us>
date Mon, 25 Nov 2024 00:33:02 -0500
parents e26874655738
children
line wrap: on
line source

/**
 * vec - a tiny SIMD vector library in C99
 * 
 * Copyright (c) 2024 Paper
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
**/

#include "vec/impl/arm/neon.h"
#include "vec/impl/generic.h"

#include <arm_neon.h>

// There is LOTS of preprocessor hacking here (as if the other files
// weren't bad enough... lol)

#define VEC_DEFINE_OPERATIONS_SIGN(sign, csign, bits, size) \
	union v##sign##int##bits##x##size##_impl_data { \
		v##sign##int##bits##x##size vec; \
		sign##int##bits##x##size##_t neon; \
	}; \
	\
	VEC_STATIC_ASSERT(VEC_ALIGNOF(sign##int##bits##x##size##_t) <= VEC_ALIGNOF(v##sign##int##bits##x##size), "vec: v" #sign "int" #bits "x" #size " alignment needs to be expanded to fit intrinsic type size"); \
	VEC_STATIC_ASSERT(sizeof(sign##int##bits##x##size##_t) <= sizeof(v##sign##int##bits##x##size), "vec: v" #sign "int" #bits "x" #size " needs to be expanded to fit intrinsic type size"); \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_load_aligned(const vec_##sign##int##bits in[size]) \
	{ \
		union v##sign##int##bits##x##size##_impl_data vec; \
		vec.neon = vld1_##sign##bits(in); \
		return vec.vec; \
	} \
	\
	static void v##sign##int##bits##x##size##_neon_store_aligned(v##sign##int##bits##x##size vec, vec_##sign##int##bits out[size]) \
	{ \
		vstore_lane_##bits(sign, ((union v##sign##int##bits##x##size##_impl_data *)&vec)->neon, out); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_add(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vadd_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_sub(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vsub_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_mul(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vmul_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_lshift(v##sign##int##bits##x##size vec1, vuint##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union vuint##bits##x##size##_impl_data *vec2d = (union vuint##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vshl_##sign##bits(vec1d->neon, (vreinterpret_##bits##_u##bits)vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_and(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vand_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_or(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = vorr_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_neon_xor(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->neon = veor_##sign##bits(vec1d->neon, vec2d->neon); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size##_impl v##sign##int##bits##x##size##_impl_neon = { \
		v##sign##int##bits##x##size##_fallback_splat, \
		v##sign##int##bits##x##size##_neon_load_aligned, \
		v##sign##int##bits##x##size##_neon_load_aligned, \
		v##sign##int##bits##x##size##_neon_store_aligned, \
		v##sign##int##bits##x##size##_neon_store_aligned, \
		v##sign##int##bits##x##size##_neon_add, \
		v##sign##int##bits##x##size##_neon_sub, \
		v##sign##int##bits##x##size##_neon_mul, \
		v##sign##int##bits##x##size##_fallback_div, \
		v##sign##int##bits##x##size##_fallback_avg, \
		v##sign##int##bits##x##size##_neon_and, \
		v##sign##int##bits##x##size##_neon_or, \
		v##sign##int##bits##x##size##_neon_xor, \
		v##sign##int##bits##x##size##_fallback_not, \
		v##sign##int##bits##x##size##_neon_lshift, \
		v##sign##int##bits##x##size##_fallback_rshift, \
		v##sign##int##bits##x##size##_fallback_lrshift, \
		v##sign##int##bits##x##size##_fallback_cmplt, \
		v##sign##int##bits##x##size##_fallback_cmple, \
		v##sign##int##bits##x##size##_fallback_cmpeq, \
		v##sign##int##bits##x##size##_fallback_cmpge, \
		v##sign##int##bits##x##size##_fallback_cmpgt, \
	};

#define VEC_DEFINE_OPERATIONS(bits, size) \
	VEC_DEFINE_OPERATIONS_SIGN(u, U, bits, size) \
	VEC_DEFINE_OPERATIONS_SIGN( ,  , bits, size)

// Ok, we'll start out with the 64-bit types.

#define vadd_8  vadd_s8
#define vadd_16 vadd_s16
#define vadd_32 vadd_s32
#define vsub_8  vsub_s8
#define vsub_16 vsub_s16
#define vsub_32 vsub_s32
#define vmul_8  vmul_s8
#define vmul_16 vmul_s16
#define vmul_32 vmul_s32
#define vshl_8  vshl_s8
#define vshl_16 vshl_s16
#define vshl_32 vshl_s32
#define veor_8  veor_s8
#define veor_16 veor_s16
#define veor_32 veor_s32
#define vorr_8  vorr_s8
#define vorr_16 vorr_s16
#define vorr_32 vorr_s32
#define vand_8  vand_s8
#define vand_16 vand_s16
#define vand_32 vand_s32
#define vld1_8  vld1_s8
#define vld1_16 vld1_s16
#define vld1_32 vld1_s32
#define vget_lane_8  vget_lane_s8
#define vget_lane_16 vget_lane_s16
#define vget_lane_32 vget_lane_s32
#define vstore_lane_8(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##8(vec, 0); \
		out[1] = vget_lane_##sign##8(vec, 1); \
		out[2] = vget_lane_##sign##8(vec, 2); \
		out[3] = vget_lane_##sign##8(vec, 3); \
		out[4] = vget_lane_##sign##8(vec, 4); \
		out[5] = vget_lane_##sign##8(vec, 5); \
		out[6] = vget_lane_##sign##8(vec, 6); \
		out[7] = vget_lane_##sign##8(vec, 7); \
	} while (0)
#define vstore_lane_16(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##16(vec, 0); \
		out[1] = vget_lane_##sign##16(vec, 1); \
		out[2] = vget_lane_##sign##16(vec, 2); \
		out[3] = vget_lane_##sign##16(vec, 3); \
	} while (0)
#define vstore_lane_32(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##32(vec, 0); \
		out[1] = vget_lane_##sign##32(vec, 1); \
	} while (0)
#define vreinterpret_8_u8(x) vreinterpret_s8_u8(x)
#define vreinterpret_16_u16(x) vreinterpret_s16_u16(x)
#define vreinterpret_32_u32(x) vreinterpret_s32_u32(x)

VEC_DEFINE_OPERATIONS(8, 8)
VEC_DEFINE_OPERATIONS(16, 4)
VEC_DEFINE_OPERATIONS(32, 2)

#undef vadd_8
#undef vadd_16
#undef vadd_32
#undef vsub_8
#undef vsub_16
#undef vsub_32
#undef vmul_8
#undef vmul_16
#undef vmul_32
#undef vshl_8
#undef vshl_16
#undef vshl_32
#undef veor_8
#undef veor_16
#undef veor_32
#undef vorr_8
#undef vorr_16
#undef vorr_32
#undef vand_8
#undef vand_16
#undef vand_32
#undef vld1_8
#undef vld1_16
#undef vld1_32
#undef vget_lane_8 
#undef vget_lane_16
#undef vget_lane_32
#undef vstore_lane_8
#undef vstore_lane_16
#undef vstore_lane_32
#undef vreinterpret_8_u8
#undef vreinterpret_16_u16
#undef vreinterpret_32_u32

///////////////////////////////////////////////////////////////////////////////
// 128-bit

// Now we can go ahead and do the 128-bit ones.

// NEON doesn't have native 64-bit multiplication, so we have
// to do it ourselves
static inline int64x2_t vmulq_s64(const int64x2_t a, const int64x2_t b)
{
    const uint32x2_t ac = vreinterpret_u32_s32(vmovn_s64(a));
    const uint32x2_t pr = vreinterpret_u32_s32(vmovn_s64(b));

    const int32x4_t hi = vmulq_s32(vreinterpretq_s32_s64(b), vreinterpretq_s32_s64(a));

    return vreinterpretq_s64_u64(vmlal_u32(vreinterpretq_u64_s64(vshlq_n_s64(vreinterpretq_s64_u64(vpaddlq_u32(vreinterpretq_u32_s32(hi))), 32)), ac, pr));
}

static inline uint64x2_t vmulq_u64(const uint64x2_t a, const uint64x2_t b)
{
    const uint32x2_t ac = vmovn_u64(a);
    const uint32x2_t pr = vmovn_u64(b);

    const uint32x4_t hi = vmulq_u32(vreinterpretq_u32_u64(b), vreinterpretq_u32_u64(a));

    return vmlal_u32(vshlq_n_u64(vpaddlq_u32(hi), 32), ac, pr);
}

#define vadd_8  vaddq_s8
#define vadd_16 vaddq_s16
#define vadd_32 vaddq_s32
#define vadd_64 vaddq_s64
#define vadd_u8  vaddq_u8
#define vadd_u16 vaddq_u16
#define vadd_u32 vaddq_u32
#define vadd_u64 vaddq_u64
#define vsub_8  vsubq_s8
#define vsub_16 vsubq_s16
#define vsub_32 vsubq_s32
#define vsub_64 vsubq_s64
#define vsub_u8  vsubq_u8
#define vsub_u16 vsubq_u16
#define vsub_u32 vsubq_u32
#define vsub_u64 vsubq_u64
#define vmul_8  vmulq_s8
#define vmul_16 vmulq_s16
#define vmul_32 vmulq_s32
#define vmul_64 vmulq_s64
#define vmul_u8  vmulq_u8
#define vmul_u16 vmulq_u16
#define vmul_u32 vmulq_u32
#define vmul_u64 vmulq_u64
#define vshl_8  vshlq_s8
#define vshl_16 vshlq_s16
#define vshl_32 vshlq_s32
#define vshl_64 vshlq_s64
#define vshl_u8  vshlq_u8
#define vshl_u16 vshlq_u16
#define vshl_u32 vshlq_u32
#define vshl_u64 vshlq_u64
#define veor_8  veorq_s8
#define veor_16 veorq_s16
#define veor_32 veorq_s32
#define veor_64 veorq_s64
#define veor_u8  veorq_u8
#define veor_u16 veorq_u16
#define veor_u32 veorq_u32
#define veor_u64 veorq_u64
#define vorr_8  vorrq_s8
#define vorr_16 vorrq_s16
#define vorr_32 vorrq_s32
#define vorr_64 vorrq_s64
#define vorr_u8  vorrq_u8
#define vorr_u16 vorrq_u16
#define vorr_u32 vorrq_u32
#define vorr_u64 vorrq_u64
#define vand_8  vandq_s8
#define vand_16 vandq_s16
#define vand_32 vandq_s32
#define vand_64 vandq_s64
#define vand_u8  vandq_u8
#define vand_u16 vandq_u16
#define vand_u32 vandq_u32
#define vand_u64 vandq_u64
#define vld1_8  vld1q_s8
#define vld1_16 vld1q_s16
#define vld1_32 vld1q_s32
#define vld1_64 vld1q_s64
#define vld1_u8  vld1q_u8
#define vld1_u16 vld1q_u16
#define vld1_u32 vld1q_u32
#define vld1_u64 vld1q_u64
#define vget_lane_8  vgetq_lane_s8
#define vget_lane_16 vgetq_lane_s16
#define vget_lane_32 vgetq_lane_s32
#define vget_lane_64 vgetq_lane_s64
#define vget_lane_u8  vgetq_lane_u8
#define vget_lane_u16 vgetq_lane_u16
#define vget_lane_u32 vgetq_lane_u32
#define vget_lane_u64 vgetq_lane_u64
#define vstore_lane_8(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##8(vec, 0); \
		out[1] = vget_lane_##sign##8(vec, 1); \
		out[2] = vget_lane_##sign##8(vec, 2); \
		out[3] = vget_lane_##sign##8(vec, 3); \
		out[4] = vget_lane_##sign##8(vec, 4); \
		out[5] = vget_lane_##sign##8(vec, 5); \
		out[6] = vget_lane_##sign##8(vec, 6); \
		out[7] = vget_lane_##sign##8(vec, 7); \
		out[8] = vget_lane_##sign##8(vec, 8); \
		out[9] = vget_lane_##sign##8(vec, 9); \
		out[10] = vget_lane_##sign##8(vec, 10); \
		out[11] = vget_lane_##sign##8(vec, 11); \
		out[12] = vget_lane_##sign##8(vec, 12); \
		out[13] = vget_lane_##sign##8(vec, 13); \
		out[14] = vget_lane_##sign##8(vec, 14); \
		out[15] = vget_lane_##sign##8(vec, 15); \
	} while (0)
#define vstore_lane_16(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##16(vec, 0); \
		out[1] = vget_lane_##sign##16(vec, 1); \
		out[2] = vget_lane_##sign##16(vec, 2); \
		out[3] = vget_lane_##sign##16(vec, 3); \
		out[4] = vget_lane_##sign##16(vec, 4); \
		out[5] = vget_lane_##sign##16(vec, 5); \
		out[6] = vget_lane_##sign##16(vec, 6); \
		out[7] = vget_lane_##sign##16(vec, 7); \
	} while (0)
#define vstore_lane_32(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##32(vec, 0); \
		out[1] = vget_lane_##sign##32(vec, 1); \
		out[2] = vget_lane_##sign##32(vec, 2); \
		out[3] = vget_lane_##sign##32(vec, 3); \
	} while (0)
#define vstore_lane_64(sign, vec, out) \
	do { \
		out[0] = vget_lane_##sign##64(vec, 0); \
		out[1] = vget_lane_##sign##64(vec, 1); \
	} while (0)
#define vreinterpret_8_u8(x) vreinterpretq_s8_u8(x)
#define vreinterpret_16_u16(x) vreinterpretq_s16_u16(x)
#define vreinterpret_32_u32(x) vreinterpretq_s32_u32(x)
#define vreinterpret_64_u64(x) vreinterpretq_s64_u64(x)

VEC_DEFINE_OPERATIONS(8, 16)
VEC_DEFINE_OPERATIONS(16, 8)
VEC_DEFINE_OPERATIONS(32, 4)
VEC_DEFINE_OPERATIONS(64, 2)

#undef vadd_8
#undef vadd_16
#undef vadd_32
#undef vadd_64
#undef vsub_8
#undef vsub_16
#undef vsub_32
#undef vsub_64
#undef vmul_8
#undef vmul_16
#undef vmul_32
#undef vmul_64
#undef vshl_8
#undef vshl_16
#undef vshl_32
#undef vshl_64
#undef veor_8
#undef veor_16
#undef veor_32
#undef veor_64
#undef vorr_8
#undef vorr_16
#undef vorr_32
#undef vorr_64
#undef vand_8
#undef vand_16
#undef vand_32
#undef vand_64
#undef vld1_8
#undef vld1_16
#undef vld1_32
#undef vld1_64
#undef vget_lane_8 
#undef vget_lane_16
#undef vget_lane_32
#undef vget_lane_64
#undef vstore_lane_8
#undef vstore_lane_16
#undef vstore_lane_32
#undef vstore_lane_64