Skip to content
3 changes: 3 additions & 0 deletions src/JuMP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ mutable struct GenericModel{T<:Real} <: AbstractModel
# A dictionary to store timing information from the JuMP macros.
enable_macro_timing::Bool
macro_times::Dict{Tuple{LineNumberNode,String},Float64}
# We use `Any` as key because we haven't defined `GenericNonlinearExpr` yet
subexpressions::Dict{UInt64,MOI.ScalarNonlinearFunction}
end

value_type(::Type{GenericModel{T}}) where {T} = T
Expand Down Expand Up @@ -251,6 +253,7 @@ function direct_generic_model(
Dict{Any,MOI.ConstraintIndex}(),
false,
Dict{Tuple{LineNumberNode,String},Float64}(),
Dict{UInt64,MOI.ScalarNonlinearFunction}(),
)
end

Expand Down
18 changes: 16 additions & 2 deletions src/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ function moi_function(constraint::AbstractConstraint)
return moi_function(jump_function(constraint))
end

function moi_function(model, constraint::AbstractConstraint)
return moi_function(model, jump_function(constraint))
end

"""
moi_set(constraint::AbstractConstraint)

Expand Down Expand Up @@ -1016,6 +1020,17 @@ function _moi_add_constraint(
return MOI.add_constraint(model, f, s)
end

function check_belongs_to_model(f::Vector, model)
for func in f
check_belongs_to_model(func, model)
end
end

function moi_function(model, f)
check_belongs_to_model(f, model)
return moi_function(f)
end

"""
add_constraint(
model::GenericModel,
Expand All @@ -1032,10 +1047,9 @@ function add_constraint(
name::String = "",
)
con = model_convert(model, con)
func, set = moi_function(model, con), moi_set(con)
# The type of backend(model) is unknown so we directly redirect to another
# function.
check_belongs_to_model(con, model)
func, set = moi_function(con), moi_set(con)
cindex = _moi_add_constraint(
backend(model),
func,
Expand Down
29 changes: 24 additions & 5 deletions src/nlp_expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -569,18 +569,33 @@ end

moi_function(x::Number) = x

function moi_function(f::GenericNonlinearExpr{V}) where {V}
function moi_function(
model::JuMP.GenericModel,
f::GenericNonlinearExpr{V},
) where {V}
cache = model.subexpressions
key = objectid(f)
if haskey(cache, key)
return cache[key]
end
ret = MOI.ScalarNonlinearFunction(f.head, similar(f.args))
stack = Tuple{MOI.ScalarNonlinearFunction,Int,GenericNonlinearExpr{V}}[]
for i in length(f.args):-1:1
if f.args[i] isa GenericNonlinearExpr{V}
push!(stack, (ret, i, f.args[i]))
elseif f.args[i] isa AbstractJuMPScalar
ret.args[i] = moi_function(model, f.args[i])
else
ret.args[i] = moi_function(f.args[i])
end
end
while !isempty(stack)
parent, i, arg = pop!(stack)
arg_key = objectid(arg)
if haskey(cache, arg_key)
parent.args[i] = cache[arg_key]
continue
end
child = MOI.ScalarNonlinearFunction(arg.head, similar(arg.args))
parent.args[i] = child
for j in length(arg.args):-1:1
Expand All @@ -590,7 +605,9 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
child.args[j] = moi_function(arg.args[j])
end
end
cache[arg_key] = child
end
cache[key] = ret
return ret
end

Expand Down Expand Up @@ -1187,7 +1204,8 @@ function moi_function(f::AbstractVector{<:GenericNonlinearExpr})
end

function MOI.VectorNonlinearFunction(f::Vector{<:AbstractJuMPScalar})
return MOI.VectorNonlinearFunction(map(moi_function, f))
model = owner_model(first(f))
return MOI.VectorNonlinearFunction(moi_function.(model, f))
end

"""
Expand Down Expand Up @@ -1233,7 +1251,7 @@ x
```
"""
function simplify(model::GenericModel, f::AbstractJuMPScalar)
g = MOI.Nonlinear.SymbolicAD.simplify(moi_function(f))
g = MOI.Nonlinear.SymbolicAD.simplify(moi_function(model, f))
return jump_function(model, g)
end

Expand Down Expand Up @@ -1284,7 +1302,8 @@ function derivative(
f::AbstractJuMPScalar,
x::GenericVariableRef{T},
) where {T}
df_dx = MOI.Nonlinear.SymbolicAD.derivative(moi_function(f), index(x))
df_dx =
MOI.Nonlinear.SymbolicAD.derivative(moi_function(model, f), index(x))
return jump_function(model, MOI.Nonlinear.SymbolicAD.simplify!(df_dx))
end

Expand Down Expand Up @@ -1329,7 +1348,7 @@ julia> ∇f[y]
```
"""
function gradient(model::GenericModel{T}, f::AbstractJuMPScalar) where {T}
g = moi_function(f)
g = moi_function(model, f)
∇f = Dict{GenericVariableRef{T},Any}()
for xi in MOI.Nonlinear.SymbolicAD.variables(g)
df_dx = MOI.Nonlinear.SymbolicAD.simplify!(
Expand Down
4 changes: 2 additions & 2 deletions src/objective.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ end

function set_objective_function(model::GenericModel, func::AbstractJuMPScalar)
check_belongs_to_model(func, model)
set_objective_function(model, moi_function(func))
set_objective_function(model, moi_function(model, func))
return
end

Expand All @@ -296,7 +296,7 @@ function set_objective_function(
for f in func
check_belongs_to_model(f, model)
end
set_objective_function(model, moi_function(func))
set_objective_function(model, moi_function(model, func))
return
end

Expand Down
12 changes: 6 additions & 6 deletions test/Containers/test_DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,11 +740,11 @@ end

function test_containers_denseaxisarray_kwarg_indexing_slicing()
Containers.@container(x[i=2:3, j=1:2], i + j)
y = x[i=2, j = :]
y = x[i=2, j=:]
@test y[j=2] == 4
y = x[i = :, j=1]
y = x[i=:, j=1]
@test y[i=3] == 4
y = x[i = :, j = :]
y = x[i=:, j=:]
@test y[i=3, j=1] == 4
return
end
Expand Down Expand Up @@ -801,11 +801,11 @@ end
function test_containers_denseaxisarrayview_kwarg_indexing_slicing()
Containers.@container(a[i=2:3, j=1:2], i + j)
x = view(a, :, :)
y = x[i=2, j = :]
y = x[i=2, j=:]
@test y[j=2] == 4
y = x[i = :, j=1]
y = x[i=:, j=1]
@test y[i=3] == 4
y = x[i = :, j = :]
y = x[i=:, j=:]
@test y[i=3, j=1] == 4
return
end
Expand Down
6 changes: 3 additions & 3 deletions test/Containers/test_SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ end

function test_containers_sparseaxisarray_kwarg_indexing_slicing()
Containers.@container(x[i=2:3, j=1:2], i + j, container = SparseAxisArray)
y = x[i=2, j = :]
y = x[i=2, j=:]
@test y[j=2] == 4
y = x[i = :, j=1]
y = x[i=:, j=1]
@test y[i=3] == 4
y = x[i = :, j = :]
y = x[i=:, j=:]
@test y[i=3, j=1] == 4
return
end
Expand Down
Loading
Loading