Related
I am trying to estimate a Mixed-mixed multinomial logit model using the gmnl package. It works perfectly when not including Alternative Specific Constants (ASC), but it produces a weird error when incorporating them. The code below was taken (and adapted) from the original article published of the package.
Data preparation
options(digits = 3)
library("gmnl")
library("mlogit")
data("Electricity", package = "mlogit")
Electr <- mlogit.data(Electricity,
id.var = "id",
choice = "choice",
varying = 3:26,
shape = "wide",
sep = "")
####Alternative Specific Constants
Electr$asc2 <- as.numeric(Electr$alt == 2)
Electr$asc3 <- as.numeric(Electr$alt == 3)
Electr$asc4 <- as.numeric(Electr$alt == 4)
Latent Class Models (with ASC)
The code below works perfectly, even including the ASC in the second part of the formula (LC_ASC_in_formula) or explicitly with the regressors (LC_ASC_in_variables).
LC_ASC_in_formula <- gmnl(choice ~ pf + cl + loc + wk + tod + seas | 1 | 0 | 0 | 1,
data = Electr,
subset = 1:3000,
model = "lc",
panel = TRUE,
Q = 2)
summary(LC_ASC_in_formula)
LC_ASC_in_variables <- gmnl( choice ~ pf + cl + loc + wk + tod + seas +asc2 +asc3 +asc4 | 0 | 0 | 0 | 1,
data = Electr,
subset = 1:3000,
model = "lc",
panel = TRUE,
Q = 2)
summary(LC_ASC_in_variables)
## Are they the same?
logLik(LC_ASC_in_variables) == logLik(LC_ASC_in_formula)
## [1] TRUE
Mixed-mixed MNL model
This model is basically a Latent Class model, but inside each class, the parameters are random (follow a previously specified parametric distribution).
Mixed-mixed MNL model WITHOUT ASC
The model works just fine when the ASCs are omitted.
MM_no_ASC <- gmnl(choice ~ pf + cl + loc + wk + tod + seas | 0 | 0 | 0 | 1,
data = Electr,
subset = 1:3000,
model = "mm",
R = 5,
panel = TRUE,
ranp = c(pf = "n",cl = "n",loc = "n",wk = "n", tod = "n",seas= "n"),
Q = 2,
iterlim = 500)
However, it fails to estimate the model when including the ASC:
As part of the variables in the model.
MM_ASC_in_variables <- gmnl( choice ~ pf + cl + loc + wk + tod + seas +
asc2 +asc3 +asc4 | 0 | 0 | 0 | 1 ,
data = Electr,
subset = 1:3000,
model = "mm",
R = 5,
panel = TRUE,
ranp = c(pf = "n",cl = "n",loc = "n",wk = "n", tod = "n",seas= "n"),
Q = 2,
iterlim = 500)
> Error in if (distr == "n") { : missing value where TRUE/FALSE needed
and when including them in the third part of the formula.
MM_ASC_in_formula <- gmnl( choice ~ pf + cl + loc + wk + tod + seas | 1 | 0 | 0 | 1 ,
data = Electr,
subset = 1:3000,
model = "mm",
R = 5,
panel = TRUE,
ranp = c(pf = "n",cl = "n",loc = "n",wk = "n", tod = "n",seas= "n"),
Q = 2,
iterlim = 500)
> Error in if (distr == "n") { : missing value where TRUE/FALSE needed
Howeve, both ways to include the ASC parameters fail to initialize the model estimation. Hopefully, someone could help me to solve this issue. Thank you in advance.
Bonus1: Traceback of the error.
I reduced the number of observations included in the estimation (subset = 1:20) to see better the traceback() of the error shown below. But I couldn't spot the error myself.
MM_ASC_in_formula <- gmnl( choice ~ pf + cl + loc + wk + tod + seas | 1 | 0 | 0 | 1 ,
data = Electr,
subset = 1:20,
model = "mm",
R = 5,
panel = TRUE,
ranp = c(pf = "n",cl = "n",loc = "n",wk = "n", tod = "n",seas= "n"),
Q = 2,
iterlim = 500)
# Error in if (distr == "n") { : missing value where TRUE/FALSE needed
traceback()
# Estimating MM-MNL model
# Error in if (distr == "n") { : missing value where TRUE/FALSE needed
# > traceback()
# 14: Makeh.rcoef(beta[, q], stds[, q], ranp, Omega[, ((i - 1) * R +
# 1):(i * R), drop = FALSE], correlation, Pi = NULL, Slist = NULL,
# mvar = NULL)
# 13: fnOrig(theta, ...)
# 12: logLikFunc(theta, fnOrig = function (theta, y, X, H, Q, id = NULL,
# ranp, R, correlation, weights = NULL, haltons = NULL, seed = 12345,
# gradient = TRUE, get.bi = FALSE)
# {
# K <- ncol(X[[1]])
# J <- length(X)
# N <- nrow(X[[1]])
# panel <- !is.null(id)
# if (panel) {
# n <- length(unique(id))
# if (length(weights) == 1)
# weights <- rep(weights, N)
# }
# beta <- matrix(theta[1L:(K * Q)], nrow = K, ncol = Q)
# nstds <- if (!correlation)
# K * Q
# else (0.5 * K * (K + 1)) * Q
# stds <- matrix(theta[(K * Q + 1):(K * Q + nstds)], ncol = Q)
# rownames(beta) <- colnames(X[[1]])
# colnames(beta) <- colnames(stds) <- paste("class", 1:Q, sep = ":")
# gamma <- theta[-c(1L:(K * Q + nstds))]
# ew <- lapply(H, function(x) exp(crossprod(t(x), gamma)))
# sew <- suml(ew)
# Wnq <- lapply(ew, function(x) {
# v <- x/sew
# v[is.na(v)] <- 0
# as.vector(v)
# })
# Wnq <- Reduce(cbind, Wnq)
# set.seed(seed)
# Omega <- make.draws(R * ifelse(panel, n, N), K, haltons)
# XBr <- vector(mode = "list", length = J)
# for (j in 1:J) XBr[[j]] <- array(NA, dim = c(N, R, Q))
# nind <- ifelse(panel, n, N)
# if (panel)
# theIds <- unique(id)
# if (get.bi)
# bi <- array(NA, dim = c(nind, R, Q, K), dimnames = list(NULL,
# NULL, NULL, colnames(X[[1]])))
# for (i in 1:nind) {
# if (panel) {
# anid <- theIds[i]
# theRows <- which(id == anid)
# }
# else theRows <- i
# for (q in 1:Q) {
# bq <- Makeh.rcoef(beta[, q], stds[, q], ranp, Omega[,
# ((i - 1) * R + 1):(i * R), drop = FALSE], correlation,
# Pi = NULL, Slist = NULL, mvar = NULL)
# for (j in 1:J) {
# XBr[[j]][theRows, , q] <- crossprod(t(X[[j]][theRows,
# , drop = FALSE]), bq$br)
# }
# if (get.bi)
# bi[i, , q, ] <- t(bq$br)
# }
# }
# EXB <- lapply(XBr, function(x) exp(x))
# SEXB <- suml.array(EXB)
# Pntirq <- lapply(EXB, function(x) x/SEXB)
# Pnrq <- suml.array(mapply("*", Pntirq, y, SIMPLIFY = FALSE))
# if (panel)
# Pnrq <- apply(Pnrq, c(2, 3), tapply, id, prod)
# Pnq <- apply(Pnrq, c(1, 3), mean)
# WPnq <- Wnq * Pnq
# Ln <- apply(WPnq, 1, sum)
# if (get.bi)
# Qir <- list(wnq = Wnq, Ln = Ln, Pnrq = Pnrq)
# lnL <- if (panel)
# sum(log(Ln) * weights[!duplicated(id)])
# else sum(log(Ln) * weights)
# if (gradient) {
# lambda <- mapply(function(y, p) y - p, y, Pntirq, SIMPLIFY = FALSE)
# Wnq.mod <- aperm(repmat(Wnq/Ln, dimen = c(1, 1, R)),
# c(1, 3, 2))
# Qnq.mod <- Wnq.mod * Pnrq
# if (panel)
# Qnq.mod <- Qnq.mod[id, , ]
# eta <- lapply(lambda, function(x) x * Qnq.mod)
# dUdb <- dUds <- vector(mode = "list", length = J)
# for (j in 1:J) {
# dUdb[[j]] <- array(NA, dim = c(N, K, Q))
# dUds[[j]] <- array(NA, dim = c(N, nrow(stds), Q))
# }
# for (i in 1:nind) {
# if (panel) {
# anid <- theIds[i]
# theRows <- which(id == anid)
# }
# else theRows <- i
# for (q in 1:Q) {
# bq <- Makeh.rcoef(beta[, q], stds[, q], ranp,
# Omega[, ((i - 1) * R + 1):(i * R), drop = FALSE],
# correlation, Pi = NULL, Slist = NULL, mvar = NULL)
# for (j in 1:J) {
# dUdb[[j]][theRows, , q] <- tcrossprod(eta[[j]][theRows,
# , q, drop = TRUE], bq$d.mu)
# dUds[[j]][theRows, , q] <- tcrossprod(eta[[j]][theRows,
# , q, drop = TRUE], bq$d.sigma)
# }
# }
# }
# if (correlation) {
# vecX <- c()
# for (i in 1:K) {
# vecX <- c(vecX, i:K)
# }
# Xac <- lapply(X, function(x) x[, vecX])
# }
# else {
# Xac <- X
# }
# Xr <- lapply(X, function(x) x[, rep(1:K, Q)])
# Xacr <- lapply(Xac, function(x) x[, rep(1:ncol(Xac[[1]]),
# Q)])
# dUdb <- lapply(dUdb, function(x) matrix(x, nrow = N))
# dUds <- lapply(dUds, function(x) matrix(x, nrow = N))
# grad.beta <- suml(mapply("*", Xr, dUdb, SIMPLIFY = FALSE))/R
# grad.stds <- suml(mapply("*", Xacr, dUds, SIMPLIFY = FALSE))/R
# Qnq <- WPnq/Ln
# if (panel) {
# Wnq <- Wnq[id, ]
# H <- lapply(H, function(x) x[id, ])
# Qnq <- Qnq[id, ]
# }
# Wg <- vector(mode = "list", length = Q)
# IQ <- diag(Q)
# for (q in 1:Q) Wg[[q]] <- rowSums(Qnq * (repRows(IQ[q,
# ], N) - repCols(Wnq[, q], Q)))
# grad.gamma <- suml(mapply("*", H, Wg, SIMPLIFY = FALSE))
# gari <- cbind(grad.beta, grad.stds, grad.gamma)
# colnames(gari) <- names(theta)
# attr(lnL, "gradient") <- gari * weights
# }
# if (get.bi) {
# Pnjq <- lapply(Pntirq, function(x) apply(x, c(1, 3),
# mean))
# if (panel)
# Wnq <- Wnq[id, ]
# Pw <- lapply(Pnjq, function(x) x * Wnq)
# attr(lnL, "prob.alt") <- sapply(Pw, function(x) apply(x,
# 1, sum))
# attr(lnL, "prob.ind") <- Ln
# attr(lnL, "bi") <- bi
# attr(lnL, "Qir") <- Qir
# attr(lnL, "Wnq") <- Wnq
# }
# lnL
# },# weights = 1, R = 5, seed = 12345, ranp = c(pf = "n", cl = "n",
# loc = "n", wk = "n", tod = "n", seas = "n"), id = structure(c(1L,
# 1L, 1L, 1L, 1L), .Label = "1", class = "factor"), H = list(
# `1` = structure(0, .Dim = c(1L, 1L), .Dimnames = list(
# "1", "(class)2")), `2` = structure(1, .Dim = c(1L,
# 1L), .Dimnames = list("2", "(class)2"))), correlation = FALSE,
# haltons = NA, Q = 2)
# 11: eval(f, sys.frame(sys.parent()))
# 10: eval(f, sys.frame(sys.parent()))
# 9: callWithoutArgs(theta, fName = fName, args = names(formals(sumt)),
# ...)
# 8: (function (theta, fName, ...)
#
# 7: do.call(callWithoutSumt, argList)
# 6: maxOptim(fn = fn, grad = grad, hess = hess, start = start, method = "BFGS",
# fixed = fixed, constraints = constraints, finalHessian = finalHessian,
# parscale = parscale, control = mControl, ...)
# 5: maxRoutine(fn = logLik, grad = grad, hess = hess, start = start,
# constraints = constraints, ...)
# 4: maxLik(method = "bfgs", iterlim = 500, start = c(`class.1.2:(intercept)` = -4.85114128700713,
# `class.1.3:(intercept)` = -7.69322200825539, `class.1.4:(intercept)` = 5.01582959989182,
# class.1.pf = -1.60963678008691, class.1.cl = 0.109892050051351,
# class.1.loc = 18.3461318629584, class.1.wk = 5.01552145983325,
# class.1.tod = 6.12905713997904, class.1.seas = -4.37562129235275,
# `class.2.2:(intercept)` = -4.81114128700713, `class.2.3:(intercept)` = -7.6532220082554,
# `class.2.4:(intercept)` = 5.05582959989182, class.2.pf = -1.56963678008691,
# class.2.cl = 0.149892050051351, class.2.loc = 18.3861318629584,
# class.2.wk = 5.05552145983325, class.2.tod = 6.16905713997903,
# class.2.seas = -4.33562129235275, class.1.sd.pf = 0.08, class.1.sd.cl = 0.08,
# class.1.sd.loc = 0.08, class.1.sd.wk = 0.08, class.1.sd.tod = 0.08,
# class.1.sd.seas = 0.08, class.2.sd.pf = 0.12, class.2.sd.cl = 0.12,
# class.2.sd.loc = 0.12, class.2.sd.wk = 0.12, class.2.sd.tod = 0.12,
# class.2.sd.seas = 0.12, `(class)2` = 0), X = Xl, y = yl, gradient = gradient,
# weights = weights, logLik = ll.mnlogit, R = R, seed = seed,
# ranp = ranp, id = id, H = Hl, correlation = correlation,
# haltons = haltons, Q = Q)
# 3: eval(opt, sys.frame(which = nframe))
# 2: eval(opt, sys.frame(which = nframe))
# 1: gmnl(choice ~ pf + cl + loc + wk + tod + seas | 1 | 0 | 0 | 1,
# data = Electr, subset = 1:20, model = "mm", R = 5, panel = TRUE,
# ranp = c(pf = "n", cl = "n", loc = "n", wk = "n", tod = "n",
# seas = "n"), Q = 2, iterlim = 500)
Bonus2 :sessionInfo()
R version 4.1.2 (2021-11-01)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)
Matrix products: default
attached base packages:
[1] grid stats graphics grDevices utils datasets
[7] methods base
other attached packages:
[1] here_1.0.1 strucchange_1.5-2 sandwich_3.0-1
[4] zoo_1.8-9 partykit_1.2-15 mvtnorm_1.1-3
[7] libcoin_1.0-9 mlogit_1.1-1 dfidx_0.0-4
[10] gmnl_1.1-3.2 Formula_1.2-4 maxLik_1.5-2
[13] miscTools_0.6-26 dplyr_1.0.7 nnet_7.3-17
Thank you in advance.
I am looking to have each frame of a scatter plot be filtered by another vector with a certain bin width and have it it roll through those. For example I can do this by:
library(ggplot2)
library(gganimate)
#example data
iris <- datasets::iris
#plot x and y
g <- ggplot(iris) + geom_point(aes(x = Petal.Width,y = Petal.Length))
#filter x and y by a third value with a bin width of 2 steping through by 0.5
g + transition_filter(transition_length = 1,
filter_length = 1,
4 < Sepal.Length & Sepal.Length < 6,
4.5 < Sepal.Length & Sepal.Length < 6.5,
5 < Sepal.Length & Sepal.Length < 7,
5.5 < Sepal.Length & Sepal.Length < 7.5,
6 < Sepal.Length & Sepal.Length < 8)
However - writing out each filter condition is tedious, and I would like to filter a different dataset with a ~20 binwidth steping through by 1 over a 300 point range so writing 100+ filters is not practical.
Is there another way to do this?
A while ago I wanted this exact function but didn't actually see anything in gganimate to do it, so I wrote something that would get the job done. Below is what I came up with, so I ended up rebuilding gganimate with this function included to avoid using :::.
I wrote this a while ago so I don't recall the exact intention of each argument at the moment of writing it (ALWAYS REMEMBER TO DOCUMENT YOUR CODE).
Here is what I recall
span : expression that can be evaluated within the data layers
size : how much data to be shown at once
enter_length/exit_length : Don't exactly recall how it works in relation to each other or size/span
range : a subset range
retain_data_order : logical - don't remember why this is here (sorry!)
library(gganimate)
#> Loading required package: ggplot2
library(rlang)
library(tweenr)
library(stringi)
get_row_event <- gganimate:::get_row_event
is_placeholder <- gganimate:::is_placeholder
recast_event_times <- gganimate:::recast_event_times
recast_times <- gganimate:::recast_times
TransitionSpan <- ggplot2::ggproto('TransitionSpan',
TransitionEvents,
finish_data = function (self, data, params)
{
lapply(data, function(d) {
split_panel <- stri_match(d$group, regex = "^(.+)<(.*)>(.*)$")
if (is.na(split_panel[1]))
return(list(d))
d$group <- match(d$group, unique(d$group))
empty_d <- d[0, , drop = FALSE]
d <- split(d, as.integer(split_panel[, 3]))
frames <- rep(list(empty_d), params$nframes)
frames[as.integer(names(d))] <- d
frames
})
},
setup_params = function(self, data, params) {
# browser()
params$start <- get_row_event(data, params$span_quo, "start")
time_class <- if (is_placeholder(params$start))
NULL
else params$start$class
end_quo <- expr(!!params$span_quo + diff(range(!!params$span_quo))*!!params$size_quo)
params$end <- get_row_event(data, end_quo, "end",
time_class)
params$enter_length <- get_row_event(data, params$enter_length_quo,
"enter_length", time_class)
params$exit_length <- get_row_event(data, params$exit_length_quo,
"exit_length", time_class)
params$require_stat <- is_placeholder(params$start) || is_placeholder(params$end) ||
is_placeholder(params$enter_length) || is_placeholder(params$exit_length)
static = lengths(params$start$values) == 0
params$row_id <- Map(function(st, end, en, ex, s) if (s)
character(0)
else paste(st, end, en, ex, sep = "_"), st = params$start$values,
end = params$end$values, en = params$enter_length$values,
ex = params$exit_length$values, s = static)
params
},
setup_params2 = function(self, data, params, row_vars) {
late_start <- FALSE
if (is_placeholder(params$start)) {
params$start <- get_row_event(data, params$start_quo, 'start', after = TRUE)
late_start <- TRUE
} else {
params$start$values <- lapply(row_vars$start, as.numeric)
}
size <- expr(!!params$size_quo)
time_class <- params$start$class
if (is_placeholder(params$end)) {
params$end <- get_row_event(data, params$end_quo, 'end', time_class, after = TRUE)
} else {
params$end$values <- lapply(row_vars$end, as.numeric)
}
if (is_placeholder(params$enter_length)) {
params$enter_length <- get_row_event(data, params$enter_length_quo, 'enter_length', time_class, after = TRUE)
} else {
params$enter_length$values <- lapply(row_vars$enter_length, as.numeric)
}
if (is_placeholder(params$exit_length)) {
params$exit_length <- get_row_event(data, params$exit_length_quo, 'exit_length', time_class, after = TRUE)
} else {
params$exit_length$values <- lapply(row_vars$exit_length, as.numeric)
}
times <- recast_event_times(params$start, params$end, params$enter_length, params$exit_length)
params$span_size <- diff(times$start$range)*eval_tidy(size)
range <- if (is.null(params$range)) {
low <- min(unlist(Map(function(start, enter) {
start - (if (length(enter) == 0) 0 else enter)
}, start = times$start$values, enter = times$enter_length$values)))
high <- max(unlist(Map(function(start, end, exit) {
(if (length(end) == 0) start else end) + (if (length(exit) == 0) 0 else exit)
}, start = times$start$values, end = times$end$values, exit = times$exit_length$values)))
range <- c(low, high)
} else {
if (!inherits(params$range, time_class)) {
stop('range must be given in the same class as time', call. = FALSE)
}
as.numeric(params$range)
}
full_length <- diff(range)
frame_time <- recast_times(
seq(range[1], range[2], length.out = params$nframes),
time_class
)
frame_length <- full_length / params$nframes
rep_frame <- round(params$span_size/frame_length)
lowerl <- c(rep(frame_time[1],rep_frame), frame_time[2:(params$nframes-rep_frame+1)])
upperl <- c(frame_time[1:(params$nframes-rep_frame)], rep(frame_time[params$nframes-rep_frame+1], rep_frame))
start <- lapply(times$start$values, function(x) {
round((params$nframes - 1) * (x - range[1])/full_length) + 1
})
end <- lapply(times$end$values, function(x) {
if (length(x) == 0) return(numeric())
round((params$nframes - 1) * (x - range[1])/full_length) + 1
})
enter_length <- lapply(times$enter_length$values, function(x) {
if (length(x) == 0) return(numeric())
round(x / frame_length)
})
exit_length <- lapply(times$exit_length$values, function(x) {
if (length(x) == 0) return(numeric())
round(x / frame_length)
})
params$range <- range
params$frame_time <- frame_time
static = lengths(start) == 0
params$row_id <- Map(function(st, end, en, ex, s) if (s) character(0) else paste(st, end, en, ex, sep = '_'),
st = start, end = end, en = enter_length, ex = exit_length, s = static)
params$lowerl <- lowerl
params$upperl <- upperl
params$frame_span <- upperl - lowerl
params$frame_info <- data.frame(
frame_time = frame_time,
lowerl = lowerl,
upperl = upperl,
frame_span = upperl - lowerl
)
params$nframes <- nrow(params$frame_info)
params
},
expand_panel = function(self, data, type, id, match, ease, enter, exit, params, layer_index) {
#browser()
row_vars <- self$get_row_vars(data)
if (is.null(row_vars))
return(data)
data$group <- paste0(row_vars$before, row_vars$after)
start <- as.numeric(row_vars$start)
end <- as.numeric(row_vars$end)
if (is.na(end[1]))
end <- NULL
enter_length <- as.numeric(row_vars$enter_length)
if (is.na(enter_length[1]))
enter_length <- NULL
exit_length <- as.numeric(row_vars$exit_length)
if (is.na(exit_length[1]))
exit_length <- NULL
data$.start <- start
all_frames <- tween_events(data, c(ease,"linear"),
params$nframes, !!start, !!end, c(1, params$nframes),
enter, exit, !!enter_length, !!exit_length)
if(params$retain_data_order){
all_frames <- all_frames[order(as.numeric(all_frames$.id)),]
} else {
all_frames <- all_frames[order(all_frames$.start, as.numeric(all_frames$.id)),]
}
all_frames$group <- paste0(all_frames$group, '<', all_frames$.frame, '>')
all_frames$.frame <- NULL
all_frames$.start <- NULL
all_frames
})
transition_span <- function(span, size = 0.5, enter_length = NULL, exit_length = NULL, range = NULL, retain_data_order = T){
span_quo <- enquo(span)
size_quo <- enquo(size)
enter_length_quo <- enquo(enter_length)
exit_length_quo <- enquo(exit_length)
gganimate:::require_quo(span_quo, "span")
ggproto(NULL, TransitionSpan,
params = list(span_quo = span_quo,
size_quo = size_quo, range = range, enter_length_quo = enter_length_quo,
exit_length_quo = exit_length_quo,
retain_data_order = retain_data_order))
}
g <- ggplot(iris) +
geom_point(aes(x = Petal.Width,y = Petal.Length, color = Sepal.Length)) +
viridis::scale_color_viridis()
a <- g + transition_span(Sepal.Length, .1, 1, 1)
animate(a, renderer = gganimate::gifski_renderer())
Created on 2021-08-11 by the reprex package (v2.0.0)
I have a database of matches with players and players'scores for each game. I am trying to create a rating variable for my prediction model. I am using formula from a blogpost.
Here is the dummy dataset:
df = data.frame(
matchid = c(1,1,1,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2,2,3,3,3,3,3,3,3,3,3,3,4,4,4,4,4,4,4,4,4,4),
playerid = c(2,3,4,5,6,7,8,9,10,11,5,2,3,4,6,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,17,19,21,18,20,22,26,24,25,23),
point = c(52,38,34,33,16,19,16,8,10,2,38,37,31,34,21,18,18,13,9,-2,45,34,37,39,12,9,7,-3,-1,-8,47,38,31,17,26,32,28,17,16,9))
Here is my attempt using for loop. The for loop run extremely slow for 30000 games database. Please give me some pointers on how to improve this process / loop. I really have no idea.
## Initialize initial rating for each player
players_ratings = data.frame(playerid = unique(df$playerid),rating = 1000, stringsAsFactors = FALSE)
## Initialize unique matches
unique_matches = df$matchid %>% unique
## Matches with rating
relative_rating_matches = list(length(df))
### GENERATE RATING
for(index in 1:length(unique_matches)){
match = df %>% filter(matchid == unique_matches[[index]])
position = index
## UPDATE RATING
match = match %>% left_join(players_ratings,by = 'playerid')
relative_rating_matches[[position]] = match
print(match)
## BUILD ACTUAL RESULTS MATRIX
S = matrix(nrow = 10, ncol = 10)
rownames(S) = match$playerid
colnames(S) = match$playerid
for(i in 1:nrow(S)) {
for(j in 1:ncol(S)) {
player_row_point = as.numeric(match %>% filter(playerid == rownames(S)[i]) %>% select(point))
player_col_point = as.numeric(match %>% filter(playerid == colnames(S)[j]) %>% select(point))
S[i,j] = ifelse(player_col_point == player_row_point,0.5,
ifelse(player_col_point > player_row_point,1,0))
}
}
diag(S)= 0
print(S)
## BUILD EXPECTED WIN/LOSS MATRIX
E = matrix(nrow = 10, ncol = 10)
rownames(E) = match$playerid
colnames(E) = match$playerid
for(i in 1:nrow(E)) {
for(j in 1:ncol(E)) {
player_row_rating = as.numeric(match %>% filter(playerid == rownames(E)[i]) %>% select(rating))
player_col_rating = as.numeric(match %>% filter(playerid == colnames(E)[j]) %>% select(rating))
r = 1 + 10^((player_row_rating - player_col_rating)/400)
expected_result = 1/r
E[i,j] = expected_result
}
}
diag(E) = 0
print(E)
## GENERATE INCREMENTAL RATING
R = 20 * (S-E)
R = as.data.frame(colSums(R)) %>% rownames_to_column()
print(R)
## UPDATE EXISTING RATING DATABASE
for(i in 1:nrow(R)){
player_id = R[i,1]
incre_rating = ifelse(is.na(R[i,2]),0,R[i,2])
cur_rating = players_ratings[players_ratings$playerid == player_id,2]
players_ratings[players_ratings$playerid == player_id,2] = cur_rating + incre_rating
}
}
I have created a multipanel Taylor plot using openair package. I want to change the font size of 'correlation' and 'observed' and make it sentence case. I have used the following code:
TaylorDiagram(data, obs = "Observed", mod = "Predicted", group = "Method", type = "Station")
The task is now achieved by changing the source code of TaylorDiagram from openair package using the following code
library(lattice)
library(dplyr)
TaylorDiagram1 <- function(mydata, obs = "obs", mod = "mod", group = NULL, type = "default",
normalise = FALSE, cols = "brewer1",
rms.col = "darkgoldenrod", cor.col = "black", arrow.lwd = 3,
annotate = "centred\nRMS error",
key = TRUE, key.title = group, key.columns = 1,
key.pos = "right", strip = TRUE, auto.text = TRUE, ...) {
## get rid of R check annoyances
sd.mod <- R <- NULL
## greyscale handling
## set graphics
current.strip <- trellis.par.get("strip.background")
current.font <- trellis.par.get("fontsize")
## reset graphic parameters
on.exit(trellis.par.set(
fontsize = current.font
))
if (length(cols) == 1 && cols == "greyscale") {
trellis.par.set(list(strip.background = list(col = "white")))
## other local colours
method.col <- "greyscale"
} else {
method.col <- "default"
}
## extra.args setup
extra.args <- list(...)
## label controls (some local xlab, ylab management in code)
extra.args$xlab <- if ("xlab" %in% names(extra.args)) {
quickText(extra.args$xlab, auto.text)
} else {
NULL
}
extra.args$ylab <- if ("ylab" %in% names(extra.args)) {
quickText(extra.args$ylab, auto.text)
} else {
NULL
}
extra.args$main <- if ("main" %in% names(extra.args)) {
quickText(extra.args$main, auto.text)
} else {
quickText("", auto.text)
}
if ("fontsize" %in% names(extra.args)) {
trellis.par.set(fontsize = list(text = extra.args$fontsize))
}
if (!"layout" %in% names(extra.args)) {
extra.args$layout <- NULL
}
if (!"pch" %in% names(extra.args)) {
extra.args$pch <- 20
}
if (!"cex" %in% names(extra.args)) {
extra.args$cex <- 2
}
## #######################################################################################
## check to see if two data sets are present
combine <- FALSE
if (length(mod) == 2) combine <- TRUE
if (any(type %in% dateTypes)) {
vars <- c("date", obs, mod)
} else {
vars <- c(obs, mod)
}
## assume two groups do not exist
twoGrp <- FALSE
if (!missing(group)) if (any(group %in% type)) stop("Can't have 'group' also in 'type'.")
mydata <- cutData(mydata, type, ...)
if (missing(group)) {
if ((!"group" %in% type) & (!"group" %in% c(obs, mod))) {
mydata$group <- factor("group")
group <- "group"
npol <- 1
}
## don't overwrite a
} else { ## means that group is there
mydata <- cutData(mydata, group, ...)
}
## if group is present, need to add that list of variables unless it is
## a pre-defined date-based one
if (!missing(group)) {
npol <- length(unique((mydata[[group[1]]])))
## if group is of length 2
if (length(group) == 2L) {
twoGrp <- TRUE
grp1 <- group[1]
grp2 <- group[2]
if (missing(key.title)) key.title <- grp1
vars <- c(vars, grp1, grp2)
mydata$newgrp <- paste(mydata[[group[1]]], mydata[[group[2]]], sep = "-")
group <- "newgrp"
}
if (group %in% dateTypes | any(type %in% dateTypes)) {
vars <- unique(c(vars, "date", group))
} else {
vars <- unique(c(vars, group))
}
}
## data checks, for base and new data if necessary
mydata <- checkPrep(mydata, vars, type)
# check mod and obs are numbers
mydata <- checkNum(mydata, vars = c(obs, mod))
## remove missing data
mydata <- na.omit(mydata)
legend <- NULL
## function to calculate stats for TD
calcStats <- function(mydata, obs = obs, mod = mod) {
R <- cor(mydata[[obs]], mydata[[mod]], use = "pairwise")
sd.obs <- sd(mydata[[obs]])
sd.mod <- sd(mydata[[mod]])
if (normalise) {
sd.mod <- sd.mod / sd.obs
sd.obs <- 1
}
res <- data.frame(R, sd.obs, sd.mod)
res
}
vars <- c(group, type)
results <- group_by(mydata, UQS(syms(vars))) %>%
do(calcStats(., obs = obs, mod = mod[1]))
results.new <- NULL
if (combine) {
results.new <- group_by(mydata, UQS(syms(vars))) %>%
do(calcStats(., obs = obs, mod = mod[2]))
}
## if no group to plot, then add a dummy one to make xyplot work
if (is.null(group)) {
results$MyGroupVar <- factor("MyGroupVar")
group <- "MyGroupVar"
}
## set up colours
myColors <- openColours(cols, npol)
pch.orig <- extra.args$pch
## combined colours if two groups
if (twoGrp) {
myColors <- rep(
openColours(cols, length(unique(mydata[[grp1]]))),
each = length(unique(mydata[[grp2]]))
)
extra.args$pch <- rep(extra.args$pch, each = length(unique(mydata[[grp2]])))
}
## basic function for lattice call + defaults
temp <- paste(type, collapse = "+")
myform <- formula(paste("R ~ sd.mod", "|", temp, sep = ""))
scales <- list(x = list(rot = 0), y = list(rot = 0))
pol.name <- sapply(levels(mydata[, group]), function(x) quickText(x, auto.text))
if (key & npol > 1 & !combine) {
thecols <- unique(myColors)
if (twoGrp) {
pol.name <- levels(factor(mydata[[grp1]]))
}
key <- list(
points = list(col = thecols), pch = pch.orig,
cex = extra.args$cex, text = list(lab = pol.name, cex = 0.8),
space = key.pos, columns = key.columns,
title = quickText(key.title, auto.text),
cex.title = 0.8, lines.title = 3
)
} else if (key & npol > 1 & combine) {
key <- list(
lines = list(col = myColors[1:npol]), lwd = arrow.lwd,
text = list(lab = pol.name, cex = 0.8), space = key.pos,
columns = key.columns,
title = quickText(key.title, auto.text),
cex.title = 0.8, lines.title = 3
)
} else {
key <- NULL
}
## special wd layout
if (length(type) == 1 & type[1] == "wd" & is.null(extra.args$layout)) {
## re-order to make sensible layout
wds <- c("NW", "N", "NE", "W", "E", "SW", "S", "SE")
mydata$wd <- ordered(mydata$wd, levels = wds)
## see if wd is actually there or not
wd.ok <- sapply(wds, function(x) {
if (x %in% unique(mydata$wd)) FALSE else TRUE
})
skip <- c(wd.ok[1:4], TRUE, wd.ok[5:8])
mydata$wd <- factor(mydata$wd) ## remove empty factor levels
extra.args$layout <- c(3, 3)
if (!"skip" %in% names(extra.args)) {
extra.args$skip <- skip
}
}
if (!"skip" %in% names(extra.args)) {
extra.args$skip <- FALSE
}
## proper names of labelling ####################################################
stripName <- sapply(levels(mydata[, type[1]]), function(x) quickText(x, auto.text))
if (strip) strip <- strip.custom(factor.levels = stripName)
if (length(type) == 1) {
strip.left <- FALSE
} else { ## two conditioning variables
stripName <- sapply(levels(mydata[, type[2]]), function(x) quickText(x, auto.text))
strip.left <- strip.custom(factor.levels = stripName)
}
## #############################################################################
## no strip needed for single panel
if (length(type) == 1 & type[1] == "default") strip <- FALSE
## not sure how to evaluate "group" in xyplot, so change to a fixed name
id <- which(names(results) == group)
names(results)[id] <- "MyGroupVar"
maxsd <- 1.2 * max(results$sd.obs, results$sd.mod)
# xlim, ylim handling
if (!"ylim" %in% names(extra.args)) {
extra.args$ylim <- 1.12 * c(0, maxsd)
}
if (!"xlim" %in% names(extra.args)) {
extra.args$xlim <- 1.12 * c(0, maxsd)
}
## xlab, ylab local management
if (is.null(extra.args$ylab)) {
extra.args$ylab <- if (normalise) "standard deviation (normalised)" else "Standard deviation"
}
if (is.null(extra.args$xlab)) {
extra.args$xlab <- extra.args$ylab
}
## plot
xyplot.args <- list(
x = myform, data = results, groups = results$MyGroupVar,
aspect = 1,
type = "n",
as.table = TRUE,
scales = scales,
key = key,
par.strip.text = list(cex = 0.8),
strip = strip,
strip.left = strip.left,
panel = function(x, y, ...) {
## annotate each panel but don't need to do this for each grouping value
panel.taylor.setup(
x, y,
results = results, maxsd = maxsd,
cor.col = cor.col, rms.col = rms.col,
annotate = annotate, ...
)
## plot data in each panel
panel.superpose(
x, y,
panel.groups = panel.taylor, ...,
results = results, results.new = results.new,
combine = combine, myColors = myColors,
arrow.lwd = arrow.lwd
)
}
)
## reset for extra.args
xyplot.args <- listUpdate(xyplot.args, extra.args)
## plot
plt <- do.call(xyplot, xyplot.args)
if (length(type) == 1) plot(plt) else plot(useOuterStrips(plt, strip = strip, strip.left = strip.left))
newdata <- results
output <- list(plot = plt, data = newdata, call = match.call())
class(output) <- "openair"
invisible(output)
}
panel.taylor.setup <- function(x, y, subscripts, results, maxsd, cor.col, rms.col,
col.symbol, annotate, group.number, type, ...) {
## note, this assumes for each level of type there is a single measured value
## therefore, only the first is used i.e. results$sd.obs[subscripts[1]]
## This does not matter if normalise = TRUE because all sd.obs = 1.
## The data frame 'results' should contain a grouping variable 'MyGroupVar',
## 'type' e.g. season, R (correlation coef), sd.obs and sd.mod
xcurve <- cos(seq(0, pi / 2, by = 0.01)) * maxsd
ycurve <- sin(seq(0, pi / 2, by = 0.01)) * maxsd
llines(xcurve, ycurve, col = "black")
xcurve <- cos(seq(0, pi / 2, by = 0.01)) * results$sd.obs[subscripts[1]]
ycurve <- sin(seq(0, pi / 2, by = 0.01)) * results$sd.obs[subscripts[1]]
llines(xcurve, ycurve, col = "black", lty = 5)
corr.lines <- c(0.2, 0.4, 0.6, 0.8, 0.9)
## grid line with alpha transparency
theCol <- t(col2rgb(cor.col)) / 255
for (gcl in corr.lines) llines(
c(0, maxsd * gcl), c(0, maxsd * sqrt(1 - gcl ^ 2)),
col = rgb(theCol, alpha = 0.4), alpha = 0.5
)
bigtick <- acos(seq(0.1, 0.9, by = 0.1))
medtick <- acos(seq(0.05, 0.95, by = 0.1))
smltick <- acos(seq(0.91, 0.99, by = 0.01))
lsegments(
cos(bigtick) * maxsd, sin(bigtick) *
maxsd, cos(bigtick) * 0.96 * maxsd, sin(bigtick) * 0.96 * maxsd,
col = cor.col
)
lsegments(
cos(medtick) * maxsd, sin(medtick) *
maxsd, cos(medtick) * 0.98 * maxsd, sin(medtick) * 0.98 * maxsd,
col = cor.col
)
lsegments(
cos(smltick) * maxsd, sin(smltick) *
maxsd, cos(smltick) * 0.99 * maxsd, sin(smltick) * 0.99 * maxsd,
col = cor.col
)
## arcs for standard deviations (3 by default)
gamma <- pretty(c(0, maxsd), n = 5)
if (gamma[length(gamma)] > maxsd) {
gamma <- gamma[-length(gamma)]
}
labelpos <- seq(45, 70, length.out = length(gamma))
## some from plotrix
for (gindex in 1:length(gamma)) {
xcurve <- cos(seq(0, pi, by = 0.03)) * gamma[gindex] +
results$sd.obs[subscripts[1]]
endcurve <- which(xcurve < 0)
endcurve <- ifelse(length(endcurve), min(endcurve) - 1, 105)
ycurve <- sin(seq(0, pi, by = 0.03)) * gamma[gindex]
maxcurve <- xcurve * xcurve + ycurve * ycurve
startcurve <- which(maxcurve > maxsd * maxsd)
startcurve <- ifelse(length(startcurve), max(startcurve) + 1, 0)
llines(
xcurve[startcurve:endcurve], ycurve[startcurve:endcurve],
col = rms.col, lty = 5
)
ltext(
xcurve[labelpos[gindex]], ycurve[labelpos[gindex]],
gamma[gindex],
cex = 0.7, col = rms.col, pos = 1,
srt = 0, font = 2
)
ltext(
1.1 * maxsd, 1.05 * maxsd,
labels = annotate, cex = 0.7,
col = rms.col, pos = 2
)
}
## angles for R key
angles <- 180 * c(bigtick, acos(c(0.95, 0.99))) / pi
ltext(
cos(c(bigtick, acos(c(0.95, 0.99)))) *
1.06 * maxsd, sin(c(bigtick, acos(c(0.95, 0.99)))) *
1.06 * maxsd, c(seq(0.1, 0.9, by = 0.1), 0.95, 0.99),
cex = 0.7,
adj = 0.5, srt = angles, col = cor.col
)
ltext(
0.82 * maxsd, 0.82 * maxsd, "Correlation",
srt = 315, cex = 0.7,
col = cor.col
)
## measured point and text
lpoints(results$sd.obs[subscripts[1]], 0, pch = 20, col = "purple", cex = 1.5)
ltext(results$sd.obs[subscripts[1]], 0, "Observed", col = "purple", cex = 0.7, pos = 3)
}
panel.taylor <- function(x, y, subscripts, results, results.new, maxsd, cor.col,
rms.col, combine, col.symbol, myColors, group.number,
type, arrow.lwd, ...) {
R <- NULL
sd.mod <- NULL ## avoid R NOTEs
## Plot actual results by type and group if given
results <- transform(results, x = sd.mod * R, y = sd.mod * sin(acos(R)))
if (combine) {
results.new <- transform(results.new, x = sd.mod * R, y = sd.mod * sin(acos(R)))
larrows(
results$x[subscripts], results$y[subscripts],
results.new$x[subscripts], results.new$y[subscripts],
angle = 30, length = 0.1, col = myColors[group.number], lwd = arrow.lwd
)
} else {
lpoints(
results$x[subscripts], results$y[subscripts],
col.symbol = myColors[group.number], ...
)
}
}
startYear <- function(dat) as.numeric(format(min(dat[order(dat)]), "%Y"))
endYear <- function(dat) as.numeric(format(max(dat[order(dat)]), "%Y"))
startMonth <- function(dat) as.numeric(format(min(dat[order(dat)]), "%m"))
endMonth <- function(dat) as.numeric(format(max(dat[order(dat)]), "%m"))
## these are pre-defined type that need a field "date"; used by cutData
dateTypes <- c("year", "hour", "month", "season", "weekday", "weekend",
"monthyear", "gmtbst", "bstgmt", "dst", "daylight",
"seasonyear", "yearseason")
checkPrep <- function(mydata, Names, type, remove.calm = TRUE, remove.neg = TRUE,
strip.white = TRUE, wd = "wd") {
## deal with conditioning variable if present, if user-defined, must exist in data
## pre-defined types
## existing conditioning variables that only depend on date (which is checked)
conds <- c(
"default", "year", "hour", "month", "season", "weekday",
"weekend", "monthyear", "gmtbst", "bstgmt", "dst", "daylight",
"yearseason", "seasonyear"
)
all.vars <- unique(c(names(mydata), conds))
varNames <- c(Names, type) ## names we want to be there
matching <- varNames %in% all.vars
if (any(!matching)) {
## not all variables are present
stop(cat("Can't find the variable(s)", varNames[!matching], "\n"))
}
## add type to names if not in pre-defined list
if (any(type %in% conds == FALSE)) {
ids <- which(type %in% conds == FALSE)
Names <- c(Names, type[ids])
}
## if type already present in data frame
if (any(type %in% names(mydata))) {
ids <- which(type %in% names(mydata))
Names <- unique(c(Names, type[ids]))
}
## just select data needed
mydata <- mydata[, Names]
## if site is in the data set, check none are missing
## seems to be a problem for some KCL data...
if ("site" %in% names(mydata)) { ## split by site
## remove any NA sites
if (anyNA(mydata$site)) {
id <- which(is.na(mydata$site))
mydata <- mydata[-id, ]
}
}
## sometimes ratios are considered which can results in infinite values
## make sure all infinite values are set to NA
mydata[] <- lapply(mydata, function(x) {
replace(x, x == Inf | x == -Inf, NA)
})
if ("ws" %in% Names) {
if ("ws" %in% Names & is.numeric(mydata$ws)) {
## check for negative wind speeds
if (any(sign(mydata$ws[!is.na(mydata$ws)]) == -1)) {
if (remove.neg) { ## remove negative ws only if TRUE
warning("Wind speed <0; removing negative data")
mydata$ws[mydata$ws < 0] <- NA
}
}
}
}
## round wd to make processing obvious
## data already rounded to nearest 10 degress will not be affected
## data not rounded will be rounded to nearest 10 degrees
## assumes 10 is average of 5-15 etc
if (wd %in% Names) {
if (wd %in% Names & is.numeric(mydata[, wd])) {
## check for wd <0 or > 360
if (any(sign(mydata[[wd]][!is.na(mydata[[wd]])]) == -1 |
mydata[[wd]][!is.na(mydata[[wd]])] > 360)) {
warning("Wind direction < 0 or > 360; removing these data")
mydata[[wd]][mydata[[wd]] < 0] <- NA
mydata[[wd]][mydata[[wd]] > 360] <- NA
}
if (remove.calm) {
if ("ws" %in% names(mydata)) {
mydata[[wd]][mydata$ws == 0] <- NA ## set wd to NA where there are calms
mydata$ws[mydata$ws == 0] <- NA ## remove calm ws
}
mydata[[wd]][mydata[[wd]] == 0] <- 360 ## set any legitimate wd to 360
## round wd for use in functions - except windRose/pollutionRose
mydata[[wd]] <- 10 * ceiling(mydata[[wd]] / 10 - 0.5)
mydata[[wd]][mydata[[wd]] == 0] <- 360 # angles <5 should be in 360 bin
}
mydata[[wd]][mydata[[wd]] == 0] <- 360 ## set any legitimate wd to 360
}
}
## make sure date is ordered in time if present
if ("date" %in% Names) {
if ("POSIXlt" %in% class(mydata$date)) {
stop("date should be in POSIXct format not POSIXlt")
}
## if date in format dd/mm/yyyy hh:mm (basic check)
if (length(grep("/", as.character(mydata$date[1]))) > 0) {
mydata$date <- as.POSIXct(strptime(mydata$date, "%d/%m/%Y %H:%M"), "GMT")
}
## try and work with a factor date - but probably a problem in original data
if (is.factor(mydata$date)) {
warning("date field is a factor, check date format")
mydata$date <- as.POSIXct(mydata$date, "GMT")
}
mydata <- arrange(mydata, date)
## make sure date is the first field
if (names(mydata)[1] != "date") {
mydata <- mydata[c("date", setdiff(names(mydata), "date"))]
}
## check to see if there are any missing dates, stop if there are
ids <- which(is.na(mydata$date))
if (length(ids) > 0) {
mydata <- mydata[-ids, ]
warning(paste(
"Missing dates detected, removing",
length(ids), "lines"
), call. = FALSE)
}
## daylight saving time can cause terrible problems - best avoided!!
if (any(dst(mydata$date))) {
warning("Detected data with Daylight Saving Time, converting to UTC/GMT")
mydata$date <- lubridate::force_tz(mydata$date, tzone = "GMT")
}
}
if (strip.white) {
## set panel strip to white
suppressWarnings(trellis.par.set(list(strip.background = list(col = "white"))))
}
## return data frame
return(mydata)
}
# function to check variables are numeric, if not force with warning
checkNum <- function(mydata, vars) {
for (i in seq_along(vars)) {
if (!is.numeric(mydata[[vars[i]]])) {
mydata[[vars[i]]] <- as.numeric(as.character(mydata[[vars[i]]]))
warning(
paste(vars[i], "is not numeric, forcing to numeric..."),
call. = FALSE
)
}
}
return(mydata)
}
## listUpdate function
# [in development]
listUpdate <- function(a, b, drop.dots = TRUE,
subset.a = NULL, subset.b = NULL) {
if (drop.dots) {
a <- a[names(a) != "..."]
b <- b[names(b) != "..."]
}
if (!is.null(subset.a)) {
a <- a[names(a) %in% subset.a]
}
if (!is.null(subset.b)) {
b <- b[names(b) %in% subset.b]
}
if (length(names(b) > 0)) {
a <- modifyList(a, b)
}
a
}
Then use the following code for plotting
TaylorDiagram1(data, obs = "Observed", mod = "Predicted", group = "Method", type = "Station",
scales=list(alternating=1),normalise = TRUE,fontsize=12,
rms.col="black",auto.text=F,xlab="Standard deviation",
cex = 1, ylab="Standard deviation",par.settings = list( grid.pars = list(fontfamily = "serif")))
I am trying to inplement a simulation of a simplified blackjack game that will return the best policy at each state s.
The blackjack simulation seems to work properly, but i somehow get an error when trying to apply the Q learning algorithm to reach the optimal policy.
Here's my code, i believe it's well documented, error is in the Q-learning block, starting at ~line 170, it is also reproducible :
#Application reinforcement learning for black jack. We will suppose here that the croupier only has 1 pack of cards
#Initial tabs
packinit = c(rep(1,4), rep(2,4),rep(3,4),rep(4,4),rep(5,4),rep(6,4),rep(7,4),rep(8,4),
rep(9,4),rep(10,16))
#In our game and for simplicifaction of the problem, aces will always count as 1. Other figures are worth 10.
#If both player and croupier have same score, then player looses.
#Croupier will draw cards until he has 17 or more.
handPinit = NULL # will contain hand of player
handCinit = NULL # will contain hand of the croupier
list = list(handPinit, handCinit, packinit)
# Methods ####################################################################################
##############################################################################################
#Random integer, returns an integer to choose card
randInt = function(pack){
int = runif(1) * length(pack)
int = int+1
int = round(int, 0)
return(int)
}
#Picks a card, asimResults it to the desired hand and deletes it from the package.
pickC = function(hand, pack){
int = randInt(pack)
hand = c(hand, pack[int])
pack = pack[-int]
return(list(hand, pack))
}
score = function(handC){
return(sum(handC, na.rm = T))
}
printWinner = function(resultList){
res = resultList[[4]]
p = res[1]
c = res[2]
if((p > c && p <= 21) || (p <= 21 && c > 21)){
cat("Player has won with ", p, ", croupier has ", c, ".\n", sep = "")
}else{
cat("Player has lost with ", p, ", croupier has ", c, ".\n", sep = "")
}
}
#Black jack sim :
simulation = function(handP, handC, pack){
#Matrix to stock choice and next state, 1st is state, 2nd is choice, 3rd is reward, 4th is start state
cs = NULL
#pick first card
temp = NULL
temp = pickC(handP, pack)
handP = temp[[1]]
pack = temp[[2]]
temp = pickC(handC, pack)
handC = temp[[1]]
pack = temp[[2]]
#stock result
cs = rbind(cs, c(score(handP), 1, 0.1, 0))
#pick second card
temp = pickC(handP, pack)
handP = temp[[1]]
pack = temp[[2]]
temp = pickC(handC, pack)
handC = temp[[1]]
pack = temp[[2]]
#stock result
cs = rbind(cs, c(score(handP), 1, 0.1, cs[length(cs[,1]), 1]))
#reward stock final
reward = NULL
#to change with algo decision
while(score(handC) < 17){
#rand number to choose action, 1 = draw
rand = round(2*runif(1),0)
#if a = 1, draw a card
if(rand == 1 && score(handP) < 21){
temp = pickC(handP, pack)
handP = temp[[1]]
pack = temp[[2]]
cs = rbind(cs, c(score(handP), 1, 0.1, cs[length(cs[,1]), 1] ))
}else{
cs = rbind(cs, c(score(handP), 0, 0.1, cs[length(cs[,1]), 1]))
}
#if croupier < 17, he draws a card
if(score(handC) < 17){
temp = pickC(handC, pack)
handC = temp[[1]]
pack = temp[[2]]
}
}
#get scores
scores = c(score(handP), score(handC))
resultList = list(handP, handC, pack, scores)
#get reward
res = resultList[[4]]
p = res[1]
c = res[2]
if((p > c && p <= 21) || (p <= 21 && c > 21)){
reward = 100
}else{
reward = -50
}
#AsimResults reward as the reward of the last line of cs
cs[length(cs[,1]), 3] = reward
#return full list
resultList = list(handP, handC, pack, scores, cs)
return(resultList)
}
#Function for simulation, outputs tab containins states, actions and choices
simRand = function(k){
resultsRand = NULL
for(i in 1:k){
#init pack and hands
pack = c(rep(1,4), rep(2,4),rep(3,4),rep(4,4),rep(5,4),rep(6,4),rep(7,4),rep(8,4),
rep(9,4),rep(10,16))
handC = NULL
handP = NULL
#simulation k
res = simulation(handP, handC, pack)
resultsRand = rbind(resultsRand, res[[5]])
#resets for next iteration
pack = c(rep(1,4), rep(2,4),rep(3,4),rep(4,4),rep(5,4),rep(6,4),rep(7,4),rep(8,4),
rep(9,4),rep(10,16))
handC = NULL
handP = NULL
}
return(resultsRand)
}
#test
for(i in 1:10){
results = simulation(handPinit, handCinit, packinit)
printWinner(results)
}
#used to max the Qvalue decision
getRowMax = function(tab){
temp = tab[1]
for(i in 2:length(tab)){
if(tab[i] > temp){
temp = tab[i]
}
}
}
#####################################################################
#Q-learning
#####################################################################
#Represent sets of Q(s, a)
Qvalues = matrix(1, nrow = 30, ncol = 2)
simResults = simRand(1000)
#Hyperparameters
alpha = 0.9
discount = 0.1
#for all rows simulated, update qvalues.
for(i in 1:length(simResults[,1])){
st = simResults[i, 4] #st
a = simResults[i, 2] #a
stPlusOne = simResults[i, 1] #st+1
Qvalues[st, a] = Qvalues[st, a] + alpha * ( simResults[i,3] * discount * getRowMax(Qvalues[stPlusOne, ]) - Qvalues[st, a] )
}
As LucyMLi points out:
First you need to add return(temp) object to the getRowMax function.
But there is another issue with your simulation, because some of the
values in simResults[, 1] are 0, which means Qvalues[stPlusOne, ] will
be empty and thus you can't compute getRowMax().