]>
Commit | Line | Data |
---|---|---|
dfeec247 XL |
1 | use core::mem; |
2 | ||
3 | // The following ~400 lines of code exists for exactly one purpose, which is | |
4 | // to optimize this code: | |
5 | // | |
6 | // byte_slice.iter().position(|&b| b > 0x7F).unwrap_or(byte_slice.len()) | |
7 | // | |
8 | // Yes... Overengineered is a word that comes to mind, but this is effectively | |
9 | // a very similar problem to memchr, and virtually nobody has been able to | |
10 | // resist optimizing the crap out of that (except for perhaps the BSD and MUSL | |
11 | // folks). In particular, this routine makes a very common case (ASCII) very | |
12 | // fast, which seems worth it. We do stop short of adding AVX variants of the | |
13 | // code below in order to retain our sanity and also to avoid needing to deal | |
14 | // with runtime target feature detection. RESIST! | |
15 | // | |
16 | // In order to understand the SIMD version below, it would be good to read this | |
17 | // comment describing how my memchr routine works: | |
18 | // https://github.com/BurntSushi/rust-memchr/blob/b0a29f267f4a7fad8ffcc8fe8377a06498202883/src/x86/sse2.rs#L19-L106 | |
19 | // | |
20 | // The primary difference with memchr is that for ASCII, we can do a bit less | |
21 | // work. In particular, we don't need to detect the presence of a specific | |
22 | // byte, but rather, whether any byte has its most significant bit set. That | |
23 | // means we can effectively skip the _mm_cmpeq_epi8 step and jump straight to | |
24 | // _mm_movemask_epi8. | |
25 | ||
26 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
27 | const USIZE_BYTES: usize = mem::size_of::<usize>(); | |
28 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
29 | const FALLBACK_LOOP_SIZE: usize = 2 * USIZE_BYTES; | |
30 | ||
31 | // This is a mask where the most significant bit of each byte in the usize | |
32 | // is set. We test this bit to determine whether a character is ASCII or not. | |
33 | // Namely, a single byte is regarded as an ASCII codepoint if and only if it's | |
34 | // most significant bit is not set. | |
35 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
36 | const ASCII_MASK_U64: u64 = 0x8080808080808080; | |
37 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
38 | const ASCII_MASK: usize = ASCII_MASK_U64 as usize; | |
39 | ||
40 | /// Returns the index of the first non ASCII byte in the given slice. | |
41 | /// | |
42 | /// If slice only contains ASCII bytes, then the length of the slice is | |
43 | /// returned. | |
44 | pub fn first_non_ascii_byte(slice: &[u8]) -> usize { | |
45 | #[cfg(not(target_arch = "x86_64"))] | |
46 | { | |
47 | first_non_ascii_byte_fallback(slice) | |
48 | } | |
49 | ||
50 | #[cfg(target_arch = "x86_64")] | |
51 | { | |
52 | first_non_ascii_byte_sse2(slice) | |
53 | } | |
54 | } | |
55 | ||
56 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
57 | fn first_non_ascii_byte_fallback(slice: &[u8]) -> usize { | |
58 | let align = USIZE_BYTES - 1; | |
59 | let start_ptr = slice.as_ptr(); | |
60 | let end_ptr = slice[slice.len()..].as_ptr(); | |
61 | let mut ptr = start_ptr; | |
62 | ||
63 | unsafe { | |
64 | if slice.len() < USIZE_BYTES { | |
65 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); | |
66 | } | |
67 | ||
68 | let chunk = read_unaligned_usize(ptr); | |
69 | let mask = chunk & ASCII_MASK; | |
70 | if mask != 0 { | |
71 | return first_non_ascii_byte_mask(mask); | |
72 | } | |
73 | ||
74 | ptr = ptr_add(ptr, USIZE_BYTES - (start_ptr as usize & align)); | |
75 | debug_assert!(ptr > start_ptr); | |
76 | debug_assert!(ptr_sub(end_ptr, USIZE_BYTES) >= start_ptr); | |
77 | if slice.len() >= FALLBACK_LOOP_SIZE { | |
78 | while ptr <= ptr_sub(end_ptr, FALLBACK_LOOP_SIZE) { | |
79 | debug_assert_eq!(0, (ptr as usize) % USIZE_BYTES); | |
80 | ||
81 | let a = *(ptr as *const usize); | |
82 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); | |
83 | if (a | b) & ASCII_MASK != 0 { | |
84 | // What a kludge. We wrap the position finding code into | |
85 | // a non-inlineable function, which makes the codegen in | |
86 | // the tight loop above a bit better by avoiding a | |
87 | // couple extra movs. We pay for it by two additional | |
88 | // stores, but only in the case of finding a non-ASCII | |
89 | // byte. | |
90 | #[inline(never)] | |
91 | unsafe fn findpos( | |
92 | start_ptr: *const u8, | |
93 | ptr: *const u8, | |
94 | ) -> usize { | |
95 | let a = *(ptr as *const usize); | |
96 | let b = *(ptr_add(ptr, USIZE_BYTES) as *const usize); | |
97 | ||
98 | let mut at = sub(ptr, start_ptr); | |
99 | let maska = a & ASCII_MASK; | |
100 | if maska != 0 { | |
101 | return at + first_non_ascii_byte_mask(maska); | |
102 | } | |
103 | ||
104 | at += USIZE_BYTES; | |
105 | let maskb = b & ASCII_MASK; | |
106 | debug_assert!(maskb != 0); | |
107 | return at + first_non_ascii_byte_mask(maskb); | |
108 | } | |
109 | return findpos(start_ptr, ptr); | |
110 | } | |
111 | ptr = ptr_add(ptr, FALLBACK_LOOP_SIZE); | |
112 | } | |
113 | } | |
114 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) | |
115 | } | |
116 | } | |
117 | ||
118 | #[cfg(target_arch = "x86_64")] | |
119 | fn first_non_ascii_byte_sse2(slice: &[u8]) -> usize { | |
120 | use core::arch::x86_64::*; | |
121 | ||
122 | const VECTOR_SIZE: usize = mem::size_of::<__m128i>(); | |
123 | const VECTOR_ALIGN: usize = VECTOR_SIZE - 1; | |
124 | const VECTOR_LOOP_SIZE: usize = 4 * VECTOR_SIZE; | |
125 | ||
126 | let start_ptr = slice.as_ptr(); | |
127 | let end_ptr = slice[slice.len()..].as_ptr(); | |
128 | let mut ptr = start_ptr; | |
129 | ||
130 | unsafe { | |
131 | if slice.len() < VECTOR_SIZE { | |
132 | return first_non_ascii_byte_slow(start_ptr, end_ptr, ptr); | |
133 | } | |
134 | ||
135 | let chunk = _mm_loadu_si128(ptr as *const __m128i); | |
136 | let mask = _mm_movemask_epi8(chunk); | |
137 | if mask != 0 { | |
138 | return mask.trailing_zeros() as usize; | |
139 | } | |
140 | ||
141 | ptr = ptr.add(VECTOR_SIZE - (start_ptr as usize & VECTOR_ALIGN)); | |
142 | debug_assert!(ptr > start_ptr); | |
143 | debug_assert!(end_ptr.sub(VECTOR_SIZE) >= start_ptr); | |
144 | if slice.len() >= VECTOR_LOOP_SIZE { | |
145 | while ptr <= ptr_sub(end_ptr, VECTOR_LOOP_SIZE) { | |
146 | debug_assert_eq!(0, (ptr as usize) % VECTOR_SIZE); | |
147 | ||
148 | let a = _mm_load_si128(ptr as *const __m128i); | |
149 | let b = _mm_load_si128(ptr.add(VECTOR_SIZE) as *const __m128i); | |
f035d41b XL |
150 | let c = |
151 | _mm_load_si128(ptr.add(2 * VECTOR_SIZE) as *const __m128i); | |
152 | let d = | |
153 | _mm_load_si128(ptr.add(3 * VECTOR_SIZE) as *const __m128i); | |
dfeec247 XL |
154 | |
155 | let or1 = _mm_or_si128(a, b); | |
156 | let or2 = _mm_or_si128(c, d); | |
157 | let or3 = _mm_or_si128(or1, or2); | |
158 | if _mm_movemask_epi8(or3) != 0 { | |
159 | let mut at = sub(ptr, start_ptr); | |
160 | let mask = _mm_movemask_epi8(a); | |
161 | if mask != 0 { | |
162 | return at + mask.trailing_zeros() as usize; | |
163 | } | |
164 | ||
165 | at += VECTOR_SIZE; | |
166 | let mask = _mm_movemask_epi8(b); | |
167 | if mask != 0 { | |
168 | return at + mask.trailing_zeros() as usize; | |
169 | } | |
170 | ||
171 | at += VECTOR_SIZE; | |
172 | let mask = _mm_movemask_epi8(c); | |
173 | if mask != 0 { | |
174 | return at + mask.trailing_zeros() as usize; | |
175 | } | |
176 | ||
177 | at += VECTOR_SIZE; | |
178 | let mask = _mm_movemask_epi8(d); | |
179 | debug_assert!(mask != 0); | |
180 | return at + mask.trailing_zeros() as usize; | |
181 | } | |
182 | ptr = ptr_add(ptr, VECTOR_LOOP_SIZE); | |
183 | } | |
184 | } | |
185 | while ptr <= end_ptr.sub(VECTOR_SIZE) { | |
186 | debug_assert!(sub(end_ptr, ptr) >= VECTOR_SIZE); | |
187 | ||
188 | let chunk = _mm_loadu_si128(ptr as *const __m128i); | |
189 | let mask = _mm_movemask_epi8(chunk); | |
190 | if mask != 0 { | |
191 | return sub(ptr, start_ptr) + mask.trailing_zeros() as usize; | |
192 | } | |
193 | ptr = ptr.add(VECTOR_SIZE); | |
194 | } | |
195 | first_non_ascii_byte_slow(start_ptr, end_ptr, ptr) | |
196 | } | |
197 | } | |
198 | ||
199 | #[inline(always)] | |
200 | unsafe fn first_non_ascii_byte_slow( | |
201 | start_ptr: *const u8, | |
202 | end_ptr: *const u8, | |
203 | mut ptr: *const u8, | |
204 | ) -> usize { | |
205 | debug_assert!(start_ptr <= ptr); | |
206 | debug_assert!(ptr <= end_ptr); | |
207 | ||
208 | while ptr < end_ptr { | |
209 | if *ptr > 0x7F { | |
210 | return sub(ptr, start_ptr); | |
211 | } | |
212 | ptr = ptr.offset(1); | |
213 | } | |
214 | sub(end_ptr, start_ptr) | |
215 | } | |
216 | ||
217 | /// Compute the position of the first ASCII byte in the given mask. | |
218 | /// | |
219 | /// The mask should be computed by `chunk & ASCII_MASK`, where `chunk` is | |
220 | /// 8 contiguous bytes of the slice being checked where *at least* one of those | |
221 | /// bytes is not an ASCII byte. | |
222 | /// | |
223 | /// The position returned is always in the inclusive range [0, 7]. | |
224 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
225 | fn first_non_ascii_byte_mask(mask: usize) -> usize { | |
226 | #[cfg(target_endian = "little")] | |
f035d41b XL |
227 | { |
228 | mask.trailing_zeros() as usize / 8 | |
229 | } | |
dfeec247 | 230 | #[cfg(target_endian = "big")] |
f035d41b XL |
231 | { |
232 | mask.leading_zeros() as usize / 8 | |
233 | } | |
dfeec247 XL |
234 | } |
235 | ||
236 | /// Increment the given pointer by the given amount. | |
237 | unsafe fn ptr_add(ptr: *const u8, amt: usize) -> *const u8 { | |
238 | debug_assert!(amt < ::core::isize::MAX as usize); | |
239 | ptr.offset(amt as isize) | |
240 | } | |
241 | ||
242 | /// Decrement the given pointer by the given amount. | |
243 | unsafe fn ptr_sub(ptr: *const u8, amt: usize) -> *const u8 { | |
244 | debug_assert!(amt < ::core::isize::MAX as usize); | |
245 | ptr.offset((amt as isize).wrapping_neg()) | |
246 | } | |
247 | ||
248 | #[cfg(any(test, not(target_arch = "x86_64")))] | |
249 | unsafe fn read_unaligned_usize(ptr: *const u8) -> usize { | |
250 | use core::ptr; | |
251 | ||
252 | let mut n: usize = 0; | |
253 | ptr::copy_nonoverlapping(ptr, &mut n as *mut _ as *mut u8, USIZE_BYTES); | |
254 | n | |
255 | } | |
256 | ||
257 | /// Subtract `b` from `a` and return the difference. `a` should be greater than | |
258 | /// or equal to `b`. | |
259 | fn sub(a: *const u8, b: *const u8) -> usize { | |
260 | debug_assert!(a >= b); | |
261 | (a as usize) - (b as usize) | |
262 | } | |
263 | ||
264 | #[cfg(test)] | |
265 | mod tests { | |
266 | use super::*; | |
267 | ||
268 | // Our testing approach here is to try and exhaustively test every case. | |
269 | // This includes the position at which a non-ASCII byte occurs in addition | |
270 | // to the alignment of the slice that we're searching. | |
271 | ||
272 | #[test] | |
273 | fn positive_fallback_forward() { | |
274 | for i in 0..517 { | |
275 | let s = "a".repeat(i); | |
276 | assert_eq!( | |
277 | i, | |
278 | first_non_ascii_byte_fallback(s.as_bytes()), | |
279 | "i: {:?}, len: {:?}, s: {:?}", | |
f035d41b XL |
280 | i, |
281 | s.len(), | |
282 | s | |
dfeec247 XL |
283 | ); |
284 | } | |
285 | } | |
286 | ||
287 | #[test] | |
288 | #[cfg(target_arch = "x86_64")] | |
289 | fn positive_sse2_forward() { | |
290 | for i in 0..517 { | |
291 | let b = "a".repeat(i).into_bytes(); | |
292 | assert_eq!(b.len(), first_non_ascii_byte_sse2(&b)); | |
293 | } | |
294 | } | |
295 | ||
296 | #[test] | |
297 | fn negative_fallback_forward() { | |
298 | for i in 0..517 { | |
299 | for align in 0..65 { | |
300 | let mut s = "a".repeat(i); | |
301 | s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); | |
302 | let s = s.get(align..).unwrap_or(""); | |
303 | assert_eq!( | |
304 | i.saturating_sub(align), | |
305 | first_non_ascii_byte_fallback(s.as_bytes()), | |
306 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}", | |
f035d41b XL |
307 | i, |
308 | align, | |
309 | s.len(), | |
310 | s | |
dfeec247 XL |
311 | ); |
312 | } | |
313 | } | |
314 | } | |
315 | ||
316 | #[test] | |
317 | #[cfg(target_arch = "x86_64")] | |
318 | fn negative_sse2_forward() { | |
319 | for i in 0..517 { | |
320 | for align in 0..65 { | |
321 | let mut s = "a".repeat(i); | |
322 | s.push_str("☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃☃"); | |
323 | let s = s.get(align..).unwrap_or(""); | |
324 | assert_eq!( | |
325 | i.saturating_sub(align), | |
326 | first_non_ascii_byte_sse2(s.as_bytes()), | |
327 | "i: {:?}, align: {:?}, len: {:?}, s: {:?}", | |
f035d41b XL |
328 | i, |
329 | align, | |
330 | s.len(), | |
331 | s | |
dfeec247 XL |
332 | ); |
333 | } | |
334 | } | |
335 | } | |
336 | } |