Source code for csky.plotting

# plotting.py

"""Plotting of skymaps, test statistic distributions, PDFs, etc.

TODO: This module is a complete mess.  Starting point for things to clean up:

* abandon colormaps.py; everyone should be using modern matplotlib by now.
* get rid of Plot and SkyPlot
* make SkyPlotter more flexible and generally tidy its interface
* probably get rid of plot_energy_pdf and plot_gauss_2d_angres_param
* add docstrings to everything that remains

"""

from __future__ import print_function

import copy
from cycler import cycler
import healpy

hp = healpy

try:
    from itertools import izip

    zip = izip
except ImportError:
    pass

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

pi = np.pi
import os
import sys

from . import utils

try:
    import histlite as hl
except:
    from icecube import histlite as hl


def _ensure_dir(dirname):
    """Make sure ``dirname`` exists and is a directory."""
    if not os.path.isdir(dirname):
        try:
            os.makedirs(dirname)  # throws if exists as file
        except OSError as e:
            if e.errno != os.errno.EEXIST:
                raise
    return dirname


skymap_cmap = {
    "blue": (
        (0.0, 0.0, 1.0),
        (0.05, 1.0, 1.0),
        (0.4, 1.0, 1.0),
        (0.6, 1.0, 1.0),
        (0.7, 0.2, 0.2),
        (1.0, 0.0, 0.0),
    ),
    "green": (
        (0.0, 0.0, 1.0),
        (0.05, 1.0, 1.0),
        (0.5, 0.0416, 0.0416),
        (0.6, 0.0, 0.0),
        (0.8, 0.5, 0.5),
        (1.0, 1.0, 1.0),
    ),
    "red": (
        (0.0, 0.0, 1.0),
        (0.05, 1.0, 1.0),
        (0.5, 0.0416, 0.0416),
        (0.6, 0.0416, 0.0416),
        (0.7, 1.0, 1.0),
        (1.0, 1.0, 1.0),
    ),
}

skymap_cmap = mpl.colors.LinearSegmentedColormap("icecube", skymap_cmap, 256)


def mpl_tex_rc(sans=False):
    import matplotlib.pyplot as plt
    import matplotlib as mpl

    if sans:
        plt.rc("text", usetex=True)
        plt.rc("font", family="sans-serif")
        # plt.rc('font', **{'sans-serif': 'Computer Modern Sans Serif'})
        mpl.rcParams["text.latex.preamble"] = [
            r"\usepackage{amsmath}",
            r"\usepackage{amssymb}",
            r"\usepackage{amsthm}",
            r"\usepackage{bm}",
            r"\usepackage{sansmath}",
            r"\SetSymbolFont{operators}   {sans}{OT1}{cmss} {m}{n}"
            r"\SetSymbolFont{letters}     {sans}{OML}{cmbrm}{m}{it}"
            r"\SetSymbolFont{symbols}     {sans}{OMS}{cmbrs}{m}{n}"
            r"\SetSymbolFont{largesymbols}{sans}{OMX}{iwona}{m}{n}"
            r"\sansmath",
        ]
    else:
        plt.rc("text", usetex=True)
        plt.rc("font", family="serif")
        # plt.rc('font', serif='Computer Modern Roman')
        mpl.rcParams["text.latex.preamble"] = [
            r"\usepackage{amsmath}",
            r"\usepackage{amssymb}",
            r"\usepackage{amsthm}",
            r"\usepackage{bm}",
        ]


def saving(plot_dir, basename, fig=None, exts="png pdf", **kw):
    utils.ensure_dir(plot_dir)
    print("-> {}/{}".format(plot_dir, basename))
    sys.stdout.flush()
    if fig is None:
        fig = plt.gcf()
    for ext in exts.split():
        fig.savefig("{}/{}.{}".format(plot_dir, basename, ext), **kw)


