¿Cómo guardo un modelo entrenado en PyTorch?

11 minutos de lectura

avatar de usuario
Wasi Ahmed

¿Cómo guardo un modelo entrenado en PyTorch? he leido eso:

  1. torch.save()/torch.load() es para guardar/cargar un objeto serializable.
  2. model.state_dict()/model.load_state_dict() es para guardar/cargar el estado del modelo.

  • Creo que es porque torch.save() también guarda todas las variables intermedias, como las salidas intermedias para el uso de propagación hacia atrás. Pero solo necesita guardar los parámetros del modelo, como peso/sesgo, etc. A veces, el primero puede ser mucho más grande que el segundo.

    –Dawei Yang

    18 de marzo de 2017 a las 17:36

  • probé torch.save(model, f) y torch.save(model.state_dict(), f). Los archivos guardados tienen el mismo tamaño. Ahora estoy confundido. Además, descubrí que usar pickle para guardar model.state_dict() es extremadamente lento. Creo que la mejor manera es usar torch.save(model.state_dict(), f) ya que tú manejas la creación del modelo, y torch maneja la carga de los pesos del modelo, eliminando así posibles problemas. Referencia: discutir.pytorch.org/t/saving-torch-models/838/4

    –Dawei Yang

    29 de marzo de 2017 a las 2:01


  • Parece que PyTorch ha abordado esto un poco más explícitamente en su sección de tutoriales—hay mucha información buena que no aparece en las respuestas aquí, incluido guardar más de un modelo a la vez y modelos de arranque en caliente.

    – whlteXbread

    24/03/2019 a las 21:55

  • que hay de malo en usar pickle?

    –Charlie Parker

    13 de julio de 2020 a las 18:23

  • @CharlieParker torch.save se basa en pickle. Lo siguiente es del tutorial vinculado anteriormente: “[torch.save] guardará todo el módulo usando el módulo pickle de Python. La desventaja de este enfoque es que los datos serializados están vinculados a las clases específicas y la estructura de directorio exacta utilizada cuando se guarda el modelo. La razón de esto es porque pickle no guarda la clase del modelo en sí. Más bien, guarda una ruta al archivo que contiene la clase, que se usa durante el tiempo de carga. Debido a esto, su código puede romperse de varias maneras cuando se usa en otros proyectos o después de refactorizaciones”.

    –David Miller

    14 de julio de 2020 a las 9:56


avatar de usuario
no loo

Fundar esta página en su repositorio de github:

Enfoque recomendado para guardar un modelo

Hay dos enfoques principales para serializar y restaurar un modelo.

El primero (recomendado) guarda y carga solo los parámetros del modelo:

torch.save(the_model.state_dict(), PATH)

Entonces despúes:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

El segundo guarda y carga todo el modelo:

torch.save(the_model, PATH)

Entonces despúes:

the_model = torch.load(PATH)

Sin embargo, en este caso, los datos serializados están vinculados a las clases específicas y la estructura de directorio exacta utilizada, por lo que pueden romperse de varias maneras cuando se usan en otros proyectos o después de algunas refactorizaciones serias.


Ver también: Guardar y cargar el modelo sección de los tutoriales oficiales de PyTorch.

  • Según @smth discutir.pytorch.org/t/guardar-y-cargar-un-modelo-en-pytorch/… el modelo se recarga para entrenar el modelo de forma predeterminada. por lo tanto, debe llamar manualmente a the_model.eval() después de la carga, si lo está cargando para inferencia, sin reanudar el entrenamiento.

    – WillZ

    15/07/2018 a las 22:30


  • el segundo método da el error stackoverflow.com/questions/53798009/… en Windows 10. No pude resolverlo

    – Gulzar

    16 de diciembre de 2018 a las 14:29

  • ¿Hay alguna opción para guardar sin necesidad de un acceso para la clase de modelo?

    – Michael D.

    11 de diciembre de 2019 a las 14:16

  • Con ese enfoque, ¿cómo realiza un seguimiento de los *args y **kwargs que necesita pasar para el caso de carga?

    – Mariano Camp

    9 abr 2020 a las 14:40

  • Hola chicos, ¿alguien podría decirme cuál es la extensión para el archivo modelo dict (.pth?) y la extensión para todo el archivo modelo (.pkl)? ¿Estoy en lo correcto?

    – Francia

    9 ago 2021 a las 15:35

avatar de usuario
Jadiel de Armas

