Fast numerically stable moving average
Calculating the moving average of a time series is a pretty simple problem. Just loop over each position, and add up the elements over the window:
void simple_ma(double* in, double* out, int N, int W) {
for (int i = 0; i < N - W + 1; i++) {
double sum = 0;
for (int j = 0; j < W; j++) {
sum += in[i + j];
}
out[i] = sum / W;
}
}
However, this has a time complexity of
Naive approach
There is a very simple trick to speed the calculation up and reach a time complexity of
void summing_ma(double* in, double* out, int N, int W) {
double sum = 0;
for (int i = 0; i < W; i++) {
sum += in[i];
}
out[0] = sum / W;
for (int i = W; i < N; i++) {
sum = sum + in[i] - in[i - W] ;
out[i - W + 1] = sum / W;
}
}
This is fast, however fails miserably when your numbers are of different scale. Let's check its output for the following input:
double in[] = {1, 1, 1, 1e17, 1, 1, 1, 1};
// Output
1 3.333e+16 3.333e+16 3.333e+16 0 0
After the fifth element, the result becomes zero, reaching a relative error of 100%. What has happened? We just hit the main problem of floating point calculations: catastrophic cancellation.
To understand catastrophic cancellation, let's review how floating point numbers are represented. Floating point numbers can be written as
Multiplying two floating point numbers is easy, just multiply the significand and sum the exponents:
In extreme cases, the sum might not change at all:
1e17 + 1 = 1e17
That is, adding 1
to 1e17
does not change the value. How does this relate to our buffered moving average? When adding 1e17
to the sum, the variable sum
becomes 1e17
. On the fifth iteration, when 1e17
leaves the window, this is what the implementation does:
(sum + 1) - 1e17 = (1e17 + 1) - 1e17 = 1e17 - 1e17 = 0
So the output for the fifth average is zero. To make matters worse, sum
is corrupted now! That is, the average of every subsequent window will be zero, introducing large errors forever.
This is a pretty serious flaw of the algorithm. Imagine you are collecting measurements, 1 million timestamps, and you have an erroneous number somewhere. Maybe you do some preprocessing and a near zero value ends up as a very large number after a division. Our fast moving average algorithm will produce nonsensical results after the bad measurement, making all the output useless. That is, a single measurement error invalidates your million elements long time series!
Fast and stable summing
The core problem with the summing_ma
algorithm is that it buffers the total sum and we need subtractions to remove values that are leaving the window. We should cache the sum such that no subtractions are needed to refresh the results.
The idea is that instead of a linear sum, we calculate the sum in a different order:
The bottom row holds the input numbers, and inner nodes store the sum of their children. The sum of the input numbers can be found in the root node. After changing 1 to 10 in the input window we get:
I've denoted the nodes that were updated by red. Note that for each level of the tree only one update is performed. The levels of the tree and thus the number of the updates is
The pseudocode of the algorithm looks like this:
- Initialize the tree by copying the values in the first window to the leaves and calculate the inner node values.
- Divide the value in the root node by
and put it in the output array. - Point
pos
to the first leaf node. - Update the node at
pos
with the new value entering the window. - Update all the ancestors of
pos
and calculate the new average based on the root node. - Point
pos
to the next leaf. - Repeat from Step 4 until finished.
The binary tree is almost-complete, meaning every level has all the nodes, except maybe for the last one. This allows for a very efficient implementation, storing the tree as an array:
Each level is stored continuously, starting with the root node. Moving up and down in the tree is easy. For the node at index
- the parent's index is
, - the left child's index is
, - the right child's index is
.
The actual implementation is the following:
void tree_ma(double* in, double* out, int N, int W) {
// Allocate buffer
int d = ceil(log2(W));
double* buffer = (double*)malloc((1 << (d + 1)) * sizeof(double));
memset(buffer, 0, (1 << (d + 1)) * sizeof(double));
// Initialize the buffer, first the leaf nodes
for (int i = 0; i < W; i++) {
buffer[(1 << d) - 1 + i] = in[i];
}
// Initialize the non-leaf nodes
for (int i = (1 << d) - 2; i >= 0; i--) {
buffer[i] = buffer[2 * i + 1] + buffer[2 * i + 2];
}
out[0] = buffer[0] / W; // Insert the first element
int pos = 0; // Position of the oldest element in the buffer
for (int i = W; i < N; i++) {
// Overwrite the oldest element in the buffer
buffer[(1 << d) - 1 + pos] = in[i];
// Update the tree
for (int k = (1 << d) - 1 + pos; k > 0;) {
k = (k - 1) / 2;
buffer[k] = buffer[2 * k + 1] + buffer[2 * k + 2];
}
// Save the the average
out[i - W + 1] = buffer[0] / W;
// Step the buffer index
pos = (pos + 1) % W;
}
free(buffer);
}
First, the leaves of the tree are set by copying the first W elements of the input. Then, the intermediate nodes are calculated backwards, which corresponds to a bottom-up traversal of the tree. Finally, the window is slid
Execution time
The following plot shows the running time of the different algorithms. I used a time series with fixed length of 1M elements. The
As you can see, the SIMD version of the simple algorithm is 4 times faster, since AVX2 allows 4 double precision additions in parallel. The tree based approach is slower at small window sizes but quickly takes over the non-SIMD simple algorithm (around
Further extensions
The tree-wise addition fixed the issue of catastrophic cancellation of the naive algorithm. However, we are even better than that. The proposed algorithm is essentially pairwise summation whose asymptotic error is
The tree-wise summation algorithm is trivially parallelizable. If you have multiple time series, you can use SIMD instructions to calculate 4 time series at the same time, speeding up the computation even more.
The trick we exploited was that we can calculate the sum of a set of numbers recursively: first split the list in two, calculate the sums of the sublists and add the results together. Can we do the same with more complicated statistics, like standard deviation or correlation? While I haven't tried implementing them, both have formulas that calculate the value recursively. Variance of a list can be computed from two sublists' variance using Chen's formula. Schubert and Gertz provide equations for covariance calculation.