when making requests to the db, use the current user context.

This commit is contained in:
Jason Kulatunga 2022-09-12 21:20:56 -04:00
parent 3811599c19
commit 1535f139c1
4 changed files with 27 additions and 21 deletions

View File

@ -30,7 +30,7 @@ func GenerateJWT(username string) (tokenString string, err error) {
return return
} }
func ValidateToken(signedToken string) (err error) { func ValidateToken(signedToken string) (*JWTClaim, error) {
token, err := jwt.ParseWithClaims( token, err := jwt.ParseWithClaims(
signedToken, signedToken,
&JWTClaim{}, &JWTClaim{},
@ -42,16 +42,16 @@ func ValidateToken(signedToken string) (err error) {
}, },
) )
if err != nil { if err != nil {
return return nil, err
} }
claims, ok := token.Claims.(*JWTClaim) claims, ok := token.Claims.(*JWTClaim)
if !ok { if !ok {
err = errors.New("couldn't parse claims") err = errors.New("couldn't parse claims")
return return nil, err
} }
if claims.ExpiresAt < time.Now().Local().Unix() { if claims.ExpiresAt < time.Now().Local().Unix() {
err = errors.New("token expired") err = errors.New("token expired")
return return nil, err
} }
return return claims, nil
} }

View File

@ -11,7 +11,7 @@ type DatabaseRepository interface {
CreateUser(context.Context, *models.User) error CreateUser(context.Context, *models.User) error
GetUserByEmail(context.Context, string) (*models.User, error) GetUserByEmail(context.Context, string) (*models.User, error)
GetCurrentUser() models.User GetCurrentUser(context.Context) models.User
UpsertResource(context.Context, models.ResourceFhir) error UpsertResource(context.Context, models.ResourceFhir) error
ListResources(context.Context, string, string) ([]models.ResourceFhir, error) ListResources(context.Context, string, string) ([]models.ResourceFhir, error)

View File

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/config" "github.com/fastenhealth/fastenhealth-onprem/backend/pkg/config"
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/models" "github.com/fastenhealth/fastenhealth-onprem/backend/pkg/models"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"gorm.io/gorm" "gorm.io/gorm"
@ -98,9 +99,10 @@ func (sr *sqliteRepository) GetUserByEmail(ctx context.Context, username string)
return &foundUser, result.Error return &foundUser, result.Error
} }
func (sr *sqliteRepository) GetCurrentUser() models.User { func (sr *sqliteRepository) GetCurrentUser(ctx context.Context) models.User {
ginCtx := ctx.(*gin.Context)
var currentUser models.User var currentUser models.User
sr.gormClient.Model(models.User{}).First(&currentUser) sr.gormClient.Model(models.User{}).First(&currentUser, models.User{Username: ginCtx.MustGet("AUTH_USERNAME").(string)})
return currentUser return currentUser
} }
@ -128,7 +130,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
queryParam := models.ResourceFhir{ queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{ OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser().ID, UserID: sr.GetCurrentUser(ctx).ID,
SourceResourceType: sourceResourceType, SourceResourceType: sourceResourceType,
}, },
} }
@ -150,7 +152,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *sqliteRepository) CreateSource(ctx context.Context, providerCreds *models.Source) error { func (sr *sqliteRepository) CreateSource(ctx context.Context, providerCreds *models.Source) error {
providerCreds.UserID = sr.GetCurrentUser().ID providerCreds.UserID = sr.GetCurrentUser(ctx).ID
if sr.gormClient.WithContext(ctx).Model(&providerCreds). if sr.gormClient.WithContext(ctx).Model(&providerCreds).
Where(models.Source{ Where(models.Source{
@ -166,7 +168,7 @@ func (sr *sqliteRepository) GetSources(ctx context.Context) ([]models.Source, er
var providerCredentials []models.Source var providerCredentials []models.Source
results := sr.gormClient.WithContext(ctx). results := sr.gormClient.WithContext(ctx).
Where(models.Source{UserID: sr.GetCurrentUser().ID}). Where(models.Source{UserID: sr.GetCurrentUser(ctx).ID}).
Find(&providerCredentials) Find(&providerCredentials)
return providerCredentials, results.Error return providerCredentials, results.Error

View File

@ -9,30 +9,34 @@ import (
) )
func RequireAuth() gin.HandlerFunc { func RequireAuth() gin.HandlerFunc {
return func(context *gin.Context) { return func(c *gin.Context) {
authHeader := context.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
authHeaderParts := strings.Split(authHeader, " ") authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 { if len(authHeaderParts) != 2 {
log.Println("Authentication header is invalid: " + authHeader) log.Println("Authentication header is invalid: " + authHeader)
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"}) c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"})
context.Abort() c.Abort()
return return
} }
tokenString := authHeaderParts[1] tokenString := authHeaderParts[1]
if tokenString == "" { if tokenString == "" {
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"}) c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"})
context.Abort() c.Abort()
return return
} }
err := auth.ValidateToken(tokenString) claim, err := auth.ValidateToken(tokenString)
if err != nil { if err != nil {
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()}) c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()})
context.Abort() c.Abort()
return return
} }
context.Next()
c.Set("AUTH_TOKEN", tokenString)
c.Set("AUTH_USERNAME", claim.Username)
c.Next()
} }
} }