# Packages ----------------------------------------------------------------

library(torch)
library(torchvision)


# Datasets ----------------------------------------------------------------

dir <- "~/Downloads/tiny-imagenet"

device <- if(cuda_is_available()) "cuda" else "cpu"

to_device <- function(x, device) {
  x$to(device = device)
}

train_ds <- tiny_imagenet_dataset(
  dir,
  download = TRUE,
  transform = function(x) {
    x %>%
      transform_to_tensor() %>%
      to_device(device) %>%
      transform_resize(c(64, 64))
  }
)

valid_ds <- tiny_imagenet_dataset(
  dir,
  download = TRUE,
  split = "val",
  transform = function(x) {
    x %>%
      transform_to_tensor() %>%
      to_device(device) %>%
      transform_resize(c(64,64))
  }
)

train_dl <- dataloader(train_ds, batch_size = 32, shuffle = TRUE, drop_last = TRUE)
valid_dl <- dataloader(valid_ds, batch_size = 32, shuffle = FALSE, drop_last = TRUE)


# Model -------------------------------------------------------------------

model <- model_alexnet(pretrained = FALSE, num_classes = length(train_ds$classes))
model$to(device = device)

optimizer <- optim_adam(model$parameters)
scheduler <- lr_step(optimizer, step_size = 1, 0.95)
loss_fn <- nn_cross_entropy_loss()


# Training loop -----------------------------------------------------------

train_step <- function(batch) {
  optimizer$zero_grad()
  output <- model(batch[[1]]$to(device = device))
  loss <- loss_fn(output, batch[[2]]$to(device = device))
  loss$backward()
  optimizer$step()
  loss
}

valid_step <- function(batch) {
  model$eval()
  pred <- model(batch[[1]]$to(device = device))
  pred <- torch_topk(pred, k = 5, dim = 2, TRUE, TRUE)[[2]]$add(1L)
  pred <- pred$to(device = torch_device("cpu"))
  correct <- batch[[2]]$view(c(-1, 1))$eq(pred)$any(dim = 2)
  model$train()
  correct$to(dtype = torch_float32())$mean()$item()
}

for (epoch in 1:50) {

  pb <- progress::progress_bar$new(
    total = length(train_dl),
    format = "[:bar] :eta Loss: :loss"
  )

  l <- c()
  for (b in enumerate(train_dl)) {
    loss <- train_step(b)
    l <- c(l, loss$item())
    pb$tick(tokens = list(loss = mean(l)))
  }

  acc <- c()
  for (b in enumerate(valid_dl)) {
    accuracy <- valid_step(b)
    acc <- c(acc, accuracy)
  }

  scheduler$step()
  cat(sprintf("[epoch %d]: Loss = %3f, Acc= %3f \n", epoch, mean(l), mean(acc)))
}