Neural Finite Element part II: Einstein summation is all we need
Prologue
It has been almost 4 years since my last blog post, and how crazy the development AI tools has been since I served as a beta tester for GitHub’s GPT coding assistant Copilot back in 2021. Back then, Copilot’s interaction interface was comically immature at times, but still an okay auto-completion tool (worse than the VSCode built-in intellisense for each language). Right now, Copilot helps me refactor complex multi-library codes, re-design the variable allocation philosophy, and write unittest
case templates.
I drafted this post around Xmas break 2022, yet things changed so quickly, and many of the writing became so dated, but I still decided to put up this post as an archive. As of now, part of the “neural” finite element has been implemented in the LA-FEM extension of FEALPy. In the previous post, I ambitiously put “the first part among many on how to implement an FEM natively in PyTorch”, yet I guess I have to shelve that attempt indefinitely.
How a vectorized MATLAB code computes integrals
There are so many occasions in finite element computations one needs to compute the following quantity:
\(\int_K f(x)\phi_j(x) \, \mathrm{d} x \text{ for each nodal basis } \phi_j, \text{ for each element }K \text{ in the triangulation}.\)
In the previous how to implement MATLAB’s accumarray in PyTorch, when $\int_K f \phi_j$ is known as a piecewise constant on each $K$, torch.scatter
is used to assemble an aggregated array according to a user-provided index array, just as MATLAB’s accumarray
function. Here we are looking for how to compute the integral itself, in Long Chen’s iFEM (Chen, 2009), it is implemented by looping through the Gaussian quadrature points:
lambda = [1/3, 1/3, 1/3; ...
0.6, 0.2, 0.2; ...
0.2, 0.6, 0.2; ...
0.2, 0.2, 0.6];
weight = [-27/48, 25/48, 25/48, 25/48];
nQuad = size(lambda,1); % number of quadrature points
bt = zeros(NT,3);
phi = lambda;
for p = 1:nQuad
% quadrature points in the x-y coordinate
pxy = phi(p,1)*node(elem(:,1),:) ...
+ phi(p,2)*node(elem(:,2),:) ...
+ phi(p,3)*node(elem(:,3),:);
fp = pde.f(pxy);
for i = 1:3
bt(:,i) = bt(:,i) + weight(p)*phi(p,i)*fp;
end
end
NT
denotes $|\mathcal{T}_h|$ the number of elements. lambda
stores the Barycentric coordinates’ values at the Gaussian quadrature points, and for linear conforming elements, the nodal basis phi
happens to be lambda
. The weight
corresponds to the weights associated with quadrature points of how to aggregate them.
The MATLAB code snippet above computes:
The resulting array bt
is an (NT, 3)
array representing
The MATLAB implementation is already “vectorized”, as there is no loop through 1:NT
, i.e., all the elements. However, it has one key problem: the dimensions are all hard-coded. For example, the number of nodal basis, one has to manually make sure the number of quadrature points match the number of weights. Due to these constraints, one has to manually write down all the forms of nodal basis function before implementing it. Of course, a more modern software design would be decoupling the basis building with the numerical quadrature used to compute various terms through object-oriented programming, which brings to the next section.
Einsum is all we need
It turns out the same MATLAB code snippet above, can be implemented using einsum
with only 1 line! The original code I implemented over four years ago retains the elegancy yet is basis-agnostic:
fK = torch.stack([pde.source(p) for p in quadPts], dim=0)
bt = torch.einsum("q, qn, qp, n->np", weight, fK, phi, area)
The Gaussian quadrature is implemented simply by vector contraction using the Einstein summation tradition. In the einsum
function’s string input, q
stands for the number of quadrature points, p
stands for the number of nodal basis in each element, and n
. None of these dimensions have to be specified explicitly. If I re-implement this integral today, I would break the two liner into more modular design, but the core coding philosophy still centers around einsum
. Another example would be to compute
for higher-order elements using Gaussian quadrature, one can use
torch.einsum('q, qnid, qnjd, n->nij', weight, gradphi, gradphi, area)
Here gradphi
is an array of shape (q, n, i, d)
that stores the i
-th nodal basis’s gradient vector of d
dimensions in n
-th element’s q
’s quadrature point. The lowest-order implementation can be simpler and I saved an index q
.
Ellipsis notation
What is even better with einsum
is that the tensor contraction rule can even be vaguely specified using ...
. For example, I re-implemented the Fourier Neural Operator in a recent paper I wrote with my collaborators (Cao et al., 2025), and the frequency domain multiplication operator has been re-written as follows in a dimension-agnostic fashion just like the integral above:
torch.einsum("bi...,io...->bo...", x, w)
which represents the contraction rules
1D: (b, c_i, x), (c_i, c_o, x) -> (b, c_o, x)
2D: (b, c_i, x, y), (c_i, c_o, x, y) -> (b, c_o, x, y)
3D: (b, c_i, x, y, t), (c_i, c_o, x, y, t) -> (b, c_o, x, y, t)
In conventional vectorization philosophy, looping through a small fixed dimension is good. For example, the number of quadrature points. However, in the era of tensorized computing, this actually hampers the software to go further scalable. It turns out, einsum
is the only thing we need to perform aggregation-like operators.
LLM’s Offering
As of 2025, of course one wonders how LLMs would write the routine above. Currently in VSCode Copilot, I use GPT 4o-mini (kindly provided by University of Missouri’s OpenAI Enterprise API key) and Claude 4.0 (by GitHub education). Using a sequence of prompting, GPT 4o-mini gives the following runnable testing codes:
# Extend to integrate f(x,y) * φ_j over all elements in a triangulation
# Output: (n_elements, n_basis_per_element) array
# Affine map from reference to physical triangle
def affine_map_and_jacobian_np(v0, v1, v2, ref_points):
v0, v1, v2 = map(np.array, (v0, v1, v2))
T = np.stack((v1 - v0, v2 - v0), axis=1) # 2x2
mapped = ref_points @ T.T + v0 # (n, 2)
J = abs(np.linalg.det(T)) # scalar
return mapped, J
def integrate_over_mesh(f, elements, vertices, quad_points, quad_weights):
"""
Vectorized version: Computes ∫_{K} f(x, y) * φ_j(x, y) dx dy for all triangles in a mesh,
using a custom quadrature rule, without explicit loops over elements.
Parameters:
f: function from R^2 -> R
elements: (n_elements, 3) array of triangle vertex indices
vertices: (n_vertices, 2) array of (x, y) coordinates
quad_points: (n_qp, 2) quadrature points on reference triangle
quad_weights: (n_qp,) corresponding weights
Returns:
(n_elements, 3) array of integrals for each φ_j per triangle
"""
n_elements = elements.shape[0]
n_qp = quad_points.shape[0]
# Get triangle vertices: shape (n_elements, 3, 2)
tri_vertices = vertices[elements] # (n_elements, 3, 2)
v0 = tri_vertices[:, 0, :] # (n_elements, 2)
v1 = tri_vertices[:, 1, :]
v2 = tri_vertices[:, 2, :]
# Compute transformation matrices and Jacobians
T = np.stack((v1 - v0, v2 - v0), axis=2) # (n_elements, 2, 2)
J = np.abs(np.linalg.det(T)) # (n_elements,)
# Map all quadrature points: use broadcasting
mapped_qp = quad_points[None, :, :] @ T.transpose(0, 2, 1) \
+ v0[:, None, :] # (n_elements, n_qp, 2)
# Evaluate f at all mapped quadrature points
f_vals = f(mapped_qp.reshape(-1, 2)).reshape(n_elements, n_qp)
# (n_elements, n_qp)
# Evaluate basis functions at quadrature points (same for all elements)
phi_vals = p1_basis_functions_np(quad_points)
# (n_qp, 3)
# Apply quadrature rule with einsum
# w_i * f_k_i * phi_i_j → sum over i (quadrature points)
integrals = J[:, None] * np.einsum("i,ki,ij->kj", quad_weights, f_vals, phi_vals)
# (n_elements, 3)
return integrals
# Example mesh: 2 triangles covering unit square (0,0)-(1,0)-(1,1)-(0,1)
vertices = np.array([
[0.0, 0.0], # 0
[1.0, 0.0], # 1
[1.0, 1.0], # 2
[0.0, 1.0] # 3
], dtype=np.float32)
elements = np.array([
[0, 1, 3], # lower left triangle
[1, 2, 3] # upper right triangle
], dtype=np.int32)
# Define function f(x, y)
f_example_np = lambda x: x[:, 0]**2 + x[:, 1]
# Test with a different quadrature rule: 4-point quadrature (3rd-order accurate)
quad_points_4 = np.array([
[1/3, 1/3],
[0.6, 0.2],
[0.2, 0.6],
[0.2, 0.2]
], dtype=np.float32)
quad_weights_4 = np.array([
-27/96,
25/96,
25/96,
25/96
], dtype=np.float32)
# Compute integrals for the mesh
mesh_integrals = integrate_over_mesh(f_example_np,
elements,
vertices,
quad_points_4,
quad_weights_4)
Initially both GPT 4o and Claude gave sequential code that loop through number of elements. After explaining each variables’ dimension, they come up with similar codes. I would say not bad at all, eh. Maybe I shall lose my job sooner than later.
Comments