Fourier Convolutions in PyTorch

Math and code for efficiently computing large convolutions with FFTs.

Frank Odom
Towards Data Science

--

Photo by Faye Cornish on Unsplash

Note: Complete methods for 1D, 2D, and 3D Fourier convolutions are provided in this Github repo. I also provide PyTorch modules, for easily adding Fourier convolutions to a trainable model.

Convolutions

Convolutions are ubiquitous in data analysis. For decades, they’ve been used in signal and image processing. More recently, they became an important ingredient in modern neural networks. You’ve probably encountered convolutions if you work with data at all.

Mathematically, convolutions are expressed as:

Although discrete convolutions are more common in computational applications, I’ll be working with the continuous form for most of this article, because it is much easier to prove the Convolution Theorem (discussed below) using continuous variables. After that, we will return to the discrete case and implement it in PyTorch using Fourier transforms. Discrete convolutions can be viewed as an approximation of continuous ones, where continuous functions are discretized on a regular grid. So we will not re-prove the Convolution Theorem for the discrete case.

Convolution Theorem

Mathematically, the Convolution Theorem can be stated as:

where the continuous Fourier transform is (up to a normalization constant):

In other words, convolution in position space is equivalent to direct multiplication in frequency space. This idea is fairly non-intuitive, but proving the Convolution Theorem is surprisingly easy for the continuous case. To do that, start by writing out the left hand side of the equation.

Now switch the order of integration, make a substitution of variables (x = y + z), and separate the two integrands.

So What?

Why should we care about all of this? Because the fast Fourier transform has a lower algorithmic complexity than convolution. Direct convolutions have complexity O(n²), because we pass over every element in g for each element in f. Fast Fourier transforms can be computed in O(n log n) time. They are much faster than convolutions when the input arrays are large. In those cases, we can use the Convolution Theorem to compute convolutions in frequency space, and then perform the inverse Fourier transform to get back to position space.

Direct convolutions are still faster when the inputs are small (e.g. 3x3 convolution kernels). In machine learning applications, it’s more common to use small kernel sizes, so deep learning libraries like PyTorch and Tensorflow only provide implementations of direct convolutions. But there are plenty of real-world use cases with large kernel sizes, where Fourier convolutions are more efficient.

PyTorch Implementation

Now, I’ll demonstrate how to implement a Fourier convolution function in PyTorch. It should mimic the functionality of torch.nn.functional.convNd and leverage FFTs behind the curtain without any additional work from the user. As such, it should accept three Tensors (signal, kernel, and optionally bias) and the padding to apply to the input. Conceptually, the inner workings of this function will be:

Let’s incrementally build the FFT convolution according the order of operations shown above. For this example, I’ll just build a 1D Fourier convolution, but it is straightforward to extend this to 2D and 3D convolutions. Or visit my Github repo, where I’ve implemented a generic N-dimensional Fourier convolution method.

1 — Pad the Input Arrays

We need to ensure that signal and kernel have the same size after padding. Apply the initial padding to signal, and then adjust the padding for kernel to match.

Notice that I only pad kernel on one side. We want the original kernel on the left-hand side of the padded array, so that it aligns with the start of the signal array.

2 — Compute Fourier Transforms

This is very easy, because N-dimensional FFTs are already implemented in PyTorch. We simply use the built-in function, and compute the FFT along the last dimension of each Tensor.

3 — Multiply the Transformed Tensors

Surprisingly, this is the trickiest part of our function. There are two reasons for that. (1) PyTorch convolutions operate on multi-dimensional Tensors, so our signal and kernel Tensors are actually three-dimensional. From this equation in the PyTorch docs, we see that matrix multiplication is performed over the first two dimensions (excluding bias term):

Source: https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d

We’ll need to include this matrix multiplication, as well as the direct multiplication over the transformed dimensions.

(2) As shown here in the docs, PyTorch actually implements a cross-correlation method instead of convolution. (The same is true for TensorFlow and other deep learning libraries.) Cross-correlation is very closely related to convolution, but with an important sign change:

This effectively reverses the orientation of the kernel (g), compared to convolution. Rather than manually flipping the kernel, we correct for this by taking the complex conjugate of our kernel in Fourier space. This is significantly faster and more memory efficient, since we do not need to create an entirely new Tensor. (A brief demonstration of how/why this works is included in the appendix at the end of the article.)

PyTorch 1.7 brings improved support for complex numbers, but many operations on complex-valued Tensors are not supported in autograd yet. For now, we have to write our own complex_matmul method as a patch. It’s not ideal, but it works and likely won’t break for future versions.

4 — Compute the Inverse Transform

Computing the inverse transform is straightforward using torch.irfftn. Then, crop out the extra array padding.

5 — Add Bias and Return

Adding the bias term is also very easy. Remember that the bias has one element for each channel in the output array, and reshape accordingly.

Put It All Together

For completeness, let’s compile all of these snippets into a cohesive function.

Test Against Direct Convolution

Finally, we’ll confirm that this is numerically equivalent to direct 1D convolution using torch.nn.functional.conv1d. We construct random Tensors for all inputs, and measure the relative difference in the output values.

Each element differs by about 1e-5 — pretty accurate, considering that we’re using 32-bit precision! Let’s also perform a quick benchmark to measure the speed of each method:

Measured benchmarks will change significantly with the machine you’re using. (I’m testing with a very old Macbook Pro.) For a kernel size of 1025, it appears that Fourier convolution is over 10 times faster.

Conclusion

I hope this has provided a thorough introduction to Fourier convolutions. I think this is a really cool trick, and there are lots of real-world applications where it can be used. I also love math, so it’s fun to see this intersection of programming and pure mathematics. All comments and constructive criticism are welcome and encouraged, and please clap if you enjoyed the article!

Appendix

Convolution vs. Cross-Correlation

Earlier in the article, we implemented cross-correlation by taking the complex conjugate of our kernel in Fourier space. I claimed that this effectively reverses the orientation of the kernel, and now I’d like to demonstrate why that is. First, remember the formulae for convolution and cross-correlation:

Then, let’s look at the Fourier transform of our kernel (g):

Take the complex conjugate of G. Note that the kernel g(x) is real-valued, so it is unaffected by conjugation. Then, make a change of variables (y = -x) and simplify the expression.

So we’ve effectively flipped the orientation of the kernel!

--

--