Depende de lo que quieras hacer.

Caso #1: Guarde el modelo para usarlo usted mismo para la inferencia: guarda el modelo, lo restaura y luego cambia el modelo al modo de evaluación. Esto se hace porque normalmente tiene BatchNorm y Dropout capas que por defecto están en modo tren en construcción:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Caso #2: Guardar modelo para retomar el entrenamiento más tarde: si necesita seguir entrenando el modelo que está a punto de guardar, necesita guardar más que solo el modelo. También debe guardar el estado del optimizador, las épocas, la puntuación, etc. Lo haría así:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Para reanudar el entrenamiento harías cosas como: state = torch.load(filepath)y luego, para restaurar el estado de cada objeto individual, algo como esto:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Ya que estás reanudando el entrenamiento, NO HAGA llamar model.eval() una vez que restablezca los estados al cargar.

Caso # 3: Modelo para ser utilizado por otra persona sin acceso a su código: En Tensorflow puedes crear un .pb archivo que define tanto la arquitectura como los pesos del modelo. Esto es muy útil, especialmente cuando se usa Tensorflow serve. La forma equivalente de hacer esto en Pytorch sería:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Esta forma todavía no es a prueba de balas y dado que pytorch todavía está experimentando muchos cambios, no lo recomendaría.

  • ¿Hay un final de archivo recomendado para los 3 casos? ¿O es siempre .pth?

    – Verena Haunschmid

    12 de febrero de 2019 a las 8:23


  • En el Caso #3 torch.load devuelve solo un OrderedDict. ¿Cómo obtienes el modelo para hacer predicciones?

    – Alber8295

    12 de febrero de 2019 a las 10:44

  • Hola, ¿puedo saber cómo hacer el mencionado “Caso #2: Guardar modelo para retomar el entrenamiento más tarde”? Logré cargar el punto de control en el modelo, luego no pude ejecutar o reanudar para entrenar el modelo como “model.to(device) model = train_model_epoch(model, criterio, Optimizer, sched, epochs)”

    – dnez

    8 de marzo de 2019 a las 7:16

  • Hola, para el caso uno que es para inferencia, en el documento oficial de pytorch dice que debe guardar el optimizador state_dict para inferencia o para completar el entrenamiento. “Al guardar un punto de control general, que se utilizará para la inferencia o la reanudación del entrenamiento, debe guardar más que solo el state_dict del modelo. Es importante guardar también el state_dict del optimizador, ya que contiene búferes y parámetros que se actualizan a medida que el modelo entrena . “

    – Mohamed Awney

    21 de septiembre de 2019 a las 13:09

  • En el caso #3, la clase de modelo debe definirse en alguna parte.

    – Michael D.

    11 de diciembre de 2019 a las 13:41

avatar de usuario
prosti

los pepinillo La biblioteca de Python implementa protocolos binarios para serializar y deserializar un objeto de Python.

Cuando usted import torch (o cuando usa PyTorch) lo hará import pickle para ti y no necesitas llamar pickle.dump() y pickle.load() directamente, cuáles son los métodos para guardar y cargar el objeto.

En realidad, torch.save() y torch.load() envolverá pickle.dump() y pickle.load() para ti.

A state_dict la otra respuesta mencionada merece solo algunas notas más.

Qué state_dict Qué tenemos dentro de PyTorch? en realidad hay dos state_dicts.

El modelo PyTorch es torch.nn.Module que tiene model.parameters() llame para obtener parámetros aprendibles (w y b). Estos parámetros de aprendizaje, una vez establecidos aleatoriamente, se actualizarán con el tiempo a medida que aprendamos. Los parámetros que se pueden aprender son los primeros state_dict.

El segundo state_dict es el dict de estado del optimizador. Recuerda que el optimizador se utiliza para mejorar nuestros parámetros de aprendizaje. Pero el optimizador state_dict está arreglado. Nada que aprender allí.

Porque state_dict Los objetos son diccionarios de Python, se pueden guardar, actualizar, modificar y restaurar fácilmente, lo que agrega una gran modularidad a los modelos y optimizadores de PyTorch.

Vamos a crear un modelo súper simple para explicar esto:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Este código generará lo siguiente:

