Neural Finite Element part I: how to implement MATLAB’s accumarray in PyTorch

5 minute read

MATLAB’s accumarray vs PyTorch’s scatter

The function accumarray in MATLAB is arguably the most used function for finite element (FEM) implementation in MATLAB if one wants to implement from ground up, not just calling high level interface to do assembling for you.

For example, assuming $\phi_i$ being our locally supported finite element basis (either Lagrange or vectorial basis like Nédélec), say we want to produce a vector array f such that f[i] is

\[\int_{\Omega} f \cdot \phi_i = \int_{\operatorname{supp} \phi_i} f\cdot \phi_i = \sum_{K\subset \operatorname{supp} \phi_i}\int_{K} f\cdot \phi_i.\]

Knowing $f$ is a piecewise constant with respect to the current triangulation, what we want to do here to “accumulate” $f$’s integral on each element $K$ that lies in each basis function $\phi_i$’s support.

Here is how we can use MATLAB’s accumarray, that is ultra fast and implemented in FORTRAN, to do it whiling avoiding the notoriously slow for loop in MATLAB. If one needs to do a linear complexity search in each iteration of the for loop, then that algorithm will become even slower.

The codes are copied from a work that had and is continuing to have the most impactful influence on my academic career path: (Chen, 2009). For example, the following code assumes f is a function handle (similar to a Callable eval in Python), and compute a linear approximation to the integral above for each $\phi_i$.

% N: number of nodes
% NT: number of elements
% elem: an NT by 3 array storing the element-to-node mapping indices
% node: an N by 2 array storing the coordinates of each node

bt = zeros(NT,3);
% linear quadrature in 2d triangle
% note the 1/3 of the Barycentric coords
% coincide with the nodal basis's value later
pxy = (1/3)*node(elem(:,1),:) ...
 + (1/3)*node(elem(:,2),:) ...
 + (1/3)*node(elem(:,3),:);
fp = f(pxy);
for i = 1:3
    bt(:,i) = bt(:,i) + (1/3)*fp;
end
bt = bt.*repmat(area,1,3);
b = accumarray(elem(:),bt(:),[N 1]);

The key line is the last one, which accumulates the integral’s contribution on each element to produce an array following the index of the nodes. The elem stores the indices, bt stores the values.

Translation using a small example

Here is a small example taken from one of my previous post:

The following code adds up vals in subs:

A = accumarray(subs,vals,sz)

For example, if you have the indices as follows

subs = [1 2 3 1 3 1 5 5]';
b = accumarray(subs,1);

Then b will be [3 1 2 0 2]' which “accumulates” the indices in the subs array. In this special case, by letting vals be 1, we are just recording the number of times each number appearing in the subs array, for example 1 appears three times thus the first entry of b is 3, etc.

If further we do something like

subs = [1 2 3 1 3 1 5 5]';
vals = [0.8 1.0 0 0 -0.2 0.3 1.0 -0.5]';
b = accumarray(subs,vals);

Now b becomes [1.1 1 -0.2 0 0.5]'. For example, the index 5 has vals to be 1.0 and -0.5, adding them becomes 0.5.

How do we do the same in PyTorch without using list comprehension or a loop? The savior is a function called scatter_. The manual of it reads:


Tensor.scatter_(dim, index, src, reduce=None)  Tensor
'''
Writes all values from the tensor src into self at the 
indices specified in the index tensor. For each value in src, 
its output index is specified by its index in src 
for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:
'''
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

I would say the example on the documentation site is not very intuitive toward a finite element practitioner. But to replicate the result above, we can simply do (note the PyTorch’s indices follow Python’s tradition starting from zero, opposing to MATLAB’s anti-human indexing system)

subs = torch.tensor([0, 1, 2, 0, 2, 0, 4, 4])
vals = torch.tensor([0.8, 1.0, 0, 0, -0.2, 0.3, 1.0, -0.5])
b = torch.zeros((subs.max()+1, ))
b.scatter_(-1,
           index=subs,
           src=vals,
           reduce='add')

The output result?

tensor([ 1.10,  1.00, -0.20,  0.00,  0.50])

A perfect match with MATLAB’s accumarray.

Performance test

If we use the simple regular mesh generator in iFEM, a benchmark is as follows:

[node,elem] = squaremesh([0 1 0 1], 1/512);
N = size(node, 1);
NT = size(elem, 1);

t = zeros(100, 1);

for i = 1:100
    vals = randn(NT, 3);
    tic;
    b = accumarray(elem(:),vals(:),[N, 1]);
    t(i) = toc;
end

mean(t)

The average time to accumulate an array of size 513**2 = 263169 is 0.0062739 seconds on an Intel 10850k CPU.

Now let us see how PyTorch rallies against MATLAB’s Fortran implementation. Using some in-house PyTorch replica of the mesh generator:

node, elem = rectangleMesh(x_range=(0, 1), y_range=(0, 1), h=1/512)
N, NT = node.size(0), elem.size(0)
dtype = torch.float64
t = []

for _ in range(100):
    vals = torch.randn((NT, 3),dtype=dtype)
    start = time.time()
    b = torch.zeros((N, ), dtype=dtype)
    b.scatter_(-1,
            index=elem.view(-1),
            src=vals.view(-1),
            reduce='add')
    t.append(time.time() - start)

print(f"{torch.tensor(t).mean():.7f}")

I am surprised that PyTorch actually pulls a faster record than MATLAB! The average time is only averaging 0.0045529 seconds on CPU using double precision! Using single precision on GPU (RTX 3090), the result is even crazier with only 0.0003962 seconds.

In next post in the series, I will talk about the difference of scatter_ and gather in PyTorch in the context of finite element, and the scenarios to use them both.

References

  1. Chen, L. (2009) iFEM: an integrated finite element methods package in MATLAB, Technical Report, University of California at Irvine. Available at: https://github.com/lyc102/ifem.

Comments