# Wavelet Regression in Python

## Wavelet Regression in Python¶

Last week I needed to get my head around wavelet regression techniques for a project I am working on. This post will show how to do basic wavelet regression in Python using PyWavelets. For references I used Chapter 9 in Wasserman's All of Non-Parametric Statistics, Ogden's Essential Wavelets for Statistical Applications and Data Analysis, and Donoho and Johnstone's "Ideal spational adaptation by wavelet shrinkage" [pdf]. I also recommend "An Introduction to Wavelets" [pdf] by Amara Graps for a high level overview and a brief historical treatment of the development of wavelet analysis. These references will provide much more detail and guidance than I will here.

Wavelets are mathematical functions that provide an orthonormal basis for functions in the $L^2$ space. The use of wavelets is akin to the use of sines and cosines to represent $L^2$ functions in Fourier analysis. In fact, many treatments on wavelets start the discussion by reviewing Fourier series represenations of functions. Legendre polynomials are another complete orthonormal system.

The simplest wavelet is the Haar wavelet. It is not much used in practice but it is convenient to fix ideas. The Haar function is given

$$\psi(x)=\begin{cases}1,&x \in [0,1)\cr -1,&x \in [\frac{1}{2},1)\cr 0,&\text{otherwise}\end{cases}$$

In [1]:
import pywt
import matplotlib.pyplot as plt
In [2]:
w = pywt.Wavelet('Haar')
phi, psi, x = w.wavefun(level=10)

fig, ax = plt.subplots()
ax.set_xlim(-.02,1.02)
ax.plot(x, psi);

The location in the domain and the range of this mother wavelet function can be conrolled by a translation index $k$ and a dilation index $j$, respectively.

$$\psi_{j,k}(x)=2^{j/2}\psi(2^{j}x-k)$$

A wavelet system is fully defined once a scaling function, or father wavelet, is defined. The scaling function ensures that the orthonormal basis covers the original space of a function without having to use an infinite number of mother wavelets. The father wavelet for the Haar system is

$$\phi(x)=I_{[0,1)}(x)$$

where $I_A(x)$ is an indicator function that equals 1 if $x$ is in the set $A$. You can see that any dilation and translation can cover the original function space. While the Haar function is a good introduction, the rest of the code will use the most nearly symmetric Debauchies wavelet with N=8.

In [3]:
db8 = pywt.Wavelet('db8')
scaling, wavelet, x = db8.wavefun()

fig, axes = plt.subplots(1, 2, sharey=True, figsize=(8,6))
ax1, ax2 = axes

ax1.plot(x, scaling);
ax1.set_title('Scaling function, N=8');
ax1.set_ylim(-1.2, 1.2);

ax2.set_title('Wavelet, N=8');
ax2.tick_params(labelleft=False);
ax2.plot(x-x.mean(), wavelet);

fig.tight_layout()

For the current purposes, the objective of wavelet analysis is to recover the unknown function $f$ from noisy data

$$y_i = f(x_i) + e_i, i \in \left\{1,\dots,n\right\}$$

with $x_i = i/n$ without loss of generality and $e_i \sim N(0,\sigma^2)$. The expansion of the function $f$ as a finite sum of wavelets can be achieved as

$$f_J(x) = \alpha\phi(x) + \sum_{j=0}^{J-1}\sum_{k=0}^{2^j-1}\beta_{jk}\psi_{jk}(x)$$

where

$$\alpha=\int_0^1f(x)\phi(x)dx\text{, }\beta_{jk}=\int_0^1f(x)\psi_{jk}(x)dx.$$

$\alpha$ and $\beta_jk$ are called the scaling coefficients and the detail coefficients, respectively. The basic idea is that the detail coefficients capture the coarser details of the function while the scaling, or smoothing, coefficients capture the overall functional form.

Donoho and Johnstone use four functions that mimic properties of empirical data in domains where wavelets might be useful, Bumps, Blocks, HeaviSine, and Doppler. For completeness, they are defined here.

