diff --git a/models/user/user.go b/models/user/user.go index 454779b9ea..7f2249925b 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -613,7 +613,7 @@ func CreateUser(u *User, overwriteDefault ...*CreateUserOverwriteOptions) (err e } // validate data - if err := validateUser(u); err != nil { + if err := ValidateUser(u); err != nil { return err } @@ -803,19 +803,26 @@ func checkDupEmail(ctx context.Context, u *User) error { return nil } -// validateUser check if user is valid to insert / update into database -func validateUser(u *User) error { - if !setting.Service.AllowedUserVisibilityModesSlice.IsAllowedVisibility(u.Visibility) && !u.IsOrganization() { - return fmt.Errorf("visibility Mode not allowed: %s", u.Visibility.String()) +// ValidateUser check if user is valid to insert / update into database +func ValidateUser(u *User, cols ...string) error { + if len(cols) == 0 || util.SliceContainsString(cols, "visibility", true) { + if !setting.Service.AllowedUserVisibilityModesSlice.IsAllowedVisibility(u.Visibility) && !u.IsOrganization() { + return fmt.Errorf("visibility Mode not allowed: %s", u.Visibility.String()) + } } - u.Email = strings.ToLower(u.Email) - return ValidateEmail(u.Email) + if len(cols) == 0 || util.SliceContainsString(cols, "email", true) { + u.Email = strings.ToLower(u.Email) + if err := ValidateEmail(u.Email); err != nil { + return err + } + } + return nil } // UpdateUser updates user's information. func UpdateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...string) error { - err := validateUser(u) + err := ValidateUser(u, cols...) if err != nil { return err } @@ -881,7 +888,7 @@ func UpdateUser(ctx context.Context, u *User, changePrimaryEmail bool, cols ...s // UpdateUserCols update user according special columns func UpdateUserCols(ctx context.Context, u *User, cols ...string) error { - if err := validateUser(u); err != nil { + if err := ValidateUser(u, cols...); err != nil { return err } diff --git a/models/user/user_test.go b/models/user/user_test.go index 8e78fee6b3..1abd0f0049 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -5,6 +5,7 @@ package user_test import ( "context" + "fmt" "math/rand" "strings" "testing" @@ -524,3 +525,21 @@ func TestIsUserVisibleToViewer(t *testing.T) { test(user31, user33, true) test(user31, nil, false) } + +func Test_ValidateUser(t *testing.T) { + oldSetting := setting.Service.AllowedUserVisibilityModesSlice + defer func() { + setting.Service.AllowedUserVisibilityModesSlice = oldSetting + }() + setting.Service.AllowedUserVisibilityModesSlice = []bool{true, false, true} + kases := map[*user_model.User]bool{ + {ID: 1, Visibility: structs.VisibleTypePublic}: true, + {ID: 2, Visibility: structs.VisibleTypeLimited}: false, + {ID: 2, Visibility: structs.VisibleTypeLimited, Email: "invalid"}: false, + {ID: 2, Visibility: structs.VisibleTypePrivate, Email: "valid@valid.com"}: true, + } + for kase, expected := range kases { + err := user_model.ValidateUser(kase) + assert.EqualValues(t, expected, err == nil, fmt.Sprintf("case: %+v", kase)) + } +}