LLVM 22.0.0git
blake3_avx512.c
Go to the documentation of this file.
1#include "blake3_impl.h"
2
3#include <immintrin.h>
4
5#define _mm_shuffle_ps2(a, b, c) \
6 (_mm_castps_si128( \
7 _mm_shuffle_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b), (c))))
8
9INLINE __m128i loadu_128(const uint8_t src[16]) {
10 return _mm_loadu_si128((const __m128i *)src);
11}
12
13INLINE __m256i loadu_256(const uint8_t src[32]) {
14 return _mm256_loadu_si256((const __m256i *)src);
15}
16
17INLINE __m512i loadu_512(const uint8_t src[64]) {
18 return _mm512_loadu_si512((const __m512i *)src);
19}
20
21INLINE void storeu_128(__m128i src, uint8_t dest[16]) {
22 _mm_storeu_si128((__m128i *)dest, src);
23}
24
25INLINE void storeu_256(__m256i src, uint8_t dest[32]) {
26 _mm256_storeu_si256((__m256i *)dest, src);
27}
28
29INLINE void storeu_512(__m512i src, uint8_t dest[64]) {
30 _mm512_storeu_si512((__m512i *)dest, src);
31}
32
33INLINE __m128i add_128(__m128i a, __m128i b) { return _mm_add_epi32(a, b); }
34
35INLINE __m256i add_256(__m256i a, __m256i b) { return _mm256_add_epi32(a, b); }
36
37INLINE __m512i add_512(__m512i a, __m512i b) { return _mm512_add_epi32(a, b); }
38
39INLINE __m128i xor_128(__m128i a, __m128i b) { return _mm_xor_si128(a, b); }
40
41INLINE __m256i xor_256(__m256i a, __m256i b) { return _mm256_xor_si256(a, b); }
42
43INLINE __m512i xor_512(__m512i a, __m512i b) { return _mm512_xor_si512(a, b); }
44
45INLINE __m128i set1_128(uint32_t x) { return _mm_set1_epi32((int32_t)x); }
46
47INLINE __m256i set1_256(uint32_t x) { return _mm256_set1_epi32((int32_t)x); }
48
49INLINE __m512i set1_512(uint32_t x) { return _mm512_set1_epi32((int32_t)x); }
50
52 return _mm_setr_epi32((int32_t)a, (int32_t)b, (int32_t)c, (int32_t)d);
53}
54
55INLINE __m128i rot16_128(__m128i x) { return _mm_ror_epi32(x, 16); }
56
57INLINE __m256i rot16_256(__m256i x) { return _mm256_ror_epi32(x, 16); }
58
59INLINE __m512i rot16_512(__m512i x) { return _mm512_ror_epi32(x, 16); }
60
61INLINE __m128i rot12_128(__m128i x) { return _mm_ror_epi32(x, 12); }
62
63INLINE __m256i rot12_256(__m256i x) { return _mm256_ror_epi32(x, 12); }
64
65INLINE __m512i rot12_512(__m512i x) { return _mm512_ror_epi32(x, 12); }
66
67INLINE __m128i rot8_128(__m128i x) { return _mm_ror_epi32(x, 8); }
68
69INLINE __m256i rot8_256(__m256i x) { return _mm256_ror_epi32(x, 8); }
70
71INLINE __m512i rot8_512(__m512i x) { return _mm512_ror_epi32(x, 8); }
72
73INLINE __m128i rot7_128(__m128i x) { return _mm_ror_epi32(x, 7); }
74
75INLINE __m256i rot7_256(__m256i x) { return _mm256_ror_epi32(x, 7); }
76
77INLINE __m512i rot7_512(__m512i x) { return _mm512_ror_epi32(x, 7); }
78
79/*
80 * ----------------------------------------------------------------------------
81 * compress_avx512
82 * ----------------------------------------------------------------------------
83 */
84
85INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
86 __m128i m) {
87 *row0 = add_128(add_128(*row0, m), *row1);
88 *row3 = xor_128(*row3, *row0);
89 *row3 = rot16_128(*row3);
90 *row2 = add_128(*row2, *row3);
91 *row1 = xor_128(*row1, *row2);
92 *row1 = rot12_128(*row1);
93}
94
95INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3,
96 __m128i m) {
97 *row0 = add_128(add_128(*row0, m), *row1);
98 *row3 = xor_128(*row3, *row0);
99 *row3 = rot8_128(*row3);
100 *row2 = add_128(*row2, *row3);
101 *row1 = xor_128(*row1, *row2);
102 *row1 = rot7_128(*row1);
103}
104
105// Note the optimization here of leaving row1 as the unrotated row, rather than
106// row0. All the message loads below are adjusted to compensate for this. See
107// discussion at https://github.com/sneves/blake2-avx2/pull/4
108INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
109 *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(2, 1, 0, 3));
110 *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
111 *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(0, 3, 2, 1));
112}
113
114INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3) {
115 *row0 = _mm_shuffle_epi32(*row0, _MM_SHUFFLE(0, 3, 2, 1));
116 *row3 = _mm_shuffle_epi32(*row3, _MM_SHUFFLE(1, 0, 3, 2));
117 *row2 = _mm_shuffle_epi32(*row2, _MM_SHUFFLE(2, 1, 0, 3));
118}
119
120INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8],
122 uint8_t block_len, uint64_t counter, uint8_t flags) {
123 rows[0] = loadu_128((uint8_t *)&cv[0]);
124 rows[1] = loadu_128((uint8_t *)&cv[4]);
125 rows[2] = set4(IV[0], IV[1], IV[2], IV[3]);
126 rows[3] = set4(counter_low(counter), counter_high(counter),
127 (uint32_t)block_len, (uint32_t)flags);
128
129 __m128i m0 = loadu_128(&block[sizeof(__m128i) * 0]);
130 __m128i m1 = loadu_128(&block[sizeof(__m128i) * 1]);
131 __m128i m2 = loadu_128(&block[sizeof(__m128i) * 2]);
132 __m128i m3 = loadu_128(&block[sizeof(__m128i) * 3]);
133
134 __m128i t0, t1, t2, t3, tt;
135
136 // Round 1. The first round permutes the message words from the original
137 // input order, into the groups that get mixed in parallel.
138 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(2, 0, 2, 0)); // 6 4 2 0
139 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
140 t1 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 3, 1)); // 7 5 3 1
141 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
142 diagonalize(&rows[0], &rows[2], &rows[3]);
143 t2 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(2, 0, 2, 0)); // 14 12 10 8
144 t2 = _mm_shuffle_epi32(t2, _MM_SHUFFLE(2, 1, 0, 3)); // 12 10 8 14
145 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
146 t3 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 1, 3, 1)); // 15 13 11 9
147 t3 = _mm_shuffle_epi32(t3, _MM_SHUFFLE(2, 1, 0, 3)); // 13 11 9 15
148 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
149 undiagonalize(&rows[0], &rows[2], &rows[3]);
150 m0 = t0;
151 m1 = t1;
152 m2 = t2;
153 m3 = t3;
154
155 // Round 2. This round and all following rounds apply a fixed permutation
156 // to the message words from the round before.
157 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
158 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
159 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
160 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
161 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
162 t1 = _mm_blend_epi16(tt, t1, 0xCC);
163 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
164 diagonalize(&rows[0], &rows[2], &rows[3]);
165 t2 = _mm_unpacklo_epi64(m3, m1);
166 tt = _mm_blend_epi16(t2, m2, 0xC0);
167 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
168 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
169 t3 = _mm_unpackhi_epi32(m1, m3);
170 tt = _mm_unpacklo_epi32(m2, t3);
171 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
172 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
173 undiagonalize(&rows[0], &rows[2], &rows[3]);
174 m0 = t0;
175 m1 = t1;
176 m2 = t2;
177 m3 = t3;
178
179 // Round 3
180 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
181 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
182 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
183 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
184 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
185 t1 = _mm_blend_epi16(tt, t1, 0xCC);
186 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
187 diagonalize(&rows[0], &rows[2], &rows[3]);
188 t2 = _mm_unpacklo_epi64(m3, m1);
189 tt = _mm_blend_epi16(t2, m2, 0xC0);
190 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
191 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
192 t3 = _mm_unpackhi_epi32(m1, m3);
193 tt = _mm_unpacklo_epi32(m2, t3);
194 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
195 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
196 undiagonalize(&rows[0], &rows[2], &rows[3]);
197 m0 = t0;
198 m1 = t1;
199 m2 = t2;
200 m3 = t3;
201
202 // Round 4
203 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
204 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
205 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
206 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
207 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
208 t1 = _mm_blend_epi16(tt, t1, 0xCC);
209 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
210 diagonalize(&rows[0], &rows[2], &rows[3]);
211 t2 = _mm_unpacklo_epi64(m3, m1);
212 tt = _mm_blend_epi16(t2, m2, 0xC0);
213 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
214 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
215 t3 = _mm_unpackhi_epi32(m1, m3);
216 tt = _mm_unpacklo_epi32(m2, t3);
217 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
218 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
219 undiagonalize(&rows[0], &rows[2], &rows[3]);
220 m0 = t0;
221 m1 = t1;
222 m2 = t2;
223 m3 = t3;
224
225 // Round 5
226 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
227 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
228 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
229 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
230 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
231 t1 = _mm_blend_epi16(tt, t1, 0xCC);
232 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
233 diagonalize(&rows[0], &rows[2], &rows[3]);
234 t2 = _mm_unpacklo_epi64(m3, m1);
235 tt = _mm_blend_epi16(t2, m2, 0xC0);
236 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
237 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
238 t3 = _mm_unpackhi_epi32(m1, m3);
239 tt = _mm_unpacklo_epi32(m2, t3);
240 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
241 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
242 undiagonalize(&rows[0], &rows[2], &rows[3]);
243 m0 = t0;
244 m1 = t1;
245 m2 = t2;
246 m3 = t3;
247
248 // Round 6
249 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
250 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
251 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
252 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
253 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
254 t1 = _mm_blend_epi16(tt, t1, 0xCC);
255 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
256 diagonalize(&rows[0], &rows[2], &rows[3]);
257 t2 = _mm_unpacklo_epi64(m3, m1);
258 tt = _mm_blend_epi16(t2, m2, 0xC0);
259 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
260 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
261 t3 = _mm_unpackhi_epi32(m1, m3);
262 tt = _mm_unpacklo_epi32(m2, t3);
263 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
264 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
265 undiagonalize(&rows[0], &rows[2], &rows[3]);
266 m0 = t0;
267 m1 = t1;
268 m2 = t2;
269 m3 = t3;
270
271 // Round 7
272 t0 = _mm_shuffle_ps2(m0, m1, _MM_SHUFFLE(3, 1, 1, 2));
273 t0 = _mm_shuffle_epi32(t0, _MM_SHUFFLE(0, 3, 2, 1));
274 g1(&rows[0], &rows[1], &rows[2], &rows[3], t0);
275 t1 = _mm_shuffle_ps2(m2, m3, _MM_SHUFFLE(3, 3, 2, 2));
276 tt = _mm_shuffle_epi32(m0, _MM_SHUFFLE(0, 0, 3, 3));
277 t1 = _mm_blend_epi16(tt, t1, 0xCC);
278 g2(&rows[0], &rows[1], &rows[2], &rows[3], t1);
279 diagonalize(&rows[0], &rows[2], &rows[3]);
280 t2 = _mm_unpacklo_epi64(m3, m1);
281 tt = _mm_blend_epi16(t2, m2, 0xC0);
282 t2 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(1, 3, 2, 0));
283 g1(&rows[0], &rows[1], &rows[2], &rows[3], t2);
284 t3 = _mm_unpackhi_epi32(m1, m3);
285 tt = _mm_unpacklo_epi32(m2, t3);
286 t3 = _mm_shuffle_epi32(tt, _MM_SHUFFLE(0, 1, 3, 2));
287 g2(&rows[0], &rows[1], &rows[2], &rows[3], t3);
288 undiagonalize(&rows[0], &rows[2], &rows[3]);
289}
290
293 uint8_t block_len, uint64_t counter,
294 uint8_t flags, uint8_t out[64]) {
295 __m128i rows[4];
296 compress_pre(rows, cv, block, block_len, counter, flags);
297 storeu_128(xor_128(rows[0], rows[2]), &out[0]);
298 storeu_128(xor_128(rows[1], rows[3]), &out[16]);
299 storeu_128(xor_128(rows[2], loadu_128((uint8_t *)&cv[0])), &out[32]);
300 storeu_128(xor_128(rows[3], loadu_128((uint8_t *)&cv[4])), &out[48]);
301}
302
305 uint8_t block_len, uint64_t counter,
306 uint8_t flags) {
307 __m128i rows[4];
308 compress_pre(rows, cv, block, block_len, counter, flags);
309 storeu_128(xor_128(rows[0], rows[2]), (uint8_t *)&cv[0]);
310 storeu_128(xor_128(rows[1], rows[3]), (uint8_t *)&cv[4]);
311}
312
313/*
314 * ----------------------------------------------------------------------------
315 * hash4_avx512
316 * ----------------------------------------------------------------------------
317 */
318
319INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r) {
320 v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
321 v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
322 v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
323 v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
324 v[0] = add_128(v[0], v[4]);
325 v[1] = add_128(v[1], v[5]);
326 v[2] = add_128(v[2], v[6]);
327 v[3] = add_128(v[3], v[7]);
328 v[12] = xor_128(v[12], v[0]);
329 v[13] = xor_128(v[13], v[1]);
330 v[14] = xor_128(v[14], v[2]);
331 v[15] = xor_128(v[15], v[3]);
332 v[12] = rot16_128(v[12]);
333 v[13] = rot16_128(v[13]);
334 v[14] = rot16_128(v[14]);
335 v[15] = rot16_128(v[15]);
336 v[8] = add_128(v[8], v[12]);
337 v[9] = add_128(v[9], v[13]);
338 v[10] = add_128(v[10], v[14]);
339 v[11] = add_128(v[11], v[15]);
340 v[4] = xor_128(v[4], v[8]);
341 v[5] = xor_128(v[5], v[9]);
342 v[6] = xor_128(v[6], v[10]);
343 v[7] = xor_128(v[7], v[11]);
344 v[4] = rot12_128(v[4]);
345 v[5] = rot12_128(v[5]);
346 v[6] = rot12_128(v[6]);
347 v[7] = rot12_128(v[7]);
348 v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
349 v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
350 v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
351 v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
352 v[0] = add_128(v[0], v[4]);
353 v[1] = add_128(v[1], v[5]);
354 v[2] = add_128(v[2], v[6]);
355 v[3] = add_128(v[3], v[7]);
356 v[12] = xor_128(v[12], v[0]);
357 v[13] = xor_128(v[13], v[1]);
358 v[14] = xor_128(v[14], v[2]);
359 v[15] = xor_128(v[15], v[3]);
360 v[12] = rot8_128(v[12]);
361 v[13] = rot8_128(v[13]);
362 v[14] = rot8_128(v[14]);
363 v[15] = rot8_128(v[15]);
364 v[8] = add_128(v[8], v[12]);
365 v[9] = add_128(v[9], v[13]);
366 v[10] = add_128(v[10], v[14]);
367 v[11] = add_128(v[11], v[15]);
368 v[4] = xor_128(v[4], v[8]);
369 v[5] = xor_128(v[5], v[9]);
370 v[6] = xor_128(v[6], v[10]);
371 v[7] = xor_128(v[7], v[11]);
372 v[4] = rot7_128(v[4]);
373 v[5] = rot7_128(v[5]);
374 v[6] = rot7_128(v[6]);
375 v[7] = rot7_128(v[7]);
376
377 v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
378 v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
379 v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
380 v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
381 v[0] = add_128(v[0], v[5]);
382 v[1] = add_128(v[1], v[6]);
383 v[2] = add_128(v[2], v[7]);
384 v[3] = add_128(v[3], v[4]);
385 v[15] = xor_128(v[15], v[0]);
386 v[12] = xor_128(v[12], v[1]);
387 v[13] = xor_128(v[13], v[2]);
388 v[14] = xor_128(v[14], v[3]);
389 v[15] = rot16_128(v[15]);
390 v[12] = rot16_128(v[12]);
391 v[13] = rot16_128(v[13]);
392 v[14] = rot16_128(v[14]);
393 v[10] = add_128(v[10], v[15]);
394 v[11] = add_128(v[11], v[12]);
395 v[8] = add_128(v[8], v[13]);
396 v[9] = add_128(v[9], v[14]);
397 v[5] = xor_128(v[5], v[10]);
398 v[6] = xor_128(v[6], v[11]);
399 v[7] = xor_128(v[7], v[8]);
400 v[4] = xor_128(v[4], v[9]);
401 v[5] = rot12_128(v[5]);
402 v[6] = rot12_128(v[6]);
403 v[7] = rot12_128(v[7]);
404 v[4] = rot12_128(v[4]);
405 v[0] = add_128(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
406 v[1] = add_128(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
407 v[2] = add_128(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
408 v[3] = add_128(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
409 v[0] = add_128(v[0], v[5]);
410 v[1] = add_128(v[1], v[6]);
411 v[2] = add_128(v[2], v[7]);
412 v[3] = add_128(v[3], v[4]);
413 v[15] = xor_128(v[15], v[0]);
414 v[12] = xor_128(v[12], v[1]);
415 v[13] = xor_128(v[13], v[2]);
416 v[14] = xor_128(v[14], v[3]);
417 v[15] = rot8_128(v[15]);
418 v[12] = rot8_128(v[12]);
419 v[13] = rot8_128(v[13]);
420 v[14] = rot8_128(v[14]);
421 v[10] = add_128(v[10], v[15]);
422 v[11] = add_128(v[11], v[12]);
423 v[8] = add_128(v[8], v[13]);
424 v[9] = add_128(v[9], v[14]);
425 v[5] = xor_128(v[5], v[10]);
426 v[6] = xor_128(v[6], v[11]);
427 v[7] = xor_128(v[7], v[8]);
428 v[4] = xor_128(v[4], v[9]);
429 v[5] = rot7_128(v[5]);
430 v[6] = rot7_128(v[6]);
431 v[7] = rot7_128(v[7]);
432 v[4] = rot7_128(v[4]);
433}
434
435INLINE void transpose_vecs_128(__m128i vecs[4]) {
436 // Interleave 32-bit lanes. The low unpack is lanes 00/11 and the high is
437 // 22/33. Note that this doesn't split the vector into two lanes, as the
438 // AVX2 counterparts do.
439 __m128i ab_01 = _mm_unpacklo_epi32(vecs[0], vecs[1]);
440 __m128i ab_23 = _mm_unpackhi_epi32(vecs[0], vecs[1]);
441 __m128i cd_01 = _mm_unpacklo_epi32(vecs[2], vecs[3]);
442 __m128i cd_23 = _mm_unpackhi_epi32(vecs[2], vecs[3]);
443
444 // Interleave 64-bit lanes.
445 __m128i abcd_0 = _mm_unpacklo_epi64(ab_01, cd_01);
446 __m128i abcd_1 = _mm_unpackhi_epi64(ab_01, cd_01);
447 __m128i abcd_2 = _mm_unpacklo_epi64(ab_23, cd_23);
448 __m128i abcd_3 = _mm_unpackhi_epi64(ab_23, cd_23);
449
450 vecs[0] = abcd_0;
451 vecs[1] = abcd_1;
452 vecs[2] = abcd_2;
453 vecs[3] = abcd_3;
454}
455
456INLINE void transpose_msg_vecs4(const uint8_t *const *inputs,
457 size_t block_offset, __m128i out[16]) {
458 out[0] = loadu_128(&inputs[0][block_offset + 0 * sizeof(__m128i)]);
459 out[1] = loadu_128(&inputs[1][block_offset + 0 * sizeof(__m128i)]);
460 out[2] = loadu_128(&inputs[2][block_offset + 0 * sizeof(__m128i)]);
461 out[3] = loadu_128(&inputs[3][block_offset + 0 * sizeof(__m128i)]);
462 out[4] = loadu_128(&inputs[0][block_offset + 1 * sizeof(__m128i)]);
463 out[5] = loadu_128(&inputs[1][block_offset + 1 * sizeof(__m128i)]);
464 out[6] = loadu_128(&inputs[2][block_offset + 1 * sizeof(__m128i)]);
465 out[7] = loadu_128(&inputs[3][block_offset + 1 * sizeof(__m128i)]);
466 out[8] = loadu_128(&inputs[0][block_offset + 2 * sizeof(__m128i)]);
467 out[9] = loadu_128(&inputs[1][block_offset + 2 * sizeof(__m128i)]);
468 out[10] = loadu_128(&inputs[2][block_offset + 2 * sizeof(__m128i)]);
469 out[11] = loadu_128(&inputs[3][block_offset + 2 * sizeof(__m128i)]);
470 out[12] = loadu_128(&inputs[0][block_offset + 3 * sizeof(__m128i)]);
471 out[13] = loadu_128(&inputs[1][block_offset + 3 * sizeof(__m128i)]);
472 out[14] = loadu_128(&inputs[2][block_offset + 3 * sizeof(__m128i)]);
473 out[15] = loadu_128(&inputs[3][block_offset + 3 * sizeof(__m128i)]);
474 for (size_t i = 0; i < 4; ++i) {
475 _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
476 }
477 transpose_vecs_128(&out[0]);
478 transpose_vecs_128(&out[4]);
479 transpose_vecs_128(&out[8]);
480 transpose_vecs_128(&out[12]);
481}
482
483INLINE void load_counters4(uint64_t counter, bool increment_counter,
484 __m128i *out_lo, __m128i *out_hi) {
485 uint64_t mask = (increment_counter ? ~0 : 0);
486 __m256i mask_vec = _mm256_set1_epi64x(mask);
487 __m256i deltas = _mm256_setr_epi64x(0, 1, 2, 3);
488 deltas = _mm256_and_si256(mask_vec, deltas);
489 __m256i counters =
490 _mm256_add_epi64(_mm256_set1_epi64x((int64_t)counter), deltas);
491 *out_lo = _mm256_cvtepi64_epi32(counters);
492 *out_hi = _mm256_cvtepi64_epi32(_mm256_srli_epi64(counters, 32));
493}
494
495static
496void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks,
497 const uint32_t key[8], uint64_t counter,
498 bool increment_counter, uint8_t flags,
499 uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
500 __m128i h_vecs[8] = {
501 set1_128(key[0]), set1_128(key[1]), set1_128(key[2]), set1_128(key[3]),
502 set1_128(key[4]), set1_128(key[5]), set1_128(key[6]), set1_128(key[7]),
503 };
504 __m128i counter_low_vec, counter_high_vec;
505 load_counters4(counter, increment_counter, &counter_low_vec,
506 &counter_high_vec);
507 uint8_t block_flags = flags | flags_start;
508
509 for (size_t block = 0; block < blocks; block++) {
510 if (block + 1 == blocks) {
511 block_flags |= flags_end;
512 }
513 __m128i block_len_vec = set1_128(BLAKE3_BLOCK_LEN);
514 __m128i block_flags_vec = set1_128(block_flags);
515 __m128i msg_vecs[16];
516 transpose_msg_vecs4(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
517
518 __m128i v[16] = {
519 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
520 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
521 set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]),
522 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
523 };
524 round_fn4(v, msg_vecs, 0);
525 round_fn4(v, msg_vecs, 1);
526 round_fn4(v, msg_vecs, 2);
527 round_fn4(v, msg_vecs, 3);
528 round_fn4(v, msg_vecs, 4);
529 round_fn4(v, msg_vecs, 5);
530 round_fn4(v, msg_vecs, 6);
531 h_vecs[0] = xor_128(v[0], v[8]);
532 h_vecs[1] = xor_128(v[1], v[9]);
533 h_vecs[2] = xor_128(v[2], v[10]);
534 h_vecs[3] = xor_128(v[3], v[11]);
535 h_vecs[4] = xor_128(v[4], v[12]);
536 h_vecs[5] = xor_128(v[5], v[13]);
537 h_vecs[6] = xor_128(v[6], v[14]);
538 h_vecs[7] = xor_128(v[7], v[15]);
539
540 block_flags = flags;
541 }
542
543 transpose_vecs_128(&h_vecs[0]);
544 transpose_vecs_128(&h_vecs[4]);
545 // The first four vecs now contain the first half of each output, and the
546 // second four vecs contain the second half of each output.
547 storeu_128(h_vecs[0], &out[0 * sizeof(__m128i)]);
548 storeu_128(h_vecs[4], &out[1 * sizeof(__m128i)]);
549 storeu_128(h_vecs[1], &out[2 * sizeof(__m128i)]);
550 storeu_128(h_vecs[5], &out[3 * sizeof(__m128i)]);
551 storeu_128(h_vecs[2], &out[4 * sizeof(__m128i)]);
552 storeu_128(h_vecs[6], &out[5 * sizeof(__m128i)]);
553 storeu_128(h_vecs[3], &out[6 * sizeof(__m128i)]);
554 storeu_128(h_vecs[7], &out[7 * sizeof(__m128i)]);
555}
556
557static
560 uint8_t block_len, uint64_t counter, uint8_t flags,
561 uint8_t out[4 * 64]) {
562 __m128i h_vecs[8] = {
563 set1_128(cv[0]), set1_128(cv[1]), set1_128(cv[2]), set1_128(cv[3]),
564 set1_128(cv[4]), set1_128(cv[5]), set1_128(cv[6]), set1_128(cv[7]),
565 };
566 uint32_t block_words[16];
567 load_block_words(block, block_words);
568 __m128i msg_vecs[16];
569 for (size_t i = 0; i < 16; i++) {
570 msg_vecs[i] = set1_128(block_words[i]);
571 }
572 __m128i counter_low_vec, counter_high_vec;
573 load_counters4(counter, true, &counter_low_vec, &counter_high_vec);
574 __m128i block_len_vec = set1_128(block_len);
575 __m128i block_flags_vec = set1_128(flags);
576 __m128i v[16] = {
577 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
578 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
579 set1_128(IV[0]), set1_128(IV[1]), set1_128(IV[2]), set1_128(IV[3]),
580 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
581 };
582 round_fn4(v, msg_vecs, 0);
583 round_fn4(v, msg_vecs, 1);
584 round_fn4(v, msg_vecs, 2);
585 round_fn4(v, msg_vecs, 3);
586 round_fn4(v, msg_vecs, 4);
587 round_fn4(v, msg_vecs, 5);
588 round_fn4(v, msg_vecs, 6);
589 for (size_t i = 0; i < 8; i++) {
590 v[i] = xor_128(v[i], v[i+8]);
591 v[i+8] = xor_128(v[i+8], h_vecs[i]);
592 }
593 transpose_vecs_128(&v[0]);
594 transpose_vecs_128(&v[4]);
595 transpose_vecs_128(&v[8]);
596 transpose_vecs_128(&v[12]);
597 for (size_t i = 0; i < 4; i++) {
598 storeu_128(v[i+ 0], &out[(4*i+0) * sizeof(__m128i)]);
599 storeu_128(v[i+ 4], &out[(4*i+1) * sizeof(__m128i)]);
600 storeu_128(v[i+ 8], &out[(4*i+2) * sizeof(__m128i)]);
601 storeu_128(v[i+12], &out[(4*i+3) * sizeof(__m128i)]);
602 }
603}
604
605/*
606 * ----------------------------------------------------------------------------
607 * hash8_avx512
608 * ----------------------------------------------------------------------------
609 */
610
611INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r) {
612 v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
613 v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
614 v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
615 v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
616 v[0] = add_256(v[0], v[4]);
617 v[1] = add_256(v[1], v[5]);
618 v[2] = add_256(v[2], v[6]);
619 v[3] = add_256(v[3], v[7]);
620 v[12] = xor_256(v[12], v[0]);
621 v[13] = xor_256(v[13], v[1]);
622 v[14] = xor_256(v[14], v[2]);
623 v[15] = xor_256(v[15], v[3]);
624 v[12] = rot16_256(v[12]);
625 v[13] = rot16_256(v[13]);
626 v[14] = rot16_256(v[14]);
627 v[15] = rot16_256(v[15]);
628 v[8] = add_256(v[8], v[12]);
629 v[9] = add_256(v[9], v[13]);
630 v[10] = add_256(v[10], v[14]);
631 v[11] = add_256(v[11], v[15]);
632 v[4] = xor_256(v[4], v[8]);
633 v[5] = xor_256(v[5], v[9]);
634 v[6] = xor_256(v[6], v[10]);
635 v[7] = xor_256(v[7], v[11]);
636 v[4] = rot12_256(v[4]);
637 v[5] = rot12_256(v[5]);
638 v[6] = rot12_256(v[6]);
639 v[7] = rot12_256(v[7]);
640 v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
641 v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
642 v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
643 v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
644 v[0] = add_256(v[0], v[4]);
645 v[1] = add_256(v[1], v[5]);
646 v[2] = add_256(v[2], v[6]);
647 v[3] = add_256(v[3], v[7]);
648 v[12] = xor_256(v[12], v[0]);
649 v[13] = xor_256(v[13], v[1]);
650 v[14] = xor_256(v[14], v[2]);
651 v[15] = xor_256(v[15], v[3]);
652 v[12] = rot8_256(v[12]);
653 v[13] = rot8_256(v[13]);
654 v[14] = rot8_256(v[14]);
655 v[15] = rot8_256(v[15]);
656 v[8] = add_256(v[8], v[12]);
657 v[9] = add_256(v[9], v[13]);
658 v[10] = add_256(v[10], v[14]);
659 v[11] = add_256(v[11], v[15]);
660 v[4] = xor_256(v[4], v[8]);
661 v[5] = xor_256(v[5], v[9]);
662 v[6] = xor_256(v[6], v[10]);
663 v[7] = xor_256(v[7], v[11]);
664 v[4] = rot7_256(v[4]);
665 v[5] = rot7_256(v[5]);
666 v[6] = rot7_256(v[6]);
667 v[7] = rot7_256(v[7]);
668
669 v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
670 v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
671 v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
672 v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
673 v[0] = add_256(v[0], v[5]);
674 v[1] = add_256(v[1], v[6]);
675 v[2] = add_256(v[2], v[7]);
676 v[3] = add_256(v[3], v[4]);
677 v[15] = xor_256(v[15], v[0]);
678 v[12] = xor_256(v[12], v[1]);
679 v[13] = xor_256(v[13], v[2]);
680 v[14] = xor_256(v[14], v[3]);
681 v[15] = rot16_256(v[15]);
682 v[12] = rot16_256(v[12]);
683 v[13] = rot16_256(v[13]);
684 v[14] = rot16_256(v[14]);
685 v[10] = add_256(v[10], v[15]);
686 v[11] = add_256(v[11], v[12]);
687 v[8] = add_256(v[8], v[13]);
688 v[9] = add_256(v[9], v[14]);
689 v[5] = xor_256(v[5], v[10]);
690 v[6] = xor_256(v[6], v[11]);
691 v[7] = xor_256(v[7], v[8]);
692 v[4] = xor_256(v[4], v[9]);
693 v[5] = rot12_256(v[5]);
694 v[6] = rot12_256(v[6]);
695 v[7] = rot12_256(v[7]);
696 v[4] = rot12_256(v[4]);
697 v[0] = add_256(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
698 v[1] = add_256(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
699 v[2] = add_256(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
700 v[3] = add_256(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
701 v[0] = add_256(v[0], v[5]);
702 v[1] = add_256(v[1], v[6]);
703 v[2] = add_256(v[2], v[7]);
704 v[3] = add_256(v[3], v[4]);
705 v[15] = xor_256(v[15], v[0]);
706 v[12] = xor_256(v[12], v[1]);
707 v[13] = xor_256(v[13], v[2]);
708 v[14] = xor_256(v[14], v[3]);
709 v[15] = rot8_256(v[15]);
710 v[12] = rot8_256(v[12]);
711 v[13] = rot8_256(v[13]);
712 v[14] = rot8_256(v[14]);
713 v[10] = add_256(v[10], v[15]);
714 v[11] = add_256(v[11], v[12]);
715 v[8] = add_256(v[8], v[13]);
716 v[9] = add_256(v[9], v[14]);
717 v[5] = xor_256(v[5], v[10]);
718 v[6] = xor_256(v[6], v[11]);
719 v[7] = xor_256(v[7], v[8]);
720 v[4] = xor_256(v[4], v[9]);
721 v[5] = rot7_256(v[5]);
722 v[6] = rot7_256(v[6]);
723 v[7] = rot7_256(v[7]);
724 v[4] = rot7_256(v[4]);
725}
726
727INLINE void transpose_vecs_256(__m256i vecs[8]) {
728 // Interleave 32-bit lanes. The low unpack is lanes 00/11/44/55, and the high
729 // is 22/33/66/77.
730 __m256i ab_0145 = _mm256_unpacklo_epi32(vecs[0], vecs[1]);
731 __m256i ab_2367 = _mm256_unpackhi_epi32(vecs[0], vecs[1]);
732 __m256i cd_0145 = _mm256_unpacklo_epi32(vecs[2], vecs[3]);
733 __m256i cd_2367 = _mm256_unpackhi_epi32(vecs[2], vecs[3]);
734 __m256i ef_0145 = _mm256_unpacklo_epi32(vecs[4], vecs[5]);
735 __m256i ef_2367 = _mm256_unpackhi_epi32(vecs[4], vecs[5]);
736 __m256i gh_0145 = _mm256_unpacklo_epi32(vecs[6], vecs[7]);
737 __m256i gh_2367 = _mm256_unpackhi_epi32(vecs[6], vecs[7]);
738
739 // Interleave 64-bit lanes. The low unpack is lanes 00/22 and the high is
740 // 11/33.
741 __m256i abcd_04 = _mm256_unpacklo_epi64(ab_0145, cd_0145);
742 __m256i abcd_15 = _mm256_unpackhi_epi64(ab_0145, cd_0145);
743 __m256i abcd_26 = _mm256_unpacklo_epi64(ab_2367, cd_2367);
744 __m256i abcd_37 = _mm256_unpackhi_epi64(ab_2367, cd_2367);
745 __m256i efgh_04 = _mm256_unpacklo_epi64(ef_0145, gh_0145);
746 __m256i efgh_15 = _mm256_unpackhi_epi64(ef_0145, gh_0145);
747 __m256i efgh_26 = _mm256_unpacklo_epi64(ef_2367, gh_2367);
748 __m256i efgh_37 = _mm256_unpackhi_epi64(ef_2367, gh_2367);
749
750 // Interleave 128-bit lanes.
751 vecs[0] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x20);
752 vecs[1] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x20);
753 vecs[2] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x20);
754 vecs[3] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x20);
755 vecs[4] = _mm256_permute2x128_si256(abcd_04, efgh_04, 0x31);
756 vecs[5] = _mm256_permute2x128_si256(abcd_15, efgh_15, 0x31);
757 vecs[6] = _mm256_permute2x128_si256(abcd_26, efgh_26, 0x31);
758 vecs[7] = _mm256_permute2x128_si256(abcd_37, efgh_37, 0x31);
759}
760
761INLINE void transpose_msg_vecs8(const uint8_t *const *inputs,
762 size_t block_offset, __m256i out[16]) {
763 out[0] = loadu_256(&inputs[0][block_offset + 0 * sizeof(__m256i)]);
764 out[1] = loadu_256(&inputs[1][block_offset + 0 * sizeof(__m256i)]);
765 out[2] = loadu_256(&inputs[2][block_offset + 0 * sizeof(__m256i)]);
766 out[3] = loadu_256(&inputs[3][block_offset + 0 * sizeof(__m256i)]);
767 out[4] = loadu_256(&inputs[4][block_offset + 0 * sizeof(__m256i)]);
768 out[5] = loadu_256(&inputs[5][block_offset + 0 * sizeof(__m256i)]);
769 out[6] = loadu_256(&inputs[6][block_offset + 0 * sizeof(__m256i)]);
770 out[7] = loadu_256(&inputs[7][block_offset + 0 * sizeof(__m256i)]);
771 out[8] = loadu_256(&inputs[0][block_offset + 1 * sizeof(__m256i)]);
772 out[9] = loadu_256(&inputs[1][block_offset + 1 * sizeof(__m256i)]);
773 out[10] = loadu_256(&inputs[2][block_offset + 1 * sizeof(__m256i)]);
774 out[11] = loadu_256(&inputs[3][block_offset + 1 * sizeof(__m256i)]);
775 out[12] = loadu_256(&inputs[4][block_offset + 1 * sizeof(__m256i)]);
776 out[13] = loadu_256(&inputs[5][block_offset + 1 * sizeof(__m256i)]);
777 out[14] = loadu_256(&inputs[6][block_offset + 1 * sizeof(__m256i)]);
778 out[15] = loadu_256(&inputs[7][block_offset + 1 * sizeof(__m256i)]);
779 for (size_t i = 0; i < 8; ++i) {
780 _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
781 }
782 transpose_vecs_256(&out[0]);
783 transpose_vecs_256(&out[8]);
784}
785
786INLINE void load_counters8(uint64_t counter, bool increment_counter,
787 __m256i *out_lo, __m256i *out_hi) {
788 uint64_t mask = (increment_counter ? ~0 : 0);
789 __m512i mask_vec = _mm512_set1_epi64(mask);
790 __m512i deltas = _mm512_setr_epi64(0, 1, 2, 3, 4, 5, 6, 7);
791 deltas = _mm512_and_si512(mask_vec, deltas);
792 __m512i counters =
793 _mm512_add_epi64(_mm512_set1_epi64((int64_t)counter), deltas);
794 *out_lo = _mm512_cvtepi64_epi32(counters);
795 *out_hi = _mm512_cvtepi64_epi32(_mm512_srli_epi64(counters, 32));
796}
797
798static
799void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks,
800 const uint32_t key[8], uint64_t counter,
801 bool increment_counter, uint8_t flags,
802 uint8_t flags_start, uint8_t flags_end, uint8_t *out) {
803 __m256i h_vecs[8] = {
804 set1_256(key[0]), set1_256(key[1]), set1_256(key[2]), set1_256(key[3]),
805 set1_256(key[4]), set1_256(key[5]), set1_256(key[6]), set1_256(key[7]),
806 };
807 __m256i counter_low_vec, counter_high_vec;
808 load_counters8(counter, increment_counter, &counter_low_vec,
809 &counter_high_vec);
810 uint8_t block_flags = flags | flags_start;
811
812 for (size_t block = 0; block < blocks; block++) {
813 if (block + 1 == blocks) {
814 block_flags |= flags_end;
815 }
816 __m256i block_len_vec = set1_256(BLAKE3_BLOCK_LEN);
817 __m256i block_flags_vec = set1_256(block_flags);
818 __m256i msg_vecs[16];
819 transpose_msg_vecs8(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
820
821 __m256i v[16] = {
822 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
823 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
824 set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]),
825 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
826 };
827 round_fn8(v, msg_vecs, 0);
828 round_fn8(v, msg_vecs, 1);
829 round_fn8(v, msg_vecs, 2);
830 round_fn8(v, msg_vecs, 3);
831 round_fn8(v, msg_vecs, 4);
832 round_fn8(v, msg_vecs, 5);
833 round_fn8(v, msg_vecs, 6);
834 h_vecs[0] = xor_256(v[0], v[8]);
835 h_vecs[1] = xor_256(v[1], v[9]);
836 h_vecs[2] = xor_256(v[2], v[10]);
837 h_vecs[3] = xor_256(v[3], v[11]);
838 h_vecs[4] = xor_256(v[4], v[12]);
839 h_vecs[5] = xor_256(v[5], v[13]);
840 h_vecs[6] = xor_256(v[6], v[14]);
841 h_vecs[7] = xor_256(v[7], v[15]);
842
843 block_flags = flags;
844 }
845
846 transpose_vecs_256(h_vecs);
847 storeu_256(h_vecs[0], &out[0 * sizeof(__m256i)]);
848 storeu_256(h_vecs[1], &out[1 * sizeof(__m256i)]);
849 storeu_256(h_vecs[2], &out[2 * sizeof(__m256i)]);
850 storeu_256(h_vecs[3], &out[3 * sizeof(__m256i)]);
851 storeu_256(h_vecs[4], &out[4 * sizeof(__m256i)]);
852 storeu_256(h_vecs[5], &out[5 * sizeof(__m256i)]);
853 storeu_256(h_vecs[6], &out[6 * sizeof(__m256i)]);
854 storeu_256(h_vecs[7], &out[7 * sizeof(__m256i)]);
855}
856
857static
860 uint8_t block_len, uint64_t counter, uint8_t flags,
861 uint8_t out[8 * 64]) {
862 __m256i h_vecs[8] = {
863 set1_256(cv[0]), set1_256(cv[1]), set1_256(cv[2]), set1_256(cv[3]),
864 set1_256(cv[4]), set1_256(cv[5]), set1_256(cv[6]), set1_256(cv[7]),
865 };
866 uint32_t block_words[16];
867 load_block_words(block, block_words);
868 __m256i msg_vecs[16];
869 for (size_t i = 0; i < 16; i++) {
870 msg_vecs[i] = set1_256(block_words[i]);
871 }
872 __m256i counter_low_vec, counter_high_vec;
873 load_counters8(counter, true, &counter_low_vec, &counter_high_vec);
874 __m256i block_len_vec = set1_256(block_len);
875 __m256i block_flags_vec = set1_256(flags);
876 __m256i v[16] = {
877 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
878 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
879 set1_256(IV[0]), set1_256(IV[1]), set1_256(IV[2]), set1_256(IV[3]),
880 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
881 };
882 round_fn8(v, msg_vecs, 0);
883 round_fn8(v, msg_vecs, 1);
884 round_fn8(v, msg_vecs, 2);
885 round_fn8(v, msg_vecs, 3);
886 round_fn8(v, msg_vecs, 4);
887 round_fn8(v, msg_vecs, 5);
888 round_fn8(v, msg_vecs, 6);
889 for (size_t i = 0; i < 8; i++) {
890 v[i] = xor_256(v[i], v[i+8]);
891 v[i+8] = xor_256(v[i+8], h_vecs[i]);
892 }
893 transpose_vecs_256(&v[0]);
894 transpose_vecs_256(&v[8]);
895 for (size_t i = 0; i < 8; i++) {
896 storeu_256(v[i+0], &out[(2*i+0) * sizeof(__m256i)]);
897 storeu_256(v[i+8], &out[(2*i+1) * sizeof(__m256i)]);
898 }
899}
900
901/*
902 * ----------------------------------------------------------------------------
903 * hash16_avx512
904 * ----------------------------------------------------------------------------
905 */
906
907INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r) {
908 v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][0]]);
909 v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][2]]);
910 v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][4]]);
911 v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][6]]);
912 v[0] = add_512(v[0], v[4]);
913 v[1] = add_512(v[1], v[5]);
914 v[2] = add_512(v[2], v[6]);
915 v[3] = add_512(v[3], v[7]);
916 v[12] = xor_512(v[12], v[0]);
917 v[13] = xor_512(v[13], v[1]);
918 v[14] = xor_512(v[14], v[2]);
919 v[15] = xor_512(v[15], v[3]);
920 v[12] = rot16_512(v[12]);
921 v[13] = rot16_512(v[13]);
922 v[14] = rot16_512(v[14]);
923 v[15] = rot16_512(v[15]);
924 v[8] = add_512(v[8], v[12]);
925 v[9] = add_512(v[9], v[13]);
926 v[10] = add_512(v[10], v[14]);
927 v[11] = add_512(v[11], v[15]);
928 v[4] = xor_512(v[4], v[8]);
929 v[5] = xor_512(v[5], v[9]);
930 v[6] = xor_512(v[6], v[10]);
931 v[7] = xor_512(v[7], v[11]);
932 v[4] = rot12_512(v[4]);
933 v[5] = rot12_512(v[5]);
934 v[6] = rot12_512(v[6]);
935 v[7] = rot12_512(v[7]);
936 v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][1]]);
937 v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][3]]);
938 v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][5]]);
939 v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][7]]);
940 v[0] = add_512(v[0], v[4]);
941 v[1] = add_512(v[1], v[5]);
942 v[2] = add_512(v[2], v[6]);
943 v[3] = add_512(v[3], v[7]);
944 v[12] = xor_512(v[12], v[0]);
945 v[13] = xor_512(v[13], v[1]);
946 v[14] = xor_512(v[14], v[2]);
947 v[15] = xor_512(v[15], v[3]);
948 v[12] = rot8_512(v[12]);
949 v[13] = rot8_512(v[13]);
950 v[14] = rot8_512(v[14]);
951 v[15] = rot8_512(v[15]);
952 v[8] = add_512(v[8], v[12]);
953 v[9] = add_512(v[9], v[13]);
954 v[10] = add_512(v[10], v[14]);
955 v[11] = add_512(v[11], v[15]);
956 v[4] = xor_512(v[4], v[8]);
957 v[5] = xor_512(v[5], v[9]);
958 v[6] = xor_512(v[6], v[10]);
959 v[7] = xor_512(v[7], v[11]);
960 v[4] = rot7_512(v[4]);
961 v[5] = rot7_512(v[5]);
962 v[6] = rot7_512(v[6]);
963 v[7] = rot7_512(v[7]);
964
965 v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][8]]);
966 v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][10]]);
967 v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][12]]);
968 v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][14]]);
969 v[0] = add_512(v[0], v[5]);
970 v[1] = add_512(v[1], v[6]);
971 v[2] = add_512(v[2], v[7]);
972 v[3] = add_512(v[3], v[4]);
973 v[15] = xor_512(v[15], v[0]);
974 v[12] = xor_512(v[12], v[1]);
975 v[13] = xor_512(v[13], v[2]);
976 v[14] = xor_512(v[14], v[3]);
977 v[15] = rot16_512(v[15]);
978 v[12] = rot16_512(v[12]);
979 v[13] = rot16_512(v[13]);
980 v[14] = rot16_512(v[14]);
981 v[10] = add_512(v[10], v[15]);
982 v[11] = add_512(v[11], v[12]);
983 v[8] = add_512(v[8], v[13]);
984 v[9] = add_512(v[9], v[14]);
985 v[5] = xor_512(v[5], v[10]);
986 v[6] = xor_512(v[6], v[11]);
987 v[7] = xor_512(v[7], v[8]);
988 v[4] = xor_512(v[4], v[9]);
989 v[5] = rot12_512(v[5]);
990 v[6] = rot12_512(v[6]);
991 v[7] = rot12_512(v[7]);
992 v[4] = rot12_512(v[4]);
993 v[0] = add_512(v[0], m[(size_t)MSG_SCHEDULE[r][9]]);
994 v[1] = add_512(v[1], m[(size_t)MSG_SCHEDULE[r][11]]);
995 v[2] = add_512(v[2], m[(size_t)MSG_SCHEDULE[r][13]]);
996 v[3] = add_512(v[3], m[(size_t)MSG_SCHEDULE[r][15]]);
997 v[0] = add_512(v[0], v[5]);
998 v[1] = add_512(v[1], v[6]);
999 v[2] = add_512(v[2], v[7]);
1000 v[3] = add_512(v[3], v[4]);
1001 v[15] = xor_512(v[15], v[0]);
1002 v[12] = xor_512(v[12], v[1]);
1003 v[13] = xor_512(v[13], v[2]);
1004 v[14] = xor_512(v[14], v[3]);
1005 v[15] = rot8_512(v[15]);
1006 v[12] = rot8_512(v[12]);
1007 v[13] = rot8_512(v[13]);
1008 v[14] = rot8_512(v[14]);
1009 v[10] = add_512(v[10], v[15]);
1010 v[11] = add_512(v[11], v[12]);
1011 v[8] = add_512(v[8], v[13]);
1012 v[9] = add_512(v[9], v[14]);
1013 v[5] = xor_512(v[5], v[10]);
1014 v[6] = xor_512(v[6], v[11]);
1015 v[7] = xor_512(v[7], v[8]);
1016 v[4] = xor_512(v[4], v[9]);
1017 v[5] = rot7_512(v[5]);
1018 v[6] = rot7_512(v[6]);
1019 v[7] = rot7_512(v[7]);
1020 v[4] = rot7_512(v[4]);
1021}
1022
1023// 0b10001000, or lanes a0/a2/b0/b2 in little-endian order
1024#define LO_IMM8 0x88
1025
1026INLINE __m512i unpack_lo_128(__m512i a, __m512i b) {
1027 return _mm512_shuffle_i32x4(a, b, LO_IMM8);
1028}
1029
1030// 0b11011101, or lanes a1/a3/b1/b3 in little-endian order
1031#define HI_IMM8 0xdd
1032
1033INLINE __m512i unpack_hi_128(__m512i a, __m512i b) {
1034 return _mm512_shuffle_i32x4(a, b, HI_IMM8);
1035}
1036
1037INLINE void transpose_vecs_512(__m512i vecs[16]) {
1038 // Interleave 32-bit lanes. The _0 unpack is lanes
1039 // 0/0/1/1/4/4/5/5/8/8/9/9/12/12/13/13, and the _2 unpack is lanes
1040 // 2/2/3/3/6/6/7/7/10/10/11/11/14/14/15/15.
1041 __m512i ab_0 = _mm512_unpacklo_epi32(vecs[0], vecs[1]);
1042 __m512i ab_2 = _mm512_unpackhi_epi32(vecs[0], vecs[1]);
1043 __m512i cd_0 = _mm512_unpacklo_epi32(vecs[2], vecs[3]);
1044 __m512i cd_2 = _mm512_unpackhi_epi32(vecs[2], vecs[3]);
1045 __m512i ef_0 = _mm512_unpacklo_epi32(vecs[4], vecs[5]);
1046 __m512i ef_2 = _mm512_unpackhi_epi32(vecs[4], vecs[5]);
1047 __m512i gh_0 = _mm512_unpacklo_epi32(vecs[6], vecs[7]);
1048 __m512i gh_2 = _mm512_unpackhi_epi32(vecs[6], vecs[7]);
1049 __m512i ij_0 = _mm512_unpacklo_epi32(vecs[8], vecs[9]);
1050 __m512i ij_2 = _mm512_unpackhi_epi32(vecs[8], vecs[9]);
1051 __m512i kl_0 = _mm512_unpacklo_epi32(vecs[10], vecs[11]);
1052 __m512i kl_2 = _mm512_unpackhi_epi32(vecs[10], vecs[11]);
1053 __m512i mn_0 = _mm512_unpacklo_epi32(vecs[12], vecs[13]);
1054 __m512i mn_2 = _mm512_unpackhi_epi32(vecs[12], vecs[13]);
1055 __m512i op_0 = _mm512_unpacklo_epi32(vecs[14], vecs[15]);
1056 __m512i op_2 = _mm512_unpackhi_epi32(vecs[14], vecs[15]);
1057
1058 // Interleave 64-bit lanes. The _0 unpack is lanes
1059 // 0/0/0/0/4/4/4/4/8/8/8/8/12/12/12/12, the _1 unpack is lanes
1060 // 1/1/1/1/5/5/5/5/9/9/9/9/13/13/13/13, the _2 unpack is lanes
1061 // 2/2/2/2/6/6/6/6/10/10/10/10/14/14/14/14, and the _3 unpack is lanes
1062 // 3/3/3/3/7/7/7/7/11/11/11/11/15/15/15/15.
1063 __m512i abcd_0 = _mm512_unpacklo_epi64(ab_0, cd_0);
1064 __m512i abcd_1 = _mm512_unpackhi_epi64(ab_0, cd_0);
1065 __m512i abcd_2 = _mm512_unpacklo_epi64(ab_2, cd_2);
1066 __m512i abcd_3 = _mm512_unpackhi_epi64(ab_2, cd_2);
1067 __m512i efgh_0 = _mm512_unpacklo_epi64(ef_0, gh_0);
1068 __m512i efgh_1 = _mm512_unpackhi_epi64(ef_0, gh_0);
1069 __m512i efgh_2 = _mm512_unpacklo_epi64(ef_2, gh_2);
1070 __m512i efgh_3 = _mm512_unpackhi_epi64(ef_2, gh_2);
1071 __m512i ijkl_0 = _mm512_unpacklo_epi64(ij_0, kl_0);
1072 __m512i ijkl_1 = _mm512_unpackhi_epi64(ij_0, kl_0);
1073 __m512i ijkl_2 = _mm512_unpacklo_epi64(ij_2, kl_2);
1074 __m512i ijkl_3 = _mm512_unpackhi_epi64(ij_2, kl_2);
1075 __m512i mnop_0 = _mm512_unpacklo_epi64(mn_0, op_0);
1076 __m512i mnop_1 = _mm512_unpackhi_epi64(mn_0, op_0);
1077 __m512i mnop_2 = _mm512_unpacklo_epi64(mn_2, op_2);
1078 __m512i mnop_3 = _mm512_unpackhi_epi64(mn_2, op_2);
1079
1080 // Interleave 128-bit lanes. The _0 unpack is
1081 // 0/0/0/0/8/8/8/8/0/0/0/0/8/8/8/8, the _1 unpack is
1082 // 1/1/1/1/9/9/9/9/1/1/1/1/9/9/9/9, and so on.
1083 __m512i abcdefgh_0 = unpack_lo_128(abcd_0, efgh_0);
1084 __m512i abcdefgh_1 = unpack_lo_128(abcd_1, efgh_1);
1085 __m512i abcdefgh_2 = unpack_lo_128(abcd_2, efgh_2);
1086 __m512i abcdefgh_3 = unpack_lo_128(abcd_3, efgh_3);
1087 __m512i abcdefgh_4 = unpack_hi_128(abcd_0, efgh_0);
1088 __m512i abcdefgh_5 = unpack_hi_128(abcd_1, efgh_1);
1089 __m512i abcdefgh_6 = unpack_hi_128(abcd_2, efgh_2);
1090 __m512i abcdefgh_7 = unpack_hi_128(abcd_3, efgh_3);
1091 __m512i ijklmnop_0 = unpack_lo_128(ijkl_0, mnop_0);
1092 __m512i ijklmnop_1 = unpack_lo_128(ijkl_1, mnop_1);
1093 __m512i ijklmnop_2 = unpack_lo_128(ijkl_2, mnop_2);
1094 __m512i ijklmnop_3 = unpack_lo_128(ijkl_3, mnop_3);
1095 __m512i ijklmnop_4 = unpack_hi_128(ijkl_0, mnop_0);
1096 __m512i ijklmnop_5 = unpack_hi_128(ijkl_1, mnop_1);
1097 __m512i ijklmnop_6 = unpack_hi_128(ijkl_2, mnop_2);
1098 __m512i ijklmnop_7 = unpack_hi_128(ijkl_3, mnop_3);
1099
1100 // Interleave 128-bit lanes again for the final outputs.
1101 vecs[0] = unpack_lo_128(abcdefgh_0, ijklmnop_0);
1102 vecs[1] = unpack_lo_128(abcdefgh_1, ijklmnop_1);
1103 vecs[2] = unpack_lo_128(abcdefgh_2, ijklmnop_2);
1104 vecs[3] = unpack_lo_128(abcdefgh_3, ijklmnop_3);
1105 vecs[4] = unpack_lo_128(abcdefgh_4, ijklmnop_4);
1106 vecs[5] = unpack_lo_128(abcdefgh_5, ijklmnop_5);
1107 vecs[6] = unpack_lo_128(abcdefgh_6, ijklmnop_6);
1108 vecs[7] = unpack_lo_128(abcdefgh_7, ijklmnop_7);
1109 vecs[8] = unpack_hi_128(abcdefgh_0, ijklmnop_0);
1110 vecs[9] = unpack_hi_128(abcdefgh_1, ijklmnop_1);
1111 vecs[10] = unpack_hi_128(abcdefgh_2, ijklmnop_2);
1112 vecs[11] = unpack_hi_128(abcdefgh_3, ijklmnop_3);
1113 vecs[12] = unpack_hi_128(abcdefgh_4, ijklmnop_4);
1114 vecs[13] = unpack_hi_128(abcdefgh_5, ijklmnop_5);
1115 vecs[14] = unpack_hi_128(abcdefgh_6, ijklmnop_6);
1116 vecs[15] = unpack_hi_128(abcdefgh_7, ijklmnop_7);
1117}
1118
1119INLINE void transpose_msg_vecs16(const uint8_t *const *inputs,
1120 size_t block_offset, __m512i out[16]) {
1121 out[0] = loadu_512(&inputs[0][block_offset]);
1122 out[1] = loadu_512(&inputs[1][block_offset]);
1123 out[2] = loadu_512(&inputs[2][block_offset]);
1124 out[3] = loadu_512(&inputs[3][block_offset]);
1125 out[4] = loadu_512(&inputs[4][block_offset]);
1126 out[5] = loadu_512(&inputs[5][block_offset]);
1127 out[6] = loadu_512(&inputs[6][block_offset]);
1128 out[7] = loadu_512(&inputs[7][block_offset]);
1129 out[8] = loadu_512(&inputs[8][block_offset]);
1130 out[9] = loadu_512(&inputs[9][block_offset]);
1131 out[10] = loadu_512(&inputs[10][block_offset]);
1132 out[11] = loadu_512(&inputs[11][block_offset]);
1133 out[12] = loadu_512(&inputs[12][block_offset]);
1134 out[13] = loadu_512(&inputs[13][block_offset]);
1135 out[14] = loadu_512(&inputs[14][block_offset]);
1136 out[15] = loadu_512(&inputs[15][block_offset]);
1137 for (size_t i = 0; i < 16; ++i) {
1138 _mm_prefetch((const void *)&inputs[i][block_offset + 256], _MM_HINT_T0);
1139 }
1140 transpose_vecs_512(out);
1141}
1142
1143INLINE void load_counters16(uint64_t counter, bool increment_counter,
1144 __m512i *out_lo, __m512i *out_hi) {
1145 const __m512i mask = _mm512_set1_epi32(-(int32_t)increment_counter);
1146 const __m512i deltas = _mm512_set_epi32(15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0);
1147 const __m512i masked_deltas = _mm512_and_si512(deltas, mask);
1148 const __m512i low_words = _mm512_add_epi32(
1149 _mm512_set1_epi32((int32_t)counter),
1150 masked_deltas);
1151 // The carry bit is 1 if the high bit of the word was 1 before addition and is
1152 // 0 after.
1153 // NOTE: It would be a bit more natural to use _mm512_cmp_epu32_mask to
1154 // compute the carry bits here, and originally we did, but that intrinsic is
1155 // broken under GCC 5.4. See https://github.com/BLAKE3-team/BLAKE3/issues/271.
1156 const __m512i carries = _mm512_srli_epi32(
1157 _mm512_andnot_si512(
1158 low_words, // 0 after (gets inverted by andnot)
1159 _mm512_set1_epi32((int32_t)counter)), // and 1 before
1160 31);
1161 const __m512i high_words = _mm512_add_epi32(
1162 _mm512_set1_epi32((int32_t)(counter >> 32)),
1163 carries);
1164 *out_lo = low_words;
1165 *out_hi = high_words;
1166}
1167
1168static
1169void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks,
1170 const uint32_t key[8], uint64_t counter,
1171 bool increment_counter, uint8_t flags,
1172 uint8_t flags_start, uint8_t flags_end,
1173 uint8_t *out) {
1174 __m512i h_vecs[8] = {
1175 set1_512(key[0]), set1_512(key[1]), set1_512(key[2]), set1_512(key[3]),
1176 set1_512(key[4]), set1_512(key[5]), set1_512(key[6]), set1_512(key[7]),
1177 };
1178 __m512i counter_low_vec, counter_high_vec;
1179 load_counters16(counter, increment_counter, &counter_low_vec,
1180 &counter_high_vec);
1181 uint8_t block_flags = flags | flags_start;
1182
1183 for (size_t block = 0; block < blocks; block++) {
1184 if (block + 1 == blocks) {
1185 block_flags |= flags_end;
1186 }
1187 __m512i block_len_vec = set1_512(BLAKE3_BLOCK_LEN);
1188 __m512i block_flags_vec = set1_512(block_flags);
1189 __m512i msg_vecs[16];
1190 transpose_msg_vecs16(inputs, block * BLAKE3_BLOCK_LEN, msg_vecs);
1191
1192 __m512i v[16] = {
1193 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
1194 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
1195 set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]),
1196 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
1197 };
1198 round_fn16(v, msg_vecs, 0);
1199 round_fn16(v, msg_vecs, 1);
1200 round_fn16(v, msg_vecs, 2);
1201 round_fn16(v, msg_vecs, 3);
1202 round_fn16(v, msg_vecs, 4);
1203 round_fn16(v, msg_vecs, 5);
1204 round_fn16(v, msg_vecs, 6);
1205 h_vecs[0] = xor_512(v[0], v[8]);
1206 h_vecs[1] = xor_512(v[1], v[9]);
1207 h_vecs[2] = xor_512(v[2], v[10]);
1208 h_vecs[3] = xor_512(v[3], v[11]);
1209 h_vecs[4] = xor_512(v[4], v[12]);
1210 h_vecs[5] = xor_512(v[5], v[13]);
1211 h_vecs[6] = xor_512(v[6], v[14]);
1212 h_vecs[7] = xor_512(v[7], v[15]);
1213
1214 block_flags = flags;
1215 }
1216
1217 // transpose_vecs_512 operates on a 16x16 matrix of words, but we only have 8
1218 // state vectors. Pad the matrix with zeros. After transposition, store the
1219 // lower half of each vector.
1220 __m512i padded[16] = {
1221 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
1222 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
1223 set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1224 set1_512(0), set1_512(0), set1_512(0), set1_512(0),
1225 };
1226 transpose_vecs_512(padded);
1227 _mm256_mask_storeu_epi32(&out[0 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[0]));
1228 _mm256_mask_storeu_epi32(&out[1 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[1]));
1229 _mm256_mask_storeu_epi32(&out[2 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[2]));
1230 _mm256_mask_storeu_epi32(&out[3 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[3]));
1231 _mm256_mask_storeu_epi32(&out[4 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[4]));
1232 _mm256_mask_storeu_epi32(&out[5 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[5]));
1233 _mm256_mask_storeu_epi32(&out[6 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[6]));
1234 _mm256_mask_storeu_epi32(&out[7 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[7]));
1235 _mm256_mask_storeu_epi32(&out[8 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[8]));
1236 _mm256_mask_storeu_epi32(&out[9 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[9]));
1237 _mm256_mask_storeu_epi32(&out[10 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[10]));
1238 _mm256_mask_storeu_epi32(&out[11 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[11]));
1239 _mm256_mask_storeu_epi32(&out[12 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[12]));
1240 _mm256_mask_storeu_epi32(&out[13 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[13]));
1241 _mm256_mask_storeu_epi32(&out[14 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[14]));
1242 _mm256_mask_storeu_epi32(&out[15 * sizeof(__m256i)], (__mmask8)-1, _mm512_castsi512_si256(padded[15]));
1243}
1244
1245static
1248 uint8_t block_len, uint64_t counter, uint8_t flags,
1249 uint8_t out[16 * 64]) {
1250 __m512i h_vecs[8] = {
1251 set1_512(cv[0]), set1_512(cv[1]), set1_512(cv[2]), set1_512(cv[3]),
1252 set1_512(cv[4]), set1_512(cv[5]), set1_512(cv[6]), set1_512(cv[7]),
1253 };
1254 uint32_t block_words[16];
1255 load_block_words(block, block_words);
1256 __m512i msg_vecs[16];
1257 for (size_t i = 0; i < 16; i++) {
1258 msg_vecs[i] = set1_512(block_words[i]);
1259 }
1260 __m512i counter_low_vec, counter_high_vec;
1261 load_counters16(counter, true, &counter_low_vec, &counter_high_vec);
1262 __m512i block_len_vec = set1_512(block_len);
1263 __m512i block_flags_vec = set1_512(flags);
1264 __m512i v[16] = {
1265 h_vecs[0], h_vecs[1], h_vecs[2], h_vecs[3],
1266 h_vecs[4], h_vecs[5], h_vecs[6], h_vecs[7],
1267 set1_512(IV[0]), set1_512(IV[1]), set1_512(IV[2]), set1_512(IV[3]),
1268 counter_low_vec, counter_high_vec, block_len_vec, block_flags_vec,
1269 };
1270 round_fn16(v, msg_vecs, 0);
1271 round_fn16(v, msg_vecs, 1);
1272 round_fn16(v, msg_vecs, 2);
1273 round_fn16(v, msg_vecs, 3);
1274 round_fn16(v, msg_vecs, 4);
1275 round_fn16(v, msg_vecs, 5);
1276 round_fn16(v, msg_vecs, 6);
1277 for (size_t i = 0; i < 8; i++) {
1278 v[i] = xor_512(v[i], v[i+8]);
1279 v[i+8] = xor_512(v[i+8], h_vecs[i]);
1280 }
1281 transpose_vecs_512(&v[0]);
1282 for (size_t i = 0; i < 16; i++) {
1283 storeu_512(v[i], &out[i * sizeof(__m512i)]);
1284 }
1285}
1286
1287/*
1288 * ----------------------------------------------------------------------------
1289 * hash_many_avx512
1290 * ----------------------------------------------------------------------------
1291 */
1292
1293INLINE void hash_one_avx512(const uint8_t *input, size_t blocks,
1294 const uint32_t key[8], uint64_t counter,
1295 uint8_t flags, uint8_t flags_start,
1296 uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN]) {
1297 uint32_t cv[8];
1298 memcpy(cv, key, BLAKE3_KEY_LEN);
1299 uint8_t block_flags = flags | flags_start;
1300 while (blocks > 0) {
1301 if (blocks == 1) {
1302 block_flags |= flags_end;
1303 }
1305 block_flags);
1306 input = &input[BLAKE3_BLOCK_LEN];
1307 blocks -= 1;
1308 block_flags = flags;
1309 }
1310 memcpy(out, cv, BLAKE3_OUT_LEN);
1311}
1312
1313void blake3_hash_many_avx512(const uint8_t *const *inputs, size_t num_inputs,
1314 size_t blocks, const uint32_t key[8],
1315 uint64_t counter, bool increment_counter,
1316 uint8_t flags, uint8_t flags_start,
1317 uint8_t flags_end, uint8_t *out) {
1318 while (num_inputs >= 16) {
1319 blake3_hash16_avx512(inputs, blocks, key, counter, increment_counter, flags,
1320 flags_start, flags_end, out);
1321 if (increment_counter) {
1322 counter += 16;
1323 }
1324 inputs += 16;
1325 num_inputs -= 16;
1326 out = &out[16 * BLAKE3_OUT_LEN];
1327 }
1328 while (num_inputs >= 8) {
1329 blake3_hash8_avx512(inputs, blocks, key, counter, increment_counter, flags,
1330 flags_start, flags_end, out);
1331 if (increment_counter) {
1332 counter += 8;
1333 }
1334 inputs += 8;
1335 num_inputs -= 8;
1336 out = &out[8 * BLAKE3_OUT_LEN];
1337 }
1338 while (num_inputs >= 4) {
1339 blake3_hash4_avx512(inputs, blocks, key, counter, increment_counter, flags,
1340 flags_start, flags_end, out);
1341 if (increment_counter) {
1342 counter += 4;
1343 }
1344 inputs += 4;
1345 num_inputs -= 4;
1346 out = &out[4 * BLAKE3_OUT_LEN];
1347 }
1348 while (num_inputs > 0) {
1349 hash_one_avx512(inputs[0], blocks, key, counter, flags, flags_start,
1350 flags_end, out);
1351 if (increment_counter) {
1352 counter += 1;
1353 }
1354 inputs += 1;
1355 num_inputs -= 1;
1356 out = &out[BLAKE3_OUT_LEN];
1357 }
1358}
1359
1362 uint8_t block_len, uint64_t counter, uint8_t flags,
1363 uint8_t* out, size_t outblocks) {
1364 while (outblocks >= 16) {
1365 blake3_xof16_avx512(cv, block, block_len, counter, flags, out);
1366 counter += 16;
1367 outblocks -= 16;
1368 out += 16 * BLAKE3_BLOCK_LEN;
1369 }
1370 while (outblocks >= 8) {
1371 blake3_xof8_avx512(cv, block, block_len, counter, flags, out);
1372 counter += 8;
1373 outblocks -= 8;
1374 out += 8 * BLAKE3_BLOCK_LEN;
1375 }
1376 while (outblocks >= 4) {
1377 blake3_xof4_avx512(cv, block, block_len, counter, flags, out);
1378 counter += 4;
1379 outblocks -= 4;
1380 out += 4 * BLAKE3_BLOCK_LEN;
1381 }
1382 while (outblocks > 0) {
1383 blake3_compress_xof_avx512(cv, block, block_len, counter, flags, out);
1384 counter += 1;
1385 outblocks -= 1;
1386 out += BLAKE3_BLOCK_LEN;
1387 }
1388}
bbsections Prepares for basic block by splitting functions into clusters of basic blocks
unify loop Fixup each natural loop to have a single exit block
INLINE __m128i rot16_128(__m128i x)
INLINE __m512i rot8_512(__m512i x)
INLINE __m128i set4(uint32_t a, uint32_t b, uint32_t c, uint32_t d)
#define _mm_shuffle_ps2(a, b, c)
INLINE __m256i set1_256(uint32_t x)
INLINE void storeu_128(__m128i src, uint8_t dest[16])
INLINE void storeu_256(__m256i src, uint8_t dest[32])
INLINE void g1(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3, __m128i m)
INLINE void storeu_512(__m512i src, uint8_t dest[64])
INLINE __m128i set1_128(uint32_t x)
INLINE __m512i set1_512(uint32_t x)
INLINE __m256i rot8_256(__m256i x)
INLINE __m512i loadu_512(const uint8_t src[64])
INLINE void round_fn16(__m512i v[16], __m512i m[16], size_t r)
INLINE __m256i rot7_256(__m256i x)
INLINE __m512i rot16_512(__m512i x)
INLINE void hash_one_avx512(const uint8_t *input, size_t blocks, const uint32_t key[8], uint64_t counter, uint8_t flags, uint8_t flags_start, uint8_t flags_end, uint8_t out[BLAKE3_OUT_LEN])
INLINE void transpose_vecs_128(__m128i vecs[4])
INLINE __m128i rot8_128(__m128i x)
INLINE void transpose_vecs_512(__m512i vecs[16])
INLINE __m512i add_512(__m512i a, __m512i b)
INLINE __m256i xor_256(__m256i a, __m256i b)
INLINE __m128i loadu_128(const uint8_t src[16])
INLINE __m128i rot12_128(__m128i x)
INLINE __m256i rot12_256(__m256i x)
INLINE void load_counters8(uint64_t counter, bool increment_counter, __m256i *out_lo, __m256i *out_hi)
INLINE __m128i add_128(__m128i a, __m128i b)
static void blake3_hash16_avx512(const uint8_t *const *inputs, size_t blocks, const uint32_t key[8], uint64_t counter, bool increment_counter, uint8_t flags, uint8_t flags_start, uint8_t flags_end, uint8_t *out)
INLINE void diagonalize(__m128i *row0, __m128i *row2, __m128i *row3)
INLINE void load_counters16(uint64_t counter, bool increment_counter, __m512i *out_lo, __m512i *out_hi)
INLINE __m256i rot16_256(__m256i x)
INLINE __m512i unpack_lo_128(__m512i a, __m512i b)
INLINE void transpose_msg_vecs4(const uint8_t *const *inputs, size_t block_offset, __m128i out[16])
INLINE void transpose_msg_vecs8(const uint8_t *const *inputs, size_t block_offset, __m256i out[16])
static void blake3_xof8_avx512(const uint32_t cv[8], const uint8_t block[BLAKE3_BLOCK_LEN], uint8_t block_len, uint64_t counter, uint8_t flags, uint8_t out[8 *64])
INLINE void undiagonalize(__m128i *row0, __m128i *row2, __m128i *row3)
#define HI_IMM8
INLINE void compress_pre(__m128i rows[4], const uint32_t cv[8], const uint8_t block[BLAKE3_BLOCK_LEN], uint8_t block_len, uint64_t counter, uint8_t flags)
INLINE void load_counters4(uint64_t counter, bool increment_counter, __m128i *out_lo, __m128i *out_hi)
INLINE void g2(__m128i *row0, __m128i *row1, __m128i *row2, __m128i *row3, __m128i m)
#define LO_IMM8
static void blake3_hash8_avx512(const uint8_t *const *inputs, size_t blocks, const uint32_t key[8], uint64_t counter, bool increment_counter, uint8_t flags, uint8_t flags_start, uint8_t flags_end, uint8_t *out)
static void blake3_xof4_avx512(const uint32_t cv[8], const uint8_t block[BLAKE3_BLOCK_LEN], uint8_t block_len, uint64_t counter, uint8_t flags, uint8_t out[4 *64])
INLINE __m256i loadu_256(const uint8_t src[32])
INLINE void round_fn4(__m128i v[16], __m128i m[16], size_t r)
INLINE __m512i rot7_512(__m512i x)
INLINE void transpose_vecs_256(__m256i vecs[8])
static void blake3_xof16_avx512(const uint32_t cv[8], const uint8_t block[BLAKE3_BLOCK_LEN], uint8_t block_len, uint64_t counter, uint8_t flags, uint8_t out[16 *64])
INLINE void transpose_msg_vecs16(const uint8_t *const *inputs, size_t block_offset, __m512i out[16])
INLINE void round_fn8(__m256i v[16], __m256i m[16], size_t r)
INLINE __m512i unpack_hi_128(__m512i a, __m512i b)
INLINE __m256i add_256(__m256i a, __m256i b)
INLINE __m128i rot7_128(__m128i x)
INLINE __m128i xor_128(__m128i a, __m128i b)
INLINE __m512i xor_512(__m512i a, __m512i b)
INLINE __m512i rot12_512(__m512i x)
static void blake3_hash4_avx512(const uint8_t *const *inputs, size_t blocks, const uint32_t key[8], uint64_t counter, bool increment_counter, uint8_t flags, uint8_t flags_start, uint8_t flags_end, uint8_t *out)
static const uint8_t MSG_SCHEDULE[7][16]
Definition blake3_impl.h:87
#define INLINE
Definition blake3_impl.h:34
static const uint32_t IV[8]
Definition blake3_impl.h:83
INLINE uint32_t counter_high(uint64_t counter)
INLINE void load_block_words(const uint8_t block[BLAKE3_BLOCK_LEN], uint32_t block_words[16])
INLINE uint32_t counter_low(uint64_t counter)
#define blake3_hash_many_avx512
#define BLAKE3_BLOCK_LEN
#define BLAKE3_OUT_LEN
#define blake3_compress_xof_avx512
#define blake3_xof_many_avx512
#define BLAKE3_KEY_LEN
#define blake3_compress_in_place_avx512