In [4]:
def doppler(x):
"""
Parameters
----------
x : array-like
Domain of x is in (0,1]

"""
if not np.all((x >= 0) & (x <= 1)):
raise ValueError("Domain of doppler is x in (0,1]")
return np.sqrt(x*(1-x))*np.sin((2.1*np.pi)/(x+.05))

def blocks(x):
"""
Piecewise constant function with jumps at t.

Constant scaler is not present in Donoho and Johnstone.
"""
K = lambda x : (1 + np.sign(x))/2.
t = np.array([[.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]]).T
h = np.array([[4, -5, 3, -4, 5, -4.2, 2.1, 4.3, -3.1, 2.1, -4.2]]).T
return 3.655606 * np.sum(h*K(x-t), axis=0)

def bumps(x):
"""
A sum of bumps with locations t at the same places as jumps in blocks.
The heights h and widths s vary and the individual bumps are of the
form K(t) = 1/(1+|x|)**4
"""
K = lambda x : (1. + np.abs(x)) ** -4.
t = np.array([[.1, .13, .15, .23, .25, .4, .44, .65, .76, .78, .81]]).T
h = np.array([[4, 5, 3, 4, 5, 4.2, 2.1, 4.3, 3.1, 2.1, 4.2]]).T
w = np.array([[.005, .005, .006, .01, .01, .03, .01, .01, .005, .008, .005]]).T
return np.sum(h*K((x-t)/w), axis=0)

def heavisine(x):
"""
Sinusoid of period 1 with two jumps at t = .3 and .72
"""
return 4 * np.sin(4*np.pi*x) - np.sign(x - .3) - np.sign(.72 - x)
In [5]:
x = np.linspace(0,1,2**11)
dop = doppler(x)
blk = blocks(x)
bmp = bumps(x)
hsin = heavisine(x)
In [6]:
fig, axes = plt.subplots(2, 2, figsize=(10,10))
ax1 = axes[0,0]
ax2 = axes[0,1]
ax3 = axes[1,0]
ax4 = axes[1,1]

ax1.plot(x,dop)
ax1.set_title("Doppler")

ax2.plot(x,blk)
ax2.set_title("Blocks")

ax3.plot(x,bmp)
ax3.set_title("Bumps")

ax4.set_title("HeaviSine")
ax4.plot(x,hsin)

for ax in fig.axes:
ax.tick_params(labelbottom=False, labelleft=False, bottom=False,
top=False, left=False, right=False)

fig.tight_layout();

The wavelet coefficients fully describe some data. First, generate some data from the noisy doppler function and apply a discrete wavelet transform to recover the smoothing and detail coefficients, then have a look at a common diagnostic plot of the wavelet coefficients using coef_pyramid_plot.

In [7]:
def coef_pyramid_plot(coefs, first=0, scale='uniform', ax=None):
"""
Parameters
----------
coefs : array-like
Wavelet Coefficients. Expects an iterable in order Cdn, Cdn-1, ...,
Cd1, Cd0.
first : int, optional
The first level to plot.
scale : str {'uniform', 'level'}, optional
Scale the coefficients using the same scale or independently by
level.
ax : Axes, optional
Matplotlib Axes instance

Returns
-------
Figure : Matplotlib figure instance
Either the parent figure of ax or a new pyplot.Figure instance if
ax is None.
"""

if ax is None:
import matplotlib.pyplot as plt
fig = plt.figure()
ax = fig.add_subplot(111, axisbg='lightgrey')
else:
fig = ax.figure

n_levels = len(coefs)
n = 2**(n_levels - 1) # assumes periodic

if scale == 'uniform':
biggest = [np.max(np.abs(np.hstack(coefs)))] * n_levels
else:
# multiply by 2 so the highest bars only take up .5
biggest = [np.max(np.abs(i))*2 for i in coefs]

for i in range(first,n_levels):
x = np.linspace(2**(n_levels - 2 - i), n - 2**(n_levels - 2 - i), 2**i)
ymin = n_levels - i - 1 + first
yheight = coefs[i]/biggest[i]
ymax = yheight + ymin
ax.vlines(x, ymin, ymax, linewidth=1.1)

