Vector Optimized Library of Kernels 3.3.0
Architecture-tuned implementations of math kernels
Loading...
Searching...
No Matches
volk_avx_intrinsics.h
Go to the documentation of this file.
1/* -*- c++ -*- */
2/*
3 * Copyright 2015 Free Software Foundation, Inc.
4 * Copyright 2023-2026 Magnus Lundmark <magnuslundmark@gmail.com>
5 *
6 * This file is part of VOLK
7 *
8 * SPDX-License-Identifier: LGPL-3.0-or-later
9 */
10
11/*
12 * This file is intended to hold AVX intrinsics.
13 * They should be used in VOLK kernels to avoid copy-pasta.
14 */
15
16#ifndef INCLUDE_VOLK_VOLK_AVX_INTRINSICS_H_
17#define INCLUDE_VOLK_VOLK_AVX_INTRINSICS_H_
18#include <immintrin.h>
19
20/*
21 * Newton-Raphson refined reciprocal square root: 1/sqrt(a)
22 * One iteration doubles precision from ~12-bit to ~24-bit
23 * x1 = x0 * (1.5 - 0.5 * a * x0^2)
24 * Handles edge cases: +0 → +Inf, +Inf → 0
25 */
26static inline __m256 _mm256_rsqrt_nr_ps(const __m256 a)
27{
28 const __m256 HALF = _mm256_set1_ps(0.5f);
29 const __m256 THREE_HALFS = _mm256_set1_ps(1.5f);
30
31 const __m256 x0 = _mm256_rsqrt_ps(a); // +Inf for +0, 0 for +Inf
32
33 // Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34 __m256 x1 = _mm256_mul_ps(
35 x0,
36 _mm256_sub_ps(THREE_HALFS,
37 _mm256_mul_ps(HALF, _mm256_mul_ps(_mm256_mul_ps(x0, x0), a))));
38
39 // For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
40 // Blend: use x0 where a == +0 or a == +Inf, else use x1
41 // AVX-only: use SSE2 integer compare, then reconstruct AVX mask
42 __m128i a_lo = _mm256_castsi256_si128(_mm256_castps_si256(a));
43 __m128i a_hi = _mm_castps_si128(_mm256_extractf128_ps(a, 1));
44 __m128i zero_si = _mm_setzero_si128();
45 __m128i inf_si = _mm_set1_epi32(0x7F800000);
46 __m128i zero_mask_lo = _mm_cmpeq_epi32(a_lo, zero_si);
47 __m128i zero_mask_hi = _mm_cmpeq_epi32(a_hi, zero_si);
48 __m128i inf_mask_lo = _mm_cmpeq_epi32(a_lo, inf_si);
49 __m128i inf_mask_hi = _mm_cmpeq_epi32(a_hi, inf_si);
50 __m128 mask_lo = _mm_castsi128_ps(_mm_or_si128(zero_mask_lo, inf_mask_lo));
51 __m128 mask_hi = _mm_castsi128_ps(_mm_or_si128(zero_mask_hi, inf_mask_hi));
52 __m256 special_mask =
53 _mm256_insertf128_ps(_mm256_castps128_ps256(mask_lo), mask_hi, 1);
54 return _mm256_blendv_ps(x1, x0, special_mask);
55}
56
57/*
58 * Approximate arctan(x) via polynomial expansion
59 * on the interval [-1, 1]
60 *
61 * Maximum relative error ~6.5e-7
62 * Polynomial evaluated via Horner's method
63 */
64static inline __m256 _mm256_arctan_poly_avx(const __m256 x)
65{
66 const __m256 a1 = _mm256_set1_ps(+0x1.ffffeap-1f);
67 const __m256 a3 = _mm256_set1_ps(-0x1.55437p-2f);
68 const __m256 a5 = _mm256_set1_ps(+0x1.972be6p-3f);
69 const __m256 a7 = _mm256_set1_ps(-0x1.1436ap-3f);
70 const __m256 a9 = _mm256_set1_ps(+0x1.5785aap-4f);
71 const __m256 a11 = _mm256_set1_ps(-0x1.2f3004p-5f);
72 const __m256 a13 = _mm256_set1_ps(+0x1.01a37cp-7f);
73
74 const __m256 x_times_x = _mm256_mul_ps(x, x);
75 __m256 arctan;
76 arctan = a13;
77 arctan = _mm256_mul_ps(x_times_x, arctan);
78 arctan = _mm256_add_ps(arctan, a11);
79 arctan = _mm256_mul_ps(x_times_x, arctan);
80 arctan = _mm256_add_ps(arctan, a9);
81 arctan = _mm256_mul_ps(x_times_x, arctan);
82 arctan = _mm256_add_ps(arctan, a7);
83 arctan = _mm256_mul_ps(x_times_x, arctan);
84 arctan = _mm256_add_ps(arctan, a5);
85 arctan = _mm256_mul_ps(x_times_x, arctan);
86 arctan = _mm256_add_ps(arctan, a3);
87 arctan = _mm256_mul_ps(x_times_x, arctan);
88 arctan = _mm256_add_ps(arctan, a1);
89 arctan = _mm256_mul_ps(x, arctan);
90
91 return arctan;
92}
93
94/*
95 * Approximate arcsin(x) via polynomial expansion
96 * P(u) such that asin(x) = x * P(x^2) on |x| <= 0.5
97 *
98 * Maximum relative error ~1.5e-6
99 * Polynomial evaluated via Horner's method
100 */
101static inline __m256 _mm256_arcsin_poly_avx(const __m256 x)
102{
103 const __m256 c0 = _mm256_set1_ps(0x1.ffffcep-1f);
104 const __m256 c1 = _mm256_set1_ps(0x1.55b648p-3f);
105 const __m256 c2 = _mm256_set1_ps(0x1.24d192p-4f);
106 const __m256 c3 = _mm256_set1_ps(0x1.0a788p-4f);
107
108 const __m256 u = _mm256_mul_ps(x, x);
109 __m256 p = c3;
110 p = _mm256_mul_ps(u, p);
111 p = _mm256_add_ps(p, c2);
112 p = _mm256_mul_ps(u, p);
113 p = _mm256_add_ps(p, c1);
114 p = _mm256_mul_ps(u, p);
115 p = _mm256_add_ps(p, c0);
116
117 return _mm256_mul_ps(x, p);
118}
119
120static inline __m256 _mm256_complexmul_ps(__m256 x, __m256 y)
121{
122 __m256 yl, yh, tmp1, tmp2;
123 yl = _mm256_moveldup_ps(y); // Load yl with cr,cr,dr,dr ...
124 yh = _mm256_movehdup_ps(y); // Load yh with ci,ci,di,di ...
125 tmp1 = _mm256_mul_ps(x, yl); // tmp1 = ar*cr,ai*cr,br*dr,bi*dr ...
126 x = _mm256_shuffle_ps(x, x, 0xB1); // Re-arrange x to be ai,ar,bi,br ...
127 tmp2 = _mm256_mul_ps(x, yh); // tmp2 = ai*ci,ar*ci,bi*di,br*di
128
129 // ar*cr-ai*ci, ai*cr+ar*ci, br*dr-bi*di, bi*dr+br*di
130 return _mm256_addsub_ps(tmp1, tmp2);
131}
132
133static inline __m256 _mm256_conjugate_ps(__m256 x)
134{
135 const __m256 conjugator = _mm256_setr_ps(0, -0.f, 0, -0.f, 0, -0.f, 0, -0.f);
136 return _mm256_xor_ps(x, conjugator); // conjugate y
137}
138
139static inline __m256 _mm256_complexconjugatemul_ps(const __m256 x, const __m256 y)
140{
141 const __m256 nswap = _mm256_permute_ps(x, 0xb1);
142 const __m256 dreal = _mm256_moveldup_ps(y);
143 const __m256 dimag = _mm256_movehdup_ps(y);
144
145 const __m256 conjugator = _mm256_setr_ps(0, -0.f, 0, -0.f, 0, -0.f, 0, -0.f);
146 const __m256 dimagconj = _mm256_xor_ps(dimag, conjugator);
147 const __m256 multreal = _mm256_mul_ps(x, dreal);
148 const __m256 multimag = _mm256_mul_ps(nswap, dimagconj);
149 return _mm256_add_ps(multreal, multimag);
150}
151
152static inline __m256 _mm256_normalize_ps(__m256 val)
153{
154 __m256 tmp1 = _mm256_mul_ps(val, val);
155 tmp1 = _mm256_hadd_ps(tmp1, tmp1);
156 tmp1 = _mm256_shuffle_ps(tmp1, tmp1, _MM_SHUFFLE(3, 1, 2, 0)); // equals 0xD8
157 tmp1 = _mm256_sqrt_ps(tmp1);
158 return _mm256_div_ps(val, tmp1);
159}
160
161static inline __m256 _mm256_magnitudesquared_ps(__m256 cplxValue1, __m256 cplxValue2)
162{
163 __m256 complex1, complex2;
164 cplxValue1 = _mm256_mul_ps(cplxValue1, cplxValue1); // Square the values
165 cplxValue2 = _mm256_mul_ps(cplxValue2, cplxValue2); // Square the Values
166 complex1 = _mm256_permute2f128_ps(cplxValue1, cplxValue2, 0x20);
167 complex2 = _mm256_permute2f128_ps(cplxValue1, cplxValue2, 0x31);
168 return _mm256_hadd_ps(complex1, complex2); // Add the I2 and Q2 values
169}
170
171static inline __m256 _mm256_magnitude_ps(__m256 cplxValue1, __m256 cplxValue2)
172{
173 return _mm256_sqrt_ps(_mm256_magnitudesquared_ps(cplxValue1, cplxValue2));
174}
175
176static inline __m256 _mm256_scaled_norm_dist_ps(const __m256 symbols0,
177 const __m256 symbols1,
178 const __m256 points0,
179 const __m256 points1,
180 const __m256 scalar)
181{
182 /*
183 * Calculate: |y - x|^2 * SNR_lin
184 * Consider 'symbolsX' and 'pointsX' to be complex float
185 * 'symbolsX' are 'y' and 'pointsX' are 'x'
186 */
187 const __m256 diff0 = _mm256_sub_ps(symbols0, points0);
188 const __m256 diff1 = _mm256_sub_ps(symbols1, points1);
189 const __m256 norms = _mm256_magnitudesquared_ps(diff0, diff1);
190 return _mm256_mul_ps(norms, scalar);
191}
192
193static inline __m256 _mm256_polar_sign_mask(__m128i fbits)
194{
195 __m256 sign_mask_dummy = _mm256_setzero_ps();
196 const __m128i zeros = _mm_set1_epi8(0x00);
197 const __m128i sign_extract = _mm_set1_epi8(0x80);
198 const __m128i shuffle_mask0 = _mm_setr_epi8(0xff,
199 0xff,
200 0xff,
201 0x00,
202 0xff,
203 0xff,
204 0xff,
205 0x01,
206 0xff,
207 0xff,
208 0xff,
209 0x02,
210 0xff,
211 0xff,
212 0xff,
213 0x03);
214 const __m128i shuffle_mask1 = _mm_setr_epi8(0xff,
215 0xff,
216 0xff,
217 0x04,
218 0xff,
219 0xff,
220 0xff,
221 0x05,
222 0xff,
223 0xff,
224 0xff,
225 0x06,
226 0xff,
227 0xff,
228 0xff,
229 0x07);
230
231 fbits = _mm_cmpgt_epi8(fbits, zeros);
232 fbits = _mm_and_si128(fbits, sign_extract);
233 __m128i sign_bits0 = _mm_shuffle_epi8(fbits, shuffle_mask0);
234 __m128i sign_bits1 = _mm_shuffle_epi8(fbits, shuffle_mask1);
235
236 __m256 sign_mask =
237 _mm256_insertf128_ps(sign_mask_dummy, _mm_castsi128_ps(sign_bits0), 0x0);
238 return _mm256_insertf128_ps(sign_mask, _mm_castsi128_ps(sign_bits1), 0x1);
239 // This is the desired function call. Though it seems to be missing in GCC.
240 // Compare: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html
241 // return _mm256_set_m128(_mm_castsi128_ps(sign_bits1),
242 // _mm_castsi128_ps(sign_bits0));
243}
244
245static inline void
246_mm256_polar_deinterleave(__m256* llr0, __m256* llr1, __m256 src0, __m256 src1)
247{
248 // deinterleave values
249 __m256 part0 = _mm256_permute2f128_ps(src0, src1, 0x20);
250 __m256 part1 = _mm256_permute2f128_ps(src0, src1, 0x31);
251 *llr0 = _mm256_shuffle_ps(part0, part1, 0x88);
252 *llr1 = _mm256_shuffle_ps(part0, part1, 0xdd);
253}
254
255static inline __m256 _mm256_polar_minsum_llrs(__m256 src0, __m256 src1)
256{
257 const __m256 sign_mask = _mm256_set1_ps(-0.0f);
258 const __m256 abs_mask =
259 _mm256_andnot_ps(sign_mask, _mm256_castsi256_ps(_mm256_set1_epi8(0xff)));
260
261 __m256 llr0, llr1;
262 _mm256_polar_deinterleave(&llr0, &llr1, src0, src1);
263
264 // calculate result
265 __m256 sign =
266 _mm256_xor_ps(_mm256_and_ps(llr0, sign_mask), _mm256_and_ps(llr1, sign_mask));
267 __m256 dst =
268 _mm256_min_ps(_mm256_and_ps(llr0, abs_mask), _mm256_and_ps(llr1, abs_mask));
269 return _mm256_or_ps(dst, sign);
270}
271
272static inline __m256 _mm256_polar_fsign_add_llrs(__m256 src0, __m256 src1, __m128i fbits)
273{
274 // prepare sign mask for correct +-
275 __m256 sign_mask = _mm256_polar_sign_mask(fbits);
276
277 __m256 llr0, llr1;
278 _mm256_polar_deinterleave(&llr0, &llr1, src0, src1);
279
280 // calculate result
281 llr0 = _mm256_xor_ps(llr0, sign_mask);
282 __m256 dst = _mm256_add_ps(llr0, llr1);
283 return dst;
284}
285
287 __m256 sq_acc, __m256 acc, __m256 val, __m256 rec, __m256 aux)
288{
289 aux = _mm256_mul_ps(aux, val);
290 aux = _mm256_sub_ps(aux, acc);
291 aux = _mm256_mul_ps(aux, aux);
292 aux = _mm256_mul_ps(aux, rec);
293 return _mm256_add_ps(sq_acc, aux);
294}
295
296/*
297 * Polynomial coefficients for log2(x)/(x-1) on [1, 2]
298 * Generated with Sollya: remez(log2(x)/(x-1), 6, [1+1b-20, 2])
299 * Max error: ~1.55e-6
300 *
301 * Usage: log2(x) ≈ poly(x) * (x - 1) for x ∈ [1, 2]
302 * Polynomial evaluated via Horner's method
303 */
304static inline __m256 _mm256_log2_poly_avx(const __m256 x)
305{
306 const __m256 c0 = _mm256_set1_ps(+0x1.a8a726p+1f);
307 const __m256 c1 = _mm256_set1_ps(-0x1.0b7f7ep+2f);
308 const __m256 c2 = _mm256_set1_ps(+0x1.05d9ccp+2f);
309 const __m256 c3 = _mm256_set1_ps(-0x1.4d476cp+1f);
310 const __m256 c4 = _mm256_set1_ps(+0x1.04fc3ap+0f);
311 const __m256 c5 = _mm256_set1_ps(-0x1.c97982p-3f);
312 const __m256 c6 = _mm256_set1_ps(+0x1.57aa42p-6f);
313
314 // Horner's method: c0 + x*(c1 + x*(c2 + ...))
315 __m256 poly = c6;
316 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c5);
317 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c4);
318 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c3);
319 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c2);
320 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c1);
321 poly = _mm256_add_ps(_mm256_mul_ps(poly, x), c0);
322 return poly;
323}
324
325#endif /* INCLUDE_VOLK_VOLK_AVX_INTRINSICS_H_ */