make sure we handle error if the current user is invalid.

This commit is contained in:
Jason Kulatunga 2023-01-15 11:07:41 -08:00
parent cdf7f83777
commit e68900b1bc
3 changed files with 92 additions and 23 deletions

View File

@ -14,7 +14,7 @@ type DatabaseRepository interface {
CreateUser(context.Context, *models.User) error
GetUserByUsername(context.Context, string) (*models.User, error)
GetCurrentUser(context.Context) *models.User
GetCurrentUser(ctx context.Context) (*models.User, error)
GetSummary(ctx context.Context) (*models.Summary, error)

View File

@ -122,7 +122,8 @@ func (sr *SqliteRepository) GetUserByUsername(ctx context.Context, username stri
return &foundUser, result.Error
}
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
//TODO: check for error, right now we return a nil which may cause a panic.
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) (*models.User, error) {
username := ctx.Value(pkg.ContextKeyTypeAuthUsername)
if username == nil {
ginCtx := ctx.(*gin.Context)
@ -130,9 +131,13 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
}
var currentUser models.User
sr.GormClient.First(&currentUser, models.User{Username: username.(string)})
result := sr.GormClient.First(&currentUser, models.User{Username: username.(string)})
return &currentUser
if result.Error != nil {
return nil, fmt.Errorf("could not retrieve current user: %v", result.Error)
}
return &currentUser, nil
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -140,6 +145,10 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
// we want a count of all resources for this user by type
var resourceCountResults []map[string]interface{}
@ -151,7 +160,7 @@ func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, er
Select("source_resource_type as resource_type, count(*) as count").
Group("source_resource_type").
Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
}).
Scan(&resourceCountResults)
if result.Error != nil {
@ -235,7 +244,12 @@ func (sr *SqliteRepository) UpsertRawResource(ctx context.Context, sourceCredent
//this method will upsert a resource, however it will not create associations.
func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) {
wrappedResourceModel.UserID = sr.GetCurrentUser(ctx).ID
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return false, currentUserErr
}
wrappedResourceModel.UserID = currentUser.ID
wrappedResourceModel.RelatedResourceFhir = nil
cachedResourceRaw := wrappedResourceModel.ResourceRaw
@ -266,10 +280,14 @@ func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceM
}
func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
},
}
@ -305,9 +323,14 @@ func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions mode
}
func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
SourceResourceType: sourceResourceType,
SourceResourceID: sourceResourceId,
},
@ -322,6 +345,11 @@ func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceR
}
func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceIdUUID, err := uuid.Parse(sourceId)
if err != nil {
return nil, err
@ -329,7 +357,7 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
SourceID: sourceIdUUID,
SourceResourceID: sourceResourceId,
},
@ -345,6 +373,10 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
// Get the patient for each source (for the current user)
func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
//SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id
@ -358,7 +390,7 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
Model(models.ResourceFhir{}).
//Group("source_id"). //broken in Postgres.
Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
SourceResourceType: "Patient",
}).
Find(&wrappedResourceModels)
@ -370,6 +402,11 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
// Generate a graph
// return list of root nodes, and their flattened related resources.
func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, nil, currentUserErr
}
// Get list of all resources
wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{})
if err != nil {
@ -383,7 +420,7 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
result := sr.GormClient.WithContext(ctx).
Table("related_resources").
Where(models.RelatedResource{
ResourceFhirUserID: sr.GetCurrentUser(ctx).ID,
ResourceFhirUserID: currentUser.ID,
}).
Scan(&relatedResourceRelationships)
if result.Error != nil {
@ -528,10 +565,13 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
//related resources that are "Condition" or "Encounter"
func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error {
//ensure that the sources are "owned" by the same user
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
if source.UserID != relatedSource.UserID {
return fmt.Errorf("user id's must match when adding associations")
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
} else if source.UserID != currentUser.ID {
return fmt.Errorf("user id's must match current user")
}
@ -647,9 +687,14 @@ func (sr *SqliteRepository) AddResourceAssociation(ctx context.Context, source *
}
func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
if source.UserID != relatedSource.UserID {
return fmt.Errorf("user id's must match when adding associations")
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
} else if source.UserID != currentUser.ID {
return fmt.Errorf("user id's must match current user")
}
@ -679,7 +724,10 @@ func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, sourc
// - add AddResourceAssociation for all resources linked to the Composition resource
// - store the Composition resource
func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error {
currentUser := sr.GetCurrentUser(ctx)
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
//generate placeholder source
placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}}
@ -821,7 +869,11 @@ func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, composit
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error {
sourceCreds.UserID = sr.GetCurrentUser(ctx).ID
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
sourceCreds.UserID = currentUser.ID
//Assign will **always** update the source credential in the DB with data passed into this function.
return sr.GormClient.WithContext(ctx).
@ -833,6 +885,11 @@ func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *model
}
func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceUUID, err := uuid.Parse(sourceId)
if err != nil {
return nil, err
@ -840,13 +897,18 @@ func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*mo
var sourceCred models.SourceCredential
results := sr.GormClient.WithContext(ctx).
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
Where(models.SourceCredential{UserID: currentUser.ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
First(&sourceCred)
return &sourceCred, results.Error
}
func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceUUID, err := uuid.Parse(sourceId)
if err != nil {
return nil, err
@ -870,7 +932,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
Select("source_id, source_resource_type as resource_type, count(*) as count").
Group("source_resource_type").
Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
SourceID: sourceUUID,
}).
Scan(&resourceTypeCounts)
@ -885,7 +947,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
var wrappedPatientResourceModel models.ResourceFhir
results := sr.GormClient.WithContext(ctx).
Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID,
UserID: currentUser.ID,
SourceResourceType: "Patient",
SourceID: sourceUUID,
}).
@ -900,10 +962,14 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
}
func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
var sourceCreds []models.SourceCredential
results := sr.GormClient.WithContext(ctx).
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID}).
Where(models.SourceCredential{UserID: currentUser.ID}).
Find(&sourceCreds)
return sourceCreds, results.Error

View File

@ -210,9 +210,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
require.NoError(suite.T(), err)
//test
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
//assert
require.NoError(suite.T(), err)
require.NotNil(suite.T(), userModelResult)
require.Equal(suite.T(), userModelResult.Username, "test_username")
}
@ -235,9 +236,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithGinContextBackgroundAut
//test
ginContext := gin.Context{}
ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username")
userModelResult := dbRepo.GetCurrentUser(&ginContext)
userModelResult, err := dbRepo.GetCurrentUser(&ginContext)
//assert
require.NoError(suite.T(), err)
require.NotNil(suite.T(), userModelResult)
require.Equal(suite.T(), userModelResult.Username, "test_username")
}
@ -251,9 +253,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
require.NoError(suite.T(), err)
//test
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
//assert
require.Error(suite.T(), err)
require.Nil(suite.T(), userModelResult)
}