[flax] 'dtype' should not be part of self._internal_dict (#609)

This commit is contained in:
Mishig Davaadorj 2022-09-22 11:46:31 +02:00 committed by GitHub
parent 4b8880a306
commit 534512bedb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 0 deletions

View File

@ -456,6 +456,9 @@ def flax_register_to_config(cls):
# Make sure init_kwargs override default kwargs
new_kwargs = {**default_kwargs, **init_kwargs}
# dtype should be part of `init_kwargs`, but not `new_kwargs`
if "dtype" in new_kwargs:
new_kwargs.pop("dtype")
# Get positional arguments aligned with kwargs
for i, arg in enumerate(args):