diff --git a/docs/subplots.py b/docs/subplots.py index ced109c96..a1b309ec9 100644 --- a/docs/subplots.py +++ b/docs/subplots.py @@ -373,8 +373,9 @@ # `~matplotlib.figure.Figure.supxlabel` and `~matplotlib.figure.Figure.supylabel`, # these labels are aligned between gridspec edges rather than figure edges. # #. Supporting five sharing "levels". These values can be passed to `sharex`, -# `sharey`, or `share`, or assigned to :rcraw:`subplots.share`. The levels -# are defined as follows: +# `sharey`, or `share`, or assigned to :rcraw:`subplots.share`. +# UltraPlot supports five explicit sharing levels plus ``'auto'``. +# The levels are defined as follows: # # * ``False`` or ``0``: Axis sharing is disabled. # * ``'labels'``, ``'labs'``, or ``1``: Axis labels are shared, but nothing else. @@ -384,6 +385,14 @@ # in the same row or column of the :class:`~ultraplot.gridspec.GridSpec`; a space # or empty plot will add the labels, but not break the limit sharing. See below # for a more complex example. +# * ``'limits'``, ``'lims'``, or ``2``: As above, plus share limits/scales/ticks. +# * ``True`` or ``3``: As above, plus hide inner tick labels. +# * ``'all'`` or ``4``: As above, plus share limits across the full subplot grid. +# * ``'auto'`` (default): Start from level ``3`` and only share compatible axes. +# This suppresses warnings for mixed axis families (e.g., cartesian + polar). +# +# Explicit sharing levels still force sharing attempts and may warn when +# incompatible axes are encountered. # # The below examples demonstrate the effect of various axis and label sharing # settings on the appearance of several subplot grids. @@ -422,6 +431,20 @@ import ultraplot as uplt import numpy as np +# The default `share='auto'` keeps incompatible axis families unshared. +fig, axs = uplt.subplots(ncols=2, proj=("cart", "polar")) +x = np.linspace(0, 2 * np.pi, 100) +axs[0].plot(x, np.sin(x)) +axs[1].plot(x, np.abs(np.sin(2 * x))) +axs.format( + suptitle="Auto sharing with mixed cartesian and polar axes", + title=("cartesian", "polar"), +) + +# %% +import ultraplot as uplt +import numpy as np + state = np.random.RandomState(51423) # Plots with minimum and maximum sharing settings diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 11835af93..78ac489ec 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1702,21 +1702,39 @@ def shared(paxs): iax._sharey_setup(left) # External axes sharing, sometimes overrides panel axes sharing - # Share x axes - parent, *children = self._get_share_axes("x") - for child in children: - child._sharex_setup(parent) - # Share y axes - parent, *children = self._get_share_axes("y") - for child in children: - child._sharey_setup(parent) - # Global sharing, use the reference subplot because why not + # Share x axes within compatible groups + axes_x = self._get_share_axes("x") + for group in self.figure._partition_share_axes(axes_x, "x"): + if not group: + continue + parent, *children = group + for child in children: + child._sharex_setup(parent) + + # Share y axes within compatible groups + axes_y = self._get_share_axes("y") + for group in self.figure._partition_share_axes(axes_y, "y"): + if not group: + continue + parent, *children = group + for child in children: + child._sharey_setup(parent) + + # Global sharing, use the reference subplot where compatible ref = self.figure._subplot_dict.get(self.figure._refnum, None) - if self is not ref: + if self is not ref and ref is not None: if self.figure._sharex > 3: - self._sharex_setup(ref, labels=False) + ok, reason = self.figure._share_axes_compatible(ref, self, "x") + if ok: + self._sharex_setup(ref, labels=False) + else: + self.figure._warn_incompatible_share("x", ref, self, reason) if self.figure._sharey > 3: - self._sharey_setup(ref, labels=False) + ok, reason = self.figure._share_axes_compatible(ref, self, "y") + if ok: + self._sharey_setup(ref, labels=False) + else: + self.figure._warn_incompatible_share("y", ref, self, reason) def _artist_fully_clipped(self, artist): """ diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index e975356e1..696639beb 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -869,13 +869,31 @@ def _apply_log_formatter_on_scale(self, s): self._update_formatter(s, "log") def set_xscale(self, value, **kwargs): + fig = getattr(self, "figure", None) + if ( + fig is not None + and hasattr(fig, "_is_auto_share_mode") + and fig._is_auto_share_mode("x") + ): + self._unshare(which="x") result = super().set_xscale(value, **kwargs) self._apply_log_formatter_on_scale("x") + if fig is not None and hasattr(fig, "_refresh_auto_share"): + fig._refresh_auto_share("x") return result def set_yscale(self, value, **kwargs): + fig = getattr(self, "figure", None) + if ( + fig is not None + and hasattr(fig, "_is_auto_share_mode") + and fig._is_auto_share_mode("y") + ): + self._unshare(which="y") result = super().set_yscale(value, **kwargs) self._apply_log_formatter_on_scale("y") + if fig is not None and hasattr(fig, "_refresh_auto_share"): + fig._refresh_auto_share("y") return result def _update_formatter( diff --git a/ultraplot/figure.py b/ultraplot/figure.py index aebb9e777..45999ad0c 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -104,7 +104,7 @@ figsize : 2-tuple, optional Tuple specifying the figure ``(width, height)``. sharex, sharey, share \ -: {0, False, 1, 'labels', 'labs', 2, 'limits', 'lims', 3, True, 4, 'all'}, \ +: {0, False, 1, 'labels', 'labs', 2, 'limits', 'lims', 3, True, 4, 'all', 'auto'}, \ default: :rc:`subplots.share` The axis sharing "level" for the *x* axis, *y* axis, or both axes. Options are as follows: @@ -119,6 +119,11 @@ row and leftmost column of subplots. * ``4`` or ``'all'``: As above but also share the axis limits, scales, and tick locations between subplots not in the same row or column. + * ``'auto'``: Start from level ``3`` and only share axes that are compatible + (for example, mixed cartesian and polar axes are kept unshared). + + Explicit sharing levels (``0`` to ``4`` and aliases) still force sharing + attempts and can emit warnings for incompatible axes. spanx, spany, span : bool or {0, 1}, default: :rc:`subplots.span` Whether to use "spanning" axis labels for the *x* axis, *y* axis, or both @@ -550,8 +555,9 @@ class Figure(mfigure.Figure): "1 or 'labels' or 'labs' (share axis labels), " "2 or 'limits' or 'lims' (share axis limits and axis labels), " "3 or True (share axis limits, axis labels, and tick labels), " - "or 4 or 'all' (share axis labels and tick labels in the same gridspec " - "rows and columns and share axis limits across all subplots)." + "4 or 'all' (share axis labels and tick labels in the same gridspec " + "rows and columns and share axis limits across all subplots), " + "or 'auto' (start unshared and share only compatible axes)." ) _space_message = ( "To set the left, right, bottom, top, wspace, or hspace gridspec values, " @@ -795,14 +801,25 @@ def __init__( translate = {"labels": 1, "labs": 1, "limits": 2, "lims": 2, "all": 4} sharex = _not_none(sharex, share, rc["subplots.share"]) sharey = _not_none(sharey, share, rc["subplots.share"]) - sharex = 3 if sharex is True else translate.get(sharex, sharex) - sharey = 3 if sharey is True else translate.get(sharey, sharey) - if sharex not in range(5): - raise ValueError(f"Invalid sharex={sharex!r}. " + self._share_message) - if sharey not in range(5): - raise ValueError(f"Invalid sharey={sharey!r}. " + self._share_message) + + def _normalize_share(value): + auto = isinstance(value, str) and value.lower() == "auto" + if auto: + return 3, True + value = 3 if value is True else translate.get(value, value) + if value not in range(5): + raise ValueError( + f"Invalid sharing value {value!r}. " + self._share_message + ) + return int(value), False + + sharex, sharex_auto = _normalize_share(sharex) + sharey, sharey_auto = _normalize_share(sharey) self._sharex = int(sharex) self._sharey = int(sharey) + self._sharex_auto = bool(sharex_auto) + self._sharey_auto = bool(sharey_auto) + self._share_incompat_warned = False # Translate span and align settings spanx = _not_none( @@ -881,6 +898,210 @@ def draw(self, renderer): self._apply_share_label_groups() super().draw(renderer) + def _is_auto_share_mode(self, which: str) -> bool: + """Return whether a given axis uses auto-share mode.""" + if which not in ("x", "y"): + return False + return bool(getattr(self, f"_share{which}_auto", False)) + + def _axis_unit_signature(self, ax, which: str): + """Return a lightweight signature for axis unit/converter compatibility.""" + axis_obj = getattr(ax, f"{which}axis", None) + if axis_obj is None: + return None + if hasattr(axis_obj, "get_converter"): + converter = axis_obj.get_converter() + else: + converter = getattr(axis_obj, "converter", None) + units = getattr(axis_obj, "units", None) + if hasattr(axis_obj, "get_units"): + units = axis_obj.get_units() + if converter is None and units is None: + return None + if isinstance(units, (str, bytes)): + unit_tag = units + elif units is not None: + unit_tag = type(units).__name__ + else: + unit_tag = None + converter_tag = type(converter).__name__ if converter is not None else None + return (converter_tag, unit_tag) + + def _share_axes_compatible(self, ref, other, which: str): + """Check whether two axes are compatible for sharing along one axis.""" + if ref is None or other is None: + return False, "missing reference axis" + if ref is other: + return True, None + if which not in ("x", "y"): + return True, None + + # External container axes should only share with the same external class. + ref_external = hasattr(ref, "has_external_axes") and ref.has_external_axes() + other_external = ( + hasattr(other, "has_external_axes") and other.has_external_axes() + ) + if ref_external or other_external: + if not (ref_external and other_external): + return False, "external and non-external axes cannot be shared" + ref_ext = ref.get_external_axes() + other_ext = other.get_external_axes() + if type(ref_ext) is not type(other_ext): + return False, "different external projection classes" + + # GeoAxes are only share-compatible with same rectilinear projection family. + ref_geo = isinstance(ref, paxes.GeoAxes) + other_geo = isinstance(other, paxes.GeoAxes) + if ref_geo or other_geo: + if not (ref_geo and other_geo): + return False, "geo and non-geo axes cannot be shared" + if not ref._is_rectilinear() or not other._is_rectilinear(): + return False, "non-rectilinear GeoAxes cannot be shared" + if type(getattr(ref, "projection", None)) is not type( + getattr(other, "projection", None) + ): + return False, "different Geo projection classes" + + # Polar and non-polar should not share. + ref_polar = isinstance(ref, paxes.PolarAxes) + other_polar = isinstance(other, paxes.PolarAxes) + if ref_polar != other_polar: + return False, "polar and non-polar axes cannot be shared" + + # Non-geo external axes are generally Cartesian-like in UltraPlot. + if not ref_geo and not other_geo and not (ref_external or other_external): + if not ( + isinstance(ref, paxes.CartesianAxes) + and isinstance(other, paxes.CartesianAxes) + ): + return False, "different axis families" + + # Scale compatibility along the active axis. + get_scale_ref = getattr(ref, f"get_{which}scale", None) + get_scale_other = getattr(other, f"get_{which}scale", None) + if callable(get_scale_ref) and callable(get_scale_other): + if get_scale_ref() != get_scale_other(): + return False, "different axis scales" + + # Units/converters must match if both are established. + uref = self._axis_unit_signature(ref, which) + uother = self._axis_unit_signature(other, which) + if uref != uother and (uref is not None or uother is not None): + return False, "different axis unit domains" + + return True, None + + def _warn_incompatible_share(self, which: str, ref, other, reason: str) -> None: + """Warn once per figure for explicit incompatible sharing.""" + if self._is_auto_share_mode(which): + return + if bool(self._share_incompat_warned): + return + self._share_incompat_warned = True + warnings._warn_ultraplot( + f"Skipping incompatible {which}-axis sharing for {type(ref).__name__} and {type(other).__name__}: {reason}." + ) + + def _partition_share_axes(self, axes, which: str): + """Partition a candidate share list into compatible sub-groups.""" + groups = [] + for ax in axes: + if ax is None: + continue + placed = False + first_mismatch = None + for group in groups: + ok, reason = self._share_axes_compatible(group[0], ax, which) + if ok: + group.append(ax) + placed = True + break + if first_mismatch is None: + first_mismatch = (group[0], reason) + if not placed: + groups.append([ax]) + if first_mismatch is not None: + ref, reason = first_mismatch + self._warn_incompatible_share(which, ref, ax, reason) + return groups + + def _iter_shared_groups(self, which: str, *, panels: bool = True): + """Yield unique shared groups for one axis direction.""" + if which not in ("x", "y"): + return + get_grouper = f"get_shared_{which}_axes" + seen = set() + for ax in self._iter_axes(hidden=False, children=False, panels=panels): + get_shared = getattr(ax, get_grouper, None) + if not callable(get_shared): + continue + siblings = list(get_shared().get_siblings(ax)) + if len(siblings) < 2: + continue + key = frozenset(map(id, siblings)) + if key in seen: + continue + seen.add(key) + yield siblings + + def _join_shared_group(self, which: str, ref, other) -> None: + """Join an axis to a shared group and copy the shared axis state.""" + ref._shared_axes[which].join(ref, other) + axis = getattr(other, f"{which}axis") + ref_axis = getattr(ref, f"{which}axis") + setattr(other, f"_share{which}", ref) + axis.major = ref_axis.major + axis.minor = ref_axis.minor + if which == "x": + lim = ref.get_xlim() + other.set_xlim(*lim, emit=False, auto=ref.get_autoscalex_on()) + else: + lim = ref.get_ylim() + other.set_ylim(*lim, emit=False, auto=ref.get_autoscaley_on()) + axis._scale = ref_axis._scale + + def _refresh_auto_share(self, which: Optional[str] = None) -> None: + """Recompute auto-sharing groups after local axis-state changes.""" + axes = list(self._iter_axes(hidden=False, children=True, panels=True)) + targets = ("x", "y") if which is None else (which,) + for target in targets: + if not self._is_auto_share_mode(target): + continue + for ax in axes: + if hasattr(ax, "_unshare"): + ax._unshare(which=target) + for ax in self._iter_axes(hidden=False, children=False, panels=False): + if hasattr(ax, "_apply_auto_share"): + ax._apply_auto_share() + self._autoscale_shared_limits(target) + + def _autoscale_shared_limits(self, which: str) -> None: + """Recompute shared data limits for each compatible shared-axis group.""" + if which not in ("x", "y"): + return + + share_level = self._sharex if which == "x" else self._sharey + if share_level <= 1: + return + + get_auto = f"get_autoscale{which}_on" + for siblings in self._iter_shared_groups(which, panels=True): + for sib in siblings: + relim = getattr(sib, "relim", None) + if callable(relim): + relim() + + ref = siblings[0] + for sib in siblings: + auto = getattr(sib, get_auto, None) + if callable(auto) and auto(): + ref = sib + break + + autoscale_view = getattr(ref, "autoscale_view", None) + if callable(autoscale_view): + autoscale_view(scalex=(which == "x"), scaley=(which == "y")) + def _snap_axes_to_pixel_grid(self, renderer) -> None: """ Snap visible axes bounds to the renderer pixel grid. @@ -1026,6 +1247,10 @@ def _compute_baseline_tick_state(self, group_axes, axis: str, label_keys): if getattr(axi, "_panel_side", None): continue + # Non-rectilinear GeoAxes should keep independent gridliner labels. + if isinstance(axi, paxes.GeoAxes) and not axi._is_rectilinear(): + return {}, True + # Supported axes types if not isinstance( axi, (paxes.CartesianAxes, paxes._CartopyAxes, paxes._BasemapAxes) @@ -1798,27 +2023,6 @@ def _add_subplot(self, *args, **kwargs): # Don't pass _subplot_spec as a keyword argument to avoid it being # propagated to Axes.set() or other methods that don't accept it ax = super().add_subplot(ss, **kwargs) - # Allow sharing for GeoAxes if rectilinear - if self._sharex or self._sharey: - if len(self.axes) > 1 and isinstance(ax, paxes.GeoAxes): - # Compare it with a reference - ref = next(self._iter_axes(hidden=False, children=False, panels=False)) - unshare = False - if not ax._is_rectilinear(): - unshare = True - elif hasattr(ax, "projection") and hasattr(ref, "projection"): - if ax.projection != ref.projection: - unshare = True - if unshare: - self._unshare_axes() - # Only warn once. Note, if axes are reshared - # the warning is not reset. This is however, - # very unlikely to happen as GeoAxes are not - # typically shared and unshared. - warnings._warn_ultraplot( - f"GeoAxes can only be shared for rectilinear projections, {ax.projection=} is not a rectilinear projection." - ) - if ax.number: self._subplot_dict[ax.number] = ax return ax @@ -1886,30 +2090,21 @@ def get_key(ax): key = get_key(ax) groups.setdefault(key, []).append(ax) - # Re-join axes per group - for group in groups.values(): - ref = group[0] - for other in group[1:]: - ref._shared_axes[which].join(ref, other) - # The following manual adjustments are necessary because the - # join method does not automatically propagate the sharing state - # and axis properties to the other axes. This ensures that the - # shared axes behave consistently. - if which == "x": - other._sharex = ref - other.xaxis.major = ref.xaxis.major - other.xaxis.minor = ref.xaxis.minor - lim = ref.get_xlim() - other.set_xlim(*lim, emit=False, auto=ref.get_autoscalex_on()) - other.xaxis._scale = ref.xaxis._scale - if which == "y": - # This logic is from sharey - other._sharey = ref - other.yaxis.major = ref.yaxis.major - other.yaxis.minor = ref.yaxis.minor - lim = ref.get_ylim() - other.set_ylim(*lim, emit=False, auto=ref.get_autoscaley_on()) - other.yaxis._scale = ref.yaxis._scale + # Re-join axes per compatible subgroup + for raw_group in groups.values(): + if which in ("x", "y"): + subgroups = self._partition_share_axes(raw_group, which) + else: + subgroups = [raw_group] + for group in subgroups: + if not group: + continue + ref = group[0] + for other in group[1:]: + if which in ("x", "y"): + self._join_shared_group(which, ref, other) + else: + ref._shared_axes[which].join(ref, other) def _add_subplots( self, diff --git a/ultraplot/internals/rcsetup.py b/ultraplot/internals/rcsetup.py index 40d52af6b..63c1caa4d 100644 --- a/ultraplot/internals/rcsetup.py +++ b/ultraplot/internals/rcsetup.py @@ -2083,11 +2083,13 @@ def copy(self): "Default width of the reference subplot." + _addendum_in, ), "subplots.share": ( - True, - _validate_belongs(0, 1, 2, 3, 4, False, "labels", "limits", True, "all"), + "auto", + _validate_belongs( + 0, 1, 2, 3, 4, False, "labels", "limits", True, "all", "auto" + ), "The axis sharing level, one of ``0``, ``1``, ``2``, or ``3``, or the " - "more intuitive aliases ``False``, ``'labels'``, ``'limits'``, or ``True``. " - "See `~ultraplot.figure.Figure` for details.", + "more intuitive aliases ``False``, ``'labels'``, ``'limits'``, ``True``, " + "or ``'auto'``. See `~ultraplot.figure.Figure` for details.", ), "subplots.span": ( True, diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index afcd259a3..1c6dfd8b3 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -1,5 +1,7 @@ import multiprocessing as mp import os +import warnings +from datetime import datetime, timedelta import numpy as np import pytest @@ -299,6 +301,110 @@ def test_suptitle_kw_position_reverted(ha, expectation): uplt.close("all") +def _share_sibling_count(ax, which: str) -> int: + return len(list(ax._shared_axes[which].get_siblings(ax))) + + +def test_default_share_mode_is_auto(): + fig, axs = uplt.subplots(ncols=2) + assert fig._sharex_auto is True + assert fig._sharey_auto is True + + +def test_auto_share_skips_mixed_cartesian_polar_without_warning(recwarn): + fig, axs = uplt.subplots(ncols=2, proj=("cart", "polar"), share="auto") + + ultra_warnings = [ + w + for w in recwarn + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + ] + assert len(ultra_warnings) == 0 + + for which in ("x", "y"): + assert _share_sibling_count(axs[0], which) == 1 + assert _share_sibling_count(axs[1], which) == 1 + + +def test_explicit_share_warns_for_mixed_cartesian_polar(): + with warnings.catch_warnings(record=True) as record: + warnings.simplefilter("always", uplt.internals.warnings.UltraPlotWarning) + fig, axs = uplt.subplots(ncols=2, proj=("cart", "polar"), share="all") + incompatible = [ + w + for w in record + if issubclass(w.category, uplt.internals.warnings.UltraPlotWarning) + and "Skipping incompatible" in str(w.message) + ] + assert len(incompatible) == 1 + + +def test_auto_share_local_yscale_change_splits_group(): + fig, axs = uplt.subplots(ncols=2, share="auto") + fig.canvas.draw() + + assert _share_sibling_count(axs[0], "y") == 2 + assert _share_sibling_count(axs[1], "y") == 2 + + axs[0].format(yscale="log") + fig.canvas.draw() + + assert axs[0].get_yscale() == "log" + assert axs[1].get_yscale() == "linear" + assert _share_sibling_count(axs[0], "y") == 1 + assert _share_sibling_count(axs[1], "y") == 1 + + +def test_auto_share_grid_yscale_change_keeps_shared_limits(): + fig, axs = uplt.subplots(ncols=2, share="auto") + x = np.linspace(1, 10, 100) + axs[0].plot(x, x) + axs[1].plot(x, 100 * x) + + axs.format(yscale="log") + fig.canvas.draw() + + assert _share_sibling_count(axs[0], "y") == 2 + assert _share_sibling_count(axs[1], "y") == 2 + + ymin, ymax = axs[0].get_ylim() + assert ymax > 500 + assert ymin > 0 + + +def test_auto_share_splits_mixed_x_unit_domains_after_refresh(): + fig, axs = uplt.subplots(ncols=2, share="auto") + fig.canvas.draw() + + # Start from independent x groups so each axis can establish units separately. + for axi in axs: + axi._unshare(which="x") + assert _share_sibling_count(axs[0], "x") == 1 + assert _share_sibling_count(axs[1], "x") == 1 + + t0 = datetime(2020, 1, 1) + axs[0].plot([t0, t0 + timedelta(days=1)], [0, 1]) + axs[1].plot([0.0, 1.0], [0, 1]) + + fig._refresh_auto_share("x") + fig.canvas.draw() + + sig0 = fig._axis_unit_signature(axs[0], "x") + sig1 = fig._axis_unit_signature(axs[1], "x") + assert sig0 != sig1 + assert _share_sibling_count(axs[0], "x") == 1 + assert _share_sibling_count(axs[1], "x") == 1 + + +def test_explicit_sharey_propagates_scale_changes(): + fig, axs = uplt.subplots(ncols=2, sharey=True) + axs[0].format(yscale="log") + fig.canvas.draw() + + assert axs[0].get_yscale() == "log" + assert axs[1].get_yscale() == "log" + + def test_subplots_pixelsnap_aligns_axes_bounds(): with uplt.rc.context({"subplots.pixelsnap": True}): fig, axs = uplt.subplots(ncols=2, nrows=2) diff --git a/ultraplot/tests/test_projections.py b/ultraplot/tests/test_projections.py index 7784e42ff..e97b7dbfc 100644 --- a/ultraplot/tests/test_projections.py +++ b/ultraplot/tests/test_projections.py @@ -46,6 +46,18 @@ def test_cartopy_labels(): return fig +def test_cartopy_labels_not_shared_for_non_rectilinear(): + """ + Non-rectilinear cartopy axes should keep independent gridliner labels. + """ + fig, axs = uplt.subplots(ncols=2, proj="robin", refwidth=3) + axs.format(coast=True, labels=True) + fig.canvas.draw() + + assert axs[0]._is_ticklabel_on("labelleft") + assert axs[1]._is_ticklabel_on("labelleft") + + @pytest.mark.mpl_image_compare def test_cartopy_contours(rng): """