Source code for psi_io._mesh

  1r"""Mesh management utilities for PSI staggered grid data.
  2
  3MAS and POT3D solve their equations on *staggered* (Yee-type) spherical grids
  4:math:`(r, \theta, \varphi)`.  Different physical quantities are located at
  5different positions within each grid cell so that discrete differential operators
  6(curl, divergence) are exactly satisfied at the discrete level.  Each axis of a
  7multi-dimensional output array is independently classified as either:
  8
  9- **Main mesh** — quantity sampled at the cell-center nodes.
 10- **Half mesh** — quantity sampled at the face or edge midpoint, displaced by half a
 11  grid spacing along that axis.
 12
 13Mesh codes
 14----------
 15A *mesh code* encodes the staggering of every axis in a single compact integer.
 16Each binary bit indicates, per axis, whether the data lives on the half mesh
 17(``1``) or the main mesh (``0``).  With PSI's Fortran column-major HDF convention
 18the **most-significant bit maps to the last numpy axis** (the radial :math:`r`
 19direction), so a three-bit code reads :math:`(r, \theta, \varphi)` MSB → LSB:
 20
 21.. list-table::
 22   :header-rows: 1
 23
 24   * - Code
 25     - :math:`r`
 26     - :math:`\theta`
 27     - :math:`\varphi`
 28     - Typical quantities
 29   * - ``0b100``
 30     - half
 31     - main
 32     - main
 33     - :math:`B_r` (MAS)
 34   * - ``0b010``
 35     - main
 36     - half
 37     - main
 38     - :math:`B_\theta` (MAS)
 39   * - ``0b001``
 40     - main
 41     - main
 42     - half
 43     - :math:`B_\varphi` (MAS)
 44   * - ``0b011``
 45     - main
 46     - half
 47     - half
 48     - :math:`v_r`, :math:`J_r` (MAS); :math:`B_r` (POT3D)
 49   * - ``0b101``
 50     - half
 51     - main
 52     - half
 53     - :math:`v_\theta`, :math:`J_\theta` (MAS); :math:`B_\theta` (POT3D)
 54   * - ``0b110``
 55     - half
 56     - half
 57     - main
 58     - :math:`v_\varphi`, :math:`J_\varphi` (MAS); :math:`B_\varphi` (POT3D)
 59   * - ``0b111``
 60     - half
 61     - half
 62     - half
 63     - scalars: :math:`T`, :math:`\rho`, :math:`p`, …
 64   * - ``0b000``
 65     - main
 66     - main
 67     - main
 68     - all-main; result of remeshing every axis
 69
 70Accepted input forms for a mesh code are described by :data:`MeshCodeType`
 71(integer, string shorthand ``'main'``/``'half'``, or per-axis sequence).
 72The memory-order convention is described by :data:`ArrayOrdering`.
 73
 74Public API
 75----------
 76:class:`Mesh`
 77    Enum with two members — :attr:`~Mesh.MAIN` and :attr:`~Mesh.HALF` — representing
 78    the two mesh positions.
 79:data:`MeshCodeType`
 80    Type alias for the three accepted forms of a mesh stagger specification.
 81:data:`ArrayOrdering`
 82    Type alias for the memory-order string (``'F'`` or ``'C'``) accepted by
 83    :func:`remesh_array`.
 84:func:`remesh_array`
 85    Shift an array from one mesh stagger to another by averaging adjacent elements
 86    along each axis that needs to move from half mesh to main mesh.
 87
 88Examples
 89--------
 90Convert a radial magnetic-field array (half-mesh in :math:`r`, the last numpy
 91axis) to the all-main mesh:
 92
 93>>> import numpy as np
 94>>> from psi_io._mesh import remesh_array
 95>>> br = np.ones((128, 64, 57))   # shape (Nφ, Nθ, Nr); Nr is half-mesh size
 96>>> br_main = remesh_array(br, imesh=0b100, omesh='main')
 97>>> br_main.shape
 98(128, 64, 56)
 99
100Remesh a scalar quantity (all-half, ``0b111``) to all-main:
101
102>>> rho = np.ones((128, 64, 57))
103>>> remesh_array(rho, imesh=0b111, omesh='main').shape
104(127, 63, 56)
105"""
106
107from __future__ import annotations
108
109__all__ = [
110    "Mesh",
111    "remesh_array",
112]
113
114import enum
115from types import MappingProxyType
116from typing import Sequence, Any, Union, Literal, Generator, Optional, Iterable
117
118import numpy as np
119
120
121_MESH_CODE_REVERSE_MAPPING = MappingProxyType({
122    '1': 1, 'h': 1, 'half': 1, 'true': 1,
123    '0': 0, 'm': 0, 'main': 0, 'false': 0
124})
125"""String-token → integer (0/1) lookup used to validate per-axis sequence mesh codes."""
126
127
128MeshCodeType = Union[int, Literal['main', 'half'], Sequence[Any]]
129"""Type alias for mesh stagger specifications accepted by :func:`remesh_array`.
130
131A mesh stagger may be expressed in any of three equivalent forms:
132
133- :class:`int` — binary-encoded stagger, one bit per axis (``1`` = half mesh,
134  ``0`` = main mesh).  With PSI's Fortran HDF convention the most-significant
135  bit maps to the last numpy axis (:math:`r`).  For example, ``0b100`` places
136  the array on the half mesh only along :math:`r` (the last axis).
137- ``'main'`` or ``'half'`` — string shorthand that applies the same stagger to
138  every axis uniformly.
139- :class:`~typing.Sequence` — one element per array dimension; each element may
140  be ``0``, ``1``, ``'m'``, ``'h'``, ``'main'``, ``'half'``, ``'true'``, or
141  ``'false'``.
142"""
143
144ArrayOrdering = Literal['F', 'C']
145"""Type alias for the memory-order convention accepted by :func:`remesh_array`.
146
147Controls how the bits of a :data:`MeshCodeType` integer map to numpy array axes.
148
149``'F'``
150    Fortran (column-major) order — the default for PSI data.  Because PSI HDF
151    files are written by Fortran code, the physical ``(r, θ, φ)`` axis ordering
152    is **reversed** in numpy storage: the **last** numpy axis corresponds to
153    :math:`r`, the middle to :math:`\\theta`, and the **first** to :math:`\\varphi`.
154    The most-significant bit of the mesh code therefore maps to the last numpy axis.
155    Use this setting whenever the array was loaded directly from a PSI HDF file.
156``'C'``
157    C (row-major) order.  Use when the array has been transposed to numpy-native
158    axis order (first axis = first physical coordinate, e.g. shape ``(Nr, Nθ, Nφ)``),
159    so that the most-significant bit maps to the first numpy axis.
160"""
161
162
[docs] 163class Mesh(enum.Enum): 164 """Enum identifying the stagger position of one array axis. 165 166 MAS and POT3D solve their equations on Yee-type staggered spherical grids. 167 Each axis of a multi-dimensional output array is independently classified as 168 :attr:`MAIN` (cell-center position) or :attr:`HALF` (cell-face/edge position, 169 displaced by half a grid spacing along that axis). 170 171 The stagger arrangement is physically motivated: 172 173 - Magnetic field components (:math:`B_r`, :math:`B_\\theta`, :math:`B_\\varphi`) 174 are face-centred — each component lives on the face through which it is the 175 outward normal — so that :math:`\\nabla \\cdot \\mathbf{B} = 0` is satisfied 176 exactly at the discrete level. 177 - Current density components follow from 178 :math:`\\mathbf{J} = \\nabla \\times \\mathbf{B}` and are therefore 179 edge-centred (half mesh on the two transverse axes). 180 - Scalar quantities (temperature, density, pressure) occupy the cell corners, 181 which correspond to the half-mesh position on all three axes (``0b111``). 182 183 :class:`Mesh` members appear as elements of the normalized mesh tuple returned 184 by :attr:`~psi_io._models.Props.mesh` and accepted by :func:`remesh_array`. 185 186 Attributes 187 ---------- 188 MAIN : int 189 Cell-center mesh position; encoded as ``0``. 190 HALF : int 191 Cell-face/edge mesh position, offset by half a grid spacing; encoded as ``1``. 192 193 Examples 194 -------- 195 >>> from psi_io._mesh import Mesh 196 >>> Mesh.MAIN.value 197 0 198 >>> Mesh.HALF.value 199 1 200 >>> Mesh('half') 201 <Mesh.HALF: 1> 202 >>> Mesh('m') 203 <Mesh.MAIN: 0> 204 >>> str(Mesh.HALF) 205 'HALF' 206 """ 207 208 HALF = 1 209 MAIN = 0 210 211 @classmethod 212 def _missing_(cls, key: Any) -> Mesh: 213 """Look up *key* in :data:`_MESH_CODE_REVERSE_MAPPING`; return ``None`` if unrecognized.""" 214 code_ = _MESH_CODE_REVERSE_MAPPING.get(str(key).lower()) 215 return cls(code_) if code_ is not None else None # type: ignore 216 217 218 def __str__(self) -> str: 219 """Return the enum member name (``'MAIN'`` or ``'HALF'``).""" 220 return str(self.name)
221 222 223def _normalize_mesh_code(mesh_code: MeshCodeType, ndim: int) -> tuple[Mesh, ...]: 224 """Convert *mesh_code* to a length-*ndim* tuple of :class:`Mesh` members. 225 226 Parameters 227 ---------- 228 mesh_code : MeshCodeType 229 Integer, ``'main'``/``'half'`` shorthand, or per-axis sequence. 230 ndim : int 231 Number of array dimensions. 232 233 Returns 234 ------- 235 out : tuple[Mesh, ...] 236 """ 237 if isinstance(mesh_code, int): 238 mesh_code = format(mesh_code, f'0{ndim}b') 239 elif mesh_code == 'main': 240 mesh_code = '0' * ndim 241 elif mesh_code == 'half': 242 mesh_code = '1' * ndim 243 elif len(mesh_code) != ndim: 244 raise ValueError(f'Mesh code length {len(mesh_code)} does not match data ndim {ndim}.') 245 try: 246 return tuple(Mesh(_MESH_CODE_REVERSE_MAPPING[str(c).lower()]) for c in mesh_code) 247 except KeyError as e: 248 raise ValueError(f"Invalid mesh code character '{e.args[0]}'. " 249 f"Valid characters are: {', '.join(_MESH_CODE_REVERSE_MAPPING.keys())}") from None 250 251 252def _average_adjacent(arr: np.ndarray, 253 axis: int 254 ) -> np.ndarray: 255 """Return the mean of adjacent element pairs along *axis*, reducing that dimension by one.""" 256 if arr.shape[axis] < 2: 257 raise ValueError(f"Cannot remesh axis {axis} with size {arr.shape[axis]}." 258 f" Need at least 2 elements to average adjacent pairs.") 259 slc_lo = [slice(None)] * arr.ndim 260 slc_hi = [slice(None)] * arr.ndim 261 slc_lo[axis] = slice(None, -1) 262 slc_hi[axis] = slice(1, None) 263 return 0.5 * (arr[tuple(slc_lo)] + arr[tuple(slc_hi)]) 264 265 266def _remesh_array(data: np.ndarray, 267 remesh: Iterable[bool] | bool, 268 order: ArrayOrdering = 'F') -> np.ndarray: 269 """Apply :func:`_average_adjacent` on each axis where *remesh* is ``True``.""" 270 if isinstance(remesh, bool): 271 remesh = [remesh] * data.ndim 272 if order == 'F': 273 remesh = reversed(remesh) 274 for i, shift in enumerate(remesh): 275 if shift: 276 data = _average_adjacent(data, i) 277 return data 278 279 280def _parse_remesh(imesh: tuple[Mesh, ...], 281 omesh: tuple[Mesh, ...], 282 order: ArrayOrdering = 'F' 283 ) -> Generator[bool]: 284 """Yield per-axis remesh flags (``True`` = half→main) by comparing *imesh* to *omesh*.""" 285 if order == 'F': 286 imesh, omesh = reversed(imesh), reversed(omesh) 287 for im, om in zip(imesh, omesh): 288 if im == om: 289 yield False 290 elif im == Mesh.HALF and om == Mesh.MAIN: 291 yield True 292 elif im == Mesh.MAIN and om == Mesh.HALF: 293 raise ValueError(f"Cannot remesh from MAIN mesh to HALF mesh.") 294 else: 295 raise ValueError(f"Invalid mesh combination: {im}{om}.") 296 297
[docs] 298def remesh_array(data: np.ndarray, 299 imesh: MeshCodeType, 300 omesh: Optional[MeshCodeType] = None, 301 order: ArrayOrdering = 'F') -> np.ndarray: 302 """Shift an array from one mesh stagger to another. 303 304 Compares the source mesh *imesh* against the target mesh *omesh* axis by axis 305 and applies adjacent-element averaging on every axis that needs to move from 306 half mesh to main mesh. Only the half → main direction is supported; 307 requesting main → half raises :class:`ValueError`. 308 309 This is commonly needed before performing interpolation or arithmetic between 310 quantities on different mesh positions. For example, computing the magnitude of the 311 magnetic requires :math:`B_r`, :math:`B_\\theta`, and 312 :math:`B_\\varphi` on the same mesh: each must be remeshed from its native 313 stagger (``0b100``, ``0b010``, ``0b001``) to a common target before squaring 314 and summing. 315 316 If *omesh* is ``None``, the array is returned unchanged. 317 318 Parameters 319 ---------- 320 data : np.ndarray 321 Input array on the stagger described by *imesh*. 322 imesh : MeshCodeType 323 Source mesh stagger in any form accepted by :data:`MeshCodeType`. 324 omesh : MeshCodeType or None, optional 325 Target mesh stagger. ``None`` (default) is a no-op. Pass ``0`` or 326 ``'main'`` to move every half-mesh axis to the main mesh. 327 order : ArrayOrdering, optional 328 Memory-order convention controlling how mesh-code bits map to numpy 329 axes; see :data:`ArrayOrdering`. Defaults to ``'F'`` (Fortran / 330 PSI HDF column-major: last numpy axis = :math:`r`). 331 332 Returns 333 ------- 334 out : np.ndarray 335 Array on the target mesh. Each remeshed axis is reduced by one element 336 via adjacent averaging; axes that already match are left unchanged. 337 338 Raises 339 ------ 340 ValueError 341 If any axis in *omesh* requests half mesh where *imesh* is already main 342 (upsampling is not supported). 343 344 See Also 345 -------- 346 Mesh : Enum representing the two mesh positions. 347 MeshCodeType : Accepted forms for mesh stagger specifications. 348 ArrayOrdering : Memory-order convention for bit–axis mapping. 349 350 Examples 351 -------- 352 Convert a radial magnetic-field array (half-mesh in :math:`r`, the last 353 numpy axis) to the all-main mesh: 354 355 >>> import numpy as np 356 >>> from psi_io._mesh import remesh_array 357 >>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size 358 >>> br_main = remesh_array(br, imesh=0b100, omesh='main') 359 >>> br_main.shape 360 (128, 64, 56) 361 362 Remesh a scalar quantity (all-half, ``0b111``) to all-main: 363 364 >>> rho = np.ones((128, 64, 57)) 365 >>> remesh_array(rho, imesh=0b111, omesh='main').shape 366 (127, 63, 56) 367 368 ``omesh=None`` is a no-op: 369 370 >>> remesh_array(br, imesh=0b100).shape 371 (128, 64, 57) 372 373 No-op when source and target stagger already match: 374 375 >>> remesh_array(br, imesh=0b100, omesh=0b100).shape 376 (128, 64, 57) 377 """ 378 if omesh is None: 379 return data 380 imesh_norm = _normalize_mesh_code(imesh, data.ndim) 381 omesh_norm = _normalize_mesh_code(omesh, data.ndim) 382 remesh_flags = _parse_remesh(imesh_norm, omesh_norm, order) 383 return _remesh_array(data, remesh_flags, order='C')
384