Torch Integration

Custom serialization functions may be registered to handle external pointer type reference objects.

This allows tensors from the torch package to be used seamlessly in ‘mirai’ computations.

Setup Steps

  1. Register the serialization and unserialization functions as a list supplied to serialization(), specifying ‘class’ as ‘torch_tensor’ and ‘vec’ as TRUE.

  2. Set up dameons - this may be done before or after setting serialization().

  3. Use everywhere() to make the torch package available on all daemons for convenience (optional).

library(mirai)
library(torch)

serialization(refhook = list(torch:::torch_serialize, torch::torch_load),
              class = "torch_tensor",
              vec = TRUE)
daemons(1)
#> [1] 1
everywhere(library(torch))

Example Usage

The below example creates a convolutional neural network using torch::nn_module().

A set of model parameters is also specified.

The model specification and parameters are then passed to and initialized within a ‘mirai’.

model <- nn_module(
  initialize = function(in_size, out_size) {
    self$conv1 <- nn_conv2d(in_size, out_size, 5)
    self$conv2 <- nn_conv2d(in_size, out_size, 5)
  },
  forward = function(x) {
    x <- self$conv1(x)
    x <- nnf_relu(x)
    x <- self$conv2(x)
    x <- nnf_relu(x)
    x
  }
)

params <- list(in_size = 1, out_size = 20)

m <- mirai(do.call(model, params), model = model, params = params)

call_mirai(m)$data
#> An `nn_module` containing 1,040 parameters.
#> 
#> ── Modules ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
#> • conv1: <nn_conv2d> #520 parameters
#> • conv2: <nn_conv2d> #520 parameters

The returned model is an object containing many tensor elements.

m$data$parameters$conv1.weight
#> torch_tensor
#> (1,1,.,.) = 
#>  -0.1196  0.1559 -0.0212  0.0140 -0.0787
#>   0.1100  0.1774 -0.0593  0.0627  0.0117
#>   0.1667 -0.0139 -0.0993  0.1091 -0.1227
#>  -0.1142 -0.0439 -0.1376  0.0252  0.1501
#>   0.0055 -0.0656 -0.0991 -0.0952 -0.0686
#> 
#> (2,1,.,.) = 
#>  -0.0989  0.0575  0.0692  0.1024 -0.0078
#>  -0.0911 -0.1030  0.0104  0.1177 -0.0083
#>  -0.0638  0.1606  0.0528  0.0920 -0.0958
#>   0.0667  0.0672  0.0642  0.0450  0.1796
#>  -0.1445 -0.0755  0.1973 -0.0961  0.1017
#> 
#> (3,1,.,.) = 
#>   0.1813 -0.0337  0.1923  0.0747  0.1205
#>   0.1879  0.1059  0.0944 -0.0888 -0.0202
#>   0.1901  0.1758 -0.0801  0.1417 -0.0559
#>  -0.0497  0.0390 -0.0177  0.0538 -0.1835
#>  -0.1041 -0.0133  0.0808 -0.1594  0.1890
#> 
#> (4,1,.,.) = 
#>   0.1759 -0.0181  0.0007  0.0052 -0.0026
#>   0.1767  0.1281 -0.0077 -0.1438  0.1677
#>  -0.0566  0.1038 -0.1956 -0.0490  0.1493
#>   0.1112  0.0389  0.1229  0.1681  0.1956
#>   0.1550  0.1753  0.0500 -0.0603 -0.1313
#> 
#> (5,1,.,.) = 
#>  -0.0235 -0.1409 -0.0469  0.0176  0.1700
#> ... [the output was truncated (use n=-1 to disable)]
#> [ CPUFloatType{20,1,5,5} ][ requires_grad = TRUE ]

It is usual for model parameters to then be passed to an optimiser.

This can also be initialized within a ‘mirai’ process.

optim <- mirai(optim_rmsprop(params = params), params = m$data$parameters)

call_mirai(optim)$data
#> <optim_rmsprop>
#>   Inherits from: <torch_optimizer>
#>   Public:
#>     add_param_group: function (param_group) 
#>     clone: function (deep = FALSE) 
#>     defaults: list
#>     initialize: function (params, lr = 0.01, alpha = 0.99, eps = 1e-08, weight_decay = 0, 
#>     load_state_dict: function (state_dict, ..., .refer_to_state_dict = FALSE) 
#>     param_groups: list
#>     state: State, R6
#>     state_dict: function () 
#>     step: function (closure = NULL) 
#>     zero_grad: function () 
#>   Private:
#>     step_helper: function (closure, loop_fun)

daemons(0)
#> [1] 0

Above, tensors and complex objects containing tensors were passed seamlessly between host and daemon processes, in the same way as any other R object.

The custom serialization in mirai leverages R’s own native ‘refhook’ mechanism to allow such completely transparent usage. Designed to be fast and efficient, data copies are minimised and the ‘official’ serialization methods from the torch package are used directly.