Source code for multimin.plotting

##################################################################
#                                                                #
# MultiMin: Multivariate Gaussian fitting                        #
#                                                                #
# Authors: Jorge I. Zuluaga                                      #
#                                                                #
##################################################################
# License: GNU Affero General Public License v3 (AGPL-3.0)       #
##################################################################

"""
Visualization and plotting utilities for MultiMin package.

Contains:
- MultiPlot: Grid plotting for N-dimensional data projections
- multimin_watermark: Add watermark to plots
"""

import warnings
import numpy as np
from matplotlib import pyplot as plt

# Import from package modules
from .base import MultiMinBase
from .util import Util
from .version import __version__


# =============================================================================
# VISUALIZATION
# =============================================================================
[docs] def multimin_watermark(ax, frac=1 / 4, alpha=1): """Add a water mark to a 2d or 3d plot. Parameters: ax: Class axes: Axe where the pryngles mark will be placed. """ # Import show_watermark from main module at runtime import multimin as mn if not mn.show_watermark: return None # Get the height of axe axh = ( ax.get_window_extent() .transformed(ax.get_figure().dpi_scale_trans.inverted()) .height ) # Check if this ax is a marginal plot (twinx) # If so, we might want to skip adding watermark to avoid duplication or placement issues # But usually watermark is added to the main axis. fig_factor = frac * axh # Options of the water mark args = dict( rotation=270, ha="left", va="top", transform=ax.transAxes, color="pink", fontsize=8 * fig_factor, zorder=100, alpha=alpha, ) # Text of the water mark mark = f"MultiMin {__version__}" # Choose the according to the fact it is a 2d or 3d plot try: ax.add_collection3d plt_text = ax.text2D except: plt_text = ax.text text = plt_text(1, 1, mark, **args) return text
[docs] class MultiPlot(MultiMinBase): """ Create a grid of plots showing the projection of a N-dimensional data. Parameters ---------- properties : dict List of properties to be shown, dictionary of dictionaries (N entries). Keys are label of attribute, ex. "q". Dictionary values: * label: label used in axis, string * range: range for property, tuple (2) figsize : int, optional Base size for panels (the size of figure will be M x figsize), default 3. fontsize : int, optional Base fontsize, default 10. direction : str, optional Direction of ticks in panels, default 'out'. Attributes ---------- N : int Number of properties. M : int Size of grid matrix (M=N-1). fw : int Figsize. fs : int Fontsize. fig : matplotlib.figure.Figure Figure handle. axs : numpy.ndarray Matrix with subplots, axes handles (MxM). axp : dict Matrix with subplots, dictionary of dictionaries. properties : list List of properties labels, list of strings (N). Methods ------- tight_layout() Tight layout if no constrained_layout was used. set_labels(**args) Set labels parameters. set_ranges() Set ranges in panels according to ranges defined in dparameters. set_tick_params(**args) Set tick parameters. sample_hist(data, colorbar=False, **args) Create a 2d-histograms of data on all panels of the MultiPlot. sample_scatter(data, **args) Scatter plot on all panels of the MultiPlot. mog_pdf(mog, **args) Plot the PDF of a MoG on all panels of the MultiPlot. mog_contour(mog, **args) Plot the contours of a MoG on all panels of the MultiPlot. """ def __init__( self, properties, figsize=3, fontsize=10, direction="out", marginals=False, ): # Basic attributes self.dproperties = properties self.properties = list(properties.keys()) self.data = None # Secondary attributes self.N = len(properties) self.marginals = marginals self.M = max(1, self.N) if self.marginals else max(1, self.N - 1) self._univariate = self.N == 1 # Optional properties self.fw = figsize self.fs = fontsize # Univariate: single 1D panel if self._univariate: from matplotlib import pyplot as plt self.fig, ax = plt.subplots( 1, 1, constrained_layout=True, figsize=(self.fw * 1.5, self.fw) ) self.axs = np.array([[ax]]) self.constrained = True self.single = True self.axp = dict() prop0 = self.properties[0] self.axp[prop0] = {prop0: ax} ax.set_xlabel(self.dproperties[prop0]["label"], fontsize=fontsize) self.tight_layout() return # Create figure and axes: it works try: from matplotlib import pyplot as plt self.fig, self.axs = plt.subplots( self.M, self.M, constrained_layout=True, figsize=(self.M * self.fw, self.M * self.fw), sharex="col", sharey="row", ) self.constrained = True except: self.fig, self.axs = plt.subplots( self.M, self.M, figsize=(self.M * self.fw, self.M * self.fw), sharex="col", sharey="row", ) self.constrained = False if not isinstance(self.axs, np.ndarray): self.axs = np.array([[self.axs]]) self.single = True else: self.single = False # Create named axis self.axp = dict() for j in range(self.N): propj = self.properties[j] if propj not in self.axp.keys(): self.axp[propj] = dict() for i in range(self.N): propi = self.properties[i] # If marginals are active if self.marginals: if j > i: continue if propi not in self.axp.keys(): self.axp[propi] = dict() if i == j: self.axp[propi][propi] = self.axs[i][i] continue # i > j: propi is y-axis, propj is x-axis (column j, row i) # We store it as axp[x-prop][y-prop] self.axp[propj][propi] = self.axs[i][j] # Also store symmetric key for convenience? # The original code did: self.axp[propj][propi] = self.axp[propi][propj] when i < j # But here we enter the loop with i and j. continue # Default case (no marginals) if i == j: continue if propi not in self.axp.keys(): self.axp[propi] = dict() if i < j: self.axp[propj][propi] = self.axp[propi][propj] continue self.axp[propj][propi] = self.axs[i - 1][j] # Deactivate unused panels for i in range(self.M): for j in range(i + 1, self.M): self.axs[i][j].axis("off") # Place ticks for i in range(self.M): for j in range(i + 1): if not self.single: self.axs[i, j].tick_params(axis="both", direction=direction) else: self.axs[i, i].tick_params(axis="both", direction=direction) for i in range(self.M): self.axs[i, 0].tick_params(axis="y", direction="out") self.axs[self.M - 1, i].tick_params(axis="x", direction="out") # Set properties of panels self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout()
[docs] def tight_layout(self): """ Tight layout if no constrained_layout was used. """ if self.constrained == False: self.fig.subplots_adjust(wspace=self.fw / 100.0, hspace=self.fw / 100.0) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="The figure layout has changed to tight" ) self.fig.tight_layout()
[docs] def set_tick_params(self, **args): """ Set tick parameters. Ex. set_tick_params(labelsize=10) Parameters ---------- **args : dict Same arguments as tick_params method. """ opts = dict(axis="both", which="major", labelsize=0.8 * self.fs) opts.update(args) for i in range(self.M): for j in range(self.M): self.axs[i][j].tick_params(**opts)
[docs] def set_ranges(self): """ Set ranges in panels according to ranges defined in dparameters. """ if getattr(self, "_univariate", False): prop = self.properties[0] if self.dproperties[prop]["range"] is not None: self.axs[0][0].set_xlim(self.dproperties[prop]["range"]) return for i, propi in enumerate(self.properties): for j, propj in enumerate(self.properties): # Marginals: set x-range on diagonal if self.marginals and i == j: if self.dproperties[propi]["range"] is not None: self.axp[propi][propi].set_xlim( self.dproperties[propi]["range"] ) continue if j <= i: continue if self.dproperties[propi]["range"] is not None: self.axp[propi][propj].set_xlim(self.dproperties[propi]["range"]) if self.dproperties[propj]["range"] is not None: self.axp[propi][propj].set_ylim(self.dproperties[propj]["range"])
[docs] def reset_ranges(self): """ Reset ranges to match the data limits. """ if self.data is not None: for i, prop in enumerate(self.properties): dmin, dmax = self.data[:, i].min(), self.data[:, i].max() # Force range to data limits (overriding default 4-sigma extents of PDF) self.dproperties[prop]["range"] = [dmin, dmax] self.set_ranges()
[docs] def set_labels(self, **args): """ Set labels parameters. Ex. set_labels(fontsize=12) Parameters ---------- **args : dict Common arguments of set_xlabel, set_ylabel and text. """ opts = dict(fontsize=self.fs) opts.update(args) for i, prop in enumerate( self.properties[:-1] if not self.marginals else self.properties ): label = self.dproperties[prop]["label"] self.axs[self.M - 1][i].set_xlabel(label, **opts) # y-labels if not self.marginals: # Standard case: properties[1:] corresponds to rows 0..M-1 for i, prop in enumerate(self.properties[1:]): label = self.dproperties[prop]["label"] self.axs[i][0].set_ylabel(label, rotation=90, labelpad=10, **opts) else: # Marginals case: properties[1:] corresponds to rows 1..M-1 # Skip row 0 (top-left marginal) as requested for i, prop in enumerate(self.properties[1:], 1): label = self.dproperties[prop]["label"] self.axs[i][0].set_ylabel(label, rotation=90, labelpad=10, **opts) # Inner text labels (right side of rows) # For marginals=True: Show horizontal labels (column headers) but hide vertical (row headers) for i in range(1, self.M): label = self.dproperties[self.properties[i]]["label"] self.axs[i - 1][i].text( 0.5, 0.0, label, ha="center", transform=self.axs[i - 1][i].transAxes, **opts, ) # Vertical label (Row) - Hide for marginals=True if not self.marginals: # 270 if you want rotation self.axs[i - 1][i].text( 0.0, 0.5, label, rotation=270, va="center", transform=self.axs[i - 1][i].transAxes, **opts, ) label = self.dproperties[self.properties[0]]["label"] if not self.single: self.axs[0][1].text( 0.0, 1.0, label, rotation=0, ha="left", va="top", transform=self.axs[0][1].transAxes, **opts, ) if not self.marginals: label = self.dproperties[self.properties[-1]]["label"] # 270 if you want rotation self.axs[-1][-1].text( 1.05, 0.5, label, rotation=270, ha="left", va="center", transform=self.axs[-1][-1].transAxes, **opts, ) self.tight_layout()
[docs] def sample_hist(self, data, colorbar=False, **args): """ Create a 2d-histograms of data on all panels of the MultiPlot. Ex. G.sample_hist(data, bins=100, cmap='viridis', margs=dict(color='blue')) Parameters ---------- data : numpy.ndarray Data to be histogramed (n=len(data)), numpy array (nxN). colorbar : bool, optional Include a colorbar? (default False). **args : dict All arguments of hist2d method. Can include 'margs' dict with arguments for marginal plots. If margs=None, marginals are not drawn. Returns ------- hist : list List of histogram instances. Examples -------- >>> properties = { ... 'Q': {'label': r"$Q$", 'range': None}, ... 'E': {'label': r"$C$", 'range': None}, ... 'I': {'label': r"$I$", 'range': None}, ... } >>> G = mm.MultiPlot(properties, figsize=3) >>> hargs = dict(bins=100, cmap='viridis', margs=dict(color='blue')) >>> hist = G.sample_hist(udata, **hargs) """ self.data = data # Extract margs dict for marginal plots margs = args.pop('margs', {}) opts = dict() opts.update(args) # Default zorder for histogram (background) if "zorder" not in opts: opts["zorder"] = -100 # Univariate: 1D histogram (same style as plot_sample) if getattr(self, "_univariate", False): ax = self.axs[0][0] hargs_1d = {k: v for k, v in opts.items() if k != "cmap"} if "bins" not in hargs_1d: hargs_1d["bins"] = min(50, max(10, len(data) // 20)) if "density" not in hargs_1d: hargs_1d["density"] = True hargs_1d.setdefault("label", "sample histogram") ax.hist(data[:, 0], **hargs_1d) ax.yaxis.set_label_position("left") ax.set_ylabel("density") # Legend (univariate): if no twin yet, add legend for histogram only handles, labels = ax.get_legend_handles_labels() if handles and getattr(self, "_ax_twin", None) is None: ax.legend( handles, labels, loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=len(handles), frameon=False, ) self.fig.subplots_adjust(top=0.88) self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark( ax, frac=0.5 ) # larger frac for single panel (match 2-panel size) self._watermark_added = True return [] # Initialize twin axes storage if not exists if not hasattr(self, '_twin_axes'): self._twin_axes = {} hist = [] for i, propi in enumerate(self.properties): if self.dproperties[propi]["range"] is not None: xmin, xmax = self.dproperties[propi]["range"] else: xmin = data[:, i].min() xmax = data[:, i].max() for j, propj in enumerate(self.properties): if j <= i: continue if self.dproperties[propj]["range"] is not None: ymin, ymax = self.dproperties[propj]["range"] else: ymin = data[:, j].min() ymax = data[:, j].max() opts["range"] = [[xmin, xmax], [ymin, ymax]] h, xe, ye, im = self.axp[propi][propj].hist2d( data[:, i], data[:, j], **opts ) hist += [im] if colorbar: # Create color bar from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(self.axp[propi][propj]) cax = divider.append_axes("top", size="9%", pad=0.1) self.fig.add_axes(cax) cticks = np.linspace(h.min(), h.max(), 10)[2:-1] self.fig.colorbar( im, ax=self.axp[propi][propj], cax=cax, orientation="horizontal", ticks=cticks, ) cax.xaxis.set_tick_params( labelsize=0.5 * self.fs, direction="in", pad=-0.8 * self.fs ) xt = cax.get_xticks() xm = xt.mean() m, e = Util.mantisa_exp(xm) xtl = [] for x in xt: xtl += ["%.1f" % (x / 10**e)] cax.set_xticklabels(xtl) cax.text( 0, 0.5, r"$\times 10^{%d}$" % e, ha="left", va="center", transform=cax.transAxes, fontsize=6, color="w", ) # Marginals for sample_hist if self.marginals and margs is not None: ax = self.axp[propi][propi] if self.dproperties[propi]["range"] is not None: xmin, xmax = self.dproperties[propi]["range"] ax.set_xlim(xmin, xmax) # Histogram on twin axis # Reuse twin axis if already exists, otherwise create new one if propi not in self._twin_axes: self._twin_axes[propi] = ax.twinx() ax_hist = self._twin_axes[propi] # Default marginal histogram options hargs_marg = dict(bins=opts.get("bins", 20), histtype="step", density=True, color="k") # Update with user-provided margs hargs_marg.update(margs) # Set range for 1D hist if dproperties has it if self.dproperties[propi]["range"] is not None: hargs_marg["range"] = self.dproperties[propi]["range"] ax_hist.hist(data[:, i], **hargs_marg) ax_hist.yaxis.set_visible(False) ax.tick_params(axis="y", left=False, right=False, labelleft=False) self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark(self.axs[0][0], frac=1 / 4 * self.axs.shape[0]) self._watermark_added = True return hist
[docs] def sample_scatter(self, data, nbins=20, **args): """ Scatter plot on all panels of the MultiPlot. Ex. G.sample_scatter(data, s=0.2, color='r', margs=dict(color='b')) Parameters ---------- data : numpy.ndarray Data to be histogramed (n=len(data)), numpy array (nxN). nbins : int, optional Number of bins for marginal histograms (default 20). **args : dict All arguments of scatter method. Can include 'margs' dict with arguments for marginal plots. If margs=None, marginals are not drawn. Returns ------- scatter : list List of scatter instances. Examples -------- >>> # With marginals (blue histogram) >>> sargs = dict(s=0.2, edgecolor='None', color='r', margs=dict(color='b')) >>> hist = G.sample_scatter(udata, **sargs) >>> # Without marginals >>> sargs = dict(s=0.2, edgecolor='None', color='r', margs=None) >>> hist = G.sample_scatter(udata, **sargs) """ self.data = data # Extract margs dict for marginal plots margs = args.pop('margs', {}) # Univariate: scatter on a twin y-axis so data range is independent of PDF/density if getattr(self, "_univariate", False): ax = self.axs[0][0] ax_twin = ax.twinx() x = data[:, 0] y_jitter = np.random.uniform(0, 1, size=len(x)) sargs_1d = dict(args) sargs_1d.setdefault("label", "sample") # Default zorder for scatter (foreground) sargs_1d.setdefault("zorder", 100) sc = ax_twin.scatter(x, y_jitter, **sargs_1d) ax_twin.set_ylim(0, 1) ax_twin.set_yticks([]) prop_name = self.properties[0] ax_twin.set_ylabel( "sample " + self.dproperties[prop_name]["label"], fontsize=self.fs ) self._ax_twin = ax_twin # store for reference # Legend: combine primary ax (e.g. histogram) and twin (sample scatter) handles, labels = ax.get_legend_handles_labels() h2, l2 = ax_twin.get_legend_handles_labels() handles, labels = handles + h2, labels + l2 if handles: ax.legend( handles, labels, loc="lower center", bbox_to_anchor=(0.5, 1.02), ncol=len(handles), frameon=False, ) self.fig.subplots_adjust(top=0.88) # room for legend above self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark( ax, frac=0.5 ) # larger frac for single panel (match 2-panel size) self._watermark_added = True return [sc] scatter = [] # Default zorder for scatter (foreground) if "zorder" not in args: args["zorder"] = 100 # Initialize twin axes storage if not exists if not hasattr(self, '_twin_axes'): self._twin_axes = {} for i, propi in enumerate(self.properties): for j, propj in enumerate(self.properties): # Marginals if self.marginals and i == j: ax = self.axp[propi][propi] if self.dproperties[propi]["range"] is not None: xmin, xmax = self.dproperties[propi]["range"] ax.set_xlim(xmin, xmax) # Only draw marginals if margs is not None if margs is not None: # Histogram on twin axis to allow independent Y-scale from scatter plots # Reuse twin axis if already exists, otherwise create new one if propi not in self._twin_axes: self._twin_axes[propi] = ax.twinx() ax_hist = self._twin_axes[propi] # Default marginal histogram options marg_opts = dict(histtype="step", density=True, color="k", lw=1) marg_opts.update(margs) ax_hist.hist( data[:, i], bins=nbins, **marg_opts ) # Hide Y-ticks for marginals (standard for corner plots) ax_hist.yaxis.set_visible(False) # Hide primary Y-axis ticks/labels as well (since it shares scale with row) ax.tick_params(axis="y", left=False, right=False, labelleft=False) continue if j <= i: continue scatter += [ self.axp[propi][propj].scatter(data[:, i], data[:, j], **args) ] self.set_labels() self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark(self.axs[0][0], frac=1 / 4 * self.axs.shape[0]) self._watermark_added = True return scatter
[docs] def mog_pdf(self, mog, grid_size=200, **args): """ Plot the PDF of a MoG on all panels of the MultiPlot. Ex. G.mog_pdf(mog, color='k', lw=2, margs=dict(color='blue')) Parameters ---------- mog : MixtureOfGaussians MoG object to plot. grid_size : int, optional Number of points for the grid (default 200). **args : dict Arguments for the plot function (e.g. color, linewidth). Can include 'margs' dict with arguments for marginal plots. If margs=None, marginals are not drawn. """ # Extract margs dict for marginal plots margs = args.pop('margs', {}) opts = dict(color="k", lw=2) opts.update(args) # Default zorder for PDF (background) # Note: User requested zorder=100 for background and -100 for foreground, # but standard is low=back, high=front. We use -100 for background. if "zorder" not in opts: opts["zorder"] = -100 if getattr(self, "_univariate", False): # Filter out arguments not supported by ax.plot # (e.g. cmap/colorbar are for pcolormesh/images) plot_opts = opts.copy() for key in ["cmap", "colorbar"]: plot_opts.pop(key, None) ax = self.axs[0][0] if "label" not in plot_opts: plot_opts["label"] = "PDF" if self.dproperties[self.properties[0]]["range"] is not None: xmin, xmax = self.dproperties[self.properties[0]]["range"] else: bounds = getattr(mog, "_domain_bounds", None) if ( bounds is not None and np.isfinite(bounds[0][0]) and np.isfinite(bounds[0][1]) ): xmin, xmax = bounds[0] else: # Robust auto-range based on mus/sigmas mu_min = np.min(mog.mus[:, 0]) mu_max = np.max(mog.mus[:, 0]) sig_max = np.max(mog.sigmas[:, 0]) nsig = 4.0 xmin = mu_min - nsig * sig_max xmax = mu_max + nsig * sig_max if not np.isfinite(xmin) or not np.isfinite(xmax) or xmin == xmax: xmin, xmax = mu_min - 1.0, mu_max + 1.0 x = np.linspace(xmin, xmax, int(grid_size)) y = mog.pdf(x.reshape(-1, 1)) ax.plot(x, y, **plot_opts) # Update y-limits if needed if y.size > 0: current_ylim = ax.get_ylim() new_ymax = max(current_ylim[1], float(np.max(y)) * 1.05) ax.set_ylim(0, new_ymax) self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark(ax, frac=0.5) self._watermark_added = True return # Initialize twin axes storage if not exists if not hasattr(self, '_twin_axes'): self._twin_axes = {} # Multivariate case w = np.asarray(mog.weights, dtype=float) w_sum = float(np.sum(w)) if w_sum <= 0: w = np.ones_like(w) / max(1, w.size) else: w = w / w_sum base_point = np.average(mog.mus, axis=0, weights=w) # Helper to get range for a variable (index k, name prop, axis ax) def _get_range(k, prop, ax, axis_idx=0): # axis_idx 0 for x, 1 for y # 1. User specified range in properties if self.dproperties[prop]["range"] is not None: return self.dproperties[prop]["range"] # 2. Existing axis limits (if data is present) # Check if axis has data that might have set limits has_data = ( ax.has_data() or len(ax.collections) > 0 or len(ax.images) > 0 or len(ax.lines) > 0 ) if has_data: if axis_idx == 0: return ax.get_xlim() else: return ax.get_ylim() # 3. MoG bounds bounds = getattr(mog, "_domain_bounds", None) if bounds is not None: lo, hi = bounds[k] if np.isfinite(lo) and np.isfinite(hi): return [float(lo), float(hi)] # 4. Auto-range based on MoG parameters mu_min = float(np.min(mog.mus[:, k])) mu_max = float(np.max(mog.mus[:, k])) sig_max = float(np.max(mog.sigmas[:, k])) nsig = 4.0 lo = mu_min - nsig * sig_max hi = mu_max + nsig * sig_max if not np.isfinite(lo) or not np.isfinite(hi) or lo == hi: lo, hi = mu_min - 1.0, mu_max + 1.0 return [lo, hi] first_im = None cmap = args.get("cmap", "Spectral_r") # Extract cmap from args or default for i, propi in enumerate(self.properties): for j, propj in enumerate(self.properties): if j <= i: continue ax = self.axp[propi][propj] x_min, x_max = _get_range(i, propi, ax, 0) y_min, y_max = _get_range(j, propj, ax, 1) xs = np.linspace(float(x_min), float(x_max), int(grid_size)) ys = np.linspace(float(y_min), float(y_max), int(grid_size)) # Careful with meshgrid indexing for pcolormesh xx, yy = np.meshgrid(xs, ys, indexing="xy") pts = np.column_stack([xx.ravel(), yy.ravel()]) X_full = np.tile(base_point, (pts.shape[0], 1)) X_full[:, i] = pts[:, 0] X_full[:, j] = pts[:, 1] zz = np.asarray(mog.pdf(X_full), dtype=float).reshape(xx.shape) # pcolormesh # Use zorder from args if present, else default to -100 zorder = args.get("zorder", -100) im = ax.pcolormesh(xx, yy, zz, shading="auto", cmap=cmap, zorder=zorder) if first_im is None: first_im = (ax, im) # Marginals for mog_pdf if self.marginals and margs is not None: ax = self.axp[propi][propi] x_min, x_max = _get_range(i, propi, ax, 0) # Grid for marginal x = np.linspace(float(x_min), float(x_max), int(grid_size)) # Compute marginal PDF: sum(w_k * N(x | mu_ki, sigma_ki)) y = np.zeros_like(x) # Need norm from scipy.stats from scipy.stats import norm # Iterate over Gaussian components for k in range(mog.ngauss): # Weight w_k = ( mog.weights[k] if mog.weights is not None else 1.0 / mog.ngauss ) if mog.weights is not None: # normalize if needed, but usually weights are normalized in mog object pass # Parameters for variable i mu_ki = mog.mus[k, i] sigma_ki = mog.sigmas[k, i] y += w_k * norm.pdf(x, loc=mu_ki, scale=sigma_ki) # Plot on twin axis # Reuse twin axis if already exists, otherwise create new one if propi not in self._twin_axes: self._twin_axes[propi] = ax.twinx() ax_marg = self._twin_axes[propi] # Default marginal plot options marg_opts = dict(color="k", lw=1) # Update with user-provided margs marg_opts.update(margs) ax_marg.plot(x, y, **marg_opts) # Set y-axis to start at 0 (like histograms) to avoid offset ax_marg.set_ylim(0, None) ax_marg.yaxis.set_visible(False) ax.tick_params(axis="y", left=False, right=False, labelleft=False) # Update x-limits if needed (though usually set by _get_range logic or user) ax.set_xlim(x_min, x_max) # Handle colorbar if requested (logic from original plot_pdf) # Note: colorbar arg was not explicitly in mog_pdf signature in previous snippet # but usage in plot_pdf(..., colorbar=False) suggests it might be passed in **args or needed. # The user's snippet for mog_pdf(self, mog, grid_size=200, **args) # If colorbar is needed, we should check args. if args.get("colorbar", False) and first_im is not None: ax0, im0 = first_im from mpl_toolkits.axes_grid1 import make_axes_locatable divider = make_axes_locatable(ax0) cax = divider.append_axes("top", size="9%", pad=0.1) self.fig.add_axes(cax) vmin = float(np.nanmin(im0.get_array())) vmax = float(np.nanmax(im0.get_array())) if np.isfinite(vmin) and np.isfinite(vmax) and vmin != vmax: cticks = np.linspace(vmin, vmax, 8)[1:-1] else: cticks = None self.fig.colorbar( im0, ax=ax0, cax=cax, orientation="horizontal", ticks=cticks ) cax.xaxis.set_tick_params( labelsize=0.5 * self.fs, direction="in", pad=-0.8 * self.fs ) self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark(self.axs[0][0], frac=1 / 4 * self.axs.shape[0]) self._watermark_added = True
[docs] def mog_contour(self, mog, grid_size=200, **args): """ Plot the contours of a MoG on all panels of the MultiPlot. Ex. G.mog_contour(mog, levels=5, cmap='Reds', margs=dict(color='blue')) Parameters ---------- mog : MixtureOfGaussians MoG object to plot. grid_size : int, optional Number of points for the grid (default 200). **args : dict Arguments for contour function. Can include 'margs' dict with arguments for marginal plots. If margs=None, marginals are not drawn. """ # Extract margs dict for marginal plots margs = args.pop('margs', {}) opts = dict(levels=5, cmap="Reds", legend=True) opts.update(args) if getattr(self, "_univariate", False): # Contours don't make sense in 1D, maybe strict validation or ignore? return # Initialize twin axes storage if not exists if not hasattr(self, '_twin_axes'): self._twin_axes = {} # Decomposition handling decomp = args.pop("decomp", False) # We need to access MixtureOfGaussians to create components if decomp=True from .mog import MixtureOfGaussians # Collect legend handles if decomp=True legend_handles = [] legend_labels = [] for i, propi in enumerate(self.properties): if self.dproperties[propi]["range"] is not None: xmin, xmax = self.dproperties[propi]["range"] else: xmin, xmax = ( self.axp[propi][self.properties[i + 1]].get_xlim() if i + 1 < self.N else (0, 1) ) for j, propj in enumerate(self.properties): if j <= i: continue if self.dproperties[propj]["range"] is not None: ymin, ymax = self.dproperties[propj]["range"] else: ymin, ymax = self.axp[propi][propj].get_ylim() # Evaluation grid xi = np.linspace(xmin, xmax, grid_size) yi = np.linspace(ymin, ymax, grid_size) Xi, Yi = np.meshgrid(xi, yi) # Helper to plot a specific MoG (full or component) def plot_mog_instance(sub_mog, style_opts): # Full vector X X_full = np.zeros((grid_size * grid_size, sub_mog.nvars)) mean_vec = np.average(sub_mog.mus, axis=0, weights=sub_mog.weights) X_full[:] = mean_vec X_full[:, i] = Xi.ravel() X_full[:, j] = Yi.ravel() Z = sub_mog.pdf(X_full).reshape(grid_size, grid_size) # Adjust levels to avoid white frame current_opts = style_opts.copy() if isinstance(style_opts.get("levels"), int): nlevels = style_opts["levels"] zmax = Z.max() current_opts["levels"] = np.linspace( 0.1 * zmax, 0.95 * zmax, nlevels ) cntr = self.axp[propi][propj].contour(Xi, Yi, Z, **current_opts) return cntr if not decomp: plot_mog_instance(mog, opts) else: # Decomposition: plot each component for k in range(mog.ngauss): # Extract component mu_k = mog.mus[k : k + 1] sigma_k = mog.Sigmas[k : k + 1] rho_k = mog.rhos[k : k + 1] if mog.rhos is not None else None # Create component MoG # Create component MoG # rhos are implicit in Sigmas, so we don't pass them to init comp_mog = MixtureOfGaussians( mus=mu_k, Sigmas=sigma_k, weights=[1.0], domain=getattr(mog, "domain", None), normalize_weights=False, ) # Style for component comp_opts = opts.copy() # Cycle colors: C0, C1... color = f"C{k % 10}" comp_opts["colors"] = color comp_opts.pop("cmap", None) # Remove cmap to use colors plot_mog_instance(comp_mog, comp_opts) # Collect legend info (only need to do this once for the first 2D panel found) if len(legend_handles) < mog.ngauss: # Create a dummy line for legend from matplotlib.lines import Line2D line = Line2D([0], [0], color=color, lw=2) legend_handles.append(line) mu_i = mu_k[0, i] mu_j = mu_k[0, j] # Safe sigma calculation var_i = sigma_k[0, i, i] var_j = sigma_k[0, j, j] sig_i = np.sqrt(max(0, var_i)) sig_j = np.sqrt(max(0, var_j)) # Calculate rho from covariance matrix cov_ij = sigma_k[0, i, j] if sig_i > 0 and sig_j > 0: rho_val = cov_ij / (sig_i * sig_j) else: rho_val = 0.0 # Safe sigma calculation var_i = sigma_k[0, i, i] var_j = sigma_k[0, j, j] sig_i = np.sqrt(max(0, var_i)) sig_j = np.sqrt(max(0, var_j)) # Calculate rho from covariance matrix cov_ij = sigma_k[0, i, j] if sig_i > 0 and sig_j > 0: rho_val = cov_ij / (sig_i * sig_j) else: rho_val = 0.0 label = rf"Comp {k + 1}: $\mu$=({mu_i:.2f}, {mu_j:.2f}), $\sigma$=({sig_i:.2f}, {sig_j:.2f}), $\rho$={rho_val:.2f}" legend_labels.append(label) # Marginals for mog_contour if self.marginals and margs is not None: ax = self.axp[propi][propi] # Get range for variable i if self.dproperties[propi]["range"] is not None: x_min, x_max = self.dproperties[propi]["range"] else: # Use same range logic as in mog_pdf bounds = getattr(mog, "_domain_bounds", None) if bounds is not None: lo, hi = bounds[i] if np.isfinite(lo) and np.isfinite(hi): x_min, x_max = float(lo), float(hi) else: mu_min = float(np.min(mog.mus[:, i])) mu_max = float(np.max(mog.mus[:, i])) sig_max = float(np.max(mog.sigmas[:, i])) nsig = 4.0 x_min = mu_min - nsig * sig_max x_max = mu_max + nsig * sig_max if not np.isfinite(x_min) or not np.isfinite(x_max) or x_min == x_max: x_min, x_max = mu_min - 1.0, mu_max + 1.0 else: mu_min = float(np.min(mog.mus[:, i])) mu_max = float(np.max(mog.mus[:, i])) sig_max = float(np.max(mog.sigmas[:, i])) nsig = 4.0 x_min = mu_min - nsig * sig_max x_max = mu_max + nsig * sig_max if not np.isfinite(x_min) or not np.isfinite(x_max) or x_min == x_max: x_min, x_max = mu_min - 1.0, mu_max + 1.0 # Grid for marginal x = np.linspace(float(x_min), float(x_max), int(grid_size)) # Compute marginal PDF: sum(w_k * N(x | mu_ki, sigma_ki)) y = np.zeros_like(x) from scipy.stats import norm # Iterate over Gaussian components for k in range(mog.ngauss): # Weight w_k = mog.weights[k] if mog.weights is not None else 1.0 / mog.ngauss # Parameters for variable i mu_ki = mog.mus[k, i] sigma_ki = mog.sigmas[k, i] y += w_k * norm.pdf(x, loc=mu_ki, scale=sigma_ki) # Plot on twin axis # Reuse twin axis if already exists, otherwise create new one if propi not in self._twin_axes: self._twin_axes[propi] = ax.twinx() ax_marg = self._twin_axes[propi] # Default marginal plot options marg_opts = dict(color="k", lw=1) # Update with user-provided margs marg_opts.update(margs) ax_marg.plot(x, y, **marg_opts) # Set y-axis to start at 0 (like histograms) to avoid offset ax_marg.set_ylim(0, None) ax_marg.yaxis.set_visible(False) ax.tick_params(axis="y", left=False, right=False, labelleft=False) # Update x-limits ax.set_xlim(x_min, x_max) if decomp and legend_handles and opts["legend"]: # Add legend to the right of G.axs[0][0] # We anchor it to axs[0][0] (top-left panel) ax_ref = self.axs[0][0] ax_ref.legend( legend_handles, legend_labels, loc="upper right", frameon=True, fontsize=6, ) self.set_ranges() self.set_tick_params() self.tight_layout() if not getattr(self, "_watermark_added", False): multimin_watermark(self.axs[0][0], frac=1 / 4 * self.axs.shape[0]) self._watermark_added = True