Estimating a 2-PL Model in Julia (Part 2)
The EM algorithm is usually used for estimation problems that involve some latent variables
In a general IRT model, the marginal likelihood function
assuming that observations are independently and identically distributed given
where
where
With the EM algorithm, we update our parameter estimates by iterating between two steps:
- E-step: obtain
with respect to the conditional distribution , treating as known; - M-step: obtain new
that maximizes .
Note: the setup of the problem in this post deviates a bit from the one in the original Bock and Aitkin paper, which assumes a multinomial distribution of the sample counts of each response pattern. The resulting estimating equations should be the same.
E-Step
First consider the E-step for the 2-PL model, which has
In addition, in IRT, because item responses are discrete, the LSAT data, which has 1,000 observations, it only has 30 response patterns of 5 binary items:
using RCall
lsat = rcopy(R"mirt::LSAT6")30×6 DataFrame
Row │ Item_1 Item_2 Item_3 Item_4 Item_5 Freq
│ Int64 Int64 Int64 Int64 Int64 Int64
─────┼───────────────────────────────────────────────
1 │ 0 0 0 0 0 3
2 │ 0 0 0 0 1 6
3 │ 0 0 0 1 0 2
4 │ 0 0 0 1 1 11
5 │ 0 0 1 0 0 1
6 │ 0 0 1 0 1 1
7 │ 0 0 1 1 0 3
8 │ 0 0 1 1 1 4
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
24 │ 1 1 0 0 1 56
25 │ 1 1 0 1 0 21
26 │ 1 1 0 1 1 173
27 │ 1 1 1 0 0 11
28 │ 1 1 1 0 1 61
29 │ 1 1 1 1 0 28
30 │ 1 1 1 1 1 298
15 rows omitted
So instead of computing the loglikelihood for 1,000 observations, we only need to do it for
with
For the E-step, we first need the conditional distribution
We need a prior distribution
Note the subscript
Note that the second expectation term is not a function of the item parameters (it is a function of
M-Step
In the M-step, we find new values
where
Following the previous literature, let
where
for
Estimating a 2-PL Model with EM in Julia
Here is my attempt to implement the EM algorithm in Julia, following the steps laid out in Harwell (1988). First, load the packages
using LinearAlgebra, LogExpFunctions
using FastGaussQuadrature: gausshermite
using NLsolve
using BenchmarkToolsFind and
# Helper for computing logits: ηᵢⱼ = aⱼθ + dⱼ
function compute_logits(θ, a, d)
[θ[i] * a[j] + d[j]
for i = eachindex(θ), j = eachindex(a)]
endcompute_logits (generic function with 1 method)
function eloglik_2pl_em(y, n, θ, w, parₜ)
num_items = size(y, 2)
aₜ = parₜ[1:num_items]
dₜ = parₜ[num_items+1:end]
ηₜ = compute_logits(θ, aₜ, dₜ)
sum1pexpη = sum(log1pexp, ηₜ, dims=2)
wpy_given_θ = Matrix{eltype(aₜ)}(undef, length(θ), length(n))
for l in eachindex(n)
wpy_given_θ[:, l] = w .* exp.(ηₜ * view(y, l, :) .- sum1pexpη)
end
pθ_given_y = wpy_given_θ ./ sum(wpy_given_θ, dims=1)
(bar_nₖ=pθ_given_y * n,
bar_rⱼₖ=pθ_given_y * (n .* y))
endeloglik_2pl_em (generic function with 1 method)
# Test:
gh15 = gausshermite(15) # 15 quadrature points([-4.499990707309391, -3.669950373404453, -2.9671669279056054, -2.3257324861738606, -1.7199925751864926, -1.136115585210924, -0.5650695832555779, -3.552713678800501e-15, 0.5650695832555779, 1.136115585210924, 1.7199925751864926, 2.3257324861738606, 2.9671669279056054, 3.669950373404453, 4.499990707309391], [1.5224758042535368e-9, 1.0591155477110773e-6, 0.00010000444123250024, 0.0027780688429127607, 0.030780033872546228, 0.15848891579593563, 0.41202868749889865, 0.5641003087264175, 0.41202868749889865, 0.15848891579593563, 0.030780033872546228, 0.0027780688429127607, 0.00010000444123250024, 1.0591155477110773e-6, 1.5224758042535368e-9])
gh15_node = gh15[1] .* √215-element Vector{Float64}:
-6.363947888829838
-5.190093591304782
-4.196207711269019
-3.289082424398771
-2.4324368270097634
-1.6067100690287344
-0.7991290683245511
-5.0242958677880805e-15
0.7991290683245511
1.6067100690287344
2.4324368270097634
3.289082424398771
4.196207711269019
5.190093591304782
6.363947888829838
gh15_weight = gh15[2] ./ √π15-element Vector{Float64}:
8.589649899633383e-10
5.975419597920666e-7
5.642146405189039e-5
0.0015673575035499477
0.01736577449213769
0.08941779539984435
0.23246229360973225
0.31825951825951826
0.23246229360973225
0.08941779539984435
0.01736577449213769
0.0015673575035499477
5.642146405189039e-5
5.975419597920666e-7
8.589649899633383e-10
exp1 = eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
gh15_node, gh15_weight,
[ones(5); zeros(5)])(bar_nₖ = [2.7564518778742383e-8, 2.0114861555321952e-5, 0.002128243156556154, 0.07560094100777943, 1.3352261691433547, 14.249248990842812, 89.7255889231157, 278.1775106946097, 365.8604302697971, 199.39830527246278, 46.43045474007526, 4.57238810409275, 0.17124982332670236, 0.001845015100526724, 2.6708429185723037e-6], bar_rⱼₖ = [4.138800953383068e-10 4.341143688898378e-11 … 8.455317850390647e-11 2.4966992659649684e-10; 9.488948442862142e-7 1.0961092559330794e-7 … 2.0276924419430197e-7 5.791071220863187e-7; … ; 0.0018436478096575996 0.0018377274514118978 … 0.0018394578991136525 0.0018424609777163434; 2.670229820441103e-6 2.667573675000273e-6 … 2.6683500732256074e-6 2.669698207072498e-6])
The output is a tuple, with the first element bar_nₖ being a bar_rⱼₖ being a
Solve estimating equations
function compute_probs(θ, a, d)
[logistic(θ[i] * a[j] + d[j]) for i = eachindex(θ), j = eachindex(a)]
endcompute_probs (generic function with 1 method)
function esteq_2pl_em(par, bar_r, bar_n, θ)
num_items = size(bar_r, 2)
a = par[1:num_items]
d = par[num_items+1:end]
rmntpθ = bar_r .- bar_n .* compute_probs(θ, a, d)
vec([sum(rmntpθ, dims=1) θ' * rmntpθ])
endesteq_2pl_em (generic function with 1 method)
# Test:
root1 = nlsolve(x -> esteq_2pl_em(x,
exp1.bar_rⱼₖ, exp1.bar_nₖ, gh15_node),
[ones(5); zeros(5)])Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
* Zero: [0.9441787368725837, 0.9412834339373086, 0.9702850411757558, 0.932179532744094, 0.9197528908890249, 2.166166737455263, 0.41020205101729856, -0.37879989990817486, 0.7262454452107591, 1.5321165384312954]
* Inf-norm of residuals: 0.000000
* Iterations: 6
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 7
* Jacobian Calls (df/dx): 7
The solution is contained in the zero field. These will be passed back to the E-step.
Iterations
We can do two more iterations:
exp2 = eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
gh15_node, gh15_weight,
root1.zero)(bar_nₖ = [1.7202823621463788e-7, 0.00012478229640175463, 0.012917226708649957, 0.42950281900235543, 6.374685649928196, 47.54065170285358, 177.55073039082802, 324.1944490261614, 289.0821550106182, 125.99472138888558, 26.268899898287724, 2.460042585603299, 0.09015529832828216, 0.0009626597451594212, 1.3887248384670485e-6], bar_rⱼₖ = [5.372365387113749e-9 6.042186880320284e-10 … 1.1967481795709406e-9 3.692939772198374e-9; 1.1195428419388414e-5 1.486089377064305e-6 … 2.671355380687325e-6 7.62773320369415e-6; … ; 0.0009620030652981468 0.000959109431245618 … 0.0009598307020344003 0.0009612789437790556; 1.3884114755863066e-6 1.3870244275199353e-6 … 1.3873554006168715e-6 1.3880472692324014e-6])
root2 = nlsolve(x -> esteq_2pl_em(x,
exp2.bar_rⱼₖ, exp2.bar_nₖ, gh15_node),
root1.zero)Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [0.9441787368725837, 0.9412834339373086, 0.9702850411757558, 0.932179532744094, 0.9197528908890249, 2.166166737455263, 0.41020205101729856, -0.37879989990817486, 0.7262454452107591, 1.5321165384312954]
* Zero: [0.8851376144605885, 0.8780270229700008, 0.9305712408536134, 0.8617362606683743, 0.8410113215381403, 2.5424571803032903, 0.781691569809412, -0.0034125639910551494, 1.0943152659197337, 1.8930834808194317]
* Inf-norm of residuals: 0.000000
* Iterations: 4
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 5
* Jacobian Calls (df/dx): 5
exp3 = eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
gh15_node, gh15_weight,
root2.zero)(bar_nₖ = [4.3023109522800455e-7, 0.0003040218519114245, 0.02987372125524563, 0.9055221912858598, 11.602389316066303, 71.03280737295191, 214.59437201659995, 325.67522637877084, 253.23782066569376, 101.01070654940031, 20.01802278052005, 1.826159972919833, 0.06609181187521956, 0.000701760614451534, 1.0099632324488185e-6], bar_rⱼₖ = [2.0531055047839344e-8 2.5014192320684056e-9 … 4.965693882437507e-9 1.5557900531690245e-8; 3.8265586589167865e-5 5.753192037170773e-6 … 1.0076467391887043e-5 2.8241538126234826e-5; … ; 0.0007012180412131716 0.0006987629848589311 … 0.0006992811247620892 0.0007004963056719968; 1.00968636885505e-6 1.0084204735673418e-6 … 1.008662595953656e-6 1.0092845394530493e-6])
root3 = nlsolve(x -> esteq_2pl_em(x,
exp3.bar_rⱼₖ, exp3.bar_nₖ, gh15_node),
root2.zero)Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [0.8851376144605885, 0.8780270229700008, 0.9305712408536134, 0.8617362606683743, 0.8410113215381403, 2.5424571803032903, 0.781691569809412, -0.0034125639910551494, 1.0943152659197337, 1.8930834808194317]
* Zero: [0.8479491477352421, 0.8323916555812338, 0.902644467313084, 0.8111833641429015, 0.7863596508859064, 2.6865158580988453, 0.9283333274643143, 0.1559502636058875, 1.2352944393730914, 2.0242412851515406]
* Inf-norm of residuals: 0.000000
* Iterations: 4
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 5
* Jacobian Calls (df/dx): 5
Stopping criteria
One way to stop the iteration is when the absolute change in the parameter estimates is less than a certain threshold (e.g., 0.00001). For example, the following shows the maximum absolute change in the parameter estimates from the second and the third iteration:
maximum(abs.(root3.zero .- root2.zero))0.15936282759694267
which is pretty large. So we should do more iterations.
Benchmarking
We can wrap the steps into a function
function estimate_2pl_em(y, n, init,
n_quadpts=101, par_tol=1e-5, rtol=1e-5, max_iter=1000)
parₜ = init
parₜ₊₁ = parₜ
# Convert y to matrix
y = Matrix(y)
# Obtain quadrature nodes and weights
ghq = gausshermite(n_quadpts)
ghq_θ = ghq[1] .* √2
ghq_w = ghq[2] ./ √π
i = 1
while i < max_iter
expₜ = eloglik_2pl_em(y, n,
ghq_θ, ghq_w, parₜ)
root = nlsolve(x -> esteq_2pl_em(x,
expₜ.bar_rⱼₖ, expₜ.bar_nₖ, ghq_θ),
parₜ, autodiff=:forward)
parₜ₊₁ = root.zero
if maximum(abs.(parₜ₊₁ - parₜ)) < par_tol
break
else
parₜ = parₜ₊₁
i += 1
end
end
(estimate=parₜ₊₁, num_iter=i)
endestimate_2pl_em (generic function with 5 methods)
@btime est_em = estimate_2pl_em(lsat[:, 1:5], lsat[:, 6], [ones(5); zeros(5)]) 16.968 ms (16212 allocations: 29.76 MiB)
(estimate = [0.8256464036061023, 0.7227767138869371, 0.8907854481666115, 0.6883896452213724, 0.6568975184341934, 2.773226187628446, 0.9902095352382095, 0.24914112339506816, 1.284763785097769, 2.0532889359951128], num_iter = 66)
Compare to mirt
data("LSAT", package = "ltm")
library(mirt)
bench::mark(
mirt = mirt(LSAT,
verbose = FALSE, quadpts = 101,
TOL = 1e-5)
)# A tibble: 1 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 mirt 107ms 107ms 9.38 22.4MB 28.1
Remark
My implementation of the EM takes a shorter time to run than direct MML (see my post in Part 1), but it does not compute the standard errors. Also, it probably uses a different convergence criterion than direct MML using Optim.jl, so it’s hard to say which one is faster.