replace_head function

Replace the head of a network Replaces the head of the network with a linear layer with d_out classes.

Replace the head of a network Replaces the head of the network with a linear layer with d_out classes.

replace_head(network, d_out)

Arguments

  • network: (torch::nn_module)

    The network

  • d_out: (integer(1))

    The number of output classes.