ax.set_xlim(0,n)
ax.set_ylim(first - 1, n_levels)
ax.yaxis.set_ticks(np.arange(n_levels-1,first-1,-1))
ax.yaxis.set_ticklabels(np.arange(first,n_levels))
ax.tick_params(top=False, right=False, direction='out', pad=6)
ax.set_ylabel("Levels", fontsize=14)
ax.grid(True, alpha=.85, color='white', axis='y', linestyle='-')
ax.set_title('Wavelet Detail Coefficients', fontsize=16,
position=(.5,1.05))
fig.subplots_adjust(top=.89)

return fig

Generate the data and get the coefficients using the multilevel discrete wavelet transform. Plot the true coefficients and the noisy ones.

In [8]:
from scipy import stats
import numpy as np

np.random.seed(12345)
blck = blocks(np.linspace(0,1,2**11))
nblck = blck + stats.norm().rvs(2**11)

true_coefs = pywt.wavedec(blck, 'db8', level=11, mode='per')
noisy_coefs = pywt.wavedec(nblck, 'db8', level=11, mode='per')

fig, axes = plt.subplots(2, 1, figsize=(9,14), sharex=True)

fig = coef_pyramid_plot(true_coefs[1:], ax=axes[0]) # omit smoothing coefs
axes[0].set_title("True Wavelet Detail Coefficients");

fig = coef_pyramid_plot(noisy_coefs[1:], ax=axes[1]) ;
axes[1].set_title("Noisy Wavelet Detail Coefficients");

fig.tight_layout()

Notice that most of the coefficients of the true signal are zero. This is the idea of sparseness, most functions, smooth or otherwise, have a sparse respresentation in a wavelet basis. The detail coefficients are non-zero where the block function is not flat. The detail coefficients of the noisy signal have many more non-zero coefficients at the higher resolutions from the added noise. To recover the signal from the noisy coefficients, you threshold the coefficients. Essentially, thresholding sets many of the coefficients to zero by assuming that they are noise. There are a number of ways to go about performing thresholding. Details can be found in the given references. I will simply apply soft thresholding using the universal threshold. The universal threshold is defined

$$\lambda=\hat{\sigma}\sqrt{2\log(N)}$$

where $\hat{\sigma}$ is a robust estimator of the standard deviation of the finest level detail coefficients. Here, I use the standardized median absolute deviation available from statsmodels.

$$\hat{\sigma}=\text{M}\text{A}\text{D}(\beta_{J-1,\cdot})$$

In [9]:
from statsmodels.robust import stand_mad

sigma = stand_mad(noisy_coefs[-1])
uthresh = sigma*np.sqrt(2*np.log(len(nblck)))

denoised = noisy_coefs[:]

denoised[1:] = (pywt.thresholding.soft(i, value=uthresh) for i in denoised[1:])

We can recover the signal by applying the inverse discrete wavelet transform to the thresholded coefficients.

In [10]:
signal = pywt.waverec(denoised, 'db8', mode='per')

fig, axes = plt.subplots(1, 2, sharey=True, sharex=True,
figsize=(10,8))
ax1, ax2 = axes

ax1.plot(signal)
ax1.set_xlim(0,2**10)
ax1.set_title("Recovered Signal")
ax1.margins(.1)

ax2.plot(nblck)
ax2.set_title("Noisy Signal")

for ax in fig.axes:
ax.tick_params(labelbottom=False, top=False, bottom=False, left=False,
right=False)

fig.tight_layout()

PyWavelets is a really cool project. But it could use a little more attention. It would be nice if some of the library functions were wrapped up so that they return wavelet objects that would be easier to work with. For instance, if wavedec returned a discrete wavelet object that the thresholding functions also accepted, this would save me some keystrokes in the above. I do not have time to take on another project, but it is low hanging fruit if someone else does. The code is MIT licensed and might even find a nice home in SciPy.