Estimating a 2-PL Model in Julia (Part 2)
The EM algorithm is usually used for estimation problems that involve some latent variables
In a general IRT model, the marginal likelihood function
assuming that observations are independently and identically distributed given
where
where
With the EM algorithm, we update our parameter estimates by iterating between two steps:
- E-step: obtain
with respect to the conditional distribution , treating as known; - M-step: obtain new
that maximizes .
Note: the setup of the problem in this post deviates a bit from the one in the original Bock and Aitkin paper, which assumes a multinomial distribution of the sample counts of each response pattern. The resulting estimating equations should be the same.
E-Step
First consider the E-step for the 2-PL model, which has
In addition, in IRT, because item responses are discrete, the LSAT
data, which has 1,000 observations, it only has 30 response patterns of 5 binary items:
using RCall
= rcopy(R"mirt::LSAT6") lsat
30×6 DataFrame
Row │ Item_1 Item_2 Item_3 Item_4 Item_5 Freq
│ Int64 Int64 Int64 Int64 Int64 Int64
─────┼───────────────────────────────────────────────
1 │ 0 0 0 0 0 3
2 │ 0 0 0 0 1 6
3 │ 0 0 0 1 0 2
4 │ 0 0 0 1 1 11
5 │ 0 0 1 0 0 1
6 │ 0 0 1 0 1 1
7 │ 0 0 1 1 0 3
8 │ 0 0 1 1 1 4
⋮ │ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮
24 │ 1 1 0 0 1 56
25 │ 1 1 0 1 0 21
26 │ 1 1 0 1 1 173
27 │ 1 1 1 0 0 11
28 │ 1 1 1 0 1 61
29 │ 1 1 1 1 0 28
30 │ 1 1 1 1 1 298
15 rows omitted
So instead of computing the loglikelihood for 1,000 observations, we only need to do it for
with
For the E-step, we first need the conditional distribution
We need a prior distribution
Note the subscript
Note that the second expectation term is not a function of the item parameters (it is a function of
M-Step
In the M-step, we find new values
where
Following the previous literature, let
where
for
Estimating a 2-PL Model with EM in Julia
Here is my attempt to implement the EM algorithm in Julia, following the steps laid out in Harwell (1988). First, load the packages
using LinearAlgebra, LogExpFunctions
using FastGaussQuadrature: gausshermite
using NLsolve
using BenchmarkTools
Find and
# Helper for computing logits: ηᵢⱼ = aⱼθ + dⱼ
function compute_logits(θ, a, d)
* a[j] + d[j]
[θ[i] = eachindex(θ), j = eachindex(a)]
for i end
compute_logits (generic function with 1 method)
function eloglik_2pl_em(y, n, θ, w, parₜ)
= size(y, 2)
num_items = parₜ[1:num_items]
aₜ = parₜ[num_items+1:end]
dₜ = compute_logits(θ, aₜ, dₜ)
ηₜ = sum(log1pexp, ηₜ, dims=2)
sum1pexpη = Matrix{eltype(aₜ)}(undef, length(θ), length(n))
wpy_given_θ for l in eachindex(n)
:, l] = w .* exp.(ηₜ * view(y, l, :) .- sum1pexpη)
wpy_given_θ[end
= wpy_given_θ ./ sum(wpy_given_θ, dims=1)
pθ_given_y =pθ_given_y * n,
(bar_nₖ=pθ_given_y * (n .* y))
bar_rⱼₖend
eloglik_2pl_em (generic function with 1 method)
# Test:
= gausshermite(15) # 15 quadrature points gh15
([-4.499990707309391, -3.669950373404453, -2.9671669279056054, -2.3257324861738606, -1.7199925751864926, -1.136115585210924, -0.5650695832555779, -3.552713678800501e-15, 0.5650695832555779, 1.136115585210924, 1.7199925751864926, 2.3257324861738606, 2.9671669279056054, 3.669950373404453, 4.499990707309391], [1.5224758042535368e-9, 1.0591155477110773e-6, 0.00010000444123250024, 0.0027780688429127607, 0.030780033872546228, 0.15848891579593563, 0.41202868749889865, 0.5641003087264175, 0.41202868749889865, 0.15848891579593563, 0.030780033872546228, 0.0027780688429127607, 0.00010000444123250024, 1.0591155477110773e-6, 1.5224758042535368e-9])
= gh15[1] .* √2 gh15_node
15-element Vector{Float64}:
-6.363947888829838
-5.190093591304782
-4.196207711269019
-3.289082424398771
-2.4324368270097634
-1.6067100690287344
-0.7991290683245511
-5.0242958677880805e-15
0.7991290683245511
1.6067100690287344
2.4324368270097634
3.289082424398771
4.196207711269019
5.190093591304782
6.363947888829838
= gh15[2] ./ √π gh15_weight
15-element Vector{Float64}:
8.589649899633383e-10
5.975419597920666e-7
5.642146405189039e-5
0.0015673575035499477
0.01736577449213769
0.08941779539984435
0.23246229360973225
0.31825951825951826
0.23246229360973225
0.08941779539984435
0.01736577449213769
0.0015673575035499477
5.642146405189039e-5
5.975419597920666e-7
8.589649899633383e-10
= eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
exp1
gh15_node, gh15_weight,ones(5); zeros(5)]) [
(bar_nₖ = [2.7564518778742383e-8, 2.0114861555321952e-5, 0.002128243156556154, 0.07560094100777943, 1.3352261691433547, 14.249248990842812, 89.7255889231157, 278.1775106946097, 365.8604302697971, 199.39830527246278, 46.43045474007526, 4.57238810409275, 0.17124982332670236, 0.001845015100526724, 2.6708429185723037e-6], bar_rⱼₖ = [4.138800953383068e-10 4.341143688898378e-11 … 8.455317850390647e-11 2.4966992659649684e-10; 9.488948442862142e-7 1.0961092559330794e-7 … 2.0276924419430197e-7 5.791071220863187e-7; … ; 0.0018436478096575996 0.0018377274514118978 … 0.0018394578991136525 0.0018424609777163434; 2.670229820441103e-6 2.667573675000273e-6 … 2.6683500732256074e-6 2.669698207072498e-6])
The output is a tuple, with the first element bar_nₖ
being a bar_rⱼₖ
being a
Solve estimating equations
function compute_probs(θ, a, d)
logistic(θ[i] * a[j] + d[j]) for i = eachindex(θ), j = eachindex(a)]
[end
compute_probs (generic function with 1 method)
function esteq_2pl_em(par, bar_r, bar_n, θ)
= size(bar_r, 2)
num_items = par[1:num_items]
a = par[num_items+1:end]
d = bar_r .- bar_n .* compute_probs(θ, a, d)
rmntpθ vec([sum(rmntpθ, dims=1) θ' * rmntpθ])
end
esteq_2pl_em (generic function with 1 method)
# Test:
= nlsolve(x -> esteq_2pl_em(x,
root1
exp1.bar_rⱼₖ, exp1.bar_nₖ, gh15_node),ones(5); zeros(5)]) [
Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]
* Zero: [0.9441787368725837, 0.9412834339373086, 0.9702850411757558, 0.932179532744094, 0.9197528908890249, 2.166166737455263, 0.41020205101729856, -0.37879989990817486, 0.7262454452107591, 1.5321165384312954]
* Inf-norm of residuals: 0.000000
* Iterations: 6
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 7
* Jacobian Calls (df/dx): 7
The solution is contained in the zero
field. These will be passed back to the E-step.
Iterations
We can do two more iterations:
= eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
exp2
gh15_node, gh15_weight, root1.zero)
(bar_nₖ = [1.7202823621463788e-7, 0.00012478229640175463, 0.012917226708649957, 0.42950281900235543, 6.374685649928196, 47.54065170285358, 177.55073039082802, 324.1944490261614, 289.0821550106182, 125.99472138888558, 26.268899898287724, 2.460042585603299, 0.09015529832828216, 0.0009626597451594212, 1.3887248384670485e-6], bar_rⱼₖ = [5.372365387113749e-9 6.042186880320284e-10 … 1.1967481795709406e-9 3.692939772198374e-9; 1.1195428419388414e-5 1.486089377064305e-6 … 2.671355380687325e-6 7.62773320369415e-6; … ; 0.0009620030652981468 0.000959109431245618 … 0.0009598307020344003 0.0009612789437790556; 1.3884114755863066e-6 1.3870244275199353e-6 … 1.3873554006168715e-6 1.3880472692324014e-6])
= nlsolve(x -> esteq_2pl_em(x,
root2
exp2.bar_rⱼₖ, exp2.bar_nₖ, gh15_node), root1.zero)
Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [0.9441787368725837, 0.9412834339373086, 0.9702850411757558, 0.932179532744094, 0.9197528908890249, 2.166166737455263, 0.41020205101729856, -0.37879989990817486, 0.7262454452107591, 1.5321165384312954]
* Zero: [0.8851376144605885, 0.8780270229700008, 0.9305712408536134, 0.8617362606683743, 0.8410113215381403, 2.5424571803032903, 0.781691569809412, -0.0034125639910551494, 1.0943152659197337, 1.8930834808194317]
* Inf-norm of residuals: 0.000000
* Iterations: 4
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 5
* Jacobian Calls (df/dx): 5
= eloglik_2pl_em(Matrix(lsat[:, 1:5]), lsat[:, 6],
exp3
gh15_node, gh15_weight, root2.zero)
(bar_nₖ = [4.3023109522800455e-7, 0.0003040218519114245, 0.02987372125524563, 0.9055221912858598, 11.602389316066303, 71.03280737295191, 214.59437201659995, 325.67522637877084, 253.23782066569376, 101.01070654940031, 20.01802278052005, 1.826159972919833, 0.06609181187521956, 0.000701760614451534, 1.0099632324488185e-6], bar_rⱼₖ = [2.0531055047839344e-8 2.5014192320684056e-9 … 4.965693882437507e-9 1.5557900531690245e-8; 3.8265586589167865e-5 5.753192037170773e-6 … 1.0076467391887043e-5 2.8241538126234826e-5; … ; 0.0007012180412131716 0.0006987629848589311 … 0.0006992811247620892 0.0007004963056719968; 1.00968636885505e-6 1.0084204735673418e-6 … 1.008662595953656e-6 1.0092845394530493e-6])
= nlsolve(x -> esteq_2pl_em(x,
root3
exp3.bar_rⱼₖ, exp3.bar_nₖ, gh15_node), root2.zero)
Results of Nonlinear Solver Algorithm
* Algorithm: Trust-region with dogleg and autoscaling
* Starting Point: [0.8851376144605885, 0.8780270229700008, 0.9305712408536134, 0.8617362606683743, 0.8410113215381403, 2.5424571803032903, 0.781691569809412, -0.0034125639910551494, 1.0943152659197337, 1.8930834808194317]
* Zero: [0.8479491477352421, 0.8323916555812338, 0.902644467313084, 0.8111833641429015, 0.7863596508859064, 2.6865158580988453, 0.9283333274643143, 0.1559502636058875, 1.2352944393730914, 2.0242412851515406]
* Inf-norm of residuals: 0.000000
* Iterations: 4
* Convergence: true
* |x - x'| < 0.0e+00: false
* |f(x)| < 1.0e-08: true
* Function Calls (f): 5
* Jacobian Calls (df/dx): 5
Stopping criteria
One way to stop the iteration is when the absolute change in the parameter estimates is less than a certain threshold (e.g., 0.00001). For example, the following shows the maximum absolute change in the parameter estimates from the second and the third iteration:
maximum(abs.(root3.zero .- root2.zero))
0.15936282759694267
which is pretty large. So we should do more iterations.
Benchmarking
We can wrap the steps into a function
function estimate_2pl_em(y, n, init,
=101, par_tol=1e-5, rtol=1e-5, max_iter=1000)
n_quadpts= init
parₜ = parₜ
parₜ₊₁ # Convert y to matrix
= Matrix(y)
y # Obtain quadrature nodes and weights
= gausshermite(n_quadpts)
ghq = ghq[1] .* √2
ghq_θ = ghq[2] ./ √π
ghq_w = 1
i while i < max_iter
= eloglik_2pl_em(y, n,
expₜ
ghq_θ, ghq_w, parₜ)= nlsolve(x -> esteq_2pl_em(x,
root
expₜ.bar_rⱼₖ, expₜ.bar_nₖ, ghq_θ),=:forward)
parₜ, autodiff= root.zero
parₜ₊₁ if maximum(abs.(parₜ₊₁ - parₜ)) < par_tol
break
else
= parₜ₊₁
parₜ += 1
i end
end
=parₜ₊₁, num_iter=i)
(estimateend
estimate_2pl_em (generic function with 5 methods)
@btime est_em = estimate_2pl_em(lsat[:, 1:5], lsat[:, 6], [ones(5); zeros(5)])
16.968 ms (16212 allocations: 29.76 MiB)
(estimate = [0.8256464036061023, 0.7227767138869371, 0.8907854481666115, 0.6883896452213724, 0.6568975184341934, 2.773226187628446, 0.9902095352382095, 0.24914112339506816, 1.284763785097769, 2.0532889359951128], num_iter = 66)
Compare to mirt
data("LSAT", package = "ltm")
library(mirt)
::mark(
benchmirt = mirt(LSAT,
verbose = FALSE, quadpts = 101,
TOL = 1e-5)
)
# A tibble: 1 × 6
expression min median `itr/sec` mem_alloc `gc/sec`
<bch:expr> <bch:tm> <bch:tm> <dbl> <bch:byt> <dbl>
1 mirt 107ms 107ms 9.38 22.4MB 28.1
Remark
My implementation of the EM takes a shorter time to run than direct MML (see my post in Part 1), but it does not compute the standard errors. Also, it probably uses a different convergence criterion than direct MML using Optim.jl
, so it’s hard to say which one is faster.