make sure we handle error if the current user is invalid.
This commit is contained in:
parent
cdf7f83777
commit
e68900b1bc
|
@ -14,7 +14,7 @@ type DatabaseRepository interface {
|
|||
CreateUser(context.Context, *models.User) error
|
||||
|
||||
GetUserByUsername(context.Context, string) (*models.User, error)
|
||||
GetCurrentUser(context.Context) *models.User
|
||||
GetCurrentUser(ctx context.Context) (*models.User, error)
|
||||
|
||||
GetSummary(ctx context.Context) (*models.Summary, error)
|
||||
|
||||
|
|
|
@ -122,7 +122,8 @@ func (sr *SqliteRepository) GetUserByUsername(ctx context.Context, username stri
|
|||
return &foundUser, result.Error
|
||||
}
|
||||
|
||||
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
||||
//TODO: check for error, right now we return a nil which may cause a panic.
|
||||
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) (*models.User, error) {
|
||||
username := ctx.Value(pkg.ContextKeyTypeAuthUsername)
|
||||
if username == nil {
|
||||
ginCtx := ctx.(*gin.Context)
|
||||
|
@ -130,9 +131,13 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
|||
}
|
||||
|
||||
var currentUser models.User
|
||||
sr.GormClient.First(¤tUser, models.User{Username: username.(string)})
|
||||
result := sr.GormClient.First(¤tUser, models.User{Username: username.(string)})
|
||||
|
||||
return ¤tUser
|
||||
if result.Error != nil {
|
||||
return nil, fmt.Errorf("could not retrieve current user: %v", result.Error)
|
||||
}
|
||||
|
||||
return ¤tUser, nil
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -140,6 +145,10 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
// we want a count of all resources for this user by type
|
||||
var resourceCountResults []map[string]interface{}
|
||||
|
@ -151,7 +160,7 @@ func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, er
|
|||
Select("source_resource_type as resource_type, count(*) as count").
|
||||
Group("source_resource_type").
|
||||
Where(models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
}).
|
||||
Scan(&resourceCountResults)
|
||||
if result.Error != nil {
|
||||
|
@ -235,7 +244,12 @@ func (sr *SqliteRepository) UpsertRawResource(ctx context.Context, sourceCredent
|
|||
|
||||
//this method will upsert a resource, however it will not create associations.
|
||||
func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) {
|
||||
wrappedResourceModel.UserID = sr.GetCurrentUser(ctx).ID
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return false, currentUserErr
|
||||
}
|
||||
|
||||
wrappedResourceModel.UserID = currentUser.ID
|
||||
wrappedResourceModel.RelatedResourceFhir = nil
|
||||
cachedResourceRaw := wrappedResourceModel.ResourceRaw
|
||||
|
||||
|
@ -266,10 +280,14 @@ func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceM
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
queryParam := models.ResourceFhir{
|
||||
OriginBase: models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -305,9 +323,14 @@ func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions mode
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
queryParam := models.ResourceFhir{
|
||||
OriginBase: models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
SourceResourceType: sourceResourceType,
|
||||
SourceResourceID: sourceResourceId,
|
||||
},
|
||||
|
@ -322,6 +345,11 @@ func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceR
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
sourceIdUUID, err := uuid.Parse(sourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -329,7 +357,7 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
|
|||
|
||||
queryParam := models.ResourceFhir{
|
||||
OriginBase: models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
SourceID: sourceIdUUID,
|
||||
SourceResourceID: sourceResourceId,
|
||||
},
|
||||
|
@ -345,6 +373,10 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
|
|||
|
||||
// Get the patient for each source (for the current user)
|
||||
func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
//SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id
|
||||
|
||||
|
@ -358,7 +390,7 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
|
|||
Model(models.ResourceFhir{}).
|
||||
//Group("source_id"). //broken in Postgres.
|
||||
Where(models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
SourceResourceType: "Patient",
|
||||
}).
|
||||
Find(&wrappedResourceModels)
|
||||
|
@ -370,6 +402,11 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
|
|||
// Generate a graph
|
||||
// return list of root nodes, and their flattened related resources.
|
||||
func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, nil, currentUserErr
|
||||
}
|
||||
|
||||
// Get list of all resources
|
||||
wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{})
|
||||
if err != nil {
|
||||
|
@ -383,7 +420,7 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
|
|||
result := sr.GormClient.WithContext(ctx).
|
||||
Table("related_resources").
|
||||
Where(models.RelatedResource{
|
||||
ResourceFhirUserID: sr.GetCurrentUser(ctx).ID,
|
||||
ResourceFhirUserID: currentUser.ID,
|
||||
}).
|
||||
Scan(&relatedResourceRelationships)
|
||||
if result.Error != nil {
|
||||
|
@ -528,10 +565,13 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
|
|||
//related resources that are "Condition" or "Encounter"
|
||||
func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error {
|
||||
//ensure that the sources are "owned" by the same user
|
||||
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return currentUserErr
|
||||
}
|
||||
if source.UserID != relatedSource.UserID {
|
||||
return fmt.Errorf("user id's must match when adding associations")
|
||||
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
|
||||
} else if source.UserID != currentUser.ID {
|
||||
return fmt.Errorf("user id's must match current user")
|
||||
}
|
||||
|
||||
|
@ -647,9 +687,14 @@ func (sr *SqliteRepository) AddResourceAssociation(ctx context.Context, source *
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return currentUserErr
|
||||
}
|
||||
|
||||
if source.UserID != relatedSource.UserID {
|
||||
return fmt.Errorf("user id's must match when adding associations")
|
||||
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
|
||||
} else if source.UserID != currentUser.ID {
|
||||
return fmt.Errorf("user id's must match current user")
|
||||
}
|
||||
|
||||
|
@ -679,7 +724,10 @@ func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, sourc
|
|||
// - add AddResourceAssociation for all resources linked to the Composition resource
|
||||
// - store the Composition resource
|
||||
func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error {
|
||||
currentUser := sr.GetCurrentUser(ctx)
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return currentUserErr
|
||||
}
|
||||
|
||||
//generate placeholder source
|
||||
placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}}
|
||||
|
@ -821,7 +869,11 @@ func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, composit
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error {
|
||||
sourceCreds.UserID = sr.GetCurrentUser(ctx).ID
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return currentUserErr
|
||||
}
|
||||
sourceCreds.UserID = currentUser.ID
|
||||
|
||||
//Assign will **always** update the source credential in the DB with data passed into this function.
|
||||
return sr.GormClient.WithContext(ctx).
|
||||
|
@ -833,6 +885,11 @@ func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *model
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
sourceUUID, err := uuid.Parse(sourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -840,13 +897,18 @@ func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*mo
|
|||
|
||||
var sourceCred models.SourceCredential
|
||||
results := sr.GormClient.WithContext(ctx).
|
||||
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
|
||||
Where(models.SourceCredential{UserID: currentUser.ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
|
||||
First(&sourceCred)
|
||||
|
||||
return &sourceCred, results.Error
|
||||
}
|
||||
|
||||
func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
sourceUUID, err := uuid.Parse(sourceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -870,7 +932,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
|||
Select("source_id, source_resource_type as resource_type, count(*) as count").
|
||||
Group("source_resource_type").
|
||||
Where(models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
SourceID: sourceUUID,
|
||||
}).
|
||||
Scan(&resourceTypeCounts)
|
||||
|
@ -885,7 +947,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
|||
var wrappedPatientResourceModel models.ResourceFhir
|
||||
results := sr.GormClient.WithContext(ctx).
|
||||
Where(models.OriginBase{
|
||||
UserID: sr.GetCurrentUser(ctx).ID,
|
||||
UserID: currentUser.ID,
|
||||
SourceResourceType: "Patient",
|
||||
SourceID: sourceUUID,
|
||||
}).
|
||||
|
@ -900,10 +962,14 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
|||
}
|
||||
|
||||
func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) {
|
||||
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||
if currentUserErr != nil {
|
||||
return nil, currentUserErr
|
||||
}
|
||||
|
||||
var sourceCreds []models.SourceCredential
|
||||
results := sr.GormClient.WithContext(ctx).
|
||||
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID}).
|
||||
Where(models.SourceCredential{UserID: currentUser.ID}).
|
||||
Find(&sourceCreds)
|
||||
|
||||
return sourceCreds, results.Error
|
||||
|
|
|
@ -210,9 +210,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
|
|||
require.NoError(suite.T(), err)
|
||||
|
||||
//test
|
||||
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||
|
||||
//assert
|
||||
require.NoError(suite.T(), err)
|
||||
require.NotNil(suite.T(), userModelResult)
|
||||
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
||||
}
|
||||
|
@ -235,9 +236,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithGinContextBackgroundAut
|
|||
//test
|
||||
ginContext := gin.Context{}
|
||||
ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username")
|
||||
userModelResult := dbRepo.GetCurrentUser(&ginContext)
|
||||
userModelResult, err := dbRepo.GetCurrentUser(&ginContext)
|
||||
|
||||
//assert
|
||||
require.NoError(suite.T(), err)
|
||||
require.NotNil(suite.T(), userModelResult)
|
||||
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
||||
}
|
||||
|
@ -251,9 +253,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
|
|||
require.NoError(suite.T(), err)
|
||||
|
||||
//test
|
||||
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||
|
||||
//assert
|
||||
require.Error(suite.T(), err)
|
||||
require.Nil(suite.T(), userModelResult)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue