when making requests to the db, use the current user context.

This commit is contained in:
Jason Kulatunga 2022-09-12 21:20:56 -04:00
parent 3811599c19
commit 1535f139c1
4 changed files with 27 additions and 21 deletions

View File

@ -30,7 +30,7 @@ func GenerateJWT(username string) (tokenString string, err error) {
return
}
func ValidateToken(signedToken string) (err error) {
func ValidateToken(signedToken string) (*JWTClaim, error) {
token, err := jwt.ParseWithClaims(
signedToken,
&JWTClaim{},
@ -42,16 +42,16 @@ func ValidateToken(signedToken string) (err error) {
},
)
if err != nil {
return
return nil, err
}
claims, ok := token.Claims.(*JWTClaim)
if !ok {
err = errors.New("couldn't parse claims")
return
return nil, err
}
if claims.ExpiresAt < time.Now().Local().Unix() {
err = errors.New("token expired")
return
return nil, err
}
return
return claims, nil
}

View File

@ -11,7 +11,7 @@ type DatabaseRepository interface {
CreateUser(context.Context, *models.User) error
GetUserByEmail(context.Context, string) (*models.User, error)
GetCurrentUser() models.User
GetCurrentUser(context.Context) models.User
UpsertResource(context.Context, models.ResourceFhir) error
ListResources(context.Context, string, string) ([]models.ResourceFhir, error)

View File

@ -5,6 +5,7 @@ import (
"fmt"
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/config"
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/models"
"github.com/gin-gonic/gin"
"github.com/glebarez/sqlite"
"github.com/sirupsen/logrus"
"gorm.io/gorm"
@ -98,9 +99,10 @@ func (sr *sqliteRepository) GetUserByEmail(ctx context.Context, username string)
return &foundUser, result.Error
}
func (sr *sqliteRepository) GetCurrentUser() models.User {
func (sr *sqliteRepository) GetCurrentUser(ctx context.Context) models.User {
ginCtx := ctx.(*gin.Context)
var currentUser models.User
sr.gormClient.Model(models.User{}).First(&currentUser)
sr.gormClient.Model(models.User{}).First(&currentUser, models.User{Username: ginCtx.MustGet("AUTH_USERNAME").(string)})
return currentUser
}
@ -128,7 +130,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser().ID,
UserID: sr.GetCurrentUser(ctx).ID,
SourceResourceType: sourceResourceType,
},
}
@ -150,7 +152,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *sqliteRepository) CreateSource(ctx context.Context, providerCreds *models.Source) error {
providerCreds.UserID = sr.GetCurrentUser().ID
providerCreds.UserID = sr.GetCurrentUser(ctx).ID
if sr.gormClient.WithContext(ctx).Model(&providerCreds).
Where(models.Source{
@ -166,7 +168,7 @@ func (sr *sqliteRepository) GetSources(ctx context.Context) ([]models.Source, er
var providerCredentials []models.Source
results := sr.gormClient.WithContext(ctx).
Where(models.Source{UserID: sr.GetCurrentUser().ID}).
Where(models.Source{UserID: sr.GetCurrentUser(ctx).ID}).
Find(&providerCredentials)
return providerCredentials, results.Error

View File

@ -9,30 +9,34 @@ import (
)
func RequireAuth() gin.HandlerFunc {
return func(context *gin.Context) {
authHeader := context.GetHeader("Authorization")
return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
authHeaderParts := strings.Split(authHeader, " ")
if len(authHeaderParts) != 2 {
log.Println("Authentication header is invalid: " + authHeader)
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"})
context.Abort()
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"})
c.Abort()
return
}
tokenString := authHeaderParts[1]
if tokenString == "" {
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"})
context.Abort()
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"})
c.Abort()
return
}
err := auth.ValidateToken(tokenString)
claim, err := auth.ValidateToken(tokenString)
if err != nil {
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()})
context.Abort()
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()})
c.Abort()
return
}
context.Next()
c.Set("AUTH_TOKEN", tokenString)
c.Set("AUTH_USERNAME", claim.Username)
c.Next()
}
}