Stochastic Taylor Derivative Estimator (STDE)

Authors

Shi Zekun

Published on

January 13, 2025

How to tackle the curses of dimensionality and the exponential curse in derivative order


The problem

Suppose we want to solve optimization problems where the loss function ff contains differential operators

arg minθf(x,uθ(x),Dα(1)uθ(x),,Dα(n)uθ(x)),uθ:RdRd. \argmin_{\theta} f(\mathbf{x}, u_{\theta }(\mathbf{x}), \mathcal{D}^{\alpha^{(1)} } u_{\theta }(\mathbf{x}), \dots, \mathcal{D}^{\alpha^{(n)} } u_{\theta }(\mathbf{x})), \quad u_{\theta }:\mathbb{R}^{d} \to \mathbb{R}^{d'}.

where uθu_{\theta } is a neural network. A prominent example of this problem is physics-informed neural networks (PINN), where the loss is the PDE residual.

Naturally, we will be using auto-diff (AD) to handle the differential operators Dα(i)\mathcal{D}^{\alpha^{(i)} }. But how to handle high-order derivatives like 2uθx2\frac{\partial^2 u_{\theta }}{\partial x^2}? The simplest way would be to apply backward mode AD (backpropagation) twice. In JAX, this can be done as

jax.grad(jax.grad(u))

and in PyTorch

u_x = torch.autograd.grad(u(x), x, create_graph=True)[0]
u_xx = torch.autograd.grad(u_x, x)[0]

However, doing this presents a curse of dimensionality and an exponential curse of order! After doing some asymptotic analysis, one would find that the memory scaling is O(2k1(d+(L1)h))\mathcal{O}({\color{red}2^{k-1}}({\color{BurntOrange}d}+(L-1)h)), and the compute scaling is O(2k(dh+(L1)h2))\mathcal{O}({\color{red}2^{k}}({\color{BurntOrange}d}h+(L-1)h^{2})). Notice the curses of dimensionality and the exponential curse in derivative order.

Why repeatedly applying backward mode AD is bad

Now, let's look into the compute graph of repeated backward mode AD and see why it is a bad idea. Suppose we have a 44 layer MLP u=F4F3F2F1u=F_{4}\circ F_{3} \circ F_2 \circ F_1 with hidden size hh. Denotes the activation as yi=Fi(yi1)\mathbf{y}_{i}=F_{i}(\mathbf{y}_{i-1}) and intermediate cotangents be vi=vi1FLi+1x\mathbf{v}_{i}^{\top}=\mathbf{v}_{i-1}^{\top}\frac{\partial F_{L-i+1}}{\partial x}.

Performing backward mode AD computes the vector-Jacobian-product (VJP) vFx\mathbf{v}^{\top}\frac{\partial F}{\partial \mathbf{x}} with the cotangent v=1\mathbf{v}^{\top}=1. The first row of the compute graph below depicts the VJP: first, the forward pass is performed, then the backward pass is performed. Notice that the backward pass can only be performed once the forward pass is completed since the activations are needed.

pic1

Now, suppose we apply VJP twice. This essentially treats the entire VJP compute graph (the blue box) as the forward pass, and as before, a backward pass is created for each node in the forward pass, and we now have twice as many activations to store:

pic2

From this, we see that the compute graph doubles with every repeated application of backward mode AD, which is the origin of the exponential curse in derivative order.

Amortization

The cost of an expensive operation can be amortized over an iterative optimization procedure by using a cheap stochastic estimator. The most well-known example is the stochastic gradient descent (SGD), which uses an estimated gradient that can be computed cheaply instead of the full gradient:

The idea of SGD was extended to differential operators in a recent work SDGD (Hu et al. 2023). Take the Laplacian operator as an example. The Laplacian of a function is the sum of diagonal elements of its Hessian. If we treat these diagonal elements as data, at each gradient descent step, we can use a mini-batch of data instead of the full batch by sampling uniformly among these diagonal elements:

2=i=1d2xi21BjBxIj2\nabla^2 =\sum_{i=1}^d \frac{\partial^2}{\partial x_{i}^2} \approx \frac{1}{B}\sum_{j}^B \frac{\partial}{\partial x^2_{I_{j}}}

Essentially, SDGD converts a d-dim problem into BB 1-dim problems. By employing amortization, SDGD effectively removes the curse of dimensionality:

  • memory complexity: O{2k1(d+(L1)h)}O{B(2k1(1+(L1)h))}\mathcal{O}\{{\color{red}2^{k-1}}({\color{orange}d}+(L-1)h)\} \to \mathcal{O}\{{\color{orange}B} ({\color{red}2^{k-1}}({\color{orange}1}+(L-1)h))\}.
  • computation complexity: O{2k(dh+(L1)h2)}O{B2k(h+(L1)h2)}\mathcal{O}\{{\color{red}2^{k}}({\color{orange}d}h+(L-1)h^{2})\} \to \mathcal{O}\{{\color{orange}B}{\color{red}2^{k}}(h+(L-1)h^{2})\}.

With SDGD, one can solve 100,000100,000 dimensional PDEs with a 40GB A100 GPU in around 12 hours. Certainly a big step forward, but the exponential curse of the derivative order persists if one looks at the asymptotics above. The amortization would be much more efficient if the exponential scaling in derivative order is reduced. This is exactly the goal of STDE (Shi et al. 2024).

Generalizing the Hutchinson trace estimator (HTE)

The construction of STDE is very much inspired by HTE (Hutchinson 1989), which is a cheap stochastic estimator for matrix trace. The gist of HTE can be written down in just a few lines

tr(A)=Evp(v)[vTAv]=AEvp(v)[vvT]=AI,vRd1Bv(i)vTAv,v(i)p(v)\begin{split} \text{tr}(\mathbf{A}) =& \mathbb{E}_{\mathbf{v} \sim p(\mathbf{v})}\left[ \mathbf{v}^{\mathrm{T}} \mathbf{A} \mathbf{v}\right] =\mathbf{A} \cdot {\mathbb{E}_{\mathbf{v} \sim p(\mathbf{v})}\left[ \mathbf{v} \mathbf{v}^{\mathrm{T}}\right]} =\mathbf{A} \cdot {\mathbf{I}}, \quad \mathbf{v} \in \mathbb{R}^d \\ \approx& \frac{1}{B} \sum_{\mathbf{v}^{(i)}} \mathbf{v}^{\mathrm{T}} \mathbf{A} \mathbf{v}, \quad \mathbf{v}^{(i)} \sim p(\mathbf{v}) \end{split}

where the highlighted part, i.e., the distribution pp is isotropic, is the constraint that must be satisfied for the equality to hold. Geometrically, this constraint can be understood as a random projection vvT\mathbf{v} \mathbf{v}^{\mathrm{T}} that is a constant map in expectation.

HTE can be applied to provide a stochastic estimation of the Laplacian since the Laplacian is the trace of Hessian. Can we extend the construction for arbitrary differential operators? The first step is to write a differential operator L\mathcal{L} in the following form:

Lu(a)=Duk(a)C(L).\mathcal{L}u(\mathbf{a}) =D^{k}_{u}(\mathbf{a}) \cdot \mathbf{C}(\mathcal{L}).

where Duk(a)D^{k}_{u}(\mathbf{a}) is the kth order derivative tensor of uu at point a\mathbf{a}, and C(L)\mathbf{C}(\mathcal{L}) is a coefficient tensor of the same shape as Duk(a)D^{k}_{u}(\mathbf{a}). For example, the Laplacian can be written in this form as

2u(a)=i=1d2uxi2=Du2(a)IC(2).\nabla^2 u(\mathbf{a}) = \sum_{i=1}^d \frac{\partial^2 u}{\partial x^2_{i}} = D^{2}_{u}(\mathbf{a}) \cdot \underbrace{\mathbf{I}}_{C(\nabla^2)}.

So HTE applied to the Laplacian can be interpreted as a randomized rank-1 decomposition of the coefficient tensor I\mathbf{I}. To generalize the idea for arbitrary differential operator, one just needs to replace the isotropic condition with a randomized rank-1 decomposition of a general coefficient tensor C(L)\mathbf{C}(\mathcal{L}):

Ep[i=1kv(vi)]=C(L)\mathbb{E}_{p}\left[\otimes_{i=1}^{k}\mathbf{v}^{(v_i)}\right] = \mathbf{C}(\mathcal{L})

With distribution pp satisfying the above, we can write the action of the differential operator on the network uu as an expectation over random projections of DukD^{k}_{u}, which can be estimated via Monte Carlo:

Lu(a)=Ep[Duk(a)i=1kv(vi)].\mathcal{L}u(\mathbf{a}) = \mathbb{E}_{p}\left[D^{k}_{u}(\mathbf{a}) \cdot \otimes_{i=1}^{k}\mathbf{v}^{(v_i)}\right].

The challenge now is to compute the random projection Duk(a)i=1kv(vi)D^{k}_{u}(\mathbf{a}) \cdot \otimes_{i=1}^{k}\mathbf{v}^{(v_i)} efficiently. One would need to avoid computing the full derivative tensor DukD^{k}_{u} since it grows exponentially to the derivative order kk:

  • k=1k=1 (Jacobian):
Du1=[ux1uxd]Rd\begin{align} D^{1}_{u}=\left[ \frac{\partial u}{\partial x_{1}} \dots \frac{\partial u}{\partial x_{d}}\right] \in \mathbb{R}^{d} \end{align}
  • k=2k=2 (Hessian):
Du2=[ux1x1ux1xduxdx1uxdxd]Rd×d\begin{align} D^{2}_{u}=\begin{bmatrix} \frac{\partial u}{\partial x_{1} \partial x_{1}} & \dots & \frac{\partial u}{\partial x_{1} \partial x_{d}} \\ \vdots & & \vdots \\ \frac{\partial u}{\partial x_{d}\partial x_{1}} & \dots & \frac{\partial u}{\partial x_{d} \partial x_{d}} \\ \end{bmatrix} \in \mathbb{R}^{d\times d} \end{align}
  • k=3k=3: Du3=[3uxixjxk]ijkRd×d×dD^{3}_{u}=[ \frac{\partial ^{3} u}{ \partial x_{i} \partial x_{j} \partial x_{k}} ]_{ijk} \in \mathbb{R}^{d\times d\times d}

It turns out that one can use high-order directional derivatives to compute the said random projection efficiently. In the following, we will discuss

  1. how to express arbitrary projection of the derivative tensor to high-order directional derivatives, and
  2. how high-order directional derivatives can be computed efficiently.

High-order directional derivatives

First, let's review the concept of first-order directional derivative. Suppose we have a scalar-valued function u:RdRu: \mathbb{R}^{d} \to \mathbb{R}, the directional derivative of uu at point a\mathbf{a} in the direction of v\mathbf{v} is the rate of change of uu along a curve g(t)=a+tvg(t)=\mathbf{a}+ t \mathbf{v}:

u(a,v):=vu(a)=ddt[ug](0)=uxv.\partial u(\mathbf{a}, \mathbf{v}) :=\partial_{\mathbf{v}} u(\mathbf{a}) = \frac{\mathrm{d}}{\mathrm{d}t}[u\circ g](0) = \frac{\partial u}{\partial x} \mathbf{v}.

The last equality comes from the chain rule, from which we see that the directional derivative is a Jacobian-vector-product (JVP). One important thing to notice is that, regardless of the input dimension of uu, its restriction to the curve ugu\circ g is always a one-dimensional function, as can be seen in the illustration below. This means that v\mathbf{v} always has the same dimension as the inputs.

pic3

Now, we are ready to generalize this concept to a higher order. Suppose the curve is not a straight line, so that it has non-zero derivatives up to, say, order kk. Let v(n)=ngtnt=0\mathbf{v}^{(n)}=\frac{\partial^n g}{\partial t^n} \vert_{t=0}. We call v(n)\mathbf{v}^{(n)} the nth input tangents. We can now compute up to the kth order rate of change of uu, along this curve gg, since ugu\circ g has a non-zero derivative of up to kth order. With this, we define the direction derivative of order kk as the kth order rate of change along a curve with input tangents v(n)n=1k\\{\mathbf{v}^{(n)} \\}_{n=1}^{k}:

ku(a,v(1),,v(k))=ktk[ug](0).\partial^{k}u(\mathbf{a}, \mathbf{v}^{(1)}, \dots, \mathbf{v}^{(k)}) = \frac{\partial^k}{\partial t^k} [u \circ g](0).

The above expression can be expanded by the Faa di Bruno's formula, which can be understood as the high-order chain rule:

ktk[ug](t)=(p1,,pk)Nk, i=1kipi=kk!ikpi!(i!)piDui=1kpi(a)d1,,di=1kpij=1k(1j!vdj(j))pj. \frac{\partial^k}{\partial t^k} [u\circ g](t) = \sum_{\substack{(p_1, \dots, p_{k})\in \mathbb{N}^{k}, \\\ \sum_{i=1}^k i\cdot p_i=k}} \frac{k!}{\prod_{i}^{k} p_{i}! (i!)^{p_{i}}} \cdot {\color{title}D_{u}^{\sum_{i=1}^k p_{i}} (\mathbf{a})_{d_1, \dots , d_{\sum_{i=1}^k p_{i}}} \cdot \prod_{j=1}^{k} \left( \frac{1}{j!} v^{(j)}_{d_{j}} \right)^{p_{j}}}.

The important observation is that we can find arbitrary projection Duk(a)i=1kv(vi)D^{k}_{u}(\mathbf{a}) \cdot \otimes_{i=1}^{k}\mathbf{v}^{(v_i)} in this formula!

The actual procedure of expressing an arbitrary projection as high-order directional derivatives is complicated, so I'll omit it here. To understand this process intuitively, here are all possible contractions for k=2k=2 expressed as high-order directional derivatives:

2u(a,v,0)=2uxixjvivj,3u(a,v,v,0)3u(a,v,0,0)=2uxixjvivj.\begin{aligned} \partial^{2}u(\mathbf{a},\mathbf{v},\mathbf{0})=\frac{\partial^2 u}{\partial x_{i} \partial x_{j}}v_{i}v_{j}, \\ \partial^{3}u(\mathbf{a},\mathbf{v},\mathbf{v}',\mathbf{0})-\partial^{3}u(\mathbf{a},\mathbf{v},\mathbf{0},\mathbf{0})=\frac{\partial^2 u}{\partial x_{i} \partial x_{j}}v_{i}v'_{j}. \end{aligned}

The highest order of directional derivative required is 33 for the above case. In general, this order will not be too big. In the paper, I show that arbitrary contraction of the derivative tensor DukD_{u}^{k} can be computed with l\partial^{l} where ll is at most k(k+1)/2k(k+1) / 2.

Finally, it is worth mentioning that the above results on forward propagation of univariate Taylor series is the generalization of the previous work (Griewank, Utke, and Walther 2000) which only uses first order perturbation/input tangent.

Taylor-mode AD

Forward mode AD computes JVP, which are first-order direction derivatives. Analogously, Taylor-mode AD (Bettencourt, Johnson, and Duvenaud 2019) computes high-order direction derivatives using only forward passes. From Faa di Bruno's formula, one would notice that the order kk directional derivative depends on all input tangents from order 11 to kk. So instead of computing ku\partial^{k} u directly, one would need to compute the whole Taylor tower dku=(u,1u,,ku)\mathrm{d}^{k} u = (u, \partial^{1} u, \dots, \partial^{k} u), as depicted below.

pic4

From the compute graph, we see that scaling in kk is now linear instead of exponential, and the computation can be parallelized.

Experiment results

Now you may wonder, in practice, how much speedup we can expect from STDE, and what are the sources of performance gain.

In section 5.2 of the STDE paper (Shi et al. 2024), I did an ablation study on a two-body Allen-Cahn equation, which is a nonlinear PDE with a zero boundary condition:

Lu(x)=2u(x)+u(x)u(x)3=f(x),xBdu(x)=0,xBd \begin{split} \mathcal{L} u(\mathbf{x}) = \nabla^2 u(\mathbf{x}) + u(\mathbf{x}) - u(\mathbf{x})^{3} =& f(\mathbf{x}), \quad \mathbf{x} \in \mathbb{B}^{d} \\ u(\mathbf{x}) =& 0, \quad\quad \mathbf{x} \in \partial\mathbb{B}^{d} \\ \end{split}

where the source term is chosen to ensure that the solution is effectively high-dimensional

f(x)=L{(1x22)(i=1d1cisin(xi+cos(xi+1)+xi+1cos(xi)))},ciN(0,1).\begin{aligned} f(x) =& \mathcal{L} \left\{ (1- ||\textbf{x}||_{2}^{2}) \left( \sum_{i=1}^{d-1} c_{i} \sin ( x_{i} + \cos (x_{i+1}) + x_{i+1} \cos (x_{i}) ) \right)\right\}, \quad c_{i}\sim \mathcal{N}(0,1). \end{aligned}

We will amortize the PINN training by using a stochastic estimator for the Laplacian term. The original implementation of the baseline method SDGD uses a for-loop to iterate through the sampled dimension final method STDE, and it was implemented in PyTorch. To separate the performance gain, I implemented SDGD in JAX (second row) and also variants of the first-order AD method that are more efficient (rows 3,4). For more details on these, see Appendix A in the paper.

I also included Forward Laplacian (Li et al. 2023), which provides a constant-level optimization for the calculation of the Laplacian operator by removing the redundancy in the AD pipeline but is not randomized. As expected, it performs very well in the low-dimensional cases but does not scale well to dimension.

Table 1: Speed Ablation

Speed (it/s) \uparrow100 D1K D10K D100K D1M D
Backward mode SDGD (PyTorch) (Hu et al. 2023)55.563.701.850.23OOM
Backward mode SDGD40.6337.0429.85OOMOOM
Parallelized backward mode SDGD1376.84845.21216.8329.24OOM
Forward-over-Backward SDGD778.18560.91193.9127.18OOM
Forward Laplacian (Li et al. 2023)1974.50373.7332.15OOMOOM
STDE1035.091054.39454.16156.9013.61

For the case of 1M D, the model converges with <10k<10k steps, which only take ~1010 minutes!

Table 2: Memory Ablation

Memory (MB) \downarrow100 D1K D10K D100K D1M D
Backward mode SDGD (PyTorch) (Hu et al. 2023)13281788452732777OOM
Backward mode SDGD5535651217OOMOOM
Parallelized backward mode SDGD53957911774931OOM
Forward-over-Backward SDGD53757915194929OOM
Forward Laplacian (Li et al. 2023)5079135505OOMOOM
STDE54353779510736235

Memory saving of STDE was significant: for the case of 1M D, only ~6GB of memory is required, whereas for all other methods, the memory requirement is beyond 40GB.


Note

STDE received the best paper award at NeurIPS 2024. The code can be found here.


References

Bettencourt, Jesse, Matthew J. Johnson, and David Duvenaud. 2019. “Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX.” In Program Transformations for Ml Workshop at Neurips 2019. https://openreview.net/forum?id=SkxEF3FNPH.
Griewank, Andreas, Jean Utke, and Andrea Walther. 2000. “Evaluating Higher Derivative Tensors by Forward Propagation of Univariate Taylor Series.” Mathematics of Computation 69 (231): 1117–31. https://doi.org/10.1090/s0025-5718-00-01120-0.
Hutchinson, M.F. 1989. “A Stochastic Estimator of the Trace of the Influence Matrix for Laplacian Smoothing Splines.” Communications in Statistics - Simulation and Computation 18 (3): 1059–76. https://doi.org/10.1080/03610918908812806.
Hu, Zheyuan, Khemraj Shukla, George Em Karniadakis, and Kenji Kawaguchi. 2023. “Tackling the Curse of Dimensionality with Physics-Informed Neural Networks.” arXiv. https://doi.org/10.48550/arXiv.2307.12306.
Li, Ruichen, Haotian Ye, Du Jiang, Xuelan Wen, Chuwei Wang, Zhe Li, Xiang Li, et al. 2023. “Forward Laplacian: A New Computational Framework for Neural Network-Based Variational Monte Carlo.” arXiv. https://doi.org/10.48550/arXiv.2307.08214.
Shi, Zekun, Zheyuan Hu, Min Lin, and Kenji Kawaguchi. 2024. “Stochastic Taylor Derivative Estimator: Efficient Amortization for Arbitrary Differential Operators.” In The Thirty-Eighth Annual Conference on Neural Information Processing Systems. https://openreview.net/forum?id=J2wI2rCG2u.