[docs] class Plot(object): """ Base class for plots. """
[docs] def __init__(self, fig=None): if fig is None: fig = plt.figure() elif isinstance(fig, int): fig = plt.figure(fig) self.fig = fig self.fig.clf() try: self.fignum = self.fig.number except: self.fignum = None
def save(self, dir, basename, exts="png pdf"): _ensure_dir(dir) for ext in exts.split(): filename = "{}/{}.{}".format(dir, basename, ext) print("-> {} ...".format(filename)) self.fig.savefig(filename) def close(self): plt.close(self.fig)
[docs] class SkyPlot(Plot): """ Skymap plotter. """
[docs] def __init__(self, fig=None, m=None, rot=None, coord="C", *a, **kw): Plot.__init__(self, fig) self.m = m self.coord = coord if rot is None: rot = 180 if coord[-1] == "C" else 0 self.rot = rot self.a = a self.kw = kw self.kw.setdefault("unit", "") self.kw.setdefault("title", "") self.kw.setdefault("format", "%.1f") self.kw.setdefault("cmap", "afmhot") if isinstance(self.kw["cmap"], str): self.kw["cmap"] = plt.get_cmap(self.kw["cmap"]) self.kw["cmap"].set_under("w") healpy.mollview(m, fig=self.fignum, rot=self.rot, coord=self.coord, *a, **kw) self.mollax = self.fig.get_axes()[0]
def graticule(self, wpad=None, *a, **kw): fig = plt.figure(self.fignum) self.kw.setdefault("alpha", 0.5) healpy.graticule(*a, **kw) usetex = mpl.rcParams["text.usetex"] if self.coord[-1] == "C": locs = 0, 359.9 if usetex: labels = r"\textbf{0h} \textbf{24h}" else: labels = "0h 24h" elif self.coord[-1] == "G": locs = -180, 179.9 if usetex: labels = r"\textbf{--180}$^\circ$ \textbf{+180}$^\circ$" else: labels = r"$-180^\circ$ $+180^\circ$" labels = labels.split() lons = self.rot - 180, self.rot + 179.9 healpy.projtext( lons[0], 0, labels[0], lonlat=True, ha="left", va="center", withdash=True, dashpad=2, dashlength=0.01, dashdirection=1, ) healpy.projtext( lons[1], 0, labels[1], lonlat=True, ha="right", va="center", withdash=True, dashpad=8, dashlength=0.01, ) for lat in [-60, -30, 30, 60]: lon = 179.9 + self.rot healpy.projtext( lon, 1.1 * lat, format(lat, "+.0f"), lonlat=True, ha="right", va="center", withdash=True, dashpad=abs(lat), dashlength=0.01, ) if wpad is not None: bounds = list(self.mollax.get_position().bounds) if bounds[0] < wpad: bounds[0] = wpad bounds[2] = 1 - 2 * wpad self.mollax.set_position(bounds) plt.draw()
[docs] def colorbar(self, unit=""): """ Draw a colorbar. """ log = self.kw.get("norm", "") == "log" vmin, vmax = np.nanmin(self.m), np.nanmax(self.m) if log: vmin = 10 ** np.ceil(np.log10(vmin)) vmax = 10 ** np.floor(np.log10(vmax)) kw = dict( orientation="horizontal", fraction=0.1, shrink=0.5, pad=0.05, ) if log: kw.update( dict( format=mpl.ticker.LogFormatterMathtext(), ticks=10 ** np.arange(np.log10(vmin), np.log10(vmax) + 1), ) ) cb = self.fig.colorbar(self.mollax.get_images()[0], ax=self.mollax, **kw) if unit: cb.set_label(unit) return cb
def show_gp(self, **kw): lon = np.linspace(0, 360, 1000) lat = np.zeros_like(lon) kw["lonlat"] = True kw["coord"] = "G" + self.coord[-1] healpy.projplot(lon, lat, **kw) def show_gc(self, **kw): coord = "G" + self.coord[-1] kw["lonlat"] = True kw["coord"] = "G" + self.coord[-1] healpy.projscatter(0, 0, **kw)
[docs] class SkyPlotter(object): """ Skymap plotter using matplotlib directly for projections. """
[docs] def __init__(self, coord="C", projection="aitoff", pc_kw={}, cb_kw={}): self.coord = coord self.projection = projection self.pc_kw = pc_kw self.cb_kw = cb_kw if self.coord not in ["C", "G"]: raise NotImplementedError('coord "{}" not yet supported'.format(self.coord)) self.cb_kw.setdefault("orientation", "horizontal") self.cb_kw.setdefault("shrink", 0.5) self.cb_kw.setdefault("pad", 0.08)
def thetaphi_to_mpl(self, theta, phi): theta, phi = np.atleast_1d(theta), np.atleast_1d(phi) x = pi - phi x[x > pi] -= 2 * pi y = pi / 2 - theta return x, y def plot_gp(self, ax, color=".5", s=0.3, strip=0.0, **kw): l = np.linspace(-pi, pi, 3000) theta_b = pi / 2 * np.ones_like(l) if self.coord == "C": r = healpy.Rotator(coord="GC") theta, phi = r(theta_b, l) if strip > 0: theta_up = theta_b + strip theta_down = theta_b - strip theta_up, phi_up = r(theta_up, l) theta_down, phi_down = r(theta_down, l) elif self.coord == "G": theta, phi = theta_b, l if strip > 0: theta_up = theta + strip theta_down = theta - strip else: raise ValueError("bad coord {}".format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, marker=".", s=s, **kw) if strip > 0: x_up, y_up = self.thetaphi_to_mpl(theta_up, phi_up) x_down, y_down = self.thetaphi_to_mpl(theta_down, phi_down) ax.scatter(x_up, y_up, color=color, marker=".", s=s, **kw) ax.scatter(x_down, y_down, color=color, marker=".", s=s, **kw) # ax.fill_between(x, y_up, y_down, alpha=0.2) def plot_sgp(self, ax, color=".5", s=0.3, **kw): from icecube import astro ras = np.linspace(0.0, 361.0, 1500) * np.pi / 180.0 decls_0 = np.zeros(len(ras)) if self.coord == "C": l, b = astro.supergal_to_equa(ras, decls_0) theta_b = np.pi / 2 - b theta, phi = theta_b, l elif self.coord == "G": l, b = astro.supergal_to_gal(ras, decls_0) theta_b = np.pi / 2 - b theta, phi = theta_b, l else: raise ValueError("bad coord {}".format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, marker=".", s=s, **kw) def plot_gc(self, ax, color=".5", s=15, **kw): l = 0 theta_b = pi / 2 if self.coord == "C": r = healpy.Rotator(coord="GC") theta, phi = r(theta_b, l) elif self.coord == "G": theta, phi = theta_b, l else: raise ValueError("bad coord {}".format(self.coord)) x, y = self.thetaphi_to_mpl(theta, phi) ax.scatter(x, y, color=color, s=s, **kw) def rotate_map(self, m, **kw): nside = healpy.get_nside(m) r = healpy.rotator.Rotator(**kw) theta, phi = healpy.pix2ang(nside, np.arange(len(m))) theta_rot, phi_rot = r(theta, phi) rot_map = healpy.get_interp_val(m, theta_rot, phi_rot) return rot_map def map_to_latlonz(self, m, N=1000): x = np.linspace(pi, -pi, 2 * N) y = np.linspace(pi, 0, N) X, Y = np.meshgrid(x, y) r = healpy.rotator.Rotator(rot=(-180, 0, 0)) YY, XX = r(Y.ravel(), X.ravel()) pix = healpy.ang2pix(healpy.get_nside(m), YY, XX) Z = np.reshape(m[pix], X.shape) lon = x[::-1] lat = pi / 2 - y return lat, lon, Z def plot_map( self, ax, m, unit="", n_ticks=5, pc_kw={}, cb_kw={}, ticks=None, log=False, titleticks=False, nohr=False, ): if m is None: return lat, lon, Z = self.map_to_latlonz(m) kw = copy.deepcopy(self.pc_kw) kw.update(pc_kw) if log: vmin = kw.pop("vmin", None) vmax = kw.pop("vmax", None) kw["norm"] = mpl.colors.LogNorm(vmin, vmax) pc = ax.pcolormesh(lon, lat, Z, **kw) usetex = mpl.rcParams["text.usetex"] def yfmt(n, *a): n = np.degrees(n) if titleticks and n > 70: return "" fmt = r"${:+.0f}^\circ$" return fmt.format(n) if n else "" ax.xaxis.set_ticks(np.radians(np.arange(-180, 180, 30))) ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(lambda *a: "")) ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(yfmt)) if not nohr: kw = dict( xycoords="axes fraction", textcoords="offset pixels", verticalalignment="center" ) ax.annotate(r"0h", xy=(1, 0.5), xytext=(10, 0), horizontalalignment="left", **kw) ax.annotate(r"24h", xy=(0, 0.5), xytext=(-10, 0), horizontalalignment="right", **kw) kw = copy.deepcopy(self.cb_kw) kw.update(cb_kw) cb = ax.figure.colorbar(pc, ax=ax, **kw) if unit: cb.set_label(unit) vmin, vmax = pc.get_clim() if ticks is None: if not log: ticks = np.linspace(vmin, vmax, n_ticks) # else: # a, b = np.round(np.log10 ([vmin, vmax])) # ticks = np.logspace(a, b, n_ticks) if not log: cb.set_ticks(ticks) return pc, cb
def ud_grade_interp(m, nside): from .trial import SkyScanner old_nside = hp.get_nside(m) new_ra, new_dec = SkyScanner.get_healpix_grid(nside) new_m = hp.get_interp_val(m, pi / 2 - new_dec, new_ra) return new_m def plot_energy_pdf(ax, ana, gamma, bins=400, range=None, **kw): pdf = ana.energy_pdf_ratio_model def f(sd, lE): return pdf(utils.Events(sindec=sd, log10energy=lE))(gamma=gamma)[0] if range is None: range = pdf.range h = hl.hist_from_eval(f, vectorize=False, bins=(bins, bins), range=range) return hl.plot2d(ax, h, **kw) def plot_gauss_2d_angres_param(sigma_param, bins=400, range=None, figscale=3, **kw): sp = sigma_param smoothed_bins = int(sp.hdec_base is not sp.hdec) fitted = int(sp.sdec is not None) ncol = 1 + smoothed_bins + fitted nrow = 3 fig, axs = plt.subplots(nrow, ncol, figsize=(figscale * ncol, figscale * nrow)) axs = np.array(axs) out = np.empty_like(axs) # dec i = 0 out[i, 0] = hl.plot2d(axs[i, 0], sp.hdec_base * 180 / pi, **kw) if smoothed_bins: j = smoothed_bins out[i, j] = hl.plot2d(axs[i, j], sp.hdec * 180 / pi, **kw) if fitted: def fdec(sd, lE): return sp.sdec(sd, lE) hdec = hl.hist_from_eval(fdec, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i, j] = hl.plot2d(axs[i, j], hdec * 180 / pi, **kw) # ra i = 1 out[i, 0] = hl.plot2d(axs[i, 0], sp.hra_base * 180 / pi, **kw) if smoothed_bins: j = smoothed_bins out[i, j] = hl.plot2d(axs[i, j], sp.hra * 180 / pi, **kw) if fitted: def fra(sd, lE): return sp.sra(sd, lE) hra = hl.hist_from_eval(fra, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i, j] = hl.plot2d(axs[i, j], hra * 180 / pi, **kw) # norm nkw = copy.deepcopy(kw) nkw.pop("vmin", 0) nkw["vmax"] = 1 i, j = 2, smoothed_bins out[i, j] = hl.plot2d(axs[i, j], sp.hnorm, **nkw) if fitted: def fnorm(sd, lE): return sp.snorm(sd, lE) hnorm = hl.hist_from_eval(fnorm, vectorize=False, bins=400, range=sp.range) j = smoothed_bins + fitted out[i, j] = hl.plot2d(axs[i, j], hnorm, **nkw) for ax in np.ravel(axs): ax.set_xlabel(r"$\sin(\delta_\mathsf{reco})$") ax.set_ylabel(r"$\log_{10}(E_\mathsf{reco})$") for o in out[0]: o["colorbar"].set_label(r"estimated $\sigma_\delta~[^\circ]$") for o in out[1]: o["colorbar"].set_label(r"estimated $\sigma_\alpha~[^\circ]$") for o, ax in zip(out[2], axs[2]): if o is None: ax.set_visible(False) else: o["colorbar"].set_label(r"normalization") return fig, axs, out soft_colors = ["#004466", "#d06050", "#2aca80", "#dd9388", "#caca68"] friendly_colors = ["#184b68", "#cf4d30", "#62badb", "#e797b4", "#eec9b4", "#f7dede"] mpl_colors_orig = np.array(mpl.rcParamsDefault["axes.prop_cycle"].by_key()["color"]) mpl_colors = mpl_colors_orig[[0, 3, 2, 1, 4, 5, 6, 7]] def mrichman_mpl(tex=True, sans=True, colors=mpl_colors): # plt.rc('axes', color_cycle=soft_colors) if mpl.__version__ > "1.5.1": mpl.rcParams["axes.prop_cycle"] = cycler("color", mpl_colors) else: mpl.rcParams["axes.color_cycle"] = mpl_colors mpl.rcParams["grid.linestyle"] = ":" mpl.rcParams["lines.linewidth"] = 2 mpl.rcParams["figure.facecolor"] = mpl.rcParams["savefig.facecolor"] = "w" mpl.rcParams["legend.framealpha"] = 1 # my laptop display is quad-HD mpl.rcParams["figure.dpi"] = 120 mpl.rcParams["savefig.dpi"] = 150 if tex: mpl_tex_rc(sans=sans)