1r"""Mesh management utilities for PSI staggered grid data.
2
3MAS and POT3D solve their equations on *staggered* (Yee-type) spherical grids
4:math:`(r, \theta, \varphi)`. Different physical quantities are located at
5different positions within each grid cell so that discrete differential operators
6(curl, divergence) are exactly satisfied at the discrete level. Each axis of a
7multi-dimensional output array is independently classified as either:
8
9- **Main mesh** — quantity sampled at the cell-center nodes.
10- **Half mesh** — quantity sampled at the face or edge midpoint, displaced by half a
11 grid spacing along that axis.
12
13Mesh codes
14----------
15A *mesh code* encodes the staggering of every axis in a single compact integer.
16Each binary bit indicates, per axis, whether the data lives on the half mesh
17(``1``) or the main mesh (``0``). With PSI's Fortran column-major HDF convention
18the **most-significant bit maps to the last numpy axis** (the radial :math:`r`
19direction), so a three-bit code reads :math:`(r, \theta, \varphi)` MSB → LSB:
20
21.. list-table::
22 :header-rows: 1
23
24 * - Code
25 - :math:`r`
26 - :math:`\theta`
27 - :math:`\varphi`
28 - Typical quantities
29 * - ``0b100``
30 - half
31 - main
32 - main
33 - :math:`B_r` (MAS)
34 * - ``0b010``
35 - main
36 - half
37 - main
38 - :math:`B_\theta` (MAS)
39 * - ``0b001``
40 - main
41 - main
42 - half
43 - :math:`B_\varphi` (MAS)
44 * - ``0b011``
45 - main
46 - half
47 - half
48 - :math:`v_r`, :math:`J_r` (MAS); :math:`B_r` (POT3D)
49 * - ``0b101``
50 - half
51 - main
52 - half
53 - :math:`v_\theta`, :math:`J_\theta` (MAS); :math:`B_\theta` (POT3D)
54 * - ``0b110``
55 - half
56 - half
57 - main
58 - :math:`v_\varphi`, :math:`J_\varphi` (MAS); :math:`B_\varphi` (POT3D)
59 * - ``0b111``
60 - half
61 - half
62 - half
63 - scalars: :math:`T`, :math:`\rho`, :math:`p`, …
64 * - ``0b000``
65 - main
66 - main
67 - main
68 - all-main; result of remeshing every axis
69
70Accepted input forms for a mesh code are described by :data:`MeshCodeType`
71(integer, string shorthand ``'main'``/``'half'``, or per-axis sequence).
72The memory-order convention is described by :data:`ArrayOrdering`.
73
74Public API
75----------
76:class:`Mesh`
77 Immutable (frozen) dataclass wrapping a binary stagger :attr:`~Mesh.code` and
78 its axis count :attr:`~Mesh.ndim`. Each bit of the code classifies one axis as
79 half mesh (``1``) or main mesh (``0``).
80:data:`MeshCodeType`
81 Type alias for the three accepted forms of a mesh stagger specification.
82:data:`ArrayOrdering`
83 Type alias for the memory-order string (``'F'`` or ``'C'``) accepted by
84 :func:`remesh_array`.
85:func:`remesh_array`
86 Shift an array from one mesh stagger to another by averaging adjacent elements
87 along each axis that needs to move from half mesh to main mesh.
88
89Examples
90--------
91Convert a radial magnetic-field array (half-mesh in :math:`r`, the last numpy
92axis) to the all-main mesh:
93
94>>> import numpy as np
95>>> from psi_io.mesh import remesh_array
96>>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size
97>>> br_main = remesh_array(br, imesh=0b100, omesh='main')
98>>> br_main.shape
99(128, 64, 56)
100
101Remesh a scalar quantity (all-half, ``0b111``) to all-main:
102
103>>> rho = np.ones((128, 64, 57))
104>>> remesh_array(rho, imesh=0b111, omesh='main').shape
105(127, 63, 56)
106"""
107
108from __future__ import annotations
109
110__all__ = [
111 "Mesh",
112 "remesh_array",
113]
114
115import functools
116from collections.abc import Sequence as ABCSequence
117from dataclasses import dataclass
118from types import MappingProxyType
119from typing import Callable, Sequence, Any, Union, Literal, Generator, Optional, Iterable
120
121import numpy as np
122
123
124_MESH_CODE_REVERSE_MAPPING = MappingProxyType({
125 '1': 1, 'h': 1, 'half': 1, 'true': 1,
126 '0': 0, 'm': 0, 'main': 0, 'false': 0
127})
128"""String-token to integer (0/1) lookup used to validate per-axis sequence mesh codes.
129
130Each key is a recognized string representation of a single-axis stagger token.
131The value ``1`` means half mesh; the value ``0`` means main mesh.
132
133Keys
134----
135``'1'``, ``'h'``, ``'half'``, ``'true'``
136 Map to ``1`` (half mesh).
137``'0'``, ``'m'``, ``'main'``, ``'false'``
138 Map to ``0`` (main mesh).
139
140Notes
141-----
142Token matching is performed case-insensitively via :func:`str.lower` before
143lookup. An unrecognized token returns ``None``, which callers treat as an
144error condition.
145
146Examples
147--------
148>>> _MESH_CODE_REVERSE_MAPPING['half']
1491
150>>> _MESH_CODE_REVERSE_MAPPING['m']
1510
152>>> _MESH_CODE_REVERSE_MAPPING.get('x') is None
153True
154"""
155
156
157MeshCodeType = Union[bool, int, Literal['main', 'half'], Sequence[Any]]
158r"""Type alias for mesh stagger specifications accepted by :func:`remesh_array`.
159
160A mesh stagger may be expressed in any of three equivalent forms:
161
162- :class:`int` — binary-encoded stagger, one bit per axis (``1`` = half mesh,
163 ``0`` = main mesh). With PSI's Fortran HDF convention the most-significant
164 bit maps to the last numpy axis (:math:`r`). For example, ``0b100`` places
165 the array on the half mesh only along :math:`r` (the last axis).
166- ``'main'`` or ``'half'`` — string shorthand that applies the same stagger to
167 every axis uniformly.
168- :class:`~typing.Sequence` — one element per array dimension; each element may
169 be ``0``, ``1``, ``'m'``, ``'h'``, ``'main'``, ``'half'``, ``'true'``, or
170 ``'false'``.
171
172Examples
173--------
174All three forms below encode the same 3-D stagger (half only along :math:`r`):
175
176>>> from psi_io.mesh import Mesh
177>>> str(Mesh.parse(0b100, ndim=3))
178'HALF, MAIN, MAIN'
179>>> str(Mesh.parse([True, False, False], ndim=3))
180'HALF, MAIN, MAIN'
181"""
182
183MeshLike = Union['Mesh', MeshCodeType]
184r"""Type alias for any accepted mesh specification, including an existing :class:`Mesh`.
185
186This is a superset of :data:`MeshCodeType`; it additionally accepts a :class:`Mesh`
187instance, which is passed through unchanged.
188
189See Also
190--------
191MeshCodeType : Accepted forms that do not include :class:`Mesh` itself.
192Mesh.parse : Normalizes any :data:`MeshLike` value into a :class:`Mesh`.
193
194Examples
195--------
196>>> from psi_io.mesh import Mesh
197>>> m = Mesh.parse(0b101, ndim=3)
198>>> Mesh.parse(m, ndim=3) is m # Mesh passthrough
199True
200>>> str(Mesh.parse('half', ndim=2))
201'HALF, HALF'
202"""
203
204ArrayOrdering = Literal['F', 'C']
205r"""Type alias for the memory-order convention accepted by :func:`remesh_array`.
206
207Controls how the bits of a :data:`MeshCodeType` integer map to numpy array axes.
208
209``'F'``
210 Fortran (column-major) order — the default for PSI data. Because PSI HDF
211 files are written by Fortran code, the physical ``(r, θ, φ)`` axis ordering
212 is **reversed** in numpy storage: the **last** numpy axis corresponds to
213 :math:`r`, the middle to :math:`\theta`, and the **first** to :math:`\varphi`.
214 The most-significant bit of the mesh code therefore maps to the last numpy axis.
215 Use this setting whenever the array was loaded directly from a PSI HDF file.
216``'C'``
217 C (row-major) order. Use when the array has been transposed to numpy-native
218 axis order (first axis = first physical coordinate, e.g. shape ``(Nr, Nθ, Nφ)``),
219 so that the most-significant bit maps to the first numpy axis.
220
221Examples
222--------
223>>> import numpy as np
224>>> from psi_io.mesh import remesh_array
225>>> arr = np.ones((57, 64, 128)) # C-order: shape (Nr, Nθ, Nφ); Nr is half-mesh
226>>> out = remesh_array(arr, imesh=0b100, omesh='main', order='C')
227>>> out.shape
228(56, 64, 128)
229"""
230
231
232def _coerce_mesh_target(method: Callable) -> Callable:
233 """Coerce the *target* argument of a :class:`Mesh` method to a :class:`Mesh`.
234
235 Decorator that normalizes the second positional argument (*target*) of a
236 bound :class:`Mesh` method. If *target* is already a :class:`Mesh` its
237 ``ndim`` is verified to match ``self.ndim``. If *target* is ``None`` it
238 is replaced by ``self`` (a no-op target). Otherwise :meth:`Mesh.parse` is
239 called with ``ndim=self.ndim`` to produce the normalized :class:`Mesh`.
240
241 Parameters
242 ----------
243 method : Callable
244 The bound :class:`Mesh` method whose second positional argument should
245 be coerced. The wrapper preserves the method's ``__name__``,
246 ``__doc__``, and other attributes via :func:`functools.wraps`.
247
248 Returns
249 -------
250 wrapper : Callable
251 Wrapped version of *method* with automatic coercion of the *target*
252 argument.
253
254 Raises
255 ------
256 ValueError
257 If *target* is a :class:`Mesh` whose ``ndim`` differs from
258 ``self.ndim``.
259
260 Examples
261 --------
262 The decorator is applied internally; here is the observable behavior it
263 enables on :meth:`Mesh.remesh`:
264
265 >>> from psi_io.mesh import Mesh
266 >>> m = Mesh.parse(0b111, ndim=3)
267 >>> m.remesh('main') # string target is coerced automatically
268 (True, True, True)
269 >>> m.remesh(None) # None is replaced by self → no-op
270 (False, False, False)
271 """
272 @functools.wraps(method)
273 def wrapper(self: 'Mesh', target: Any, *args: Any, **kwargs: Any) -> Any:
274 if isinstance(target, Mesh):
275 if self.ndim != target.ndim:
276 raise ValueError(f"ndim mismatch: {self.ndim} vs {target.ndim}.")
277 elif target is None:
278 target = self
279 else:
280 target = Mesh.parse(target, ndim=self.ndim)
281 return method(self, target, *args, **kwargs)
282 return wrapper
283
284
[docs]
285@functools.total_ordering
286@dataclass(frozen=True)
287class Mesh:
288 r"""Compact, immutable representation of a multi-axis mesh stagger code.
289
290 A :class:`Mesh` is a frozen dataclass that bundles two integers: a binary
291 stagger *code* and the axis count *ndim*. Together they describe, for a
292 *ndim*-dimensional array, which axes are sampled on the **half mesh** (face or
293 edge midpoints, bit ``1``) versus the **main mesh** (cell-center nodes, bit
294 ``0``). Because it replaces the legacy two-member ``Mesh`` enum, a single
295 instance now encodes the stagger of *every* axis at once rather than one axis
296 per enum member.
297
298 The bit-to-axis mapping follows PSI's Fortran column-major HDF convention: the
299 **most-significant bit maps to the first logical axis** (physical :math:`r`),
300 descending to the least-significant bit at the last axis (:math:`\varphi`).
301 For a 3-bit code the axes therefore read :math:`(r, \theta, \varphi)` from
302 MSB to LSB. For example, ``code=0b100, ndim=3`` means :math:`r` is on the
303 half mesh while :math:`\theta` and :math:`\varphi` are on the main mesh
304 (the MAS :math:`B_r` stagger).
305
306 Once built, a :class:`Mesh` supports rich operations: it iterates as a
307 sequence of per-axis booleans (:meth:`__iter__`), indexes to a single-axis
308 :class:`Mesh` (:meth:`__getitem__`), reports its remesh requirements against a
309 target via :meth:`remesh` / the ``>>`` operator, reverses axis order with
310 :meth:`reverse`, and acts as a plain integer code in index contexts
311 (:meth:`__index__`).
312
313 Prefer constructing via :meth:`parse` rather than calling the constructor
314 directly: :meth:`parse` accepts every form described by :data:`MeshCodeType`
315 (integers, the shorthand strings ``'main'``/``'half'``, per-axis token
316 strings such as ``'MMH'``, and per-axis boolean/int sequences) and infers or
317 validates *ndim* for you. The direct constructor accepts only an explicit
318 integer *code* and *ndim*.
319
320 Parameters
321 ----------
322 code : int
323 Binary-encoded stagger integer. Bit ``i`` (counting from the LSB) sets
324 the stagger of logical axis ``ndim - 1 - i``: ``1`` for half mesh, ``0``
325 for main mesh. Must fit within *ndim* bits (i.e. ``0 <= code < 2**ndim``).
326 ndim : int
327 Number of array dimensions (and therefore significant bits) represented
328 by this code. Fixes the width of the stagger field and the length of the
329 iterated/indexed sequence.
330
331 Raises
332 ------
333 ValueError
334 If *code* has any bit set at or above position *ndim* (i.e.
335 ``code >= 2**ndim``), since that bit could not correspond to a real axis.
336
337 See Also
338 --------
339 Mesh.parse : Build a :class:`Mesh` from any :data:`MeshCodeType` form.
340 Mesh.remesh : Per-axis flags for moving from this stagger to a target.
341 remesh_array : Apply an actual half-to-main mesh shift to a NumPy array.
342
343 Examples
344 --------
345 >>> from psi_io.mesh import Mesh
346 >>> str(Mesh.parse(0b100, ndim=3))
347 'HALF, MAIN, MAIN'
348 >>> str(Mesh.parse('MMH', ndim=3))
349 'MAIN, MAIN, HALF'
350 >>> str(Mesh.parse([True, False, True], ndim=3))
351 'HALF, MAIN, HALF'
352
353 The direct constructor takes an explicit code and dimension count:
354
355 >>> Mesh(code=0b100, ndim=3)
356 Mesh(HALF, MAIN, MAIN)
357 """
358
359 code: int
360 ndim: int
361
362 def __post_init__(self) -> None:
363 """Validate that *code* fits within *ndim* bits.
364
365 Runs automatically after the dataclass constructor assigns :attr:`code`
366 and :attr:`ndim`. Guards the class invariant that every set bit of
367 *code* corresponds to a real axis, so an over-wide code is rejected at
368 construction time rather than producing a silently truncated stagger.
369
370 Raises
371 ------
372 ValueError
373 If *code* has any bits set at position *ndim* or higher (i.e.
374 ``code >= 2**ndim``).
375
376 Examples
377 --------
378 >>> from psi_io.mesh import Mesh
379 >>> Mesh(code=0b100, ndim=3).code # valid: 3 bits, MSB is bit 2
380 4
381 >>> import pytest
382 >>> with pytest.raises(ValueError):
383 ... Mesh(code=0b1000, ndim=3) # 4-bit value in 3-bit field
384 """
385 mask = (1 << self.ndim) - 1
386 if self.code & ~mask:
387 raise ValueError(f"Code 0b{self.code:b} exceeds {self.ndim} bits.")
388
389 def __repr__(self) -> str:
390 """Return an unambiguous string representation of this :class:`Mesh`.
391
392 The representation uses the human-readable stagger labels from
393 :meth:`__str__` rather than the raw integer code.
394
395 Returns
396 -------
397 out : str
398 String of the form ``'Mesh(<stagger>)'``, e.g.
399 ``'Mesh(HALF, MAIN, MAIN)'``.
400
401 Examples
402 --------
403 >>> from psi_io.mesh import Mesh
404 >>> repr(Mesh.parse(0b101, ndim=3))
405 'Mesh(HALF, MAIN, HALF)'
406 """
407 return f"Mesh({self})"
408
409 def __len__(self) -> int:
410 """Return the number of axes encoded by this :class:`Mesh`.
411
412 Returns
413 -------
414 out : int
415 Equal to :attr:`ndim`.
416
417 Examples
418 --------
419 >>> from psi_io.mesh import Mesh
420 >>> len(Mesh.parse(0b101, ndim=3))
421 3
422 """
423 return self.ndim
424
425 def __bool__(self) -> bool:
426 """Return ``True`` if any axis is on the half mesh.
427
428 Returns
429 -------
430 out : bool
431 ``True`` when :attr:`code` is non-zero; ``False`` when all axes
432 are on the main mesh (code ``0``).
433
434 Examples
435 --------
436 >>> from psi_io.mesh import Mesh
437 >>> bool(Mesh.parse('half', ndim=3))
438 True
439 >>> bool(Mesh.parse('main', ndim=3))
440 False
441 """
442 return self.code != 0
443
444 def __index__(self) -> int:
445 """Return the raw integer mesh code for use in index contexts.
446
447 Allows a :class:`Mesh` to be used wherever a plain integer code is
448 expected (e.g., passed directly to :func:`remesh_array` as *imesh*).
449
450 Returns
451 -------
452 out : int
453 Equal to :attr:`code`.
454
455 Examples
456 --------
457 >>> from psi_io.mesh import Mesh
458 >>> m = Mesh.parse(0b110, ndim=3)
459 >>> int(m)
460 6
461 >>> hex(m)
462 '0x6'
463 """
464 return self.code
465
466 def __lt__(self, other: Mesh | int) -> bool:
467 """Compare this :class:`Mesh` to *other* by code value.
468
469 Parameters
470 ----------
471 other : Mesh | int
472 The object to compare against. If a :class:`Mesh`, its
473 ``ndim`` must equal ``self.ndim``. If an :class:`int`, the
474 comparison is against the raw code.
475
476 Returns
477 -------
478 out : bool
479 ``True`` when ``self.code < other`` (or ``other.code``).
480
481 Raises
482 ------
483 ValueError
484 If *other* is a :class:`Mesh` with a different ``ndim``.
485
486 Examples
487 --------
488 >>> from psi_io.mesh import Mesh
489 >>> Mesh.parse(0b001, ndim=3) < Mesh.parse(0b100, ndim=3)
490 True
491 >>> Mesh.parse(0b100, ndim=3) < 3
492 False
493 """
494 if isinstance(other, Mesh):
495 if self.ndim != other.ndim:
496 raise ValueError(f"Cannot compare MeshCodes with different ndim: {self.ndim} vs {other.ndim}.")
497 return self.code < other.code
498 if isinstance(other, int):
499 return self.code < other
500 return NotImplemented
501
502 def __getitem__(self, item: int | slice) -> Mesh:
503 """Return a sub-:class:`Mesh` for the axis or slice specified by *item*.
504
505 Parameters
506 ----------
507 item : int | slice
508 Axis index or slice. Negative integer indices are supported.
509 Slicing follows the same semantics as Python sequences.
510
511 Returns
512 -------
513 out : Mesh
514 A new :class:`Mesh` containing only the selected axis or axes.
515 The ``ndim`` of the result equals ``1`` for integer indexing and
516 ``len(range(*item.indices(self.ndim)))`` for slice indexing.
517
518 Raises
519 ------
520 IndexError
521 If an integer *item* is out of range for this :class:`Mesh`.
522 TypeError
523 If *item* is neither an :class:`int` nor a :class:`slice`.
524
525 Examples
526 --------
527 >>> from psi_io.mesh import Mesh
528 >>> m = Mesh.parse(0b101, ndim=3) # HALF, MAIN, HALF
529 >>> str(m[0])
530 'HALF'
531 >>> str(m[1])
532 'MAIN'
533 >>> str(m[:2])
534 'HALF, MAIN'
535 """
536 if isinstance(item, int):
537 if item < 0:
538 item += self.ndim
539 if not 0 <= item < self.ndim:
540 raise IndexError(f"Index {item} out of range for Mesh with ndim={self.ndim}.")
541 return Mesh((self.code >> (self.ndim - 1 - item)) & 1, 1)
542 if isinstance(item, slice):
543 indices = range(*item.indices(self.ndim))
544 code = 0
545 for i in indices:
546 code = (code << 1) | ((self.code >> (self.ndim - 1 - i)) & 1)
547 return Mesh(code, len(indices))
548 raise TypeError(f"Indices must be integers or slices, not {type(item).__name__!r}.")
549
550 def __iter__(self) -> Generator[bool, None, None]:
551 """Yield per-axis half-mesh flags MSB-first (logical axis order).
552
553 Yields ``True`` for each axis on the half mesh, ``False`` for main.
554 Iterating the result of :meth:`remesh` gives the remesh flags directly.
555
556 Yields
557 ------
558 flag : bool
559 ``True`` if the current axis is on the half mesh; ``False`` if on
560 the main mesh. Axes are yielded in logical order (MSB first).
561
562 Examples
563 --------
564 >>> from psi_io.mesh import Mesh
565 >>> list(Mesh.parse(0b101, ndim=3))
566 [True, False, True]
567 >>> list(Mesh.parse('main', ndim=3))
568 [False, False, False]
569 """
570 for i in range(self.ndim):
571 yield bool((self.code >> (self.ndim - 1 - i)) & 1)
572
573 def __reversed__(self) -> Generator[bool, None, None]:
574 """Yield per-axis half-mesh flags in reverse (LSB-first) order.
575
576 Equivalent to iterating over :meth:`reverse`.
577
578 Yields
579 ------
580 flag : bool
581 ``True`` if the current axis is on the half mesh; ``False`` if on
582 the main mesh. Axes are yielded in reverse logical order (LSB
583 first).
584
585 Examples
586 --------
587 >>> from psi_io.mesh import Mesh
588 >>> list(reversed(Mesh.parse(0b101, ndim=3)))
589 [True, False, True]
590 >>> list(reversed(Mesh.parse(0b100, ndim=3)))
591 [False, False, True]
592 """
593 return iter(self.reverse())
594
595 def __str__(self) -> str:
596 """Return a human-readable per-axis stagger string.
597
598 Returns
599 -------
600 out : str
601 Comma-separated ``'HALF'``/``'MAIN'`` labels, one per axis in
602 logical order (MSB first). For example ``'HALF, MAIN, MAIN'``.
603
604 Examples
605 --------
606 >>> from psi_io.mesh import Mesh
607 >>> str(Mesh.parse(0b100, ndim=3))
608 'HALF, MAIN, MAIN'
609 >>> str(Mesh.parse('main', ndim=2))
610 'MAIN, MAIN'
611 """
612 return ', '.join(
613 'HALF' if (self.code >> (self.ndim - 1 - i)) & 1 else 'MAIN'
614 for i in range(self.ndim)
615 )
616
617 def __rshift__(self, other: Optional[MeshLike]) -> tuple[bool]:
618 """Return remesh flags for the transition ``self`` → ``other``.
619
620 Syntactic sugar for :meth:`remesh`. Each ``True`` in the result
621 indicates an axis that must be averaged (half → main) when
622 transforming data from this mesh to *other*.
623
624 Parameters
625 ----------
626 other : MeshLike | None
627 Target mesh stagger in any form accepted by :data:`MeshLike`.
628 ``None`` is a no-op (returns all ``False``).
629
630 Returns
631 -------
632 out : tuple[bool, ...]
633 Per-axis remesh flags; ``True`` means the axis needs averaging.
634
635 Examples
636 --------
637 >>> from psi_io.mesh import Mesh
638 >>> src = Mesh.parse(0b111, ndim=3)
639 >>> src >> 'main'
640 (True, True, True)
641 >>> src >> None
642 (False, False, False)
643 """
644 return self.remesh(other)
645
[docs]
646 def reverse(self) -> Mesh:
647 """Return a new :class:`Mesh` with the axis order reversed (MSB to LSB).
648
649 Useful for converting between C-order and F-order axis conventions,
650 where the first physical axis becomes the last numpy axis or vice
651 versa.
652
653 Returns
654 -------
655 out : Mesh
656 A new :class:`Mesh` with the same ``ndim`` but with the bit order
657 of :attr:`code` reversed so that what was the MSB becomes the LSB.
658
659 Examples
660 --------
661 >>> from psi_io.mesh import Mesh
662 >>> m = Mesh.parse(0b100, ndim=3) # HALF, MAIN, MAIN
663 >>> str(m.reverse())
664 'MAIN, MAIN, HALF'
665 >>> str(Mesh.parse(0b110, ndim=3).reverse())
666 'MAIN, HALF, HALF'
667 """
668 result, code = 0, self.code
669 for _ in range(self.ndim):
670 result = (result << 1) | (code & 1)
671 code >>= 1
672 return Mesh(result, self.ndim)
673
[docs]
674 @functools.singledispatchmethod
675 @classmethod
676 def parse(cls, mesh_code: MeshLike, *args: Any):
677 """Normalize *mesh_code* into a :class:`Mesh`.
678
679 Parameters
680 ----------
681 mesh_code : MeshLike
682 Stagger specification in any accepted form: an integer, the
683 ``'main'``/``'half'`` shorthands, a per-axis sequence of tokens
684 (``0``/``1``, ``'m'``/``'h'``, ``True``/``False``, etc.), or an
685 existing :class:`Mesh` (returned as-is).
686 ndim : int, optional
687 Number of dimensions. Required when *mesh_code* is an integer or
688 the ``'main'``/``'half'`` shorthand; inferred from sequence length
689 otherwise. If both are provided they must agree.
690
691 Returns
692 -------
693 out : Mesh
694 Normalized :class:`Mesh` with the specified stagger and
695 dimensionality.
696
697 Raises
698 ------
699 ValueError
700 If *ndim* is required but not supplied, if *ndim* conflicts with
701 the sequence length, or if any token is unrecognized.
702 TypeError
703 If *mesh_code* is of an unsupported type.
704
705 Examples
706 --------
707 >>> from psi_io.mesh import Mesh
708 >>> str(Mesh.parse(0b100, ndim=3))
709 'HALF, MAIN, MAIN'
710 >>> str(Mesh.parse('half', ndim=2))
711 'HALF, HALF'
712 >>> str(Mesh.parse([1, 0, 1], ndim=3))
713 'HALF, MAIN, HALF'
714 >>> m = Mesh.parse(0b010, ndim=3)
715 >>> Mesh.parse(m) is m
716 True
717 """
718 if isinstance(mesh_code, Mesh):
719 return mesh_code
720 raise TypeError(f"Cannot convert {type(mesh_code).__name__!r} to Mesh.")
721
722 @parse.register(bool)
723 @classmethod
724 def _(cls, mesh_code, ndim: int):
725 return cls(0 if not mesh_code else (1 << ndim) - 1, ndim)
726
727 @parse.register(int)
728 @classmethod
729 def _(cls, mesh_code, ndim: int):
730 return cls(mesh_code, ndim)
731
732 @parse.register(str)
733 @classmethod
734 def _(cls, mesh_code, ndim: Optional[int] = None):
735 if mesh_code.lower() in ('main', 'half'):
736 if ndim is None:
737 raise ValueError("ndim is required for 'main'/'half' shorthands.")
738 return cls(0 if mesh_code == 'main' else (1 << ndim) - 1, ndim)
739 seq = list(mesh_code)
740 return cls.parse(seq, ndim)
741
[docs]
742 @parse.register(ABCSequence)
743 @classmethod
744 def _(cls, mesh_code, ndim: Optional[int] = None):
745 if ndim is not None and ndim != len(mesh_code):
746 raise ValueError(f"ndim={ndim} conflicts with sequence length {len(mesh_code)}.")
747 code = 0
748 for c in mesh_code:
749 bit = _MESH_CODE_REVERSE_MAPPING.get(str(c).lower())
750 if bit is None:
751 raise ValueError(f"Invalid mesh code token {c!r}.")
752 code = (code << 1) | bit
753 return cls(code, ndim or len(mesh_code))
754
[docs]
755 @_coerce_mesh_target
756 def remesh(self, target: Optional[MeshLike], strict: bool = True) -> tuple[bool]:
757 """Return per-axis flags indicating which axes require averaging.
758
759 An axis needs remeshing when the source is on the half mesh (``1``) and
760 the target is on the main mesh (``0``). By default, requesting a
761 main-to-half transition (upsampling) raises a :exc:`ValueError`.
762
763 Parameters
764 ----------
765 target : MeshLike | None
766 Desired output stagger; coerced to :class:`Mesh` via
767 :func:`_coerce_mesh_target` if necessary. ``None`` is treated
768 as ``self`` (no-op: returns all ``False``).
769 strict : bool, optional
770 If ``True`` (default), raise :exc:`ValueError` when any axis in
771 *target* is on the half mesh but the corresponding axis in
772 ``self`` is already on the main mesh (main → half is not
773 supported). Set to ``False`` to silently ignore such axes.
774
775 Returns
776 -------
777 out : tuple[bool, ...]
778 Tuple of per-axis boolean flags in logical axis order (MSB first).
779 ``True`` at position *i* means axis *i* must be averaged (half →
780 main); ``False`` means no averaging is needed on that axis.
781
782 Raises
783 ------
784 ValueError
785 If *strict* is ``True`` and *target* requests a half-mesh axis
786 where ``self`` is already on the main mesh.
787
788 Examples
789 --------
790 >>> from psi_io.mesh import Mesh
791 >>> src = Mesh.parse(0b111, ndim=3) # all-half
792 >>> src.remesh('main')
793 (True, True, True)
794 >>> src.remesh(None) # no-op: target == self
795 (False, False, False)
796 >>> src.remesh(0b101) # only theta needs averaging
797 (False, True, False)
798 """
799 mask = (1 << self.ndim) - 1
800 if strict and (~self.code & mask) & target.code:
801 raise ValueError(f"Cannot remesh from {self} to {target}: main → half transitions are not supported.")
802 return tuple(Mesh(self.code & (~target.code & mask), self.ndim))
803
804
805def _average_adjacent(arr: np.ndarray,
806 axis: int
807 ) -> np.ndarray:
808 """Return the mean of adjacent element pairs along *axis*, reducing its size by one.
809
810 Computes ``0.5 * (arr[..., :-1, ...] + arr[..., 1:, ...])`` where the
811 ellipsis notation represents all other axes. The result has the same
812 shape as *arr* on every axis except *axis*, which shrinks by one element.
813
814 Parameters
815 ----------
816 arr : np.ndarray
817 Input array of any dtype and number of dimensions.
818 axis : int
819 Axis along which to average adjacent pairs. Must satisfy
820 ``0 <= axis < arr.ndim``.
821
822 Returns
823 -------
824 out : np.ndarray
825 Array with the same shape as *arr* except ``out.shape[axis] ==
826 arr.shape[axis] - 1``. The dtype is determined by NumPy's
827 float promotion rules (typically ``float64`` for integer input).
828
829 Raises
830 ------
831 ValueError
832 If ``arr.shape[axis] < 2``, because at least two elements are needed
833 to form one averaged pair.
834
835 Examples
836 --------
837 >>> import numpy as np
838 >>> from psi_io.mesh import _average_adjacent
839 >>> arr = np.array([1.0, 3.0, 5.0, 7.0])
840 >>> _average_adjacent(arr, axis=0)
841 array([2., 4., 6.])
842 >>> arr2d = np.array([[0.0, 2.0], [4.0, 6.0], [8.0, 10.0]])
843 >>> _average_adjacent(arr2d, axis=0).shape
844 (2, 2)
845 """
846 if arr.shape[axis] < 2:
847 raise ValueError(f"Cannot remesh axis {axis} with size {arr.shape[axis]}."
848 f" Need at least 2 elements to average adjacent pairs.")
849 slc_lo = [slice(None)] * arr.ndim
850 slc_hi = [slice(None)] * arr.ndim
851 slc_lo[axis] = slice(None, -1)
852 slc_hi[axis] = slice(1, None)
853 return 0.5 * (arr[tuple(slc_lo)] + arr[tuple(slc_hi)])
854
855
856def _remesh_array(data: np.ndarray,
857 remesh: Iterable[bool] | bool,
858 order: ArrayOrdering = 'F') -> np.ndarray:
859 """Apply adjacent-element averaging on each axis where *remesh* is ``True``.
860
861 Iterates over axes in numpy index order (0, 1, 2, …) and calls
862 :func:`_average_adjacent` on each axis flagged for remeshing. When
863 *order* is ``'F'`` the *remesh* iterable is reversed before the loop so
864 that the logical MSB-first ordering of :class:`Mesh` maps correctly to
865 numpy's last-axis-first Fortran convention.
866
867 Parameters
868 ----------
869 data : np.ndarray
870 Input array on the source mesh stagger.
871 remesh : Iterable[bool] | bool
872 Per-axis remesh flags in logical axis order (MSB first), as returned
873 by :meth:`Mesh.remesh`. A single :class:`bool` is broadcast to all
874 axes.
875 order : ArrayOrdering, optional
876 Memory-order convention; ``'F'`` (default) reverses *remesh* before
877 iterating so that MSB maps to the last numpy axis. ``'C'`` uses the
878 flags as-is so that MSB maps to the first numpy axis.
879
880 Returns
881 -------
882 out : np.ndarray
883 Array with averaged values along every flagged axis. Shape is reduced
884 by one on each remeshed axis.
885
886 Examples
887 --------
888 >>> import numpy as np
889 >>> from psi_io.mesh import _remesh_array
890 >>> arr = np.ones((4, 4, 4))
891 >>> out = _remesh_array(arr, remesh=[False, False, True], order='F')
892 >>> out.shape # F-order: LSB flag (True) maps to numpy axis 0 (phi)
893 (3, 4, 4)
894 >>> out2 = _remesh_array(arr, remesh=True)
895 >>> out2.shape # all axes averaged
896 (3, 3, 3)
897 """
898 if isinstance(remesh, bool):
899 remesh = [remesh] * data.ndim
900 if order == 'F':
901 remesh = reversed(remesh)
902 for i, shift in enumerate(remesh):
903 if shift:
904 data = _average_adjacent(data, i)
905 return data
906
907
[docs]
908def remesh_array(data: np.ndarray,
909 imesh: MeshCodeType,
910 omesh: Optional[MeshCodeType] = None,
911 order: ArrayOrdering = 'F') -> np.ndarray:
912 r"""Shift an array from one mesh stagger to another.
913
914 Compares the source mesh *imesh* against the target mesh *omesh* axis by axis
915 and applies adjacent-element averaging on every axis that needs to move from
916 half mesh to main mesh. Only the half → main direction is supported;
917 requesting main → half raises :exc:`ValueError`.
918
919 This is commonly needed before performing interpolation or arithmetic between
920 quantities on different mesh positions. For example, computing the magnitude of the
921 magnetic field requires :math:`B_r`, :math:`B_\theta`, and
922 :math:`B_\varphi` on the same mesh: each must be remeshed from its native
923 stagger (``0b100``, ``0b010``, ``0b001``) to a common target before squaring
924 and summing.
925
926 If *omesh* is ``None``, the array is returned unchanged.
927
928 Parameters
929 ----------
930 data : np.ndarray
931 Input array on the stagger described by *imesh*.
932 imesh : MeshCodeType
933 Source mesh stagger in any form accepted by :data:`MeshCodeType`.
934 omesh : MeshCodeType | None, optional
935 Target mesh stagger. ``None`` (default) is a no-op. Pass ``0`` or
936 ``'main'`` to move every half-mesh axis to the main mesh.
937 order : ArrayOrdering, optional
938 Memory-order convention controlling how mesh-code bits map to numpy
939 axes; see :data:`ArrayOrdering`. Defaults to ``'F'`` (Fortran /
940 PSI HDF column-major: last numpy axis = :math:`r`).
941
942 Returns
943 -------
944 out : np.ndarray
945 Array on the target mesh. Each remeshed axis is reduced by one element
946 via adjacent averaging; axes that already match are left unchanged.
947
948 Raises
949 ------
950 ValueError
951 If any axis in *omesh* requests half mesh where *imesh* is already main
952 (upsampling is not supported).
953
954 See Also
955 --------
956 Mesh : Compact representation of a multi-axis mesh stagger code.
957 MeshCodeType : Accepted forms for mesh stagger specifications.
958 ArrayOrdering : Memory-order convention for bit-axis mapping.
959
960 Examples
961 --------
962 Convert a radial magnetic-field array (half-mesh in :math:`r`, the last
963 numpy axis) to the all-main mesh:
964
965 >>> import numpy as np
966 >>> from psi_io.mesh import remesh_array
967 >>> br = np.ones((128, 64, 57)) # shape (Nφ, Nθ, Nr); Nr is half-mesh size
968 >>> br_main = remesh_array(br, imesh=0b100, omesh='main')
969 >>> br_main.shape
970 (128, 64, 56)
971
972 Remesh a scalar quantity (all-half, ``0b111``) to all-main:
973
974 >>> rho = np.ones((128, 64, 57))
975 >>> remesh_array(rho, imesh=0b111, omesh='main').shape
976 (127, 63, 56)
977
978 ``omesh=None`` is a no-op:
979
980 >>> remesh_array(br, imesh=0b100).shape
981 (128, 64, 57)
982
983 No-op when source and target stagger already match:
984
985 >>> remesh_array(br, imesh=0b100, omesh=0b100).shape
986 (128, 64, 57)
987 """
988 if omesh is None:
989 return data
990 remesh = Mesh.parse(imesh, data.ndim).remesh(omesh)
991 return _remesh_array(data, remesh, order=order)