view src/impl/x86/avx512f.c @ 25:92156fe32755

impl/ppc/altivec: update to new implementation the signed average function is wrong; it needs to round up the number when only one of them is odd, but that doesn't necessarily seem to be true because altivec is weird, and that's what we need to emulate the quirks for. ugh. also the altivec backend uses the generic functions instead of fallbacks because it does indeed use the exact same memory structure as the generic implementation...
author Paper <paper@tflc.us>
date Sun, 24 Nov 2024 11:15:59 +0000
parents e49e70f7012f
children
line wrap: on
line source

/**
 * vec - a tiny SIMD vector library in C99
 * 
 * Copyright (c) 2024 Paper
 * 
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 * 
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 * 
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
**/

#include "vec/impl/x86/avx512f.h"
#include "vec/impl/generic.h"

#include <immintrin.h>

// this is a stupid amount of work just to do these operations, is it really worth it ?
// also same note in avx2.c applies here, these do not handle sign bits properly, which
// isn't that big of a deal for regular arithmetic operations, but matters quite a bit
// when doing things like arithmetic shifts.
#define VEC_AVX512F_OPERATION_8x64(op, sign) \
	do { \
		union v##sign##int8x64_impl_data *vec1d = (union v##sign##int8x64_impl_data *)&vec1; \
		union v##sign##int8x64_impl_data *vec2d = (union v##sign##int8x64_impl_data *)&vec2; \
	\
		/* unpack and operate */ \
		__m512i dst_1 = _mm512_##op##_epi32(_mm512_srli_epi32(_mm512_slli_epi32(vec1d->avx512f, 24), 24), _mm512_srli_epi32(_mm512_slli_epi32(vec2d->avx512f, 24), 24)); \
		__m512i dst_2 = _mm512_##op##_epi32(_mm512_srli_epi32(_mm512_slli_epi32(vec1d->avx512f, 16), 24), _mm512_srli_epi32(_mm512_slli_epi32(vec2d->avx512f, 16), 24)); \
		__m512i dst_3 = _mm512_##op##_epi32(_mm512_srli_epi32(_mm512_slli_epi32(vec1d->avx512f, 8), 24), _mm512_srli_epi32(_mm512_slli_epi32(vec2d->avx512f, 8), 24)); \
		__m512i dst_4 = _mm512_##op##_epi32(_mm512_srli_epi32(vec1d->avx512f, 24), _mm512_srli_epi32(vec2d->avx512f, 24)); \
	\
		/* repack */ \
		vec1d->avx512f = _mm512_or_si512( \
			_mm512_or_si512( \
				_mm512_srli_epi32(_mm512_slli_epi32(dst_1, 24), 24), \
				_mm512_srli_epi32(_mm512_slli_epi32(dst_2, 24), 16) \
			), \
			_mm512_or_si512( \
				_mm512_srli_epi32(_mm512_slli_epi32(dst_3, 24), 8), \
				_mm512_slli_epi32(dst_4, 24) \
			) \
		); \
	\
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_OPERATION_16x32(op, sign) \
	do { \
		union v##sign##int16x32_impl_data *vec1d = (union v##sign##int16x32_impl_data *)&vec1; \
		union v##sign##int16x32_impl_data *vec2d = (union v##sign##int16x32_impl_data *)&vec2; \
	\
		/* unpack and operate; it would be nice if we had an _m512_andi_epi32... */ \
		__m512i dst_1 = _mm512_##op##_epi32(_mm512_srli_epi32(_mm512_slli_epi32(vec1d->avx512f, 16), 16), _mm512_srli_epi32(_mm512_slli_epi32(vec2d->avx512f, 16), 16)); \
		__m512i dst_2 = _mm512_##op##_epi32(_mm512_srli_epi32(vec1d->avx512f, 16), _mm512_srli_epi32(vec2d->avx512f, 16)); \
	\
		/* repack */ \
		vec1d->avx512f = _mm512_or_si512( \
			_mm512_srli_epi32(_mm512_slli_epi32(dst_1, 16), 16), \
			_mm512_slli_epi32(dst_2, 16) \
		); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_ADD_8x64(sign) \
	VEC_AVX512F_OPERATION_8x64(add, sign)

#define VEC_AVX512F_ADD_16x32(sign) \
	VEC_AVX512F_OPERATION_16x32(add, sign)

#define VEC_AVX512F_ADD_32x16(sign) \
	do { \
		union v##sign##int32x16_impl_data *vec1d = (union v##sign##int32x16_impl_data *)&vec1; \
		union v##sign##int32x16_impl_data *vec2d = (union v##sign##int32x16_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_add_epi32(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_ADD_64x8(sign) \
	do { \
		union v##sign##int64x8_impl_data *vec1d = (union v##sign##int64x8_impl_data *)&vec1; \
		union v##sign##int64x8_impl_data *vec2d = (union v##sign##int64x8_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_add_epi64(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_SUB_8x64(sign) \
	VEC_AVX512F_OPERATION_8x64(sub, sign)

#define VEC_AVX512F_SUB_16x32(sign) \
	VEC_AVX512F_OPERATION_16x32(sub, sign)

#define VEC_AVX512F_SUB_32x16(sign) \
	do { \
		union v##sign##int32x16_impl_data *vec1d = (union v##sign##int32x16_impl_data *)&vec1; \
		union v##sign##int32x16_impl_data *vec2d = (union v##sign##int32x16_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sub_epi32(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_SUB_64x8(sign) \
	do { \
		union v##sign##int64x8_impl_data *vec1d = (union v##sign##int64x8_impl_data *)&vec1; \
		union v##sign##int64x8_impl_data *vec2d = (union v##sign##int64x8_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sub_epi64(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_MUL_8x64(sign) \
	VEC_AVX512F_OPERATION_8x64(mullo, sign)

#define VEC_AVX512F_MUL_16x32(sign) \
	VEC_AVX512F_OPERATION_16x32(mullo, sign)

#define VEC_AVX512F_MUL_32x16(sign) \
	do { \
		union v##sign##int32x16_impl_data *vec1d = (union v##sign##int32x16_impl_data *)&vec1; \
		union v##sign##int32x16_impl_data *vec2d = (union v##sign##int32x16_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_mullo_epi32(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_MUL_64x8(sign) \
	do { \
		union v##sign##int64x8_impl_data *vec1d = (union v##sign##int64x8_impl_data *)&vec1; \
		union v##sign##int64x8_impl_data *vec2d = (union v##sign##int64x8_impl_data *)&vec2; \
	\
		__m512i ac = _mm512_mul_epu32(vec1d->avx512f, vec2d->avx512f); \
		__m512i b  = _mm512_srli_epi64(vec1d->avx512f, 32); \
		__m512i bc = _mm512_mul_epu32(b, vec2d->avx512f); \
		__m512i d  = _mm512_srli_epi64(vec2d->avx512f, 32); \
		__m512i ad = _mm512_mul_epu32(vec1d->avx512f, d); \
		__m512i hi = _mm512_add_epi64(bc, ad); \
		hi = _mm512_slli_epi64(hi, 32); \
	\
		vec1d->avx512f = _mm512_add_epi64(hi, ac); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_LSHIFT_8x64(sign) \
	VEC_AVX512F_OPERATION_8x64(sllv, sign)

#define VEC_AVX512F_LSHIFT_16x32(sign) \
	VEC_AVX512F_OPERATION_16x32(sllv, sign)

#define VEC_AVX512F_LSHIFT_32x16(sign) \
	do { \
		union v##sign##int32x16_impl_data *vec1d = (union v##sign##int32x16_impl_data *)&vec1; \
		union v##sign##int32x16_impl_data *vec2d = (union v##sign##int32x16_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sllv_epi32(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_LSHIFT_64x8(sign) \
	do { \
		union v##sign##int64x8_impl_data *vec1d = (union v##sign##int64x8_impl_data *)&vec1; \
		union v##sign##int64x8_impl_data *vec2d = (union v##sign##int64x8_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sllv_epi64(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_lRSHIFT_8x64(sign) \
	VEC_AVX512F_OPERATION_8x64(srlv, sign)

#define VEC_AVX512F_lRSHIFT_16x32(sign) \
	VEC_AVX512F_OPERATION_16x32(srlv, sign)

#define VEC_AVX512F_aRSHIFT_8x64(sign) \
	do { \
		return v##sign##int8x64_generic_rshift(vec1, vec2); \
	} while (0)

#define VEC_AVX512F_aRSHIFT_16x32(sign) \
	do { \
		return v##sign##int16x32_generic_rshift(vec1, vec2); \
	} while (0)

#define VEC_AVX512F_RSHIFT_8x64(sign, aORl) VEC_AVX512F_##aORl##RSHIFT_8x64(sign)
#define VEC_AVX512F_RSHIFT_16x32(sign, aORl) VEC_AVX512F_##aORl##RSHIFT_16x32(sign)

#define VEC_AVX512F_RSHIFT_32x16(sign, aORl) \
	do { \
		union v##sign##int32x16_impl_data *vec1d = (union v##sign##int32x16_impl_data *)&vec1; \
		union v##sign##int32x16_impl_data *vec2d = (union v##sign##int32x16_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sr##aORl##v_epi32(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_RSHIFT_64x8(sign, aORl) \
	do { \
		union v##sign##int64x8_impl_data *vec1d = (union v##sign##int64x8_impl_data *)&vec1; \
		union v##sign##int64x8_impl_data *vec2d = (union v##sign##int64x8_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_sr##aORl##v_epi64(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} while (0)

#define VEC_AVX512F_uRSHIFT_8x64(sign, aORl) VEC_AVX512F_RSHIFT_8x64(sign, l)
#define VEC_AVX512F_uRSHIFT_16x32(sign, aORl) VEC_AVX512F_RSHIFT_16x32(sign, l)
#define VEC_AVX512F_uRSHIFT_32x16(sign, aORl) VEC_AVX512F_RSHIFT_32x16(sign, l)
#define VEC_AVX512F_uRSHIFT_64x8(sign, aORl) VEC_AVX512F_RSHIFT_64x8(sign, l)

#define VEC_AVX512F_DEFINE_OPERATIONS_SIGN(sign, bits, size) \
	union v##sign##int##bits##x##size##_impl_data { \
		v##sign##int##bits##x##size vec; \
		__m512i avx512f; \
	}; \
	\
	VEC_STATIC_ASSERT(VEC_ALIGNOF(__m512i) <= VEC_ALIGNOF(v##sign##int##bits##x##size), "vec: v" #sign "int" #bits "x" #size " alignment needs to be expanded to fit intrinsic type size"); \
	VEC_STATIC_ASSERT(sizeof(__m512i) <= sizeof(v##sign##int##bits##x##size), "vec: v" #sign "int" #bits "x" #size " needs to be expanded to fit intrinsic type size"); \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_load_aligned(const vec_##sign##int##bits in[size]) \
	{ \
		union v##sign##int##bits##x##size##_impl_data vec; \
		vec.avx512f = _mm512_load_si512((const __m512i *)in); \
		return vec.vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_load(const vec_##sign##int##bits in[size]) \
	{ \
		union v##sign##int##bits##x##size##_impl_data vec; \
		vec.avx512f = _mm512_loadu_si512((const __m512i *)in); \
		return vec.vec; \
	} \
	\
	static void v##sign##int##bits##x##size##_avx512f_store_aligned(v##sign##int##bits##x##size vec, vec_##sign##int##bits out[size]) \
	{ \
		_mm512_store_si512((__m512i *)out, ((union v##sign##int##bits##x##size##_impl_data *)&vec)->avx512f); \
	} \
	\
	static void v##sign##int##bits##x##size##_avx512f_store(v##sign##int##bits##x##size vec, vec_##sign##int##bits out[size]) \
	{ \
		_mm512_storeu_si512((__m512i *)out, ((union v##sign##int##bits##x##size##_impl_data *)&vec)->avx512f); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_add(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		VEC_AVX512F_ADD_##bits##x##size(sign); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_sub(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		VEC_AVX512F_SUB_##bits##x##size(sign); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_mul(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		VEC_AVX512F_MUL_##bits##x##size(sign); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_and(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_and_si512(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_or(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_or_si512(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_xor(v##sign##int##bits##x##size vec1, v##sign##int##bits##x##size vec2) \
	{ \
		union v##sign##int##bits##x##size##_impl_data *vec1d = (union v##sign##int##bits##x##size##_impl_data *)&vec1; \
		union v##sign##int##bits##x##size##_impl_data *vec2d = (union v##sign##int##bits##x##size##_impl_data *)&vec2; \
	\
		vec1d->avx512f = _mm512_xor_si512(vec1d->avx512f, vec2d->avx512f); \
		return vec1d->vec; \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_lshift(v##sign##int##bits##x##size vec1, vuint##bits##x##size vec2) \
	{ \
		VEC_AVX512F_LSHIFT_##bits##x##size(sign); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_rshift(v##sign##int##bits##x##size vec1, vuint##bits##x##size vec2) \
	{ \
		VEC_AVX512F_##sign##RSHIFT_##bits##x##size(sign, a); \
	} \
	\
	static v##sign##int##bits##x##size v##sign##int##bits##x##size##_avx512f_lrshift(v##sign##int##bits##x##size vec1, vuint##bits##x##size vec2) \
	{ \
		VEC_AVX512F_RSHIFT_##bits##x##size(sign, l); \
	} \
	\
	const v##sign##int##bits##x##size##_impl v##sign##int##bits##x##size##_impl_avx512f = { \
		v##sign##int##bits##x##size##_generic_splat, \
		v##sign##int##bits##x##size##_avx512f_load_aligned, \
		v##sign##int##bits##x##size##_avx512f_load, \
		v##sign##int##bits##x##size##_avx512f_store_aligned, \
		v##sign##int##bits##x##size##_avx512f_store, \
		v##sign##int##bits##x##size##_avx512f_add, \
		v##sign##int##bits##x##size##_avx512f_sub, \
		v##sign##int##bits##x##size##_avx512f_mul, \
		v##sign##int##bits##x##size##_generic_div, \
		v##sign##int##bits##x##size##_generic_avg, \
		v##sign##int##bits##x##size##_avx512f_and, \
		v##sign##int##bits##x##size##_avx512f_or, \
		v##sign##int##bits##x##size##_avx512f_xor, \
		v##sign##int##bits##x##size##_generic_not, \
		v##sign##int##bits##x##size##_avx512f_lshift, \
		v##sign##int##bits##x##size##_avx512f_rshift, \
		v##sign##int##bits##x##size##_avx512f_lrshift, \
		v##sign##int##bits##x##size##_generic_cmplt, \
		v##sign##int##bits##x##size##_generic_cmple, \
		v##sign##int##bits##x##size##_generic_cmpeq, \
		v##sign##int##bits##x##size##_generic_cmpge, \
		v##sign##int##bits##x##size##_generic_cmpgt, \
	};

#define VEC_AVX512F_DEFINE_OPERATIONS(bits, size) \
	VEC_AVX512F_DEFINE_OPERATIONS_SIGN(u, bits, size) \
	VEC_AVX512F_DEFINE_OPERATIONS_SIGN( , bits, size)

VEC_AVX512F_DEFINE_OPERATIONS(8, 64)
VEC_AVX512F_DEFINE_OPERATIONS(16, 32)
VEC_AVX512F_DEFINE_OPERATIONS(32, 16)
VEC_AVX512F_DEFINE_OPERATIONS(64, 8)