]>
Commit | Line | Data |
---|---|---|
f20569fa XL |
1 | use std::arch::x86_64::{ |
2 | __m256i, | |
3 | _mm256_and_si256, | |
4 | _mm256_cmpeq_epi8, | |
5 | _mm256_extract_epi64, | |
6 | _mm256_loadu_si256, | |
7 | _mm256_sad_epu8, | |
8 | _mm256_set1_epi8, | |
9 | _mm256_setzero_si256, | |
10 | _mm256_sub_epi8, | |
11 | _mm256_xor_si256, | |
12 | }; | |
13 | ||
14 | #[target_feature(enable = "avx2")] | |
15 | pub unsafe fn _mm256_set1_epu8(a: u8) -> __m256i { | |
16 | _mm256_set1_epi8(a as i8) | |
17 | } | |
18 | ||
19 | #[target_feature(enable = "avx2")] | |
20 | pub unsafe fn mm256_cmpneq_epi8(a: __m256i, b: __m256i) -> __m256i { | |
21 | _mm256_xor_si256(_mm256_cmpeq_epi8(a, b), _mm256_set1_epi8(-1)) | |
22 | } | |
23 | ||
24 | const MASK: [u8; 64] = [ | |
25 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
26 | 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
27 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
28 | 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
29 | ]; | |
30 | ||
31 | #[target_feature(enable = "avx2")] | |
32 | unsafe fn mm256_from_offset(slice: &[u8], offset: usize) -> __m256i { | |
cdc7bbd5 | 33 | _mm256_loadu_si256(slice.as_ptr().add(offset) as *const _) |
f20569fa XL |
34 | } |
35 | ||
36 | #[target_feature(enable = "avx2")] | |
37 | unsafe fn sum(u8s: &__m256i) -> usize { | |
38 | let sums = _mm256_sad_epu8(*u8s, _mm256_setzero_si256()); | |
39 | ( | |
40 | _mm256_extract_epi64(sums, 0) + _mm256_extract_epi64(sums, 1) + | |
41 | _mm256_extract_epi64(sums, 2) + _mm256_extract_epi64(sums, 3) | |
42 | ) as usize | |
43 | } | |
44 | ||
45 | #[target_feature(enable = "avx2")] | |
46 | pub unsafe fn chunk_count(haystack: &[u8], needle: u8) -> usize { | |
47 | assert!(haystack.len() >= 32); | |
48 | ||
49 | let mut offset = 0; | |
50 | let mut count = 0; | |
51 | ||
52 | let needles = _mm256_set1_epu8(needle); | |
53 | ||
54 | // 8160 | |
55 | while haystack.len() >= offset + 32 * 255 { | |
56 | let mut counts = _mm256_setzero_si256(); | |
57 | for _ in 0..255 { | |
58 | counts = _mm256_sub_epi8( | |
59 | counts, | |
60 | _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) | |
61 | ); | |
62 | offset += 32; | |
63 | } | |
64 | count += sum(&counts); | |
65 | } | |
66 | ||
67 | // 4096 | |
68 | if haystack.len() >= offset + 32 * 128 { | |
69 | let mut counts = _mm256_setzero_si256(); | |
70 | for _ in 0..128 { | |
71 | counts = _mm256_sub_epi8( | |
72 | counts, | |
73 | _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset), needles) | |
74 | ); | |
75 | offset += 32; | |
76 | } | |
77 | count += sum(&counts); | |
78 | } | |
79 | ||
80 | // 32 | |
81 | let mut counts = _mm256_setzero_si256(); | |
82 | for i in 0..(haystack.len() - offset) / 32 { | |
83 | counts = _mm256_sub_epi8( | |
84 | counts, | |
85 | _mm256_cmpeq_epi8(mm256_from_offset(haystack, offset + i * 32), needles) | |
86 | ); | |
87 | } | |
88 | if haystack.len() % 32 != 0 { | |
89 | counts = _mm256_sub_epi8( | |
90 | counts, | |
91 | _mm256_and_si256( | |
92 | _mm256_cmpeq_epi8(mm256_from_offset(haystack, haystack.len() - 32), needles), | |
93 | mm256_from_offset(&MASK, haystack.len() % 32) | |
94 | ) | |
95 | ); | |
96 | } | |
97 | count += sum(&counts); | |
98 | ||
99 | count | |
100 | } | |
101 | ||
102 | #[target_feature(enable = "avx2")] | |
103 | unsafe fn is_leading_utf8_byte(u8s: __m256i) -> __m256i { | |
104 | mm256_cmpneq_epi8(_mm256_and_si256(u8s, _mm256_set1_epu8(0b1100_0000)), _mm256_set1_epu8(0b1000_0000)) | |
105 | } | |
106 | ||
107 | #[target_feature(enable = "avx2")] | |
108 | pub unsafe fn chunk_num_chars(utf8_chars: &[u8]) -> usize { | |
109 | assert!(utf8_chars.len() >= 32); | |
110 | ||
111 | let mut offset = 0; | |
112 | let mut count = 0; | |
113 | ||
114 | // 8160 | |
115 | while utf8_chars.len() >= offset + 32 * 255 { | |
116 | let mut counts = _mm256_setzero_si256(); | |
117 | ||
118 | for _ in 0..255 { | |
119 | counts = _mm256_sub_epi8( | |
120 | counts, | |
121 | is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) | |
122 | ); | |
123 | offset += 32; | |
124 | } | |
125 | count += sum(&counts); | |
126 | } | |
127 | ||
128 | // 4096 | |
129 | if utf8_chars.len() >= offset + 32 * 128 { | |
130 | let mut counts = _mm256_setzero_si256(); | |
131 | for _ in 0..128 { | |
132 | counts = _mm256_sub_epi8( | |
133 | counts, | |
134 | is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset)) | |
135 | ); | |
136 | offset += 32; | |
137 | } | |
138 | count += sum(&counts); | |
139 | } | |
140 | ||
141 | // 32 | |
142 | let mut counts = _mm256_setzero_si256(); | |
143 | for i in 0..(utf8_chars.len() - offset) / 32 { | |
144 | counts = _mm256_sub_epi8( | |
145 | counts, | |
146 | is_leading_utf8_byte(mm256_from_offset(utf8_chars, offset + i * 32)) | |
147 | ); | |
148 | } | |
149 | if utf8_chars.len() % 32 != 0 { | |
150 | counts = _mm256_sub_epi8( | |
151 | counts, | |
152 | _mm256_and_si256( | |
153 | is_leading_utf8_byte(mm256_from_offset(utf8_chars, utf8_chars.len() - 32)), | |
154 | mm256_from_offset(&MASK, utf8_chars.len() % 32) | |
155 | ) | |
156 | ); | |
157 | } | |
158 | count += sum(&counts); | |
159 | ||
160 | count | |
161 | } |