Model's state_dict:
weight      torch.Size([2, 5])
bias      torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state      {}
param_groups      [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Tenga en cuenta que este es un modelo mínimo. Puede intentar agregar una pila de secuencias

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Tenga en cuenta que solo las capas con parámetros aprendibles (capas convolucionales, capas lineales, etc.) y los búferes registrados (capas de normas por lotes) tienen entradas en el modelo. state_dict.

Las cosas que no se pueden aprender pertenecen al objeto del optimizador state_dictque contiene información sobre el estado del optimizador, así como los hiperparámetros utilizados.

El resto de la historia es la misma; en la fase de inferencia (esta es una fase en la que usamos el modelo después del entrenamiento) para predecir; predecimos basándonos en los parámetros que aprendimos. Entonces, para la inferencia, solo necesitamos guardar los parámetros. model.state_dict().

torch.save(model.state_dict(), filepath)

Y para usar más tarde model.load_state_dict(torch.load(filepath)) model.eval()

Nota: No olvides la última línea. model.eval() esto es crucial después de cargar el modelo.

Tampoco trates de guardar torch.save(model.parameters(), filepath). los model.parameters() es solo el objeto generador.

Por otra parte, torch.save(model, filepath) guarda el objeto del modelo en sí, pero tenga en cuenta que el modelo no tiene el optimizador state_dict. Verifique la otra excelente respuesta de @Jadiel de Armas para guardar el dictado de estado del optimizador.

  • Aunque no es una solución sencilla, ¡la esencia del problema está profundamente analizada! Votar a favor.

    –Jason Young

    2 de junio de 2020 a las 14:58


avatar de usuario
duro

Una convención común de PyTorch es guardar modelos usando una extensión de archivo .pt o .pth.

Guardar/Cargar todo el modelo

Ahorrar:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Carga:

(La clase de modelo debe definirse en alguna parte)

model.load_state_dict(torch.load(PATH))
model.eval()

avatar de usuario
Alegría Mazumder

Si desea guardar el modelo y desea reanudar el entrenamiento más tarde:

GPU única:
Ahorrar:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath="checkpoint.t7"
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

GPU múltiple:
Ahorrar

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath="checkpoint.t7"
torch.save(state,savepath)

Carga:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU

avatar de usuario
iacob

Guardar localmente

La forma en que guarde su modelo depende de cómo desee acceder a él en el futuro. Si puede llamar a una nueva instancia de la model clase, entonces todo lo que necesita hacer es guardar/cargar los pesos del modelo con model.state_dict():

# Save:
torch.save(old_model.state_dict(), PATH)

# Load:
new_model = TheModelClass(*args, **kwargs)
new_model.load_state_dict(torch.load(PATH))

Si no puede por alguna razón (o prefiere la sintaxis más simple), puede guardar el modelo completo (en realidad, una referencia a los archivos que definen el modelo, junto con su state_dict) con torch.save():

# Save:
torch.save(old_model, PATH)

# Load:
new_model = torch.load(PATH)

Pero dado que esta es una referencia a la ubicación de los archivos que definen la clase del modelo, este código no es portátil a menos que esos archivos también se transfieran a la misma estructura de directorios.

Guardar en la nube – TorchHub

Si desea que su modelo sea portátil, puede importarlo fácilmente con torch.hub. Si agrega una definición apropiada hubconf.py archivo a un repositorio de github, esto se puede llamar fácilmente desde PyTorch para permitir a los usuarios cargar su modelo con/sin pesos:

hubconf.py (github.com/repo_propietario/repo_nombre)

dependencies = ['torch']
from my_module import mymodel as _mymodel

def mymodel(pretrained=False, **kwargs):
    return _mymodel(pretrained=pretrained, **kwargs)

Cargando modelo:

new_model = torch.hub.load('repo_owner/repo_name', 'mymodel')
new_model_pretrained = torch.hub.load('repo_owner/repo_name', 'mymodel', pretrained=True)

avatar de usuario
Cristiano__

pip install pytorch-relámpago

asegúrese de que su modelo principal use pl.LightningModule en lugar de nn.Module

Guardando y cargando puntos de control usando pytorch lightning

import pytorch_lightning as pl

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyModel.load_from_checkpoint(checkpoint_path="example.ckpt")

¿Ha sido útil esta solución?

Esta web utiliza cookies propias y de terceros para su correcto funcionamiento y para fines analíticos y para mostrarte publicidad relacionada con sus preferencias en base a un perfil elaborado a partir de tus hábitos de navegación. Al hacer clic en el botón Aceptar, acepta el uso de estas tecnologías y el procesamiento de tus datos para estos propósitos. Configurar y más información
Privacidad