Encabezados de fila y columna en las subparcelas de matplotlib

5 minutos de lectura

¿Cuál es la mejor práctica para agregar una fila y un encabezado de columna a una cuadrícula de subparcelas generadas en un bucle en matplotlib? Puedo pensar en un par, pero no particularmente bueno:

  1. Para columnas, con un contador para su ciclo puede usar set_title() solo para la primera fila. Para filas esto no funciona. tendrías que dibujar text fuera de las parcelas.
  2. Agrega una fila adicional de subparcelas en la parte superior y una columna adicional de subparcelas a la izquierda, y dibuja texto en el medio de esa subparcela.

¿Puede sugerir una mejor alternativa?

ingrese la descripción de la imagen aquí

avatar de usuario
jose kington

Hay varias maneras de hacer esto. La manera fácil es explotar las etiquetas y y los títulos de la trama y luego usar fig.tight_layout() para dejar sitio a las etiquetas. Alternativamente, puede colocar texto adicional en la ubicación correcta con annotate y luego hacer espacio para ello semi-manualmente.


Si no tiene etiquetas y en sus ejes, es fácil explotar el título y la etiqueta y de la primera fila y columna de ejes.

import matplotlib.pyplot as plt

cols = ['Column {}'.format(col) for col in range(1, 4)]
rows = ['Row {}'.format(row) for row in ['A', 'B', 'C', 'D']]

fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(12, 8))

for ax, col in zip(axes[0], cols):
    ax.set_title(col)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=0, size="large")

fig.tight_layout()
plt.show()

ingrese la descripción de la imagen aquí


Si tiene etiquetas y, o si prefiere un poco más de flexibilidad, puede usar annotate para colocar las etiquetas. Esto es más complicado, pero le permite tener títulos de gráficos individuales, ylabels, etc. además de las etiquetas de fila y columna.

import matplotlib.pyplot as plt
from matplotlib.transforms import offset_copy


cols = ['Column {}'.format(col) for col in range(1, 4)]
rows = ['Row {}'.format(row) for row in ['A', 'B', 'C', 'D']]

fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(12, 8))
plt.setp(axes.flat, xlabel="X-label", ylabel="Y-label")

pad = 5 # in points

for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords="axes fraction", textcoords="offset points",
                size="large", ha="center", va="baseline")

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords="offset points",
                size="large", ha="right", va="center")

fig.tight_layout()
# tight_layout doesn't take these labels into account. We'll need 
# to make some room. These numbers are are manually tweaked. 
# You could automatically calculate them, but it's a pain.
fig.subplots_adjust(left=0.15, top=0.95)

plt.show()

ingrese la descripción de la imagen aquí

  • Los métodos is_first_col(), is_last_col(), is_first_row() y is_last_row() también puede ser conveniente en este contexto.

    – Gerrit

    19 de junio de 2017 a las 10:22

  • También como nota, anotar matplotlib tiene la opción de rotación, por lo que si desea rotar su etiqueta 90 grados, simplemente agregue el argumento rotation = 90

    – mathishard.pero nos encanta

    26 mayo 2020 a las 21:41

  • Disculpe, ¿qué hace el número 8 en figsize = (12, 8)?

    – Amirhossein

    14 de marzo de 2021 a las 8:13

  • @Amirhosein El 8 es la altura de la figura en pulgadas y el 12 es el ancho, por lo que (12,8) asegura una relación de aspecto de 4×3. Debido a que hay subparcelas de 3×3, esto da como resultado un diseño sensato de subparcelas de paisaje.

    – espuma78

    25 de abril de 2021 a las 18:49

La respuesta anterior funciona. Simplemente no es que en la segunda versión de la respuesta, tienes:

for ax, row in zip(axes[:,0], rows):
    ax.annotate(col, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad,0),
                xycoords=ax.yaxis.label, textcoords="offset points",
                size="large", ha="right", va="center")

en vez de:

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row,xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad,0),                    
                xycoords=ax.yaxis.label, textcoords="offset points",
                size="large", ha="right", va="center")

avatar de usuario
paime

Según la respuesta de Joe Kington, puse una función que se puede reutilizar en una base de código:

Acepta como argumentos:

  • fig : La figura que contiene los ejes a trabajar
  • row_headers, col_headers : una secuencia de cadenas para ser encabezados
  • row_pad, col_pad: int valor para ajustar el relleno
  • rotate_row_headers: si rotar 90° los encabezados de fila
  • **text_kwargs: remitido a ax.annotate(...)

Función aquí, ejemplos a continuación:

import numpy as np

def add_headers(
    fig,
    *,
    row_headers=None,
    col_headers=None,
    row_pad=1,
    col_pad=5,
    rotate_row_headers=True,
    **text_kwargs
):
    # Based on https://stackoverflow.com/a/25814386

    axes = fig.get_axes()

    for ax in axes:
        sbs = ax.get_subplotspec()

        # Putting headers on cols
        if (col_headers is not None) and sbs.is_first_row():
            ax.annotate(
                col_headers[sbs.colspan.start],
                xy=(0.5, 1),
                xytext=(0, col_pad),
                xycoords="axes fraction",
                textcoords="offset points",
                ha="center",
                va="baseline",
                **text_kwargs,
            )

        # Putting headers on rows
        if (row_headers is not None) and sbs.is_first_col():
            ax.annotate(
                row_headers[sbs.rowspan.start],
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - row_pad, 0),
                xycoords=ax.yaxis.label,
                textcoords="offset points",
                ha="right",
                va="center",
                rotation=rotate_row_headers * 90,
                **text_kwargs,
            )

Aquí hay un ejemplo de cómo usarlo en una cuadrícula estándar (ningún eje abarca varias filas/columnas):

import random
import matplotlib.pyplot as plt

mosaic = [
    ["A0", "A1", "A2"],
    ["B0", "B1", "B2"],
]
row_headers = ["Row A", "Row B"]
col_headers = ["Col 0", "Col 1", "Col 2"]

subplots_kwargs = dict(sharex=True, sharey=True, figsize=(10, 6))
fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)

font_kwargs = dict(fontfamily="monospace", fontweight="bold", fontsize="large")
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)

plt.show()

resultado: rejilla regular

Si algunos ejes abarcan varias filas/columnas, se vuelve un poco menos sencillo asignar encabezados de filas/columnas correctamente. No logré resolverlo desde dentro de la función, pero teniendo cuidado con el dado row_headers y col_headers argumentos es suficiente para que funcione fácilmente:

mosaic = [
    ["A0", "A1", "A1", "A2"],
    ["A0", "A1", "A1", "A2"],
    ["B0", "B1", "B1", "B2"],
]

row_headers = ["A", "A", "B"]  # or
row_headers = ["A", None, "B"]  # or
row_headers = {0: "A", 2: "B"}

col_headers = ["0", "1", "1", "2"]  # or
col_headers = ["0", "1", None, "2"]  # or
col_headers = {0: "0", 1: "1", 3: "2"}

fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)
plt.show()

resultado: rejilla no regular

¿Ha sido útil esta solución?