nn_graph function

Graph Network

Graph Network

Represents a neural network using a Graph that usually costains mostly PipeOpModules.

nn_graph(graph, shapes_in, output_map = graph$output$name, list_output = FALSE)

Arguments

  • graph: (Graph)

    The Graph to wrap. Is not cloned.

  • shapes_in: (named integer)

    Shape info of tensors that go into graph. Names must be graph$input$name, possibly in different order.

  • output_map: (character)

    Which of graph's outputs to use. Must be a subset of graph$output$name.

  • list_output: (logical(1))

    Whether output should be a list of tensors. If FALSE (default), then length(output_map) must be 1.

Returns

nn_graph

Examples

graph = mlr3pipelines::Graph$new() graph$add_pipeop(po("module_1", module = nn_linear(10, 20)), clone = FALSE) graph$add_pipeop(po("module_2", module = nn_relu()), clone = FALSE) graph$add_pipeop(po("module_3", module = nn_linear(20, 1)), clone = FALSE) graph$add_edge("module_1", "module_2") graph$add_edge("module_2", "module_3") network = nn_graph(graph, shapes_in = list(module_1.input = c(NA, 10))) x = torch_randn(16, 10) network(module_1.input = x)

See Also

Other Graph Network: ModelDescriptor(), TorchIngressToken(), mlr_learners_torch_model, mlr_pipeops_module, mlr_pipeops_torch, mlr_pipeops_torch_ingress, mlr_pipeops_torch_ingress_categ, mlr_pipeops_torch_ingress_ltnsr, mlr_pipeops_torch_ingress_num, model_descriptor_to_learner(), model_descriptor_to_module(), model_descriptor_union()