Extending Autograd (original) (raw)
Adding operations to autograd requires implementing a newautograd_function
for each operation. Recall thatautograd_functions
s are what autograd
uses to compute the results and gradients, and encode the operation history. Every new function requires you to implement 2 methods:
Note
It’s the user’s responsibility to use the special functions in the forward’s ctx
properly in order to ensure that the newautograd_function
works properly with the autograd engine.
save_for_backward()
must be used when saving input or ouput of the forward to be used later in the backward.mark_dirty()
must be used to mark any input that is modified inplace by the forward function.mark_non_differentiable()
must be used to tell the engine if an output is not differentiable.
Examples
Below you can find code for a linear function:
linear <- autograd_function(
forward = function(ctx, input, weight, bias = NULL) {
ctx$save_for_backward(input = input, weight = weight, bias = bias)
output <- input$mm(weight$t())
if (!is.null(bias))
output <- output + bias$unsqueeze(0)$expand_as(output)
output
},
backward = function(ctx, grad_output) {
s <- ctx$saved_variables
grads <- list(
input = NULL,
weight = NULL,
bias = NULL
)
if (ctx$needs_input_grad$input)
grads$input <- grad_output$mm(s$weight)
if (ctx$needs_input_grad$weight)
grads$weight <- grad_output$t()$mm(s$input)
if (!is.null(s$bias) && ctx$needs_input_grad$bias)
grads$bias <- grad_output$sum(dim = 0)
grads
}
)
Here, we give an additional example of a function that is parametrized by non-Tensor arguments:
mul_constant <- autograd_function(
forward = function(ctx, tensor, constant) {
ctx$save_for_backward(constant = constant)
tensor * constant
},
backward = function(ctx, grad_output) {
v <- ctx$saved_variables
list(
tensor = grad_output * v$constant
)
}
)
x <- torch_tensor(1, requires_grad = TRUE)
o <- mul_constant(x, 2)
o$backward()
x$grad
#> torch_tensor
#> 2
#> [ CPUFloatType{1} ]