1r"""Mesh management utilities for PSI staggered grid data.
2
3MAS and POT3D solve their equations on *staggered* (Yee-type) spherical grids
4:math:`(r, \theta, \varphi)`. Different physical quantities are located at
5different positions within each grid cell so that discrete differential operators
6(curl, divergence) are exactly satisfied at the discrete level. Each axis of a
7multi-dimensional output array is independently classified as either:
8
9- **Main mesh** — quantity sampled at the cell-center nodes.
10- **Half mesh** — quantity sampled at the face or edge midpoint, displaced by half a
11 grid spacing along that axis.
12
13Mesh codes
14----------
15A *mesh code* encodes the staggering of every axis in a single compact integer.
16Each binary bit indicates, per axis, whether the data lives on the half mesh
17(``1``) or the main mesh (``0``). With PSI's Fortran column-major HDF convention
18the **most-significant bit maps to the last numpy axis** (the radial :math:`r`
19direction), so a three-bit code reads :math:`(r, \theta, \varphi)` MSB → LSB:
20
21.. list-table::
22 :header-rows: 1
23
24 * - Code
25 - :math:`r`
26 - :math:`\theta`
27 - :math:`\varphi`
28 - Typical quantities
29 * - ``0b100``
30 - half
31 - main
32 - main
33 - :math:`B_r` (MAS)
34 * - ``0b010``
35 - main
36 - half
37 - main
38 - :math:`B_\theta` (MAS)
39 * - ``0b001``
40 - main
41 - main
42 - half
43 - :math:`B_\varphi` (MAS)
44 * - ``0b011``
45 - main
46 - half
47 - half
48 - :math:`v_r`, :math:`J_r` (MAS); :math:`B_r` (POT3D)
49 * - ``0b101``
50 - half
51 - main
52 - half
53 - :math:`v_\theta`, :math:`J_\theta` (MAS); :math:`B_\theta` (POT3D)
54 * - ``0b110``
55 - half
56 - half
57 - main
58 - :math:`v_\varphi`, :math:`J_\varphi` (MAS); :math:`B_\varphi` (POT3D)
59 * - ``0b111``
60 - half
61 - half
62 - half
63 - scalars: :math:`T`, :math:`\rho`, :math:`p`, …
64 * - ``0b000``
65 - main
66 - main
67 - main
68 - all-main; result of remeshing every axis
69
70Accepted input forms for a mesh code are described by :data:`MeshCodeType`
71(integer, string shorthand ``'main'``/``'half'``, or per-axis sequence).
72The memory-order convention is described by :data:`ArrayOrdering`.
73
74Public API
75----------
76:class:`Mesh`
77 Enum with two members — :attr:`~Mesh.MAIN` and :attr:`~Mesh.HALF` — representing
78 the two mesh positions.
79:data:`MeshCodeType`
80 Type alias for the three accepted forms of a mesh stagger specification.
81:data:`ArrayOrdering`
82 Type alias for the memory-order string (``'F'`` or ``'C'``) accepted by
83 :func:`remesh_array`.
84:func:`remesh_array`
85 Shift an array from one mesh stagger to another by averaging adjacent elements
86 along each axis that needs to move from half mesh to main mesh.
87
88Examples
89--------
90Convert a radial magnetic-field array (half-mesh in :math:`r`, the last numpy
91axis) to the all-main mesh:
92
93>>> import numpy as np
94>>> from psi_io._mesh import remesh_array
95>>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size
96>>> br_main = remesh_array(br, imesh=0b100, omesh='main')
97>>> br_main.shape
98(128, 64, 56)
99
100Remesh a scalar quantity (all-half, ``0b111``) to all-main:
101
102>>> rho = np.ones((128, 64, 57))
103>>> remesh_array(rho, imesh=0b111, omesh='main').shape
104(127, 63, 56)
105"""
106
107from __future__ import annotations
108
109__all__ = [
110 "Mesh",
111 "remesh_array",
112]
113
114import enum
115from types import MappingProxyType
116from typing import Sequence, Any, Union, Literal, Generator, Optional, Iterable
117
118import numpy as np
119
120
121_MESH_CODE_REVERSE_MAPPING = MappingProxyType({
122 '1': 1, 'h': 1, 'half': 1, 'true': 1,
123 '0': 0, 'm': 0, 'main': 0, 'false': 0
124})
125"""String-token → integer (0/1) lookup used to validate per-axis sequence mesh codes."""
126
127
128MeshCodeType = Union[int, Literal['main', 'half'], Sequence[Any]]
129"""Type alias for mesh stagger specifications accepted by :func:`remesh_array`.
130
131A mesh stagger may be expressed in any of three equivalent forms:
132
133- :class:`int` — binary-encoded stagger, one bit per axis (``1`` = half mesh,
134 ``0`` = main mesh). With PSI's Fortran HDF convention the most-significant
135 bit maps to the last numpy axis (:math:`r`). For example, ``0b100`` places
136 the array on the half mesh only along :math:`r` (the last axis).
137- ``'main'`` or ``'half'`` — string shorthand that applies the same stagger to
138 every axis uniformly.
139- :class:`~typing.Sequence` — one element per array dimension; each element may
140 be ``0``, ``1``, ``'m'``, ``'h'``, ``'main'``, ``'half'``, ``'true'``, or
141 ``'false'``.
142"""
143
144ArrayOrdering = Literal['F', 'C']
145"""Type alias for the memory-order convention accepted by :func:`remesh_array`.
146
147Controls how the bits of a :data:`MeshCodeType` integer map to numpy array axes.
148
149``'F'``
150 Fortran (column-major) order — the default for PSI data. Because PSI HDF
151 files are written by Fortran code, the physical ``(r, θ, φ)`` axis ordering
152 is **reversed** in numpy storage: the **last** numpy axis corresponds to
153 :math:`r`, the middle to :math:`\\theta`, and the **first** to :math:`\\varphi`.
154 The most-significant bit of the mesh code therefore maps to the last numpy axis.
155 Use this setting whenever the array was loaded directly from a PSI HDF file.
156``'C'``
157 C (row-major) order. Use when the array has been transposed to numpy-native
158 axis order (first axis = first physical coordinate, e.g. shape ``(Nr, Nθ, Nφ)``),
159 so that the most-significant bit maps to the first numpy axis.
160"""
161
162
[docs]
163class Mesh(enum.Enum):
164 """Enum identifying the stagger position of one array axis.
165
166 MAS and POT3D solve their equations on Yee-type staggered spherical grids.
167 Each axis of a multi-dimensional output array is independently classified as
168 :attr:`MAIN` (cell-center position) or :attr:`HALF` (cell-face/edge position,
169 displaced by half a grid spacing along that axis).
170
171 The stagger arrangement is physically motivated:
172
173 - Magnetic field components (:math:`B_r`, :math:`B_\\theta`, :math:`B_\\varphi`)
174 are face-centred — each component lives on the face through which it is the
175 outward normal — so that :math:`\\nabla \\cdot \\mathbf{B} = 0` is satisfied
176 exactly at the discrete level.
177 - Current density components follow from
178 :math:`\\mathbf{J} = \\nabla \\times \\mathbf{B}` and are therefore
179 edge-centred (half mesh on the two transverse axes).
180 - Scalar quantities (temperature, density, pressure) occupy the cell corners,
181 which correspond to the half-mesh position on all three axes (``0b111``).
182
183 :class:`Mesh` members appear as elements of the normalized mesh tuple returned
184 by :attr:`~psi_io._models.Props.mesh` and accepted by :func:`remesh_array`.
185
186 Attributes
187 ----------
188 MAIN : int
189 Cell-center mesh position; encoded as ``0``.
190 HALF : int
191 Cell-face/edge mesh position, offset by half a grid spacing; encoded as ``1``.
192
193 Examples
194 --------
195 >>> from psi_io._mesh import Mesh
196 >>> Mesh.MAIN.value
197 0
198 >>> Mesh.HALF.value
199 1
200 >>> Mesh('half')
201 <Mesh.HALF: 1>
202 >>> Mesh('m')
203 <Mesh.MAIN: 0>
204 >>> str(Mesh.HALF)
205 'HALF'
206 """
207
208 HALF = 1
209 MAIN = 0
210
211 @classmethod
212 def _missing_(cls, key: Any) -> Mesh:
213 """Look up *key* in :data:`_MESH_CODE_REVERSE_MAPPING`; return ``None`` if unrecognized."""
214 code_ = _MESH_CODE_REVERSE_MAPPING.get(str(key).lower())
215 return cls(code_) if code_ is not None else None # type: ignore
216
217
218 def __str__(self) -> str:
219 """Return the enum member name (``'MAIN'`` or ``'HALF'``)."""
220 return str(self.name)
221
222
223def _normalize_mesh_code(mesh_code: MeshCodeType, ndim: int) -> tuple[Mesh, ...]:
224 """Convert *mesh_code* to a length-*ndim* tuple of :class:`Mesh` members.
225
226 Parameters
227 ----------
228 mesh_code : MeshCodeType
229 Integer, ``'main'``/``'half'`` shorthand, or per-axis sequence.
230 ndim : int
231 Number of array dimensions.
232
233 Returns
234 -------
235 out : tuple[Mesh, ...]
236 """
237 if isinstance(mesh_code, int):
238 mesh_code = format(mesh_code, f'0{ndim}b')
239 elif mesh_code == 'main':
240 mesh_code = '0' * ndim
241 elif mesh_code == 'half':
242 mesh_code = '1' * ndim
243 elif len(mesh_code) != ndim:
244 raise ValueError(f'Mesh code length {len(mesh_code)} does not match data ndim {ndim}.')
245 try:
246 return tuple(Mesh(_MESH_CODE_REVERSE_MAPPING[str(c).lower()]) for c in mesh_code)
247 except KeyError as e:
248 raise ValueError(f"Invalid mesh code character '{e.args[0]}'. "
249 f"Valid characters are: {', '.join(_MESH_CODE_REVERSE_MAPPING.keys())}") from None
250
251
252def _average_adjacent(arr: np.ndarray,
253 axis: int
254 ) -> np.ndarray:
255 """Return the mean of adjacent element pairs along *axis*, reducing that dimension by one."""
256 if arr.shape[axis] < 2:
257 raise ValueError(f"Cannot remesh axis {axis} with size {arr.shape[axis]}."
258 f" Need at least 2 elements to average adjacent pairs.")
259 slc_lo = [slice(None)] * arr.ndim
260 slc_hi = [slice(None)] * arr.ndim
261 slc_lo[axis] = slice(None, -1)
262 slc_hi[axis] = slice(1, None)
263 return 0.5 * (arr[tuple(slc_lo)] + arr[tuple(slc_hi)])
264
265
266def _remesh_array(data: np.ndarray,
267 remesh: Iterable[bool] | bool,
268 order: ArrayOrdering = 'F') -> np.ndarray:
269 """Apply :func:`_average_adjacent` on each axis where *remesh* is ``True``."""
270 if isinstance(remesh, bool):
271 remesh = [remesh] * data.ndim
272 if order == 'F':
273 remesh = reversed(remesh)
274 for i, shift in enumerate(remesh):
275 if shift:
276 data = _average_adjacent(data, i)
277 return data
278
279
280def _parse_remesh(imesh: tuple[Mesh, ...],
281 omesh: tuple[Mesh, ...],
282 order: ArrayOrdering = 'F'
283 ) -> Generator[bool]:
284 """Yield per-axis remesh flags (``True`` = half→main) by comparing *imesh* to *omesh*."""
285 if order == 'F':
286 imesh, omesh = reversed(imesh), reversed(omesh)
287 for im, om in zip(imesh, omesh):
288 if im == om:
289 yield False
290 elif im == Mesh.HALF and om == Mesh.MAIN:
291 yield True
292 elif im == Mesh.MAIN and om == Mesh.HALF:
293 raise ValueError(f"Cannot remesh from MAIN mesh to HALF mesh.")
294 else:
295 raise ValueError(f"Invalid mesh combination: {im} → {om}.")
296
297
[docs]
298def remesh_array(data: np.ndarray,
299 imesh: MeshCodeType,
300 omesh: Optional[MeshCodeType] = None,
301 order: ArrayOrdering = 'F') -> np.ndarray:
302 """Shift an array from one mesh stagger to another.
303
304 Compares the source mesh *imesh* against the target mesh *omesh* axis by axis
305 and applies adjacent-element averaging on every axis that needs to move from
306 half mesh to main mesh. Only the half → main direction is supported;
307 requesting main → half raises :class:`ValueError`.
308
309 This is commonly needed before performing interpolation or arithmetic between
310 quantities on different mesh positions. For example, computing the magnitude of the
311 magnetic requires :math:`B_r`, :math:`B_\\theta`, and
312 :math:`B_\\varphi` on the same mesh: each must be remeshed from its native
313 stagger (``0b100``, ``0b010``, ``0b001``) to a common target before squaring
314 and summing.
315
316 If *omesh* is ``None``, the array is returned unchanged.
317
318 Parameters
319 ----------
320 data : np.ndarray
321 Input array on the stagger described by *imesh*.
322 imesh : MeshCodeType
323 Source mesh stagger in any form accepted by :data:`MeshCodeType`.
324 omesh : MeshCodeType or None, optional
325 Target mesh stagger. ``None`` (default) is a no-op. Pass ``0`` or
326 ``'main'`` to move every half-mesh axis to the main mesh.
327 order : ArrayOrdering, optional
328 Memory-order convention controlling how mesh-code bits map to numpy
329 axes; see :data:`ArrayOrdering`. Defaults to ``'F'`` (Fortran /
330 PSI HDF column-major: last numpy axis = :math:`r`).
331
332 Returns
333 -------
334 out : np.ndarray
335 Array on the target mesh. Each remeshed axis is reduced by one element
336 via adjacent averaging; axes that already match are left unchanged.
337
338 Raises
339 ------
340 ValueError
341 If any axis in *omesh* requests half mesh where *imesh* is already main
342 (upsampling is not supported).
343
344 See Also
345 --------
346 Mesh : Enum representing the two mesh positions.
347 MeshCodeType : Accepted forms for mesh stagger specifications.
348 ArrayOrdering : Memory-order convention for bit–axis mapping.
349
350 Examples
351 --------
352 Convert a radial magnetic-field array (half-mesh in :math:`r`, the last
353 numpy axis) to the all-main mesh:
354
355 >>> import numpy as np
356 >>> from psi_io._mesh import remesh_array
357 >>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size
358 >>> br_main = remesh_array(br, imesh=0b100, omesh='main')
359 >>> br_main.shape
360 (128, 64, 56)
361
362 Remesh a scalar quantity (all-half, ``0b111``) to all-main:
363
364 >>> rho = np.ones((128, 64, 57))
365 >>> remesh_array(rho, imesh=0b111, omesh='main').shape
366 (127, 63, 56)
367
368 ``omesh=None`` is a no-op:
369
370 >>> remesh_array(br, imesh=0b100).shape
371 (128, 64, 57)
372
373 No-op when source and target stagger already match:
374
375 >>> remesh_array(br, imesh=0b100, omesh=0b100).shape
376 (128, 64, 57)
377 """
378 if omesh is None:
379 return data
380 imesh_norm = _normalize_mesh_code(imesh, data.ndim)
381 omesh_norm = _normalize_mesh_code(omesh, data.ndim)
382 remesh_flags = _parse_remesh(imesh_norm, omesh_norm, order)
383 return _remesh_array(data, remesh_flags, order='C')
384