Efficiently Sampling a Conditioned Distribution
In this lesson we will learn how to implement the conditioned distribution from the previous lesson without using rejection sampling; we want to avoid the possibly-long-running loop.
We'll cover the following
In the previous lesson, we concluded that the predicates and projections used by Where
and Select
on discrete distributions be pure functions:
- They must complete normally, produce no side effects, consume no external state, and never change their behavior.
- They must produce the same result when called with the same argument, every time.
If we make these restrictions then we can get some big performance wins out of Where and Select. Let’s see how.
Dealing With the “Long-Running Loop”
The biggest problem we face is that possibly-long-running loop in the Where
. We are “rejection sampling” the distribution, and we know that can take a long time. Is there a way to directly produce a new distribution that can be efficiently sampled?
Of course, there is. Let’s make a helper method:
public static IDiscreteDistribution<T> ToWeighted<T>(this IEnumerable<T> items, IEnumerable<int> weights)
{
var list = items.ToList();
return WeightedInteger.Distribution(weights).Select(i => list[i]);
}
There’s an additional helper method that we are going to need in a couple of lessons, so let’s just make it now:
public static IDiscreteDistribution<T> ToWeighted<T>(this IEnumerable<T> items,
params int[] weights) => items.ToWeighted((IEnumerable<int>)weights);
And now we can delete our Conditioned
class altogether, and replace it with:
public static IDiscreteDistribution<T> Where<T>(this IDiscreteDistribution<T> d, Func<T, bool> predicate)
{
var s = d.Support().Where(predicate).ToList();
return s.ToWeighted(s.Select(t => d.Weight(t)));
}
Recall that the WeightedInteger
factory will throw an exception if the support is empty, and return a Singleton
or Bernoulli
if its size one or two.
Exercise: We’re doing probabilistic workflows here; it seems strange that we are either 100% rejecting or 100% accepting in
Where
. Can you write an optimized implementation of this method?
public static IDiscreteDistribution<T> Where<T>
(this IDiscreteDistribution<T> d, Func<T, IDiscreteDistribution<bool>> predicate)
That is, we accept each
T
with a probability distribution given by a probabilistic predicate.
Get hands-on with 1200+ tech skills courses.