From 1535f139c17b5cfc6e2bcd32fb06263248596aef Mon Sep 17 00:00:00 2001 From: Jason Kulatunga Date: Mon, 12 Sep 2022 21:20:56 -0400 Subject: [PATCH] when making requests to the db, use the current user context. --- backend/pkg/auth/utils.go | 10 ++++----- backend/pkg/database/interface.go | 2 +- backend/pkg/database/sqlite_repository.go | 12 ++++++----- backend/pkg/web/middleware/require_auth.go | 24 +++++++++++++--------- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/backend/pkg/auth/utils.go b/backend/pkg/auth/utils.go index b8cc4de2..cbcd4e43 100644 --- a/backend/pkg/auth/utils.go +++ b/backend/pkg/auth/utils.go @@ -30,7 +30,7 @@ func GenerateJWT(username string) (tokenString string, err error) { return } -func ValidateToken(signedToken string) (err error) { +func ValidateToken(signedToken string) (*JWTClaim, error) { token, err := jwt.ParseWithClaims( signedToken, &JWTClaim{}, @@ -42,16 +42,16 @@ func ValidateToken(signedToken string) (err error) { }, ) if err != nil { - return + return nil, err } claims, ok := token.Claims.(*JWTClaim) if !ok { err = errors.New("couldn't parse claims") - return + return nil, err } if claims.ExpiresAt < time.Now().Local().Unix() { err = errors.New("token expired") - return + return nil, err } - return + return claims, nil } diff --git a/backend/pkg/database/interface.go b/backend/pkg/database/interface.go index b8616ee2..adae76ce 100644 --- a/backend/pkg/database/interface.go +++ b/backend/pkg/database/interface.go @@ -11,7 +11,7 @@ type DatabaseRepository interface { CreateUser(context.Context, *models.User) error GetUserByEmail(context.Context, string) (*models.User, error) - GetCurrentUser() models.User + GetCurrentUser(context.Context) models.User UpsertResource(context.Context, models.ResourceFhir) error ListResources(context.Context, string, string) ([]models.ResourceFhir, error) diff --git a/backend/pkg/database/sqlite_repository.go b/backend/pkg/database/sqlite_repository.go index 28262d68..4e2bfd75 100644 --- a/backend/pkg/database/sqlite_repository.go +++ b/backend/pkg/database/sqlite_repository.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/fastenhealth/fastenhealth-onprem/backend/pkg/config" "github.com/fastenhealth/fastenhealth-onprem/backend/pkg/models" + "github.com/gin-gonic/gin" "github.com/glebarez/sqlite" "github.com/sirupsen/logrus" "gorm.io/gorm" @@ -98,9 +99,10 @@ func (sr *sqliteRepository) GetUserByEmail(ctx context.Context, username string) return &foundUser, result.Error } -func (sr *sqliteRepository) GetCurrentUser() models.User { +func (sr *sqliteRepository) GetCurrentUser(ctx context.Context) models.User { + ginCtx := ctx.(*gin.Context) var currentUser models.User - sr.gormClient.Model(models.User{}).First(¤tUser) + sr.gormClient.Model(models.User{}).First(¤tUser, models.User{Username: ginCtx.MustGet("AUTH_USERNAME").(string)}) return currentUser } @@ -128,7 +130,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp queryParam := models.ResourceFhir{ OriginBase: models.OriginBase{ - UserID: sr.GetCurrentUser().ID, + UserID: sr.GetCurrentUser(ctx).ID, SourceResourceType: sourceResourceType, }, } @@ -150,7 +152,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// func (sr *sqliteRepository) CreateSource(ctx context.Context, providerCreds *models.Source) error { - providerCreds.UserID = sr.GetCurrentUser().ID + providerCreds.UserID = sr.GetCurrentUser(ctx).ID if sr.gormClient.WithContext(ctx).Model(&providerCreds). Where(models.Source{ @@ -166,7 +168,7 @@ func (sr *sqliteRepository) GetSources(ctx context.Context) ([]models.Source, er var providerCredentials []models.Source results := sr.gormClient.WithContext(ctx). - Where(models.Source{UserID: sr.GetCurrentUser().ID}). + Where(models.Source{UserID: sr.GetCurrentUser(ctx).ID}). Find(&providerCredentials) return providerCredentials, results.Error diff --git a/backend/pkg/web/middleware/require_auth.go b/backend/pkg/web/middleware/require_auth.go index 2f6491fe..0c56044e 100644 --- a/backend/pkg/web/middleware/require_auth.go +++ b/backend/pkg/web/middleware/require_auth.go @@ -9,30 +9,34 @@ import ( ) func RequireAuth() gin.HandlerFunc { - return func(context *gin.Context) { - authHeader := context.GetHeader("Authorization") + return func(c *gin.Context) { + authHeader := c.GetHeader("Authorization") authHeaderParts := strings.Split(authHeader, " ") if len(authHeaderParts) != 2 { log.Println("Authentication header is invalid: " + authHeader) - context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"}) - context.Abort() + c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"}) + c.Abort() return } tokenString := authHeaderParts[1] if tokenString == "" { - context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"}) - context.Abort() + c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"}) + c.Abort() return } - err := auth.ValidateToken(tokenString) + claim, err := auth.ValidateToken(tokenString) if err != nil { - context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()}) - context.Abort() + c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()}) + c.Abort() return } - context.Next() + + c.Set("AUTH_TOKEN", tokenString) + c.Set("AUTH_USERNAME", claim.Username) + + c.Next() } }