]> git.proxmox.com Git - perlmod.git/blob - perlmod/src/de.rs
fbaeaf7069b1a638e8887899d8f59897426521a5
[perlmod.git] / perlmod / src / de.rs
1 //! Serde deserializer for perl values.
2
3 use std::marker::PhantomData;
4
5 use serde::de::{self, Deserialize, DeserializeSeed, MapAccess, SeqAccess, Visitor};
6
7 use crate::error::Error;
8 use crate::scalar::Type;
9 use crate::Value;
10 use crate::{array, ffi, hash};
11
12 /// Perl [`Value`](crate::Value) deserializer.
13 struct Deserializer<'de> {
14 input: Value,
15 option_allowed: bool,
16 _lifetime: PhantomData<&'de Value>,
17 }
18
19 /// Deserialize a perl [`Value`](crate::Value).
20 ///
21 /// Note that this causes all the underlying data to be copied recursively.
22 pub fn from_value<T>(input: Value) -> Result<T, Error>
23 where
24 T: serde::de::DeserializeOwned,
25 {
26 let mut deserializer = Deserializer::<'static>::from_value(input);
27 let out = T::deserialize(&mut deserializer)?;
28 Ok(out)
29 }
30
31 /// Deserialize a reference to a perl [`Value`](crate::Value).
32 ///
33 /// Note that this causes all the underlying data to be copied recursively, except for data
34 /// deserialized to `&[u8]` or `&str`, which will reference the "original" value (whatever that
35 /// means for perl).
36 pub fn from_ref_value<'de, T>(input: &'de Value) -> Result<T, Error>
37 where
38 T: Deserialize<'de>,
39 {
40 let mut deserializer = Deserializer::<'de>::from_value(input.clone_ref());
41 let out = T::deserialize(&mut deserializer)?;
42 Ok(out)
43 }
44
45 impl<'deserializer> Deserializer<'deserializer> {
46 pub fn from_value(input: Value) -> Self {
47 Deserializer {
48 input,
49 option_allowed: true,
50 _lifetime: PhantomData,
51 }
52 }
53
54 fn deref_current(&mut self) -> Result<(), Error> {
55 while let Value::Reference(_) = &self.input {
56 self.input = self.input.dereference().ok_or_else(|| {
57 Error::new("failed to dereference a reference while deserializing")
58 })?;
59 }
60 Ok(())
61 }
62
63 fn sanity_check(&mut self) -> Result<(), Error> {
64 if let Value::Scalar(value) = &self.input {
65 match value.ty() {
66 Type::Scalar(_) => Ok(()),
67 Type::Other(other) => Err(Error(format!(
68 "cannot deserialize weird magic perl values ({})",
69 other
70 ))),
71 // These are impossible as they are all handled by different Value enum types:
72 Type::Reference => Error::fail("Value::Scalar: containing a reference"),
73 Type::Array => Error::fail("Value::Scalar: containing an array"),
74 Type::Hash => Error::fail("Value::Scalar: containing a hash"),
75 }
76 } else {
77 Ok(())
78 }
79 }
80
81 fn get(&mut self) -> Result<&Value, Error> {
82 self.deref_current()?;
83 self.sanity_check()?;
84 Ok(&self.input)
85 }
86
87 /// deserialize_any, preferring a string value
88 fn deserialize_any_string<'de, V>(&mut self, visitor: V) -> Result<V::Value, Error>
89 where
90 V: Visitor<'de>,
91 {
92 match self.get()? {
93 Value::Scalar(value) => match value.ty() {
94 Type::Scalar(flags) => {
95 use crate::scalar::Flags;
96
97 if flags.contains(Flags::STRING) {
98 let s = unsafe { str_set_wrong_lifetime(value.pv_string_utf8()) };
99 visitor.visit_borrowed_str(s)
100 } else if flags.contains(Flags::DOUBLE) {
101 visitor.visit_f64(value.nv())
102 } else if flags.contains(Flags::INTEGER) {
103 visitor.visit_i64(value.iv() as i64)
104 } else if flags.is_empty() {
105 visitor.visit_none()
106 } else {
107 visitor.visit_unit()
108 }
109 }
110 _ => unreachable!(),
111 },
112 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
113 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
114 Value::Reference(_) => unreachable!(),
115 }
116 }
117
118 /// deserialize_any, preferring an integer value
119 fn deserialize_any_iv<'de, V>(&mut self, visitor: V) -> Result<V::Value, Error>
120 where
121 V: Visitor<'de>,
122 {
123 match self.get()? {
124 Value::Scalar(value) => match value.ty() {
125 Type::Scalar(flags) => {
126 use crate::scalar::Flags;
127
128 if flags.contains(Flags::INTEGER) {
129 visitor.visit_i64(value.iv() as i64)
130 } else if flags.contains(Flags::DOUBLE) {
131 visitor.visit_f64(value.nv())
132 } else if flags.contains(Flags::STRING) {
133 let s = unsafe { str_set_wrong_lifetime(value.pv_string_utf8()) };
134 visitor.visit_borrowed_str(s)
135 } else {
136 visitor.visit_unit()
137 }
138 }
139 _ => unreachable!(),
140 },
141 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
142 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
143 Value::Reference(_) => unreachable!(),
144 }
145 }
146
147 /// deserialize_any, preferring a float value
148 fn deserialize_any_nv<'de, V>(&mut self, visitor: V) -> Result<V::Value, Error>
149 where
150 V: Visitor<'de>,
151 {
152 match self.get()? {
153 Value::Scalar(value) => match value.ty() {
154 Type::Scalar(flags) => {
155 use crate::scalar::Flags;
156
157 if flags.contains(Flags::DOUBLE) {
158 visitor.visit_f64(value.nv())
159 } else if flags.contains(Flags::INTEGER) {
160 visitor.visit_i64(value.iv() as i64)
161 } else if flags.contains(Flags::STRING) {
162 let s = unsafe { str_set_wrong_lifetime(value.pv_string_utf8()) };
163 visitor.visit_borrowed_str(s)
164 } else {
165 visitor.visit_unit()
166 }
167 }
168 _ => unreachable!(),
169 },
170 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
171 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
172 Value::Reference(_) => unreachable!(),
173 }
174 }
175 }
176
177 /// We use this only for `Value`s in our deserializer. We know this works because serde says the
178 /// lifetime needs to only live as long as the serializer, and we feed our serializer with the data
179 /// from a borrowed Value (keeping references to all the contained data within perl), which lives
180 /// longer than the deserializer.
181 unsafe fn str_set_wrong_lifetime<'a, 'b>(s: &'a str) -> &'b str {
182 std::str::from_utf8_unchecked(std::slice::from_raw_parts(s.as_ptr(), s.len()))
183 }
184
185 impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
186 type Error = Error;
187
188 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
189 where
190 V: Visitor<'de>,
191 {
192 self.deserialize_any_string(visitor)
193 }
194
195 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Error>
196 where
197 V: Visitor<'de>,
198 {
199 match self.get()? {
200 Value::Scalar(value) => match value.ty() {
201 Type::Scalar(flags) => {
202 use crate::scalar::Flags;
203
204 if flags.is_empty() || flags.intersects(Flags::INTEGER | Flags::DOUBLE) {
205 visitor.visit_bool(unsafe { ffi::RSPL_SvTRUE(value.sv()) })
206 } else {
207 Error::fail("expected bool value")
208 }
209 }
210 _ => unreachable!(),
211 },
212 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
213 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
214 Value::Reference(_) => unreachable!(),
215 }
216 }
217
218 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Error>
219 where
220 V: Visitor<'de>,
221 {
222 self.deserialize_any_iv(visitor)
223 }
224
225 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Error>
226 where
227 V: Visitor<'de>,
228 {
229 self.deserialize_any_iv(visitor)
230 }
231
232 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Error>
233 where
234 V: Visitor<'de>,
235 {
236 self.deserialize_any_iv(visitor)
237 }
238
239 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Error>
240 where
241 V: Visitor<'de>,
242 {
243 self.deserialize_any_iv(visitor)
244 }
245
246 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Error>
247 where
248 V: Visitor<'de>,
249 {
250 self.deserialize_any_iv(visitor)
251 }
252
253 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Error>
254 where
255 V: Visitor<'de>,
256 {
257 self.deserialize_any_iv(visitor)
258 }
259
260 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Error>
261 where
262 V: Visitor<'de>,
263 {
264 self.deserialize_any_iv(visitor)
265 }
266
267 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Error>
268 where
269 V: Visitor<'de>,
270 {
271 self.deserialize_any_iv(visitor)
272 }
273
274 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Error>
275 where
276 V: Visitor<'de>,
277 {
278 self.deserialize_any_nv(visitor)
279 }
280
281 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Error>
282 where
283 V: Visitor<'de>,
284 {
285 self.deserialize_any_nv(visitor)
286 }
287
288 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Error>
289 where
290 V: Visitor<'de>,
291 {
292 match self.get()? {
293 Value::Scalar(value) => match value.ty() {
294 Type::Scalar(flags) => {
295 use crate::scalar::Flags;
296
297 if flags.contains(Flags::INTEGER) {
298 let c = value.iv();
299 if c < 0x100 {
300 visitor.visit_char(c as u8 as char)
301 } else {
302 visitor.visit_i64(c as i64)
303 }
304 } else if flags.contains(Flags::DOUBLE) {
305 visitor.visit_f64(value.nv())
306 } else if flags.contains(Flags::STRING) {
307 let s = value.pv_string_utf8();
308 let mut chars = s.chars();
309 match chars.next() {
310 Some(ch) if chars.next().is_none() => visitor.visit_char(ch),
311 _ => {
312 let s = unsafe { str_set_wrong_lifetime(value.pv_string_utf8()) };
313 visitor.visit_borrowed_str(s)
314 }
315 }
316 } else {
317 visitor.visit_unit()
318 }
319 }
320 _ => unreachable!(),
321 },
322 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
323 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
324 Value::Reference(_) => unreachable!(),
325 }
326 }
327
328 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Error>
329 where
330 V: Visitor<'de>,
331 {
332 self.deserialize_any(visitor)
333 }
334
335 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Error>
336 where
337 V: Visitor<'de>,
338 {
339 self.deserialize_any(visitor)
340 }
341
342 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Error>
343 where
344 V: Visitor<'de>,
345 {
346 match self.get()? {
347 Value::Scalar(value) => match value.ty() {
348 Type::Scalar(flags) => {
349 use crate::scalar::Flags;
350
351 if flags.contains(Flags::STRING) {
352 let bytes = value.pv_bytes();
353 let bytes: &'de [u8] =
354 unsafe { std::slice::from_raw_parts(bytes.as_ptr(), bytes.len()) };
355 visitor.visit_borrowed_bytes(bytes)
356 } else if flags.contains(Flags::DOUBLE) {
357 visitor.visit_f64(value.nv())
358 } else if flags.contains(Flags::INTEGER) {
359 visitor.visit_i64(value.iv() as i64)
360 } else {
361 visitor.visit_unit()
362 }
363 }
364 _ => unreachable!(),
365 },
366 Value::Hash(value) => visitor.visit_map(HashAccess::new(value)),
367 Value::Array(value) => visitor.visit_seq(ArrayAccess::new(value)),
368 Value::Reference(_) => unreachable!(),
369 }
370 }
371
372 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Error>
373 where
374 V: Visitor<'de>,
375 {
376 self.deserialize_bytes(visitor)
377 }
378
379 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Error>
380 where
381 V: Visitor<'de>,
382 {
383 if self.option_allowed {
384 if let Value::Scalar(value) = self.get()? {
385 if let Type::Scalar(flags) = value.ty() {
386 if flags.is_empty() {
387 return visitor.visit_none();
388 }
389 }
390 }
391 self.option_allowed = false;
392 let res = visitor.visit_some(&mut *self);
393 self.option_allowed = true;
394 res
395 } else {
396 self.deserialize_any(visitor)
397 }
398 }
399
400 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Error>
401 where
402 V: Visitor<'de>,
403 {
404 self.deserialize_any(visitor)
405 }
406
407 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Error>
408 where
409 V: Visitor<'de>,
410 {
411 self.deserialize_any(visitor)
412 }
413
414 fn deserialize_newtype_struct<V>(
415 self,
416 _name: &'static str,
417 visitor: V,
418 ) -> Result<V::Value, Error>
419 where
420 V: Visitor<'de>,
421 {
422 self.deserialize_any(visitor)
423 }
424
425 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Error>
426 where
427 V: Visitor<'de>,
428 {
429 self.deserialize_any(visitor)
430 }
431
432 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Error>
433 where
434 V: Visitor<'de>,
435 {
436 self.deserialize_any(visitor)
437 }
438
439 fn deserialize_tuple_struct<V>(
440 self,
441 _name: &'static str,
442 _len: usize,
443 visitor: V,
444 ) -> Result<V::Value, Error>
445 where
446 V: Visitor<'de>,
447 {
448 self.deserialize_any(visitor)
449 }
450
451 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Error>
452 where
453 V: Visitor<'de>,
454 {
455 self.deserialize_any(visitor)
456 }
457
458 fn deserialize_struct<V>(
459 self,
460 _name: &'static str,
461 _fields: &'static [&'static str],
462 visitor: V,
463 ) -> Result<V::Value, Error>
464 where
465 V: Visitor<'de>,
466 {
467 self.deserialize_map(visitor)
468 }
469
470 fn deserialize_enum<V>(
471 self,
472 _name: &'static str,
473 _variants: &'static [&'static str],
474 visitor: V,
475 ) -> Result<V::Value, Error>
476 where
477 V: Visitor<'de>,
478 {
479 self.deserialize_any(visitor)
480 }
481
482 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Error>
483 where
484 V: Visitor<'de>,
485 {
486 self.deserialize_any(visitor)
487 }
488
489 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Error>
490 where
491 V: Visitor<'de>,
492 {
493 self.deserialize_any(visitor)
494 }
495 }
496
497 /// Serde `MapAccess` intermediate type.
498 pub struct HashAccess<'a> {
499 hash: &'a hash::Hash,
500 entry: *mut ffi::HE,
501 finished: bool,
502 at_value: bool,
503 }
504
505 impl<'a> HashAccess<'a> {
506 pub fn new(value: &'a hash::Hash) -> Self {
507 drop(value.shared_iter()); // reset iterator
508 Self {
509 hash: &value,
510 entry: std::ptr::null_mut(),
511 finished: false,
512 at_value: false,
513 }
514 }
515 }
516
517 impl<'de, 'a> MapAccess<'de> for HashAccess<'a> {
518 type Error = Error;
519
520 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
521 where
522 K: DeserializeSeed<'de>,
523 {
524 if self.finished {
525 return Ok(None);
526 }
527
528 if self.entry.is_null() {
529 self.entry = unsafe { ffi::RSPL_hv_iternext(self.hash.hv()) };
530 if self.entry.is_null() {
531 self.finished = true;
532 return Ok(None);
533 }
534 } else if self.at_value {
535 return Error::fail("map access value skipped");
536 }
537
538 self.at_value = true;
539
540 let key = unsafe { Value::from_raw_ref(ffi::RSPL_hv_iterkeysv(self.entry)) };
541 seed.deserialize(&mut Deserializer::from_value(key))
542 .map(Some)
543 }
544
545 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
546 where
547 V: DeserializeSeed<'de>,
548 {
549 if self.finished {
550 return Error::fail("map access value requested after end");
551 }
552
553 if self.entry.is_null() || !self.at_value {
554 return Error::fail("map access key skipped");
555 }
556
557 self.at_value = false;
558
559 let value =
560 unsafe { Value::from_raw_ref(ffi::RSPL_hv_iterval(self.hash.hv(), self.entry)) };
561 self.entry = std::ptr::null_mut();
562
563 seed.deserialize(&mut Deserializer::from_value(value))
564 }
565 }
566
567 /// Serde `SeqAccess` intermediate type.
568 pub struct ArrayAccess<'a> {
569 iter: array::Iter<'a>,
570 }
571
572 impl<'a> ArrayAccess<'a> {
573 pub fn new(value: &'a array::Array) -> Self {
574 Self { iter: value.iter() }
575 }
576 }
577
578 impl<'de, 'a> SeqAccess<'de> for ArrayAccess<'a> {
579 type Error = Error;
580
581 fn next_element_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
582 where
583 K: DeserializeSeed<'de>,
584 {
585 self.iter
586 .next()
587 .map(move |value| seed.deserialize(&mut Deserializer::from_value(value)))
588 .transpose()
589 }
590 }