"""
This utility makes a combined heatmap and scatter plot -- a heatmap
for data in high-density regions, a scatter plot to show outliers in
low-density regions.
"""
import numpy as np
from matplotlib import rcParams
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.stats import binned_statistic_2d
try:
from astropy.units import Quantity
from astropy.units import dimensionless_unscaled as du
have_astropy = True
except:
have_astropy = False
[docs]
def heatscatter(x, y,
aspect = 'equal',
ax = None,
cmap = rcParams["image.cmap"],
log = True,
nxbin = 50,
nybin = 50,
scale = 'density',
scat_alpha = 0.5,
scat_color = None,
scat_size = 5,
vmax = None,
vmin = None,
vmin_contour = False,
vmin_contour_alpha = 1.0,
vmin_contour_color = None,
vmin_frac = 0.999,
w = None,
xlim = None,
ylim = None):
"""
Makes a 2d heatmap plus scatter plot, where data in high-density
regions are shown using a heatmap, and data in low-density regions
are shown using scatter points.
Parameters:
x : array or astropy.Quantity, shape (N,)
x coordinates of points
y : array or astropy.Quantity, shape (N,)
y coordinates of points
aspect : 'auto' | 'equal' | num
aspect ratio of the central panel; effects are identical to
calling set_aspect on the central panel axes
ax : matplotlib.Axes
axes into which to draw plot; defaults to the currently
active axes
cmap : str | Colormap
color map to use for heatmap
log : bool
if True, use a logarithmic scale for the heatmap
nxbin : int or arraylike, shape (N,)
number of bins in x direction
nybin : int or arraylike, shape (N,)
number of bins in y direction
scale : 'total' | 'max' | 'frac' | 'density' | 'normed'
method used to scale bin values; 'total' means that the
quantity shown is the unscaled sum of the weights of all
points in that bin; 'max' means that bins are normalized so
the value in the largest bin is unity; 'frac' means that the
value shown is the fraction of the weight in each bin;
'density' means that the bins are normalized so that the
quantity shown is the sum of weights divided by the size /
area of the bin; 'normed' means that bins are normalized so
that the area integral is unity
scat_alpha : float
alpha value for scatter plot points; setting this to zero
suppresses scatter points entirely
scat_color : string
color to use for scatter points
scat_size : float
size of marker for scatter points; equivalent to s parameter
in matplotlib.pyplot.scatter
vmax : float or astropy.Quantity
maximum value for heatmap; equivalent to the vmax keyword
for matplotlib.pyplot.imshow
vmin : float or astropy.Quantity
minimum value for heatmap; points in regions where the
heatmap value is < vmin are shown as scatter points; if left
as None, a value will be computed from vmin_frac
vmin_contour : bool
if True, draw a contour at the boundary between the heat
mapped region and the scatter point region
vmin_contour_alpha : float
alpha value of the threshold contour
vmin_contour_color : string
color of the threshold contour
vmin_frac : float
fraction of the sum of the weights at which to switch from
showing the data as a heatmap to showing individual scatter
points; used to compute vmin
w : array or astropy.Quantity, shape (N,)
weights of individual points; if None, all points have
weight 1
xlim, ylim : arraylike, shape (2,)
plotting limits in x and y directions
Returns:
img : matplotlib.image.AxesImage
a callback to the heatmap, created by imshow
"""
# Set the weights if not specified
if w is None:
w = np.ones(x.shape)
# Set the scatter point and contour colors if not specified; we
# set these to obtain maximum contrast from the color used for the
# color table minimum value, since they will be plotted against
# a background of this color.
if scat_color is None:
c0 = cm.get_cmap(cmap)(0)[:3]
scat_color = np.zeros(3)
for i in range(3):
if c0[i] < 0.5: scat_color[i] = 1.0
else: scat_color[i] = 0.0
scat_color = np.atleast_2d(scat_color)
if vmin_contour_color is None:
c0 = cm.get_cmap(cmap)(0)[:3]
vmin_contour_color = np.zeros(3)
for i in range(3):
if c0[i] < 0.5: vmin_contour_color[i] = 1.0
else: vmin_contour_color[i] = 0.0
vmin_contour_color = np.atleast_2d(vmin_contour_color)
# If the data we have been passed are astropy quantities, handle
# any unit conversions here, then change everything here to flat
# numpy arrays; many of the functions we use below don't play
# nicely with astropy units
if have_astropy:
# Strip x
if type(x) is Quantity:
ux = x.unit
x_ = (x/ux).to(du).value
if xlim is not None:
xlim_ = np.array([ (xlim[0]/ux).to(du).value,
(xlim[1]/ux).to(du).value ])
else:
xlim_ = None
else:
ux = du
x_ = x
xlim_ = xlim
# Strip y
if type(y) is Quantity:
uy = y.unit
y_ = (y/uy).to(du).value
if ylim is not None:
ylim_ = np.array([ (ylim[0]/uy).to(du).value,
(ylim[1]/uy).to(du).value ])
else:
ylim_ = None
else:
uy = du
y_ = y
ylim_ = ylim
# Strip w
if type(w) is Quantity:
uw = w.unit
w_ = (w/uw).to(du).value
else:
uw = du
w_ = w
# Strip vmin and vmax; this is a bit complicated, since the
# units we expect these quantities to have depends on the
# choice of the scale parameter; the
# options are:
# 'total' ==> same units as w
# 'max' or 'frac' ==> dimensionless
# 'density' ==> units of w / (x*y)
# 'normed' ==> units of 1 / (x*y)
if vmin is not None:
if scale == 'total':
vmin_ = (vmin / uw).to(du).value
elif scale == 'max' or vmin == 'frac':
vmin_ = vmin
elif scale == 'density':
vmin_ = (vmin * ux * uy / uw).to(du).value
elif scale == 'normed':
vmin_ = (vmin * ux * uy).to(du).value
else:
vmin_ = vmin
if vmax is not None:
if scale == 'total':
vmax_ = (vmax / uw).to(du).value
elif scale == 'max' or vmax == 'frac':
vmax_ = vmax
elif scale == 'density':
vmax_ = (vmax * ux * uy / uw).to(du).value
elif scale == 'normed':
vmax_ = (vmax * ux * uy).to(du).value
else:
vmax_ = vmax
else:
x_ = x
y_ = y
w_ = z
xlim_ = xlim
ylim_ = ylim
vmin_ = vmin
vmax_ = vmax
ux = None
uy = None
uz = None
# Define the grid on which to compute the heatmap
if xlim is None:
xlim_ = [np.amin(x_), np.amax(x_)]
if ylim is None:
ylim_ = [np.amin(y_), np.amax(y_)]
xgrd = np.linspace(xlim_[0], xlim_[1], nxbin+1)
ygrd = np.linspace(ylim_[0], ylim_[1], nybin+1)
xgrd_h = 0.5*(xgrd[1:]+xgrd[:-1])
ygrd_h = 0.5*(ygrd[1:]+ygrd[:-1])
xx, yy = np.meshgrid(xgrd_h, ygrd_h)
# Get 2D histogram; note that we have to handle the case of
# inverted axis limits with care, because binned_statistic_2d
# doesn't natively support them
xlim1 = np.sort(xlim_)
ylim1 = np.sort(ylim_)
binsum, xe, ye, binidx \
= binned_statistic_2d(x_, y_, w_,
statistic='sum',
bins=[nxbin, nybin],
range = [[float(xlim1[0]), float(xlim1[1])],
[float(ylim1[0]), float(ylim1[1])]],
expand_binnumbers=True)
if xlim_[0] > xlim_[1]:
binsum = binsum[::-1, :]
xe = xe[::-1]
binidx[0,:] = nxbin+1 - binidx[0,:]
if ylim_[0] > ylim_[1]:
binsum = binsum[:, ::-1]
ye = ye[::-1]
binidx[1,:] = nybin+1 - binidx[1,:]
# Set z
if scale == 'total':
z = binsum
elif scale == 'max':
z = binsum / np.amax(binsum)
elif scale == 'frac':
z = binsum / np.sum(w_)
elif scale == 'density':
z = binsum / np.abs((xe[1]-xe[0])*(ye[1]-ye[0]))
elif scale == 'normed':
z = binsum / np.abs((xe[1]-xe[0])*(ye[1]-ye[0])) / np.sum(w_)
else:
raise ValueError("unknown scale parameter: "+str(scale))
# Compute vmin if not specified
if vmin is None:
zsort = np.sort(z, axis=None)
csum = np.cumsum(zsort)
csum = csum / csum[-1]
vmin_ = zsort[np.argmax(csum > 1.0-vmin_frac)]
# Take log if requested
if log:
if np.amax(z) <= 0.0:
raise ValueError("cannot use log scale: no positive z values")
z[z <= 0] = np.amin(z[z > 0])
z = np.log10(z)
vmin_ = np.log10(vmin_)
if vmax_ is not None:
vmax_ = np.log10(vmax_)
# Get indices of individual points to show
flag = np.logical_and.reduce((binidx[0,:] > 0,
binidx[1,:] > 0,
binidx[0,:] <= binsum.shape[0],
binidx[1,:] <= binsum.shape[1]))
scatteridx = np.zeros(len(x), dtype=bool)
scatteridx[flag] \
= z[binidx[0,flag]-1, binidx[1,flag]-1] < vmin_
# Get axes
if ax is None:
ax = plt.gca()
# Plot contour at vmin if requested
if vmin_contour:
ax.contour(xx, yy, np.transpose(z),
levels = [vmin_],
colors = vmin_contour_color,
alpha = vmin_contour_alpha,
linestyles = '-')
# Plot scatter points outside contour
if scat_alpha > 0:
ax.scatter(x[scatteridx],
y[scatteridx],
color=scat_color,
s=scat_size,
alpha=scat_alpha,
edgecolor='none')
# Plot density map
img = ax.imshow(np.transpose(z),
origin='lower', aspect=aspect,
vmin=vmin_, vmax=vmax_, cmap=cmap,
extent=[xlim_[0], xlim_[1], ylim_[0], ylim_[1]])
# Set plot range
ax.set_xlim(xlim_)
ax.set_ylim(ylim_)
# Return callback
return ax