We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 9f80bdc commit 6aeed97Copy full SHA for 6aeed97
pytensor/link/jax/linker.py
@@ -117,10 +117,8 @@ def create_thunk_inputs(self, storage_map):
117
for n in self.fgraph.inputs:
118
sinput = storage_map[n]
119
if isinstance(sinput[0], Generator):
120
- new_value = jax_typify(
121
- sinput[0], dtype=getattr(sinput[0], "dtype", None)
122
- )
123
- sinput[0] = new_value
+ # Neet to convert Generator into JAX PRNGkey
+ sinput[0] = jax_typify(sinput[0])
124
thunk_inputs.append(sinput)
125
126
return thunk_inputs
0 commit comments