Summed Area Table

Summed Area Table

·

6 min read

Summed area table is a data structure which is like the 2d version of prefix sum. If you have no idea about prefix sum, check my previous article about it. Same like prefix sum, summed area table can be use to get the sum of the rectangular subset of a grid efficiently.

Let's take the box blur task as an example. Box blur uses a kernel window to slide over the whole image, calculate the average and generate the blurred image. A 3*3(radius=1) kernel is like below.

1 1 1
1 1 1
1 1 1

The blur process can be seen as below.

kernel   pixel     multiply       sum        average
-----------------------------------------------
1 1 1    1 2 3    1*1 1*2 1*3    -  -  -    -   -   -
1 1 1  * 4 5 6 => 1*4 1*5 1*6 => - sum - => - sum/9 -
1 1 1    7 8 9    1*7 1*8 1*9    -  -  -    -   -   -

sum/9 = (1+2+3+4+5+6+7+8+9)/9

After we know what the task is, it is not hard to implement it in code. The leetcode Image Smoother problem is the same problem, let's try to solve it in code.

function imageSmoother(img: number[][]): number[][] {
    const img2 = Array.from({ length: img.length }, (_, i: number) => {
        return Array.from({ length: img[i].length }, (_, j: number) => {
            return img[i][j];
        });
    });

    const radius = 1;
    for (let i = 0; i < img.length; i++) {
        for (let j = 0; j < img[i].length; j++) {
            let sum = 0;
            let count = 0;
            for (let ii = i - radius; ii <= i + radius; ii++) {
                for (let jj = j - radius; jj <= j + radius; jj++) {
                    if (ii < 0) continue;
                    if (jj < 0) continue;
                    if (ii >= img.length) continue;
                    if (jj >= img[i].length) continue;
                    count += 1;
                    sum += img[ii][jj];
                }
            }
            img2[i][j] = Math.floor(sum / count);
        }
    }

    return img2;
};

As you can see, the process is very clear. We iterate all pixels, and calculate window average value for each pixels. Let image's size as m * n, then window size as w, then time complexity is O(mnww).

How can we speed up this process? That's what summed area table used for. Like prefix sum, a summed area table show the sumed values from left top coordinate to current coordinate. See the example below.

img
1 2 3
4 5 6
7 8 9

table
1       1+2          1+2+3
1+4     1+2+4+5      1+2+3+4+5+6
1+4+7   1+2+4+5+7+8  1+2+3+4+5+6+7+8+9

Now let's consider how to calcuate this table efficiently. We certainly shouldn't calculate from scratch for each position. Actually, if you look above example closely, you may find every value of the table has a relation with its left value, top value and left top value. If current position has coordinate (i,j), then we can calculate its value using below formula.

table[i][j] = img[i][j] + table[i-1][j] + table[i][j-1] - t[i-1][j-1];

So with this formula, we can calculate the table in O(n) time complexity.

Next thing to consider is, how this table can be used to help solving the problem.

Suppose we have an image below. What we want to do now is to calculate the sum from position x to position y.

. . . . . . . . .
. a . . c . . . .
. . x . . . . . .
. . . . . . . . .
. b . . y . . . .
. . . . . . . . .

With the summed area table, we can calcuate this targeted area's sum using below formula.

const target = table[yi][yj] - table[bi][bj] - table[ci][cj] + table[ai][aj]

Its not very hard to understand, we just subtract the extra sum(left(b) and top(c)). And because we subtrack the left-top(a) area twice, we add one back.

OK, let write all of this in code now.


function imageSmoother(img: number[][]): number[][] {
    const m = img.length;
    const n = img[0].length;

    const img2 = Array.from({ length: m }, (_, i: number) => {
        return Array.from({ length: n }, (_, j: number) => {
            return img[i][j];
        });
    });

    // summed area table
    const t = Array.from({ length: m }, () => Array.from({ length: n }, () => 0));
    t[0][0] = img[0][0];
    for (let i = 1; i < m; i++) {
        t[i][0] = t[i - 1][0] + img[i][0];
    }
    for (let j = 1; j < n; j++) {
        t[0][j] = t[0][j - 1] + img[0][j];
    }
    for (let i = 1; i < m; i++) {
        for (let j = 1; j < n; j++) {
            t[i][j] = img[i][j] + t[i - 1][j] + t[i][j - 1] - t[i - 1][j - 1];
        }
    }

    const radius = 1;

    for (let i = 0; i < img.length; i++) {
        for (let j = 0; j < img[i].length; j++) {
            // calculate top index and its value
            const topI = i - radius - 1;
            const topJ = j + radius >= n ? n - 1 : j + radius;
            const topValue = topI >= 0 ? t[topI][topJ] : 0;

            // calculate left index and its value
            const leftI = i + radius >= m ? m - 1 : i + radius;
            const leftJ = j - radius - 1;
            const leftValue = leftJ >= 0 ? t[leftI][leftJ] : 0;

            // calculate top-left index and its value
            const topLeftI = i - radius - 1;
            const topLeftJ = j - radius - 1;
            const topLeftValue = (topLeftI >= 0 && topLeftJ >= 0) ? t[topLeftI][topLeftJ] : 0;

            // calculate bottom-right index and its value
            const bottomRightI = i + radius >= m ? m - 1 : i + radius;
            const bottomRightJ = j + radius >= n ? n - 1 : j + radius;
            const bottomRightValue = t[bottomRightI][bottomRightJ]

            const sum = bottomRightValue - topValue - leftValue + topLeftValue;

            // calculate valid numbers of cells
            const row = (i - 0 > radius ? radius : i - 0) + 1 + (m - i - 1 > radius ? radius : m - i - 1);
            const column = (j - 0 > radius ? radius : j - 0) + 1 + (n - j - 1 > radius ? radius : n - j - 1);
            const count = row * column;

            img2[i][j] = Math.floor(sum / count);
        }
    }

    return img2;

};

The above may seems a lot, but the whole process is very straightforward. The extra code is used to handle edge cases.

As you can see, with summed area table, we can calculate every pixel's window sum in O(1) time complexity. So total time complexity reduced to O(mn).