Vector Optimized Library of Kernels 3.3.0
Architecture-tuned implementations of math kernels
Loading...
Searching...
No Matches
volk_avx512_intrinsics.h
Go to the documentation of this file.
1/* -*- c++ -*- */
2/*
3 * Copyright 2024-2026 Magnus Lundmark <magnuslundmark@gmail.com>
4 *
5 * This file is part of VOLK
6 *
7 * SPDX-License-Identifier: LGPL-3.0-or-later
8 */
9
10/*
11 * This file is intended to hold AVX512 intrinsics.
12 * They should be used in VOLK kernels to avoid copy-paste.
13 */
14
15#ifndef INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_
16#define INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_
17#include <immintrin.h>
18
20// Newton-Raphson refined reciprocal square root: 1/sqrt(a)
21// One iteration doubles precision from ~12-bit to ~24-bit
22// x1 = x0 * (1.5 - 0.5 * a * x0^2)
23// Handles edge cases: +0 → +Inf, +Inf → 0
24// Requires AVX512F
26static inline __m512 _mm512_rsqrt_nr_ps(const __m512 a)
27{
28 const __m512 HALF = _mm512_set1_ps(0.5f);
29 const __m512 THREE_HALFS = _mm512_set1_ps(1.5f);
30
31 const __m512 x0 = _mm512_rsqrt14_ps(a); // +Inf for +0, 0 for +Inf
32
33 // Newton-Raphson: x1 = x0 * (1.5 - 0.5 * a * x0^2)
34 __m512 x1 = _mm512_mul_ps(
35 x0, _mm512_fnmadd_ps(HALF, _mm512_mul_ps(_mm512_mul_ps(x0, x0), a), THREE_HALFS));
36
37 // For +0 and +Inf inputs, x0 is correct but NR produces NaN due to Inf*0
38 // Blend: use x0 where a == +0 or a == +Inf, else use x1
39 __m512i a_si = _mm512_castps_si512(a);
40 __mmask16 zero_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_setzero_si512());
41 __mmask16 inf_mask = _mm512_cmpeq_epi32_mask(a_si, _mm512_set1_epi32(0x7F800000));
42 return _mm512_mask_blend_ps(zero_mask | inf_mask, x1, x0);
43}
44
46// Place real parts of two complex vectors in output
47// Requires AVX512F
49static inline __m512 _mm512_real(const __m512 z1, const __m512 z2)
50{
51 const __m512i idx =
52 _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
53 return _mm512_permutex2var_ps(z1, idx, z2);
54}
55
57// Place imaginary parts of two complex vectors in output
58// Requires AVX512F
60static inline __m512 _mm512_imag(const __m512 z1, const __m512 z2)
61{
62 const __m512i idx =
63 _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
64 return _mm512_permutex2var_ps(z1, idx, z2);
65}
66
68// Approximate arctan(x) via polynomial expansion on the interval [-1, 1]
69// Maximum relative error ~6.5e-7
70// Polynomial evaluated via Horner's method
71// Requires AVX512F
73static inline __m512 _mm512_arctan_poly_avx512(const __m512 x)
74{
75 const __m512 a1 = _mm512_set1_ps(+0x1.ffffeap-1f);
76 const __m512 a3 = _mm512_set1_ps(-0x1.55437p-2f);
77 const __m512 a5 = _mm512_set1_ps(+0x1.972be6p-3f);
78 const __m512 a7 = _mm512_set1_ps(-0x1.1436ap-3f);
79 const __m512 a9 = _mm512_set1_ps(+0x1.5785aap-4f);
80 const __m512 a11 = _mm512_set1_ps(-0x1.2f3004p-5f);
81 const __m512 a13 = _mm512_set1_ps(+0x1.01a37cp-7f);
82
83 const __m512 x_times_x = _mm512_mul_ps(x, x);
84 __m512 arctan;
85 arctan = a13;
86 arctan = _mm512_fmadd_ps(x_times_x, arctan, a11);
87 arctan = _mm512_fmadd_ps(x_times_x, arctan, a9);
88 arctan = _mm512_fmadd_ps(x_times_x, arctan, a7);
89 arctan = _mm512_fmadd_ps(x_times_x, arctan, a5);
90 arctan = _mm512_fmadd_ps(x_times_x, arctan, a3);
91 arctan = _mm512_fmadd_ps(x_times_x, arctan, a1);
92 arctan = _mm512_mul_ps(x, arctan);
93
94 return arctan;
95}
96
98// Approximate arcsin(x) via polynomial expansion
99// P(u) such that asin(x) = x * P(x^2) on |x| <= 0.5
100// Maximum relative error ~1.5e-6
101// Polynomial evaluated via Horner's method
102// Requires AVX512F
104static inline __m512 _mm512_arcsin_poly_avx512(const __m512 x)
105{
106 const __m512 c0 = _mm512_set1_ps(0x1.ffffcep-1f);
107 const __m512 c1 = _mm512_set1_ps(0x1.55b648p-3f);
108 const __m512 c2 = _mm512_set1_ps(0x1.24d192p-4f);
109 const __m512 c3 = _mm512_set1_ps(0x1.0a788p-4f);
110
111 const __m512 u = _mm512_mul_ps(x, x);
112 __m512 p = c3;
113 p = _mm512_fmadd_ps(u, p, c2);
114 p = _mm512_fmadd_ps(u, p, c1);
115 p = _mm512_fmadd_ps(u, p, c0);
116
117 return _mm512_mul_ps(x, p);
118}
119
121// Complex multiply: (a+bi) * (c+di) = (ac-bd) + i(ad+bc)
122// Requires AVX512F
124static inline __m512 _mm512_complexmul_ps(const __m512 x, const __m512 y)
125{
126 const __m512 yl = _mm512_moveldup_ps(y); // Load yl with cr,cr,dr,dr ...
127 const __m512 yh = _mm512_movehdup_ps(y); // Load yh with ci,ci,di,di ...
128 const __m512 tmp1 = _mm512_mul_ps(x, yl); // tmp1 = ar*cr,ai*cr,br*dr,bi*dr ...
129 const __m512 x_swap =
130 _mm512_permute_ps(x, 0xB1); // Re-arrange x to be ai,ar,bi,br ...
131
132 // Compute ar*cr-ai*ci, ai*cr+ar*ci, br*dr-bi*di, bi*dr+br*di using FMA
133 // We need: tmp1 - (x_swap * yh) for real parts, tmp1 + (x_swap * yh) for imag parts
134 // This is accomplished with addsub pattern
135 const __m512 tmp2 = _mm512_mul_ps(x_swap, yh); // ai*ci,ar*ci,bi*di,br*di
136
137 // Use mask to create addsub behavior: subtract on even indices, add on odd
138 const __mmask16 addsub_mask = 0x5555; // 0101010101010101 in binary
139 return _mm512_mask_sub_ps(_mm512_add_ps(tmp1, tmp2), addsub_mask, tmp1, tmp2);
140}
141
143// Complex conjugate multiply: (a+bi) * conj(c+di) = (ac+bd) + i(bc-ad)
144// Requires AVX512F
146static inline __m512 _mm512_complexconjugatemul_ps(const __m512 x, const __m512 y)
147{
148 // Compute (a+bi) * conj(c+di) = (a+bi) * (c-di) = (ac+bd) + i(bc-ad)
149 const __m512 nswap = _mm512_permute_ps(x, 0xb1); // Swap real/imag: bi, ar, ...
150 const __m512 dreal = _mm512_moveldup_ps(y); // cr, cr, dr, dr, ...
151 const __m512 dimag = _mm512_movehdup_ps(y); // ci, ci, di, di, ...
152
153 // Use integer xor for conjugation (AVX512F compatible)
154 const __m512i conjugator_i = _mm512_setr_epi32(0,
155 0x80000000,
156 0,
157 0x80000000,
158 0,
159 0x80000000,
160 0,
161 0x80000000,
162 0,
163 0x80000000,
164 0,
165 0x80000000,
166 0,
167 0x80000000,
168 0,
169 0x80000000);
170 const __m512 dimagconj = _mm512_castsi512_ps(_mm512_xor_epi32(
171 _mm512_castps_si512(dimag), conjugator_i)); // ci, -ci, di, -di, ...
172
173 // Use FMA: x*dreal + nswap*dimagconj
174 return _mm512_fmadd_ps(nswap, dimagconj, _mm512_mul_ps(x, dreal));
175}
176
178// Normalize complex vector: divide each complex number by its magnitude
179// Requires AVX512F
181static inline __m512 _mm512_normalize_ps(const __m512 val)
182{
183 // Square the values: [r0^2, i0^2, r1^2, i1^2, ...]
184 __m512 tmp1 = _mm512_mul_ps(val, val);
185
186 // Swap adjacent elements to get [i0^2, r0^2, i1^2, r1^2, ...]
187 const __m512 tmp1_swapped = _mm512_permute_ps(tmp1, 0xB1);
188
189 // Add to get [r0^2+i0^2, i0^2+r0^2, r1^2+i1^2, i1^2+r1^2, ...]
190 __m512 mag_sq = _mm512_add_ps(tmp1, tmp1_swapped);
191
192 // Take square root to get magnitude
193 const __m512 mag = _mm512_sqrt_ps(mag_sq);
194
195 // Divide by magnitude
196 return _mm512_div_ps(val, mag);
197}
198
200// Minimax polynomial for sin(x) on [-pi/4, pi/4]
201// Coefficients via Remez algorithm (Sollya)
202// Max |error| < 7.3e-9
203// sin(x) = x + x^3 * (s1 + x^2 * (s2 + x^2 * s3))
204// Requires AVX512F
206static inline __m512 _mm512_sin_poly_avx512(const __m512 x)
207{
208 const __m512 s1 = _mm512_set1_ps(-0x1.555552p-3f);
209 const __m512 s2 = _mm512_set1_ps(+0x1.110be2p-7f);
210 const __m512 s3 = _mm512_set1_ps(-0x1.9ab22ap-13f);
211
212 const __m512 x2 = _mm512_mul_ps(x, x);
213 const __m512 x3 = _mm512_mul_ps(x2, x);
214
215 __m512 poly = _mm512_fmadd_ps(x2, s3, s2);
216 poly = _mm512_fmadd_ps(x2, poly, s1);
217 return _mm512_fmadd_ps(x3, poly, x);
218}
219
221// Minimax polynomial for cos(x) on [-pi/4, pi/4]
222// Coefficients via Remez algorithm (Sollya)
223// Max |error| < 1.1e-7
224// cos(x) = 1 + x^2 * (c1 + x^2 * (c2 + x^2 * c3))
225// Requires AVX512F
227static inline __m512 _mm512_cos_poly_avx512(const __m512 x)
228{
229 const __m512 c1 = _mm512_set1_ps(-0x1.fffff4p-2f);
230 const __m512 c2 = _mm512_set1_ps(+0x1.554a46p-5f);
231 const __m512 c3 = _mm512_set1_ps(-0x1.661be2p-10f);
232 const __m512 one = _mm512_set1_ps(1.0f);
233
234 const __m512 x2 = _mm512_mul_ps(x, x);
235
236 __m512 poly = _mm512_fmadd_ps(x2, c3, c2);
237 poly = _mm512_fmadd_ps(x2, poly, c1);
238 return _mm512_fmadd_ps(x2, poly, one);
239}
240
242// Polynomial coefficients for log2(x)/(x-1) on [1, 2]
243// Generated with Sollya: remez(log2(x)/(x-1), 6, [1+1b-20, 2])
244// Max error: ~1.55e-6
245//
246// Usage: log2(x) ≈ poly(x) * (x - 1) for x ∈ [1, 2]
247// Polynomial evaluated via Horner's method with FMA
248// Requires AVX512F
250static inline __m512 _mm512_log2_poly_avx512(const __m512 x)
251{
252 const __m512 c0 = _mm512_set1_ps(+0x1.a8a726p+1f);
253 const __m512 c1 = _mm512_set1_ps(-0x1.0b7f7ep+2f);
254 const __m512 c2 = _mm512_set1_ps(+0x1.05d9ccp+2f);
255 const __m512 c3 = _mm512_set1_ps(-0x1.4d476cp+1f);
256 const __m512 c4 = _mm512_set1_ps(+0x1.04fc3ap+0f);
257 const __m512 c5 = _mm512_set1_ps(-0x1.c97982p-3f);
258 const __m512 c6 = _mm512_set1_ps(+0x1.57aa42p-6f);
259
260 // Horner's method with FMA: c0 + x*(c1 + x*(c2 + ...))
261 __m512 poly = c6;
262 poly = _mm512_fmadd_ps(poly, x, c5);
263 poly = _mm512_fmadd_ps(poly, x, c4);
264 poly = _mm512_fmadd_ps(poly, x, c3);
265 poly = _mm512_fmadd_ps(poly, x, c2);
266 poly = _mm512_fmadd_ps(poly, x, c1);
267 poly = _mm512_fmadd_ps(poly, x, c0);
268 return poly;
269}
270
271#endif /* INCLUDE_VOLK_VOLK_AVX512_INTRINSICS_H_ */