Skip to content

Commit 82ec7a8

Browse files
committed
Avoid returning NaN from Gamma::sample
This changes the order of multiplications used to compute the result to avoid multiplying zero with an expression that can overflow to +inf. Note that the parameter combinations which could lead to this (shape very close to zero, scale very close to the max float value) continue to be inaccurately handled; the Gamma distribution sampler will now tend to return zero instead of NaN for them. The limit (shape 0, scale inf) is not well defined.
1 parent a6a9f7b commit 82ec7a8

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- Fix panic in `FisherF::new` on almost zero parameters (#39)
2323
- Fix panic in `NormalInverseGaussian::new` with very large `alpha`; this is a Value-breaking change (#40)
2424
- Error instead of producing `-inf` output for `Exp` when `lambda` is `-0.0` (#44)
25+
- Avoid returning NaN from `Gamma::sample`; this is a Value-breaking change and also affects `ChiSquared` and `Dirichlet` (#46)
2526

2627
## [0.5.2]
2728

src/gamma.rs

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,10 @@ where
173173
return Err(Error::ScaleTooSmall);
174174
}
175175

176-
let repr = if shape == F::one() {
177-
One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
176+
let repr = if shape == F::infinity() || scale == F::infinity() {
177+
One(Exp::new(F::zero()).unwrap())
178+
} else if shape == F::one() {
179+
One(Exp::new(F::one() / scale).unwrap())
178180
} else if shape < F::one() {
179181
Small(GammaSmallShape::new_raw(shape, scale))
180182
} else {
@@ -212,6 +214,28 @@ where
212214
d,
213215
}
214216
}
217+
218+
fn sample_unscaled<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
219+
// Marsaglia & Tsang method, 2000
220+
loop {
221+
let x: F = rng.sample(StandardNormal);
222+
let v_cbrt = F::one() + self.c * x;
223+
if v_cbrt <= F::zero() {
224+
continue;
225+
}
226+
227+
let v = v_cbrt * v_cbrt * v_cbrt;
228+
let u: F = rng.sample(Open01);
229+
230+
let x_sqr = x * x;
231+
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
232+
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
233+
{
234+
// `x` is concentrated enough that `v` should always be finite
235+
return v;
236+
}
237+
}
238+
}
215239
}
216240

217241
impl<F> Distribution<F> for Gamma<F>
@@ -238,35 +262,22 @@ where
238262
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
239263
let u: F = rng.sample(Open01);
240264

241-
self.large_shape.sample(rng) * u.powf(self.inv_shape)
265+
let a = self.large_shape.sample_unscaled(rng);
266+
let b = u.powf(self.inv_shape);
267+
// Multiplying numbers with `scale` can overflow, so do it last to avoid
268+
// producing NaN = inf * 0.0. All the other terms are finite and small.
269+
(a * b * self.large_shape.d) * self.large_shape.scale
242270
}
243271
}
272+
244273
impl<F> Distribution<F> for GammaLargeShape<F>
245274
where
246275
F: Float,
247276
StandardNormal: Distribution<F>,
248277
Open01: Distribution<F>,
249278
{
250279
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
251-
// Marsaglia & Tsang method, 2000
252-
loop {
253-
let x: F = rng.sample(StandardNormal);
254-
let v_cbrt = F::one() + self.c * x;
255-
if v_cbrt <= F::zero() {
256-
// a^3 <= 0 iff a <= 0
257-
continue;
258-
}
259-
260-
let v = v_cbrt * v_cbrt * v_cbrt;
261-
let u: F = rng.sample(Open01);
262-
263-
let x_sqr = x * x;
264-
if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
265-
|| u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
266-
{
267-
return self.d * v * self.scale;
268-
}
269-
}
280+
self.sample_unscaled(rng) * (self.d * self.scale)
270281
}
271282
}
272283

@@ -278,4 +289,13 @@ mod test {
278289
fn gamma_distributions_can_be_compared() {
279290
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
280291
}
292+
293+
#[test]
294+
fn gamma_extreme_values() {
295+
let d = Gamma::new(f64::infinity(), 2.0).unwrap();
296+
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
297+
298+
let d = Gamma::new(2.0, f64::infinity()).unwrap();
299+
assert_eq!(d.sample(&mut crate::test::rng(21)), f64::infinity());
300+
}
281301
}

0 commit comments

Comments
 (0)