usuario637140
Conozco dos formas de excluir elementos de un cálculo del cálculo de gradiente backward
Método 1: usando with torch.no_grad()
with torch.no_grad():
y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y)
loss.backward();
Método 2: usando .detach()
y = reward + gamma * torch.max(net.forward(x))
loss = criterion(net.forward(torch.from_numpy(o)), y.detach())
loss.backward();
¿Hay alguna diferencia entre estos dos? ¿Hay ventajas/desventajas en cualquiera de los dos?
tensor.detach()
crea un tensor que comparte almacenamiento con un tensor que no requiere graduación. Separa la salida del gráfico computacional. Por lo tanto, ningún gradiente se retropropagará a lo largo de esta variable.
el envoltorio with torch.no_grad()
establecer temporalmente todos los requires_grad
marca a falso. torch.no_grad
dice que ninguna operación debe construir el gráfico.
La diferencia es que uno se refiere solo a una variable dada en la que se llama. El otro afecta a todas las operaciones que tienen lugar dentro de la with
declaración. También, torch.no_grad
usará menos memoria porque sabe desde el principio que no se necesitan gradientes, por lo que no necesita mantener resultados intermedios.
Obtenga más información sobre las diferencias entre estos junto con ejemplos de aquí.
detach()
Un ejemplo sin detach()
:
from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x
r=(y+z).sum()
make_dot(r)
El resultado final en verde. r
es una raíz del gráfico computacional AD y en azul es el tensor hoja.
Otro ejemplo con detach()
:
from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.detach()
r=(y+z).sum()
make_dot(r)
Esto es lo mismo que:
from torchviz import make_dot
x=torch.ones(2, requires_grad=True)
y=2*x
z=3+x.data
r=(y+z).sum()
make_dot(r)
Pero, x.data
es la forma antigua (notación), y x.detach()
es la nueva forma.
¿Cuál es la diferencia con x.detach()
print(x)
print(x.detach())
Afuera:
tensor([1., 1.], requires_grad=True)
tensor([1., 1.])
Entonces x.detach()
es una forma de quitar requires_grad
y lo que obtienes es un nuevo separado tensor (separado del gráfico computacional AD).
antorcha.no_grad
torch.no_grad
es en realidad una clase.
x=torch.ones(2, requires_grad=True)
with torch.no_grad():
y = x * 2
print(y.requires_grad)
Afuera:
False
De help(torch.no_grad)
:
Deshabilitar el cálculo de gradiente es útil para la inferencia, cuando está seguro | que no llamaras :meth:
Tensor.backward()
. Reducirá la memoria | consumo para cálculos que de otro modo tendríanrequires_grad=True
. |
| En este modo, el resultado de cada cálculo tendrá |requires_grad=False
incluso cuando las entradas tienenrequires_grad=True
.
-
Gracias por las respuestas… brinda una descripción rápida e intuitiva de las funciones .data y detach en el gráfico de cálculo
– Kashyap
25 de diciembre de 2020 a las 4:00
-
@prosti ¿Cuál es la forma completa y el significado de AD?
– Purushothaman Srikanth
27 de febrero de 2021 a las 12:36
-
SHAGUN SHARMA
Una explicación simple y profunda es que el uso de with torch.no_grad()
se comporta como un bucle donde todo lo escrito en él tendrá allí requires_grad
argumento establecido como False
aunque temporalmente. Por lo tanto, no es necesario especificar nada más allá de esto si necesita detener la retropropagación de los gradientes de ciertas variables o funciones.
Sin embargo, torch.detach()
simplemente separa la variable del gráfico de cálculo de gradiente como sugiere el nombre. Pero esto se usa cuando esta especificación debe proporcionarse para un número limitado de variables o funciones, por ejemplo. generalmente, mientras se muestran los resultados de pérdida y precisión después de que finaliza una época en el entrenamiento de la red neuronal porque en ese momento, solo consumía recursos, ya que su gradiente no importará durante la visualización de los resultados.
-
¡Simple! Esta es una gran respuesta.
–Bryce Wayne
21 de enero de 2021 a las 18:52