¿Cómo muestro una sola imagen en PyTorch?

4 minutos de lectura

Avatar de usuario de Tom Hale
tom hale

¿Cómo muestro un PyTorch? Tensor de forma (3, 224, 224) que representa una imagen RGB de 224×224? Usando plt.imshow(image) da el error:

TypeError: dimensiones no válidas para datos de imagen

Avatar de usuario de Tom Hale
tom hale

Dado un Tensor representando la imagen, utilice .permute() para poner los canales como la última dimensión:

plt.imshow(  tensor_image.permute(1, 2, 0)  )

Nota: permute no copia ni asigna memoriay from_numpy() tampoco

  • Vaya, gracias… Esto funcionó para mí… Estaba tratando de hacer tensor_image.numpy().reshape([224,224,3]) y visualízalo usando cv2.imshow() Pero no estaba obteniendo la imagen real… ¿qué está mal aquí?

    – Devashish Prasad

    04/06/2020 a las 14:50

  • @DevashishPrasad El problema es que reshape([224,224,3]) no hace lo mismo que permute(1, 2, 0) hace. El permute La función es similar a transponer una matriz, donde las filas se convierten en columnas y las columnas en filas. El reshape función hace algo totalmente ajeno que no sé cómo describir de manera concisa. En breve, reshape es la función incorrecta.

    – Tanner Swett

    1 de marzo de 2021 a las 17:53

  • cual es la forma de tensor_image ?

    –Charlie Parker

    16 de noviembre de 2022 a las 23:02

  • Una alternativa posiblemente más legible es plt.imshow(torch.einsum('cwh->whc', tensor_image))

    – rusheb

    28 de diciembre de 2022 a las 17:27

avatar de usuario de trsvchn
trsvchn

Como se puede ver matplotlib funciona bien incluso sin conversión a numpy formación. Pero PyTorch Tensors (“Tensores de imagen”) son canales primero, así que para usarlos con matplotlib necesitas remodelarlo:

Código:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Producción:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])

  • Hmm, no funciona para mí, vea la pregunta actualizada con la forma del tensor.

    – Tom Hale

    14 de diciembre de 2018 a las 4:40

Avatar de usuario de Tom Hale
tom hale

Dado que la imagen se carga como se describe y se almacena en la variable image:

plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively

o como Soumith sugirió:

def show(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')

  • import torchvision.transforms # tal vez agregue la importación al código

    –Fuji

    10 de mayo de 2021 a las 4:46

avatar de usuario de iacob
iacob

Los módulos PyTorch que procesan datos de imagen esperan tensores en el formato profundidad × alto × ancho.1

Mientras que PILLow y Matplotlib esperan matrices de imágenes en el formato alto × ancho × profundidad.2

Puede convertir fácilmente tensores a/de este formato con una transformación TorchVision:

from torchvision import transforms.functional as F

F.to_pil_image(image_tensor)

O permutando directamente los ejes:

image_tensor.permute(1,2,0)

  1. Los módulos de PyTorch que tratan con datos de imagen requieren que los tensores se distribuyan como profundidad × alto × ancho : canales, altura y ancho, respectivamente.

  2. Tenga en cuenta cómo tenemos que usar permute para cambiar el orden de los ejes de profundidad × alto × ancho a alto × ancho × fondo para que coincida con lo que espera Matplotlib.

Avatar de usuario de Tom Hale
tom hale

Un ejemplo completo dado un nombre de ruta de imagen img_path:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Tenga en cuenta que transforms.* devolver un claserazón por la cual el corchete funky.

Avatar de usuario de TheExorcist
El exorcista

La antorcha tiene forma de canal, altura, ancho, necesita convertirla en altura, ancho, canal para permutar.

plt.imshow(white_torch.permute(1, 2, 0))

O directamente si quieres

import torch
import torchvision
from torchvision.io import read_image
import torchvision.transforms as T

!wget 'https://images.unsplash.com/photo-1553284965-83fd3e82fa5a?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxleHBsb3JlLWZlZWR8NHx8fGVufDB8fHx8&w=1000&q=80'  -O white_horse.jpg

white_torch = torchvision.io.read_image('white_horse.jpg')

T.ToPILImage()(white_torch)

ingrese la descripción de la imagen aquí

avatar de usuario de aravinda_gn
aravinda_gn

Usa show_image de fastai

from fastai.vision.all import show_image

ingrese la descripción de la imagen aquí

ingrese la descripción de la imagen aquí

¿Ha sido útil esta solución?