make sure we handle error if the current user is invalid.
This commit is contained in:
parent
cdf7f83777
commit
e68900b1bc
|
@ -14,7 +14,7 @@ type DatabaseRepository interface {
|
||||||
CreateUser(context.Context, *models.User) error
|
CreateUser(context.Context, *models.User) error
|
||||||
|
|
||||||
GetUserByUsername(context.Context, string) (*models.User, error)
|
GetUserByUsername(context.Context, string) (*models.User, error)
|
||||||
GetCurrentUser(context.Context) *models.User
|
GetCurrentUser(ctx context.Context) (*models.User, error)
|
||||||
|
|
||||||
GetSummary(ctx context.Context) (*models.Summary, error)
|
GetSummary(ctx context.Context) (*models.Summary, error)
|
||||||
|
|
||||||
|
|
|
@ -122,7 +122,8 @@ func (sr *SqliteRepository) GetUserByUsername(ctx context.Context, username stri
|
||||||
return &foundUser, result.Error
|
return &foundUser, result.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
//TODO: check for error, right now we return a nil which may cause a panic.
|
||||||
|
func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) (*models.User, error) {
|
||||||
username := ctx.Value(pkg.ContextKeyTypeAuthUsername)
|
username := ctx.Value(pkg.ContextKeyTypeAuthUsername)
|
||||||
if username == nil {
|
if username == nil {
|
||||||
ginCtx := ctx.(*gin.Context)
|
ginCtx := ctx.(*gin.Context)
|
||||||
|
@ -130,9 +131,13 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
||||||
}
|
}
|
||||||
|
|
||||||
var currentUser models.User
|
var currentUser models.User
|
||||||
sr.GormClient.First(¤tUser, models.User{Username: username.(string)})
|
result := sr.GormClient.First(¤tUser, models.User{Username: username.(string)})
|
||||||
|
|
||||||
return ¤tUser
|
if result.Error != nil {
|
||||||
|
return nil, fmt.Errorf("could not retrieve current user: %v", result.Error)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ¤tUser, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -140,6 +145,10 @@ func (sr *SqliteRepository) GetCurrentUser(ctx context.Context) *models.User {
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) {
|
func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
// we want a count of all resources for this user by type
|
// we want a count of all resources for this user by type
|
||||||
var resourceCountResults []map[string]interface{}
|
var resourceCountResults []map[string]interface{}
|
||||||
|
@ -151,7 +160,7 @@ func (sr *SqliteRepository) GetSummary(ctx context.Context) (*models.Summary, er
|
||||||
Select("source_resource_type as resource_type, count(*) as count").
|
Select("source_resource_type as resource_type, count(*) as count").
|
||||||
Group("source_resource_type").
|
Group("source_resource_type").
|
||||||
Where(models.OriginBase{
|
Where(models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
}).
|
}).
|
||||||
Scan(&resourceCountResults)
|
Scan(&resourceCountResults)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
@ -235,7 +244,12 @@ func (sr *SqliteRepository) UpsertRawResource(ctx context.Context, sourceCredent
|
||||||
|
|
||||||
//this method will upsert a resource, however it will not create associations.
|
//this method will upsert a resource, however it will not create associations.
|
||||||
func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) {
|
func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceModel *models.ResourceFhir) (bool, error) {
|
||||||
wrappedResourceModel.UserID = sr.GetCurrentUser(ctx).ID
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return false, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
|
wrappedResourceModel.UserID = currentUser.ID
|
||||||
wrappedResourceModel.RelatedResourceFhir = nil
|
wrappedResourceModel.RelatedResourceFhir = nil
|
||||||
cachedResourceRaw := wrappedResourceModel.ResourceRaw
|
cachedResourceRaw := wrappedResourceModel.ResourceRaw
|
||||||
|
|
||||||
|
@ -266,10 +280,14 @@ func (sr *SqliteRepository) UpsertResource(ctx context.Context, wrappedResourceM
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) {
|
func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions models.ListResourceQueryOptions) ([]models.ResourceFhir, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
queryParam := models.ResourceFhir{
|
queryParam := models.ResourceFhir{
|
||||||
OriginBase: models.OriginBase{
|
OriginBase: models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -305,9 +323,14 @@ func (sr *SqliteRepository) ListResources(ctx context.Context, queryOptions mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) {
|
func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceResourceType string, sourceResourceId string) (*models.ResourceFhir, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
queryParam := models.ResourceFhir{
|
queryParam := models.ResourceFhir{
|
||||||
OriginBase: models.OriginBase{
|
OriginBase: models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
SourceResourceType: sourceResourceType,
|
SourceResourceType: sourceResourceType,
|
||||||
SourceResourceID: sourceResourceId,
|
SourceResourceID: sourceResourceId,
|
||||||
},
|
},
|
||||||
|
@ -322,6 +345,11 @@ func (sr *SqliteRepository) GetResourceBySourceType(ctx context.Context, sourceR
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) {
|
func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId string, sourceResourceId string) (*models.ResourceFhir, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
sourceIdUUID, err := uuid.Parse(sourceId)
|
sourceIdUUID, err := uuid.Parse(sourceId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -329,7 +357,7 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
|
||||||
|
|
||||||
queryParam := models.ResourceFhir{
|
queryParam := models.ResourceFhir{
|
||||||
OriginBase: models.OriginBase{
|
OriginBase: models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
SourceID: sourceIdUUID,
|
SourceID: sourceIdUUID,
|
||||||
SourceResourceID: sourceResourceId,
|
SourceResourceID: sourceResourceId,
|
||||||
},
|
},
|
||||||
|
@ -345,6 +373,10 @@ func (sr *SqliteRepository) GetResourceBySourceId(ctx context.Context, sourceId
|
||||||
|
|
||||||
// Get the patient for each source (for the current user)
|
// Get the patient for each source (for the current user)
|
||||||
func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) {
|
func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.ResourceFhir, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
//SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id
|
//SELECT * FROM resource_fhirs WHERE user_id = "" and source_resource_type = "Patient" GROUP BY source_id
|
||||||
|
|
||||||
|
@ -358,7 +390,7 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
|
||||||
Model(models.ResourceFhir{}).
|
Model(models.ResourceFhir{}).
|
||||||
//Group("source_id"). //broken in Postgres.
|
//Group("source_id"). //broken in Postgres.
|
||||||
Where(models.OriginBase{
|
Where(models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
SourceResourceType: "Patient",
|
SourceResourceType: "Patient",
|
||||||
}).
|
}).
|
||||||
Find(&wrappedResourceModels)
|
Find(&wrappedResourceModels)
|
||||||
|
@ -370,6 +402,11 @@ func (sr *SqliteRepository) GetPatientForSources(ctx context.Context) ([]models.
|
||||||
// Generate a graph
|
// Generate a graph
|
||||||
// return list of root nodes, and their flattened related resources.
|
// return list of root nodes, and their flattened related resources.
|
||||||
func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) {
|
func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*models.ResourceFhir, []*models.ResourceFhir, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
// Get list of all resources
|
// Get list of all resources
|
||||||
wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{})
|
wrappedResourceModels, err := sr.ListResources(ctx, models.ListResourceQueryOptions{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -383,7 +420,7 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
|
||||||
result := sr.GormClient.WithContext(ctx).
|
result := sr.GormClient.WithContext(ctx).
|
||||||
Table("related_resources").
|
Table("related_resources").
|
||||||
Where(models.RelatedResource{
|
Where(models.RelatedResource{
|
||||||
ResourceFhirUserID: sr.GetCurrentUser(ctx).ID,
|
ResourceFhirUserID: currentUser.ID,
|
||||||
}).
|
}).
|
||||||
Scan(&relatedResourceRelationships)
|
Scan(&relatedResourceRelationships)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
|
@ -528,10 +565,13 @@ func (sr *SqliteRepository) GetFlattenedResourceGraph(ctx context.Context) ([]*m
|
||||||
//related resources that are "Condition" or "Encounter"
|
//related resources that are "Condition" or "Encounter"
|
||||||
func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error {
|
func (sr *SqliteRepository) AddReciprocalResourceAssociations(ctx context.Context, source *models.SourceCredential, resource *models.ResourceFhir, relatedSource *models.SourceCredential, relatedResource *models.ResourceFhir) error {
|
||||||
//ensure that the sources are "owned" by the same user
|
//ensure that the sources are "owned" by the same user
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return currentUserErr
|
||||||
|
}
|
||||||
if source.UserID != relatedSource.UserID {
|
if source.UserID != relatedSource.UserID {
|
||||||
return fmt.Errorf("user id's must match when adding associations")
|
return fmt.Errorf("user id's must match when adding associations")
|
||||||
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
|
} else if source.UserID != currentUser.ID {
|
||||||
return fmt.Errorf("user id's must match current user")
|
return fmt.Errorf("user id's must match current user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -647,9 +687,14 @@ func (sr *SqliteRepository) AddResourceAssociation(ctx context.Context, source *
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error {
|
func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, source *models.SourceCredential, resourceType string, resourceId string, relatedSource *models.SourceCredential, relatedResourceType string, relatedResourceId string) error {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
if source.UserID != relatedSource.UserID {
|
if source.UserID != relatedSource.UserID {
|
||||||
return fmt.Errorf("user id's must match when adding associations")
|
return fmt.Errorf("user id's must match when adding associations")
|
||||||
} else if source.UserID != sr.GetCurrentUser(ctx).ID {
|
} else if source.UserID != currentUser.ID {
|
||||||
return fmt.Errorf("user id's must match current user")
|
return fmt.Errorf("user id's must match current user")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -679,7 +724,10 @@ func (sr *SqliteRepository) RemoveResourceAssociation(ctx context.Context, sourc
|
||||||
// - add AddResourceAssociation for all resources linked to the Composition resource
|
// - add AddResourceAssociation for all resources linked to the Composition resource
|
||||||
// - store the Composition resource
|
// - store the Composition resource
|
||||||
func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error {
|
func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, compositionTitle string, resources []*models.ResourceFhir) error {
|
||||||
currentUser := sr.GetCurrentUser(ctx)
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
//generate placeholder source
|
//generate placeholder source
|
||||||
placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}}
|
placeholderSource := models.SourceCredential{UserID: currentUser.ID, SourceType: "manual", ModelBase: models.ModelBase{ID: uuid.MustParse("00000000-0000-0000-0000-000000000000")}}
|
||||||
|
@ -821,7 +869,11 @@ func (sr *SqliteRepository) AddResourceComposition(ctx context.Context, composit
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error {
|
func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *models.SourceCredential) error {
|
||||||
sourceCreds.UserID = sr.GetCurrentUser(ctx).ID
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return currentUserErr
|
||||||
|
}
|
||||||
|
sourceCreds.UserID = currentUser.ID
|
||||||
|
|
||||||
//Assign will **always** update the source credential in the DB with data passed into this function.
|
//Assign will **always** update the source credential in the DB with data passed into this function.
|
||||||
return sr.GormClient.WithContext(ctx).
|
return sr.GormClient.WithContext(ctx).
|
||||||
|
@ -833,6 +885,11 @@ func (sr *SqliteRepository) CreateSource(ctx context.Context, sourceCreds *model
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) {
|
func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*models.SourceCredential, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
sourceUUID, err := uuid.Parse(sourceId)
|
sourceUUID, err := uuid.Parse(sourceId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -840,13 +897,18 @@ func (sr *SqliteRepository) GetSource(ctx context.Context, sourceId string) (*mo
|
||||||
|
|
||||||
var sourceCred models.SourceCredential
|
var sourceCred models.SourceCredential
|
||||||
results := sr.GormClient.WithContext(ctx).
|
results := sr.GormClient.WithContext(ctx).
|
||||||
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
|
Where(models.SourceCredential{UserID: currentUser.ID, ModelBase: models.ModelBase{ID: sourceUUID}}).
|
||||||
First(&sourceCred)
|
First(&sourceCred)
|
||||||
|
|
||||||
return &sourceCred, results.Error
|
return &sourceCred, results.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) {
|
func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId string) (*models.SourceSummary, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
sourceUUID, err := uuid.Parse(sourceId)
|
sourceUUID, err := uuid.Parse(sourceId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -870,7 +932,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
||||||
Select("source_id, source_resource_type as resource_type, count(*) as count").
|
Select("source_id, source_resource_type as resource_type, count(*) as count").
|
||||||
Group("source_resource_type").
|
Group("source_resource_type").
|
||||||
Where(models.OriginBase{
|
Where(models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
SourceID: sourceUUID,
|
SourceID: sourceUUID,
|
||||||
}).
|
}).
|
||||||
Scan(&resourceTypeCounts)
|
Scan(&resourceTypeCounts)
|
||||||
|
@ -885,7 +947,7 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
||||||
var wrappedPatientResourceModel models.ResourceFhir
|
var wrappedPatientResourceModel models.ResourceFhir
|
||||||
results := sr.GormClient.WithContext(ctx).
|
results := sr.GormClient.WithContext(ctx).
|
||||||
Where(models.OriginBase{
|
Where(models.OriginBase{
|
||||||
UserID: sr.GetCurrentUser(ctx).ID,
|
UserID: currentUser.ID,
|
||||||
SourceResourceType: "Patient",
|
SourceResourceType: "Patient",
|
||||||
SourceID: sourceUUID,
|
SourceID: sourceUUID,
|
||||||
}).
|
}).
|
||||||
|
@ -900,10 +962,14 @@ func (sr *SqliteRepository) GetSourceSummary(ctx context.Context, sourceId strin
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) {
|
func (sr *SqliteRepository) GetSources(ctx context.Context) ([]models.SourceCredential, error) {
|
||||||
|
currentUser, currentUserErr := sr.GetCurrentUser(ctx)
|
||||||
|
if currentUserErr != nil {
|
||||||
|
return nil, currentUserErr
|
||||||
|
}
|
||||||
|
|
||||||
var sourceCreds []models.SourceCredential
|
var sourceCreds []models.SourceCredential
|
||||||
results := sr.GormClient.WithContext(ctx).
|
results := sr.GormClient.WithContext(ctx).
|
||||||
Where(models.SourceCredential{UserID: sr.GetCurrentUser(ctx).ID}).
|
Where(models.SourceCredential{UserID: currentUser.ID}).
|
||||||
Find(&sourceCreds)
|
Find(&sourceCreds)
|
||||||
|
|
||||||
return sourceCreds, results.Error
|
return sourceCreds, results.Error
|
||||||
|
|
|
@ -210,9 +210,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
|
||||||
require.NoError(suite.T(), err)
|
require.NoError(suite.T(), err)
|
||||||
|
|
||||||
//test
|
//test
|
||||||
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||||
|
|
||||||
//assert
|
//assert
|
||||||
|
require.NoError(suite.T(), err)
|
||||||
require.NotNil(suite.T(), userModelResult)
|
require.NotNil(suite.T(), userModelResult)
|
||||||
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
||||||
}
|
}
|
||||||
|
@ -235,9 +236,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithGinContextBackgroundAut
|
||||||
//test
|
//test
|
||||||
ginContext := gin.Context{}
|
ginContext := gin.Context{}
|
||||||
ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username")
|
ginContext.Set(pkg.ContextKeyTypeAuthUsername, "test_username")
|
||||||
userModelResult := dbRepo.GetCurrentUser(&ginContext)
|
userModelResult, err := dbRepo.GetCurrentUser(&ginContext)
|
||||||
|
|
||||||
//assert
|
//assert
|
||||||
|
require.NoError(suite.T(), err)
|
||||||
require.NotNil(suite.T(), userModelResult)
|
require.NotNil(suite.T(), userModelResult)
|
||||||
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
require.Equal(suite.T(), userModelResult.Username, "test_username")
|
||||||
}
|
}
|
||||||
|
@ -251,9 +253,10 @@ func (suite *RepositoryTestSuite) TestGetCurrentUser_WithContextBackgroundAuthUs
|
||||||
require.NoError(suite.T(), err)
|
require.NoError(suite.T(), err)
|
||||||
|
|
||||||
//test
|
//test
|
||||||
userModelResult := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
userModelResult, err := dbRepo.GetCurrentUser(context.WithValue(context.Background(), pkg.ContextKeyTypeAuthUsername, "test_username"))
|
||||||
|
|
||||||
//assert
|
//assert
|
||||||
|
require.Error(suite.T(), err)
|
||||||
require.Nil(suite.T(), userModelResult)
|
require.Nil(suite.T(), userModelResult)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue