Skip to content

JITable grid and transform creation #632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/regression_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ jobs:
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Set Swap Space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Test with pytest
run: |
pwd
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ jobs:
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Set Swap Space
uses: pierotofy/set-swap-space@master
with:
swap-size-gb: 10
- name: Test with pytest
run: |
pwd
Expand Down
124 changes: 78 additions & 46 deletions desc/basis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Classes for spectral bases and functions for evaluation."""

import functools
from abc import ABC, abstractmethod
from math import factorial

Expand Down Expand Up @@ -773,14 +774,7 @@ def evaluate(
lm = lm[lmidx]
m = m[midx]

# some logic here to use the fastest method, assuming that you're not using
# "unique" within jit/AD since that doesn't work
if unique and (np.max(modes[:, 0]) <= 24):
radial_fun = zernike_radial_poly
else:
radial_fun = zernike_radial

radial = radial_fun(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
poloidal = fourier(t[:, np.newaxis], m, 1, derivatives[1])

if unique:
Expand Down Expand Up @@ -817,7 +811,7 @@ def change_resolution(self, L, M, sym=None):


class ChebyshevDoubleFourierBasis(Basis):
"""3D basis: tensor product of Chebyshev poynomials and two Fourier series.
"""3D basis: tensor product of Chebyshev polynomials and two Fourier series.

Fourier series in both the poloidal and toroidal coordinates.

Expand Down Expand Up @@ -856,7 +850,7 @@ def _get_modes(self, L=0, M=0, N=0):
Parameters
----------
L : int
Maximum radial resoltuion.
Maximum radial resolution.
M : int
Maximum poloidal resolution.
N : int
Expand Down Expand Up @@ -1117,6 +1111,8 @@ def evaluate(
lm = modes[:, :2]

if unique:
# TODO: can avoid this here by using grid.unique_idx etc
# and adding unique_modes attributes to basis
_, ridx, routidx = np.unique(
r, return_index=True, return_inverse=True, axis=0
)
Expand All @@ -1142,14 +1138,7 @@ def evaluate(
m = m[midx]
n = n[nidx]

# some logic here to use the fastest method, assuming that you're not using
# "unique" within jit/AD since that doesn't work
if unique and (np.max(modes[:, 0]) <= 24):
radial_fun = zernike_radial_poly
else:
radial_fun = zernike_radial

radial = radial_fun(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
poloidal = fourier(t[:, np.newaxis], m, dt=derivatives[1])
toroidal = fourier(z[:, np.newaxis], n, NFP=self.NFP, dt=derivatives[2])
if unique:
Expand Down Expand Up @@ -1192,7 +1181,7 @@ def change_resolution(self, L, M, N, NFP=None, sym=None):
self._set_up()


def polyder_vec(p, m):
def polyder_vec(p, m, exact=False):
"""Vectorized version of polyder.

For differentiating multiple polynomials of the same degree
Expand All @@ -1204,13 +1193,23 @@ def polyder_vec(p, m):
each column is a power of x
m : int >=0
order of derivative
exact : bool
Whether to use exact integer arithmetic (not compatible with JAX, but may be
needed for very high degree polynomials)

Returns
-------
der : ndarray, shape(N,M)
polynomial coefficients for derivative in descending order

"""
if exact:
return _polyder_exact(p, m)
else:
return _polyder_jax(p, m)


def _polyder_exact(p, m):
factorial = np.math.factorial
m = np.asarray(m, dtype=int) # order of derivative
p = np.atleast_2d(p)
Expand All @@ -1224,6 +1223,24 @@ def polyder_vec(p, m):
p = np.roll(D * p, m, axis=1)
idx = np.arange(p.shape[1])
p = np.where(idx < m, 0, p)
return p


@jit
def _polyder_jax(p, m):
p = jnp.atleast_2d(p)
m = jnp.asarray(m).astype(int)
order = p.shape[1] - 1
D = jnp.arange(order, -1, -1)

def body(i, Di):
return Di * jnp.maximum(D - i, 1)

D = fori_loop(0, m, body, jnp.ones_like(D))

p = jnp.roll(D * p, m, axis=1)
idx = jnp.arange(p.shape[1])
p = jnp.where(idx < m, 0, p)

return p

Expand Down Expand Up @@ -1253,31 +1270,36 @@ def polyval_vec(p, x, prec=None):
Each row corresponds to a polynomial, each column to a value of x

"""
if prec is not None and prec > 18:
return _polyval_exact(p, x, prec)
else:
return _polyval_jax(p, x)


def _polyval_exact(p, x, prec):
p = np.atleast_2d(p)
x = np.atleast_1d(x).flatten()
# for modest to large arrays, faster to find unique values and
# only evaluate those. Have to cast to float because np.unique
# can't handle object types like python native int
unq_x, xidx = np.unique(x, return_inverse=True)
_, pidx, outidx = np.unique(
p.astype(float), return_index=True, return_inverse=True, axis=0
)
unq_p = p[pidx]
# TODO: possibly multithread this bit
mpmath.mp.dps = prec
y = np.array([np.asarray(mpmath.polyval(list(pi), x)) for pi in p])
return y.astype(float)

if prec is not None and prec > 18:
# TODO: possibly multithread this bit
mpmath.mp.dps = prec
y = np.array([np.asarray(mpmath.polyval(list(pi), unq_x)) for pi in unq_p])
else:
npoly = unq_p.shape[0] # number of polynomials
order = unq_p.shape[1] # order of polynomials
nx = len(unq_x) # number of coordinates
y = np.zeros((npoly, nx))

