1 // Copyright 2015 Google Inc. All rights reserved.
3 // Permission is hereby granted, free of charge, to any person obtaining a copy
4 // of this software and associated documentation files (the "Software"), to deal
5 // in the Software without restriction, including without limitation the rights
6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 // copies of the Software, and to permit persons to whom the Software is
8 // furnished to do so, subject to the following conditions:
10 // The above copyright notice and this permission notice shall be included in
11 // all copies or substantial portions of the Software.
13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21 //! Utility functions for HTML escaping. Only useful when building your own
24 use std
::fmt
::{Arguments, Write as FmtWrite}
;
25 use std
::io
::{self, ErrorKind, Write}
;
26 use std
::str::from_utf8
;
29 static HREF_SAFE
: [u8; 128] = [
30 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
31 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
32 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
33 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1,
34 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
35 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1,
36 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
37 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
40 static HEX_CHARS
: &[u8] = b
"0123456789ABCDEF";
41 static AMP_ESCAPE
: &str = "&";
42 static SLASH_ESCAPE
: &str = "'";
44 /// This wrapper exists because we can't have both a blanket implementation
45 /// for all types implementing `Write` and types of the for `&mut W` where
46 /// `W: StrWrite`. Since we need the latter a lot, we choose to wrap
48 pub struct WriteWrapper
<W
>(pub W
);
50 /// Trait that allows writing string slices. This is basically an extension
51 /// of `std::io::Write` in order to include `String`.
53 fn write_str(&mut self, s
: &str) -> io
::Result
<()>;
55 fn write_fmt(&mut self, args
: Arguments
) -> io
::Result
<()>;
58 impl<W
> StrWrite
for WriteWrapper
<W
>
63 fn write_str(&mut self, s
: &str) -> io
::Result
<()> {
64 self.0.write_all(s
.as_bytes())
68 fn write_fmt(&mut self, args
: Arguments
) -> io
::Result
<()> {
69 self.0.write_fmt(args
)
73 impl<'w
> StrWrite
for String
{
75 fn write_str(&mut self, s
: &str) -> io
::Result
<()> {
81 fn write_fmt(&mut self, args
: Arguments
) -> io
::Result
<()> {
82 // FIXME: translate fmt error to io error?
83 FmtWrite
::write_fmt(self, args
).map_err(|_
| ErrorKind
::Other
.into())
87 impl<W
> StrWrite
for &'_
mut W
92 fn write_str(&mut self, s
: &str) -> io
::Result
<()> {
97 fn write_fmt(&mut self, args
: Arguments
) -> io
::Result
<()> {
98 (**self).write_fmt(args
)
102 /// Writes an href to the buffer, escaping href unsafe bytes.
103 pub fn escape_href
<W
>(mut w
: W
, s
: &str) -> io
::Result
<()>
107 let bytes
= s
.as_bytes();
109 for i
in 0..bytes
.len() {
111 if c
>= 0x80 || HREF_SAFE
[c
as usize] == 0 {
112 // character needing escape
114 // write partial substring up to mark
116 w
.write_str(&s
[mark
..i
])?
;
120 w
.write_str(AMP_ESCAPE
)?
;
123 w
.write_str(SLASH_ESCAPE
)?
;
126 let mut buf
= [0u8; 3];
128 buf
[1] = HEX_CHARS
[((c
as usize) >> 4) & 0xF];
129 buf
[2] = HEX_CHARS
[(c
as usize) & 0xF];
130 let escaped
= from_utf8(&buf
).unwrap();
131 w
.write_str(escaped
)?
;
134 mark
= i
+ 1; // all escaped characters are ASCII
137 w
.write_str(&s
[mark
..])
140 const fn create_html_escape_table() -> [u8; 256] {
141 let mut table
= [0; 256];
142 table
[b'
"' as usize] = 1;
143 table[b'&' as usize] = 2;
144 table[b'<' as usize] = 3;
145 table[b'>' as usize] = 4;
149 static HTML_ESCAPE_TABLE: [u8; 256] = create_html_escape_table();
151 static HTML_ESCAPES: [&'static str; 5] = ["", ""
;", "&
;", "<
;", ">
;"];
153 /// Writes the given string to the Write sink, replacing special HTML bytes
154 /// (<, >, &, ") by escape sequences
.
155 pub fn escape_html
<W
: StrWrite
>(w
: W
, s
: &str) -> io
::Result
<()> {
156 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
158 simd
::escape_html(w
, s
)
160 #[cfg(not(all(target_arch = "x86_64", feature = "simd")))]
162 escape_html_scalar(w
, s
)
166 fn escape_html_scalar
<W
: StrWrite
>(mut w
: W
, s
: &str) -> io
::Result
<()> {
167 let bytes
= s
.as_bytes();
173 .position(|&c
| HTML_ESCAPE_TABLE
[c
as usize] != 0)
181 let escape
= HTML_ESCAPE_TABLE
[c
as usize];
182 let escape_seq
= HTML_ESCAPES
[escape
as usize];
183 w
.write_str(&s
[mark
..i
])?
;
184 w
.write_str(escape_seq
)?
;
186 mark
= i
; // all escaped characters are ASCII
188 w
.write_str(&s
[mark
..])
191 #[cfg(all(target_arch = "x86_64", feature = "simd"))]
194 use std
::arch
::x86_64
::*;
196 use std
::mem
::size_of
;
198 const VECTOR_SIZE
: usize = size_of
::<__m128i
>();
200 pub(crate) fn escape_html
<W
: StrWrite
>(mut w
: W
, s
: &str) -> io
::Result
<()> {
201 // The SIMD accelerated code uses the PSHUFB instruction, which is part
202 // of the SSSE3 instruction set. Further, we can only use this code if
203 // the buffer is at least one VECTOR_SIZE in length to prevent reading
204 // out of bounds. If either of these conditions is not met, we fall back
206 if is_x86_feature_detected
!("ssse3") && s
.len() >= VECTOR_SIZE
{
207 let bytes
= s
.as_bytes();
211 foreach_special_simd(bytes
, 0, |i
| {
212 let escape_ix
= *bytes
.get_unchecked(i
) as usize;
214 super::HTML_ESCAPES
[super::HTML_ESCAPE_TABLE
[escape_ix
] as usize];
215 w
.write_str(&s
.get_unchecked(mark
..i
))?
;
216 mark
= i
+ 1; // all escaped characters are ASCII
217 w
.write_str(replacement
)
219 w
.write_str(&s
.get_unchecked(mark
..))
222 super::escape_html_scalar(w
, s
)
226 /// Creates the lookup table for use in `compute_mask`.
227 const fn create_lookup() -> [u8; 16] {
228 let mut table
= [0; 16];
229 table
[(b'
<'
& 0x0f) as usize] = b'
<'
;
230 table
[(b'
>'
& 0x0f) as usize] = b'
>'
;
231 table
[(b'
&'
& 0x0f) as usize] = b'
&'
;
232 table
[(b'
"' & 0x0f) as usize] = b'"'
;
233 table
[0] = 0b0111_1111;
237 #[target_feature(enable = "ssse3")]
238 /// Computes a byte mask at given offset in the byte buffer. Its first 16 (least significant)
239 /// bits correspond to whether there is an HTML special byte (&, <, ", >) at the 16 bytes
240 /// `bytes[offset..]`. For example, the mask `(1 << 3)` states that there is an HTML byte
241 /// at `offset + 3`. It is only safe to call this function when
242 /// `bytes.len() >= offset + VECTOR_SIZE`.
243 unsafe fn compute_mask(bytes
: &[u8], offset
: usize) -> i32 {
244 debug_assert
!(bytes
.len() >= offset
+ VECTOR_SIZE
);
246 let table
= create_lookup();
247 let lookup
= _mm_loadu_si128(table
.as_ptr() as *const __m128i
);
248 let raw_ptr
= bytes
.as_ptr().offset(offset
as isize) as *const __m128i
;
250 // Load the vector from memory.
251 let vector
= _mm_loadu_si128(raw_ptr
);
252 // We take the least significant 4 bits of every byte and use them as indices
253 // to map into the lookup vector.
254 // Note that shuffle maps bytes with their most significant bit set to lookup[0].
255 // Bytes that share their lower nibble with an HTML special byte get mapped to that
256 // corresponding special byte. Note that all HTML special bytes have distinct lower
257 // nibbles. Other bytes either get mapped to 0 or 127.
258 let expected
= _mm_shuffle_epi8(lookup
, vector
);
259 // We compare the original vector to the mapped output. Bytes that shared a lower
260 // nibble with an HTML special byte match *only* if they are that special byte. Bytes
261 // that have either a 0 lower nibble or their most significant bit set were mapped to
262 // 127 and will hence never match. All other bytes have non-zero lower nibbles but
263 // were mapped to 0 and will therefore also not match.
264 let matches
= _mm_cmpeq_epi8(expected
, vector
);
266 // Translate matches to a bitmask, where every 1 corresponds to a HTML special character
267 // and a 0 is a non-HTML byte.
268 _mm_movemask_epi8(matches
)
271 /// Calls the given function with the index of every byte in the given byteslice
272 /// that is either ", &, <, or > and for no other byte.
273 /// Make sure to only call this when `bytes.len() >= 16`, undefined behaviour may
275 #[target_feature(enable = "ssse3")]
276 unsafe fn foreach_special_simd
<F
>(
282 F
: FnMut(usize) -> io
::Result
<()>,
284 // The strategy here is to walk the byte buffer in chunks of VECTOR_SIZE (16)
285 // bytes at a time starting at the given offset. For each chunk, we compute a
286 // a bitmask indicating whether the corresponding byte is a HTML special byte.
287 // We then iterate over all the 1 bits in this mask and call the callback function
288 // with the corresponding index in the buffer.
289 // When the number of HTML special bytes in the buffer is relatively low, this
290 // allows us to quickly go through the buffer without a lookup and for every
293 debug_assert
!(bytes
.len() >= VECTOR_SIZE
);
294 let upperbound
= bytes
.len() - VECTOR_SIZE
;
295 while offset
< upperbound
{
296 let mut mask
= compute_mask(bytes
, offset
);
298 let ix
= mask
.trailing_zeros();
299 callback(offset
+ ix
as usize)?
;
300 mask ^
= mask
& -mask
;
302 offset
+= VECTOR_SIZE
;
305 // Final iteration. We align the read with the end of the slice and
306 // shift off the bytes at start we have already scanned.
307 let mut mask
= compute_mask(bytes
, upperbound
);
308 mask
>>= offset
- upperbound
;
310 let ix
= mask
.trailing_zeros();
311 callback(offset
+ ix
as usize)?
;
312 mask ^
= mask
& -mask
;
318 mod html_scan_tests
{
321 let mut vec
= Vec
::new();
323 super::foreach_special_simd("&aXaaaa.a'aa9a<>aab&".as_bytes(), 0, |ix
| {
328 assert_eq
!(vec
, vec
![0, 14, 15, 19]);
331 // only match these bytes, and when we match them, match them VECTOR_SIZE times
333 fn only_right_bytes_matched() {
335 let right_byte
= b
== b'
&'
|| b
== b'
<'
|| b
== b'
>'
|| b
== b'
"';
336 let vek = vec![b; super::VECTOR_SIZE];
337 let mut match_count = 0;
339 super::foreach_special_simd(&vek, 0, |_| {
345 assert!((match_count > 0) == (match_count == super::VECTOR_SIZE));
347 (match_count == super::VECTOR_SIZE),
349 "match_count
: {}
, byte
: {:?}
",