# Segment Tree

·

## Range Sum Query - Mutable

The problem is, say we have an array `[1,2,3,4,5]`, create a structure like below, to support update and range query operations.

``````class NumArray {
constructor(nums: number[]) {

}

update(index: number, val: number): void {

}

sumRange(left: number, right: number): number {

}
}
``````

The `update` function is just to update specified index's value. The `sumRange` function is to get the sum from index `left` to index `right`.

The most simple solution should be just store the input array, and get the sum values each time when the function `sumRange` be called.

This way, the time complexity should be `O(n)`.

Can we do better?

Another idea is to create a 2d array(like below), to save all possible left and right combinations, and store the sums in the 2d array.

``````  0  1  2  3  4
0 1  3  6  10 15
1    2  5  9  14
2       3  7  12
3          4  9
4             5

row as left, column as right
``````

This way, the time complexity of `sumRange` should become `O(1)`. This is fast. But the problem is, if the array become bigger, then create this 2d array becomes expensive, with `O(n^2)` space complexity. If the array is very large, the this approach may become impossible.

So let's see a special data structure called segment tree, which could be used to range query problems with `O(log(n))` time complexity.

In short, a segment tree is a binary tree, in which each node stands for a partial answer for a certain range.

Let's first see how to create a segment tree and then how to use it to solve range query problems.

Still, let's take the array `[1,2,3,4,5]` as an example.

Like normal binary tree process, creating a segment tree can be seen as an recursive process. The root node stands for the sum from the left index(0) to the right index(4). Then we get the mid index(2) and then left child should be the sum from left index(0) to mid index(2), and right child should be mid index plus 1(3) to right index(4). This is one process. Then we do this process recursively for every node, until we find left index equals to right index.

Follow the above process, we can create a segment tree like below. As you can see, every node has a left index, a right index, a value(sum), a left child and a right child.

``````                [0,4]15
[0,2]6          [3,4]9
[0,1]3  [2,2]3  [3,3]4  [4,4]5
[0,0]1 [1,1]2
``````

So, say we have this tree. Then how to do range query on this data structure?

Just like the process of creating the tree, range query is also a recursively process. We do range query to root node first, and apply the sample process recursively.

Say cur node is for range `[0,4]` with sum value 15.

Case 1: if the query range is like `[-5, -1]` or `[5,9]`, which has no overlap with the node, then we should just return 0 obviously.

Case 2: If the query range is like `[0, 9]` or `[-9, 4]`, which means it has `[0,4]` totally, then we should return the sum value of this node.

Case 3: If the query range is like `[0,1]` or `[2,4]`, which means with the range `[0,4]`, we can't get the answer. Then this is the time to do the same query to left child node and right child node recursively. And return the sum of two child nodes returned values.

With all this in mind, let's see it in code.

``````class SegmentTreeNode {
value: number;
leftIndex: number;
rightIndex: number;
leftChild: SegmentTreeNode | null;
rightChild: SegmentTreeNode | null;

constructor(value: number, leftIndex: number, rightIndex: number, leftChild: SegmentTreeNode | null = null, rightChild: SegmentTreeNode | null = null) {
this.value = value;
this.leftIndex = leftIndex;
this.rightIndex = rightIndex;
this.leftChild = leftChild;
this.rightChild = rightChild
}
}

class NumArray {
private segmentTree: SegmentTreeNode;
private nums: number[];

constructor(nums: number[]) {
this.nums = nums;
this.segmentTree = this.createTree(0, nums.length - 1);
}

// create tree recursively until left index === right index
createTree(leftIndex: number, rightIndex: number): SegmentTreeNode {
if (leftIndex == rightIndex) {
return new SegmentTreeNode(this.nums[leftIndex], leftIndex, rightIndex);
}

const mid = ~~((leftIndex + rightIndex) / 2);
const leftChild = this.createTree(leftIndex, mid);
const rightChild = this.createTree(mid + 1, rightIndex);

return new SegmentTreeNode(leftChild.value + rightChild.value, leftIndex, rightIndex, leftChild, rightChild);
}

// update is simple, just see if index is in current node's range
// if it is, then apply the diff
update(index: number, val: number): void {
const diff = val - this.nums[index];
this.nums[index] = val;
this.updateUtil(index, diff, this.segmentTree);
}
updateUtil(index: number, diff: number, node: SegmentTreeNode | null) {
if (node === null) return;
if (index < node.leftIndex || index > node.rightIndex) return;
node.value += diff; // apply diff
this.updateUtil(index, diff, node.leftChild);
this.updateUtil(index, diff, node.rightChild)
}

// this is range query part
sumRange(left: number, right: number): number {
return this.sumRangeUtil(this.segmentTree, left, right);
}
sumRangeUtil(node: SegmentTreeNode | null, leftIndex: number, rightIndex: number): number {
if (node === null) return 0;
// case 1
if (rightIndex < node.leftIndex || leftIndex > node.rightIndex) return 0;
// case 2
if (leftIndex <= node.leftIndex && rightIndex >= node.rightIndex) return node.value;
// case 3
return this.sumRangeUtil(node.leftChild, leftIndex, rightIndex) + this.sumRangeUtil(node.rightChild, leftIndex, rightIndex);
}
}
``````