for k in range(order):
y = y * unq_x + np.atleast_2d(unq_p[:, k]).T
@jit
def _polyval_jax(p, x):
p = jnp.atleast_2d(p)
x = jnp.atleast_1d(x).flatten()
npoly = p.shape[0] # number of polynomials
order = p.shape[1] # order of polynomials
nx = len(x) # number of coordinates
y = jnp.zeros((npoly, nx))

def body(k, y):
return y * x + jnp.atleast_2d(p[:, k]).T

y = fori_loop(0, order, body, y)

return y[outidx][:, xidx].astype(float)
return y.astype(float)


def zernike_radial_coeffs(l, m, exact=True):
Expand Down Expand Up @@ -1337,7 +1359,7 @@ def zernike_radial_coeffs(l, m, exact=True):
return c


def zernike_radial_poly(r, l, m, dr=0):
def zernike_radial_poly(r, l, m, dr=0, exact="auto"):
"""Radial part of zernike polynomials.

Evaluates basis functions using numpy to
Expand All @@ -1356,21 +1378,30 @@ def zernike_radial_poly(r, l, m, dr=0):
azimuthal mode number(s)
dr : int
order of derivative (Default = 0)
exact : {"auto", True, False}
Whether to use exact/extended precision arithmetic. Slower but more accurate.
"auto" will use higher accuracy when needed.

Returns
-------
y : ndarray, shape(N,K)
basis function(s) evaluated at specified points

"""
coeffs = zernike_radial_coeffs(l, m)
lmax = np.max(l)
coeffs = polyder_vec(coeffs, dr)
# this should give accuracy of ~1e-10 in the eval'd polynomials
prec = int(0.4 * lmax + 8.4)
if exact == "auto":
exact = np.max(l) > 54
if exact:
# this should give accuracy of ~1e-10 in the eval'd polynomials
lmax = np.max(l)
prec = int(0.4 * lmax + 8.4)
else:
prec = None
coeffs = zernike_radial_coeffs(l, m, exact=exact)
coeffs = polyder_vec(coeffs, dr, exact=exact)
return polyval_vec(coeffs, r, prec=prec).T


@functools.partial(jit, static_argnums=3)
def zernike_radial(r, l, m, dr=0):
"""Radial part of zernike polynomials.

Expand Down Expand Up @@ -1493,6 +1524,7 @@ def powers(rho, l, dr=0):
return polyval_vec(coeffs, rho).T


@functools.partial(jit, static_argnums=2)
def chebyshev(r, l, dr=0):
"""Shifted Chebyshev polynomial.

Expand All @@ -1511,8 +1543,8 @@ def chebyshev(r, l, dr=0):
basis function(s) evaluated at specified points

"""
r, l = map(jnp.asarray, (r, l))
x = 2 * r - 1 # shift
x, l, dr = map(jnp.asarray, (x, l, dr))
if dr == 0:
return jnp.cos(l * jnp.arccos(x))
else:
Expand Down
19 changes: 16 additions & 3 deletions desc/compute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def _get_derivs_1_key(key):
return {key: np.unique(val, axis=0).tolist() for key, val in derivs.items()}


def get_profiles(keys, obj, grid=None, has_axis=False, **kwargs):
def get_profiles(keys, obj, grid=None, has_axis=False, jitable=False, **kwargs):
"""Get profiles needed to compute a given quantity on a given grid.

Parameters
Expand All @@ -261,6 +261,8 @@ def get_profiles(keys, obj, grid=None, has_axis=False, **kwargs):
Grid to compute quantity on.
has_axis : bool
Whether the grid to compute on has a node on the magnetic axis.
jitable: bool
Whether to skip certain checks so that this operation works under JIT

Returns
-------
Expand All @@ -286,6 +288,8 @@ def get_profiles(keys, obj, grid=None, has_axis=False, **kwargs):
return profiles
for val in profiles.values():
if val is not None:
if jitable and hasattr(val, "_transform"):
val._transform.method = "jitable"
val.grid = grid
return profiles

Expand Down Expand Up @@ -325,7 +329,7 @@ def get_params(keys, obj, has_axis=False, **kwargs):
return params


def get_transforms(keys, obj, grid, **kwargs):
def get_transforms(keys, obj, grid, jitable=False, **kwargs):
"""Get transforms needed to compute a given quantity on a given grid.

Parameters
Expand All @@ -336,6 +340,8 @@ def get_transforms(keys, obj, grid, **kwargs):
Object to compute quantity for.
grid : Grid
Grid to compute quantity on
jitable: bool
Whether to skip certain checks so that this operation works under JIT

Returns
-------
Expand All @@ -347,13 +353,18 @@ def get_transforms(keys, obj, grid, **kwargs):
from desc.basis import DoubleFourierSeries
from desc.transform import Transform

method = "jitable" if jitable else "auto"
keys = [keys] if isinstance(keys, str) else keys
derivs = get_derivs(keys, obj, has_axis=grid.axis.size)
transforms = {"grid": grid}
for c in derivs.keys():
if hasattr(obj, c + "_basis"):
transforms[c] = Transform(
grid, getattr(obj, c + "_basis"), derivs=derivs[c], build=True
grid,
getattr(obj, c + "_basis"),
derivs=derivs[c],
build=True,
method=method,
)
elif c == "B":
transforms["B"] = Transform(
Expand All @@ -367,6 +378,7 @@ def get_transforms(keys, obj, grid, **kwargs):
derivs=derivs["B"],
build=True,
build_pinv=True,
method=method,
)
elif c == "w":
transforms["w"] = Transform(
Expand All @@ -380,6 +392,7 @@ def get_transforms(keys, obj, grid, **kwargs):
derivs=derivs["w"],
build=True,
build_pinv=True,
method=method,
)
elif c == "rotmat":
transforms["rotmat"] = obj.rotmat
Expand Down
Loading