diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index af3206dd..53d3ea42 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -238,7 +238,7 @@ class PhiMLP(nn.Module): ) # llama weights are up_proj and down_proj and bias=False - self.up_proj = TensorParallelRowLinear.load( + self.up_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.fc1", weights=weights,