6f15ac60b2
This PR adds `force_downcast_after` to `FastRMSNorm.forward` which is used in the Gemma model. References https://github.com/huggingface/transformers/pull/29402 and https://github.com/huggingface/transformers/pull/29729 Setting `force_downcast_after=True` will perform the `hidden_states * weight` multiplication in f32 and then downcast to half. This differs slightly from the current implementation which first casts the `hidden_states` to a half and then multiples. |
||
---|---|---|
.. | ||
models | ||
pb | ||
utils | ||
__init__.py | ||
cache.py | ||
cli.py | ||
interceptor.py | ||
server.py | ||
tracing.py |