library("igraph")
library("fpc")
library("popbio")

source("http://www.cis.jhu.edu/~parky/HSBM/ccc_utils.r")

set.seed(1234)
R <- 8 
(nR.vec <- 200+100*ceiling(5*runif(R)))
# [1] 300 600 600 600 700 600 300 400
## Motif labels for the SBM
(motif.vec <- sample(1:3, R, replace = TRUE))
# [1] 2 2 3 2 1 3 1 3
vlab <- rep(1:R, times=nR.vec)
mlab <- rep(motif.vec, times=nR.vec)

p <- 0.01

B.list <- list()
rho.list <- list()
for(r in 1:R){
    if(motif.vec[r] == 1){
        B.list[[r]] <- matrix(0.25,3,3);
        diag(B.list[[r]]) <- .4
        rho.list[[r]] <- c(0.25, 0.5, 0.25)

    } else if(motif.vec[r] == 2) {
        B.list[[r]] <- matrix(0.2,3,3);
        diag(B.list[[r]]) <- .25; B.list[[r]][2,2] <- .8
        rho.list[[r]] <- c(0.3, 0.4, 0.3)

    } else {
        B.list[[r]] <- matrix(0.25,3,3);
        diag(B.list[[r]]) <- 0.3; B.list[[r]][3,3] <- .7
        rho.list[[r]] <- c(0.4, 0.2, 0.4)
    }
}
rho.list
# [[1]]
# [1] 0.3 0.4 0.3
# 
# [[2]]
# [1] 0.3 0.4 0.3
# 
# [[3]]
# [1] 0.4 0.2 0.4
# 
# [[4]]
# [1] 0.3 0.4 0.3
# 
# [[5]]
# [1] 0.25 0.50 0.25
# 
# [[6]]
# [1] 0.4 0.2 0.4
# 
# [[7]]
# [1] 0.25 0.50 0.25
# 
# [[8]]
# [1] 0.4 0.2 0.4
B.list
# [[1]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[2]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[3]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
# 
# [[4]]
#      [,1] [,2] [,3]
# [1,] 0.25  0.2 0.20
# [2,] 0.20  0.8 0.20
# [3,] 0.20  0.2 0.25
# 
# [[5]]
#      [,1] [,2] [,3]
# [1,] 0.40 0.25 0.25
# [2,] 0.25 0.40 0.25
# [3,] 0.25 0.25 0.40
# 
# [[6]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
# 
# [[7]]
#      [,1] [,2] [,3]
# [1,] 0.40 0.25 0.25
# [2,] 0.25 0.40 0.25
# [3,] 0.25 0.25 0.40
# 
# [[8]]
#      [,1] [,2] [,3]
# [1,] 0.30 0.25 0.25
# [2,] 0.25 0.30 0.25
# [3,] 0.25 0.25 0.70
## subblock label
svec <- lapply(1:R, function(x) rho.list[[x]]*nR.vec[x])
slab <- unlist(sapply(1:R, function(x) rep(1:3,times=svec[[x]])+(x-1)*3))
    

g <- sample_hierarchical_sbm(sum(nR.vec), nR.vec, rho.list, B.list, p)
mycol <- rainbow(max(mlab))[motif.vec]
plotmemb(g[],vlab,main=paste("A, R = ", max(vlab), ", m = 3"),drawborder=TRUE,lwd=.01,lcol=mycol,lwdb=2)
## Step 1
dmax <- 50
Xhat <- embed_adjacency_matrix(g,dmax,options=list(maxiter=10000))$X
eval <- sqrt(colSums(Xhat^2))
(dhat <- getElbows(eval,3,plot=F))
# [1]  8 14 16
dhat <- dhat[1]

sXhat <- Xhat[,1:dhat] / sqrt(rowSums(Xhat[,1:dhat]^2))
Rmax <- 1.5*dhat
pamkout <- pamk(sXhat,krange=floor(Rmax/2):Rmax,usepam=FALSE)
pamkout$nc
# [1] 8
membp <- pamkout$pamobj$cluster
(tablep <- table(membp))
# membp
#   1   2   3   4   5   6   7   8 
# 300 600 600 600 700 600 300 400
mycol2 <- rainbow(max(membp))
plotmemb(g[],membp,main=paste("Skmeans, Rhat = ", max(membp)),drawborder=TRUE,lwd=.01,lcol=mycol2,lwdb=2)
## Step 2
sigma <- 0.5
X.list <- reembed(g[], 3, membp)
S <- computeS(X.list, sigma)
rownames(S) <- colnames(S) <- 1:pamkout$nc
my.image2(S[order(motif.vec), order(motif.vec)], text.cex = 0,
       round = 2, srt = 0, border = "gray70", label.cex = 1.5)
graphs.cluster <- pamk(S, diss = TRUE, krange = 2:(length(X.list)-1))
Yhat <- graphs.cluster$pamobject$clustering
(numc <- graphs.cluster$nc)
# [1] 3
## Step 3
Bhat.list <- list()
rhohat.list <- list()
for(i in 1:numc){
    idx.i <- which(Yhat == i)
    Xi <- NULL
    for(j in idx.i){
        if(is.null(Xi)){
            Xi <- X.list[[j]]
        } else {
            T <- find.transform(X.list[[j]], Xi)
            Xi <- rbind(Xi, X.list[[j]] %*% T)
        }
    }
    Xi.pamk <- pamk(Xi)
    Bi <- matrix(0, Xi.pamk$nc, Xi.pamk$nc)
    rhohat.list[[i]] <- as.vector(table(Xi.pamk$pamobj$cluster))/nrow(Xi)
    for(j1 in 1:Xi.pamk$nc){
        for(j2 in j1:Xi.pamk$nc){
            Bi[j1,j2] <- sum(Xi.pamk$pamobj$medoids[j1,]*Xi.pamk$pamobj$medoids[j2,])
            Bi[j2,j1] <- Bi[j1,j2]
        }
    }
    Bhat.list[[i]] <- Bi
}
Bhat.list
# [[1]]
#           [,1]      [,2]
# [1,] 0.2191020 0.2028399
# [2,] 0.2028399 0.7951436
# 
# [[2]]
#           [,1]      [,2]
# [1,] 0.2743059 0.2492344
# [2,] 0.2492344 0.7204546
# 
# [[3]]
#           [,1]      [,2]      [,3]
# [1,] 0.3924297 0.2535544 0.2684659
# [2,] 0.2535544 0.4080792 0.2626907
# [3,] 0.2684659 0.2626907 0.4050112
rhohat.list
# [[1]]
# [1] 0.6 0.4
# 
# [[2]]
# [1] 0.6 0.4
# 
# [[3]]
# [1] 0.258 0.491 0.251