## Range Sum Query 2D - Mutable

The problem is 2D matrix, create a class which should support update and range query on the 2D matrix.

``````class NumMatrix {
constructor(matrix: number[][]) {

}

update(row: number, col: number, val: number): void {

}

sumRegion(row1: number, col1: number, row2: number, col2: number): number {

}
}
``````

As you can see, this problem basically is just the 2D version of the previous problem.

With just a little change, we can directly use the segment tree structure to solve this problem too.

``````class SegmentTreeNode {
val: number;
row1: number;
col1: number;
row2: number;
col2: number;
topLeftChild: SegmentTreeNode | null;
topRightChild: SegmentTreeNode | null;
bottomLeftChild: SegmentTreeNode | null;
bottomRightChild: SegmentTreeNode | null;

constructor(
val: number,
row1: number,
col1: number,
row2: number,
col2: number,
topLeftChild: SegmentTreeNode | null = null,
topRightChild: SegmentTreeNode | null = null,
bottomLeftChild: SegmentTreeNode | null = null,
bottomRightChild: SegmentTreeNode | null = null
) {
this.val = val;
this.row1 = row1;
this.col1 = col1;
this.row2 = row2;
this.col2 = col2;
this.topLeftChild = topLeftChild;
this.topRightChild = topRightChild;
this.bottomLeftChild = bottomLeftChild;
this.bottomRightChild = bottomRightChild;
}
}

class NumMatrix {
matrix: number[][];
segmentTree: SegmentTreeNode | null;

constructor(matrix: number[][]) {
this.matrix = matrix;
this.segmentTree = this.createTree(0, 0, matrix.length - 1, matrix[0].length - 1);
}

createTree(row1: number, col1: number, row2: number, col2: number): SegmentTreeNode {
if (row1 === row2 && col1 === col2) {
return new SegmentTreeNode(this.matrix[row1][col1], row1, col1, row2, col2);
}

const midRow = ~~((row2 + row1) / 2);
const midCol = ~~((col2 + col1) / 2);

if (row1 === row2) {
let topLeftChild = this.createTree(row1, col1, row2, midCol);
let topRightChild = this.createTree(row1, midCol + 1, row2, col2);
return new SegmentTreeNode(
topLeftChild.val + topRightChild.val,
row1, col1, row2, col2,
topLeftChild, topRightChild,
null, null
);
}

if (col1 === col2) {
const topLeftChild = this.createTree(row1, col1, midRow, col2);
const bottomLeftChild = this.createTree(midRow + 1, col1, row2, col2);
return new SegmentTreeNode(
topLeftChild.val + bottomLeftChild.val,
row1, col1, row2, col2,
topLeftChild, null, bottomLeftChild, null
);
}

let topLeftChild = this.createTree(row1, col1, midRow, midCol);
let topRightChild = this.createTree(row1, midCol + 1, midRow, col2);
let bottomLeftChild = this.createTree(midRow + 1, col1, row2, midCol);
let bottomRightChild = this.createTree(midRow + 1, midCol + 1, row2, col2);
return new SegmentTreeNode(
topLeftChild.val + topRightChild.val + bottomLeftChild.val + bottomRightChild.val,
row1, col1, row2, col2,
topLeftChild, topRightChild, bottomLeftChild, bottomRightChild,
);
}

update(row: number, col: number, val: number): void {
const diff = val - this.matrix[row][col];
this.matrix[row][col] = val;
this.updateUtil(this.segmentTree, row, col, diff);
}
updateUtil(node: SegmentTreeNode | null, row: number, col: number, diff: number) {
if (node === null) return;
if (row < node.row1 || row > node.row2 || col < node.col1 || col > node.col2) return;
node.val += diff;
this.updateUtil(node.topLeftChild, row, col, diff);
this.updateUtil(node.topRightChild, row, col, diff);
this.updateUtil(node.bottomLeftChild, row, col, diff);
this.updateUtil(node.bottomRightChild, row, col, diff);
}

sumRegion(row1: number, col1: number, row2: number, col2: number): number {
return this.sumRegionUtil(this.segmentTree, row1, col1, row2, col2);
}
sumRegionUtil(node: SegmentTreeNode | null, row1: number, col1: number, row2: number, col2: number): number {
if (node === null) return 0;
if (row1 > node.row2 || col1 > node.col2 || row2 < node.row1 || col2 < node.col1) return 0;
if (node.row1 >= row1 && node.row2 <= row2 && node.col1 >= col1 && node.col2 <= col2) return node.val;
return (
this.sumRegionUtil(node.topLeftChild, row1, col1, row2, col2) +
this.sumRegionUtil(node.topRightChild, row1, col1, row2, col2) +
this.sumRegionUtil(node.bottomLeftChild, row1, col1, row2, col2) +
this.sumRegionUtil(node.bottomRightChild, row1, col1, row2, col2)
);
}
}
``````