when making requests to the db, use the current user context.
This commit is contained in:
parent
3811599c19
commit
1535f139c1
|
@ -30,7 +30,7 @@ func GenerateJWT(username string) (tokenString string, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func ValidateToken(signedToken string) (err error) {
|
||||
func ValidateToken(signedToken string) (*JWTClaim, error) {
|
||||
token, err := jwt.ParseWithClaims(
|
||||
signedToken,
|
||||
&JWTClaim{},
|
||||
|
@ -42,16 +42,16 @@ func ValidateToken(signedToken string) (err error) {
|
|||
},
|
||||
)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
claims, ok := token.Claims.(*JWTClaim)
|
||||
if !ok {
|
||||
err = errors.New("couldn't parse claims")
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
if claims.ExpiresAt < time.Now().Local().Unix() {
|
||||
err = errors.New("token expired")
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
return claims, nil
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ type DatabaseRepository interface {
|
|||
|
||||
CreateUser(context.Context, *models.User) error
|
||||
GetUserByEmail(context.Context, string) (*models.User, error)
|
||||
GetCurrentUser() models.User
|
||||
GetCurrentUser(context.Context) models.User
|
||||
|
||||
UpsertResource(context.Context, models.ResourceFhir) error
|
||||
ListResources(context.Context, string, string) ([]models.ResourceFhir, error)
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/config"
|
||||
"github.com/fastenhealth/fastenhealth-onprem/backend/pkg/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gorm.io/gorm"
|
||||
|
@ -98,9 +99,10 @@ func (sr *sqliteRepository) GetUserByEmail(ctx context.Context, username string)
|
|||
return &foundUser, result.Error
|
||||
}
|
||||
|
||||
func (sr *sqliteRepository) GetCurrentUser() models.User {
|
||||
func (sr *sqliteRepository) GetCurrentUser(ctx context.Context) models.User {
|
||||
ginCtx := ctx.(*gin.Context)
|
||||
var currentUser models.User
|
||||
sr.gormClient.Model(models.User{}).First(¤tUser)
|
||||
sr.gormClient.Model(models.User{}).First(¤tUser, models.User{Username: ginCtx.MustGet("AUTH_USERNAME").(string)})
|
||||
|
||||
return currentUser
|
||||
}
|
||||
|
@ -128,7 +130,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
|
|||
|
||||
queryParam := models.ResourceFhir{
|
||||
OriginBase: models.OriginBase{
|
||||
UserID: sr.GetCurrentUser().ID,
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
SourceResourceType: sourceResourceType,
|
||||
},
|
||||
}
|
||||
|
@ -150,7 +152,7 @@ func (sr *sqliteRepository) ListResources(ctx context.Context, sourceResourceTyp
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (sr *sqliteRepository) CreateSource(ctx context.Context, providerCreds *models.Source) error {
|
||||
providerCreds.UserID = sr.GetCurrentUser().ID
|
||||
providerCreds.UserID = sr.GetCurrentUser(ctx).ID
|
||||
|
||||
if sr.gormClient.WithContext(ctx).Model(&providerCreds).
|
||||
Where(models.Source{
|
||||
|
@ -166,7 +168,7 @@ func (sr *sqliteRepository) GetSources(ctx context.Context) ([]models.Source, er
|
|||
|
||||
var providerCredentials []models.Source
|
||||
results := sr.gormClient.WithContext(ctx).
|
||||
Where(models.Source{UserID: sr.GetCurrentUser().ID}).
|
||||
Where(models.Source{UserID: sr.GetCurrentUser(ctx).ID}).
|
||||
Find(&providerCredentials)
|
||||
|
||||
return providerCredentials, results.Error
|
||||
|
|
|
@ -9,30 +9,34 @@ import (
|
|||
)
|
||||
|
||||
func RequireAuth() gin.HandlerFunc {
|
||||
return func(context *gin.Context) {
|
||||
authHeader := context.GetHeader("Authorization")
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
authHeaderParts := strings.Split(authHeader, " ")
|
||||
|
||||
if len(authHeaderParts) != 2 {
|
||||
log.Println("Authentication header is invalid: " + authHeader)
|
||||
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"})
|
||||
context.Abort()
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain a valid token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
tokenString := authHeaderParts[1]
|
||||
|
||||
if tokenString == "" {
|
||||
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"})
|
||||
context.Abort()
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": "request does not contain an access token"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
err := auth.ValidateToken(tokenString)
|
||||
claim, err := auth.ValidateToken(tokenString)
|
||||
if err != nil {
|
||||
context.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()})
|
||||
context.Abort()
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"success": false, "error": err.Error()})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
context.Next()
|
||||
|
||||
c.Set("AUTH_TOKEN", tokenString)
|
||||
c.Set("AUTH_USERNAME", claim.Username)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue