Source code for psi_io.mesh

  1r"""Mesh management utilities for PSI staggered grid data.
  2
  3MAS and POT3D solve their equations on *staggered* (Yee-type) spherical grids
  4:math:`(r, \theta, \varphi)`.  Different physical quantities are located at
  5different positions within each grid cell so that discrete differential operators
  6(curl, divergence) are exactly satisfied at the discrete level.  Each axis of a
  7multi-dimensional output array is independently classified as either:
  8
  9- **Main mesh** — quantity sampled at the cell-center nodes.
 10- **Half mesh** — quantity sampled at the face or edge midpoint, displaced by half a
 11  grid spacing along that axis.
 12
 13Mesh codes
 14----------
 15A *mesh code* encodes the staggering of every axis in a single compact integer.
 16Each binary bit indicates, per axis, whether the data lives on the half mesh
 17(``1``) or the main mesh (``0``).  With PSI's Fortran column-major HDF convention
 18the **most-significant bit maps to the last numpy axis** (the radial :math:`r`
 19direction), so a three-bit code reads :math:`(r, \theta, \varphi)` MSB → LSB:
 20
 21.. list-table::
 22   :header-rows: 1
 23
 24   * - Code
 25     - :math:`r`
 26     - :math:`\theta`
 27     - :math:`\varphi`
 28     - Typical quantities
 29   * - ``0b100``
 30     - half
 31     - main
 32     - main
 33     - :math:`B_r` (MAS)
 34   * - ``0b010``
 35     - main
 36     - half
 37     - main
 38     - :math:`B_\theta` (MAS)
 39   * - ``0b001``
 40     - main
 41     - main
 42     - half
 43     - :math:`B_\varphi` (MAS)
 44   * - ``0b011``
 45     - main
 46     - half
 47     - half
 48     - :math:`v_r`, :math:`J_r` (MAS); :math:`B_r` (POT3D)
 49   * - ``0b101``
 50     - half
 51     - main
 52     - half
 53     - :math:`v_\theta`, :math:`J_\theta` (MAS); :math:`B_\theta` (POT3D)
 54   * - ``0b110``
 55     - half
 56     - half
 57     - main
 58     - :math:`v_\varphi`, :math:`J_\varphi` (MAS); :math:`B_\varphi` (POT3D)
 59   * - ``0b111``
 60     - half
 61     - half
 62     - half
 63     - scalars: :math:`T`, :math:`\rho`, :math:`p`, …
 64   * - ``0b000``
 65     - main
 66     - main
 67     - main
 68     - all-main; result of remeshing every axis
 69
 70Accepted input forms for a mesh code are described by :data:`MeshCodeType`
 71(integer, string shorthand ``'main'``/``'half'``, or per-axis sequence).
 72The memory-order convention is described by :data:`ArrayOrdering`.
 73
 74Public API
 75----------
 76:class:`Mesh`
 77    Immutable (frozen) dataclass wrapping a binary stagger :attr:`~Mesh.code` and
 78    its axis count :attr:`~Mesh.ndim`.  Each bit of the code classifies one axis as
 79    half mesh (``1``) or main mesh (``0``).
 80:data:`MeshCodeType`
 81    Type alias for the three accepted forms of a mesh stagger specification.
 82:data:`ArrayOrdering`
 83    Type alias for the memory-order string (``'F'`` or ``'C'``) accepted by
 84    :func:`remesh_array`.
 85:func:`remesh_array`
 86    Shift an array from one mesh stagger to another by averaging adjacent elements
 87    along each axis that needs to move from half mesh to main mesh.
 88
 89Examples
 90--------
 91Convert a radial magnetic-field array (half-mesh in :math:`r`, the last numpy
 92axis) to the all-main mesh:
 93
 94>>> import numpy as np
 95>>> from psi_io.mesh import remesh_array
 96>>> br = np.ones((128, 64, 57))   # shape (Nφ, Nθ, Nr); Nr is half-mesh size
 97>>> br_main = remesh_array(br, imesh=0b100, omesh='main')
 98>>> br_main.shape
 99(128, 64, 56)
100
101Remesh a scalar quantity (all-half, ``0b111``) to all-main:
102
103>>> rho = np.ones((128, 64, 57))
104>>> remesh_array(rho, imesh=0b111, omesh='main').shape
105(127, 63, 56)
106"""
107
108from __future__ import annotations
109
110__all__ = [
111    "Mesh",
112    "remesh_array",
113]
114
115import functools
116from collections.abc import Sequence as ABCSequence
117from dataclasses import dataclass
118from types import MappingProxyType
119from typing import Callable, Sequence, Any, Union, Literal, Generator, Optional, Iterable
120
121import numpy as np
122
123
124_MESH_CODE_REVERSE_MAPPING = MappingProxyType({
125    '1': 1, 'h': 1, 'half': 1, 'true': 1,
126    '0': 0, 'm': 0, 'main': 0, 'false': 0
127})
128"""String-token to integer (0/1) lookup used to validate per-axis sequence mesh codes.
129
130Each key is a recognized string representation of a single-axis stagger token.
131The value ``1`` means half mesh; the value ``0`` means main mesh.
132
133Keys
134----
135``'1'``, ``'h'``, ``'half'``, ``'true'``
136    Map to ``1`` (half mesh).
137``'0'``, ``'m'``, ``'main'``, ``'false'``
138    Map to ``0`` (main mesh).
139
140Notes
141-----
142Token matching is performed case-insensitively via :func:`str.lower` before
143lookup.  An unrecognized token returns ``None``, which callers treat as an
144error condition.
145
146Examples
147--------
148>>> _MESH_CODE_REVERSE_MAPPING['half']
1491
150>>> _MESH_CODE_REVERSE_MAPPING['m']
1510
152>>> _MESH_CODE_REVERSE_MAPPING.get('x') is None
153True
154"""
155
156
157MeshCodeType = Union[bool, int, Literal['main', 'half'], Sequence[Any]]
158r"""Type alias for mesh stagger specifications accepted by :func:`remesh_array`.
159
160A mesh stagger may be expressed in any of three equivalent forms:
161
162- :class:`int` — binary-encoded stagger, one bit per axis (``1`` = half mesh,
163  ``0`` = main mesh).  With PSI's Fortran HDF convention the most-significant
164  bit maps to the last numpy axis (:math:`r`).  For example, ``0b100`` places
165  the array on the half mesh only along :math:`r` (the last axis).
166- ``'main'`` or ``'half'`` — string shorthand that applies the same stagger to
167  every axis uniformly.
168- :class:`~typing.Sequence` — one element per array dimension; each element may
169  be ``0``, ``1``, ``'m'``, ``'h'``, ``'main'``, ``'half'``, ``'true'``, or
170  ``'false'``.
171
172Examples
173--------
174All three forms below encode the same 3-D stagger (half only along :math:`r`):
175
176>>> from psi_io.mesh import Mesh
177>>> str(Mesh.parse(0b100, ndim=3))
178'HALF, MAIN, MAIN'
179>>> str(Mesh.parse([True, False, False], ndim=3))
180'HALF, MAIN, MAIN'
181"""
182
183MeshLike = Union['Mesh', MeshCodeType]
184r"""Type alias for any accepted mesh specification, including an existing :class:`Mesh`.
185
186This is a superset of :data:`MeshCodeType`; it additionally accepts a :class:`Mesh` 
187instance, which is passed through unchanged.
188
189See Also
190--------
191MeshCodeType : Accepted forms that do not include :class:`Mesh` itself.
192Mesh.parse : Normalizes any :data:`MeshLike` value into a :class:`Mesh`.
193
194Examples
195--------
196>>> from psi_io.mesh import Mesh
197>>> m = Mesh.parse(0b101, ndim=3)
198>>> Mesh.parse(m, ndim=3) is m   # Mesh passthrough
199True
200>>> str(Mesh.parse('half', ndim=2))
201'HALF, HALF'
202"""
203
204ArrayOrdering = Literal['F', 'C']
205r"""Type alias for the memory-order convention accepted by :func:`remesh_array`.
206
207Controls how the bits of a :data:`MeshCodeType` integer map to numpy array axes.
208
209``'F'``
210    Fortran (column-major) order — the default for PSI data.  Because PSI HDF
211    files are written by Fortran code, the physical ``(r, θ, φ)`` axis ordering
212    is **reversed** in numpy storage: the **last** numpy axis corresponds to
213    :math:`r`, the middle to :math:`\theta`, and the **first** to :math:`\varphi`.
214    The most-significant bit of the mesh code therefore maps to the last numpy axis.
215    Use this setting whenever the array was loaded directly from a PSI HDF file.
216``'C'``
217    C (row-major) order.  Use when the array has been transposed to numpy-native
218    axis order (first axis = first physical coordinate, e.g. shape ``(Nr, Nθ, Nφ)``),
219    so that the most-significant bit maps to the first numpy axis.
220
221Examples
222--------
223>>> import numpy as np
224>>> from psi_io.mesh import remesh_array
225>>> arr = np.ones((57, 64, 128))   # C-order: shape (Nr, Nθ, Nφ); Nr is half-mesh
226>>> out = remesh_array(arr, imesh=0b100, omesh='main', order='C')
227>>> out.shape
228(56, 64, 128)
229"""
230
231
232def _coerce_mesh_target(method: Callable) -> Callable:
233    """Coerce the *target* argument of a :class:`Mesh` method to a :class:`Mesh`.
234
235    Decorator that normalizes the second positional argument (*target*) of a
236    bound :class:`Mesh` method.  If *target* is already a :class:`Mesh` its
237    ``ndim`` is verified to match ``self.ndim``.  If *target* is ``None`` it
238    is replaced by ``self`` (a no-op target).  Otherwise :meth:`Mesh.parse` is
239    called with ``ndim=self.ndim`` to produce the normalized :class:`Mesh`.
240
241    Parameters
242    ----------
243    method : Callable
244        The bound :class:`Mesh` method whose second positional argument should
245        be coerced.  The wrapper preserves the method's ``__name__``,
246        ``__doc__``, and other attributes via :func:`functools.wraps`.
247
248    Returns
249    -------
250    wrapper : Callable
251        Wrapped version of *method* with automatic coercion of the *target*
252        argument.
253
254    Raises
255    ------
256    ValueError
257        If *target* is a :class:`Mesh` whose ``ndim`` differs from
258        ``self.ndim``.
259
260    Examples
261    --------
262    The decorator is applied internally; here is the observable behavior it
263    enables on :meth:`Mesh.remesh`:
264
265    >>> from psi_io.mesh import Mesh
266    >>> m = Mesh.parse(0b111, ndim=3)
267    >>> m.remesh('main')         # string target is coerced automatically
268    (True, True, True)
269    >>> m.remesh(None)           # None is replaced by self → no-op
270    (False, False, False)
271    """
272    @functools.wraps(method)
273    def wrapper(self: 'Mesh', target: Any, *args: Any, **kwargs: Any) -> Any:
274        if isinstance(target, Mesh):
275            if self.ndim != target.ndim:
276                raise ValueError(f"ndim mismatch: {self.ndim} vs {target.ndim}.")
277        elif target is None:
278            target = self
279        else:
280            target = Mesh.parse(target, ndim=self.ndim)
281        return method(self, target, *args, **kwargs)
282    return wrapper
283
284
[docs] 285@functools.total_ordering 286@dataclass(frozen=True) 287class Mesh: 288 r"""Compact, immutable representation of a multi-axis mesh stagger code. 289 290 A :class:`Mesh` is a frozen dataclass that bundles two integers: a binary 291 stagger *code* and the axis count *ndim*. Together they describe, for a 292 *ndim*-dimensional array, which axes are sampled on the **half mesh** (face or 293 edge midpoints, bit ``1``) versus the **main mesh** (cell-center nodes, bit 294 ``0``). Because it replaces the legacy two-member ``Mesh`` enum, a single 295 instance now encodes the stagger of *every* axis at once rather than one axis 296 per enum member. 297 298 The bit-to-axis mapping follows PSI's Fortran column-major HDF convention: the 299 **most-significant bit maps to the first logical axis** (physical :math:`r`), 300 descending to the least-significant bit at the last axis (:math:`\varphi`). 301 For a 3-bit code the axes therefore read :math:`(r, \theta, \varphi)` from 302 MSB to LSB. For example, ``code=0b100, ndim=3`` means :math:`r` is on the 303 half mesh while :math:`\theta` and :math:`\varphi` are on the main mesh 304 (the MAS :math:`B_r` stagger). 305 306 Once built, a :class:`Mesh` supports rich operations: it iterates as a 307 sequence of per-axis booleans (:meth:`__iter__`), indexes to a single-axis 308 :class:`Mesh` (:meth:`__getitem__`), reports its remesh requirements against a 309 target via :meth:`remesh` / the ``>>`` operator, reverses axis order with 310 :meth:`reverse`, and acts as a plain integer code in index contexts 311 (:meth:`__index__`). 312 313 Prefer constructing via :meth:`parse` rather than calling the constructor 314 directly: :meth:`parse` accepts every form described by :data:`MeshCodeType` 315 (integers, the shorthand strings ``'main'``/``'half'``, per-axis token 316 strings such as ``'MMH'``, and per-axis boolean/int sequences) and infers or 317 validates *ndim* for you. The direct constructor accepts only an explicit 318 integer *code* and *ndim*. 319 320 Parameters 321 ---------- 322 code : int 323 Binary-encoded stagger integer. Bit ``i`` (counting from the LSB) sets 324 the stagger of logical axis ``ndim - 1 - i``: ``1`` for half mesh, ``0`` 325 for main mesh. Must fit within *ndim* bits (i.e. ``0 <= code < 2**ndim``). 326 ndim : int 327 Number of array dimensions (and therefore significant bits) represented 328 by this code. Fixes the width of the stagger field and the length of the 329 iterated/indexed sequence. 330 331 Raises 332 ------ 333 ValueError 334 If *code* has any bit set at or above position *ndim* (i.e. 335 ``code >= 2**ndim``), since that bit could not correspond to a real axis. 336 337 See Also 338 -------- 339 Mesh.parse : Build a :class:`Mesh` from any :data:`MeshCodeType` form. 340 Mesh.remesh : Per-axis flags for moving from this stagger to a target. 341 remesh_array : Apply an actual half-to-main mesh shift to a NumPy array. 342 343 Examples 344 -------- 345 >>> from psi_io.mesh import Mesh 346 >>> str(Mesh.parse(0b100, ndim=3)) 347 'HALF, MAIN, MAIN' 348 >>> str(Mesh.parse('MMH', ndim=3)) 349 'MAIN, MAIN, HALF' 350 >>> str(Mesh.parse([True, False, True], ndim=3)) 351 'HALF, MAIN, HALF' 352 353 The direct constructor takes an explicit code and dimension count: 354 355 >>> Mesh(code=0b100, ndim=3) 356 Mesh(HALF, MAIN, MAIN) 357 """ 358 359 code: int 360 ndim: int 361 362 def __post_init__(self) -> None: 363 """Validate that *code* fits within *ndim* bits. 364 365 Runs automatically after the dataclass constructor assigns :attr:`code` 366 and :attr:`ndim`. Guards the class invariant that every set bit of 367 *code* corresponds to a real axis, so an over-wide code is rejected at 368 construction time rather than producing a silently truncated stagger. 369 370 Raises 371 ------ 372 ValueError 373 If *code* has any bits set at position *ndim* or higher (i.e. 374 ``code >= 2**ndim``). 375 376 Examples 377 -------- 378 >>> from psi_io.mesh import Mesh 379 >>> Mesh(code=0b100, ndim=3).code # valid: 3 bits, MSB is bit 2 380 4 381 >>> import pytest 382 >>> with pytest.raises(ValueError): 383 ... Mesh(code=0b1000, ndim=3) # 4-bit value in 3-bit field 384 """ 385 mask = (1 << self.ndim) - 1 386 if self.code & ~mask: 387 raise ValueError(f"Code 0b{self.code:b} exceeds {self.ndim} bits.") 388 389 def __repr__(self) -> str: 390 """Return an unambiguous string representation of this :class:`Mesh`. 391 392 The representation uses the human-readable stagger labels from 393 :meth:`__str__` rather than the raw integer code. 394 395 Returns 396 ------- 397 out : str 398 String of the form ``'Mesh(<stagger>)'``, e.g. 399 ``'Mesh(HALF, MAIN, MAIN)'``. 400 401 Examples 402 -------- 403 >>> from psi_io.mesh import Mesh 404 >>> repr(Mesh.parse(0b101, ndim=3)) 405 'Mesh(HALF, MAIN, HALF)' 406 """ 407 return f"Mesh({self})" 408 409 def __len__(self) -> int: 410 """Return the number of axes encoded by this :class:`Mesh`. 411 412 Returns 413 ------- 414 out : int 415 Equal to :attr:`ndim`. 416 417 Examples 418 -------- 419 >>> from psi_io.mesh import Mesh 420 >>> len(Mesh.parse(0b101, ndim=3)) 421 3 422 """ 423 return self.ndim 424 425 def __bool__(self) -> bool: 426 """Return ``True`` if any axis is on the half mesh. 427 428 Returns 429 ------- 430 out : bool 431 ``True`` when :attr:`code` is non-zero; ``False`` when all axes 432 are on the main mesh (code ``0``). 433 434 Examples 435 -------- 436 >>> from psi_io.mesh import Mesh 437 >>> bool(Mesh.parse('half', ndim=3)) 438 True 439 >>> bool(Mesh.parse('main', ndim=3)) 440 False 441 """ 442 return self.code != 0 443 444 def __index__(self) -> int: 445 """Return the raw integer mesh code for use in index contexts. 446 447 Allows a :class:`Mesh` to be used wherever a plain integer code is 448 expected (e.g., passed directly to :func:`remesh_array` as *imesh*). 449 450 Returns 451 ------- 452 out : int 453 Equal to :attr:`code`. 454 455 Examples 456 -------- 457 >>> from psi_io.mesh import Mesh 458 >>> m = Mesh.parse(0b110, ndim=3) 459 >>> int(m) 460 6 461 >>> hex(m) 462 '0x6' 463 """ 464 return self.code 465 466 def __lt__(self, other: Mesh | int) -> bool: 467 """Compare this :class:`Mesh` to *other* by code value. 468 469 Parameters 470 ---------- 471 other : Mesh | int 472 The object to compare against. If a :class:`Mesh`, its 473 ``ndim`` must equal ``self.ndim``. If an :class:`int`, the 474 comparison is against the raw code. 475 476 Returns 477 ------- 478 out : bool 479 ``True`` when ``self.code < other`` (or ``other.code``). 480 481 Raises 482 ------ 483 ValueError 484 If *other* is a :class:`Mesh` with a different ``ndim``. 485 486 Examples 487 -------- 488 >>> from psi_io.mesh import Mesh 489 >>> Mesh.parse(0b001, ndim=3) < Mesh.parse(0b100, ndim=3) 490 True 491 >>> Mesh.parse(0b100, ndim=3) < 3 492 False 493 """ 494 if isinstance(other, Mesh): 495 if self.ndim != other.ndim: 496 raise ValueError(f"Cannot compare MeshCodes with different ndim: {self.ndim} vs {other.ndim}.") 497 return self.code < other.code 498 if isinstance(other, int): 499 return self.code < other 500 return NotImplemented 501 502 def __getitem__(self, item: int | slice) -> Mesh: 503 """Return a sub-:class:`Mesh` for the axis or slice specified by *item*. 504 505 Parameters 506 ---------- 507 item : int | slice 508 Axis index or slice. Negative integer indices are supported. 509 Slicing follows the same semantics as Python sequences. 510 511 Returns 512 ------- 513 out : Mesh 514 A new :class:`Mesh` containing only the selected axis or axes. 515 The ``ndim`` of the result equals ``1`` for integer indexing and 516 ``len(range(*item.indices(self.ndim)))`` for slice indexing. 517 518 Raises 519 ------ 520 IndexError 521 If an integer *item* is out of range for this :class:`Mesh`. 522 TypeError 523 If *item* is neither an :class:`int` nor a :class:`slice`. 524 525 Examples 526 -------- 527 >>> from psi_io.mesh import Mesh 528 >>> m = Mesh.parse(0b101, ndim=3) # HALF, MAIN, HALF 529 >>> str(m[0]) 530 'HALF' 531 >>> str(m[1]) 532 'MAIN' 533 >>> str(m[:2]) 534 'HALF, MAIN' 535 """ 536 if isinstance(item, int): 537 if item < 0: 538 item += self.ndim 539 if not 0 <= item < self.ndim: 540 raise IndexError(f"Index {item} out of range for Mesh with ndim={self.ndim}.") 541 return Mesh((self.code >> (self.ndim - 1 - item)) & 1, 1) 542 if isinstance(item, slice): 543 indices = range(*item.indices(self.ndim)) 544 code = 0 545 for i in indices: 546 code = (code << 1) | ((self.code >> (self.ndim - 1 - i)) & 1) 547 return Mesh(code, len(indices)) 548 raise TypeError(f"Indices must be integers or slices, not {type(item).__name__!r}.") 549 550 def __iter__(self) -> Generator[bool, None, None]: 551 """Yield per-axis half-mesh flags MSB-first (logical axis order). 552 553 Yields ``True`` for each axis on the half mesh, ``False`` for main. 554 Iterating the result of :meth:`remesh` gives the remesh flags directly. 555 556 Yields 557 ------ 558 flag : bool 559 ``True`` if the current axis is on the half mesh; ``False`` if on 560 the main mesh. Axes are yielded in logical order (MSB first). 561 562 Examples 563 -------- 564 >>> from psi_io.mesh import Mesh 565 >>> list(Mesh.parse(0b101, ndim=3)) 566 [True, False, True] 567 >>> list(Mesh.parse('main', ndim=3)) 568 [False, False, False] 569 """ 570 for i in range(self.ndim): 571 yield bool((self.code >> (self.ndim - 1 - i)) & 1) 572 573 def __reversed__(self) -> Generator[bool, None, None]: 574 """Yield per-axis half-mesh flags in reverse (LSB-first) order. 575 576 Equivalent to iterating over :meth:`reverse`. 577 578 Yields 579 ------ 580 flag : bool 581 ``True`` if the current axis is on the half mesh; ``False`` if on 582 the main mesh. Axes are yielded in reverse logical order (LSB 583 first). 584 585 Examples 586 -------- 587 >>> from psi_io.mesh import Mesh 588 >>> list(reversed(Mesh.parse(0b101, ndim=3))) 589 [True, False, True] 590 >>> list(reversed(Mesh.parse(0b100, ndim=3))) 591 [False, False, True] 592 """ 593 return iter(self.reverse()) 594 595 def __str__(self) -> str: 596 """Return a human-readable per-axis stagger string. 597 598 Returns 599 ------- 600 out : str 601 Comma-separated ``'HALF'``/``'MAIN'`` labels, one per axis in 602 logical order (MSB first). For example ``'HALF, MAIN, MAIN'``. 603 604 Examples 605 -------- 606 >>> from psi_io.mesh import Mesh 607 >>> str(Mesh.parse(0b100, ndim=3)) 608 'HALF, MAIN, MAIN' 609 >>> str(Mesh.parse('main', ndim=2)) 610 'MAIN, MAIN' 611 """ 612 return ', '.join( 613 'HALF' if (self.code >> (self.ndim - 1 - i)) & 1 else 'MAIN' 614 for i in range(self.ndim) 615 ) 616 617 def __rshift__(self, other: Optional[MeshLike]) -> tuple[bool]: 618 """Return remesh flags for the transition ``self`` → ``other``. 619 620 Syntactic sugar for :meth:`remesh`. Each ``True`` in the result 621 indicates an axis that must be averaged (half → main) when 622 transforming data from this mesh to *other*. 623 624 Parameters 625 ---------- 626 other : MeshLike | None 627 Target mesh stagger in any form accepted by :data:`MeshLike`. 628 ``None`` is a no-op (returns all ``False``). 629 630 Returns 631 ------- 632 out : tuple[bool, ...] 633 Per-axis remesh flags; ``True`` means the axis needs averaging. 634 635 Examples 636 -------- 637 >>> from psi_io.mesh import Mesh 638 >>> src = Mesh.parse(0b111, ndim=3) 639 >>> src >> 'main' 640 (True, True, True) 641 >>> src >> None 642 (False, False, False) 643 """ 644 return self.remesh(other) 645
[docs] 646 def reverse(self) -> Mesh: 647 """Return a new :class:`Mesh` with the axis order reversed (MSB to LSB). 648 649 Useful for converting between C-order and F-order axis conventions, 650 where the first physical axis becomes the last numpy axis or vice 651 versa. 652 653 Returns 654 ------- 655 out : Mesh 656 A new :class:`Mesh` with the same ``ndim`` but with the bit order 657 of :attr:`code` reversed so that what was the MSB becomes the LSB. 658 659 Examples 660 -------- 661 >>> from psi_io.mesh import Mesh 662 >>> m = Mesh.parse(0b100, ndim=3) # HALF, MAIN, MAIN 663 >>> str(m.reverse()) 664 'MAIN, MAIN, HALF' 665 >>> str(Mesh.parse(0b110, ndim=3).reverse()) 666 'MAIN, HALF, HALF' 667 """ 668 result, code = 0, self.code 669 for _ in range(self.ndim): 670 result = (result << 1) | (code & 1) 671 code >>= 1 672 return Mesh(result, self.ndim)
673
[docs] 674 @functools.singledispatchmethod 675 @classmethod 676 def parse(cls, mesh_code: MeshLike, *args: Any): 677 """Normalize *mesh_code* into a :class:`Mesh`. 678 679 Parameters 680 ---------- 681 mesh_code : MeshLike 682 Stagger specification in any accepted form: an integer, the 683 ``'main'``/``'half'`` shorthands, a per-axis sequence of tokens 684 (``0``/``1``, ``'m'``/``'h'``, ``True``/``False``, etc.), or an 685 existing :class:`Mesh` (returned as-is). 686 ndim : int, optional 687 Number of dimensions. Required when *mesh_code* is an integer or 688 the ``'main'``/``'half'`` shorthand; inferred from sequence length 689 otherwise. If both are provided they must agree. 690 691 Returns 692 ------- 693 out : Mesh 694 Normalized :class:`Mesh` with the specified stagger and 695 dimensionality. 696 697 Raises 698 ------ 699 ValueError 700 If *ndim* is required but not supplied, if *ndim* conflicts with 701 the sequence length, or if any token is unrecognized. 702 TypeError 703 If *mesh_code* is of an unsupported type. 704 705 Examples 706 -------- 707 >>> from psi_io.mesh import Mesh 708 >>> str(Mesh.parse(0b100, ndim=3)) 709 'HALF, MAIN, MAIN' 710 >>> str(Mesh.parse('half', ndim=2)) 711 'HALF, HALF' 712 >>> str(Mesh.parse([1, 0, 1], ndim=3)) 713 'HALF, MAIN, HALF' 714 >>> m = Mesh.parse(0b010, ndim=3) 715 >>> Mesh.parse(m) is m 716 True 717 """ 718 if isinstance(mesh_code, Mesh): 719 return mesh_code 720 raise TypeError(f"Cannot convert {type(mesh_code).__name__!r} to Mesh.")
721 722 @parse.register(bool) 723 @classmethod 724 def _(cls, mesh_code, ndim: int): 725 return cls(0 if not mesh_code else (1 << ndim) - 1, ndim) 726 727 @parse.register(int) 728 @classmethod 729 def _(cls, mesh_code, ndim: int): 730 return cls(mesh_code, ndim) 731 732 @parse.register(str) 733 @classmethod 734 def _(cls, mesh_code, ndim: Optional[int] = None): 735 if mesh_code.lower() in ('main', 'half'): 736 if ndim is None: 737 raise ValueError("ndim is required for 'main'/'half' shorthands.") 738 return cls(0 if mesh_code == 'main' else (1 << ndim) - 1, ndim) 739 seq = list(mesh_code) 740 return cls.parse(seq, ndim) 741
[docs] 742 @parse.register(ABCSequence) 743 @classmethod 744 def _(cls, mesh_code, ndim: Optional[int] = None): 745 if ndim is not None and ndim != len(mesh_code): 746 raise ValueError(f"ndim={ndim} conflicts with sequence length {len(mesh_code)}.") 747 code = 0 748 for c in mesh_code: 749 bit = _MESH_CODE_REVERSE_MAPPING.get(str(c).lower()) 750 if bit is None: 751 raise ValueError(f"Invalid mesh code token {c!r}.") 752 code = (code << 1) | bit 753 return cls(code, ndim or len(mesh_code))
754
[docs] 755 @_coerce_mesh_target 756 def remesh(self, target: Optional[MeshLike], strict: bool = True) -> tuple[bool]: 757 """Return per-axis flags indicating which axes require averaging. 758 759 An axis needs remeshing when the source is on the half mesh (``1``) and 760 the target is on the main mesh (``0``). By default, requesting a 761 main-to-half transition (upsampling) raises a :exc:`ValueError`. 762 763 Parameters 764 ---------- 765 target : MeshLike | None 766 Desired output stagger; coerced to :class:`Mesh` via 767 :func:`_coerce_mesh_target` if necessary. ``None`` is treated 768 as ``self`` (no-op: returns all ``False``). 769 strict : bool, optional 770 If ``True`` (default), raise :exc:`ValueError` when any axis in 771 *target* is on the half mesh but the corresponding axis in 772 ``self`` is already on the main mesh (main → half is not 773 supported). Set to ``False`` to silently ignore such axes. 774 775 Returns 776 ------- 777 out : tuple[bool, ...] 778 Tuple of per-axis boolean flags in logical axis order (MSB first). 779 ``True`` at position *i* means axis *i* must be averaged (half → 780 main); ``False`` means no averaging is needed on that axis. 781 782 Raises 783 ------ 784 ValueError 785 If *strict* is ``True`` and *target* requests a half-mesh axis 786 where ``self`` is already on the main mesh. 787 788 Examples 789 -------- 790 >>> from psi_io.mesh import Mesh 791 >>> src = Mesh.parse(0b111, ndim=3) # all-half 792 >>> src.remesh('main') 793 (True, True, True) 794 >>> src.remesh(None) # no-op: target == self 795 (False, False, False) 796 >>> src.remesh(0b101) # only theta needs averaging 797 (False, True, False) 798 """ 799 mask = (1 << self.ndim) - 1 800 if strict and (~self.code & mask) & target.code: 801 raise ValueError(f"Cannot remesh from {self} to {target}: main → half transitions are not supported.") 802 return tuple(Mesh(self.code & (~target.code & mask), self.ndim))
803 804 805def _average_adjacent(arr: np.ndarray, 806 axis: int 807 ) -> np.ndarray: 808 """Return the mean of adjacent element pairs along *axis*, reducing its size by one. 809 810 Computes ``0.5 * (arr[..., :-1, ...] + arr[..., 1:, ...])`` where the 811 ellipsis notation represents all other axes. The result has the same 812 shape as *arr* on every axis except *axis*, which shrinks by one element. 813 814 Parameters 815 ---------- 816 arr : np.ndarray 817 Input array of any dtype and number of dimensions. 818 axis : int 819 Axis along which to average adjacent pairs. Must satisfy 820 ``0 <= axis < arr.ndim``. 821 822 Returns 823 ------- 824 out : np.ndarray 825 Array with the same shape as *arr* except ``out.shape[axis] == 826 arr.shape[axis] - 1``. The dtype is determined by NumPy's 827 float promotion rules (typically ``float64`` for integer input). 828 829 Raises 830 ------ 831 ValueError 832 If ``arr.shape[axis] < 2``, because at least two elements are needed 833 to form one averaged pair. 834 835 Examples 836 -------- 837 >>> import numpy as np 838 >>> from psi_io.mesh import _average_adjacent 839 >>> arr = np.array([1.0, 3.0, 5.0, 7.0]) 840 >>> _average_adjacent(arr, axis=0) 841 array([2., 4., 6.]) 842 >>> arr2d = np.array([[0.0, 2.0], [4.0, 6.0], [8.0, 10.0]]) 843 >>> _average_adjacent(arr2d, axis=0).shape 844 (2, 2) 845 """ 846 if arr.shape[axis] < 2: 847 raise ValueError(f"Cannot remesh axis {axis} with size {arr.shape[axis]}." 848 f" Need at least 2 elements to average adjacent pairs.") 849 slc_lo = [slice(None)] * arr.ndim 850 slc_hi = [slice(None)] * arr.ndim 851 slc_lo[axis] = slice(None, -1) 852 slc_hi[axis] = slice(1, None) 853 return 0.5 * (arr[tuple(slc_lo)] + arr[tuple(slc_hi)]) 854 855 856def _remesh_array(data: np.ndarray, 857 remesh: Iterable[bool] | bool, 858 order: ArrayOrdering = 'F') -> np.ndarray: 859 """Apply adjacent-element averaging on each axis where *remesh* is ``True``. 860 861 Iterates over axes in numpy index order (0, 1, 2, …) and calls 862 :func:`_average_adjacent` on each axis flagged for remeshing. When 863 *order* is ``'F'`` the *remesh* iterable is reversed before the loop so 864 that the logical MSB-first ordering of :class:`Mesh` maps correctly to 865 numpy's last-axis-first Fortran convention. 866 867 Parameters 868 ---------- 869 data : np.ndarray 870 Input array on the source mesh stagger. 871 remesh : Iterable[bool] | bool 872 Per-axis remesh flags in logical axis order (MSB first), as returned 873 by :meth:`Mesh.remesh`. A single :class:`bool` is broadcast to all 874 axes. 875 order : ArrayOrdering, optional 876 Memory-order convention; ``'F'`` (default) reverses *remesh* before 877 iterating so that MSB maps to the last numpy axis. ``'C'`` uses the 878 flags as-is so that MSB maps to the first numpy axis. 879 880 Returns 881 ------- 882 out : np.ndarray 883 Array with averaged values along every flagged axis. Shape is reduced 884 by one on each remeshed axis. 885 886 Examples 887 -------- 888 >>> import numpy as np 889 >>> from psi_io.mesh import _remesh_array 890 >>> arr = np.ones((4, 4, 4)) 891 >>> out = _remesh_array(arr, remesh=[False, False, True], order='F') 892 >>> out.shape # F-order: LSB flag (True) maps to numpy axis 0 (phi) 893 (3, 4, 4) 894 >>> out2 = _remesh_array(arr, remesh=True) 895 >>> out2.shape # all axes averaged 896 (3, 3, 3) 897 """ 898 if isinstance(remesh, bool): 899 remesh = [remesh] * data.ndim 900 if order == 'F': 901 remesh = reversed(remesh) 902 for i, shift in enumerate(remesh): 903 if shift: 904 data = _average_adjacent(data, i) 905 return data 906 907
[docs] 908def remesh_array(data: np.ndarray, 909 imesh: MeshCodeType, 910 omesh: Optional[MeshCodeType] = None, 911 order: ArrayOrdering = 'F') -> np.ndarray: 912 r"""Shift an array from one mesh stagger to another. 913 914 Compares the source mesh *imesh* against the target mesh *omesh* axis by axis 915 and applies adjacent-element averaging on every axis that needs to move from 916 half mesh to main mesh. Only the half → main direction is supported; 917 requesting main → half raises :exc:`ValueError`. 918 919 This is commonly needed before performing interpolation or arithmetic between 920 quantities on different mesh positions. For example, computing the magnitude of the 921 magnetic field requires :math:`B_r`, :math:`B_\theta`, and 922 :math:`B_\varphi` on the same mesh: each must be remeshed from its native 923 stagger (``0b100``, ``0b010``, ``0b001``) to a common target before squaring 924 and summing. 925 926 If *omesh* is ``None``, the array is returned unchanged. 927 928 Parameters 929 ---------- 930 data : np.ndarray 931 Input array on the stagger described by *imesh*. 932 imesh : MeshCodeType 933 Source mesh stagger in any form accepted by :data:`MeshCodeType`. 934 omesh : MeshCodeType | None, optional 935 Target mesh stagger. ``None`` (default) is a no-op. Pass ``0`` or 936 ``'main'`` to move every half-mesh axis to the main mesh. 937 order : ArrayOrdering, optional 938 Memory-order convention controlling how mesh-code bits map to numpy 939 axes; see :data:`ArrayOrdering`. Defaults to ``'F'`` (Fortran / 940 PSI HDF column-major: last numpy axis = :math:`r`). 941 942 Returns 943 ------- 944 out : np.ndarray 945 Array on the target mesh. Each remeshed axis is reduced by one element 946 via adjacent averaging; axes that already match are left unchanged. 947 948 Raises 949 ------ 950 ValueError 951 If any axis in *omesh* requests half mesh where *imesh* is already main 952 (upsampling is not supported). 953 954 See Also 955 -------- 956 Mesh : Compact representation of a multi-axis mesh stagger code. 957 MeshCodeType : Accepted forms for mesh stagger specifications. 958 ArrayOrdering : Memory-order convention for bit-axis mapping. 959 960 Examples 961 -------- 962 Convert a radial magnetic-field array (half-mesh in :math:`r`, the last 963 numpy axis) to the all-main mesh: 964 965 >>> import numpy as np 966 >>> from psi_io.mesh import remesh_array 967 >>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size 968 >>> br_main = remesh_array(br, imesh=0b100, omesh='main') 969 >>> br_main.shape 970 (128, 64, 56) 971 972 Remesh a scalar quantity (all-half, ``0b111``) to all-main: 973 974 >>> rho = np.ones((128, 64, 57)) 975 >>> remesh_array(rho, imesh=0b111, omesh='main').shape 976 (127, 63, 56) 977 978 ``omesh=None`` is a no-op: 979 980 >>> remesh_array(br, imesh=0b100).shape 981 (128, 64, 57) 982 983 No-op when source and target stagger already match: 984 985 >>> remesh_array(br, imesh=0b100, omesh=0b100).shape 986 (128, 64, 57) 987 """ 988 if omesh is None: 989 return data 990 remesh = Mesh.parse(imesh, data.ndim).remesh(omesh) 991 return _remesh_array(data, remesh, order=order)