Graph Network
Represents a neural network using a Graph
that usually costains mostly PipeOpModule
s.
nn_graph(graph, shapes_in, output_map = graph$output$name, list_output = FALSE)
graph
: (Graph
)
The Graph
to wrap. Is not cloned.
shapes_in
: (named integer
)
Shape info of tensors that go into graph
. Names must be graph$input$name
, possibly in different order.
output_map
: (character
)
Which of graph
's outputs to use. Must be a subset of graph$output$name
.
list_output
: (logical(1)
)
Whether output should be a list of tensors. If FALSE
(default), then length(output_map)
must be 1.
nn_graph
graph = mlr3pipelines::Graph$new() graph$add_pipeop(po("module_1", module = nn_linear(10, 20)), clone = FALSE) graph$add_pipeop(po("module_2", module = nn_relu()), clone = FALSE) graph$add_pipeop(po("module_3", module = nn_linear(20, 1)), clone = FALSE) graph$add_edge("module_1", "module_2") graph$add_edge("module_2", "module_3") network = nn_graph(graph, shapes_in = list(module_1.input = c(NA, 10))) x = torch_randn(16, 10) network(module_1.input = x)
Other Graph Network: ModelDescriptor()
, TorchIngressToken()
, mlr_learners_torch_model
, mlr_pipeops_module
, mlr_pipeops_torch
, mlr_pipeops_torch_ingress
, mlr_pipeops_torch_ingress_categ
, mlr_pipeops_torch_ingress_ltnsr
, mlr_pipeops_torch_ingress_num
, model_descriptor_to_learner()
, model_descriptor_to_module()
, model_descriptor_union()
Useful links