Skip to content

Commit 485f776

Browse files
add object state in functional
1 parent d1ab538 commit 485f776

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

keras/src/models/functional.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,13 @@ class Functional(Function, Model):
9898
"""
9999

100100
def __new__(cls, *args, **kwargs):
101-
obj = super().__new__(cls, *args, **kwargs)
102-
return obj
101+
102+
if backend.backend() == "jax" and is_nnx_enabled():
103+
instance = super(Functional, cls).__new__(cls)
104+
from flax import nnx
105+
106+
vars(instance)["_object__state"] = nnx.object.ObjectState()
107+
return typing.cast(cls, super().__new__(cls))
103108

104109

105110
@tracking.no_automatic_dependency_tracking

0 commit comments

Comments
 (0)