Class representing the context.
Class representing the context.
Class representing the context.
ptr
: (Dev related) pointer to the context c++ object.
needs_input_grad
: boolean listing arguments of forward
and whether they require_grad.
saved_variables
: list of objects that were saved for backward via save_for_backward
.
new()
(Dev related) Initializes the context. Not user related.
AutogradContext$new(
ptr,
env,
argument_names = NULL,
argument_needs_grad = NULL
)
ptr
: pointer to the c++ object
env
: environment that encloses both forward and backward
argument_names
: names of forward arguments
argument_needs_grad
: whether each argument in forward needs grad.
save_for_backward()
Saves given objects for a future call to backward().
This should be called at most once, and only from inside the forward()
method.
Later, saved objects can be accessed through the saved_variables
attribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.
Arguments can also be any kind of R object.
AutogradContext$save_for_backward(...)
...
: any kind of R object that will be saved for the backward pass. It's common to pass named arguments.
mark_non_differentiable()
Marks outputs as non-differentiable.
This should be called at most once, only from inside the forward()
method, and all arguments should be outputs.
This will mark outputs as not requiring gradients, increasing the efficiency of backward computation. You still need to accept a gradient for each output in backward()
, but it’s always going to be a zero tensor with the same shape as the shape of a corresponding output.
This is used e.g. for indices returned from a max Function.
AutogradContext$mark_non_differentiable(...)
...
: non-differentiable outputs.
mark_dirty()
Marks given tensors as modified in an in-place operation.
This should be called at most once, only from inside the forward()
method, and all arguments should be inputs.
Every tensor that’s been modified in-place in a call to forward()
should be given to this function, to ensure correctness of our checks. It doesn’t matter whether the function is called before or after modification.
AutogradContext$mark_dirty(...)
...
: tensors that are modified in-place.
clone()
The objects of this class are cloneable with this method.
AutogradContext$clone(deep = FALSE)
deep
: Whether to make a deep clone.
Useful links