Weighted Random

Weighted Random

·

3 min read

Say we have an array and we want to get one element randomly. Instead of getting the element with equal probability, we want to specify the probability directly. This is called weighted random.

For example, if we have an array ["a", "b", "c"], then we want to get the random element with probability with [1, 2, 3]. Then the probability to get each element is like this:

["a", "b", "c"]
[1, 2, 3]

=> 

sum => 1 + 2 + 3 => 6
[1/6, 2/6, 3/6]

So how to solve this problem?

The most intuitive method may be just enumerate all the elements with its probability value. For example, the above array may be converted into below form.

["a", "b", "c"]
[1, 2, 3]

=> 

["a", "b", "b", "c", "c", "c"]

Then we can solve this problem by getting the random element with equal probability.

But with this method, if the value of probability array is very big, then there might be a memory issue.

Actually we can use the idea of prefix sum to solve this problem.

We still use the above probability array, let's calculate its prefix sum.

[1, 2, 3]

=>

[1, 3, 6]

With this prefix sum array, what pattern can we derive? Well, we can see this prefix sum array with ranges.

[1, 3, 6]

=>

[0, 1]
[1, 3]
[3, 6]

As you can see, each size of the range is the real probability of the corresponding element. So let add each element to see more clearly.

a => [0, 1]
b => [1, 3]
c => [3, 6]

So we can just choose one random value from [0, 6), and check which range this value falls into. Then we can get the corresponding element.

With this in mind, let use this leetcode problem Random Pick with Weight to test the idea. The code looks like below.

class Solution {
    private accuWeights: number[];
    private max: number;

    constructor(w: number[]) {
        this.accuWeights = [w[0]];
        for (let i = 1; i < w.length; i++) {
            this.accuWeights.push(this.accuWeights[i-1] + w[i])
        }
        this.max = this.accuWeights[this.accuWeights.length - 1]
    }

    pickIndex(): number {
        const v = Math.random() * this.max;
        for (let i = 0; i < this.accuWeights.length; i++) {
            if (v < this.accuWeights[i]) {
                return i;
            }
        }
        return 0;
    }
}

In the pickIndex function, we use a for loop to find the target range. So the time complexity is O(n).

Because this prefix sum array is in sorted order, so we can easily reduce this time complexity into O(log(n)) by binary search.


class Solution {
    private accuWeights: number[];
    private max: number;

    constructor(w: number[]) {
        this.accuWeights = [w[0]];
        for (let i = 1; i < w.length; i++) {
            this.accuWeights.push(this.accuWeights[i-1] + w[i])
        }
        this.max = this.accuWeights[this.accuWeights.length - 1]
    }

    pickIndex(): number {
        const v = Math.random() * this.max;
        return this.binarySearch(v);
    }

    binarySearch(v: number): number {
        let left = 0;
        let right = this.accuWeights.length - 1;

        while(left < right) {
            const mid = Math.floor((right - left)/2) + left;
            if (v < this.accuWeights[mid]) {
                right = mid;
            } else {
                left = mid+1;
            }
        }

        return left;
    }
}