fix(server): Fixing non parameters in quantize script `bigcode/starcoder` was an example. (#661)

This commit is contained in:
Nicolas Patry 2023-07-20 16:04:15 +02:00 committed by GitHub
parent 362883f259
commit 08b8eec1d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -812,10 +812,13 @@ def load_weights_pre_hook(module_name, weights, recursive=False):
tensor = weights.get_tensor(tensor_name)
setdeepattr(module, local_param, nn.Parameter(tensor))
else:
tensor = current_tensor.to(device=torch.device("cuda:0"))
if current_tensor.requires_grad:
tensor = nn.Parameter(tensor)
setdeepattr(
module,
local_param,
nn.Parameter(current_tensor.to(device=torch.device("cuda:0"))),
tensor
)
return inner