# Texture Generation with Neural Cellular Automata
# This script is a port to torch for R of
# https://colab.research.google.com/github/google-research/self-organising-systems/blob/master/notebooks/texture_nca_pytorch.ipynb
# by Alex Mordvintsev (@zzznah)

# Packages ----------------------------------------------------------------

library(torch)
library(torchvision)
library(purrr)
library(zeallot)

device <- if(cuda_is_available()) "cuda" else "cpu"

# Function definitions ----------------------------------------------------

gram_matrix <- function(input) {
  size <- input$size()
  G <- torch_einsum("bchw, bdhw -> bcd", list(input, input))
  G$div(prod(tail(size, 2)))
}

style_loss <- function(content, style) {
  (content - style)$square()$mean()
}

# Model definition --------------------------------------------------------

cnn <- model_vgg16(pretrained = TRUE)$features

# we create an nn_module that does the same as the sequential container but
# returns the results of all convolutions. we also replace inplace operations
# for copy on modify ones
features <- nn_module(
  initialize = function(cnn) {
    self$cnn <- cnn
  },
  forward = function(input) {
    conv_outs <- list()
    for (i in seq_along(self$cnn)) {
      layer <- self$cnn[[i]]

      if (inherits(layer, "nn_relu"))
        input <- nnf_relu(input)
      else
        input <- layer(input)

      if (inherits(layer, "nn_conv2d"))
        conv_outs[[length(conv_outs) + 1]] <- input

    }
    conv_outs
  }
)

model <- features(cnn)
model$to(device = device)


# CA model ---------------------------------------------------------------

perchannel_conv <- function(x, filters) {
  c(b, ch, h, w) %<-% x$shape
  y <- x$reshape(c(b*ch, 1, h, w))
  y <- nnf_pad(y, rep(1, 4), mode = "circular")
  y <- nnf_conv2d(y, filters[,newaxis])
  y$reshape(c(b, -1, h, w))
}

perception <- nn_module(
  initialize = function() {
    ident <- torch_tensor(rbind(
      c(0, 0, 0),
      c(0, 1, 0),
      c(0, 0, 0)
    ), device = device)

    sobel_x <- torch_tensor(rbind(
      c(-1, 0, 1),
      c(-2, 0, 2),
      c(-1, 0, 1)
    ), device = device)$div(8)

    lap <- torch_tensor(rbind(
      c(1, 2, 1),
      c(-1, -12, 2),
      c(1, 2, 1)
    ), device = device)$div(16)

    self$filters <- torch_stack(list(
      ident,
      sobel_x,
      sobel_x$t(),
      lap
    ))
  },

  forward = function(x) {
    perchannel_conv(x, self$filters)
  }
)

CA <- nn_module(
  "CA",
  initialize = function(chn = 12, hidden_n = 96) {
    self$chn <- chn
    self$w1 <- nn_conv2d(chn*4, hidden_n, 1)
    self$w2 <- nn_conv2d(hidden_n, chn, 1, bias = FALSE)
    nn_init_zeros_(self$w2$weight)
    self$perception <- perception()
  },
  forward = function(x, update_rate = 0.5) {
    y <- x %>%
      self$perception() %>%
      self$w1() %>%
      torch_relu() %>%
      self$w2()

    update_mask <- torch_rand_like(y[,1,..,drop=FALSE]) < update_rate
    x + y * update_mask
  },
  seed = function(n, size = 128) {
    torch_zeros(n, self$chn, size, size, device = device)
  }
)


# Read image and preprocess -----------------------------------------------

norm_mean <- c(0.485, 0.456, 0.406)
norm_std <- c(0.229, 0.224, 0.225)

normalize <- function(img) {
  transform_normalize(img, norm_mean, norm_std)
}

denormalize <- function(img) {
  transform_normalize(img, -norm_mean/norm_std, 1/norm_std)
}

plot_img <- function(x) {
  im <- denormalize(x)[1,..]$
    permute(c(2, 3, 1))$
    to(device = "cpu")$
    clamp(0,1) %>% # make it [0,1]
    as.array()
  op <- par(mar=rep(0, 4))
  plot(as.raster(im), asp = NA)
  par(op)
}

load_image <- function(path) {
  x <- base_loader(path) %>%
    transform_to_tensor() %>%
    transform_resize(c(128, 128))
  x <- x[newaxis,..]
  x <- normalize(x)
  x$to(device = device)
}

img_path <- tempfile(fileext = ".jpg")
download.file(
  "https://upload.wikimedia.org/wikipedia/commons/thumb/0/04/Tempera%2C_charcoal_and_gouache_mountain_painting_by_Nicholas_Roerich.jpg/301px-Tempera%2C_charcoal_and_gouache_mountain_painting_by_Nicholas_Roerich.jpg",
  img_path
)

img <- load_image(img_path)
plot_img(img)


# Compute target style ----------------------------------------------------

compute_style <- function(imgs) {
  style_layers <- 1:5 # first five convolutions
  convs <- model(imgs)[style_layers]
  convs %>%
    map(gram_matrix)
}

with_no_grad({
  target_style <- compute_style(img)
})


# Setup training ----------------------------------------------------------

ca <- CA()
ca$to(device = device)
optimizer <- optim_adam(ca$parameters, lr = 1e-3)
scheduler <- lr_step(optimizer, step_size = 2000, gamma = 0.3)

with_no_grad({
  pool <- ca$seed(1024)
})

pb <- progress::progress_bar$new(total = 4000, format = ":current/:total [:bar] :elapsed/:eta")
for (i in 1:4000) {
  with_no_grad({
    batch_idx <- sample.int(pool$size(1), 4)
    x <- pool[batch_idx,..]
  })

  step_n <- sample(32:96, 1)

  for (k in 1:step_n) {
    x <- ca(x)
  }

  styles <- compute_style(x[,1:3,..])
  loss <- purrr::map2(styles, target_style, style_loss) %>%
    purrr::reduce(~.x + .y)

  with_no_grad({

    loss$backward()

    for (p in ca$parameters)
      p$grad$div_(p$grad$norm()+1e-8)

    optimizer$step()
    optimizer$zero_grad()
    scheduler$step()

    pool[batch_idx,..] <- x

  })
  pb$tick()
}

torch_save(ca, "ca.pt")

# NCA video ---------------------------------------------------------------

create_frames <- function() {
  x <- ca$seed(1, 128)
  for (k in 1:300) {
    step_n <- min(2^(k%/%30), 16)
    for (i in 1:step_n) {
      with_no_grad({
        x <- ca(x)
      })
    }
    plot_img(x[1,1:3,..,drop=FALSE])
  }
}

gifski::save_gif(
  create_frames(),
  gif_file = "nca_video.gif",
  width = 128,
  height = 128,
  delay = 0.1
)