]> git.proxmox.com Git - rustc.git/blob - vendor/rand/src/distributions/dirichlet.rs
New upstream version 1.51.0+dfsg1
[rustc.git] / vendor / rand / src / distributions / dirichlet.rs
1 // Copyright 2018 Developers of the Rand project.
2 // Copyright 2013 The Rust Project Developers.
3 //
4 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6 // <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7 // option. This file may not be copied, modified, or distributed
8 // except according to those terms.
9
10 //! The dirichlet distribution.
11 #![allow(deprecated)]
12 #![allow(clippy::all)]
13
14 use crate::distributions::gamma::Gamma;
15 use crate::distributions::Distribution;
16 use crate::Rng;
17
18 /// The dirichelet distribution `Dirichlet(alpha)`.
19 ///
20 /// The Dirichlet distribution is a family of continuous multivariate
21 /// probability distributions parameterized by a vector alpha of positive reals.
22 /// It is a multivariate generalization of the beta distribution.
23 #[deprecated(since = "0.7.0", note = "moved to rand_distr crate")]
24 #[derive(Clone, Debug)]
25 pub struct Dirichlet {
26 /// Concentration parameters (alpha)
27 alpha: Vec<f64>,
28 }
29
30 impl Dirichlet {
31 /// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
32 ///
33 /// # Panics
34 /// - if `alpha.len() < 2`
35 #[inline]
36 pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Dirichlet {
37 let a = alpha.into();
38 assert!(a.len() > 1);
39 for i in 0..a.len() {
40 assert!(a[i] > 0.0);
41 }
42
43 Dirichlet { alpha: a }
44 }
45
46 /// Construct a new `Dirichlet` with the given shape parameter `alpha` and `size`.
47 ///
48 /// # Panics
49 /// - if `alpha <= 0.0`
50 /// - if `size < 2`
51 #[inline]
52 pub fn new_with_param(alpha: f64, size: usize) -> Dirichlet {
53 assert!(alpha > 0.0);
54 assert!(size > 1);
55 Dirichlet {
56 alpha: vec![alpha; size],
57 }
58 }
59 }
60
61 impl Distribution<Vec<f64>> for Dirichlet {
62 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
63 let n = self.alpha.len();
64 let mut samples = vec![0.0f64; n];
65 let mut sum = 0.0f64;
66
67 for i in 0..n {
68 let g = Gamma::new(self.alpha[i], 1.0);
69 samples[i] = g.sample(rng);
70 sum += samples[i];
71 }
72 let invacc = 1.0 / sum;
73 for i in 0..n {
74 samples[i] *= invacc;
75 }
76 samples
77 }
78 }
79
80 #[cfg(test)]
81 mod test {
82 use super::Dirichlet;
83 use crate::distributions::Distribution;
84
85 #[test]
86 fn test_dirichlet() {
87 let d = Dirichlet::new(vec![1.0, 2.0, 3.0]);
88 let mut rng = crate::test::rng(221);
89 let samples = d.sample(&mut rng);
90 let _: Vec<f64> = samples
91 .into_iter()
92 .map(|x| {
93 assert!(x > 0.0);
94 x
95 })
96 .collect();
97 }
98
99 #[test]
100 fn test_dirichlet_with_param() {
101 let alpha = 0.5f64;
102 let size = 2;
103 let d = Dirichlet::new_with_param(alpha, size);
104 let mut rng = crate::test::rng(221);
105 let samples = d.sample(&mut rng);
106 let _: Vec<f64> = samples
107 .into_iter()
108 .map(|x| {
109 assert!(x > 0.0);
110 x
111 })
112 .collect();
113 }
114
115 #[test]
116 #[should_panic]
117 fn test_dirichlet_invalid_length() {
118 Dirichlet::new_with_param(0.5f64, 1);
119 }
120
121 #[test]
122 #[should_panic]
123 fn test_dirichlet_invalid_alpha() {
124 Dirichlet::new_with_param(0.0f64, 2);
125 }
126 }