make sure we handle error if the current user is invalid.

This commit is contained in:
Jason Kulatunga 2023-01-15 11:07:41 -08:00
parent cdf7f83777
commit e68900b1bc
3 changed files with 92 additions and 23 deletions

View File

@ -14,7 +14,7 @@ type DatabaseRepository interface {
CreateUser(context.Context, *models.User) error CreateUser(context.Context, *models.User) error
GetUserByUsername(context.Context, string) (*models.User, error) GetUserByUsername(context.Context, string) (*models.User, error)
GetCurrentUser(context.Context) *models.User GetCurrentUser(ctx context.Context) (*models.User, error)
GetSummary(ctx context.Context) (*models.Summary, error) GetSummary(ctx context.Context) (*models.Summary, error)

View File

@ -122,7 +122,8 @@ func (sr *SqliteRepository) GetUserByUsername(ctx context.Context, username stri
return &foundUser, result.Error return &foundUser, result.Error
} }
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User { //TODO: check for error, right now we return a nil which may cause a panic.
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) (*models.User, error) {
username := ctx.Value(pkg.ContextKeyTypeAuthUsername) username := ctx.Value(pkg.ContextKeyTypeAuthUsername)
if username == nil { if username == nil {
ginCtx := ctx.(*gin.Context) ginCtx := ctx.(*gin.Context)
@ -130,9 +131,13 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
} }
var currentUser models.User var currentUser models.User
sr.GormClient.First(&currentUser, models.User{Username: username.(string)}) result := sr.GormClient.First(&currentUser, models.User{Username: username.(string)})
return &currentUser if result.Error != nil {
return nil, fmt.Errorf("could not retrieve current user: %v", result.Error)
}
return &currentUser, nil
} }
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -140,6 +145,10 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) { func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
// we want a count of all resources for this user by type // we want a count of all resources for this user by type
var resourceCountResults []map[string]interface{} var resourceCountResults []map[string]interface{}
@ -151,7 +160,7 @@ func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, er
Select("source_resource_type as resource_type, count(*) as count"). Select("source_resource_type as resource_type, count(*) as count").
Group("source_resource_type"). Group("source_resource_type").
Where(models.OriginBase{ Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
}). }).
Scan(&resourceCountResults) Scan(&resourceCountResults)
if result.Error != nil { if result.Error != nil {
@ -235,7 +244,12 @@ func (sr *SqliteRepository) UpsertRawResource(ctx context.Context, sourceCredent
//this method will upsert a resource, however it will not create associations. //this method will upsert a resource, however it will not create associations.
func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) { func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) {
wrappedResourceModel.UserID = sr.GetCurrentUser(ctx).ID currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return false, currentUserErr
}
wrappedResourceModel.UserID = currentUser.ID
wrappedResourceModel.RelatedResourceFhir = nil wrappedResourceModel.RelatedResourceFhir = nil
cachedResourceRaw := wrappedResourceModel.ResourceRaw cachedResourceRaw := wrappedResourceModel.ResourceRaw
@ -266,10 +280,14 @@ func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceM
} }
func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) { func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
queryParam := models.ResourceFhir{ queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{ OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
}, },
} }
@ -305,9 +323,14 @@ func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions mode
} }
func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) { func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
queryParam := models.ResourceFhir{ queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{ OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
SourceResourceType: sourceResourceType, SourceResourceType: sourceResourceType,
SourceResourceID: sourceResourceId, SourceResourceID: sourceResourceId,
}, },
@ -322,6 +345,11 @@ func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceR
} }
func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) { func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceIdUUID, err := uuid.Parse(sourceId) sourceIdUUID, err := uuid.Parse(sourceId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -329,7 +357,7 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
queryParam := models.ResourceFhir{ queryParam := models.ResourceFhir{
OriginBase: models.OriginBase{ OriginBase: models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
SourceID: sourceIdUUID, SourceID: sourceIdUUID,
SourceResourceID: sourceResourceId, SourceResourceID: sourceResourceId,
}, },
@ -345,6 +373,10 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
// Get the patient for each source (for the current user) // Get the patient for each source (for the current user)
func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) { func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
//SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id //SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id
@ -358,7 +390,7 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
Model(models.ResourceFhir{}). Model(models.ResourceFhir{}).
//Group("source_id"). //broken in Postgres. //Group("source_id"). //broken in Postgres.
Where(models.OriginBase{ Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
SourceResourceType: "Patient", SourceResourceType: "Patient",
}). }).
Find(&wrappedResourceModels) Find(&wrappedResourceModels)
@ -370,6 +402,11 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
// Generate a graph // Generate a graph
// return list of root nodes, and their flattened related resources. // return list of root nodes, and their flattened related resources.
func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) { func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, nil, currentUserErr
}
// Get list of all resources // Get list of all resources
wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{}) wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{})
if err != nil { if err != nil {
@ -383,7 +420,7 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
result := sr.GormClient.WithContext(ctx). result := sr.GormClient.WithContext(ctx).
Table("related_resources"). Table("related_resources").
Where(models.RelatedResource{ Where(models.RelatedResource{
ResourceFhirUserID: sr.GetCurrentUser(ctx).ID, ResourceFhirUserID: currentUser.ID,
}). }).
Scan(&relatedResourceRelationships) Scan(&relatedResourceRelationships)
if result.Error != nil { if result.Error != nil {
@ -528,10 +565,13 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
//related resources that are "Condition" or "Encounter" //related resources that are "Condition" or "Encounter"
func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error { func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error {
//ensure that the sources are "owned" by the same user //ensure that the sources are "owned" by the same user
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
if source.UserID != relatedSource.UserID { if source.UserID != relatedSource.UserID {
return fmt.Errorf("user id's must match when adding associations") return fmt.Errorf("user id's must match when adding associations")
} else if source.UserID != sr.GetCurrentUser(ctx).ID { } else if source.UserID != currentUser.ID {
return fmt.Errorf("user id's must match current user") return fmt.Errorf("user id's must match current user")
} }
@ -647,9 +687,14 @@ func (sr *SqliteRepository) AddResourceAssociation(ctx context.Context, source *
} }
func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error { func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
if source.UserID != relatedSource.UserID { if source.UserID != relatedSource.UserID {
return fmt.Errorf("user id's must match when adding associations") return fmt.Errorf("user id's must match when adding associations")
} else if source.UserID != sr.GetCurrentUser(ctx).ID { } else if source.UserID != currentUser.ID {
return fmt.Errorf("user id's must match current user") return fmt.Errorf("user id's must match current user")
} }
@ -679,7 +724,10 @@ func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, sourc
// - add AddResourceAssociation for all resources linked to the Composition resource // - add AddResourceAssociation for all resources linked to the Composition resource
// - store the Composition resource // - store the Composition resource
func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error { func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error {
currentUser := sr.GetCurrentUser(ctx) currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
//generate placeholder source //generate placeholder source
placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}} placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}}
@ -821,7 +869,11 @@ func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, composit
//////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error { func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error {
sourceCreds.UserID = sr.GetCurrentUser(ctx).ID currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return currentUserErr
}
sourceCreds.UserID = currentUser.ID
//Assign will **always** update the source credential in the DB with data passed into this function. //Assign will **always** update the source credential in the DB with data passed into this function.
return sr.GormClient.WithContext(ctx). return sr.GormClient.WithContext(ctx).
@ -833,6 +885,11 @@ func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *model
} }
func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) { func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceUUID, err := uuid.Parse(sourceId) sourceUUID, err := uuid.Parse(sourceId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -840,13 +897,18 @@ func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*mo
var sourceCred models.SourceCredential var sourceCred models.SourceCredential
results := sr.GormClient.WithContext(ctx). results := sr.GormClient.WithContext(ctx).
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID, ModelBase: models.ModelBase{ID: sourceUUID}}). Where(models.SourceCredential{UserID: currentUser.ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
First(&sourceCred) First(&sourceCred)
return &sourceCred, results.Error return &sourceCred, results.Error
} }
func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) { func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
sourceUUID, err := uuid.Parse(sourceId) sourceUUID, err := uuid.Parse(sourceId)
if err != nil { if err != nil {
return nil, err return nil, err
@ -870,7 +932,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
Select("source_id, source_resource_type as resource_type, count(*) as count"). Select("source_id, source_resource_type as resource_type, count(*) as count").
Group("source_resource_type"). Group("source_resource_type").
Where(models.OriginBase{ Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
SourceID: sourceUUID, SourceID: sourceUUID,
}). }).
Scan(&resourceTypeCounts) Scan(&resourceTypeCounts)
@ -885,7 +947,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
var wrappedPatientResourceModel models.ResourceFhir var wrappedPatientResourceModel models.ResourceFhir
results := sr.GormClient.WithContext(ctx). results := sr.GormClient.WithContext(ctx).
Where(models.OriginBase{ Where(models.OriginBase{
UserID: sr.GetCurrentUser(ctx).ID, UserID: currentUser.ID,
SourceResourceType: "Patient", SourceResourceType: "Patient",
SourceID: sourceUUID, SourceID: sourceUUID,
}). }).
@ -900,10 +962,14 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
} }
func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) { func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) {
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
if currentUserErr != nil {
return nil, currentUserErr
}
var sourceCreds []models.SourceCredential var sourceCreds []models.SourceCredential
results := sr.GormClient.WithContext(ctx). results := sr.GormClient.WithContext(ctx).
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID}). Where(models.SourceCredential{UserID: currentUser.ID}).
Find(&sourceCreds) Find(&sourceCreds)
return sourceCreds, results.Error return sourceCreds, results.Error

View File

@ -210,9 +210,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
//test //test
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username")) userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
//assert //assert
require.NoError(suite.T(), err)
require.NotNil(suite.T(), userModelResult) require.NotNil(suite.T(), userModelResult)
require.Equal(suite.T(), userModelResult.Username, "test_username") require.Equal(suite.T(), userModelResult.Username, "test_username")
} }
@ -235,9 +236,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithGinContextBackgroundAut
//test //test
ginContext := gin.Context{} ginContext := gin.Context{}
ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username") ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username")
userModelResult := dbRepo.GetCurrentUser(&ginContext) userModelResult, err := dbRepo.GetCurrentUser(&ginContext)
//assert //assert
require.NoError(suite.T(), err)
require.NotNil(suite.T(), userModelResult) require.NotNil(suite.T(), userModelResult)
require.Equal(suite.T(), userModelResult.Username, "test_username") require.Equal(suite.T(), userModelResult.Username, "test_username")
} }
@ -251,9 +253,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
require.NoError(suite.T(), err) require.NoError(suite.T(), err)
//test //test
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username")) userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
//assert //assert
require.Error(suite.T(), err)
require.Nil(suite.T(), userModelResult) require.Nil(suite.T(), userModelResult)
} }