mirror of https://github.com/go-gitea/gitea.git
Oauth2 consumer (#679)
* initial stuff for oauth2 login, fails on: * login button on the signIn page to start the OAuth2 flow and a callback for each provider Only GitHub is implemented for now * show login button only when the OAuth2 consumer is configured (and activated) * create macaron group for oauth2 urls * prevent net/http in modules (other then oauth2) * use a new data sessions oauth2 folder for storing the oauth2 session data * add missing 2FA when this is enabled on the user * add password option for OAuth2 user , for use with git over http and login to the GUI * add tip for registering a GitHub OAuth application * at startup of Gitea register all configured providers and also on adding/deleting of new providers * custom handling of errors in oauth2 request init + show better tip * add ExternalLoginUser model and migration script to add it to database * link a external account to an existing account (still need to handle wrong login and signup) and remove if user is removed * remove the linked external account from the user his settings * if user is unknown we allow him to register a new account or link it to some existing account * sign up with button on signin page (als change OAuth2Provider structure so we can store basic stuff about providers) * from gorilla/sessions docs: "Important Note: If you aren't using gorilla/mux, you need to wrap your handlers with context.ClearHandler as or else you will leak memory!" (we're using gorilla/sessions for storing oauth2 sessions) * use updated goth lib that now supports getting the OAuth2 user if the AccessToken is still valid instead of re-authenticating (prevent flooding the OAuth2 provider)
This commit is contained in:
parent
fd941db246
commit
01d957677f
17
cmd/web.go
17
cmd/web.go
|
@ -41,6 +41,7 @@ import (
|
||||||
"github.com/go-macaron/toolbox"
|
"github.com/go-macaron/toolbox"
|
||||||
"github.com/urfave/cli"
|
"github.com/urfave/cli"
|
||||||
macaron "gopkg.in/macaron.v1"
|
macaron "gopkg.in/macaron.v1"
|
||||||
|
context2 "github.com/gorilla/context"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CmdWeb represents the available web sub-command.
|
// CmdWeb represents the available web sub-command.
|
||||||
|
@ -210,6 +211,13 @@ func runWeb(ctx *cli.Context) error {
|
||||||
m.Post("/sign_up", bindIgnErr(auth.RegisterForm{}), user.SignUpPost)
|
m.Post("/sign_up", bindIgnErr(auth.RegisterForm{}), user.SignUpPost)
|
||||||
m.Get("/reset_password", user.ResetPasswd)
|
m.Get("/reset_password", user.ResetPasswd)
|
||||||
m.Post("/reset_password", user.ResetPasswdPost)
|
m.Post("/reset_password", user.ResetPasswdPost)
|
||||||
|
m.Group("/oauth2", func() {
|
||||||
|
m.Get("/:provider", user.SignInOAuth)
|
||||||
|
m.Get("/:provider/callback", user.SignInOAuthCallback)
|
||||||
|
})
|
||||||
|
m.Get("/link_account", user.LinkAccount)
|
||||||
|
m.Post("/link_account_signin", bindIgnErr(auth.SignInForm{}), user.LinkAccountPostSignIn)
|
||||||
|
m.Post("/link_account_signup", bindIgnErr(auth.RegisterForm{}), user.LinkAccountPostRegister)
|
||||||
m.Group("/two_factor", func() {
|
m.Group("/two_factor", func() {
|
||||||
m.Get("", user.TwoFactor)
|
m.Get("", user.TwoFactor)
|
||||||
m.Post("", bindIgnErr(auth.TwoFactorAuthForm{}), user.TwoFactorPost)
|
m.Post("", bindIgnErr(auth.TwoFactorAuthForm{}), user.TwoFactorPost)
|
||||||
|
@ -236,6 +244,7 @@ func runWeb(ctx *cli.Context) error {
|
||||||
Post(bindIgnErr(auth.NewAccessTokenForm{}), user.SettingsApplicationsPost)
|
Post(bindIgnErr(auth.NewAccessTokenForm{}), user.SettingsApplicationsPost)
|
||||||
m.Post("/applications/delete", user.SettingsDeleteApplication)
|
m.Post("/applications/delete", user.SettingsDeleteApplication)
|
||||||
m.Route("/delete", "GET,POST", user.SettingsDelete)
|
m.Route("/delete", "GET,POST", user.SettingsDelete)
|
||||||
|
m.Combo("/account_link").Get(user.SettingsAccountLinks).Post(user.SettingsDeleteAccountLink)
|
||||||
m.Group("/two_factor", func() {
|
m.Group("/two_factor", func() {
|
||||||
m.Get("", user.SettingsTwoFactor)
|
m.Get("", user.SettingsTwoFactor)
|
||||||
m.Post("/regenerate_scratch", user.SettingsTwoFactorRegenerateScratch)
|
m.Post("/regenerate_scratch", user.SettingsTwoFactorRegenerateScratch)
|
||||||
|
@ -671,11 +680,11 @@ func runWeb(ctx *cli.Context) error {
|
||||||
var err error
|
var err error
|
||||||
switch setting.Protocol {
|
switch setting.Protocol {
|
||||||
case setting.HTTP:
|
case setting.HTTP:
|
||||||
err = runHTTP(listenAddr, m)
|
err = runHTTP(listenAddr, context2.ClearHandler(m))
|
||||||
case setting.HTTPS:
|
case setting.HTTPS:
|
||||||
err = runHTTPS(listenAddr, setting.CertFile, setting.KeyFile, m)
|
err = runHTTPS(listenAddr, setting.CertFile, setting.KeyFile, context2.ClearHandler(m))
|
||||||
case setting.FCGI:
|
case setting.FCGI:
|
||||||
err = fcgi.Serve(nil, m)
|
err = fcgi.Serve(nil, context2.ClearHandler(m))
|
||||||
case setting.UnixSocket:
|
case setting.UnixSocket:
|
||||||
if err := os.Remove(listenAddr); err != nil && !os.IsNotExist(err) {
|
if err := os.Remove(listenAddr); err != nil && !os.IsNotExist(err) {
|
||||||
log.Fatal(4, "Failed to remove unix socket directory %s: %v", listenAddr, err)
|
log.Fatal(4, "Failed to remove unix socket directory %s: %v", listenAddr, err)
|
||||||
|
@ -691,7 +700,7 @@ func runWeb(ctx *cli.Context) error {
|
||||||
if err = os.Chmod(listenAddr, os.FileMode(setting.UnixSocketPermission)); err != nil {
|
if err = os.Chmod(listenAddr, os.FileMode(setting.UnixSocketPermission)); err != nil {
|
||||||
log.Fatal(4, "Failed to set permission of unix socket: %v", err)
|
log.Fatal(4, "Failed to set permission of unix socket: %v", err)
|
||||||
}
|
}
|
||||||
err = http.Serve(listener, m)
|
err = http.Serve(listener, context2.ClearHandler(m))
|
||||||
default:
|
default:
|
||||||
log.Fatal(4, "Invalid protocol: %s", setting.Protocol)
|
log.Fatal(4, "Invalid protocol: %s", setting.Protocol)
|
||||||
}
|
}
|
||||||
|
|
|
@ -847,3 +847,43 @@ func IsErrUploadNotExist(err error) bool {
|
||||||
func (err ErrUploadNotExist) Error() string {
|
func (err ErrUploadNotExist) Error() string {
|
||||||
return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID)
|
return fmt.Sprintf("attachment does not exist [id: %d, uuid: %s]", err.ID, err.UUID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ___________ __ .__ .____ .__ ____ ___
|
||||||
|
// \_ _____/__ ____/ |_ ___________ ____ _____ | | | | ____ ____ |__| ____ | | \______ ___________
|
||||||
|
// | __)_\ \/ /\ __\/ __ \_ __ \/ \\__ \ | | | | / _ \ / ___\| |/ \ | | / ___// __ \_ __ \
|
||||||
|
// | \> < | | \ ___/| | \/ | \/ __ \| |__ | |__( <_> ) /_/ > | | \ | | /\___ \\ ___/| | \/
|
||||||
|
// /_______ /__/\_ \ |__| \___ >__| |___| (____ /____/ |_______ \____/\___ /|__|___| / |______//____ >\___ >__|
|
||||||
|
// \/ \/ \/ \/ \/ \/ /_____/ \/ \/ \/
|
||||||
|
|
||||||
|
// ErrExternalLoginUserAlreadyExist represents a "ExternalLoginUserAlreadyExist" kind of error.
|
||||||
|
type ErrExternalLoginUserAlreadyExist struct {
|
||||||
|
ExternalID string
|
||||||
|
UserID int64
|
||||||
|
LoginSourceID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsErrExternalLoginUserAlreadyExist checks if an error is a ExternalLoginUserAlreadyExist.
|
||||||
|
func IsErrExternalLoginUserAlreadyExist(err error) bool {
|
||||||
|
_, ok := err.(ErrExternalLoginUserAlreadyExist)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err ErrExternalLoginUserAlreadyExist) Error() string {
|
||||||
|
return fmt.Sprintf("external login user already exists [externalID: %s, userID: %d, loginSourceID: %d]", err.ExternalID, err.UserID, err.LoginSourceID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrExternalLoginUserNotExist represents a "ExternalLoginUserNotExist" kind of error.
|
||||||
|
type ErrExternalLoginUserNotExist struct {
|
||||||
|
UserID int64
|
||||||
|
LoginSourceID int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsErrExternalLoginUserNotExist checks if an error is a ExternalLoginUserNotExist.
|
||||||
|
func IsErrExternalLoginUserNotExist(err error) bool {
|
||||||
|
_, ok := err.(ErrExternalLoginUserNotExist)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (err ErrExternalLoginUserNotExist) Error() string {
|
||||||
|
return fmt.Sprintf("external login user link does not exists [userID: %d, loginSourceID: %d]", err.UserID, err.LoginSourceID)
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package models
|
||||||
|
|
||||||
|
import "github.com/markbates/goth"
|
||||||
|
|
||||||
|
// ExternalLoginUser makes the connecting between some existing user and additional external login sources
|
||||||
|
type ExternalLoginUser struct {
|
||||||
|
ExternalID string `xorm:"NOT NULL"`
|
||||||
|
UserID int64 `xorm:"NOT NULL"`
|
||||||
|
LoginSourceID int64 `xorm:"NOT NULL"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExternalLogin checks if a externalID in loginSourceID scope already exists
|
||||||
|
func GetExternalLogin(externalLoginUser *ExternalLoginUser) (bool, error) {
|
||||||
|
return x.Get(externalLoginUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListAccountLinks returns a map with the ExternalLoginUser and its LoginSource
|
||||||
|
func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) {
|
||||||
|
externalAccounts := make([]*ExternalLoginUser, 0, 5)
|
||||||
|
err := x.Where("user_id=?", user.ID).
|
||||||
|
Desc("login_source_id").
|
||||||
|
Find(&externalAccounts)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return externalAccounts, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAccountToUser link the gothUser to the user
|
||||||
|
func LinkAccountToUser(user *User, gothUser goth.User) error {
|
||||||
|
loginSource, err := GetActiveOAuth2LoginSourceByName(gothUser.Provider)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
externalLoginUser := &ExternalLoginUser{
|
||||||
|
ExternalID: gothUser.UserID,
|
||||||
|
UserID: user.ID,
|
||||||
|
LoginSourceID: loginSource.ID,
|
||||||
|
}
|
||||||
|
has, err := x.Get(externalLoginUser)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if has {
|
||||||
|
return ErrExternalLoginUserAlreadyExist{gothUser.UserID, user.ID, loginSource.ID}
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = x.Insert(externalLoginUser)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAccountLink will remove all external login sources for the given user
|
||||||
|
func RemoveAccountLink(user *User, loginSourceID int64) (int64, error) {
|
||||||
|
deleted, err := x.Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID})
|
||||||
|
if err != nil {
|
||||||
|
return deleted, err
|
||||||
|
}
|
||||||
|
if deleted < 1 {
|
||||||
|
return deleted, ErrExternalLoginUserNotExist{user.ID, loginSourceID}
|
||||||
|
}
|
||||||
|
return deleted, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveAllAccountLinks will remove all external login sources for the given user
|
||||||
|
func RemoveAllAccountLinks(user *User) error {
|
||||||
|
_, err := x.Delete(&ExternalLoginUser{UserID: user.ID})
|
||||||
|
return err
|
||||||
|
}
|
|
@ -22,6 +22,7 @@ import (
|
||||||
"code.gitea.io/gitea/modules/auth/ldap"
|
"code.gitea.io/gitea/modules/auth/ldap"
|
||||||
"code.gitea.io/gitea/modules/auth/pam"
|
"code.gitea.io/gitea/modules/auth/pam"
|
||||||
"code.gitea.io/gitea/modules/log"
|
"code.gitea.io/gitea/modules/log"
|
||||||
|
"code.gitea.io/gitea/modules/auth/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoginType represents an login type.
|
// LoginType represents an login type.
|
||||||
|
@ -30,19 +31,21 @@ type LoginType int
|
||||||
// Note: new type must append to the end of list to maintain compatibility.
|
// Note: new type must append to the end of list to maintain compatibility.
|
||||||
const (
|
const (
|
||||||
LoginNoType LoginType = iota
|
LoginNoType LoginType = iota
|
||||||
LoginPlain // 1
|
LoginPlain // 1
|
||||||
LoginLDAP // 2
|
LoginLDAP // 2
|
||||||
LoginSMTP // 3
|
LoginSMTP // 3
|
||||||
LoginPAM // 4
|
LoginPAM // 4
|
||||||
LoginDLDAP // 5
|
LoginDLDAP // 5
|
||||||
|
LoginOAuth2 // 6
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoginNames contains the name of LoginType values.
|
// LoginNames contains the name of LoginType values.
|
||||||
var LoginNames = map[LoginType]string{
|
var LoginNames = map[LoginType]string{
|
||||||
LoginLDAP: "LDAP (via BindDN)",
|
LoginLDAP: "LDAP (via BindDN)",
|
||||||
LoginDLDAP: "LDAP (simple auth)", // Via direct bind
|
LoginDLDAP: "LDAP (simple auth)", // Via direct bind
|
||||||
LoginSMTP: "SMTP",
|
LoginSMTP: "SMTP",
|
||||||
LoginPAM: "PAM",
|
LoginPAM: "PAM",
|
||||||
|
LoginOAuth2: "OAuth2",
|
||||||
}
|
}
|
||||||
|
|
||||||
// SecurityProtocolNames contains the name of SecurityProtocol values.
|
// SecurityProtocolNames contains the name of SecurityProtocol values.
|
||||||
|
@ -57,6 +60,7 @@ var (
|
||||||
_ core.Conversion = &LDAPConfig{}
|
_ core.Conversion = &LDAPConfig{}
|
||||||
_ core.Conversion = &SMTPConfig{}
|
_ core.Conversion = &SMTPConfig{}
|
||||||
_ core.Conversion = &PAMConfig{}
|
_ core.Conversion = &PAMConfig{}
|
||||||
|
_ core.Conversion = &OAuth2Config{}
|
||||||
)
|
)
|
||||||
|
|
||||||
// LDAPConfig holds configuration for LDAP login source.
|
// LDAPConfig holds configuration for LDAP login source.
|
||||||
|
@ -115,6 +119,23 @@ func (cfg *PAMConfig) ToDB() ([]byte, error) {
|
||||||
return json.Marshal(cfg)
|
return json.Marshal(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuth2Config holds configuration for the OAuth2 login source.
|
||||||
|
type OAuth2Config struct {
|
||||||
|
Provider string
|
||||||
|
ClientID string
|
||||||
|
ClientSecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromDB fills up an OAuth2Config from serialized format.
|
||||||
|
func (cfg *OAuth2Config) FromDB(bs []byte) error {
|
||||||
|
return json.Unmarshal(bs, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToDB exports an SMTPConfig to a serialized format.
|
||||||
|
func (cfg *OAuth2Config) ToDB() ([]byte, error) {
|
||||||
|
return json.Marshal(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
// LoginSource represents an external way for authorizing users.
|
// LoginSource represents an external way for authorizing users.
|
||||||
type LoginSource struct {
|
type LoginSource struct {
|
||||||
ID int64 `xorm:"pk autoincr"`
|
ID int64 `xorm:"pk autoincr"`
|
||||||
|
@ -162,6 +183,8 @@ func (source *LoginSource) BeforeSet(colName string, val xorm.Cell) {
|
||||||
source.Cfg = new(SMTPConfig)
|
source.Cfg = new(SMTPConfig)
|
||||||
case LoginPAM:
|
case LoginPAM:
|
||||||
source.Cfg = new(PAMConfig)
|
source.Cfg = new(PAMConfig)
|
||||||
|
case LoginOAuth2:
|
||||||
|
source.Cfg = new(OAuth2Config)
|
||||||
default:
|
default:
|
||||||
panic("unrecognized login source type: " + com.ToStr(*val))
|
panic("unrecognized login source type: " + com.ToStr(*val))
|
||||||
}
|
}
|
||||||
|
@ -203,6 +226,11 @@ func (source *LoginSource) IsPAM() bool {
|
||||||
return source.Type == LoginPAM
|
return source.Type == LoginPAM
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOAuth2 returns true of this source is of the OAuth2 type.
|
||||||
|
func (source *LoginSource) IsOAuth2() bool {
|
||||||
|
return source.Type == LoginOAuth2
|
||||||
|
}
|
||||||
|
|
||||||
// HasTLS returns true of this source supports TLS.
|
// HasTLS returns true of this source supports TLS.
|
||||||
func (source *LoginSource) HasTLS() bool {
|
func (source *LoginSource) HasTLS() bool {
|
||||||
return ((source.IsLDAP() || source.IsDLDAP()) &&
|
return ((source.IsLDAP() || source.IsDLDAP()) &&
|
||||||
|
@ -250,6 +278,11 @@ func (source *LoginSource) PAM() *PAMConfig {
|
||||||
return source.Cfg.(*PAMConfig)
|
return source.Cfg.(*PAMConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuth2 returns OAuth2Config for this source, if of OAuth2 type.
|
||||||
|
func (source *LoginSource) OAuth2() *OAuth2Config {
|
||||||
|
return source.Cfg.(*OAuth2Config)
|
||||||
|
}
|
||||||
|
|
||||||
// CreateLoginSource inserts a LoginSource in the DB if not already
|
// CreateLoginSource inserts a LoginSource in the DB if not already
|
||||||
// existing with the given name.
|
// existing with the given name.
|
||||||
func CreateLoginSource(source *LoginSource) error {
|
func CreateLoginSource(source *LoginSource) error {
|
||||||
|
@ -261,12 +294,16 @@ func CreateLoginSource(source *LoginSource) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = x.Insert(source)
|
_, err = x.Insert(source)
|
||||||
|
if err == nil && source.IsOAuth2() {
|
||||||
|
oAuth2Config := source.OAuth2()
|
||||||
|
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoginSources returns a slice of all login sources found in DB.
|
// LoginSources returns a slice of all login sources found in DB.
|
||||||
func LoginSources() ([]*LoginSource, error) {
|
func LoginSources() ([]*LoginSource, error) {
|
||||||
auths := make([]*LoginSource, 0, 5)
|
auths := make([]*LoginSource, 0, 6)
|
||||||
return auths, x.Find(&auths)
|
return auths, x.Find(&auths)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -285,6 +322,11 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) {
|
||||||
// UpdateSource updates a LoginSource record in DB.
|
// UpdateSource updates a LoginSource record in DB.
|
||||||
func UpdateSource(source *LoginSource) error {
|
func UpdateSource(source *LoginSource) error {
|
||||||
_, err := x.Id(source.ID).AllCols().Update(source)
|
_, err := x.Id(source.ID).AllCols().Update(source)
|
||||||
|
if err == nil && source.IsOAuth2() {
|
||||||
|
oAuth2Config := source.OAuth2()
|
||||||
|
oauth2.RemoveProvider(source.Name)
|
||||||
|
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret)
|
||||||
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -296,6 +338,18 @@ func DeleteSource(source *LoginSource) error {
|
||||||
} else if count > 0 {
|
} else if count > 0 {
|
||||||
return ErrLoginSourceInUse{source.ID}
|
return ErrLoginSourceInUse{source.ID}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
count, err = x.Count(&ExternalLoginUser{LoginSourceID: source.ID})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if count > 0 {
|
||||||
|
return ErrLoginSourceInUse{source.ID}
|
||||||
|
}
|
||||||
|
|
||||||
|
if source.IsOAuth2() {
|
||||||
|
oauth2.RemoveProvider(source.Name)
|
||||||
|
}
|
||||||
|
|
||||||
_, err = x.Id(source.ID).Delete(new(LoginSource))
|
_, err = x.Id(source.ID).Delete(new(LoginSource))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -444,7 +498,7 @@ func LoginViaSMTP(user *User, login, password string, sourceID int64, cfg *SMTPC
|
||||||
idx := strings.Index(login, "@")
|
idx := strings.Index(login, "@")
|
||||||
if idx == -1 {
|
if idx == -1 {
|
||||||
return nil, ErrUserNotExist{0, login, 0}
|
return nil, ErrUserNotExist{0, login, 0}
|
||||||
} else if !com.IsSliceContainsStr(strings.Split(cfg.AllowedDomains, ","), login[idx+1:]) {
|
} else if !com.IsSliceContainsStr(strings.Split(cfg.AllowedDomains, ","), login[idx + 1:]) {
|
||||||
return nil, ErrUserNotExist{0, login, 0}
|
return nil, ErrUserNotExist{0, login, 0}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -526,6 +580,27 @@ func LoginViaPAM(user *User, login, password string, sourceID int64, cfg *PAMCon
|
||||||
return user, CreateUser(user)
|
return user, CreateUser(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ________ _____ __ .__ ________
|
||||||
|
// \_____ \ / _ \ __ ___/ |_| |__ \_____ \
|
||||||
|
// / | \ / /_\ \| | \ __\ | \ / ____/
|
||||||
|
// / | \/ | \ | /| | | Y \/ \
|
||||||
|
// \_______ /\____|__ /____/ |__| |___| /\_______ \
|
||||||
|
// \/ \/ \/ \/
|
||||||
|
|
||||||
|
// OAuth2Provider describes the display values of a single OAuth2 provider
|
||||||
|
type OAuth2Provider struct {
|
||||||
|
Name string
|
||||||
|
DisplayName string
|
||||||
|
Image string
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2Providers contains the map of registered OAuth2 providers in Gitea (based on goth)
|
||||||
|
// key is used to map the OAuth2Provider with the goth provider type (also in LoginSource.OAuth2Config.Provider)
|
||||||
|
// value is used to store display data
|
||||||
|
var OAuth2Providers = map[string]OAuth2Provider{
|
||||||
|
"github": {Name: "github", DisplayName:"GitHub", Image: "/img/github.png"},
|
||||||
|
}
|
||||||
|
|
||||||
// ExternalUserLogin attempts a login using external source types.
|
// ExternalUserLogin attempts a login using external source types.
|
||||||
func ExternalUserLogin(user *User, login, password string, source *LoginSource, autoRegister bool) (*User, error) {
|
func ExternalUserLogin(user *User, login, password string, source *LoginSource, autoRegister bool) (*User, error) {
|
||||||
if !source.IsActived {
|
if !source.IsActived {
|
||||||
|
@ -560,7 +635,7 @@ func UserSignIn(username, password string) (*User, error) {
|
||||||
|
|
||||||
if hasUser {
|
if hasUser {
|
||||||
switch user.LoginType {
|
switch user.LoginType {
|
||||||
case LoginNoType, LoginPlain:
|
case LoginNoType, LoginPlain, LoginOAuth2:
|
||||||
if user.ValidatePassword(password) {
|
if user.ValidatePassword(password) {
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
@ -580,12 +655,16 @@ func UserSignIn(username, password string) (*User, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sources := make([]*LoginSource, 0, 3)
|
sources := make([]*LoginSource, 0, 5)
|
||||||
if err = x.UseBool().Find(&sources, &LoginSource{IsActived: true}); err != nil {
|
if err = x.UseBool().Find(&sources, &LoginSource{IsActived: true}); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, source := range sources {
|
for _, source := range sources {
|
||||||
|
if source.IsOAuth2() {
|
||||||
|
// don't try to authenticate against OAuth2 sources
|
||||||
|
continue
|
||||||
|
}
|
||||||
authUser, err := ExternalUserLogin(nil, username, password, source, true)
|
authUser, err := ExternalUserLogin(nil, username, password, source, true)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return authUser, nil
|
return authUser, nil
|
||||||
|
@ -596,3 +675,58 @@ func UserSignIn(username, password string) (*User, error) {
|
||||||
|
|
||||||
return nil, ErrUserNotExist{user.ID, user.Name, 0}
|
return nil, ErrUserNotExist{user.ID, user.Name, 0}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources
|
||||||
|
func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) {
|
||||||
|
sources := make([]*LoginSource, 0, 1)
|
||||||
|
if err := x.UseBool().Find(&sources, &LoginSource{IsActived: true, Type: LoginOAuth2}); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return sources, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name
|
||||||
|
func GetActiveOAuth2LoginSourceByName(name string) (*LoginSource, error) {
|
||||||
|
loginSource := &LoginSource{
|
||||||
|
Name: name,
|
||||||
|
Type: LoginOAuth2,
|
||||||
|
IsActived: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
has, err := x.UseBool().Get(loginSource)
|
||||||
|
if !has || err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return loginSource, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers
|
||||||
|
// key is used as technical name (like in the callbackURL)
|
||||||
|
// values to display
|
||||||
|
func GetActiveOAuth2Providers() (map[string]OAuth2Provider, error) {
|
||||||
|
// Maybe also seperate used and unused providers so we can force the registration of only 1 active provider for each type
|
||||||
|
|
||||||
|
loginSources, err := GetActiveOAuth2ProviderLoginSources()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
providers := make(map[string]OAuth2Provider)
|
||||||
|
for _, source := range loginSources {
|
||||||
|
providers[source.Name] = OAuth2Providers[source.OAuth2().Provider]
|
||||||
|
}
|
||||||
|
|
||||||
|
return providers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitOAuth2 initialize the OAuth2 lib and register all active OAuth2 providers in the library
|
||||||
|
func InitOAuth2() {
|
||||||
|
oauth2.Init()
|
||||||
|
loginSources, _ := GetActiveOAuth2ProviderLoginSources()
|
||||||
|
|
||||||
|
for _, source := range loginSources {
|
||||||
|
oAuth2Config := source.OAuth2()
|
||||||
|
oauth2.RegisterProvider(source.Name, oAuth2Config.Provider, oAuth2Config.ClientID, oAuth2Config.ClientSecret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -84,6 +84,8 @@ var migrations = []Migration{
|
||||||
NewMigration("create repo unit table and add units for all repos", addUnitsToTables),
|
NewMigration("create repo unit table and add units for all repos", addUnitsToTables),
|
||||||
// v17 -> v18
|
// v17 -> v18
|
||||||
NewMigration("set protect branches updated with created", setProtectedBranchUpdatedWithCreated),
|
NewMigration("set protect branches updated with created", setProtectedBranchUpdatedWithCreated),
|
||||||
|
// v18 -> v19
|
||||||
|
NewMigration("add external login user", addExternalLoginUser),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate database to current version
|
// Migrate database to current version
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
// Copyright 2016 Gitea. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package migrations
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/go-xorm/xorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExternalLoginUser makes the connecting between some existing user and additional external login sources
|
||||||
|
type ExternalLoginUser struct {
|
||||||
|
ExternalID string `xorm:"NOT NULL"`
|
||||||
|
UserID int64 `xorm:"NOT NULL"`
|
||||||
|
LoginSourceID int64 `xorm:"NOT NULL"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func addExternalLoginUser(x *xorm.Engine) error {
|
||||||
|
if err := x.Sync2(new(ExternalLoginUser)); err != nil {
|
||||||
|
return fmt.Errorf("Sync2: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
|
@ -196,6 +196,11 @@ func (u *User) IsLocal() bool {
|
||||||
return u.LoginType <= LoginPlain
|
return u.LoginType <= LoginPlain
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsOAuth2 returns true if user login type is LoginOAuth2.
|
||||||
|
func (u *User) IsOAuth2() bool {
|
||||||
|
return u.LoginType == LoginOAuth2
|
||||||
|
}
|
||||||
|
|
||||||
// HasForkedRepo checks if user has already forked a repository with given ID.
|
// HasForkedRepo checks if user has already forked a repository with given ID.
|
||||||
func (u *User) HasForkedRepo(repoID int64) bool {
|
func (u *User) HasForkedRepo(repoID int64) bool {
|
||||||
_, has := HasForkedRepo(u.ID, repoID)
|
_, has := HasForkedRepo(u.ID, repoID)
|
||||||
|
@ -397,6 +402,11 @@ func (u *User) ValidatePassword(passwd string) bool {
|
||||||
return subtle.ConstantTimeCompare([]byte(u.Passwd), []byte(newUser.Passwd)) == 1
|
return subtle.ConstantTimeCompare([]byte(u.Passwd), []byte(newUser.Passwd)) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsPasswordSet checks if the password is set or left empty
|
||||||
|
func (u *User) IsPasswordSet() bool {
|
||||||
|
return !u.ValidatePassword("")
|
||||||
|
}
|
||||||
|
|
||||||
// UploadAvatar saves custom avatar for user.
|
// UploadAvatar saves custom avatar for user.
|
||||||
// FIXME: split uploads to different subdirs in case we have massive users.
|
// FIXME: split uploads to different subdirs in case we have massive users.
|
||||||
func (u *User) UploadAvatar(data []byte) error {
|
func (u *User) UploadAvatar(data []byte) error {
|
||||||
|
@ -947,6 +957,12 @@ func deleteUser(e *xorm.Session, u *User) error {
|
||||||
return fmt.Errorf("clear assignee: %v", err)
|
return fmt.Errorf("clear assignee: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ***** START: ExternalLoginUser *****
|
||||||
|
if err = RemoveAllAccountLinks(u); err != nil {
|
||||||
|
return fmt.Errorf("ExternalLoginUser: %v", err)
|
||||||
|
}
|
||||||
|
// ***** END: ExternalLoginUser *****
|
||||||
|
|
||||||
if _, err = e.Id(u.ID).Delete(new(User)); err != nil {
|
if _, err = e.Id(u.ID).Delete(new(User)); err != nil {
|
||||||
return fmt.Errorf("Delete: %v", err)
|
return fmt.Errorf("Delete: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -1190,6 +1206,11 @@ func GetUserByEmail(email string) (*User, error) {
|
||||||
return nil, ErrUserNotExist{0, email, 0}
|
return nil, ErrUserNotExist{0, email, 0}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetUser checks if a user already exists
|
||||||
|
func GetUser(user *User) (bool, error) {
|
||||||
|
return x.Get(user)
|
||||||
|
}
|
||||||
|
|
||||||
// SearchUserOptions contains the options for searching
|
// SearchUserOptions contains the options for searching
|
||||||
type SearchUserOptions struct {
|
type SearchUserOptions struct {
|
||||||
Keyword string
|
Keyword string
|
||||||
|
|
|
@ -179,7 +179,7 @@ func AssignForm(form interface{}, data map[string]interface{}) {
|
||||||
func getRuleBody(field reflect.StructField, prefix string) string {
|
func getRuleBody(field reflect.StructField, prefix string) string {
|
||||||
for _, rule := range strings.Split(field.Tag.Get("binding"), ";") {
|
for _, rule := range strings.Split(field.Tag.Get("binding"), ";") {
|
||||||
if strings.HasPrefix(rule, prefix) {
|
if strings.HasPrefix(rule, prefix) {
|
||||||
return rule[len(prefix) : len(rule)-1]
|
return rule[len(prefix): len(rule) - 1]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
@ -237,7 +237,7 @@ func validate(errs binding.Errors, data map[string]interface{}, f Form, l macaro
|
||||||
}
|
}
|
||||||
|
|
||||||
if errs[0].FieldNames[0] == field.Name {
|
if errs[0].FieldNames[0] == field.Name {
|
||||||
data["Err_"+field.Name] = true
|
data["Err_" + field.Name] = true
|
||||||
|
|
||||||
trName := field.Tag.Get("locale")
|
trName := field.Tag.Get("locale")
|
||||||
if len(trName) == 0 {
|
if len(trName) == 0 {
|
||||||
|
|
|
@ -12,7 +12,7 @@ import (
|
||||||
// AuthenticationForm form for authentication
|
// AuthenticationForm form for authentication
|
||||||
type AuthenticationForm struct {
|
type AuthenticationForm struct {
|
||||||
ID int64
|
ID int64
|
||||||
Type int `binding:"Range(2,5)"`
|
Type int `binding:"Range(2,6)"`
|
||||||
Name string `binding:"Required;MaxSize(30)"`
|
Name string `binding:"Required;MaxSize(30)"`
|
||||||
Host string
|
Host string
|
||||||
Port int
|
Port int
|
||||||
|
@ -36,6 +36,9 @@ type AuthenticationForm struct {
|
||||||
TLS bool
|
TLS bool
|
||||||
SkipVerify bool
|
SkipVerify bool
|
||||||
PAMServiceName string
|
PAMServiceName string
|
||||||
|
Oauth2Provider string
|
||||||
|
Oauth2Key string
|
||||||
|
Oauth2Secret string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate validates fields
|
// Validate validates fields
|
||||||
|
|
|
@ -0,0 +1,105 @@
|
||||||
|
// Copyright 2017 The Gitea Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a MIT-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"code.gitea.io/gitea/modules/setting"
|
||||||
|
"code.gitea.io/gitea/modules/log"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/markbates/goth"
|
||||||
|
"github.com/markbates/goth/gothic"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"github.com/satori/go.uuid"
|
||||||
|
"path/filepath"
|
||||||
|
"github.com/markbates/goth/providers/github"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
sessionUsersStoreKey = "gitea-oauth2-sessions"
|
||||||
|
providerHeaderKey = "gitea-oauth2-provider"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Init initialize the setup of the OAuth2 library
|
||||||
|
func Init() {
|
||||||
|
sessionDir := filepath.Join(setting.AppDataPath, "sessions", "oauth2")
|
||||||
|
if err := os.MkdirAll(sessionDir, 0700); err != nil {
|
||||||
|
log.Fatal(4, "Fail to create dir %s: %v", sessionDir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gothic.Store = sessions.NewFilesystemStore(sessionDir, []byte(sessionUsersStoreKey))
|
||||||
|
|
||||||
|
gothic.SetState = func(req *http.Request) string {
|
||||||
|
return uuid.NewV4().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
gothic.GetProviderName = func(req *http.Request) (string, error) {
|
||||||
|
return req.Header.Get(providerHeaderKey), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auth OAuth2 auth service
|
||||||
|
func Auth(provider string, request *http.Request, response http.ResponseWriter) error {
|
||||||
|
// not sure if goth is thread safe (?) when using multiple providers
|
||||||
|
request.Header.Set(providerHeaderKey, provider)
|
||||||
|
|
||||||
|
// don't use the default gothic begin handler to prevent issues when some error occurs
|
||||||
|
// normally the gothic library will write some custom stuff to the response instead of our own nice error page
|
||||||
|
//gothic.BeginAuthHandler(response, request)
|
||||||
|
|
||||||
|
url, err := gothic.GetAuthURL(response, request)
|
||||||
|
if err == nil {
|
||||||
|
http.Redirect(response, request, url, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderCallback handles OAuth callback, resolve to a goth user and send back to original url
|
||||||
|
// this will trigger a new authentication request, but because we save it in the session we can use that
|
||||||
|
func ProviderCallback(provider string, request *http.Request, response http.ResponseWriter) (goth.User, error) {
|
||||||
|
// not sure if goth is thread safe (?) when using multiple providers
|
||||||
|
request.Header.Set(providerHeaderKey, provider)
|
||||||
|
|
||||||
|
user, err := gothic.CompleteUserAuth(response, request)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider register a OAuth2 provider in goth lib
|
||||||
|
func RegisterProvider(providerName, providerType, clientID, clientSecret string) {
|
||||||
|
provider := createProvider(providerName, providerType, clientID, clientSecret)
|
||||||
|
|
||||||
|
if provider != nil {
|
||||||
|
goth.UseProviders(provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveProvider removes the given OAuth2 provider from the goth lib
|
||||||
|
func RemoveProvider(providerName string) {
|
||||||
|
delete(goth.GetProviders(), providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to create different types of goth providers
|
||||||
|
func createProvider(providerName, providerType, clientID, clientSecret string) goth.Provider {
|
||||||
|
callbackURL := setting.AppURL + "user/oauth2/" + providerName + "/callback"
|
||||||
|
|
||||||
|
var provider goth.Provider
|
||||||
|
|
||||||
|
switch providerType {
|
||||||
|
case "github":
|
||||||
|
provider = github.New(clientID, clientSecret, callbackURL, "user:email")
|
||||||
|
}
|
||||||
|
|
||||||
|
// always set the name if provider is created so we can support multiple setups of 1 provider
|
||||||
|
if provider != nil {
|
||||||
|
provider.SetName(providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider
|
||||||
|
}
|
|
@ -143,7 +143,7 @@ func (f *AddEmailForm) Validate(ctx *macaron.Context, errs binding.Errors) bindi
|
||||||
|
|
||||||
// ChangePasswordForm form for changing password
|
// ChangePasswordForm form for changing password
|
||||||
type ChangePasswordForm struct {
|
type ChangePasswordForm struct {
|
||||||
OldPassword string `form:"old_password" binding:"Required;MinSize(1);MaxSize(255)"`
|
OldPassword string `form:"old_password" binding:"MaxSize(255)"`
|
||||||
Password string `form:"password" binding:"Required;MaxSize(255)"`
|
Password string `form:"password" binding:"Required;MaxSize(255)"`
|
||||||
Retype string `form:"retype"`
|
Retype string `form:"retype"`
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,11 @@ dashboard = Dashboard
|
||||||
explore = Explore
|
explore = Explore
|
||||||
help = Help
|
help = Help
|
||||||
sign_in = Sign In
|
sign_in = Sign In
|
||||||
|
sign_in_with = Sign in with
|
||||||
sign_out = Sign Out
|
sign_out = Sign Out
|
||||||
sign_up = Sign Up
|
sign_up = Sign Up
|
||||||
|
link_account = Link Account
|
||||||
|
link_account_signin_or_signup = Login with existing credentials to link your existing account to these new account, or sign up for a new account
|
||||||
register = Register
|
register = Register
|
||||||
website = Website
|
website = Website
|
||||||
version = Version
|
version = Version
|
||||||
|
@ -277,6 +280,7 @@ applications = Applications
|
||||||
orgs = Organizations
|
orgs = Organizations
|
||||||
delete = Delete Account
|
delete = Delete Account
|
||||||
twofa = Two-Factor Authentication
|
twofa = Two-Factor Authentication
|
||||||
|
account_link = External Accounts
|
||||||
uid = Uid
|
uid = Uid
|
||||||
|
|
||||||
public_profile = Public Profile
|
public_profile = Public Profile
|
||||||
|
@ -379,6 +383,13 @@ then_enter_passcode = Then enter the passcode the application gives you:
|
||||||
passcode_invalid = That passcode is invalid. Try again.
|
passcode_invalid = That passcode is invalid. Try again.
|
||||||
twofa_enrolled = Your account has now been enrolled in two-factor authentication. Make sure to save your scratch token (%s), as it will only be shown once!
|
twofa_enrolled = Your account has now been enrolled in two-factor authentication. Make sure to save your scratch token (%s), as it will only be shown once!
|
||||||
|
|
||||||
|
manage_account_links = Manage account links
|
||||||
|
manage_account_links_desc = External accounts linked to this account
|
||||||
|
account_links_not_available = There are no external accounts linked to this account
|
||||||
|
remove_account_link = Remove linked account
|
||||||
|
remove_account_link_desc = Delete this account link will remove all related access for your account. Do you want to continue?
|
||||||
|
remove_account_link_success = Account link has been removed successfully!
|
||||||
|
|
||||||
delete_account = Delete Your Account
|
delete_account = Delete Your Account
|
||||||
delete_prompt = The operation will delete your account permanently, and <strong>CANNOT</strong> be undone!
|
delete_prompt = The operation will delete your account permanently, and <strong>CANNOT</strong> be undone!
|
||||||
confirm_delete_account = Confirm Deletion
|
confirm_delete_account = Confirm Deletion
|
||||||
|
@ -1106,8 +1117,12 @@ auths.allowed_domains_helper = Leave it empty to not restrict any domains. Multi
|
||||||
auths.enable_tls = Enable TLS Encryption
|
auths.enable_tls = Enable TLS Encryption
|
||||||
auths.skip_tls_verify = Skip TLS Verify
|
auths.skip_tls_verify = Skip TLS Verify
|
||||||
auths.pam_service_name = PAM Service Name
|
auths.pam_service_name = PAM Service Name
|
||||||
|
auths.oauth2_provider = OAuth2 provider
|
||||||
|
auths.oauth2_clientID = Client ID (Key)
|
||||||
|
auths.oauth2_clientSecret = Client Secret
|
||||||
auths.enable_auto_register = Enable Auto Registration
|
auths.enable_auto_register = Enable Auto Registration
|
||||||
auths.tips = Tips
|
auths.tips = Tips
|
||||||
|
auths.tip.github = Register a new OAuth application on https://github.com/settings/applications/new and use <host>/user/oauth2/<Authentication Name>/callback as "Authorization callback URL"
|
||||||
auths.edit = Edit Authentication Setting
|
auths.edit = Edit Authentication Setting
|
||||||
auths.activated = This authentication is activated
|
auths.activated = This authentication is activated
|
||||||
auths.new_success = New authentication '%s' has been added successfully.
|
auths.new_success = New authentication '%s' has been added successfully.
|
||||||
|
|
|
@ -2983,3 +2983,24 @@ footer .ui.language .menu {
|
||||||
.ui.user.list .item .description a:hover {
|
.ui.user.list .item .description a:hover {
|
||||||
text-decoration: underline;
|
text-decoration: underline;
|
||||||
}
|
}
|
||||||
|
.user.link-account:not(.icon) {
|
||||||
|
padding-top: 15px;
|
||||||
|
padding-bottom: 5px;
|
||||||
|
}
|
||||||
|
.signin .oauth2 div {
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
.signin .oauth2 div p {
|
||||||
|
margin: 10px 5px 0 0;
|
||||||
|
float: left;
|
||||||
|
}
|
||||||
|
.signin .oauth2 a {
|
||||||
|
margin-right: 5px;
|
||||||
|
}
|
||||||
|
.signin .oauth2 a:last-child {
|
||||||
|
margin-right: 0px;
|
||||||
|
}
|
||||||
|
.signin .oauth2 img {
|
||||||
|
width: 32px;
|
||||||
|
height: 32px;
|
||||||
|
}
|
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
|
@ -1019,9 +1019,9 @@ function initAdmin() {
|
||||||
// New authentication
|
// New authentication
|
||||||
if ($('.admin.new.authentication').length > 0) {
|
if ($('.admin.new.authentication').length > 0) {
|
||||||
$('#auth_type').change(function () {
|
$('#auth_type').change(function () {
|
||||||
$('.ldap, .dldap, .smtp, .pam, .has-tls').hide();
|
$('.ldap, .dldap, .smtp, .pam, .oauth2, .has-tls').hide();
|
||||||
|
|
||||||
$('.ldap input[required], .dldap input[required], .smtp input[required], .pam input[required], .has-tls input[required]').removeAttr('required');
|
$('.ldap input[required], .dldap input[required], .smtp input[required], .pam input[required], .oauth2 input[required] .has-tls input[required]').removeAttr('required');
|
||||||
|
|
||||||
var authType = $(this).val();
|
var authType = $(this).val();
|
||||||
switch (authType) {
|
switch (authType) {
|
||||||
|
@ -1042,6 +1042,10 @@ function initAdmin() {
|
||||||
$('.dldap').show();
|
$('.dldap').show();
|
||||||
$('.dldap div.required:not(.ldap) input').attr('required', 'required');
|
$('.dldap div.required:not(.ldap) input').attr('required', 'required');
|
||||||
break;
|
break;
|
||||||
|
case '6': // OAuth2
|
||||||
|
$('.oauth2').show();
|
||||||
|
$('.oauth2 input').attr('required', 'required');
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (authType == '2' || authType == '5') {
|
if (authType == '2' || authType == '5') {
|
||||||
|
|
|
@ -53,6 +53,7 @@ var (
|
||||||
{models.LoginNames[models.LoginDLDAP], models.LoginDLDAP},
|
{models.LoginNames[models.LoginDLDAP], models.LoginDLDAP},
|
||||||
{models.LoginNames[models.LoginSMTP], models.LoginSMTP},
|
{models.LoginNames[models.LoginSMTP], models.LoginSMTP},
|
||||||
{models.LoginNames[models.LoginPAM], models.LoginPAM},
|
{models.LoginNames[models.LoginPAM], models.LoginPAM},
|
||||||
|
{models.LoginNames[models.LoginOAuth2], models.LoginOAuth2},
|
||||||
}
|
}
|
||||||
securityProtocols = []dropdownItem{
|
securityProtocols = []dropdownItem{
|
||||||
{models.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted], ldap.SecurityProtocolUnencrypted},
|
{models.SecurityProtocolNames[ldap.SecurityProtocolUnencrypted], ldap.SecurityProtocolUnencrypted},
|
||||||
|
@ -75,6 +76,14 @@ func NewAuthSource(ctx *context.Context) {
|
||||||
ctx.Data["AuthSources"] = authSources
|
ctx.Data["AuthSources"] = authSources
|
||||||
ctx.Data["SecurityProtocols"] = securityProtocols
|
ctx.Data["SecurityProtocols"] = securityProtocols
|
||||||
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
||||||
|
ctx.Data["OAuth2Providers"] = models.OAuth2Providers
|
||||||
|
|
||||||
|
// only the first as default
|
||||||
|
for key := range models.OAuth2Providers {
|
||||||
|
ctx.Data["oauth2_provider"] = key
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
ctx.HTML(200, tplAuthNew)
|
ctx.HTML(200, tplAuthNew)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,6 +122,14 @@ func parseSMTPConfig(form auth.AuthenticationForm) *models.SMTPConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func parseOAuth2Config(form auth.AuthenticationForm) *models.OAuth2Config {
|
||||||
|
return &models.OAuth2Config{
|
||||||
|
Provider: form.Oauth2Provider,
|
||||||
|
ClientID: form.Oauth2Key,
|
||||||
|
ClientSecret: form.Oauth2Secret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewAuthSourcePost response for adding an auth source
|
// NewAuthSourcePost response for adding an auth source
|
||||||
func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
||||||
ctx.Data["Title"] = ctx.Tr("admin.auths.new")
|
ctx.Data["Title"] = ctx.Tr("admin.auths.new")
|
||||||
|
@ -124,6 +141,7 @@ func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
||||||
ctx.Data["AuthSources"] = authSources
|
ctx.Data["AuthSources"] = authSources
|
||||||
ctx.Data["SecurityProtocols"] = securityProtocols
|
ctx.Data["SecurityProtocols"] = securityProtocols
|
||||||
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
||||||
|
ctx.Data["OAuth2Providers"] = models.OAuth2Providers
|
||||||
|
|
||||||
hasTLS := false
|
hasTLS := false
|
||||||
var config core.Conversion
|
var config core.Conversion
|
||||||
|
@ -138,6 +156,8 @@ func NewAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
||||||
config = &models.PAMConfig{
|
config = &models.PAMConfig{
|
||||||
ServiceName: form.PAMServiceName,
|
ServiceName: form.PAMServiceName,
|
||||||
}
|
}
|
||||||
|
case models.LoginOAuth2:
|
||||||
|
config = parseOAuth2Config(form)
|
||||||
default:
|
default:
|
||||||
ctx.Error(400)
|
ctx.Error(400)
|
||||||
return
|
return
|
||||||
|
@ -178,6 +198,7 @@ func EditAuthSource(ctx *context.Context) {
|
||||||
|
|
||||||
ctx.Data["SecurityProtocols"] = securityProtocols
|
ctx.Data["SecurityProtocols"] = securityProtocols
|
||||||
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
||||||
|
ctx.Data["OAuth2Providers"] = models.OAuth2Providers
|
||||||
|
|
||||||
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid"))
|
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -187,16 +208,20 @@ func EditAuthSource(ctx *context.Context) {
|
||||||
ctx.Data["Source"] = source
|
ctx.Data["Source"] = source
|
||||||
ctx.Data["HasTLS"] = source.HasTLS()
|
ctx.Data["HasTLS"] = source.HasTLS()
|
||||||
|
|
||||||
|
if source.IsOAuth2() {
|
||||||
|
ctx.Data["CurrentOAuth2Provider"] = models.OAuth2Providers[source.OAuth2().Provider]
|
||||||
|
}
|
||||||
ctx.HTML(200, tplAuthEdit)
|
ctx.HTML(200, tplAuthEdit)
|
||||||
}
|
}
|
||||||
|
|
||||||
// EditAuthSourcePost resposne for editing auth source
|
// EditAuthSourcePost response for editing auth source
|
||||||
func EditAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
func EditAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
||||||
ctx.Data["Title"] = ctx.Tr("admin.auths.edit")
|
ctx.Data["Title"] = ctx.Tr("admin.auths.edit")
|
||||||
ctx.Data["PageIsAdmin"] = true
|
ctx.Data["PageIsAdmin"] = true
|
||||||
ctx.Data["PageIsAdminAuthentications"] = true
|
ctx.Data["PageIsAdminAuthentications"] = true
|
||||||
|
|
||||||
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
ctx.Data["SMTPAuths"] = models.SMTPAuths
|
||||||
|
ctx.Data["OAuth2Providers"] = models.OAuth2Providers
|
||||||
|
|
||||||
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid"))
|
source, err := models.GetLoginSourceByID(ctx.ParamsInt64(":authid"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -221,6 +246,8 @@ func EditAuthSourcePost(ctx *context.Context, form auth.AuthenticationForm) {
|
||||||
config = &models.PAMConfig{
|
config = &models.PAMConfig{
|
||||||
ServiceName: form.PAMServiceName,
|
ServiceName: form.PAMServiceName,
|
||||||
}
|
}
|
||||||
|
case models.LoginOAuth2:
|
||||||
|
config = parseOAuth2Config(form)
|
||||||
default:
|
default:
|
||||||
ctx.Error(400)
|
ctx.Error(400)
|
||||||
return
|
return
|
||||||
|
|
|
@ -54,6 +54,7 @@ func GlobalInit() {
|
||||||
log.Fatal(4, "Failed to initialize ORM engine: %v", err)
|
log.Fatal(4, "Failed to initialize ORM engine: %v", err)
|
||||||
}
|
}
|
||||||
models.HasEngine = true
|
models.HasEngine = true
|
||||||
|
models.InitOAuth2()
|
||||||
|
|
||||||
models.LoadRepoConfig()
|
models.LoadRepoConfig()
|
||||||
models.NewRepoContext()
|
models.NewRepoContext()
|
||||||
|
|
|
@ -59,7 +59,7 @@ func HTTP(ctx *context.Context) {
|
||||||
isWiki := false
|
isWiki := false
|
||||||
if strings.HasSuffix(reponame, ".wiki") {
|
if strings.HasSuffix(reponame, ".wiki") {
|
||||||
isWiki = true
|
isWiki = true
|
||||||
reponame = reponame[:len(reponame)-5]
|
reponame = reponame[:len(reponame) - 5]
|
||||||
}
|
}
|
||||||
|
|
||||||
repoUser, err := models.GetUserByName(username)
|
repoUser, err := models.GetUserByName(username)
|
||||||
|
@ -191,9 +191,9 @@ func HTTP(ctx *context.Context) {
|
||||||
|
|
||||||
var lastLine int64
|
var lastLine int64
|
||||||
for {
|
for {
|
||||||
head := input[lastLine : lastLine+2]
|
head := input[lastLine: lastLine + 2]
|
||||||
if head[0] == '0' && head[1] == '0' {
|
if head[0] == '0' && head[1] == '0' {
|
||||||
size, err := strconv.ParseInt(string(input[lastLine+2:lastLine+4]), 16, 32)
|
size, err := strconv.ParseInt(string(input[lastLine + 2:lastLine + 4]), 16, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(4, "%v", err)
|
log.Error(4, "%v", err)
|
||||||
return
|
return
|
||||||
|
@ -204,7 +204,7 @@ func HTTP(ctx *context.Context) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
line := input[lastLine : lastLine+size]
|
line := input[lastLine: lastLine + size]
|
||||||
idx := bytes.IndexRune(line, '\000')
|
idx := bytes.IndexRune(line, '\000')
|
||||||
if idx > -1 {
|
if idx > -1 {
|
||||||
line = line[:idx]
|
line = line[:idx]
|
||||||
|
@ -370,7 +370,7 @@ func gitCommand(dir string, args ...string) []byte {
|
||||||
|
|
||||||
func getGitConfig(option, dir string) string {
|
func getGitConfig(option, dir string) string {
|
||||||
out := string(gitCommand(dir, "config", option))
|
out := string(gitCommand(dir, "config", option))
|
||||||
return out[0 : len(out)-1]
|
return out[0: len(out) - 1]
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConfigSetting(service, dir string) bool {
|
func getConfigSetting(service, dir string) bool {
|
||||||
|
@ -501,7 +501,7 @@ func updateServerInfo(dir string) []byte {
|
||||||
}
|
}
|
||||||
|
|
||||||
func packetWrite(str string) []byte {
|
func packetWrite(str string) []byte {
|
||||||
s := strconv.FormatInt(int64(len(str)+4), 16)
|
s := strconv.FormatInt(int64(len(str) + 4), 16)
|
||||||
if len(s)%4 != 0 {
|
if len(s)%4 != 0 {
|
||||||
s = strings.Repeat("0", 4-len(s)%4) + s
|
s = strings.Repeat("0", 4-len(s)%4) + s
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,10 @@ import (
|
||||||
"code.gitea.io/gitea/modules/context"
|
"code.gitea.io/gitea/modules/context"
|
||||||
"code.gitea.io/gitea/modules/log"
|
"code.gitea.io/gitea/modules/log"
|
||||||
"code.gitea.io/gitea/modules/setting"
|
"code.gitea.io/gitea/modules/setting"
|
||||||
|
"net/http"
|
||||||
|
"code.gitea.io/gitea/modules/auth/oauth2"
|
||||||
|
"github.com/markbates/goth"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -30,6 +34,7 @@ const (
|
||||||
tplResetPassword base.TplName = "user/auth/reset_passwd"
|
tplResetPassword base.TplName = "user/auth/reset_passwd"
|
||||||
tplTwofa base.TplName = "user/auth/twofa"
|
tplTwofa base.TplName = "user/auth/twofa"
|
||||||
tplTwofaScratch base.TplName = "user/auth/twofa_scratch"
|
tplTwofaScratch base.TplName = "user/auth/twofa_scratch"
|
||||||
|
tplLinkAccount base.TplName = "user/auth/link_account"
|
||||||
)
|
)
|
||||||
|
|
||||||
// AutoSignIn reads cookie and try to auto-login.
|
// AutoSignIn reads cookie and try to auto-login.
|
||||||
|
@ -61,7 +66,7 @@ func AutoSignIn(ctx *context.Context) (bool, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if val, _ := ctx.GetSuperSecureCookie(
|
if val, _ := ctx.GetSuperSecureCookie(
|
||||||
base.EncodeMD5(u.Rands+u.Passwd), setting.CookieRememberName); val != u.Name {
|
base.EncodeMD5(u.Rands + u.Passwd), setting.CookieRememberName); val != u.Name {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,6 +114,13 @@ func SignIn(ctx *context.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oauth2Providers, err := models.GetActiveOAuth2Providers()
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "UserSignIn", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx.Data["OAuth2Providers"] = oauth2Providers
|
||||||
|
|
||||||
ctx.HTML(200, tplSignIn)
|
ctx.HTML(200, tplSignIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -116,6 +128,13 @@ func SignIn(ctx *context.Context) {
|
||||||
func SignInPost(ctx *context.Context, form auth.SignInForm) {
|
func SignInPost(ctx *context.Context, form auth.SignInForm) {
|
||||||
ctx.Data["Title"] = ctx.Tr("sign_in")
|
ctx.Data["Title"] = ctx.Tr("sign_in")
|
||||||
|
|
||||||
|
oauth2Providers, err := models.GetActiveOAuth2Providers()
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "UserSignIn", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx.Data["OAuth2Providers"] = oauth2Providers
|
||||||
|
|
||||||
if ctx.HasError() {
|
if ctx.HasError() {
|
||||||
ctx.HTML(200, tplSignIn)
|
ctx.HTML(200, tplSignIn)
|
||||||
return
|
return
|
||||||
|
@ -277,7 +296,7 @@ func handleSignInFull(ctx *context.Context, u *models.User, remember bool, obeyR
|
||||||
if remember {
|
if remember {
|
||||||
days := 86400 * setting.LogInRememberDays
|
days := 86400 * setting.LogInRememberDays
|
||||||
ctx.SetCookie(setting.CookieUserName, u.Name, days, setting.AppSubURL)
|
ctx.SetCookie(setting.CookieUserName, u.Name, days, setting.AppSubURL)
|
||||||
ctx.SetSuperSecureCookie(base.EncodeMD5(u.Rands+u.Passwd),
|
ctx.SetSuperSecureCookie(base.EncodeMD5(u.Rands + u.Passwd),
|
||||||
setting.CookieRememberName, u.Name, days, setting.AppSubURL)
|
setting.CookieRememberName, u.Name, days, setting.AppSubURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -309,6 +328,333 @@ func handleSignInFull(ctx *context.Context, u *models.User, remember bool, obeyR
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignInOAuth handles the OAuth2 login buttons
|
||||||
|
func SignInOAuth(ctx *context.Context) {
|
||||||
|
provider := ctx.Params(":provider")
|
||||||
|
|
||||||
|
loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "SignIn", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// try to do a direct callback flow, so we don't authenticate the user again but use the valid accesstoken to get the user
|
||||||
|
user, gothUser, err := oAuth2UserLoginCallback(loginSource, ctx.Req.Request, ctx.Resp)
|
||||||
|
if err == nil && user != nil {
|
||||||
|
// we got the user without going through the whole OAuth2 authentication flow again
|
||||||
|
handleOAuth2SignIn(user, gothUser, ctx, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err = oauth2.Auth(loginSource.Name, ctx.Req.Request, ctx.Resp)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "SignIn", err)
|
||||||
|
}
|
||||||
|
// redirect is done in oauth2.Auth
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignInOAuthCallback handles the callback from the given provider
|
||||||
|
func SignInOAuthCallback(ctx *context.Context) {
|
||||||
|
provider := ctx.Params(":provider")
|
||||||
|
|
||||||
|
// first look if the provider is still active
|
||||||
|
loginSource, err := models.GetActiveOAuth2LoginSourceByName(provider)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "SignIn", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginSource == nil {
|
||||||
|
ctx.Handle(500, "SignIn", errors.New("No valid provider found, check configured callback url in provider"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, gothUser, err := oAuth2UserLoginCallback(loginSource, ctx.Req.Request, ctx.Resp)
|
||||||
|
|
||||||
|
handleOAuth2SignIn(u, gothUser, ctx, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleOAuth2SignIn(u *models.User, gothUser goth.User, ctx *context.Context, err error) {
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "UserSignIn", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if u == nil {
|
||||||
|
// no existing user is found, request attach or new account
|
||||||
|
ctx.Session.Set("linkAccountGothUser", gothUser)
|
||||||
|
ctx.Redirect(setting.AppSubURL + "/user/link_account")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this user is enrolled in 2FA, we can't sign the user in just yet.
|
||||||
|
// Instead, redirect them to the 2FA authentication page.
|
||||||
|
_, err = models.GetTwoFactorByUID(u.ID)
|
||||||
|
if err != nil {
|
||||||
|
if models.IsErrTwoFactorNotEnrolled(err) {
|
||||||
|
ctx.Session.Set("uid", u.ID)
|
||||||
|
ctx.Session.Set("uname", u.Name)
|
||||||
|
|
||||||
|
// Clear whatever CSRF has right now, force to generate a new one
|
||||||
|
ctx.SetCookie(setting.CSRFCookieName, "", -1, setting.AppSubURL)
|
||||||
|
|
||||||
|
// Register last login
|
||||||
|
u.SetLastLogin()
|
||||||
|
if err := models.UpdateUser(u); err != nil {
|
||||||
|
ctx.Handle(500, "UpdateUser", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if redirectTo, _ := url.QueryUnescape(ctx.GetCookie("redirect_to")); len(redirectTo) > 0 {
|
||||||
|
ctx.SetCookie("redirect_to", "", -1, setting.AppSubURL)
|
||||||
|
ctx.Redirect(redirectTo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Redirect(setting.AppSubURL + "/")
|
||||||
|
} else {
|
||||||
|
ctx.Handle(500, "UserSignIn", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// User needs to use 2FA, save data and redirect to 2FA page.
|
||||||
|
ctx.Session.Set("twofaUid", u.ID)
|
||||||
|
ctx.Session.Set("twofaRemember", false)
|
||||||
|
ctx.Redirect(setting.AppSubURL + "/user/two_factor")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuth2UserLoginCallback attempts to handle the callback from the OAuth2 provider and if successful
|
||||||
|
// login the user
|
||||||
|
func oAuth2UserLoginCallback(loginSource *models.LoginSource, request *http.Request, response http.ResponseWriter) (*models.User, goth.User, error) {
|
||||||
|
gothUser, err := oauth2.ProviderCallback(loginSource.Name, request, response)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &models.User{
|
||||||
|
LoginName: gothUser.UserID,
|
||||||
|
LoginType: models.LoginOAuth2,
|
||||||
|
LoginSource: loginSource.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
hasUser, err := models.GetUser(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasUser {
|
||||||
|
return user, goth.User{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// search in external linked users
|
||||||
|
externalLoginUser := &models.ExternalLoginUser{
|
||||||
|
ExternalID: gothUser.UserID,
|
||||||
|
LoginSourceID: loginSource.ID,
|
||||||
|
}
|
||||||
|
hasUser, err = models.GetExternalLogin(externalLoginUser)
|
||||||
|
if err != nil {
|
||||||
|
return nil, goth.User{}, err
|
||||||
|
}
|
||||||
|
if hasUser {
|
||||||
|
user, err = models.GetUserByID(externalLoginUser.UserID)
|
||||||
|
return user, goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// no user found to login
|
||||||
|
return nil, gothUser, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAccount shows the page where the user can decide to login or create a new account
|
||||||
|
func LinkAccount(ctx *context.Context) {
|
||||||
|
ctx.Data["Title"] = ctx.Tr("link_account")
|
||||||
|
ctx.Data["LinkAccountMode"] = true
|
||||||
|
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
|
||||||
|
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration
|
||||||
|
ctx.Data["ShowRegistrationButton"] = false
|
||||||
|
|
||||||
|
// use this to set the right link into the signIn and signUp templates in the link_account template
|
||||||
|
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
|
||||||
|
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"
|
||||||
|
|
||||||
|
gothUser := ctx.Session.Get("linkAccountGothUser")
|
||||||
|
if gothUser == nil {
|
||||||
|
ctx.Handle(500, "UserSignIn", errors.New("not in LinkAccount session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Data["user_name"] = gothUser.(goth.User).NickName
|
||||||
|
ctx.Data["email"] = gothUser.(goth.User).Email
|
||||||
|
|
||||||
|
ctx.HTML(200, tplLinkAccount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAccountPostSignIn handle the coupling of external account with another account using signIn
|
||||||
|
func LinkAccountPostSignIn(ctx *context.Context, signInForm auth.SignInForm) {
|
||||||
|
ctx.Data["Title"] = ctx.Tr("link_account")
|
||||||
|
ctx.Data["LinkAccountMode"] = true
|
||||||
|
ctx.Data["LinkAccountModeSignIn"] = true
|
||||||
|
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
|
||||||
|
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration
|
||||||
|
ctx.Data["ShowRegistrationButton"] = false
|
||||||
|
|
||||||
|
// use this to set the right link into the signIn and signUp templates in the link_account template
|
||||||
|
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
|
||||||
|
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"
|
||||||
|
|
||||||
|
gothUser := ctx.Session.Get("linkAccountGothUser")
|
||||||
|
if gothUser == nil {
|
||||||
|
ctx.Handle(500, "UserSignIn", errors.New("not in LinkAccount session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.HasError() {
|
||||||
|
ctx.HTML(200, tplLinkAccount)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := models.UserSignIn(signInForm.UserName, signInForm.Password)
|
||||||
|
if err != nil {
|
||||||
|
if models.IsErrUserNotExist(err) {
|
||||||
|
ctx.RenderWithErr(ctx.Tr("form.username_password_incorrect"), tplLinkAccount, &signInForm)
|
||||||
|
} else {
|
||||||
|
ctx.Handle(500, "UserLinkAccount", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this user is enrolled in 2FA, we can't sign the user in just yet.
|
||||||
|
// Instead, redirect them to the 2FA authentication page.
|
||||||
|
_, err = models.GetTwoFactorByUID(u.ID)
|
||||||
|
if err != nil {
|
||||||
|
if models.IsErrTwoFactorNotEnrolled(err) {
|
||||||
|
models.LinkAccountToUser(u, gothUser.(goth.User))
|
||||||
|
handleSignIn(ctx, u, signInForm.Remember)
|
||||||
|
} else {
|
||||||
|
ctx.Handle(500, "UserLinkAccount", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// User needs to use 2FA, save data and redirect to 2FA page.
|
||||||
|
ctx.Session.Set("twofaUid", u.ID)
|
||||||
|
ctx.Session.Set("twofaRemember", signInForm.Remember)
|
||||||
|
ctx.Session.Set("linkAccount", true)
|
||||||
|
|
||||||
|
ctx.Redirect(setting.AppSubURL + "/user/two_factor")
|
||||||
|
}
|
||||||
|
|
||||||
|
// LinkAccountPostRegister handle the creation of a new account for an external account using signUp
|
||||||
|
func LinkAccountPostRegister(ctx *context.Context, cpt *captcha.Captcha, form auth.RegisterForm) {
|
||||||
|
ctx.Data["Title"] = ctx.Tr("link_account")
|
||||||
|
ctx.Data["LinkAccountMode"] = true
|
||||||
|
ctx.Data["LinkAccountModeRegister"] = true
|
||||||
|
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
|
||||||
|
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration
|
||||||
|
ctx.Data["ShowRegistrationButton"] = false
|
||||||
|
|
||||||
|
// use this to set the right link into the signIn and signUp templates in the link_account template
|
||||||
|
ctx.Data["SignInLink"] = setting.AppSubURL + "/user/link_account_signin"
|
||||||
|
ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/link_account_signup"
|
||||||
|
|
||||||
|
gothUser := ctx.Session.Get("linkAccountGothUser")
|
||||||
|
if gothUser == nil {
|
||||||
|
ctx.Handle(500, "UserSignUp", errors.New("not in LinkAccount session"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx.HasError() {
|
||||||
|
ctx.HTML(200, tplLinkAccount)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if setting.Service.DisableRegistration {
|
||||||
|
ctx.Error(403)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if setting.Service.EnableCaptcha && !cpt.VerifyReq(ctx.Req) {
|
||||||
|
ctx.Data["Err_Captcha"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("form.captcha_incorrect"), tplLinkAccount, &form)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (len(strings.TrimSpace(form.Password)) > 0 || len(strings.TrimSpace(form.Retype)) > 0) && form.Password != form.Retype {
|
||||||
|
ctx.Data["Err_Password"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("form.password_not_match"), tplLinkAccount, &form)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(strings.TrimSpace(form.Password)) > 0 && len(form.Password) < setting.MinPasswordLength {
|
||||||
|
ctx.Data["Err_Password"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("auth.password_too_short", setting.MinPasswordLength), tplLinkAccount, &form)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
loginSource, err := models.GetActiveOAuth2LoginSourceByName(gothUser.(goth.User).Provider)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "CreateUser", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u := &models.User{
|
||||||
|
Name: form.UserName,
|
||||||
|
Email: form.Email,
|
||||||
|
Passwd: form.Password,
|
||||||
|
IsActive: !setting.Service.RegisterEmailConfirm,
|
||||||
|
LoginType: models.LoginOAuth2,
|
||||||
|
LoginSource: loginSource.ID,
|
||||||
|
LoginName: gothUser.(goth.User).UserID,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := models.CreateUser(u); err != nil {
|
||||||
|
switch {
|
||||||
|
case models.IsErrUserAlreadyExist(err):
|
||||||
|
ctx.Data["Err_UserName"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("form.username_been_taken"), tplLinkAccount, &form)
|
||||||
|
case models.IsErrEmailAlreadyUsed(err):
|
||||||
|
ctx.Data["Err_Email"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("form.email_been_used"), tplLinkAccount, &form)
|
||||||
|
case models.IsErrNameReserved(err):
|
||||||
|
ctx.Data["Err_UserName"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("user.form.name_reserved", err.(models.ErrNameReserved).Name), tplLinkAccount, &form)
|
||||||
|
case models.IsErrNamePatternNotAllowed(err):
|
||||||
|
ctx.Data["Err_UserName"] = true
|
||||||
|
ctx.RenderWithErr(ctx.Tr("user.form.name_pattern_not_allowed", err.(models.ErrNamePatternNotAllowed).Pattern), tplLinkAccount, &form)
|
||||||
|
default:
|
||||||
|
ctx.Handle(500, "CreateUser", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Trace("Account created: %s", u.Name)
|
||||||
|
|
||||||
|
// Auto-set admin for the only user.
|
||||||
|
if models.CountUsers() == 1 {
|
||||||
|
u.IsAdmin = true
|
||||||
|
u.IsActive = true
|
||||||
|
if err := models.UpdateUser(u); err != nil {
|
||||||
|
ctx.Handle(500, "UpdateUser", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send confirmation email
|
||||||
|
if setting.Service.RegisterEmailConfirm && u.ID > 1 {
|
||||||
|
models.SendActivateAccountMail(ctx.Context, u)
|
||||||
|
ctx.Data["IsSendRegisterMail"] = true
|
||||||
|
ctx.Data["Email"] = u.Email
|
||||||
|
ctx.Data["Hours"] = setting.Service.ActiveCodeLives / 60
|
||||||
|
ctx.HTML(200, TplActivate)
|
||||||
|
|
||||||
|
if err := ctx.Cache.Put("MailResendLimit_"+u.LowerName, u.LowerName, 180); err != nil {
|
||||||
|
log.Error(4, "Set cache(MailResendLimit) fail: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.Redirect(setting.AppSubURL + "/user/login")
|
||||||
|
}
|
||||||
|
|
||||||
// SignOut sign out from login status
|
// SignOut sign out from login status
|
||||||
func SignOut(ctx *context.Context) {
|
func SignOut(ctx *context.Context) {
|
||||||
ctx.Session.Delete("uid")
|
ctx.Session.Delete("uid")
|
||||||
|
@ -328,11 +674,7 @@ func SignUp(ctx *context.Context) {
|
||||||
|
|
||||||
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
|
ctx.Data["EnableCaptcha"] = setting.Service.EnableCaptcha
|
||||||
|
|
||||||
if setting.Service.DisableRegistration {
|
ctx.Data["DisableRegistration"] = setting.Service.DisableRegistration
|
||||||
ctx.Data["DisableRegistration"] = true
|
|
||||||
ctx.HTML(200, tplSignUp)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx.HTML(200, tplSignUp)
|
ctx.HTML(200, tplSignUp)
|
||||||
}
|
}
|
||||||
|
@ -540,7 +882,7 @@ func ForgotPasswdPost(ctx *context.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !u.IsLocal() {
|
if !u.IsLocal() && !u.IsOAuth2() {
|
||||||
ctx.Data["Err_Email"] = true
|
ctx.Data["Err_Email"] = true
|
||||||
ctx.RenderWithErr(ctx.Tr("auth.non_local_account"), tplForgotPassword, nil)
|
ctx.RenderWithErr(ctx.Tr("auth.non_local_account"), tplForgotPassword, nil)
|
||||||
return
|
return
|
||||||
|
|
|
@ -37,6 +37,7 @@ const (
|
||||||
tplSettingsApplications base.TplName = "user/settings/applications"
|
tplSettingsApplications base.TplName = "user/settings/applications"
|
||||||
tplSettingsTwofa base.TplName = "user/settings/twofa"
|
tplSettingsTwofa base.TplName = "user/settings/twofa"
|
||||||
tplSettingsTwofaEnroll base.TplName = "user/settings/twofa_enroll"
|
tplSettingsTwofaEnroll base.TplName = "user/settings/twofa_enroll"
|
||||||
|
tplSettingsAccountLink base.TplName = "user/settings/account_link"
|
||||||
tplSettingsDelete base.TplName = "user/settings/delete"
|
tplSettingsDelete base.TplName = "user/settings/delete"
|
||||||
tplSecurity base.TplName = "user/security"
|
tplSecurity base.TplName = "user/security"
|
||||||
)
|
)
|
||||||
|
@ -200,7 +201,7 @@ func SettingsPasswordPost(ctx *context.Context, form auth.ChangePasswordForm) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ctx.User.ValidatePassword(form.OldPassword) {
|
if ctx.User.IsPasswordSet() && !ctx.User.ValidatePassword(form.OldPassword) {
|
||||||
ctx.Flash.Error(ctx.Tr("settings.password_incorrect"))
|
ctx.Flash.Error(ctx.Tr("settings.password_incorrect"))
|
||||||
} else if form.Password != form.Retype {
|
} else if form.Password != form.Retype {
|
||||||
ctx.Flash.Error(ctx.Tr("form.password_not_match"))
|
ctx.Flash.Error(ctx.Tr("form.password_not_match"))
|
||||||
|
@ -631,6 +632,49 @@ func SettingsTwoFactorEnrollPost(ctx *context.Context, form auth.TwoFactorAuthFo
|
||||||
ctx.Redirect(setting.AppSubURL + "/user/settings/two_factor")
|
ctx.Redirect(setting.AppSubURL + "/user/settings/two_factor")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SettingsAccountLinks render the account links settings page
|
||||||
|
func SettingsAccountLinks(ctx *context.Context) {
|
||||||
|
ctx.Data["Title"] = ctx.Tr("settings")
|
||||||
|
ctx.Data["PageIsSettingsAccountLink"] = true
|
||||||
|
|
||||||
|
accountLinks, err := models.ListAccountLinks(ctx.User)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Handle(500, "ListAccountLinks", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// map the provider display name with the LoginSource
|
||||||
|
sources := make(map[*models.LoginSource]string)
|
||||||
|
for _, externalAccount := range accountLinks {
|
||||||
|
if loginSource, err := models.GetLoginSourceByID(externalAccount.LoginSourceID); err == nil {
|
||||||
|
var providerDisplayName string
|
||||||
|
if loginSource.IsOAuth2() {
|
||||||
|
providerTechnicalName := loginSource.OAuth2().Provider
|
||||||
|
providerDisplayName = models.OAuth2Providers[providerTechnicalName].DisplayName
|
||||||
|
} else {
|
||||||
|
providerDisplayName = loginSource.Name
|
||||||
|
}
|
||||||
|
sources[loginSource] = providerDisplayName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ctx.Data["AccountLinks"] = sources
|
||||||
|
|
||||||
|
ctx.HTML(200, tplSettingsAccountLink)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SettingsDeleteAccountLink delete a single account link
|
||||||
|
func SettingsDeleteAccountLink(ctx *context.Context) {
|
||||||
|
if _, err := models.RemoveAccountLink(ctx.User, ctx.QueryInt64("loginSourceID")); err != nil {
|
||||||
|
ctx.Flash.Error("RemoveAccountLink: " + err.Error())
|
||||||
|
} else {
|
||||||
|
ctx.Flash.Success(ctx.Tr("settings.remove_account_link_success"))
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx.JSON(200, map[string]interface{}{
|
||||||
|
"redirect": setting.AppSubURL + "/user/settings/account_link",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// SettingsDelete render user suicide page and response for delete user himself
|
// SettingsDelete render user suicide page and response for delete user himself
|
||||||
func SettingsDelete(ctx *context.Context) {
|
func SettingsDelete(ctx *context.Context) {
|
||||||
ctx.Data["Title"] = ctx.Tr("settings")
|
ctx.Data["Title"] = ctx.Tr("settings")
|
||||||
|
|
|
@ -142,6 +142,32 @@
|
||||||
</div>
|
</div>
|
||||||
{{end}}
|
{{end}}
|
||||||
|
|
||||||
|
<!-- OAuth2 -->
|
||||||
|
{{if .Source.IsOAuth2}}
|
||||||
|
{{ $cfg:=.Source.OAuth2 }}
|
||||||
|
<div class="inline required field">
|
||||||
|
<label>{{.i18n.Tr "admin.auths.oauth2_provider"}}</label>
|
||||||
|
<div class="ui selection type dropdown">
|
||||||
|
<input type="hidden" id="oauth2_provider" name="oauth2_provider" value="{{$cfg.Provider}}" required>
|
||||||
|
<div class="text">{{.CurrentOAuth2Provider.DisplayName}}</div>
|
||||||
|
<i class="dropdown icon"></i>
|
||||||
|
<div class="menu">
|
||||||
|
{{range $key, $value := .OAuth2Providers}}
|
||||||
|
<div class="item" data-value="{{$key}}">{{$value.DisplayName}}</div>
|
||||||
|
{{end}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="required field">
|
||||||
|
<label for="oauth2_key">{{.i18n.Tr "admin.auths.oauth2_clientID"}}</label>
|
||||||
|
<input id="oauth2_key" name="oauth2_key" value="{{$cfg.ClientID}}" required>
|
||||||
|
</div>
|
||||||
|
<div class="required field">
|
||||||
|
<label for="oauth2_secret">{{.i18n.Tr "admin.auths.oauth2_clientSecret"}}</label>
|
||||||
|
<input id="oauth2_secret" name="oauth2_secret" value="{{$cfg.ClientSecret}}" required>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
<div class="inline field {{if not .Source.IsSMTP}}hide{{end}}">
|
<div class="inline field {{if not .Source.IsSMTP}}hide{{end}}">
|
||||||
<div class="ui checkbox">
|
<div class="ui checkbox">
|
||||||
<label><strong>{{.i18n.Tr "admin.auths.enable_tls"}}</strong></label>
|
<label><strong>{{.i18n.Tr "admin.auths.enable_tls"}}</strong></label>
|
||||||
|
|
|
@ -133,6 +133,31 @@
|
||||||
<input id="pam_service_name" name="pam_service_name" value="{{.pam_service_name}}" />
|
<input id="pam_service_name" name="pam_service_name" value="{{.pam_service_name}}" />
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
<!-- OAuth2 -->
|
||||||
|
<div class="oauth2 field {{if not (eq .type 6)}}hide{{end}}">
|
||||||
|
<div class="inline required field">
|
||||||
|
<label>{{.i18n.Tr "admin.auths.oauth2_provider"}}</label>
|
||||||
|
<div class="ui selection type dropdown">
|
||||||
|
<input type="hidden" id="oauth2_provider" name="oauth2_provider" value="{{.oauth2_provider}}">
|
||||||
|
<div class="text">{{.oauth2_provider}}</div>
|
||||||
|
<i class="dropdown icon"></i>
|
||||||
|
<div class="menu">
|
||||||
|
{{range $key, $value := .OAuth2Providers}}
|
||||||
|
<div class="item" data-value="{{$key}}">{{$value.DisplayName}}</div>
|
||||||
|
{{end}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div class="required field">
|
||||||
|
<label for="oauth2_key">{{.i18n.Tr "admin.auths.oauth2_clientID"}}</label>
|
||||||
|
<input id="oauth2_key" name="oauth2_key" value="{{.oauth2_key}}">
|
||||||
|
</div>
|
||||||
|
<div class="required field">
|
||||||
|
<label for="oauth2_secret">{{.i18n.Tr "admin.auths.oauth2_clientSecret"}}</label>
|
||||||
|
<input id="oauth2_secret" name="oauth2_secret" value="{{.oauth2_secret}}">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
<div class="ldap field">
|
<div class="ldap field">
|
||||||
<div class="ui checkbox">
|
<div class="ui checkbox">
|
||||||
<label><strong>{{.i18n.Tr "admin.auths.attributes_in_bind"}}</strong></label>
|
<label><strong>{{.i18n.Tr "admin.auths.attributes_in_bind"}}</strong></label>
|
||||||
|
@ -170,6 +195,8 @@
|
||||||
<div class="ui attached segment">
|
<div class="ui attached segment">
|
||||||
<h5>GMail Settings:</h5>
|
<h5>GMail Settings:</h5>
|
||||||
<p>Host: smtp.gmail.com, Port: 587, Enable TLS Encryption: true</p>
|
<p>Host: smtp.gmail.com, Port: 587, Enable TLS Encryption: true</p>
|
||||||
|
<h5>OAuth GitHub:</h5>
|
||||||
|
<p>{{.i18n.Tr "admin.auths.tip.github"}}</p>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
|
|
@ -43,7 +43,7 @@
|
||||||
<input id="email" name="email" type="email" value="{{.User.Email}}" autofocus required>
|
<input id="email" name="email" type="email" value="{{.User.Email}}" autofocus required>
|
||||||
</div>
|
</div>
|
||||||
<input class="fake" type="password">
|
<input class="fake" type="password">
|
||||||
<div class="local field {{if .Err_Password}}error{{end}} {{if not (eq .User.LoginSource 0)}}hide{{end}}">
|
<div class="local field {{if .Err_Password}}error{{end}} {{if not (or (.User.IsLocal) (.User.IsOAuth2))}}hide{{end}}">
|
||||||
<label for="password">{{.i18n.Tr "password"}}</label>
|
<label for="password">{{.i18n.Tr "password"}}</label>
|
||||||
<input id="password" name="password" type="password">
|
<input id="password" name="password" type="password">
|
||||||
<p class="help">{{.i18n.Tr "admin.users.password_helper"}}</p>
|
<p class="help">{{.i18n.Tr "admin.users.password_helper"}}</p>
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
{{template "base/head" .}}
|
||||||
|
<div class="user link-account">
|
||||||
|
<div class="ui middle very relaxed page grid">
|
||||||
|
<div class="column">
|
||||||
|
<p class="large center">
|
||||||
|
{{.i18n.Tr "link_account_signin_or_signup"}}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{template "user/auth/signin_inner" .}}
|
||||||
|
{{template "user/auth/signup_inner" .}}
|
||||||
|
{{template "base/footer" .}}
|
|
@ -1,44 +1,3 @@
|
||||||
{{template "base/head" .}}
|
{{template "base/head" .}}
|
||||||
<div class="user signin">
|
{{template "user/auth/signin_inner" .}}
|
||||||
<div class="ui middle very relaxed page grid">
|
|
||||||
<div class="column">
|
|
||||||
<form class="ui form" action="{{.Link}}" method="post">
|
|
||||||
{{.CsrfTokenHtml}}
|
|
||||||
<h3 class="ui top attached header">
|
|
||||||
{{.i18n.Tr "sign_in"}}
|
|
||||||
</h3>
|
|
||||||
<div class="ui attached segment">
|
|
||||||
{{template "base/alert" .}}
|
|
||||||
<div class="required inline field {{if .Err_UserName}}error{{end}}">
|
|
||||||
<label for="user_name">{{.i18n.Tr "home.uname_holder"}}</label>
|
|
||||||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required>
|
|
||||||
</div>
|
|
||||||
<div class="required inline field {{if .Err_Password}}error{{end}}">
|
|
||||||
<label for="password">{{.i18n.Tr "password"}}</label>
|
|
||||||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required>
|
|
||||||
</div>
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
<div class="ui checkbox">
|
|
||||||
<label>{{.i18n.Tr "auth.remember_me"}}</label>
|
|
||||||
<input name="remember" type="checkbox">
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
<button class="ui green button">{{.i18n.Tr "sign_in"}}</button>
|
|
||||||
<a href="{{AppSubUrl}}/user/forget_password">{{.i18n.Tr "auth.forget_password"}}</a>
|
|
||||||
</div>
|
|
||||||
{{if .ShowRegistrationButton}}
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
<a href="{{AppSubUrl}}/user/sign_up">{{.i18n.Tr "auth.sign_up_now" | Str2html}}</a>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{template "base/footer" .}}
|
{{template "base/footer" .}}
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
<div class="user signin{{if .LinkAccountMode}} icon{{end}}">
|
||||||
|
<div class="ui middle very relaxed page grid">
|
||||||
|
<div class="column">
|
||||||
|
<form class="ui form" action="{{if not .LinkAccountMode}}{{.Link}}{{else}}{{.SignInLink}}{{end}}" method="post">
|
||||||
|
{{.CsrfTokenHtml}}
|
||||||
|
<h3 class="ui top attached header">
|
||||||
|
{{.i18n.Tr "sign_in"}}
|
||||||
|
</h3>
|
||||||
|
<div class="ui attached segment">
|
||||||
|
{{if or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn)}}
|
||||||
|
{{template "base/alert" .}}
|
||||||
|
{{end}}
|
||||||
|
<div class="required inline field {{if and (.Err_UserName) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn))}}error{{end}}">
|
||||||
|
<label for="user_name">{{.i18n.Tr "home.uname_holder"}}</label>
|
||||||
|
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required>
|
||||||
|
</div>
|
||||||
|
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeSignIn))}}error{{end}}">
|
||||||
|
<label for="password">{{.i18n.Tr "password"}}</label>
|
||||||
|
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required>
|
||||||
|
</div>
|
||||||
|
{{if not .LinkAccountMode}}
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
<div class="ui checkbox">
|
||||||
|
<label>{{.i18n.Tr "auth.remember_me"}}</label>
|
||||||
|
<input name="remember" type="checkbox">
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
<button class="ui green button">{{.i18n.Tr "sign_in"}}</button>
|
||||||
|
<a href="{{AppSubUrl}}/user/forget_password">{{.i18n.Tr "auth.forget_password"}}</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{if .ShowRegistrationButton}}
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
<a href="{{AppSubUrl}}/user/sign_up">{{.i18n.Tr "auth.sign_up_now" | Str2html}}</a>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{if .OAuth2Providers}}
|
||||||
|
<div class="ui attached segment">
|
||||||
|
<div class="oauth2 center">
|
||||||
|
<div>
|
||||||
|
<p>{{.i18n.Tr "sign_in_with"}}</p>{{range $key, $value := .OAuth2Providers}}<a href="{{AppSubUrl}}/user/oauth2/{{$key}}"><img alt="{{$value.DisplayName}}" title="{{$value.DisplayName}}" src="{{AppSubUrl}}{{$value.Image}}"></a>{{end}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
|
@ -1,56 +1,3 @@
|
||||||
{{template "base/head" .}}
|
{{template "base/head" .}}
|
||||||
<div class="user signup">
|
{{template "user/auth/signup_inner" .}}
|
||||||
<div class="ui middle very relaxed page grid">
|
|
||||||
<div class="column">
|
|
||||||
<form class="ui form" action="{{.Link}}" method="post">
|
|
||||||
{{.CsrfTokenHtml}}
|
|
||||||
<h3 class="ui top attached header">
|
|
||||||
{{if .IsSocialLogin}}{{.i18n.Tr "social_sign_in" | Str2html}}{{else}}{{.i18n.Tr "sign_up"}}{{end}}
|
|
||||||
</h3>
|
|
||||||
<div class="ui attached segment">
|
|
||||||
{{template "base/alert" .}}
|
|
||||||
{{if .DisableRegistration}}
|
|
||||||
<p>{{.i18n.Tr "auth.disable_register_prompt"}}</p>
|
|
||||||
{{else}}
|
|
||||||
<div class="required inline field {{if .Err_UserName}}error{{end}}">
|
|
||||||
<label for="user_name">{{.i18n.Tr "username"}}</label>
|
|
||||||
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required>
|
|
||||||
</div>
|
|
||||||
<div class="required inline field {{if .Err_Email}}error{{end}}">
|
|
||||||
<label for="email">{{.i18n.Tr "email"}}</label>
|
|
||||||
<input id="email" name="email" type="email" value="{{.email}}" required>
|
|
||||||
</div>
|
|
||||||
<div class="required inline field {{if .Err_Password}}error{{end}}">
|
|
||||||
<label for="password">{{.i18n.Tr "password"}}</label>
|
|
||||||
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required>
|
|
||||||
</div>
|
|
||||||
<div class="required inline field {{if .Err_Password}}error{{end}}">
|
|
||||||
<label for="retype">{{.i18n.Tr "re_type"}}</label>
|
|
||||||
<input id="retype" name="retype" type="password" value="{{.retype}}" autocomplete="off" required>
|
|
||||||
</div>
|
|
||||||
{{if .EnableCaptcha}}
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
{{.Captcha.CreateHtml}}
|
|
||||||
</div>
|
|
||||||
<div class="required inline field {{if .Err_Captcha}}error{{end}}">
|
|
||||||
<label for="captcha">{{.i18n.Tr "captcha"}}</label>
|
|
||||||
<input id="captcha" name="captcha" value="{{.captcha}}" autocomplete="off">
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
<button class="ui green button">{{.i18n.Tr "auth.create_new_account"}}</button>
|
|
||||||
</div>
|
|
||||||
<div class="inline field">
|
|
||||||
<label></label>
|
|
||||||
<a href="{{AppSubUrl}}/user/login">{{if .IsSocialLogin}}{{.i18n.Tr "auth.social_register_helper_msg"}}{{else}}{{.i18n.Tr "auth.register_helper_msg"}}{{end}}</a>
|
|
||||||
</div>
|
|
||||||
{{end}}
|
|
||||||
</div>
|
|
||||||
</form>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
{{template "base/footer" .}}
|
{{template "base/footer" .}}
|
||||||
|
|
|
@ -0,0 +1,59 @@
|
||||||
|
<div class="user signup{{if .LinkAccountMode}} icon{{end}}">
|
||||||
|
<div class="ui middle very relaxed page grid">
|
||||||
|
<div class="column">
|
||||||
|
<form class="ui form" action="{{if not .LinkAccountMode}}{{.Link}}{{else}}{{.SignUpLink}}{{end}}" method="post">
|
||||||
|
{{.CsrfTokenHtml}}
|
||||||
|
<h3 class="ui top attached header">
|
||||||
|
{{.i18n.Tr "sign_up"}}
|
||||||
|
</h3>
|
||||||
|
<div class="ui attached segment">
|
||||||
|
{{if or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister)}}
|
||||||
|
{{template "base/alert" .}}
|
||||||
|
{{end}}
|
||||||
|
{{if .DisableRegistration}}
|
||||||
|
<p>{{.i18n.Tr "auth.disable_register_prompt"}}</p>
|
||||||
|
{{else}}
|
||||||
|
<div class="required inline field {{if and (.Err_UserName) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}">
|
||||||
|
<label for="user_name">{{.i18n.Tr "username"}}</label>
|
||||||
|
<input id="user_name" name="user_name" value="{{.user_name}}" autofocus required>
|
||||||
|
</div>
|
||||||
|
<div class="required inline field {{if .Err_Email}}error{{end}}">
|
||||||
|
<label for="email">{{.i18n.Tr "email"}}</label>
|
||||||
|
<input id="email" name="email" type="email" value="{{.email}}" required>
|
||||||
|
</div>
|
||||||
|
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}">
|
||||||
|
<label for="password">{{.i18n.Tr "password"}}</label>
|
||||||
|
<input id="password" name="password" type="password" value="{{.password}}" autocomplete="off" required>
|
||||||
|
</div>
|
||||||
|
<div class="required inline field {{if and (.Err_Password) (or (not .LinkAccountMode) (and .LinkAccountMode .LinkAccountModeRegister))}}error{{end}}">
|
||||||
|
<label for="retype">{{.i18n.Tr "re_type"}}</label>
|
||||||
|
<input id="retype" name="retype" type="password" value="{{.retype}}" autocomplete="off" required>
|
||||||
|
</div>
|
||||||
|
{{if .EnableCaptcha}}
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
{{.Captcha.CreateHtml}}
|
||||||
|
</div>
|
||||||
|
<div class="required inline field {{if .Err_Captcha}}error{{end}}">
|
||||||
|
<label for="captcha">{{.i18n.Tr "captcha"}}</label>
|
||||||
|
<input id="captcha" name="captcha" value="{{.captcha}}" autocomplete="off">
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
<button class="ui green button">{{.i18n.Tr "auth.create_new_account"}}</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{if not .LinkAccountMode}}
|
||||||
|
<div class="inline field">
|
||||||
|
<label></label>
|
||||||
|
<a href="{{AppSubUrl}}/user/login">{{.i18n.Tr "auth.register_helper_msg"}}</a>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
|
@ -0,0 +1,48 @@
|
||||||
|
{{template "base/head" .}}
|
||||||
|
<div class="user settings account_link">
|
||||||
|
<div class="ui container">
|
||||||
|
<div class="ui grid">
|
||||||
|
{{template "user/settings/navbar" .}}
|
||||||
|
<div class="twelve wide column content">
|
||||||
|
{{template "base/alert" .}}
|
||||||
|
<h4 class="ui top attached header">
|
||||||
|
{{.i18n.Tr "settings.manage_account_links"}}
|
||||||
|
</h4>
|
||||||
|
<div class="ui attached segment">
|
||||||
|
<div class="ui key list">
|
||||||
|
<div class="item">
|
||||||
|
{{.i18n.Tr "settings.manage_account_links_desc"}}
|
||||||
|
</div>
|
||||||
|
{{if .AccountLinks}}
|
||||||
|
{{range $loginSource, $provider := .AccountLinks}}
|
||||||
|
<div class="item ui grid">
|
||||||
|
<div class="column">
|
||||||
|
<strong>{{$provider}}</strong>
|
||||||
|
{{if $loginSource.IsActived}}<span class="text red">{{$.i18n.Tr "settings.active"}}</span>{{end}}
|
||||||
|
<div class="ui right">
|
||||||
|
<button class="ui red tiny button delete-button" data-url="{{$.Link}}" data-id="{{$loginSource.ID}}">
|
||||||
|
{{$.i18n.Tr "settings.delete_key"}}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="ui small basic delete modal">
|
||||||
|
<div class="ui icon header">
|
||||||
|
<i class="trash icon"></i>
|
||||||
|
{{.i18n.Tr "settings.remove_account_link"}}
|
||||||
|
</div>
|
||||||
|
<div class="content">
|
||||||
|
<p>{{.i18n.Tr "settings.remove_account_link_desc"}}</p>
|
||||||
|
</div>
|
||||||
|
{{template "base/delete_modal_actions" .}}
|
||||||
|
</div>
|
||||||
|
{{template "base/footer" .}}
|
|
@ -22,6 +22,9 @@
|
||||||
<a class="{{if .PageIsSettingsTwofa}}active{{end}} item" href="{{AppSubUrl}}/user/settings/two_factor">
|
<a class="{{if .PageIsSettingsTwofa}}active{{end}} item" href="{{AppSubUrl}}/user/settings/two_factor">
|
||||||
{{.i18n.Tr "settings.twofa"}}
|
{{.i18n.Tr "settings.twofa"}}
|
||||||
</a>
|
</a>
|
||||||
|
<a class="{{if .PageIsSettingsAccountLink}}active{{end}} item" href="{{AppSubUrl}}/user/settings/account_link">
|
||||||
|
{{.i18n.Tr "settings.account_link"}}
|
||||||
|
</a>
|
||||||
<a class="{{if .PageIsSettingsDelete}}active{{end}} item" href="{{AppSubUrl}}/user/settings/delete">
|
<a class="{{if .PageIsSettingsDelete}}active{{end}} item" href="{{AppSubUrl}}/user/settings/delete">
|
||||||
{{.i18n.Tr "settings.delete"}}
|
{{.i18n.Tr "settings.delete"}}
|
||||||
</a>
|
</a>
|
||||||
|
|
|
@ -9,13 +9,15 @@
|
||||||
{{.i18n.Tr "settings.change_password"}}
|
{{.i18n.Tr "settings.change_password"}}
|
||||||
</h4>
|
</h4>
|
||||||
<div class="ui attached segment">
|
<div class="ui attached segment">
|
||||||
{{if .SignedUser.IsLocal}}
|
{{if or (.SignedUser.IsLocal) (.SignedUser.IsOAuth2)}}
|
||||||
<form class="ui form" action="{{.Link}}" method="post">
|
<form class="ui form" action="{{.Link}}" method="post">
|
||||||
{{.CsrfTokenHtml}}
|
{{.CsrfTokenHtml}}
|
||||||
|
{{if .SignedUser.IsPasswordSet}}
|
||||||
<div class="required field {{if .Err_OldPassword}}error{{end}}">
|
<div class="required field {{if .Err_OldPassword}}error{{end}}">
|
||||||
<label for="old_password">{{.i18n.Tr "settings.old_password"}}</label>
|
<label for="old_password">{{.i18n.Tr "settings.old_password"}}</label>
|
||||||
<input id="old_password" name="old_password" type="password" autocomplete="off" autofocus required>
|
<input id="old_password" name="old_password" type="password" autocomplete="off" autofocus required>
|
||||||
</div>
|
</div>
|
||||||
|
{{end}}
|
||||||
<div class="required field {{if .Err_Password}}error{{end}}">
|
<div class="required field {{if .Err_Password}}error{{end}}">
|
||||||
<label for="password">{{.i18n.Tr "settings.new_password"}}</label>
|
<label for="password">{{.i18n.Tr "settings.new_password"}}</label>
|
||||||
<input id="password" name="password" type="password" autocomplete="off" required>
|
<input id="password" name="password" type="password" autocomplete="off" required>
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,10 @@
|
||||||
|
context
|
||||||
|
=======
|
||||||
|
[![Build Status](https://travis-ci.org/gorilla/context.png?branch=master)](https://travis-ci.org/gorilla/context)
|
||||||
|
|
||||||
|
gorilla/context is a general purpose registry for global request variables.
|
||||||
|
|
||||||
|
> Note: gorilla/context, having been born well before `context.Context` existed, does not play well
|
||||||
|
> with the shallow copying of the request that [`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext) (added to net/http Go 1.7 onwards) performs. You should either use *just* gorilla/context, or moving forward, the new `http.Request.Context()`.
|
||||||
|
|
||||||
|
Read the full documentation here: http://www.gorillatoolkit.org/pkg/context
|
|
@ -0,0 +1,143 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package context
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
mutex sync.RWMutex
|
||||||
|
data = make(map[*http.Request]map[interface{}]interface{})
|
||||||
|
datat = make(map[*http.Request]int64)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Set stores a value for a given key in a given request.
|
||||||
|
func Set(r *http.Request, key, val interface{}) {
|
||||||
|
mutex.Lock()
|
||||||
|
if data[r] == nil {
|
||||||
|
data[r] = make(map[interface{}]interface{})
|
||||||
|
datat[r] = time.Now().Unix()
|
||||||
|
}
|
||||||
|
data[r][key] = val
|
||||||
|
mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a value stored for a given key in a given request.
|
||||||
|
func Get(r *http.Request, key interface{}) interface{} {
|
||||||
|
mutex.RLock()
|
||||||
|
if ctx := data[r]; ctx != nil {
|
||||||
|
value := ctx[key]
|
||||||
|
mutex.RUnlock()
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
mutex.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOk returns stored value and presence state like multi-value return of map access.
|
||||||
|
func GetOk(r *http.Request, key interface{}) (interface{}, bool) {
|
||||||
|
mutex.RLock()
|
||||||
|
if _, ok := data[r]; ok {
|
||||||
|
value, ok := data[r][key]
|
||||||
|
mutex.RUnlock()
|
||||||
|
return value, ok
|
||||||
|
}
|
||||||
|
mutex.RUnlock()
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAll returns all stored values for the request as a map. Nil is returned for invalid requests.
|
||||||
|
func GetAll(r *http.Request) map[interface{}]interface{} {
|
||||||
|
mutex.RLock()
|
||||||
|
if context, ok := data[r]; ok {
|
||||||
|
result := make(map[interface{}]interface{}, len(context))
|
||||||
|
for k, v := range context {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
mutex.RUnlock()
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
mutex.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAllOk returns all stored values for the request as a map and a boolean value that indicates if
|
||||||
|
// the request was registered.
|
||||||
|
func GetAllOk(r *http.Request) (map[interface{}]interface{}, bool) {
|
||||||
|
mutex.RLock()
|
||||||
|
context, ok := data[r]
|
||||||
|
result := make(map[interface{}]interface{}, len(context))
|
||||||
|
for k, v := range context {
|
||||||
|
result[k] = v
|
||||||
|
}
|
||||||
|
mutex.RUnlock()
|
||||||
|
return result, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes a value stored for a given key in a given request.
|
||||||
|
func Delete(r *http.Request, key interface{}) {
|
||||||
|
mutex.Lock()
|
||||||
|
if data[r] != nil {
|
||||||
|
delete(data[r], key)
|
||||||
|
}
|
||||||
|
mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear removes all values stored for a given request.
|
||||||
|
//
|
||||||
|
// This is usually called by a handler wrapper to clean up request
|
||||||
|
// variables at the end of a request lifetime. See ClearHandler().
|
||||||
|
func Clear(r *http.Request) {
|
||||||
|
mutex.Lock()
|
||||||
|
clear(r)
|
||||||
|
mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear is Clear without the lock.
|
||||||
|
func clear(r *http.Request) {
|
||||||
|
delete(data, r)
|
||||||
|
delete(datat, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Purge removes request data stored for longer than maxAge, in seconds.
|
||||||
|
// It returns the amount of requests removed.
|
||||||
|
//
|
||||||
|
// If maxAge <= 0, all request data is removed.
|
||||||
|
//
|
||||||
|
// This is only used for sanity check: in case context cleaning was not
|
||||||
|
// properly set some request data can be kept forever, consuming an increasing
|
||||||
|
// amount of memory. In case this is detected, Purge() must be called
|
||||||
|
// periodically until the problem is fixed.
|
||||||
|
func Purge(maxAge int) int {
|
||||||
|
mutex.Lock()
|
||||||
|
count := 0
|
||||||
|
if maxAge <= 0 {
|
||||||
|
count = len(data)
|
||||||
|
data = make(map[*http.Request]map[interface{}]interface{})
|
||||||
|
datat = make(map[*http.Request]int64)
|
||||||
|
} else {
|
||||||
|
min := time.Now().Unix() - int64(maxAge)
|
||||||
|
for r := range data {
|
||||||
|
if datat[r] < min {
|
||||||
|
clear(r)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mutex.Unlock()
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearHandler wraps an http.Handler and clears request values at the end
|
||||||
|
// of a request lifetime.
|
||||||
|
func ClearHandler(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer Clear(r)
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
|
@ -0,0 +1,88 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package context stores values shared during a request lifetime.
|
||||||
|
|
||||||
|
Note: gorilla/context, having been born well before `context.Context` existed,
|
||||||
|
does not play well > with the shallow copying of the request that
|
||||||
|
[`http.Request.WithContext`](https://golang.org/pkg/net/http/#Request.WithContext)
|
||||||
|
(added to net/http Go 1.7 onwards) performs. You should either use *just*
|
||||||
|
gorilla/context, or moving forward, the new `http.Request.Context()`.
|
||||||
|
|
||||||
|
For example, a router can set variables extracted from the URL and later
|
||||||
|
application handlers can access those values, or it can be used to store
|
||||||
|
sessions values to be saved at the end of a request. There are several
|
||||||
|
others common uses.
|
||||||
|
|
||||||
|
The idea was posted by Brad Fitzpatrick to the go-nuts mailing list:
|
||||||
|
|
||||||
|
http://groups.google.com/group/golang-nuts/msg/e2d679d303aa5d53
|
||||||
|
|
||||||
|
Here's the basic usage: first define the keys that you will need. The key
|
||||||
|
type is interface{} so a key can be of any type that supports equality.
|
||||||
|
Here we define a key using a custom int type to avoid name collisions:
|
||||||
|
|
||||||
|
package foo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/gorilla/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
type key int
|
||||||
|
|
||||||
|
const MyKey key = 0
|
||||||
|
|
||||||
|
Then set a variable. Variables are bound to an http.Request object, so you
|
||||||
|
need a request instance to set a value:
|
||||||
|
|
||||||
|
context.Set(r, MyKey, "bar")
|
||||||
|
|
||||||
|
The application can later access the variable using the same key you provided:
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// val is "bar".
|
||||||
|
val := context.Get(r, foo.MyKey)
|
||||||
|
|
||||||
|
// returns ("bar", true)
|
||||||
|
val, ok := context.GetOk(r, foo.MyKey)
|
||||||
|
// ...
|
||||||
|
}
|
||||||
|
|
||||||
|
And that's all about the basic usage. We discuss some other ideas below.
|
||||||
|
|
||||||
|
Any type can be stored in the context. To enforce a given type, make the key
|
||||||
|
private and wrap Get() and Set() to accept and return values of a specific
|
||||||
|
type:
|
||||||
|
|
||||||
|
type key int
|
||||||
|
|
||||||
|
const mykey key = 0
|
||||||
|
|
||||||
|
// GetMyKey returns a value for this package from the request values.
|
||||||
|
func GetMyKey(r *http.Request) SomeType {
|
||||||
|
if rv := context.Get(r, mykey); rv != nil {
|
||||||
|
return rv.(SomeType)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMyKey sets a value for this package in the request values.
|
||||||
|
func SetMyKey(r *http.Request, val SomeType) {
|
||||||
|
context.Set(r, mykey, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
Variables must be cleared at the end of a request, to remove all values
|
||||||
|
that were stored. This can be done in an http.Handler, after a request was
|
||||||
|
served. Just call Clear() passing the request:
|
||||||
|
|
||||||
|
context.Clear(r)
|
||||||
|
|
||||||
|
...or use ClearHandler(), which conveniently wraps an http.Handler to clear
|
||||||
|
variables at the end of a request lifetime.
|
||||||
|
|
||||||
|
The Routers from the packages gorilla/mux and gorilla/pat call Clear()
|
||||||
|
so if you are using either of them you don't need to clear the context manually.
|
||||||
|
*/
|
||||||
|
package context
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,299 @@
|
||||||
|
gorilla/mux
|
||||||
|
===
|
||||||
|
[![GoDoc](https://godoc.org/github.com/gorilla/mux?status.svg)](https://godoc.org/github.com/gorilla/mux)
|
||||||
|
[![Build Status](https://travis-ci.org/gorilla/mux.svg?branch=master)](https://travis-ci.org/gorilla/mux)
|
||||||
|
|
||||||
|
![Gorilla Logo](http://www.gorillatoolkit.org/static/images/gorilla-icon-64.png)
|
||||||
|
|
||||||
|
http://www.gorillatoolkit.org/pkg/mux
|
||||||
|
|
||||||
|
Package `gorilla/mux` implements a request router and dispatcher for matching incoming requests to
|
||||||
|
their respective handler.
|
||||||
|
|
||||||
|
The name mux stands for "HTTP request multiplexer". Like the standard `http.ServeMux`, `mux.Router` matches incoming requests against a list of registered routes and calls a handler for the route that matches the URL or other conditions. The main features are:
|
||||||
|
|
||||||
|
* It implements the `http.Handler` interface so it is compatible with the standard `http.ServeMux`.
|
||||||
|
* Requests can be matched based on URL host, path, path prefix, schemes, header and query values, HTTP methods or using custom matchers.
|
||||||
|
* URL hosts and paths can have variables with an optional regular expression.
|
||||||
|
* Registered URLs can be built, or "reversed", which helps maintaining references to resources.
|
||||||
|
* Routes can be used as subrouters: nested routes are only tested if the parent route matches. This is useful to define groups of routes that share common conditions like a host, a path prefix or other repeated attributes. As a bonus, this optimizes request matching.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
* [Install](#install)
|
||||||
|
* [Examples](#examples)
|
||||||
|
* [Matching Routes](#matching-routes)
|
||||||
|
* [Static Files](#static-files)
|
||||||
|
* [Registered URLs](#registered-urls)
|
||||||
|
* [Full Example](#full-example)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
With a [correctly configured](https://golang.org/doc/install#testing) Go toolchain:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
go get -u github.com/gorilla/mux
|
||||||
|
```
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
Let's start registering a couple of URL paths and handlers:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/", HomeHandler)
|
||||||
|
r.HandleFunc("/products", ProductsHandler)
|
||||||
|
r.HandleFunc("/articles", ArticlesHandler)
|
||||||
|
http.Handle("/", r)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Here we register three routes mapping URL paths to handlers. This is equivalent to how `http.HandleFunc()` works: if an incoming request URL matches one of the paths, the corresponding handler is called passing (`http.ResponseWriter`, `*http.Request`) as parameters.
|
||||||
|
|
||||||
|
Paths can have variables. They are defined using the format `{name}` or `{name:pattern}`. If a regular expression pattern is not defined, the matched variable will be anything until the next slash. For example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/products/{key}", ProductHandler)
|
||||||
|
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
|
||||||
|
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
The names are used to create a map of route variables which can be retrieved calling `mux.Vars()`:
|
||||||
|
|
||||||
|
```go
|
||||||
|
vars := mux.Vars(request)
|
||||||
|
category := vars["category"]
|
||||||
|
```
|
||||||
|
|
||||||
|
And this is all you need to know about the basic usage. More advanced options are explained below.
|
||||||
|
|
||||||
|
### Matching Routes
|
||||||
|
|
||||||
|
Routes can also be restricted to a domain or subdomain. Just define a host pattern to be matched. They can also have variables:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
// Only matches if domain is "www.example.com".
|
||||||
|
r.Host("www.example.com")
|
||||||
|
// Matches a dynamic subdomain.
|
||||||
|
r.Host("{subdomain:[a-z]+}.domain.com")
|
||||||
|
```
|
||||||
|
|
||||||
|
There are several other matchers that can be added. To match path prefixes:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.PathPrefix("/products/")
|
||||||
|
```
|
||||||
|
|
||||||
|
...or HTTP methods:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Methods("GET", "POST")
|
||||||
|
```
|
||||||
|
|
||||||
|
...or URL schemes:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Schemes("https")
|
||||||
|
```
|
||||||
|
|
||||||
|
...or header values:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Headers("X-Requested-With", "XMLHttpRequest")
|
||||||
|
```
|
||||||
|
|
||||||
|
...or query values:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.Queries("key", "value")
|
||||||
|
```
|
||||||
|
|
||||||
|
...or to use a custom matcher function:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
|
||||||
|
return r.ProtoMajor == 0
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
...and finally, it is possible to combine several matchers in a single route:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.HandleFunc("/products", ProductsHandler).
|
||||||
|
Host("www.example.com").
|
||||||
|
Methods("GET").
|
||||||
|
Schemes("http")
|
||||||
|
```
|
||||||
|
|
||||||
|
Setting the same matching conditions again and again can be boring, so we have a way to group several routes that share the same requirements. We call it "subrouting".
|
||||||
|
|
||||||
|
For example, let's say we have several URLs that should only match when the host is `www.example.com`. Create a route for that host and get a "subrouter" from it:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.Host("www.example.com").Subrouter()
|
||||||
|
```
|
||||||
|
|
||||||
|
Then register routes in the subrouter:
|
||||||
|
|
||||||
|
```go
|
||||||
|
s.HandleFunc("/products/", ProductsHandler)
|
||||||
|
s.HandleFunc("/products/{key}", ProductHandler)
|
||||||
|
s.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
The three URL paths we registered above will only be tested if the domain is `www.example.com`, because the subrouter is tested first. This is not only convenient, but also optimizes request matching. You can create subrouters combining any attribute matchers accepted by a route.
|
||||||
|
|
||||||
|
Subrouters can be used to create domain or path "namespaces": you define subrouters in a central place and then parts of the app can register its paths relatively to a given subrouter.
|
||||||
|
|
||||||
|
There's one more thing about subroutes. When a subrouter has a path prefix, the inner routes use it as base for their paths:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.PathPrefix("/products").Subrouter()
|
||||||
|
// "/products/"
|
||||||
|
s.HandleFunc("/", ProductsHandler)
|
||||||
|
// "/products/{key}/"
|
||||||
|
s.HandleFunc("/{key}/", ProductHandler)
|
||||||
|
// "/products/{key}/details"
|
||||||
|
s.HandleFunc("/{key}/details", ProductDetailsHandler)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Static Files
|
||||||
|
|
||||||
|
Note that the path provided to `PathPrefix()` represents a "wildcard": calling
|
||||||
|
`PathPrefix("/static/").Handler(...)` means that the handler will be passed any
|
||||||
|
request that matches "/static/*". This makes it easy to serve static files with mux:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func main() {
|
||||||
|
var dir string
|
||||||
|
|
||||||
|
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
|
||||||
|
flag.Parse()
|
||||||
|
r := mux.NewRouter()
|
||||||
|
|
||||||
|
// This will serve files under http://localhost:8000/static/<filename>
|
||||||
|
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: r,
|
||||||
|
Addr: "127.0.0.1:8000",
|
||||||
|
// Good practice: enforce timeouts for servers you create!
|
||||||
|
WriteTimeout: 15 * time.Second,
|
||||||
|
ReadTimeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Fatal(srv.ListenAndServe())
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Registered URLs
|
||||||
|
|
||||||
|
Now let's see how to build registered URLs.
|
||||||
|
|
||||||
|
Routes can be named. All routes that define a name can have their URLs built, or "reversed". We define a name calling `Name()` on a route. For example:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
```
|
||||||
|
|
||||||
|
To build a URL, get the route and call the `URL()` method, passing a sequence of key/value pairs for the route variables. For the previous route, we would do:
|
||||||
|
|
||||||
|
```go
|
||||||
|
url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||||
|
```
|
||||||
|
|
||||||
|
...and the result will be a `url.URL` with the following path:
|
||||||
|
|
||||||
|
```
|
||||||
|
"/articles/technology/42"
|
||||||
|
```
|
||||||
|
|
||||||
|
This also works for host variables:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.Host("{subdomain}.domain.com").
|
||||||
|
Path("/articles/{category}/{id:[0-9]+}").
|
||||||
|
HandlerFunc(ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
|
||||||
|
// url.String() will be "http://news.domain.com/articles/technology/42"
|
||||||
|
url, err := r.Get("article").URL("subdomain", "news",
|
||||||
|
"category", "technology",
|
||||||
|
"id", "42")
|
||||||
|
```
|
||||||
|
|
||||||
|
All variables defined in the route are required, and their values must conform to the corresponding patterns. These requirements guarantee that a generated URL will always match a registered route -- the only exception is for explicitly defined "build-only" routes which never match.
|
||||||
|
|
||||||
|
Regex support also exists for matching Headers within a route. For example, we could do:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r.HeadersRegexp("Content-Type", "application/(text|json)")
|
||||||
|
```
|
||||||
|
|
||||||
|
...and the route will match both requests with a Content-Type of `application/json` as well as `application/text`
|
||||||
|
|
||||||
|
There's also a way to build only the URL host or path for a route: use the methods `URLHost()` or `URLPath()` instead. For the previous route, we would do:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// "http://news.domain.com/"
|
||||||
|
host, err := r.Get("article").URLHost("subdomain", "news")
|
||||||
|
|
||||||
|
// "/articles/technology/42"
|
||||||
|
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
|
||||||
|
```
|
||||||
|
|
||||||
|
And if you use subrouters, host and path defined separately can be built as well:
|
||||||
|
|
||||||
|
```go
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.Host("{subdomain}.domain.com").Subrouter()
|
||||||
|
s.Path("/articles/{category}/{id:[0-9]+}").
|
||||||
|
HandlerFunc(ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
|
||||||
|
// "http://news.domain.com/articles/technology/42"
|
||||||
|
url, err := r.Get("article").URL("subdomain", "news",
|
||||||
|
"category", "technology",
|
||||||
|
"id", "42")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Full Example
|
||||||
|
|
||||||
|
Here's a complete, runnable example of a small `mux` based server:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"log"
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
)
|
||||||
|
|
||||||
|
func YourHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Write([]byte("Gorilla!\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
// Routes consist of a path and a handler function.
|
||||||
|
r.HandleFunc("/", YourHandler)
|
||||||
|
|
||||||
|
// Bind to a port and pass our router in
|
||||||
|
log.Fatal(http.ListenAndServe(":8000", r))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
BSD licensed. See the LICENSE file for details.
|
|
@ -0,0 +1,26 @@
|
||||||
|
// +build !go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gorilla/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||||
|
return context.Get(r, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||||
|
if val == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
context.Set(r, key, val)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClear(r *http.Request) {
|
||||||
|
context.Clear(r)
|
||||||
|
}
|
|
@ -0,0 +1,24 @@
|
||||||
|
// +build go1.7
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
func contextGet(r *http.Request, key interface{}) interface{} {
|
||||||
|
return r.Context().Value(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextSet(r *http.Request, key, val interface{}) *http.Request {
|
||||||
|
if val == nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.WithContext(context.WithValue(r.Context(), key, val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClear(r *http.Request) {
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,235 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package mux implements a request router and dispatcher.
|
||||||
|
|
||||||
|
The name mux stands for "HTTP request multiplexer". Like the standard
|
||||||
|
http.ServeMux, mux.Router matches incoming requests against a list of
|
||||||
|
registered routes and calls a handler for the route that matches the URL
|
||||||
|
or other conditions. The main features are:
|
||||||
|
|
||||||
|
* Requests can be matched based on URL host, path, path prefix, schemes,
|
||||||
|
header and query values, HTTP methods or using custom matchers.
|
||||||
|
* URL hosts and paths can have variables with an optional regular
|
||||||
|
expression.
|
||||||
|
* Registered URLs can be built, or "reversed", which helps maintaining
|
||||||
|
references to resources.
|
||||||
|
* Routes can be used as subrouters: nested routes are only tested if the
|
||||||
|
parent route matches. This is useful to define groups of routes that
|
||||||
|
share common conditions like a host, a path prefix or other repeated
|
||||||
|
attributes. As a bonus, this optimizes request matching.
|
||||||
|
* It implements the http.Handler interface so it is compatible with the
|
||||||
|
standard http.ServeMux.
|
||||||
|
|
||||||
|
Let's start registering a couple of URL paths and handlers:
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/", HomeHandler)
|
||||||
|
r.HandleFunc("/products", ProductsHandler)
|
||||||
|
r.HandleFunc("/articles", ArticlesHandler)
|
||||||
|
http.Handle("/", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
Here we register three routes mapping URL paths to handlers. This is
|
||||||
|
equivalent to how http.HandleFunc() works: if an incoming request URL matches
|
||||||
|
one of the paths, the corresponding handler is called passing
|
||||||
|
(http.ResponseWriter, *http.Request) as parameters.
|
||||||
|
|
||||||
|
Paths can have variables. They are defined using the format {name} or
|
||||||
|
{name:pattern}. If a regular expression pattern is not defined, the matched
|
||||||
|
variable will be anything until the next slash. For example:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/products/{key}", ProductHandler)
|
||||||
|
r.HandleFunc("/articles/{category}/", ArticlesCategoryHandler)
|
||||||
|
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler)
|
||||||
|
|
||||||
|
Groups can be used inside patterns, as long as they are non-capturing (?:re). For example:
|
||||||
|
|
||||||
|
r.HandleFunc("/articles/{category}/{sort:(?:asc|desc|new)}", ArticlesCategoryHandler)
|
||||||
|
|
||||||
|
The names are used to create a map of route variables which can be retrieved
|
||||||
|
calling mux.Vars():
|
||||||
|
|
||||||
|
vars := mux.Vars(request)
|
||||||
|
category := vars["category"]
|
||||||
|
|
||||||
|
And this is all you need to know about the basic usage. More advanced options
|
||||||
|
are explained below.
|
||||||
|
|
||||||
|
Routes can also be restricted to a domain or subdomain. Just define a host
|
||||||
|
pattern to be matched. They can also have variables:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
// Only matches if domain is "www.example.com".
|
||||||
|
r.Host("www.example.com")
|
||||||
|
// Matches a dynamic subdomain.
|
||||||
|
r.Host("{subdomain:[a-z]+}.domain.com")
|
||||||
|
|
||||||
|
There are several other matchers that can be added. To match path prefixes:
|
||||||
|
|
||||||
|
r.PathPrefix("/products/")
|
||||||
|
|
||||||
|
...or HTTP methods:
|
||||||
|
|
||||||
|
r.Methods("GET", "POST")
|
||||||
|
|
||||||
|
...or URL schemes:
|
||||||
|
|
||||||
|
r.Schemes("https")
|
||||||
|
|
||||||
|
...or header values:
|
||||||
|
|
||||||
|
r.Headers("X-Requested-With", "XMLHttpRequest")
|
||||||
|
|
||||||
|
...or query values:
|
||||||
|
|
||||||
|
r.Queries("key", "value")
|
||||||
|
|
||||||
|
...or to use a custom matcher function:
|
||||||
|
|
||||||
|
r.MatcherFunc(func(r *http.Request, rm *RouteMatch) bool {
|
||||||
|
return r.ProtoMajor == 0
|
||||||
|
})
|
||||||
|
|
||||||
|
...and finally, it is possible to combine several matchers in a single route:
|
||||||
|
|
||||||
|
r.HandleFunc("/products", ProductsHandler).
|
||||||
|
Host("www.example.com").
|
||||||
|
Methods("GET").
|
||||||
|
Schemes("http")
|
||||||
|
|
||||||
|
Setting the same matching conditions again and again can be boring, so we have
|
||||||
|
a way to group several routes that share the same requirements.
|
||||||
|
We call it "subrouting".
|
||||||
|
|
||||||
|
For example, let's say we have several URLs that should only match when the
|
||||||
|
host is "www.example.com". Create a route for that host and get a "subrouter"
|
||||||
|
from it:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.Host("www.example.com").Subrouter()
|
||||||
|
|
||||||
|
Then register routes in the subrouter:
|
||||||
|
|
||||||
|
s.HandleFunc("/products/", ProductsHandler)
|
||||||
|
s.HandleFunc("/products/{key}", ProductHandler)
|
||||||
|
s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
|
||||||
|
|
||||||
|
The three URL paths we registered above will only be tested if the domain is
|
||||||
|
"www.example.com", because the subrouter is tested first. This is not
|
||||||
|
only convenient, but also optimizes request matching. You can create
|
||||||
|
subrouters combining any attribute matchers accepted by a route.
|
||||||
|
|
||||||
|
Subrouters can be used to create domain or path "namespaces": you define
|
||||||
|
subrouters in a central place and then parts of the app can register its
|
||||||
|
paths relatively to a given subrouter.
|
||||||
|
|
||||||
|
There's one more thing about subroutes. When a subrouter has a path prefix,
|
||||||
|
the inner routes use it as base for their paths:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.PathPrefix("/products").Subrouter()
|
||||||
|
// "/products/"
|
||||||
|
s.HandleFunc("/", ProductsHandler)
|
||||||
|
// "/products/{key}/"
|
||||||
|
s.HandleFunc("/{key}/", ProductHandler)
|
||||||
|
// "/products/{key}/details"
|
||||||
|
s.HandleFunc("/{key}/details", ProductDetailsHandler)
|
||||||
|
|
||||||
|
Note that the path provided to PathPrefix() represents a "wildcard": calling
|
||||||
|
PathPrefix("/static/").Handler(...) means that the handler will be passed any
|
||||||
|
request that matches "/static/*". This makes it easy to serve static files with mux:
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var dir string
|
||||||
|
|
||||||
|
flag.StringVar(&dir, "dir", ".", "the directory to serve files from. Defaults to the current dir")
|
||||||
|
flag.Parse()
|
||||||
|
r := mux.NewRouter()
|
||||||
|
|
||||||
|
// This will serve files under http://localhost:8000/static/<filename>
|
||||||
|
r.PathPrefix("/static/").Handler(http.StripPrefix("/static/", http.FileServer(http.Dir(dir))))
|
||||||
|
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: r,
|
||||||
|
Addr: "127.0.0.1:8000",
|
||||||
|
// Good practice: enforce timeouts for servers you create!
|
||||||
|
WriteTimeout: 15 * time.Second,
|
||||||
|
ReadTimeout: 15 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Fatal(srv.ListenAndServe())
|
||||||
|
}
|
||||||
|
|
||||||
|
Now let's see how to build registered URLs.
|
||||||
|
|
||||||
|
Routes can be named. All routes that define a name can have their URLs built,
|
||||||
|
or "reversed". We define a name calling Name() on a route. For example:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
|
||||||
|
To build a URL, get the route and call the URL() method, passing a sequence of
|
||||||
|
key/value pairs for the route variables. For the previous route, we would do:
|
||||||
|
|
||||||
|
url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||||
|
|
||||||
|
...and the result will be a url.URL with the following path:
|
||||||
|
|
||||||
|
"/articles/technology/42"
|
||||||
|
|
||||||
|
This also works for host variables:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
r.Host("{subdomain}.domain.com").
|
||||||
|
Path("/articles/{category}/{id:[0-9]+}").
|
||||||
|
HandlerFunc(ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
|
||||||
|
// url.String() will be "http://news.domain.com/articles/technology/42"
|
||||||
|
url, err := r.Get("article").URL("subdomain", "news",
|
||||||
|
"category", "technology",
|
||||||
|
"id", "42")
|
||||||
|
|
||||||
|
All variables defined in the route are required, and their values must
|
||||||
|
conform to the corresponding patterns. These requirements guarantee that a
|
||||||
|
generated URL will always match a registered route -- the only exception is
|
||||||
|
for explicitly defined "build-only" routes which never match.
|
||||||
|
|
||||||
|
Regex support also exists for matching Headers within a route. For example, we could do:
|
||||||
|
|
||||||
|
r.HeadersRegexp("Content-Type", "application/(text|json)")
|
||||||
|
|
||||||
|
...and the route will match both requests with a Content-Type of `application/json` as well as
|
||||||
|
`application/text`
|
||||||
|
|
||||||
|
There's also a way to build only the URL host or path for a route:
|
||||||
|
use the methods URLHost() or URLPath() instead. For the previous route,
|
||||||
|
we would do:
|
||||||
|
|
||||||
|
// "http://news.domain.com/"
|
||||||
|
host, err := r.Get("article").URLHost("subdomain", "news")
|
||||||
|
|
||||||
|
// "/articles/technology/42"
|
||||||
|
path, err := r.Get("article").URLPath("category", "technology", "id", "42")
|
||||||
|
|
||||||
|
And if you use subrouters, host and path defined separately can be built
|
||||||
|
as well:
|
||||||
|
|
||||||
|
r := mux.NewRouter()
|
||||||
|
s := r.Host("{subdomain}.domain.com").Subrouter()
|
||||||
|
s.Path("/articles/{category}/{id:[0-9]+}").
|
||||||
|
HandlerFunc(ArticleHandler).
|
||||||
|
Name("article")
|
||||||
|
|
||||||
|
// "http://news.domain.com/articles/technology/42"
|
||||||
|
url, err := r.Get("article").URL("subdomain", "news",
|
||||||
|
"category", "technology",
|
||||||
|
"id", "42")
|
||||||
|
*/
|
||||||
|
package mux
|
|
@ -0,0 +1,542 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"path"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRouter returns a new router instance.
|
||||||
|
func NewRouter() *Router {
|
||||||
|
return &Router{namedRoutes: make(map[string]*Route), KeepContext: false}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router registers routes to be matched and dispatches a handler.
|
||||||
|
//
|
||||||
|
// It implements the http.Handler interface, so it can be registered to serve
|
||||||
|
// requests:
|
||||||
|
//
|
||||||
|
// var router = mux.NewRouter()
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// http.Handle("/", router)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Or, for Google App Engine, register it in a init() function:
|
||||||
|
//
|
||||||
|
// func init() {
|
||||||
|
// http.Handle("/", router)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// This will send all incoming requests to the router.
|
||||||
|
type Router struct {
|
||||||
|
// Configurable Handler to be used when no route matches.
|
||||||
|
NotFoundHandler http.Handler
|
||||||
|
// Parent route, if this is a subrouter.
|
||||||
|
parent parentRoute
|
||||||
|
// Routes to be matched, in order.
|
||||||
|
routes []*Route
|
||||||
|
// Routes by name for URL building.
|
||||||
|
namedRoutes map[string]*Route
|
||||||
|
// See Router.StrictSlash(). This defines the flag for new routes.
|
||||||
|
strictSlash bool
|
||||||
|
// See Router.SkipClean(). This defines the flag for new routes.
|
||||||
|
skipClean bool
|
||||||
|
// If true, do not clear the request context after handling the request.
|
||||||
|
// This has no effect when go1.7+ is used, since the context is stored
|
||||||
|
// on the request itself.
|
||||||
|
KeepContext bool
|
||||||
|
// see Router.UseEncodedPath(). This defines a flag for all routes.
|
||||||
|
useEncodedPath bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match matches registered routes against the request.
|
||||||
|
func (r *Router) Match(req *http.Request, match *RouteMatch) bool {
|
||||||
|
for _, route := range r.routes {
|
||||||
|
if route.Match(req, match) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Closest match for a router (includes sub-routers)
|
||||||
|
if r.NotFoundHandler != nil {
|
||||||
|
match.Handler = r.NotFoundHandler
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP dispatches the handler registered in the matched route.
|
||||||
|
//
|
||||||
|
// When there is a match, the route variables can be retrieved calling
|
||||||
|
// mux.Vars(request).
|
||||||
|
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||||
|
if !r.skipClean {
|
||||||
|
path := req.URL.Path
|
||||||
|
if r.useEncodedPath {
|
||||||
|
path = getPath(req)
|
||||||
|
}
|
||||||
|
// Clean path to canonical form and redirect.
|
||||||
|
if p := cleanPath(path); p != path {
|
||||||
|
|
||||||
|
// Added 3 lines (Philip Schlump) - It was dropping the query string and #whatever from query.
|
||||||
|
// This matches with fix in go 1.2 r.c. 4 for same problem. Go Issue:
|
||||||
|
// http://code.google.com/p/go/issues/detail?id=5252
|
||||||
|
url := *req.URL
|
||||||
|
url.Path = p
|
||||||
|
p = url.String()
|
||||||
|
|
||||||
|
w.Header().Set("Location", p)
|
||||||
|
w.WriteHeader(http.StatusMovedPermanently)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var match RouteMatch
|
||||||
|
var handler http.Handler
|
||||||
|
if r.Match(req, &match) {
|
||||||
|
handler = match.Handler
|
||||||
|
req = setVars(req, match.Vars)
|
||||||
|
req = setCurrentRoute(req, match.Route)
|
||||||
|
}
|
||||||
|
if handler == nil {
|
||||||
|
handler = http.NotFoundHandler()
|
||||||
|
}
|
||||||
|
if !r.KeepContext {
|
||||||
|
defer contextClear(req)
|
||||||
|
}
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a route registered with the given name.
|
||||||
|
func (r *Router) Get(name string) *Route {
|
||||||
|
return r.getNamedRoutes()[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRoute returns a route registered with the given name. This method
|
||||||
|
// was renamed to Get() and remains here for backwards compatibility.
|
||||||
|
func (r *Router) GetRoute(name string) *Route {
|
||||||
|
return r.getNamedRoutes()[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
// StrictSlash defines the trailing slash behavior for new routes. The initial
|
||||||
|
// value is false.
|
||||||
|
//
|
||||||
|
// When true, if the route path is "/path/", accessing "/path" will redirect
|
||||||
|
// to the former and vice versa. In other words, your application will always
|
||||||
|
// see the path as specified in the route.
|
||||||
|
//
|
||||||
|
// When false, if the route path is "/path", accessing "/path/" will not match
|
||||||
|
// this route and vice versa.
|
||||||
|
//
|
||||||
|
// Special case: when a route sets a path prefix using the PathPrefix() method,
|
||||||
|
// strict slash is ignored for that route because the redirect behavior can't
|
||||||
|
// be determined from a prefix alone. However, any subrouters created from that
|
||||||
|
// route inherit the original StrictSlash setting.
|
||||||
|
func (r *Router) StrictSlash(value bool) *Router {
|
||||||
|
r.strictSlash = value
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipClean defines the path cleaning behaviour for new routes. The initial
|
||||||
|
// value is false. Users should be careful about which routes are not cleaned
|
||||||
|
//
|
||||||
|
// When true, if the route path is "/path//to", it will remain with the double
|
||||||
|
// slash. This is helpful if you have a route like: /fetch/http://xkcd.com/534/
|
||||||
|
//
|
||||||
|
// When false, the path will be cleaned, so /fetch/http://xkcd.com/534/ will
|
||||||
|
// become /fetch/http/xkcd.com/534
|
||||||
|
func (r *Router) SkipClean(value bool) *Router {
|
||||||
|
r.skipClean = value
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// UseEncodedPath tells the router to match the encoded original path
|
||||||
|
// to the routes.
|
||||||
|
// For eg. "/path/foo%2Fbar/to" will match the path "/path/{var}/to".
|
||||||
|
// This behavior has the drawback of needing to match routes against
|
||||||
|
// r.RequestURI instead of r.URL.Path. Any modifications (such as http.StripPrefix)
|
||||||
|
// to r.URL.Path will not affect routing when this flag is on and thus may
|
||||||
|
// induce unintended behavior.
|
||||||
|
//
|
||||||
|
// If not called, the router will match the unencoded path to the routes.
|
||||||
|
// For eg. "/path/foo%2Fbar/to" will match the path "/path/foo/bar/to"
|
||||||
|
func (r *Router) UseEncodedPath() *Router {
|
||||||
|
r.useEncodedPath = true
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// parentRoute
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// getNamedRoutes returns the map where named routes are registered.
|
||||||
|
func (r *Router) getNamedRoutes() map[string]*Route {
|
||||||
|
if r.namedRoutes == nil {
|
||||||
|
if r.parent != nil {
|
||||||
|
r.namedRoutes = r.parent.getNamedRoutes()
|
||||||
|
} else {
|
||||||
|
r.namedRoutes = make(map[string]*Route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.namedRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRegexpGroup returns regexp definitions from the parent route, if any.
|
||||||
|
func (r *Router) getRegexpGroup() *routeRegexpGroup {
|
||||||
|
if r.parent != nil {
|
||||||
|
return r.parent.getRegexpGroup()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) buildVars(m map[string]string) map[string]string {
|
||||||
|
if r.parent != nil {
|
||||||
|
m = r.parent.buildVars(m)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Route factories
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// NewRoute registers an empty route.
|
||||||
|
func (r *Router) NewRoute() *Route {
|
||||||
|
route := &Route{parent: r, strictSlash: r.strictSlash, skipClean: r.skipClean, useEncodedPath: r.useEncodedPath}
|
||||||
|
r.routes = append(r.routes, route)
|
||||||
|
return route
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle registers a new route with a matcher for the URL path.
|
||||||
|
// See Route.Path() and Route.Handler().
|
||||||
|
func (r *Router) Handle(path string, handler http.Handler) *Route {
|
||||||
|
return r.NewRoute().Path(path).Handler(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleFunc registers a new route with a matcher for the URL path.
|
||||||
|
// See Route.Path() and Route.HandlerFunc().
|
||||||
|
func (r *Router) HandleFunc(path string, f func(http.ResponseWriter,
|
||||||
|
*http.Request)) *Route {
|
||||||
|
return r.NewRoute().Path(path).HandlerFunc(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers registers a new route with a matcher for request header values.
|
||||||
|
// See Route.Headers().
|
||||||
|
func (r *Router) Headers(pairs ...string) *Route {
|
||||||
|
return r.NewRoute().Headers(pairs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host registers a new route with a matcher for the URL host.
|
||||||
|
// See Route.Host().
|
||||||
|
func (r *Router) Host(tpl string) *Route {
|
||||||
|
return r.NewRoute().Host(tpl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatcherFunc registers a new route with a custom matcher function.
|
||||||
|
// See Route.MatcherFunc().
|
||||||
|
func (r *Router) MatcherFunc(f MatcherFunc) *Route {
|
||||||
|
return r.NewRoute().MatcherFunc(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods registers a new route with a matcher for HTTP methods.
|
||||||
|
// See Route.Methods().
|
||||||
|
func (r *Router) Methods(methods ...string) *Route {
|
||||||
|
return r.NewRoute().Methods(methods...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path registers a new route with a matcher for the URL path.
|
||||||
|
// See Route.Path().
|
||||||
|
func (r *Router) Path(tpl string) *Route {
|
||||||
|
return r.NewRoute().Path(tpl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathPrefix registers a new route with a matcher for the URL path prefix.
|
||||||
|
// See Route.PathPrefix().
|
||||||
|
func (r *Router) PathPrefix(tpl string) *Route {
|
||||||
|
return r.NewRoute().PathPrefix(tpl)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Queries registers a new route with a matcher for URL query values.
|
||||||
|
// See Route.Queries().
|
||||||
|
func (r *Router) Queries(pairs ...string) *Route {
|
||||||
|
return r.NewRoute().Queries(pairs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schemes registers a new route with a matcher for URL schemes.
|
||||||
|
// See Route.Schemes().
|
||||||
|
func (r *Router) Schemes(schemes ...string) *Route {
|
||||||
|
return r.NewRoute().Schemes(schemes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildVarsFunc registers a new route with a custom function for modifying
|
||||||
|
// route variables before building a URL.
|
||||||
|
func (r *Router) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||||
|
return r.NewRoute().BuildVarsFunc(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Walk walks the router and all its sub-routers, calling walkFn for each route
|
||||||
|
// in the tree. The routes are walked in the order they were added. Sub-routers
|
||||||
|
// are explored depth-first.
|
||||||
|
func (r *Router) Walk(walkFn WalkFunc) error {
|
||||||
|
return r.walk(walkFn, []*Route{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SkipRouter is used as a return value from WalkFuncs to indicate that the
|
||||||
|
// router that walk is about to descend down to should be skipped.
|
||||||
|
var SkipRouter = errors.New("skip this router")
|
||||||
|
|
||||||
|
// WalkFunc is the type of the function called for each route visited by Walk.
|
||||||
|
// At every invocation, it is given the current route, and the current router,
|
||||||
|
// and a list of ancestor routes that lead to the current route.
|
||||||
|
type WalkFunc func(route *Route, router *Router, ancestors []*Route) error
|
||||||
|
|
||||||
|
func (r *Router) walk(walkFn WalkFunc, ancestors []*Route) error {
|
||||||
|
for _, t := range r.routes {
|
||||||
|
if t.regexp == nil || t.regexp.path == nil || t.regexp.path.template == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err := walkFn(t, r, ancestors)
|
||||||
|
if err == SkipRouter {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, sr := range t.matchers {
|
||||||
|
if h, ok := sr.(*Router); ok {
|
||||||
|
err := h.walk(walkFn, ancestors)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if h, ok := t.handler.(*Router); ok {
|
||||||
|
ancestors = append(ancestors, t)
|
||||||
|
err := h.walk(walkFn, ancestors)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ancestors = ancestors[:len(ancestors)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Context
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// RouteMatch stores information about a matched route.
|
||||||
|
type RouteMatch struct {
|
||||||
|
Route *Route
|
||||||
|
Handler http.Handler
|
||||||
|
Vars map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const (
|
||||||
|
varsKey contextKey = iota
|
||||||
|
routeKey
|
||||||
|
)
|
||||||
|
|
||||||
|
// Vars returns the route variables for the current request, if any.
|
||||||
|
func Vars(r *http.Request) map[string]string {
|
||||||
|
if rv := contextGet(r, varsKey); rv != nil {
|
||||||
|
return rv.(map[string]string)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CurrentRoute returns the matched route for the current request, if any.
|
||||||
|
// This only works when called inside the handler of the matched route
|
||||||
|
// because the matched route is stored in the request context which is cleared
|
||||||
|
// after the handler returns, unless the KeepContext option is set on the
|
||||||
|
// Router.
|
||||||
|
func CurrentRoute(r *http.Request) *Route {
|
||||||
|
if rv := contextGet(r, routeKey); rv != nil {
|
||||||
|
return rv.(*Route)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setVars(r *http.Request, val interface{}) *http.Request {
|
||||||
|
return contextSet(r, varsKey, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setCurrentRoute(r *http.Request, val interface{}) *http.Request {
|
||||||
|
return contextSet(r, routeKey, val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Helpers
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// getPath returns the escaped path if possible; doing what URL.EscapedPath()
|
||||||
|
// which was added in go1.5 does
|
||||||
|
func getPath(req *http.Request) string {
|
||||||
|
if req.RequestURI != "" {
|
||||||
|
// Extract the path from RequestURI (which is escaped unlike URL.Path)
|
||||||
|
// as detailed here as detailed in https://golang.org/pkg/net/url/#URL
|
||||||
|
// for < 1.5 server side workaround
|
||||||
|
// http://localhost/path/here?v=1 -> /path/here
|
||||||
|
path := req.RequestURI
|
||||||
|
path = strings.TrimPrefix(path, req.URL.Scheme+`://`)
|
||||||
|
path = strings.TrimPrefix(path, req.URL.Host)
|
||||||
|
if i := strings.LastIndex(path, "?"); i > -1 {
|
||||||
|
path = path[:i]
|
||||||
|
}
|
||||||
|
if i := strings.LastIndex(path, "#"); i > -1 {
|
||||||
|
path = path[:i]
|
||||||
|
}
|
||||||
|
return path
|
||||||
|
}
|
||||||
|
return req.URL.Path
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanPath returns the canonical path for p, eliminating . and .. elements.
|
||||||
|
// Borrowed from the net/http package.
|
||||||
|
func cleanPath(p string) string {
|
||||||
|
if p == "" {
|
||||||
|
return "/"
|
||||||
|
}
|
||||||
|
if p[0] != '/' {
|
||||||
|
p = "/" + p
|
||||||
|
}
|
||||||
|
np := path.Clean(p)
|
||||||
|
// path.Clean removes trailing slash except for root;
|
||||||
|
// put the trailing slash back if necessary.
|
||||||
|
if p[len(p)-1] == '/' && np != "/" {
|
||||||
|
np += "/"
|
||||||
|
}
|
||||||
|
|
||||||
|
return np
|
||||||
|
}
|
||||||
|
|
||||||
|
// uniqueVars returns an error if two slices contain duplicated strings.
|
||||||
|
func uniqueVars(s1, s2 []string) error {
|
||||||
|
for _, v1 := range s1 {
|
||||||
|
for _, v2 := range s2 {
|
||||||
|
if v1 == v2 {
|
||||||
|
return fmt.Errorf("mux: duplicated route variable %q", v2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPairs returns the count of strings passed in, and an error if
|
||||||
|
// the count is not an even number.
|
||||||
|
func checkPairs(pairs ...string) (int, error) {
|
||||||
|
length := len(pairs)
|
||||||
|
if length%2 != 0 {
|
||||||
|
return length, fmt.Errorf(
|
||||||
|
"mux: number of parameters must be multiple of 2, got %v", pairs)
|
||||||
|
}
|
||||||
|
return length, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapFromPairsToString converts variadic string parameters to a
|
||||||
|
// string to string map.
|
||||||
|
func mapFromPairsToString(pairs ...string) (map[string]string, error) {
|
||||||
|
length, err := checkPairs(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := make(map[string]string, length/2)
|
||||||
|
for i := 0; i < length; i += 2 {
|
||||||
|
m[pairs[i]] = pairs[i+1]
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapFromPairsToRegex converts variadic string paramers to a
|
||||||
|
// string to regex map.
|
||||||
|
func mapFromPairsToRegex(pairs ...string) (map[string]*regexp.Regexp, error) {
|
||||||
|
length, err := checkPairs(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := make(map[string]*regexp.Regexp, length/2)
|
||||||
|
for i := 0; i < length; i += 2 {
|
||||||
|
regex, err := regexp.Compile(pairs[i+1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m[pairs[i]] = regex
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchInArray returns true if the given string value is in the array.
|
||||||
|
func matchInArray(arr []string, value string) bool {
|
||||||
|
for _, v := range arr {
|
||||||
|
if v == value {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchMapWithString returns true if the given key/value pairs exist in a given map.
|
||||||
|
func matchMapWithString(toCheck map[string]string, toMatch map[string][]string, canonicalKey bool) bool {
|
||||||
|
for k, v := range toCheck {
|
||||||
|
// Check if key exists.
|
||||||
|
if canonicalKey {
|
||||||
|
k = http.CanonicalHeaderKey(k)
|
||||||
|
}
|
||||||
|
if values := toMatch[k]; values == nil {
|
||||||
|
return false
|
||||||
|
} else if v != "" {
|
||||||
|
// If value was defined as an empty string we only check that the
|
||||||
|
// key exists. Otherwise we also check for equality.
|
||||||
|
valueExists := false
|
||||||
|
for _, value := range values {
|
||||||
|
if v == value {
|
||||||
|
valueExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valueExists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// matchMapWithRegex returns true if the given key/value pairs exist in a given map compiled against
|
||||||
|
// the given regex
|
||||||
|
func matchMapWithRegex(toCheck map[string]*regexp.Regexp, toMatch map[string][]string, canonicalKey bool) bool {
|
||||||
|
for k, v := range toCheck {
|
||||||
|
// Check if key exists.
|
||||||
|
if canonicalKey {
|
||||||
|
k = http.CanonicalHeaderKey(k)
|
||||||
|
}
|
||||||
|
if values := toMatch[k]; values == nil {
|
||||||
|
return false
|
||||||
|
} else if v != nil {
|
||||||
|
// If value was defined as an empty string we only check that the
|
||||||
|
// key exists. Otherwise we also check for equality.
|
||||||
|
valueExists := false
|
||||||
|
for _, value := range values {
|
||||||
|
if v.MatchString(value) {
|
||||||
|
valueExists = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !valueExists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -0,0 +1,316 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newRouteRegexp parses a route template and returns a routeRegexp,
|
||||||
|
// used to match a host, a path or a query string.
|
||||||
|
//
|
||||||
|
// It will extract named variables, assemble a regexp to be matched, create
|
||||||
|
// a "reverse" template to build URLs and compile regexps to validate variable
|
||||||
|
// values used in URL building.
|
||||||
|
//
|
||||||
|
// Previously we accepted only Python-like identifiers for variable
|
||||||
|
// names ([a-zA-Z_][a-zA-Z0-9_]*), but currently the only restriction is that
|
||||||
|
// name and pattern can't be empty, and names can't contain a colon.
|
||||||
|
func newRouteRegexp(tpl string, matchHost, matchPrefix, matchQuery, strictSlash, useEncodedPath bool) (*routeRegexp, error) {
|
||||||
|
// Check if it is well-formed.
|
||||||
|
idxs, errBraces := braceIndices(tpl)
|
||||||
|
if errBraces != nil {
|
||||||
|
return nil, errBraces
|
||||||
|
}
|
||||||
|
// Backup the original.
|
||||||
|
template := tpl
|
||||||
|
// Now let's parse it.
|
||||||
|
defaultPattern := "[^/]+"
|
||||||
|
if matchQuery {
|
||||||
|
defaultPattern = "[^?&]*"
|
||||||
|
} else if matchHost {
|
||||||
|
defaultPattern = "[^.]+"
|
||||||
|
matchPrefix = false
|
||||||
|
}
|
||||||
|
// Only match strict slash if not matching
|
||||||
|
if matchPrefix || matchHost || matchQuery {
|
||||||
|
strictSlash = false
|
||||||
|
}
|
||||||
|
// Set a flag for strictSlash.
|
||||||
|
endSlash := false
|
||||||
|
if strictSlash && strings.HasSuffix(tpl, "/") {
|
||||||
|
tpl = tpl[:len(tpl)-1]
|
||||||
|
endSlash = true
|
||||||
|
}
|
||||||
|
varsN := make([]string, len(idxs)/2)
|
||||||
|
varsR := make([]*regexp.Regexp, len(idxs)/2)
|
||||||
|
pattern := bytes.NewBufferString("")
|
||||||
|
pattern.WriteByte('^')
|
||||||
|
reverse := bytes.NewBufferString("")
|
||||||
|
var end int
|
||||||
|
var err error
|
||||||
|
for i := 0; i < len(idxs); i += 2 {
|
||||||
|
// Set all values we are interested in.
|
||||||
|
raw := tpl[end:idxs[i]]
|
||||||
|
end = idxs[i+1]
|
||||||
|
parts := strings.SplitN(tpl[idxs[i]+1:end-1], ":", 2)
|
||||||
|
name := parts[0]
|
||||||
|
patt := defaultPattern
|
||||||
|
if len(parts) == 2 {
|
||||||
|
patt = parts[1]
|
||||||
|
}
|
||||||
|
// Name or pattern can't be empty.
|
||||||
|
if name == "" || patt == "" {
|
||||||
|
return nil, fmt.Errorf("mux: missing name or pattern in %q",
|
||||||
|
tpl[idxs[i]:end])
|
||||||
|
}
|
||||||
|
// Build the regexp pattern.
|
||||||
|
fmt.Fprintf(pattern, "%s(?P<%s>%s)", regexp.QuoteMeta(raw), varGroupName(i/2), patt)
|
||||||
|
|
||||||
|
// Build the reverse template.
|
||||||
|
fmt.Fprintf(reverse, "%s%%s", raw)
|
||||||
|
|
||||||
|
// Append variable name and compiled pattern.
|
||||||
|
varsN[i/2] = name
|
||||||
|
varsR[i/2], err = regexp.Compile(fmt.Sprintf("^%s$", patt))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add the remaining.
|
||||||
|
raw := tpl[end:]
|
||||||
|
pattern.WriteString(regexp.QuoteMeta(raw))
|
||||||
|
if strictSlash {
|
||||||
|
pattern.WriteString("[/]?")
|
||||||
|
}
|
||||||
|
if matchQuery {
|
||||||
|
// Add the default pattern if the query value is empty
|
||||||
|
if queryVal := strings.SplitN(template, "=", 2)[1]; queryVal == "" {
|
||||||
|
pattern.WriteString(defaultPattern)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matchPrefix {
|
||||||
|
pattern.WriteByte('$')
|
||||||
|
}
|
||||||
|
reverse.WriteString(raw)
|
||||||
|
if endSlash {
|
||||||
|
reverse.WriteByte('/')
|
||||||
|
}
|
||||||
|
// Compile full regexp.
|
||||||
|
reg, errCompile := regexp.Compile(pattern.String())
|
||||||
|
if errCompile != nil {
|
||||||
|
return nil, errCompile
|
||||||
|
}
|
||||||
|
// Done!
|
||||||
|
return &routeRegexp{
|
||||||
|
template: template,
|
||||||
|
matchHost: matchHost,
|
||||||
|
matchQuery: matchQuery,
|
||||||
|
strictSlash: strictSlash,
|
||||||
|
useEncodedPath: useEncodedPath,
|
||||||
|
regexp: reg,
|
||||||
|
reverse: reverse.String(),
|
||||||
|
varsN: varsN,
|
||||||
|
varsR: varsR,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// routeRegexp stores a regexp to match a host or path and information to
|
||||||
|
// collect and validate route variables.
|
||||||
|
type routeRegexp struct {
|
||||||
|
// The unmodified template.
|
||||||
|
template string
|
||||||
|
// True for host match, false for path or query string match.
|
||||||
|
matchHost bool
|
||||||
|
// True for query string match, false for path and host match.
|
||||||
|
matchQuery bool
|
||||||
|
// The strictSlash value defined on the route, but disabled if PathPrefix was used.
|
||||||
|
strictSlash bool
|
||||||
|
// Determines whether to use encoded path from getPath function or unencoded
|
||||||
|
// req.URL.Path for path matching
|
||||||
|
useEncodedPath bool
|
||||||
|
// Expanded regexp.
|
||||||
|
regexp *regexp.Regexp
|
||||||
|
// Reverse template.
|
||||||
|
reverse string
|
||||||
|
// Variable names.
|
||||||
|
varsN []string
|
||||||
|
// Variable regexps (validators).
|
||||||
|
varsR []*regexp.Regexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match matches the regexp against the URL host or path.
|
||||||
|
func (r *routeRegexp) Match(req *http.Request, match *RouteMatch) bool {
|
||||||
|
if !r.matchHost {
|
||||||
|
if r.matchQuery {
|
||||||
|
return r.matchQueryString(req)
|
||||||
|
}
|
||||||
|
path := req.URL.Path
|
||||||
|
if r.useEncodedPath {
|
||||||
|
path = getPath(req)
|
||||||
|
}
|
||||||
|
return r.regexp.MatchString(path)
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.regexp.MatchString(getHost(req))
|
||||||
|
}
|
||||||
|
|
||||||
|
// url builds a URL part using the given values.
|
||||||
|
func (r *routeRegexp) url(values map[string]string) (string, error) {
|
||||||
|
urlValues := make([]interface{}, len(r.varsN))
|
||||||
|
for k, v := range r.varsN {
|
||||||
|
value, ok := values[v]
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("mux: missing route variable %q", v)
|
||||||
|
}
|
||||||
|
urlValues[k] = value
|
||||||
|
}
|
||||||
|
rv := fmt.Sprintf(r.reverse, urlValues...)
|
||||||
|
if !r.regexp.MatchString(rv) {
|
||||||
|
// The URL is checked against the full regexp, instead of checking
|
||||||
|
// individual variables. This is faster but to provide a good error
|
||||||
|
// message, we check individual regexps if the URL doesn't match.
|
||||||
|
for k, v := range r.varsN {
|
||||||
|
if !r.varsR[k].MatchString(values[v]) {
|
||||||
|
return "", fmt.Errorf(
|
||||||
|
"mux: variable %q doesn't match, expected %q", values[v],
|
||||||
|
r.varsR[k].String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return rv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getURLQuery returns a single query parameter from a request URL.
|
||||||
|
// For a URL with foo=bar&baz=ding, we return only the relevant key
|
||||||
|
// value pair for the routeRegexp.
|
||||||
|
func (r *routeRegexp) getURLQuery(req *http.Request) string {
|
||||||
|
if !r.matchQuery {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
templateKey := strings.SplitN(r.template, "=", 2)[0]
|
||||||
|
for key, vals := range req.URL.Query() {
|
||||||
|
if key == templateKey && len(vals) > 0 {
|
||||||
|
return key + "=" + vals[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *routeRegexp) matchQueryString(req *http.Request) bool {
|
||||||
|
return r.regexp.MatchString(r.getURLQuery(req))
|
||||||
|
}
|
||||||
|
|
||||||
|
// braceIndices returns the first level curly brace indices from a string.
|
||||||
|
// It returns an error in case of unbalanced braces.
|
||||||
|
func braceIndices(s string) ([]int, error) {
|
||||||
|
var level, idx int
|
||||||
|
var idxs []int
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
switch s[i] {
|
||||||
|
case '{':
|
||||||
|
if level++; level == 1 {
|
||||||
|
idx = i
|
||||||
|
}
|
||||||
|
case '}':
|
||||||
|
if level--; level == 0 {
|
||||||
|
idxs = append(idxs, idx, i+1)
|
||||||
|
} else if level < 0 {
|
||||||
|
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if level != 0 {
|
||||||
|
return nil, fmt.Errorf("mux: unbalanced braces in %q", s)
|
||||||
|
}
|
||||||
|
return idxs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// varGroupName builds a capturing group name for the indexed variable.
|
||||||
|
func varGroupName(idx int) string {
|
||||||
|
return "v" + strconv.Itoa(idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// routeRegexpGroup
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// routeRegexpGroup groups the route matchers that carry variables.
|
||||||
|
type routeRegexpGroup struct {
|
||||||
|
host *routeRegexp
|
||||||
|
path *routeRegexp
|
||||||
|
queries []*routeRegexp
|
||||||
|
}
|
||||||
|
|
||||||
|
// setMatch extracts the variables from the URL once a route matches.
|
||||||
|
func (v *routeRegexpGroup) setMatch(req *http.Request, m *RouteMatch, r *Route) {
|
||||||
|
// Store host variables.
|
||||||
|
if v.host != nil {
|
||||||
|
host := getHost(req)
|
||||||
|
matches := v.host.regexp.FindStringSubmatchIndex(host)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
extractVars(host, matches, v.host.varsN, m.Vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
path := req.URL.Path
|
||||||
|
if r.useEncodedPath {
|
||||||
|
path = getPath(req)
|
||||||
|
}
|
||||||
|
// Store path variables.
|
||||||
|
if v.path != nil {
|
||||||
|
matches := v.path.regexp.FindStringSubmatchIndex(path)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
extractVars(path, matches, v.path.varsN, m.Vars)
|
||||||
|
// Check if we should redirect.
|
||||||
|
if v.path.strictSlash {
|
||||||
|
p1 := strings.HasSuffix(path, "/")
|
||||||
|
p2 := strings.HasSuffix(v.path.template, "/")
|
||||||
|
if p1 != p2 {
|
||||||
|
u, _ := url.Parse(req.URL.String())
|
||||||
|
if p1 {
|
||||||
|
u.Path = u.Path[:len(u.Path)-1]
|
||||||
|
} else {
|
||||||
|
u.Path += "/"
|
||||||
|
}
|
||||||
|
m.Handler = http.RedirectHandler(u.String(), 301)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Store query string variables.
|
||||||
|
for _, q := range v.queries {
|
||||||
|
queryURL := q.getURLQuery(req)
|
||||||
|
matches := q.regexp.FindStringSubmatchIndex(queryURL)
|
||||||
|
if len(matches) > 0 {
|
||||||
|
extractVars(queryURL, matches, q.varsN, m.Vars)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getHost tries its best to return the request host.
|
||||||
|
func getHost(r *http.Request) string {
|
||||||
|
if r.URL.IsAbs() {
|
||||||
|
return r.URL.Host
|
||||||
|
}
|
||||||
|
host := r.Host
|
||||||
|
// Slice off any port information.
|
||||||
|
if i := strings.Index(host, ":"); i != -1 {
|
||||||
|
host = host[:i]
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractVars(input string, matches []int, names []string, output map[string]string) {
|
||||||
|
for i, name := range names {
|
||||||
|
output[name] = input[matches[2*i+2]:matches[2*i+3]]
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,636 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package mux
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Route stores information to match a request and build URLs.
|
||||||
|
type Route struct {
|
||||||
|
// Parent where the route was registered (a Router).
|
||||||
|
parent parentRoute
|
||||||
|
// Request handler for the route.
|
||||||
|
handler http.Handler
|
||||||
|
// List of matchers.
|
||||||
|
matchers []matcher
|
||||||
|
// Manager for the variables from host and path.
|
||||||
|
regexp *routeRegexpGroup
|
||||||
|
// If true, when the path pattern is "/path/", accessing "/path" will
|
||||||
|
// redirect to the former and vice versa.
|
||||||
|
strictSlash bool
|
||||||
|
// If true, when the path pattern is "/path//to", accessing "/path//to"
|
||||||
|
// will not redirect
|
||||||
|
skipClean bool
|
||||||
|
// If true, "/path/foo%2Fbar/to" will match the path "/path/{var}/to"
|
||||||
|
useEncodedPath bool
|
||||||
|
// If true, this route never matches: it is only used to build URLs.
|
||||||
|
buildOnly bool
|
||||||
|
// The name used to build URLs.
|
||||||
|
name string
|
||||||
|
// Error resulted from building a route.
|
||||||
|
err error
|
||||||
|
|
||||||
|
buildVarsFunc BuildVarsFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Route) SkipClean() bool {
|
||||||
|
return r.skipClean
|
||||||
|
}
|
||||||
|
|
||||||
|
// Match matches the route against the request.
|
||||||
|
func (r *Route) Match(req *http.Request, match *RouteMatch) bool {
|
||||||
|
if r.buildOnly || r.err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Match everything.
|
||||||
|
for _, m := range r.matchers {
|
||||||
|
if matched := m.Match(req, match); !matched {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Yay, we have a match. Let's collect some info about it.
|
||||||
|
if match.Route == nil {
|
||||||
|
match.Route = r
|
||||||
|
}
|
||||||
|
if match.Handler == nil {
|
||||||
|
match.Handler = r.handler
|
||||||
|
}
|
||||||
|
if match.Vars == nil {
|
||||||
|
match.Vars = make(map[string]string)
|
||||||
|
}
|
||||||
|
// Set variables.
|
||||||
|
if r.regexp != nil {
|
||||||
|
r.regexp.setMatch(req, match, r)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Route attributes
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// GetError returns an error resulted from building the route, if any.
|
||||||
|
func (r *Route) GetError() error {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOnly sets the route to never match: it is only used to build URLs.
|
||||||
|
func (r *Route) BuildOnly() *Route {
|
||||||
|
r.buildOnly = true
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Handler sets a handler for the route.
|
||||||
|
func (r *Route) Handler(handler http.Handler) *Route {
|
||||||
|
if r.err == nil {
|
||||||
|
r.handler = handler
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandlerFunc sets a handler function for the route.
|
||||||
|
func (r *Route) HandlerFunc(f func(http.ResponseWriter, *http.Request)) *Route {
|
||||||
|
return r.Handler(http.HandlerFunc(f))
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHandler returns the handler for the route, if any.
|
||||||
|
func (r *Route) GetHandler() http.Handler {
|
||||||
|
return r.handler
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Name sets the name for the route, used to build URLs.
|
||||||
|
// If the name was registered already it will be overwritten.
|
||||||
|
func (r *Route) Name(name string) *Route {
|
||||||
|
if r.name != "" {
|
||||||
|
r.err = fmt.Errorf("mux: route already has name %q, can't set %q",
|
||||||
|
r.name, name)
|
||||||
|
}
|
||||||
|
if r.err == nil {
|
||||||
|
r.name = name
|
||||||
|
r.getNamedRoutes()[name] = r
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetName returns the name for the route, if any.
|
||||||
|
func (r *Route) GetName() string {
|
||||||
|
return r.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Matchers
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// matcher types try to match a request.
|
||||||
|
type matcher interface {
|
||||||
|
Match(*http.Request, *RouteMatch) bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// addMatcher adds a matcher to the route.
|
||||||
|
func (r *Route) addMatcher(m matcher) *Route {
|
||||||
|
if r.err == nil {
|
||||||
|
r.matchers = append(r.matchers, m)
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRegexpMatcher adds a host or path matcher and builder to a route.
|
||||||
|
func (r *Route) addRegexpMatcher(tpl string, matchHost, matchPrefix, matchQuery bool) error {
|
||||||
|
if r.err != nil {
|
||||||
|
return r.err
|
||||||
|
}
|
||||||
|
r.regexp = r.getRegexpGroup()
|
||||||
|
if !matchHost && !matchQuery {
|
||||||
|
if len(tpl) == 0 || tpl[0] != '/' {
|
||||||
|
return fmt.Errorf("mux: path must start with a slash, got %q", tpl)
|
||||||
|
}
|
||||||
|
if r.regexp.path != nil {
|
||||||
|
tpl = strings.TrimRight(r.regexp.path.template, "/") + tpl
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rr, err := newRouteRegexp(tpl, matchHost, matchPrefix, matchQuery, r.strictSlash, r.useEncodedPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for _, q := range r.regexp.queries {
|
||||||
|
if err = uniqueVars(rr.varsN, q.varsN); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchHost {
|
||||||
|
if r.regexp.path != nil {
|
||||||
|
if err = uniqueVars(rr.varsN, r.regexp.path.varsN); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.regexp.host = rr
|
||||||
|
} else {
|
||||||
|
if r.regexp.host != nil {
|
||||||
|
if err = uniqueVars(rr.varsN, r.regexp.host.varsN); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if matchQuery {
|
||||||
|
r.regexp.queries = append(r.regexp.queries, rr)
|
||||||
|
} else {
|
||||||
|
r.regexp.path = rr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.addMatcher(rr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// headerMatcher matches the request against header values.
|
||||||
|
type headerMatcher map[string]string
|
||||||
|
|
||||||
|
func (m headerMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||||
|
return matchMapWithString(m, r.Header, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Headers adds a matcher for request header values.
|
||||||
|
// It accepts a sequence of key/value pairs to be matched. For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.Headers("Content-Type", "application/json",
|
||||||
|
// "X-Requested-With", "XMLHttpRequest")
|
||||||
|
//
|
||||||
|
// The above route will only match if both request header values match.
|
||||||
|
// If the value is an empty string, it will match any value if the key is set.
|
||||||
|
func (r *Route) Headers(pairs ...string) *Route {
|
||||||
|
if r.err == nil {
|
||||||
|
var headers map[string]string
|
||||||
|
headers, r.err = mapFromPairsToString(pairs...)
|
||||||
|
return r.addMatcher(headerMatcher(headers))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// headerRegexMatcher matches the request against the route given a regex for the header
|
||||||
|
type headerRegexMatcher map[string]*regexp.Regexp
|
||||||
|
|
||||||
|
func (m headerRegexMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||||
|
return matchMapWithRegex(m, r.Header, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HeadersRegexp accepts a sequence of key/value pairs, where the value has regex
|
||||||
|
// support. For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.HeadersRegexp("Content-Type", "application/(text|json)",
|
||||||
|
// "X-Requested-With", "XMLHttpRequest")
|
||||||
|
//
|
||||||
|
// The above route will only match if both the request header matches both regular expressions.
|
||||||
|
// It the value is an empty string, it will match any value if the key is set.
|
||||||
|
func (r *Route) HeadersRegexp(pairs ...string) *Route {
|
||||||
|
if r.err == nil {
|
||||||
|
var headers map[string]*regexp.Regexp
|
||||||
|
headers, r.err = mapFromPairsToRegex(pairs...)
|
||||||
|
return r.addMatcher(headerRegexMatcher(headers))
|
||||||
|
}
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Host -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Host adds a matcher for the URL host.
|
||||||
|
// It accepts a template with zero or more URL variables enclosed by {}.
|
||||||
|
// Variables can define an optional regexp pattern to be matched:
|
||||||
|
//
|
||||||
|
// - {name} matches anything until the next dot.
|
||||||
|
//
|
||||||
|
// - {name:pattern} matches the given regexp pattern.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.Host("www.example.com")
|
||||||
|
// r.Host("{subdomain}.domain.com")
|
||||||
|
// r.Host("{subdomain:[a-z]+}.domain.com")
|
||||||
|
//
|
||||||
|
// Variable names must be unique in a given route. They can be retrieved
|
||||||
|
// calling mux.Vars(request).
|
||||||
|
func (r *Route) Host(tpl string) *Route {
|
||||||
|
r.err = r.addRegexpMatcher(tpl, true, false, false)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatcherFunc ----------------------------------------------------------------
|
||||||
|
|
||||||
|
// MatcherFunc is the function signature used by custom matchers.
|
||||||
|
type MatcherFunc func(*http.Request, *RouteMatch) bool
|
||||||
|
|
||||||
|
// Match returns the match for a given request.
|
||||||
|
func (m MatcherFunc) Match(r *http.Request, match *RouteMatch) bool {
|
||||||
|
return m(r, match)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatcherFunc adds a custom function to be used as request matcher.
|
||||||
|
func (r *Route) MatcherFunc(f MatcherFunc) *Route {
|
||||||
|
return r.addMatcher(f)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// methodMatcher matches the request against HTTP methods.
|
||||||
|
type methodMatcher []string
|
||||||
|
|
||||||
|
func (m methodMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||||
|
return matchInArray(m, r.Method)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Methods adds a matcher for HTTP methods.
|
||||||
|
// It accepts a sequence of one or more methods to be matched, e.g.:
|
||||||
|
// "GET", "POST", "PUT".
|
||||||
|
func (r *Route) Methods(methods ...string) *Route {
|
||||||
|
for k, v := range methods {
|
||||||
|
methods[k] = strings.ToUpper(v)
|
||||||
|
}
|
||||||
|
return r.addMatcher(methodMatcher(methods))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path -----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Path adds a matcher for the URL path.
|
||||||
|
// It accepts a template with zero or more URL variables enclosed by {}. The
|
||||||
|
// template must start with a "/".
|
||||||
|
// Variables can define an optional regexp pattern to be matched:
|
||||||
|
//
|
||||||
|
// - {name} matches anything until the next slash.
|
||||||
|
//
|
||||||
|
// - {name:pattern} matches the given regexp pattern.
|
||||||
|
//
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.Path("/products/").Handler(ProductsHandler)
|
||||||
|
// r.Path("/products/{key}").Handler(ProductsHandler)
|
||||||
|
// r.Path("/articles/{category}/{id:[0-9]+}").
|
||||||
|
// Handler(ArticleHandler)
|
||||||
|
//
|
||||||
|
// Variable names must be unique in a given route. They can be retrieved
|
||||||
|
// calling mux.Vars(request).
|
||||||
|
func (r *Route) Path(tpl string) *Route {
|
||||||
|
r.err = r.addRegexpMatcher(tpl, false, false, false)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// PathPrefix -----------------------------------------------------------------
|
||||||
|
|
||||||
|
// PathPrefix adds a matcher for the URL path prefix. This matches if the given
|
||||||
|
// template is a prefix of the full URL path. See Route.Path() for details on
|
||||||
|
// the tpl argument.
|
||||||
|
//
|
||||||
|
// Note that it does not treat slashes specially ("/foobar/" will be matched by
|
||||||
|
// the prefix "/foo") so you may want to use a trailing slash here.
|
||||||
|
//
|
||||||
|
// Also note that the setting of Router.StrictSlash() has no effect on routes
|
||||||
|
// with a PathPrefix matcher.
|
||||||
|
func (r *Route) PathPrefix(tpl string) *Route {
|
||||||
|
r.err = r.addRegexpMatcher(tpl, false, true, false)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Queries adds a matcher for URL query values.
|
||||||
|
// It accepts a sequence of key/value pairs. Values may define variables.
|
||||||
|
// For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.Queries("foo", "bar", "id", "{id:[0-9]+}")
|
||||||
|
//
|
||||||
|
// The above route will only match if the URL contains the defined queries
|
||||||
|
// values, e.g.: ?foo=bar&id=42.
|
||||||
|
//
|
||||||
|
// It the value is an empty string, it will match any value if the key is set.
|
||||||
|
//
|
||||||
|
// Variables can define an optional regexp pattern to be matched:
|
||||||
|
//
|
||||||
|
// - {name} matches anything until the next slash.
|
||||||
|
//
|
||||||
|
// - {name:pattern} matches the given regexp pattern.
|
||||||
|
func (r *Route) Queries(pairs ...string) *Route {
|
||||||
|
length := len(pairs)
|
||||||
|
if length%2 != 0 {
|
||||||
|
r.err = fmt.Errorf(
|
||||||
|
"mux: number of parameters must be multiple of 2, got %v", pairs)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for i := 0; i < length; i += 2 {
|
||||||
|
if r.err = r.addRegexpMatcher(pairs[i]+"="+pairs[i+1], false, false, true); r.err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schemes --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// schemeMatcher matches the request against URL schemes.
|
||||||
|
type schemeMatcher []string
|
||||||
|
|
||||||
|
func (m schemeMatcher) Match(r *http.Request, match *RouteMatch) bool {
|
||||||
|
return matchInArray(m, r.URL.Scheme)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Schemes adds a matcher for URL schemes.
|
||||||
|
// It accepts a sequence of schemes to be matched, e.g.: "http", "https".
|
||||||
|
func (r *Route) Schemes(schemes ...string) *Route {
|
||||||
|
for k, v := range schemes {
|
||||||
|
schemes[k] = strings.ToLower(v)
|
||||||
|
}
|
||||||
|
return r.addMatcher(schemeMatcher(schemes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildVarsFunc --------------------------------------------------------------
|
||||||
|
|
||||||
|
// BuildVarsFunc is the function signature used by custom build variable
|
||||||
|
// functions (which can modify route variables before a route's URL is built).
|
||||||
|
type BuildVarsFunc func(map[string]string) map[string]string
|
||||||
|
|
||||||
|
// BuildVarsFunc adds a custom function to be used to modify build variables
|
||||||
|
// before a route's URL is built.
|
||||||
|
func (r *Route) BuildVarsFunc(f BuildVarsFunc) *Route {
|
||||||
|
r.buildVarsFunc = f
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Subrouter ------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Subrouter creates a subrouter for the route.
|
||||||
|
//
|
||||||
|
// It will test the inner routes only if the parent route matched. For example:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// s := r.Host("www.example.com").Subrouter()
|
||||||
|
// s.HandleFunc("/products/", ProductsHandler)
|
||||||
|
// s.HandleFunc("/products/{key}", ProductHandler)
|
||||||
|
// s.HandleFunc("/articles/{category}/{id:[0-9]+}"), ArticleHandler)
|
||||||
|
//
|
||||||
|
// Here, the routes registered in the subrouter won't be tested if the host
|
||||||
|
// doesn't match.
|
||||||
|
func (r *Route) Subrouter() *Router {
|
||||||
|
router := &Router{parent: r, strictSlash: r.strictSlash}
|
||||||
|
r.addMatcher(router)
|
||||||
|
return router
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// URL building
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// URL builds a URL for the route.
|
||||||
|
//
|
||||||
|
// It accepts a sequence of key/value pairs for the route variables. For
|
||||||
|
// example, given this route:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||||
|
// Name("article")
|
||||||
|
//
|
||||||
|
// ...a URL for it can be built using:
|
||||||
|
//
|
||||||
|
// url, err := r.Get("article").URL("category", "technology", "id", "42")
|
||||||
|
//
|
||||||
|
// ...which will return an url.URL with the following path:
|
||||||
|
//
|
||||||
|
// "/articles/technology/42"
|
||||||
|
//
|
||||||
|
// This also works for host variables:
|
||||||
|
//
|
||||||
|
// r := mux.NewRouter()
|
||||||
|
// r.Host("{subdomain}.domain.com").
|
||||||
|
// HandleFunc("/articles/{category}/{id:[0-9]+}", ArticleHandler).
|
||||||
|
// Name("article")
|
||||||
|
//
|
||||||
|
// // url.String() will be "http://news.domain.com/articles/technology/42"
|
||||||
|
// url, err := r.Get("article").URL("subdomain", "news",
|
||||||
|
// "category", "technology",
|
||||||
|
// "id", "42")
|
||||||
|
//
|
||||||
|
// All variables defined in the route are required, and their values must
|
||||||
|
// conform to the corresponding patterns.
|
||||||
|
func (r *Route) URL(pairs ...string) (*url.URL, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
if r.regexp == nil {
|
||||||
|
return nil, errors.New("mux: route doesn't have a host or path")
|
||||||
|
}
|
||||||
|
values, err := r.prepareVars(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var scheme, host, path string
|
||||||
|
if r.regexp.host != nil {
|
||||||
|
// Set a default scheme.
|
||||||
|
scheme = "http"
|
||||||
|
if host, err = r.regexp.host.url(values); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.regexp.path != nil {
|
||||||
|
if path, err = r.regexp.path.url(values); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &url.URL{
|
||||||
|
Scheme: scheme,
|
||||||
|
Host: host,
|
||||||
|
Path: path,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLHost builds the host part of the URL for a route. See Route.URL().
|
||||||
|
//
|
||||||
|
// The route must have a host defined.
|
||||||
|
func (r *Route) URLHost(pairs ...string) (*url.URL, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
if r.regexp == nil || r.regexp.host == nil {
|
||||||
|
return nil, errors.New("mux: route doesn't have a host")
|
||||||
|
}
|
||||||
|
values, err := r.prepareVars(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
host, err := r.regexp.host.url(values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: host,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// URLPath builds the path part of the URL for a route. See Route.URL().
|
||||||
|
//
|
||||||
|
// The route must have a path defined.
|
||||||
|
func (r *Route) URLPath(pairs ...string) (*url.URL, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return nil, r.err
|
||||||
|
}
|
||||||
|
if r.regexp == nil || r.regexp.path == nil {
|
||||||
|
return nil, errors.New("mux: route doesn't have a path")
|
||||||
|
}
|
||||||
|
values, err := r.prepareVars(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path, err := r.regexp.path.url(values)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &url.URL{
|
||||||
|
Path: path,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPathTemplate returns the template used to build the
|
||||||
|
// route match.
|
||||||
|
// This is useful for building simple REST API documentation and for instrumentation
|
||||||
|
// against third-party services.
|
||||||
|
// An error will be returned if the route does not define a path.
|
||||||
|
func (r *Route) GetPathTemplate() (string, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return "", r.err
|
||||||
|
}
|
||||||
|
if r.regexp == nil || r.regexp.path == nil {
|
||||||
|
return "", errors.New("mux: route doesn't have a path")
|
||||||
|
}
|
||||||
|
return r.regexp.path.template, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHostTemplate returns the template used to build the
|
||||||
|
// route match.
|
||||||
|
// This is useful for building simple REST API documentation and for instrumentation
|
||||||
|
// against third-party services.
|
||||||
|
// An error will be returned if the route does not define a host.
|
||||||
|
func (r *Route) GetHostTemplate() (string, error) {
|
||||||
|
if r.err != nil {
|
||||||
|
return "", r.err
|
||||||
|
}
|
||||||
|
if r.regexp == nil || r.regexp.host == nil {
|
||||||
|
return "", errors.New("mux: route doesn't have a host")
|
||||||
|
}
|
||||||
|
return r.regexp.host.template, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareVars converts the route variable pairs into a map. If the route has a
|
||||||
|
// BuildVarsFunc, it is invoked.
|
||||||
|
func (r *Route) prepareVars(pairs ...string) (map[string]string, error) {
|
||||||
|
m, err := mapFromPairsToString(pairs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return r.buildVars(m), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Route) buildVars(m map[string]string) map[string]string {
|
||||||
|
if r.parent != nil {
|
||||||
|
m = r.parent.buildVars(m)
|
||||||
|
}
|
||||||
|
if r.buildVarsFunc != nil {
|
||||||
|
m = r.buildVarsFunc(m)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// parentRoute
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
// parentRoute allows routes to know about parent host and path definitions.
|
||||||
|
type parentRoute interface {
|
||||||
|
getNamedRoutes() map[string]*Route
|
||||||
|
getRegexpGroup() *routeRegexpGroup
|
||||||
|
buildVars(map[string]string) map[string]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNamedRoutes returns the map where named routes are registered.
|
||||||
|
func (r *Route) getNamedRoutes() map[string]*Route {
|
||||||
|
if r.parent == nil {
|
||||||
|
// During tests router is not always set.
|
||||||
|
r.parent = NewRouter()
|
||||||
|
}
|
||||||
|
return r.parent.getNamedRoutes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRegexpGroup returns regexp definitions from this route.
|
||||||
|
func (r *Route) getRegexpGroup() *routeRegexpGroup {
|
||||||
|
if r.regexp == nil {
|
||||||
|
if r.parent == nil {
|
||||||
|
// During tests router is not always set.
|
||||||
|
r.parent = NewRouter()
|
||||||
|
}
|
||||||
|
regexp := r.parent.getRegexpGroup()
|
||||||
|
if regexp == nil {
|
||||||
|
r.regexp = new(routeRegexpGroup)
|
||||||
|
} else {
|
||||||
|
// Copy.
|
||||||
|
r.regexp = &routeRegexpGroup{
|
||||||
|
host: regexp.host,
|
||||||
|
path: regexp.path,
|
||||||
|
queries: regexp.queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return r.regexp
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,78 @@
|
||||||
|
securecookie
|
||||||
|
============
|
||||||
|
[![GoDoc](https://godoc.org/github.com/gorilla/securecookie?status.svg)](https://godoc.org/github.com/gorilla/securecookie) [![Build Status](https://travis-ci.org/gorilla/securecookie.png?branch=master)](https://travis-ci.org/gorilla/securecookie)
|
||||||
|
|
||||||
|
securecookie encodes and decodes authenticated and optionally encrypted
|
||||||
|
cookie values.
|
||||||
|
|
||||||
|
Secure cookies can't be forged, because their values are validated using HMAC.
|
||||||
|
When encrypted, the content is also inaccessible to malicious eyes. It is still
|
||||||
|
recommended that sensitive data not be stored in cookies, and that HTTPS be used
|
||||||
|
to prevent cookie [replay attacks](https://en.wikipedia.org/wiki/Replay_attack).
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
To use it, first create a new SecureCookie instance:
|
||||||
|
|
||||||
|
```go
|
||||||
|
// Hash keys should be at least 32 bytes long
|
||||||
|
var hashKey = []byte("very-secret")
|
||||||
|
// Block keys should be 16 bytes (AES-128) or 32 bytes (AES-256) long.
|
||||||
|
// Shorter keys may weaken the encryption used.
|
||||||
|
var blockKey = []byte("a-lot-secret")
|
||||||
|
var s = securecookie.New(hashKey, blockKey)
|
||||||
|
```
|
||||||
|
|
||||||
|
The hashKey is required, used to authenticate the cookie value using HMAC.
|
||||||
|
It is recommended to use a key with 32 or 64 bytes.
|
||||||
|
|
||||||
|
The blockKey is optional, used to encrypt the cookie value -- set it to nil
|
||||||
|
to not use encryption. If set, the length must correspond to the block size
|
||||||
|
of the encryption algorithm. For AES, used by default, valid lengths are
|
||||||
|
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||||
|
|
||||||
|
Strong keys can be created using the convenience function GenerateRandomKey().
|
||||||
|
|
||||||
|
Once a SecureCookie instance is set, use it to encode a cookie value:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func SetCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
value := map[string]string{
|
||||||
|
"foo": "bar",
|
||||||
|
}
|
||||||
|
if encoded, err := s.Encode("cookie-name", value); err == nil {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: "cookie-name",
|
||||||
|
Value: encoded,
|
||||||
|
Path: "/",
|
||||||
|
Secure: true,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Later, use the same SecureCookie instance to decode and validate a cookie
|
||||||
|
value:
|
||||||
|
|
||||||
|
```go
|
||||||
|
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if cookie, err := r.Cookie("cookie-name"); err == nil {
|
||||||
|
value := make(map[string]string)
|
||||||
|
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil {
|
||||||
|
fmt.Fprintf(w, "The value of foo is %q", value["foo"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
We stored a map[string]string, but secure cookies can hold any value that
|
||||||
|
can be encoded using `encoding/gob`. To store custom types, they must be
|
||||||
|
registered first using gob.Register(). For basic types this is not needed;
|
||||||
|
it works out of the box. An optional JSON encoder that uses `encoding/json` is
|
||||||
|
available for types compatible with JSON.
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
BSD licensed. See the LICENSE file for details.
|
|
@ -0,0 +1,61 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package securecookie encodes and decodes authenticated and optionally
|
||||||
|
encrypted cookie values.
|
||||||
|
|
||||||
|
Secure cookies can't be forged, because their values are validated using HMAC.
|
||||||
|
When encrypted, the content is also inaccessible to malicious eyes.
|
||||||
|
|
||||||
|
To use it, first create a new SecureCookie instance:
|
||||||
|
|
||||||
|
var hashKey = []byte("very-secret")
|
||||||
|
var blockKey = []byte("a-lot-secret")
|
||||||
|
var s = securecookie.New(hashKey, blockKey)
|
||||||
|
|
||||||
|
The hashKey is required, used to authenticate the cookie value using HMAC.
|
||||||
|
It is recommended to use a key with 32 or 64 bytes.
|
||||||
|
|
||||||
|
The blockKey is optional, used to encrypt the cookie value -- set it to nil
|
||||||
|
to not use encryption. If set, the length must correspond to the block size
|
||||||
|
of the encryption algorithm. For AES, used by default, valid lengths are
|
||||||
|
16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||||
|
|
||||||
|
Strong keys can be created using the convenience function GenerateRandomKey().
|
||||||
|
|
||||||
|
Once a SecureCookie instance is set, use it to encode a cookie value:
|
||||||
|
|
||||||
|
func SetCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
value := map[string]string{
|
||||||
|
"foo": "bar",
|
||||||
|
}
|
||||||
|
if encoded, err := s.Encode("cookie-name", value); err == nil {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: "cookie-name",
|
||||||
|
Value: encoded,
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Later, use the same SecureCookie instance to decode and validate a cookie
|
||||||
|
value:
|
||||||
|
|
||||||
|
func ReadCookieHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if cookie, err := r.Cookie("cookie-name"); err == nil {
|
||||||
|
value := make(map[string]string)
|
||||||
|
if err = s2.Decode("cookie-name", cookie.Value, &value); err == nil {
|
||||||
|
fmt.Fprintf(w, "The value of foo is %q", value["foo"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
We stored a map[string]string, but secure cookies can hold any value that
|
||||||
|
can be encoded using encoding/gob. To store custom types, they must be
|
||||||
|
registered first using gob.Register(). For basic types this is not needed;
|
||||||
|
it works out of the box.
|
||||||
|
*/
|
||||||
|
package securecookie
|
|
@ -0,0 +1,25 @@
|
||||||
|
// +build gofuzz
|
||||||
|
|
||||||
|
package securecookie
|
||||||
|
|
||||||
|
var hashKey = []byte("very-secret12345")
|
||||||
|
var blockKey = []byte("a-lot-secret1234")
|
||||||
|
var s = New(hashKey, blockKey)
|
||||||
|
|
||||||
|
type Cookie struct {
|
||||||
|
B bool
|
||||||
|
I int
|
||||||
|
S string
|
||||||
|
}
|
||||||
|
|
||||||
|
func Fuzz(data []byte) int {
|
||||||
|
datas := string(data)
|
||||||
|
var c Cookie
|
||||||
|
if err := s.Decode("fuzz", datas, &c); err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if _, err := s.Encode("fuzz", c); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return 1
|
||||||
|
}
|
|
@ -0,0 +1,646 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package securecookie
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/aes"
|
||||||
|
"crypto/cipher"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"crypto/subtle"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/gob"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"io"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Error is the interface of all errors returned by functions in this library.
|
||||||
|
type Error interface {
|
||||||
|
error
|
||||||
|
|
||||||
|
// IsUsage returns true for errors indicating the client code probably
|
||||||
|
// uses this library incorrectly. For example, the client may have
|
||||||
|
// failed to provide a valid hash key, or may have failed to configure
|
||||||
|
// the Serializer adequately for encoding value.
|
||||||
|
IsUsage() bool
|
||||||
|
|
||||||
|
// IsDecode returns true for errors indicating that a cookie could not
|
||||||
|
// be decoded and validated. Since cookies are usually untrusted
|
||||||
|
// user-provided input, errors of this type should be expected.
|
||||||
|
// Usually, the proper action is simply to reject the request.
|
||||||
|
IsDecode() bool
|
||||||
|
|
||||||
|
// IsInternal returns true for unexpected errors occurring in the
|
||||||
|
// securecookie implementation.
|
||||||
|
IsInternal() bool
|
||||||
|
|
||||||
|
// Cause, if it returns a non-nil value, indicates that this error was
|
||||||
|
// propagated from some underlying library. If this method returns nil,
|
||||||
|
// this error was raised directly by this library.
|
||||||
|
//
|
||||||
|
// Cause is provided principally for debugging/logging purposes; it is
|
||||||
|
// rare that application logic should perform meaningfully different
|
||||||
|
// logic based on Cause. See, for example, the caveats described on
|
||||||
|
// (MultiError).Cause().
|
||||||
|
Cause() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// errorType is a bitmask giving the error type(s) of an cookieError value.
|
||||||
|
type errorType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
usageError = errorType(1 << iota)
|
||||||
|
decodeError
|
||||||
|
internalError
|
||||||
|
)
|
||||||
|
|
||||||
|
type cookieError struct {
|
||||||
|
typ errorType
|
||||||
|
msg string
|
||||||
|
cause error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e cookieError) IsUsage() bool { return (e.typ & usageError) != 0 }
|
||||||
|
func (e cookieError) IsDecode() bool { return (e.typ & decodeError) != 0 }
|
||||||
|
func (e cookieError) IsInternal() bool { return (e.typ & internalError) != 0 }
|
||||||
|
|
||||||
|
func (e cookieError) Cause() error { return e.cause }
|
||||||
|
|
||||||
|
func (e cookieError) Error() string {
|
||||||
|
parts := []string{"securecookie: "}
|
||||||
|
if e.msg == "" {
|
||||||
|
parts = append(parts, "error")
|
||||||
|
} else {
|
||||||
|
parts = append(parts, e.msg)
|
||||||
|
}
|
||||||
|
if c := e.Cause(); c != nil {
|
||||||
|
parts = append(parts, " - caused by: ", c.Error())
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errGeneratingIV = cookieError{typ: internalError, msg: "failed to generate random iv"}
|
||||||
|
|
||||||
|
errNoCodecs = cookieError{typ: usageError, msg: "no codecs provided"}
|
||||||
|
errHashKeyNotSet = cookieError{typ: usageError, msg: "hash key is not set"}
|
||||||
|
errBlockKeyNotSet = cookieError{typ: usageError, msg: "block key is not set"}
|
||||||
|
errEncodedValueTooLong = cookieError{typ: usageError, msg: "the value is too long"}
|
||||||
|
|
||||||
|
errValueToDecodeTooLong = cookieError{typ: decodeError, msg: "the value is too long"}
|
||||||
|
errTimestampInvalid = cookieError{typ: decodeError, msg: "invalid timestamp"}
|
||||||
|
errTimestampTooNew = cookieError{typ: decodeError, msg: "timestamp is too new"}
|
||||||
|
errTimestampExpired = cookieError{typ: decodeError, msg: "expired timestamp"}
|
||||||
|
errDecryptionFailed = cookieError{typ: decodeError, msg: "the value could not be decrypted"}
|
||||||
|
errValueNotByte = cookieError{typ: decodeError, msg: "value not a []byte."}
|
||||||
|
errValueNotBytePtr = cookieError{typ: decodeError, msg: "value not a pointer to []byte."}
|
||||||
|
|
||||||
|
// ErrMacInvalid indicates that cookie decoding failed because the HMAC
|
||||||
|
// could not be extracted and verified. Direct use of this error
|
||||||
|
// variable is deprecated; it is public only for legacy compatibility,
|
||||||
|
// and may be privatized in the future, as it is rarely useful to
|
||||||
|
// distinguish between this error and other Error implementations.
|
||||||
|
ErrMacInvalid = cookieError{typ: decodeError, msg: "the value is not valid"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Codec defines an interface to encode and decode cookie values.
|
||||||
|
type Codec interface {
|
||||||
|
Encode(name string, value interface{}) (string, error)
|
||||||
|
Decode(name, value string, dst interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a new SecureCookie.
|
||||||
|
//
|
||||||
|
// hashKey is required, used to authenticate values using HMAC. Create it using
|
||||||
|
// GenerateRandomKey(). It is recommended to use a key with 32 or 64 bytes.
|
||||||
|
//
|
||||||
|
// blockKey is optional, used to encrypt values. Create it using
|
||||||
|
// GenerateRandomKey(). The key length must correspond to the block size
|
||||||
|
// of the encryption algorithm. For AES, used by default, valid lengths are
|
||||||
|
// 16, 24, or 32 bytes to select AES-128, AES-192, or AES-256.
|
||||||
|
// The default encoder used for cookie serialization is encoding/gob.
|
||||||
|
//
|
||||||
|
// Note that keys created using GenerateRandomKey() are not automatically
|
||||||
|
// persisted. New keys will be created when the application is restarted, and
|
||||||
|
// previously issued cookies will not be able to be decoded.
|
||||||
|
func New(hashKey, blockKey []byte) *SecureCookie {
|
||||||
|
s := &SecureCookie{
|
||||||
|
hashKey: hashKey,
|
||||||
|
blockKey: blockKey,
|
||||||
|
hashFunc: sha256.New,
|
||||||
|
maxAge: 86400 * 30,
|
||||||
|
maxLength: 4096,
|
||||||
|
sz: GobEncoder{},
|
||||||
|
}
|
||||||
|
if hashKey == nil {
|
||||||
|
s.err = errHashKeyNotSet
|
||||||
|
}
|
||||||
|
if blockKey != nil {
|
||||||
|
s.BlockFunc(aes.NewCipher)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// SecureCookie encodes and decodes authenticated and optionally encrypted
|
||||||
|
// cookie values.
|
||||||
|
type SecureCookie struct {
|
||||||
|
hashKey []byte
|
||||||
|
hashFunc func() hash.Hash
|
||||||
|
blockKey []byte
|
||||||
|
block cipher.Block
|
||||||
|
maxLength int
|
||||||
|
maxAge int64
|
||||||
|
minAge int64
|
||||||
|
err error
|
||||||
|
sz Serializer
|
||||||
|
// For testing purposes, the function that returns the current timestamp.
|
||||||
|
// If not set, it will use time.Now().UTC().Unix().
|
||||||
|
timeFunc func() int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serializer provides an interface for providing custom serializers for cookie
|
||||||
|
// values.
|
||||||
|
type Serializer interface {
|
||||||
|
Serialize(src interface{}) ([]byte, error)
|
||||||
|
Deserialize(src []byte, dst interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GobEncoder encodes cookie values using encoding/gob. This is the simplest
|
||||||
|
// encoder and can handle complex types via gob.Register.
|
||||||
|
type GobEncoder struct{}
|
||||||
|
|
||||||
|
// JSONEncoder encodes cookie values using encoding/json. Users who wish to
|
||||||
|
// encode complex types need to satisfy the json.Marshaller and
|
||||||
|
// json.Unmarshaller interfaces.
|
||||||
|
type JSONEncoder struct{}
|
||||||
|
|
||||||
|
// NopEncoder does not encode cookie values, and instead simply accepts a []byte
|
||||||
|
// (as an interface{}) and returns a []byte. This is particularly useful when
|
||||||
|
// you encoding an object upstream and do not wish to re-encode it.
|
||||||
|
type NopEncoder struct{}
|
||||||
|
|
||||||
|
// MaxLength restricts the maximum length, in bytes, for the cookie value.
|
||||||
|
//
|
||||||
|
// Default is 4096, which is the maximum value accepted by Internet Explorer.
|
||||||
|
func (s *SecureCookie) MaxLength(value int) *SecureCookie {
|
||||||
|
s.maxLength = value
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxAge restricts the maximum age, in seconds, for the cookie value.
|
||||||
|
//
|
||||||
|
// Default is 86400 * 30. Set it to 0 for no restriction.
|
||||||
|
func (s *SecureCookie) MaxAge(value int) *SecureCookie {
|
||||||
|
s.maxAge = int64(value)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// MinAge restricts the minimum age, in seconds, for the cookie value.
|
||||||
|
//
|
||||||
|
// Default is 0 (no restriction).
|
||||||
|
func (s *SecureCookie) MinAge(value int) *SecureCookie {
|
||||||
|
s.minAge = int64(value)
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashFunc sets the hash function used to create HMAC.
|
||||||
|
//
|
||||||
|
// Default is crypto/sha256.New.
|
||||||
|
func (s *SecureCookie) HashFunc(f func() hash.Hash) *SecureCookie {
|
||||||
|
s.hashFunc = f
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlockFunc sets the encryption function used to create a cipher.Block.
|
||||||
|
//
|
||||||
|
// Default is crypto/aes.New.
|
||||||
|
func (s *SecureCookie) BlockFunc(f func([]byte) (cipher.Block, error)) *SecureCookie {
|
||||||
|
if s.blockKey == nil {
|
||||||
|
s.err = errBlockKeyNotSet
|
||||||
|
} else if block, err := f(s.blockKey); err == nil {
|
||||||
|
s.block = block
|
||||||
|
} else {
|
||||||
|
s.err = cookieError{cause: err, typ: usageError}
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding sets the encoding/serialization method for cookies.
|
||||||
|
//
|
||||||
|
// Default is encoding/gob. To encode special structures using encoding/gob,
|
||||||
|
// they must be registered first using gob.Register().
|
||||||
|
func (s *SecureCookie) SetSerializer(sz Serializer) *SecureCookie {
|
||||||
|
s.sz = sz
|
||||||
|
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode encodes a cookie value.
|
||||||
|
//
|
||||||
|
// It serializes, optionally encrypts, signs with a message authentication code,
|
||||||
|
// and finally encodes the value.
|
||||||
|
//
|
||||||
|
// The name argument is the cookie name. It is stored with the encoded value.
|
||||||
|
// The value argument is the value to be encoded. It can be any value that can
|
||||||
|
// be encoded using the currently selected serializer; see SetSerializer().
|
||||||
|
//
|
||||||
|
// It is the client's responsibility to ensure that value, when encoded using
|
||||||
|
// the current serialization/encryption settings on s and then base64-encoded,
|
||||||
|
// is shorter than the maximum permissible length.
|
||||||
|
func (s *SecureCookie) Encode(name string, value interface{}) (string, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return "", s.err
|
||||||
|
}
|
||||||
|
if s.hashKey == nil {
|
||||||
|
s.err = errHashKeyNotSet
|
||||||
|
return "", s.err
|
||||||
|
}
|
||||||
|
var err error
|
||||||
|
var b []byte
|
||||||
|
// 1. Serialize.
|
||||||
|
if b, err = s.sz.Serialize(value); err != nil {
|
||||||
|
return "", cookieError{cause: err, typ: usageError}
|
||||||
|
}
|
||||||
|
// 2. Encrypt (optional).
|
||||||
|
if s.block != nil {
|
||||||
|
if b, err = encrypt(s.block, b); err != nil {
|
||||||
|
return "", cookieError{cause: err, typ: usageError}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b = encode(b)
|
||||||
|
// 3. Create MAC for "name|date|value". Extra pipe to be used later.
|
||||||
|
b = []byte(fmt.Sprintf("%s|%d|%s|", name, s.timestamp(), b))
|
||||||
|
mac := createMac(hmac.New(s.hashFunc, s.hashKey), b[:len(b)-1])
|
||||||
|
// Append mac, remove name.
|
||||||
|
b = append(b, mac...)[len(name)+1:]
|
||||||
|
// 4. Encode to base64.
|
||||||
|
b = encode(b)
|
||||||
|
// 5. Check length.
|
||||||
|
if s.maxLength != 0 && len(b) > s.maxLength {
|
||||||
|
return "", errEncodedValueTooLong
|
||||||
|
}
|
||||||
|
// Done.
|
||||||
|
return string(b), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode decodes a cookie value.
|
||||||
|
//
|
||||||
|
// It decodes, verifies a message authentication code, optionally decrypts and
|
||||||
|
// finally deserializes the value.
|
||||||
|
//
|
||||||
|
// The name argument is the cookie name. It must be the same name used when
|
||||||
|
// it was stored. The value argument is the encoded cookie value. The dst
|
||||||
|
// argument is where the cookie will be decoded. It must be a pointer.
|
||||||
|
func (s *SecureCookie) Decode(name, value string, dst interface{}) error {
|
||||||
|
if s.err != nil {
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
if s.hashKey == nil {
|
||||||
|
s.err = errHashKeyNotSet
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
// 1. Check length.
|
||||||
|
if s.maxLength != 0 && len(value) > s.maxLength {
|
||||||
|
return errValueToDecodeTooLong
|
||||||
|
}
|
||||||
|
// 2. Decode from base64.
|
||||||
|
b, err := decode([]byte(value))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 3. Verify MAC. Value is "date|value|mac".
|
||||||
|
parts := bytes.SplitN(b, []byte("|"), 3)
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return ErrMacInvalid
|
||||||
|
}
|
||||||
|
h := hmac.New(s.hashFunc, s.hashKey)
|
||||||
|
b = append([]byte(name+"|"), b[:len(b)-len(parts[2])-1]...)
|
||||||
|
if err = verifyMac(h, b, parts[2]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// 4. Verify date ranges.
|
||||||
|
var t1 int64
|
||||||
|
if t1, err = strconv.ParseInt(string(parts[0]), 10, 64); err != nil {
|
||||||
|
return errTimestampInvalid
|
||||||
|
}
|
||||||
|
t2 := s.timestamp()
|
||||||
|
if s.minAge != 0 && t1 > t2-s.minAge {
|
||||||
|
return errTimestampTooNew
|
||||||
|
}
|
||||||
|
if s.maxAge != 0 && t1 < t2-s.maxAge {
|
||||||
|
return errTimestampExpired
|
||||||
|
}
|
||||||
|
// 5. Decrypt (optional).
|
||||||
|
b, err = decode(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.block != nil {
|
||||||
|
if b, err = decrypt(s.block, b); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 6. Deserialize.
|
||||||
|
if err = s.sz.Deserialize(b, dst); err != nil {
|
||||||
|
return cookieError{cause: err, typ: decodeError}
|
||||||
|
}
|
||||||
|
// Done.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// timestamp returns the current timestamp, in seconds.
|
||||||
|
//
|
||||||
|
// For testing purposes, the function that generates the timestamp can be
|
||||||
|
// overridden. If not set, it will return time.Now().UTC().Unix().
|
||||||
|
func (s *SecureCookie) timestamp() int64 {
|
||||||
|
if s.timeFunc == nil {
|
||||||
|
return time.Now().UTC().Unix()
|
||||||
|
}
|
||||||
|
return s.timeFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authentication -------------------------------------------------------------
|
||||||
|
|
||||||
|
// createMac creates a message authentication code (MAC).
|
||||||
|
func createMac(h hash.Hash, value []byte) []byte {
|
||||||
|
h.Write(value)
|
||||||
|
return h.Sum(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyMac verifies that a message authentication code (MAC) is valid.
|
||||||
|
func verifyMac(h hash.Hash, value []byte, mac []byte) error {
|
||||||
|
mac2 := createMac(h, value)
|
||||||
|
// Check that both MACs are of equal length, as subtle.ConstantTimeCompare
|
||||||
|
// does not do this prior to Go 1.4.
|
||||||
|
if len(mac) == len(mac2) && subtle.ConstantTimeCompare(mac, mac2) == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return ErrMacInvalid
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encryption -----------------------------------------------------------------
|
||||||
|
|
||||||
|
// encrypt encrypts a value using the given block in counter mode.
|
||||||
|
//
|
||||||
|
// A random initialization vector (http://goo.gl/zF67k) with the length of the
|
||||||
|
// block size is prepended to the resulting ciphertext.
|
||||||
|
func encrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||||
|
iv := GenerateRandomKey(block.BlockSize())
|
||||||
|
if iv == nil {
|
||||||
|
return nil, errGeneratingIV
|
||||||
|
}
|
||||||
|
// Encrypt it.
|
||||||
|
stream := cipher.NewCTR(block, iv)
|
||||||
|
stream.XORKeyStream(value, value)
|
||||||
|
// Return iv + ciphertext.
|
||||||
|
return append(iv, value...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt decrypts a value using the given block in counter mode.
|
||||||
|
//
|
||||||
|
// The value to be decrypted must be prepended by a initialization vector
|
||||||
|
// (http://goo.gl/zF67k) with the length of the block size.
|
||||||
|
func decrypt(block cipher.Block, value []byte) ([]byte, error) {
|
||||||
|
size := block.BlockSize()
|
||||||
|
if len(value) > size {
|
||||||
|
// Extract iv.
|
||||||
|
iv := value[:size]
|
||||||
|
// Extract ciphertext.
|
||||||
|
value = value[size:]
|
||||||
|
// Decrypt it.
|
||||||
|
stream := cipher.NewCTR(block, iv)
|
||||||
|
stream.XORKeyStream(value, value)
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
return nil, errDecryptionFailed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialization --------------------------------------------------------------
|
||||||
|
|
||||||
|
// Serialize encodes a value using gob.
|
||||||
|
func (e GobEncoder) Serialize(src interface{}) ([]byte, error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
enc := gob.NewEncoder(buf)
|
||||||
|
if err := enc.Encode(src); err != nil {
|
||||||
|
return nil, cookieError{cause: err, typ: usageError}
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize decodes a value using gob.
|
||||||
|
func (e GobEncoder) Deserialize(src []byte, dst interface{}) error {
|
||||||
|
dec := gob.NewDecoder(bytes.NewBuffer(src))
|
||||||
|
if err := dec.Decode(dst); err != nil {
|
||||||
|
return cookieError{cause: err, typ: decodeError}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize encodes a value using encoding/json.
|
||||||
|
func (e JSONEncoder) Serialize(src interface{}) ([]byte, error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
enc := json.NewEncoder(buf)
|
||||||
|
if err := enc.Encode(src); err != nil {
|
||||||
|
return nil, cookieError{cause: err, typ: usageError}
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize decodes a value using encoding/json.
|
||||||
|
func (e JSONEncoder) Deserialize(src []byte, dst interface{}) error {
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(src))
|
||||||
|
if err := dec.Decode(dst); err != nil {
|
||||||
|
return cookieError{cause: err, typ: decodeError}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize passes a []byte through as-is.
|
||||||
|
func (e NopEncoder) Serialize(src interface{}) ([]byte, error) {
|
||||||
|
if b, ok := src.([]byte); ok {
|
||||||
|
return b, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errValueNotByte
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deserialize passes a []byte through as-is.
|
||||||
|
func (e NopEncoder) Deserialize(src []byte, dst interface{}) error {
|
||||||
|
if dat, ok := dst.(*[]byte); ok {
|
||||||
|
*dat = src
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errValueNotBytePtr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encoding -------------------------------------------------------------------
|
||||||
|
|
||||||
|
// encode encodes a value using base64.
|
||||||
|
func encode(value []byte) []byte {
|
||||||
|
encoded := make([]byte, base64.URLEncoding.EncodedLen(len(value)))
|
||||||
|
base64.URLEncoding.Encode(encoded, value)
|
||||||
|
return encoded
|
||||||
|
}
|
||||||
|
|
||||||
|
// decode decodes a cookie using base64.
|
||||||
|
func decode(value []byte) ([]byte, error) {
|
||||||
|
decoded := make([]byte, base64.URLEncoding.DecodedLen(len(value)))
|
||||||
|
b, err := base64.URLEncoding.Decode(decoded, value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, cookieError{cause: err, typ: decodeError, msg: "base64 decode failed"}
|
||||||
|
}
|
||||||
|
return decoded[:b], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// GenerateRandomKey creates a random key with the given length in bytes.
|
||||||
|
// On failure, returns nil.
|
||||||
|
//
|
||||||
|
// Callers should explicitly check for the possibility of a nil return, treat
|
||||||
|
// it as a failure of the system random number generator, and not continue.
|
||||||
|
func GenerateRandomKey(length int) []byte {
|
||||||
|
k := make([]byte, length)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, k); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodecsFromPairs returns a slice of SecureCookie instances.
|
||||||
|
//
|
||||||
|
// It is a convenience function to create a list of codecs for key rotation. Note
|
||||||
|
// that the generated Codecs will have the default options applied: callers
|
||||||
|
// should iterate over each Codec and type-assert the underlying *SecureCookie to
|
||||||
|
// change these.
|
||||||
|
//
|
||||||
|
// Example:
|
||||||
|
//
|
||||||
|
// codecs := securecookie.CodecsFromPairs(
|
||||||
|
// []byte("new-hash-key"),
|
||||||
|
// []byte("new-block-key"),
|
||||||
|
// []byte("old-hash-key"),
|
||||||
|
// []byte("old-block-key"),
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// // Modify each instance.
|
||||||
|
// for _, s := range codecs {
|
||||||
|
// if cookie, ok := s.(*securecookie.SecureCookie); ok {
|
||||||
|
// cookie.MaxAge(86400 * 7)
|
||||||
|
// cookie.SetSerializer(securecookie.JSONEncoder{})
|
||||||
|
// cookie.HashFunc(sha512.New512_256)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
func CodecsFromPairs(keyPairs ...[]byte) []Codec {
|
||||||
|
codecs := make([]Codec, len(keyPairs)/2+len(keyPairs)%2)
|
||||||
|
for i := 0; i < len(keyPairs); i += 2 {
|
||||||
|
var blockKey []byte
|
||||||
|
if i+1 < len(keyPairs) {
|
||||||
|
blockKey = keyPairs[i+1]
|
||||||
|
}
|
||||||
|
codecs[i/2] = New(keyPairs[i], blockKey)
|
||||||
|
}
|
||||||
|
return codecs
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodeMulti encodes a cookie value using a group of codecs.
|
||||||
|
//
|
||||||
|
// The codecs are tried in order. Multiple codecs are accepted to allow
|
||||||
|
// key rotation.
|
||||||
|
//
|
||||||
|
// On error, may return a MultiError.
|
||||||
|
func EncodeMulti(name string, value interface{}, codecs ...Codec) (string, error) {
|
||||||
|
if len(codecs) == 0 {
|
||||||
|
return "", errNoCodecs
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors MultiError
|
||||||
|
for _, codec := range codecs {
|
||||||
|
encoded, err := codec.Encode(name, value)
|
||||||
|
if err == nil {
|
||||||
|
return encoded, nil
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
return "", errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// DecodeMulti decodes a cookie value using a group of codecs.
|
||||||
|
//
|
||||||
|
// The codecs are tried in order. Multiple codecs are accepted to allow
|
||||||
|
// key rotation.
|
||||||
|
//
|
||||||
|
// On error, may return a MultiError.
|
||||||
|
func DecodeMulti(name string, value string, dst interface{}, codecs ...Codec) error {
|
||||||
|
if len(codecs) == 0 {
|
||||||
|
return errNoCodecs
|
||||||
|
}
|
||||||
|
|
||||||
|
var errors MultiError
|
||||||
|
for _, codec := range codecs {
|
||||||
|
err := codec.Decode(name, value, dst)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
errors = append(errors, err)
|
||||||
|
}
|
||||||
|
return errors
|
||||||
|
}
|
||||||
|
|
||||||
|
// MultiError groups multiple errors.
|
||||||
|
type MultiError []error
|
||||||
|
|
||||||
|
func (m MultiError) IsUsage() bool { return m.any(func(e Error) bool { return e.IsUsage() }) }
|
||||||
|
func (m MultiError) IsDecode() bool { return m.any(func(e Error) bool { return e.IsDecode() }) }
|
||||||
|
func (m MultiError) IsInternal() bool { return m.any(func(e Error) bool { return e.IsInternal() }) }
|
||||||
|
|
||||||
|
// Cause returns nil for MultiError; there is no unique underlying cause in the
|
||||||
|
// general case.
|
||||||
|
//
|
||||||
|
// Note: we could conceivably return a non-nil Cause only when there is exactly
|
||||||
|
// one child error with a Cause. However, it would be brittle for client code
|
||||||
|
// to rely on the arity of causes inside a MultiError, so we have opted not to
|
||||||
|
// provide this functionality. Clients which really wish to access the Causes
|
||||||
|
// of the underlying errors are free to iterate through the errors themselves.
|
||||||
|
func (m MultiError) Cause() error { return nil }
|
||||||
|
|
||||||
|
func (m MultiError) Error() string {
|
||||||
|
s, n := "", 0
|
||||||
|
for _, e := range m {
|
||||||
|
if e != nil {
|
||||||
|
if n == 0 {
|
||||||
|
s = e.Error()
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch n {
|
||||||
|
case 0:
|
||||||
|
return "(0 errors)"
|
||||||
|
case 1:
|
||||||
|
return s
|
||||||
|
case 2:
|
||||||
|
return s + " (and 1 other error)"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// any returns true if any element of m is an Error for which pred returns true.
|
||||||
|
func (m MultiError) any(pred func(Error) bool) bool {
|
||||||
|
for _, e := range m {
|
||||||
|
if ourErr, ok := e.(Error); ok && pred(ourErr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2012 Rodrigo Moraes. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,81 @@
|
||||||
|
sessions
|
||||||
|
========
|
||||||
|
[![GoDoc](https://godoc.org/github.com/gorilla/sessions?status.svg)](https://godoc.org/github.com/gorilla/sessions) [![Build Status](https://travis-ci.org/gorilla/sessions.png?branch=master)](https://travis-ci.org/gorilla/sessions)
|
||||||
|
|
||||||
|
gorilla/sessions provides cookie and filesystem sessions and infrastructure for
|
||||||
|
custom session backends.
|
||||||
|
|
||||||
|
The key features are:
|
||||||
|
|
||||||
|
* Simple API: use it as an easy way to set signed (and optionally
|
||||||
|
encrypted) cookies.
|
||||||
|
* Built-in backends to store sessions in cookies or the filesystem.
|
||||||
|
* Flash messages: session values that last until read.
|
||||||
|
* Convenient way to switch session persistency (aka "remember me") and set
|
||||||
|
other attributes.
|
||||||
|
* Mechanism to rotate authentication and encryption keys.
|
||||||
|
* Multiple sessions per request, even using different backends.
|
||||||
|
* Interfaces and infrastructure for custom session backends: sessions from
|
||||||
|
different stores can be retrieved and batch-saved using a common API.
|
||||||
|
|
||||||
|
Let's start with an example that shows the sessions API in a nutshell:
|
||||||
|
|
||||||
|
```go
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var store = sessions.NewCookieStore([]byte("something-very-secret"))
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get a session. We're ignoring the error resulted from decoding an
|
||||||
|
// existing session: Get() always returns a session, even if empty.
|
||||||
|
session, _ := store.Get(r, "session-name")
|
||||||
|
// Set some session values.
|
||||||
|
session.Values["foo"] = "bar"
|
||||||
|
session.Values[42] = 43
|
||||||
|
// Save it before we write to the response/return from the handler.
|
||||||
|
session.Save(r, w)
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
First we initialize a session store calling `NewCookieStore()` and passing a
|
||||||
|
secret key used to authenticate the session. Inside the handler, we call
|
||||||
|
`store.Get()` to retrieve an existing session or a new one. Then we set some
|
||||||
|
session values in session.Values, which is a `map[interface{}]interface{}`.
|
||||||
|
And finally we call `session.Save()` to save the session in the response.
|
||||||
|
|
||||||
|
Important Note: If you aren't using gorilla/mux, you need to wrap your handlers
|
||||||
|
with
|
||||||
|
[`context.ClearHandler`](http://www.gorillatoolkit.org/pkg/context#ClearHandler)
|
||||||
|
as or else you will leak memory! An easy way to do this is to wrap the top-level
|
||||||
|
mux when calling http.ListenAndServe:
|
||||||
|
|
||||||
|
More examples are available [on the Gorilla
|
||||||
|
website](http://www.gorillatoolkit.org/pkg/sessions).
|
||||||
|
|
||||||
|
## Store Implementations
|
||||||
|
|
||||||
|
Other implementations of the `sessions.Store` interface:
|
||||||
|
|
||||||
|
* [github.com/starJammer/gorilla-sessions-arangodb](https://github.com/starJammer/gorilla-sessions-arangodb) - ArangoDB
|
||||||
|
* [github.com/yosssi/boltstore](https://github.com/yosssi/boltstore) - Bolt
|
||||||
|
* [github.com/srinathgs/couchbasestore](https://github.com/srinathgs/couchbasestore) - Couchbase
|
||||||
|
* [github.com/denizeren/dynamostore](https://github.com/denizeren/dynamostore) - Dynamodb on AWS
|
||||||
|
* [github.com/bradleypeabody/gorilla-sessions-memcache](https://github.com/bradleypeabody/gorilla-sessions-memcache) - Memcache
|
||||||
|
* [github.com/dsoprea/go-appengine-sessioncascade](https://github.com/dsoprea/go-appengine-sessioncascade) - Memcache/Datastore/Context in AppEngine
|
||||||
|
* [github.com/kidstuff/mongostore](https://github.com/kidstuff/mongostore) - MongoDB
|
||||||
|
* [github.com/srinathgs/mysqlstore](https://github.com/srinathgs/mysqlstore) - MySQL
|
||||||
|
* [github.com/EnumApps/clustersqlstore](https://github.com/EnumApps/clustersqlstore) - MySQL Cluster
|
||||||
|
* [github.com/antonlindstrom/pgstore](https://github.com/antonlindstrom/pgstore) - PostgreSQL
|
||||||
|
* [github.com/boj/redistore](https://github.com/boj/redistore) - Redis
|
||||||
|
* [github.com/boj/rethinkstore](https://github.com/boj/rethinkstore) - RethinkDB
|
||||||
|
* [github.com/boj/riakstore](https://github.com/boj/riakstore) - Riak
|
||||||
|
* [github.com/michaeljs1990/sqlitestore](https://github.com/michaeljs1990/sqlitestore) - SQLite
|
||||||
|
* [github.com/wader/gormstore](https://github.com/wader/gormstore) - GORM (MySQL, PostgreSQL, SQLite)
|
||||||
|
* [github.com/gernest/qlstore](https://github.com/gernest/qlstore) - ql
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
BSD licensed. See the LICENSE file for details.
|
|
@ -0,0 +1,199 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package sessions provides cookie and filesystem sessions and
|
||||||
|
infrastructure for custom session backends.
|
||||||
|
|
||||||
|
The key features are:
|
||||||
|
|
||||||
|
* Simple API: use it as an easy way to set signed (and optionally
|
||||||
|
encrypted) cookies.
|
||||||
|
* Built-in backends to store sessions in cookies or the filesystem.
|
||||||
|
* Flash messages: session values that last until read.
|
||||||
|
* Convenient way to switch session persistency (aka "remember me") and set
|
||||||
|
other attributes.
|
||||||
|
* Mechanism to rotate authentication and encryption keys.
|
||||||
|
* Multiple sessions per request, even using different backends.
|
||||||
|
* Interfaces and infrastructure for custom session backends: sessions from
|
||||||
|
different stores can be retrieved and batch-saved using a common API.
|
||||||
|
|
||||||
|
Let's start with an example that shows the sessions API in a nutshell:
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
var store = sessions.NewCookieStore([]byte("something-very-secret"))
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get a session. We're ignoring the error resulted from decoding an
|
||||||
|
// existing session: Get() always returns a session, even if empty.
|
||||||
|
session, err := store.Get(r, "session-name")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set some session values.
|
||||||
|
session.Values["foo"] = "bar"
|
||||||
|
session.Values[42] = 43
|
||||||
|
// Save it before we write to the response/return from the handler.
|
||||||
|
session.Save(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
First we initialize a session store calling NewCookieStore() and passing a
|
||||||
|
secret key used to authenticate the session. Inside the handler, we call
|
||||||
|
store.Get() to retrieve an existing session or a new one. Then we set some
|
||||||
|
session values in session.Values, which is a map[interface{}]interface{}.
|
||||||
|
And finally we call session.Save() to save the session in the response.
|
||||||
|
|
||||||
|
Note that in production code, we should check for errors when calling
|
||||||
|
session.Save(r, w), and either display an error message or otherwise handle it.
|
||||||
|
|
||||||
|
Save must be called before writing to the response, otherwise the session
|
||||||
|
cookie will not be sent to the client.
|
||||||
|
|
||||||
|
Important Note: If you aren't using gorilla/mux, you need to wrap your handlers
|
||||||
|
with context.ClearHandler as or else you will leak memory! An easy way to do this
|
||||||
|
is to wrap the top-level mux when calling http.ListenAndServe:
|
||||||
|
|
||||||
|
http.ListenAndServe(":8080", context.ClearHandler(http.DefaultServeMux))
|
||||||
|
|
||||||
|
The ClearHandler function is provided by the gorilla/context package.
|
||||||
|
|
||||||
|
That's all you need to know for the basic usage. Let's take a look at other
|
||||||
|
options, starting with flash messages.
|
||||||
|
|
||||||
|
Flash messages are session values that last until read. The term appeared with
|
||||||
|
Ruby On Rails a few years back. When we request a flash message, it is removed
|
||||||
|
from the session. To add a flash, call session.AddFlash(), and to get all
|
||||||
|
flashes, call session.Flashes(). Here is an example:
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get a session.
|
||||||
|
session, err := store.Get(r, "session-name")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the previously flashes, if any.
|
||||||
|
if flashes := session.Flashes(); len(flashes) > 0 {
|
||||||
|
// Use the flash values.
|
||||||
|
} else {
|
||||||
|
// Set a new flash.
|
||||||
|
session.AddFlash("Hello, flash messages world!")
|
||||||
|
}
|
||||||
|
session.Save(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
Flash messages are useful to set information to be read after a redirection,
|
||||||
|
like after form submissions.
|
||||||
|
|
||||||
|
There may also be cases where you want to store a complex datatype within a
|
||||||
|
session, such as a struct. Sessions are serialised using the encoding/gob package,
|
||||||
|
so it is easy to register new datatypes for storage in sessions:
|
||||||
|
|
||||||
|
import(
|
||||||
|
"encoding/gob"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Person struct {
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
Email string
|
||||||
|
Age int
|
||||||
|
}
|
||||||
|
|
||||||
|
type M map[string]interface{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
|
||||||
|
gob.Register(&Person{})
|
||||||
|
gob.Register(&M{})
|
||||||
|
}
|
||||||
|
|
||||||
|
As it's not possible to pass a raw type as a parameter to a function, gob.Register()
|
||||||
|
relies on us passing it a value of the desired type. In the example above we've passed
|
||||||
|
it a pointer to a struct and a pointer to a custom type representing a
|
||||||
|
map[string]interface. (We could have passed non-pointer values if we wished.) This will
|
||||||
|
then allow us to serialise/deserialise values of those types to and from our sessions.
|
||||||
|
|
||||||
|
Note that because session values are stored in a map[string]interface{}, there's
|
||||||
|
a need to type-assert data when retrieving it. We'll use the Person struct we registered above:
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
session, err := store.Get(r, "session-name")
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve our struct and type-assert it
|
||||||
|
val := session.Values["person"]
|
||||||
|
var person = &Person{}
|
||||||
|
if person, ok := val.(*Person); !ok {
|
||||||
|
// Handle the case that it's not an expected type
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now we can use our person object
|
||||||
|
}
|
||||||
|
|
||||||
|
By default, session cookies last for a month. This is probably too long for
|
||||||
|
some cases, but it is easy to change this and other attributes during
|
||||||
|
runtime. Sessions can be configured individually or the store can be
|
||||||
|
configured and then all sessions saved using it will use that configuration.
|
||||||
|
We access session.Options or store.Options to set a new configuration. The
|
||||||
|
fields are basically a subset of http.Cookie fields. Let's change the
|
||||||
|
maximum age of a session to one week:
|
||||||
|
|
||||||
|
session.Options = &sessions.Options{
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 86400 * 7,
|
||||||
|
HttpOnly: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
Sometimes we may want to change authentication and/or encryption keys without
|
||||||
|
breaking existing sessions. The CookieStore supports key rotation, and to use
|
||||||
|
it you just need to set multiple authentication and encryption keys, in pairs,
|
||||||
|
to be tested in order:
|
||||||
|
|
||||||
|
var store = sessions.NewCookieStore(
|
||||||
|
[]byte("new-authentication-key"),
|
||||||
|
[]byte("new-encryption-key"),
|
||||||
|
[]byte("old-authentication-key"),
|
||||||
|
[]byte("old-encryption-key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
New sessions will be saved using the first pair. Old sessions can still be
|
||||||
|
read because the first pair will fail, and the second will be tested. This
|
||||||
|
makes it easy to "rotate" secret keys and still be able to validate existing
|
||||||
|
sessions. Note: for all pairs the encryption key is optional; set it to nil
|
||||||
|
or omit it and and encryption won't be used.
|
||||||
|
|
||||||
|
Multiple sessions can be used in the same request, even with different
|
||||||
|
session backends. When this happens, calling Save() on each session
|
||||||
|
individually would be cumbersome, so we have a way to save all sessions
|
||||||
|
at once: it's sessions.Save(). Here's an example:
|
||||||
|
|
||||||
|
var store = sessions.NewCookieStore([]byte("something-very-secret"))
|
||||||
|
|
||||||
|
func MyHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get a session and set a value.
|
||||||
|
session1, _ := store.Get(r, "session-one")
|
||||||
|
session1.Values["foo"] = "bar"
|
||||||
|
// Get another session and set another value.
|
||||||
|
session2, _ := store.Get(r, "session-two")
|
||||||
|
session2.Values[42] = 43
|
||||||
|
// Save all sessions.
|
||||||
|
sessions.Save(r, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
This is possible because when we call Get() from a session store, it adds the
|
||||||
|
session to a common registry. Save() uses it to save all registered sessions.
|
||||||
|
*/
|
||||||
|
package sessions
|
|
@ -0,0 +1,102 @@
|
||||||
|
// This file contains code adapted from the Go standard library
|
||||||
|
// https://github.com/golang/go/blob/39ad0fd0789872f9469167be7fe9578625ff246e/src/net/http/lex.go
|
||||||
|
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
var isTokenTable = [127]bool{
|
||||||
|
'!': true,
|
||||||
|
'#': true,
|
||||||
|
'$': true,
|
||||||
|
'%': true,
|
||||||
|
'&': true,
|
||||||
|
'\'': true,
|
||||||
|
'*': true,
|
||||||
|
'+': true,
|
||||||
|
'-': true,
|
||||||
|
'.': true,
|
||||||
|
'0': true,
|
||||||
|
'1': true,
|
||||||
|
'2': true,
|
||||||
|
'3': true,
|
||||||
|
'4': true,
|
||||||
|
'5': true,
|
||||||
|
'6': true,
|
||||||
|
'7': true,
|
||||||
|
'8': true,
|
||||||
|
'9': true,
|
||||||
|
'A': true,
|
||||||
|
'B': true,
|
||||||
|
'C': true,
|
||||||
|
'D': true,
|
||||||
|
'E': true,
|
||||||
|
'F': true,
|
||||||
|
'G': true,
|
||||||
|
'H': true,
|
||||||
|
'I': true,
|
||||||
|
'J': true,
|
||||||
|
'K': true,
|
||||||
|
'L': true,
|
||||||
|
'M': true,
|
||||||
|
'N': true,
|
||||||
|
'O': true,
|
||||||
|
'P': true,
|
||||||
|
'Q': true,
|
||||||
|
'R': true,
|
||||||
|
'S': true,
|
||||||
|
'T': true,
|
||||||
|
'U': true,
|
||||||
|
'W': true,
|
||||||
|
'V': true,
|
||||||
|
'X': true,
|
||||||
|
'Y': true,
|
||||||
|
'Z': true,
|
||||||
|
'^': true,
|
||||||
|
'_': true,
|
||||||
|
'`': true,
|
||||||
|
'a': true,
|
||||||
|
'b': true,
|
||||||
|
'c': true,
|
||||||
|
'd': true,
|
||||||
|
'e': true,
|
||||||
|
'f': true,
|
||||||
|
'g': true,
|
||||||
|
'h': true,
|
||||||
|
'i': true,
|
||||||
|
'j': true,
|
||||||
|
'k': true,
|
||||||
|
'l': true,
|
||||||
|
'm': true,
|
||||||
|
'n': true,
|
||||||
|
'o': true,
|
||||||
|
'p': true,
|
||||||
|
'q': true,
|
||||||
|
'r': true,
|
||||||
|
's': true,
|
||||||
|
't': true,
|
||||||
|
'u': true,
|
||||||
|
'v': true,
|
||||||
|
'w': true,
|
||||||
|
'x': true,
|
||||||
|
'y': true,
|
||||||
|
'z': true,
|
||||||
|
'|': true,
|
||||||
|
'~': true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func isToken(r rune) bool {
|
||||||
|
i := int(r)
|
||||||
|
return i < len(isTokenTable) && isTokenTable[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNotToken(r rune) bool {
|
||||||
|
return !isToken(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isCookieNameValid(raw string) bool {
|
||||||
|
if raw == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.IndexFunc(raw, isNotToken) < 0
|
||||||
|
}
|
|
@ -0,0 +1,241 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/gob"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Default flashes key.
|
||||||
|
const flashesKey = "_flash"
|
||||||
|
|
||||||
|
// Options --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Options stores configuration for a session or session store.
|
||||||
|
//
|
||||||
|
// Fields are a subset of http.Cookie fields.
|
||||||
|
type Options struct {
|
||||||
|
Path string
|
||||||
|
Domain string
|
||||||
|
// MaxAge=0 means no 'Max-Age' attribute specified.
|
||||||
|
// MaxAge<0 means delete cookie now, equivalently 'Max-Age: 0'.
|
||||||
|
// MaxAge>0 means Max-Age attribute present and given in seconds.
|
||||||
|
MaxAge int
|
||||||
|
Secure bool
|
||||||
|
HttpOnly bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// NewSession is called by session stores to create a new session instance.
|
||||||
|
func NewSession(store Store, name string) *Session {
|
||||||
|
return &Session{
|
||||||
|
Values: make(map[interface{}]interface{}),
|
||||||
|
store: store,
|
||||||
|
name: name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session stores the values and optional configuration for a session.
|
||||||
|
type Session struct {
|
||||||
|
// The ID of the session, generated by stores. It should not be used for
|
||||||
|
// user data.
|
||||||
|
ID string
|
||||||
|
// Values contains the user-data for the session.
|
||||||
|
Values map[interface{}]interface{}
|
||||||
|
Options *Options
|
||||||
|
IsNew bool
|
||||||
|
store Store
|
||||||
|
name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flashes returns a slice of flash messages from the session.
|
||||||
|
//
|
||||||
|
// A single variadic argument is accepted, and it is optional: it defines
|
||||||
|
// the flash key. If not defined "_flash" is used by default.
|
||||||
|
func (s *Session) Flashes(vars ...string) []interface{} {
|
||||||
|
var flashes []interface{}
|
||||||
|
key := flashesKey
|
||||||
|
if len(vars) > 0 {
|
||||||
|
key = vars[0]
|
||||||
|
}
|
||||||
|
if v, ok := s.Values[key]; ok {
|
||||||
|
// Drop the flashes and return it.
|
||||||
|
delete(s.Values, key)
|
||||||
|
flashes = v.([]interface{})
|
||||||
|
}
|
||||||
|
return flashes
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddFlash adds a flash message to the session.
|
||||||
|
//
|
||||||
|
// A single variadic argument is accepted, and it is optional: it defines
|
||||||
|
// the flash key. If not defined "_flash" is used by default.
|
||||||
|
func (s *Session) AddFlash(value interface{}, vars ...string) {
|
||||||
|
key := flashesKey
|
||||||
|
if len(vars) > 0 {
|
||||||
|
key = vars[0]
|
||||||
|
}
|
||||||
|
var flashes []interface{}
|
||||||
|
if v, ok := s.Values[key]; ok {
|
||||||
|
flashes = v.([]interface{})
|
||||||
|
}
|
||||||
|
s.Values[key] = append(flashes, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save is a convenience method to save this session. It is the same as calling
|
||||||
|
// store.Save(request, response, session). You should call Save before writing to
|
||||||
|
// the response or returning from the handler.
|
||||||
|
func (s *Session) Save(r *http.Request, w http.ResponseWriter) error {
|
||||||
|
return s.store.Save(r, w, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name used to register the session.
|
||||||
|
func (s *Session) Name() string {
|
||||||
|
return s.name
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store returns the session store used to register the session.
|
||||||
|
func (s *Session) Store() Store {
|
||||||
|
return s.store
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry -------------------------------------------------------------------
|
||||||
|
|
||||||
|
// sessionInfo stores a session tracked by the registry.
|
||||||
|
type sessionInfo struct {
|
||||||
|
s *Session
|
||||||
|
e error
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKey is the type used to store the registry in the context.
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
// registryKey is the key used to store the registry in the context.
|
||||||
|
const registryKey contextKey = 0
|
||||||
|
|
||||||
|
// GetRegistry returns a registry instance for the current request.
|
||||||
|
func GetRegistry(r *http.Request) *Registry {
|
||||||
|
registry := context.Get(r, registryKey)
|
||||||
|
if registry != nil {
|
||||||
|
return registry.(*Registry)
|
||||||
|
}
|
||||||
|
newRegistry := &Registry{
|
||||||
|
request: r,
|
||||||
|
sessions: make(map[string]sessionInfo),
|
||||||
|
}
|
||||||
|
context.Set(r, registryKey, newRegistry)
|
||||||
|
return newRegistry
|
||||||
|
}
|
||||||
|
|
||||||
|
// Registry stores sessions used during a request.
|
||||||
|
type Registry struct {
|
||||||
|
request *http.Request
|
||||||
|
sessions map[string]sessionInfo
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get registers and returns a session for the given name and session store.
|
||||||
|
//
|
||||||
|
// It returns a new session if there are no sessions registered for the name.
|
||||||
|
func (s *Registry) Get(store Store, name string) (session *Session, err error) {
|
||||||
|
if !isCookieNameValid(name) {
|
||||||
|
return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name)
|
||||||
|
}
|
||||||
|
if info, ok := s.sessions[name]; ok {
|
||||||
|
session, err = info.s, info.e
|
||||||
|
} else {
|
||||||
|
session, err = store.New(s.request, name)
|
||||||
|
session.name = name
|
||||||
|
s.sessions[name] = sessionInfo{s: session, e: err}
|
||||||
|
}
|
||||||
|
session.store = store
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save saves all sessions registered for the current request.
|
||||||
|
func (s *Registry) Save(w http.ResponseWriter) error {
|
||||||
|
var errMulti MultiError
|
||||||
|
for name, info := range s.sessions {
|
||||||
|
session := info.s
|
||||||
|
if session.store == nil {
|
||||||
|
errMulti = append(errMulti, fmt.Errorf(
|
||||||
|
"sessions: missing store for session %q", name))
|
||||||
|
} else if err := session.store.Save(s.request, w, session); err != nil {
|
||||||
|
errMulti = append(errMulti, fmt.Errorf(
|
||||||
|
"sessions: error saving session %q -- %v", name, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errMulti != nil {
|
||||||
|
return errMulti
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helpers --------------------------------------------------------------------
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gob.Register([]interface{}{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save saves all sessions used during the current request.
|
||||||
|
func Save(r *http.Request, w http.ResponseWriter) error {
|
||||||
|
return GetRegistry(r).Save(w)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCookie returns an http.Cookie with the options set. It also sets
|
||||||
|
// the Expires field calculated based on the MaxAge value, for Internet
|
||||||
|
// Explorer compatibility.
|
||||||
|
func NewCookie(name, value string, options *Options) *http.Cookie {
|
||||||
|
cookie := &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
Path: options.Path,
|
||||||
|
Domain: options.Domain,
|
||||||
|
MaxAge: options.MaxAge,
|
||||||
|
Secure: options.Secure,
|
||||||
|
HttpOnly: options.HttpOnly,
|
||||||
|
}
|
||||||
|
if options.MaxAge > 0 {
|
||||||
|
d := time.Duration(options.MaxAge) * time.Second
|
||||||
|
cookie.Expires = time.Now().Add(d)
|
||||||
|
} else if options.MaxAge < 0 {
|
||||||
|
// Set it to the past to expire now.
|
||||||
|
cookie.Expires = time.Unix(1, 0)
|
||||||
|
}
|
||||||
|
return cookie
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error ----------------------------------------------------------------------
|
||||||
|
|
||||||
|
// MultiError stores multiple errors.
|
||||||
|
//
|
||||||
|
// Borrowed from the App Engine SDK.
|
||||||
|
type MultiError []error
|
||||||
|
|
||||||
|
func (m MultiError) Error() string {
|
||||||
|
s, n := "", 0
|
||||||
|
for _, e := range m {
|
||||||
|
if e != nil {
|
||||||
|
if n == 0 {
|
||||||
|
s = e.Error()
|
||||||
|
}
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch n {
|
||||||
|
case 0:
|
||||||
|
return "(0 errors)"
|
||||||
|
case 1:
|
||||||
|
return s
|
||||||
|
case 2:
|
||||||
|
return s + " (and 1 other error)"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s (and %d other errors)", s, n-1)
|
||||||
|
}
|
|
@ -0,0 +1,295 @@
|
||||||
|
// Copyright 2012 The Gorilla Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base32"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store is an interface for custom session stores.
|
||||||
|
//
|
||||||
|
// See CookieStore and FilesystemStore for examples.
|
||||||
|
type Store interface {
|
||||||
|
// Get should return a cached session.
|
||||||
|
Get(r *http.Request, name string) (*Session, error)
|
||||||
|
|
||||||
|
// New should create and return a new session.
|
||||||
|
//
|
||||||
|
// Note that New should never return a nil session, even in the case of
|
||||||
|
// an error if using the Registry infrastructure to cache the session.
|
||||||
|
New(r *http.Request, name string) (*Session, error)
|
||||||
|
|
||||||
|
// Save should persist session to the underlying store implementation.
|
||||||
|
Save(r *http.Request, w http.ResponseWriter, s *Session) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieStore ----------------------------------------------------------------
|
||||||
|
|
||||||
|
// NewCookieStore returns a new CookieStore.
|
||||||
|
//
|
||||||
|
// Keys are defined in pairs to allow key rotation, but the common case is
|
||||||
|
// to set a single authentication key and optionally an encryption key.
|
||||||
|
//
|
||||||
|
// The first key in a pair is used for authentication and the second for
|
||||||
|
// encryption. The encryption key can be set to nil or omitted in the last
|
||||||
|
// pair, but the authentication key is required in all pairs.
|
||||||
|
//
|
||||||
|
// It is recommended to use an authentication key with 32 or 64 bytes.
|
||||||
|
// The encryption key, if set, must be either 16, 24, or 32 bytes to select
|
||||||
|
// AES-128, AES-192, or AES-256 modes.
|
||||||
|
//
|
||||||
|
// Use the convenience function securecookie.GenerateRandomKey() to create
|
||||||
|
// strong keys.
|
||||||
|
func NewCookieStore(keyPairs ...[]byte) *CookieStore {
|
||||||
|
cs := &CookieStore{
|
||||||
|
Codecs: securecookie.CodecsFromPairs(keyPairs...),
|
||||||
|
Options: &Options{
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 86400 * 30,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cs.MaxAge(cs.Options.MaxAge)
|
||||||
|
return cs
|
||||||
|
}
|
||||||
|
|
||||||
|
// CookieStore stores sessions using secure cookies.
|
||||||
|
type CookieStore struct {
|
||||||
|
Codecs []securecookie.Codec
|
||||||
|
Options *Options // default configuration
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a session for the given name after adding it to the registry.
|
||||||
|
//
|
||||||
|
// It returns a new session if the sessions doesn't exist. Access IsNew on
|
||||||
|
// the session to check if it is an existing session or a new one.
|
||||||
|
//
|
||||||
|
// It returns a new session and an error if the session exists but could
|
||||||
|
// not be decoded.
|
||||||
|
func (s *CookieStore) Get(r *http.Request, name string) (*Session, error) {
|
||||||
|
return GetRegistry(r).Get(s, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a session for the given name without adding it to the registry.
|
||||||
|
//
|
||||||
|
// The difference between New() and Get() is that calling New() twice will
|
||||||
|
// decode the session data twice, while Get() registers and reuses the same
|
||||||
|
// decoded session after the first call.
|
||||||
|
func (s *CookieStore) New(r *http.Request, name string) (*Session, error) {
|
||||||
|
session := NewSession(s, name)
|
||||||
|
opts := *s.Options
|
||||||
|
session.Options = &opts
|
||||||
|
session.IsNew = true
|
||||||
|
var err error
|
||||||
|
if c, errCookie := r.Cookie(name); errCookie == nil {
|
||||||
|
err = securecookie.DecodeMulti(name, c.Value, &session.Values,
|
||||||
|
s.Codecs...)
|
||||||
|
if err == nil {
|
||||||
|
session.IsNew = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save adds a single session to the response.
|
||||||
|
func (s *CookieStore) Save(r *http.Request, w http.ResponseWriter,
|
||||||
|
session *Session) error {
|
||||||
|
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
|
||||||
|
s.Codecs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxAge sets the maximum age for the store and the underlying cookie
|
||||||
|
// implementation. Individual sessions can be deleted by setting Options.MaxAge
|
||||||
|
// = -1 for that session.
|
||||||
|
func (s *CookieStore) MaxAge(age int) {
|
||||||
|
s.Options.MaxAge = age
|
||||||
|
|
||||||
|
// Set the maxAge for each securecookie instance.
|
||||||
|
for _, codec := range s.Codecs {
|
||||||
|
if sc, ok := codec.(*securecookie.SecureCookie); ok {
|
||||||
|
sc.MaxAge(age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilesystemStore ------------------------------------------------------------
|
||||||
|
|
||||||
|
var fileMutex sync.RWMutex
|
||||||
|
|
||||||
|
// NewFilesystemStore returns a new FilesystemStore.
|
||||||
|
//
|
||||||
|
// The path argument is the directory where sessions will be saved. If empty
|
||||||
|
// it will use os.TempDir().
|
||||||
|
//
|
||||||
|
// See NewCookieStore() for a description of the other parameters.
|
||||||
|
func NewFilesystemStore(path string, keyPairs ...[]byte) *FilesystemStore {
|
||||||
|
if path == "" {
|
||||||
|
path = os.TempDir()
|
||||||
|
}
|
||||||
|
fs := &FilesystemStore{
|
||||||
|
Codecs: securecookie.CodecsFromPairs(keyPairs...),
|
||||||
|
Options: &Options{
|
||||||
|
Path: "/",
|
||||||
|
MaxAge: 86400 * 30,
|
||||||
|
},
|
||||||
|
path: path,
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.MaxAge(fs.Options.MaxAge)
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilesystemStore stores sessions in the filesystem.
|
||||||
|
//
|
||||||
|
// It also serves as a reference for custom stores.
|
||||||
|
//
|
||||||
|
// This store is still experimental and not well tested. Feedback is welcome.
|
||||||
|
type FilesystemStore struct {
|
||||||
|
Codecs []securecookie.Codec
|
||||||
|
Options *Options // default configuration
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxLength restricts the maximum length of new sessions to l.
|
||||||
|
// If l is 0 there is no limit to the size of a session, use with caution.
|
||||||
|
// The default for a new FilesystemStore is 4096.
|
||||||
|
func (s *FilesystemStore) MaxLength(l int) {
|
||||||
|
for _, c := range s.Codecs {
|
||||||
|
if codec, ok := c.(*securecookie.SecureCookie); ok {
|
||||||
|
codec.MaxLength(l)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a session for the given name after adding it to the registry.
|
||||||
|
//
|
||||||
|
// See CookieStore.Get().
|
||||||
|
func (s *FilesystemStore) Get(r *http.Request, name string) (*Session, error) {
|
||||||
|
return GetRegistry(r).Get(s, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a session for the given name without adding it to the registry.
|
||||||
|
//
|
||||||
|
// See CookieStore.New().
|
||||||
|
func (s *FilesystemStore) New(r *http.Request, name string) (*Session, error) {
|
||||||
|
session := NewSession(s, name)
|
||||||
|
opts := *s.Options
|
||||||
|
session.Options = &opts
|
||||||
|
session.IsNew = true
|
||||||
|
var err error
|
||||||
|
if c, errCookie := r.Cookie(name); errCookie == nil {
|
||||||
|
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
|
||||||
|
if err == nil {
|
||||||
|
err = s.load(session)
|
||||||
|
if err == nil {
|
||||||
|
session.IsNew = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save adds a single session to the response.
|
||||||
|
//
|
||||||
|
// If the Options.MaxAge of the session is <= 0 then the session file will be
|
||||||
|
// deleted from the store path. With this process it enforces the properly
|
||||||
|
// session cookie handling so no need to trust in the cookie management in the
|
||||||
|
// web browser.
|
||||||
|
func (s *FilesystemStore) Save(r *http.Request, w http.ResponseWriter,
|
||||||
|
session *Session) error {
|
||||||
|
// Delete if max-age is <= 0
|
||||||
|
if session.Options.MaxAge <= 0 {
|
||||||
|
if err := s.erase(session); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
http.SetCookie(w, NewCookie(session.Name(), "", session.Options))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if session.ID == "" {
|
||||||
|
// Because the ID is used in the filename, encode it to
|
||||||
|
// use alphanumeric characters only.
|
||||||
|
session.ID = strings.TrimRight(
|
||||||
|
base32.StdEncoding.EncodeToString(
|
||||||
|
securecookie.GenerateRandomKey(32)), "=")
|
||||||
|
}
|
||||||
|
if err := s.save(session); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
encoded, err := securecookie.EncodeMulti(session.Name(), session.ID,
|
||||||
|
s.Codecs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
http.SetCookie(w, NewCookie(session.Name(), encoded, session.Options))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MaxAge sets the maximum age for the store and the underlying cookie
|
||||||
|
// implementation. Individual sessions can be deleted by setting Options.MaxAge
|
||||||
|
// = -1 for that session.
|
||||||
|
func (s *FilesystemStore) MaxAge(age int) {
|
||||||
|
s.Options.MaxAge = age
|
||||||
|
|
||||||
|
// Set the maxAge for each securecookie instance.
|
||||||
|
for _, codec := range s.Codecs {
|
||||||
|
if sc, ok := codec.(*securecookie.SecureCookie); ok {
|
||||||
|
sc.MaxAge(age)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// save writes encoded session.Values to a file.
|
||||||
|
func (s *FilesystemStore) save(session *Session) error {
|
||||||
|
encoded, err := securecookie.EncodeMulti(session.Name(), session.Values,
|
||||||
|
s.Codecs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
filename := filepath.Join(s.path, "session_"+session.ID)
|
||||||
|
fileMutex.Lock()
|
||||||
|
defer fileMutex.Unlock()
|
||||||
|
return ioutil.WriteFile(filename, []byte(encoded), 0600)
|
||||||
|
}
|
||||||
|
|
||||||
|
// load reads a file and decodes its content into session.Values.
|
||||||
|
func (s *FilesystemStore) load(session *Session) error {
|
||||||
|
filename := filepath.Join(s.path, "session_"+session.ID)
|
||||||
|
fileMutex.RLock()
|
||||||
|
defer fileMutex.RUnlock()
|
||||||
|
fdata, err := ioutil.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err = securecookie.DecodeMulti(session.Name(), string(fdata),
|
||||||
|
&session.Values, s.Codecs...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete session file
|
||||||
|
func (s *FilesystemStore) erase(session *Session) error {
|
||||||
|
filename := filepath.Join(s.path, "session_"+session.ID)
|
||||||
|
|
||||||
|
fileMutex.RLock()
|
||||||
|
defer fileMutex.RUnlock()
|
||||||
|
|
||||||
|
err := os.Remove(filename)
|
||||||
|
return err
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
Copyright (c) 2014 Mark Bates
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining
|
||||||
|
a copy of this software and associated documentation files (the
|
||||||
|
"Software"), to deal in the Software without restriction, including
|
||||||
|
without limitation the rights to use, copy, modify, merge, publish,
|
||||||
|
distribute, sublicense, and/or sell copies of the Software, and to
|
||||||
|
permit persons to whom the Software is furnished to do so, subject to
|
||||||
|
the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be
|
||||||
|
included in all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||||
|
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||||
|
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||||
|
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
|
||||||
|
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||||
|
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
|
||||||
|
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,143 @@
|
||||||
|
# Goth: Multi-Provider Authentication for Go [![GoDoc](https://godoc.org/github.com/markbates/goth?status.svg)](https://godoc.org/github.com/markbates/goth) [![Build Status](https://travis-ci.org/markbates/goth.svg)](https://travis-ci.org/markbates/goth)
|
||||||
|
|
||||||
|
Package goth provides a simple, clean, and idiomatic way to write authentication
|
||||||
|
packages for Go web applications.
|
||||||
|
|
||||||
|
Unlike other similar packages, Goth, lets you write OAuth, OAuth2, or any other
|
||||||
|
protocol providers, as long as they implement the `Provider` and `Session` interfaces.
|
||||||
|
|
||||||
|
This package was inspired by [https://github.com/intridea/omniauth](https://github.com/intridea/omniauth).
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
```text
|
||||||
|
$ go get github.com/markbates/goth
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Providers
|
||||||
|
|
||||||
|
* Amazon
|
||||||
|
* Auth0
|
||||||
|
* Bitbucket
|
||||||
|
* Box
|
||||||
|
* Cloud Foundry
|
||||||
|
* Dailymotion
|
||||||
|
* Deezer
|
||||||
|
* Digital Ocean
|
||||||
|
* Discord
|
||||||
|
* Dropbox
|
||||||
|
* Facebook
|
||||||
|
* Fitbit
|
||||||
|
* GitHub
|
||||||
|
* Gitlab
|
||||||
|
* Google+
|
||||||
|
* Heroku
|
||||||
|
* InfluxCloud
|
||||||
|
* Instagram
|
||||||
|
* Intercom
|
||||||
|
* Lastfm
|
||||||
|
* Linkedin
|
||||||
|
* Meetup
|
||||||
|
* OneDrive
|
||||||
|
* OpenID Connect (auto discovery)
|
||||||
|
* Paypal
|
||||||
|
* SalesForce
|
||||||
|
* Slack
|
||||||
|
* Soundcloud
|
||||||
|
* Spotify
|
||||||
|
* Steam
|
||||||
|
* Stripe
|
||||||
|
* Twitch
|
||||||
|
* Twitter
|
||||||
|
* Uber
|
||||||
|
* Wepay
|
||||||
|
* Yahoo
|
||||||
|
* Yammer
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
|
||||||
|
See the [examples](examples) folder for a working application that lets users authenticate
|
||||||
|
through Twitter, Facebook, Google Plus etc.
|
||||||
|
|
||||||
|
To run the example either clone the source from GitHub
|
||||||
|
|
||||||
|
```text
|
||||||
|
$ git clone git@github.com:markbates/goth.git
|
||||||
|
```
|
||||||
|
or use
|
||||||
|
```text
|
||||||
|
$ go get github.com/markbates/goth
|
||||||
|
```
|
||||||
|
```text
|
||||||
|
$ cd goth/examples
|
||||||
|
$ go get -v
|
||||||
|
$ go build
|
||||||
|
$ ./examples
|
||||||
|
```
|
||||||
|
|
||||||
|
Now open up your browser and go to [http://localhost:3000](http://localhost:3000) to see the example.
|
||||||
|
|
||||||
|
To actually use the different providers, please make sure you configure them given the system environments as defined in the examples/main.go file
|
||||||
|
|
||||||
|
## Issues
|
||||||
|
|
||||||
|
Issues always stand a significantly better chance of getting fixed if the are accompanied by a
|
||||||
|
pull request.
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
Would I love to see more providers? Certainly! Would you love to contribute one? Hopefully, yes!
|
||||||
|
|
||||||
|
1. Fork it
|
||||||
|
2. Create your feature branch (git checkout -b my-new-feature)
|
||||||
|
3. Write Tests!
|
||||||
|
4. Commit your changes (git commit -am 'Add some feature')
|
||||||
|
5. Push to the branch (git push origin my-new-feature)
|
||||||
|
6. Create new Pull Request
|
||||||
|
|
||||||
|
## Contributors
|
||||||
|
|
||||||
|
* Mark Bates
|
||||||
|
* Tyler Bunnell
|
||||||
|
* Corey McGrillis
|
||||||
|
* willemvd
|
||||||
|
* Rakesh Goyal
|
||||||
|
* Andy Grunwald
|
||||||
|
* Glenn Walker
|
||||||
|
* Kevin Fitzpatrick
|
||||||
|
* Ben Tranter
|
||||||
|
* Sharad Ganapathy
|
||||||
|
* Andrew Chilton
|
||||||
|
* sharadgana
|
||||||
|
* Aurorae
|
||||||
|
* Craig P Jolicoeur
|
||||||
|
* Zac Bergquist
|
||||||
|
* Geoff Franks
|
||||||
|
* Raphael Geronimi
|
||||||
|
* Noah Shibley
|
||||||
|
* lumost
|
||||||
|
* oov
|
||||||
|
* Felix Lamouroux
|
||||||
|
* Rafael Quintela
|
||||||
|
* Tyler
|
||||||
|
* DenSm
|
||||||
|
* Samy KACIMI
|
||||||
|
* dante gray
|
||||||
|
* Noah
|
||||||
|
* Jacob Walker
|
||||||
|
* Marin Martinic
|
||||||
|
* Roy
|
||||||
|
* Omni Adams
|
||||||
|
* Sasa Brankovic
|
||||||
|
* dkhamsing
|
||||||
|
* Dante Swift
|
||||||
|
* Attila Domokos
|
||||||
|
* Albin Gilles
|
||||||
|
* Syed Zubairuddin
|
||||||
|
* Johnny Boursiquot
|
||||||
|
* Jerome Touffe-Blin
|
||||||
|
* bryanl
|
||||||
|
* Masanobu YOSHIOKA
|
||||||
|
* Jonathan Hall
|
||||||
|
* HaiMing.Yin
|
||||||
|
* Sairam Kunala
|
|
@ -0,0 +1,10 @@
|
||||||
|
/*
|
||||||
|
Package goth provides a simple, clean, and idiomatic way to write authentication
|
||||||
|
packages for Go web applications.
|
||||||
|
|
||||||
|
This package was inspired by https://github.com/intridea/omniauth.
|
||||||
|
|
||||||
|
See the examples folder for a working application that lets users authenticate
|
||||||
|
through Twitter or Facebook.
|
||||||
|
*/
|
||||||
|
package goth
|
|
@ -0,0 +1,219 @@
|
||||||
|
/*
|
||||||
|
Package gothic wraps common behaviour when using Goth. This makes it quick, and easy, to get up
|
||||||
|
and running with Goth. Of course, if you want complete control over how things flow, in regards
|
||||||
|
to the authentication process, feel free and use Goth directly.
|
||||||
|
|
||||||
|
See https://github.com/markbates/goth/examples/main.go to see this in action.
|
||||||
|
*/
|
||||||
|
package gothic
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
|
"github.com/markbates/goth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SessionName is the key used to access the session store.
|
||||||
|
const SessionName = "_gothic_session"
|
||||||
|
|
||||||
|
// Store can/should be set by applications using gothic. The default is a cookie store.
|
||||||
|
var Store sessions.Store
|
||||||
|
var defaultStore sessions.Store
|
||||||
|
|
||||||
|
var keySet = false
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
key := []byte(os.Getenv("SESSION_SECRET"))
|
||||||
|
keySet = len(key) != 0
|
||||||
|
Store = sessions.NewCookieStore([]byte(key))
|
||||||
|
defaultStore = Store
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
BeginAuthHandler is a convienence handler for starting the authentication process.
|
||||||
|
It expects to be able to get the name of the provider from the query parameters
|
||||||
|
as either "provider" or ":provider".
|
||||||
|
|
||||||
|
BeginAuthHandler will redirect the user to the appropriate authentication end-point
|
||||||
|
for the requested provider.
|
||||||
|
|
||||||
|
See https://github.com/markbates/goth/examples/main.go to see this in action.
|
||||||
|
*/
|
||||||
|
func BeginAuthHandler(res http.ResponseWriter, req *http.Request) {
|
||||||
|
url, err := GetAuthURL(res, req)
|
||||||
|
if err != nil {
|
||||||
|
res.WriteHeader(http.StatusBadRequest)
|
||||||
|
fmt.Fprintln(res, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(res, req, url, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetState sets the state string associated with the given request.
|
||||||
|
// If no state string is associated with the request, one will be generated.
|
||||||
|
// This state is sent to the provider and can be retrieved during the
|
||||||
|
// callback.
|
||||||
|
var SetState = func(req *http.Request) string {
|
||||||
|
state := req.URL.Query().Get("state")
|
||||||
|
if len(state) > 0 {
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
|
||||||
|
return "state"
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetState gets the state returned by the provider during the callback.
|
||||||
|
// This is used to prevent CSRF attacks, see
|
||||||
|
// http://tools.ietf.org/html/rfc6749#section-10.12
|
||||||
|
var GetState = func(req *http.Request) string {
|
||||||
|
return req.URL.Query().Get("state")
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
GetAuthURL starts the authentication process with the requested provided.
|
||||||
|
It will return a URL that should be used to send users to.
|
||||||
|
|
||||||
|
It expects to be able to get the name of the provider from the query parameters
|
||||||
|
as either "provider" or ":provider".
|
||||||
|
|
||||||
|
I would recommend using the BeginAuthHandler instead of doing all of these steps
|
||||||
|
yourself, but that's entirely up to you.
|
||||||
|
*/
|
||||||
|
func GetAuthURL(res http.ResponseWriter, req *http.Request) (string, error) {
|
||||||
|
|
||||||
|
if !keySet && defaultStore == Store {
|
||||||
|
fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerName, err := GetProviderName(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := goth.GetProvider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sess, err := provider.BeginAuth(SetState(req))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := sess.GetAuthURL()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = storeInSession(providerName, sess.Marshal(), req, res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return url, err
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
CompleteUserAuth does what it says on the tin. It completes the authentication
|
||||||
|
process and fetches all of the basic information about the user from the provider.
|
||||||
|
|
||||||
|
It expects to be able to get the name of the provider from the query parameters
|
||||||
|
as either "provider" or ":provider".
|
||||||
|
|
||||||
|
See https://github.com/markbates/goth/examples/main.go to see this in action.
|
||||||
|
*/
|
||||||
|
var CompleteUserAuth = func(res http.ResponseWriter, req *http.Request) (goth.User, error) {
|
||||||
|
|
||||||
|
if !keySet && defaultStore == Store {
|
||||||
|
fmt.Println("goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store.")
|
||||||
|
}
|
||||||
|
|
||||||
|
providerName, err := GetProviderName(req)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := goth.GetProvider(providerName)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
value, err := getFromSession(providerName, req)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sess, err := provider.UnmarshalSession(value)
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := provider.FetchUser(sess)
|
||||||
|
if err == nil {
|
||||||
|
// user can be found with existing session data
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// get new token and retry fetch
|
||||||
|
_, err = sess.Authorize(provider, req.URL.Query())
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = storeInSession(providerName, sess.Marshal(), req, res)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return goth.User{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return provider.FetchUser(sess)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviderName is a function used to get the name of a provider
|
||||||
|
// for a given request. By default, this provider is fetched from
|
||||||
|
// the URL query string. If you provide it in a different way,
|
||||||
|
// assign your own function to this variable that returns the provider
|
||||||
|
// name for your request.
|
||||||
|
var GetProviderName = getProviderName
|
||||||
|
|
||||||
|
func getProviderName(req *http.Request) (string, error) {
|
||||||
|
provider := req.URL.Query().Get("provider")
|
||||||
|
if provider == "" {
|
||||||
|
if p, ok := mux.Vars(req)["provider"]; ok {
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if provider == "" {
|
||||||
|
provider = req.URL.Query().Get(":provider")
|
||||||
|
}
|
||||||
|
if provider == "" {
|
||||||
|
return provider, errors.New("you must select a provider")
|
||||||
|
}
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeInSession(key string, value string, req *http.Request, res http.ResponseWriter) error {
|
||||||
|
session, _ := Store.Get(req, key + SessionName)
|
||||||
|
|
||||||
|
session.Values[key] = value
|
||||||
|
|
||||||
|
return session.Save(req, res)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFromSession(key string, req *http.Request) (string, error) {
|
||||||
|
session, _ := Store.Get(req, key + SessionName)
|
||||||
|
|
||||||
|
value := session.Values[key]
|
||||||
|
if value == nil {
|
||||||
|
return "", errors.New("could not find a matching session for this request")
|
||||||
|
}
|
||||||
|
|
||||||
|
return value.(string), nil
|
||||||
|
}
|
|
@ -0,0 +1,75 @@
|
||||||
|
package goth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Provider needs to be implemented for each 3rd party authentication provider
|
||||||
|
// e.g. Facebook, Twitter, etc...
|
||||||
|
type Provider interface {
|
||||||
|
Name() string
|
||||||
|
SetName(name string)
|
||||||
|
BeginAuth(state string) (Session, error)
|
||||||
|
UnmarshalSession(string) (Session, error)
|
||||||
|
FetchUser(Session) (User, error)
|
||||||
|
Debug(bool)
|
||||||
|
RefreshToken(refreshToken string) (*oauth2.Token, error) //Get new access token based on the refresh token
|
||||||
|
RefreshTokenAvailable() bool //Refresh token is provided by auth provider or not
|
||||||
|
}
|
||||||
|
|
||||||
|
const NoAuthUrlErrorMessage = "an AuthURL has not been set"
|
||||||
|
|
||||||
|
// Providers is list of known/available providers.
|
||||||
|
type Providers map[string]Provider
|
||||||
|
|
||||||
|
var providers = Providers{}
|
||||||
|
|
||||||
|
// UseProviders adds a list of available providers for use with Goth.
|
||||||
|
// Can be called multiple times. If you pass the same provider more
|
||||||
|
// than once, the last will be used.
|
||||||
|
func UseProviders(viders ...Provider) {
|
||||||
|
for _, provider := range viders {
|
||||||
|
providers[provider.Name()] = provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProviders returns a list of all the providers currently in use.
|
||||||
|
func GetProviders() Providers {
|
||||||
|
return providers
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProvider returns a previously created provider. If Goth has not
|
||||||
|
// been told to use the named provider it will return an error.
|
||||||
|
func GetProvider(name string) (Provider, error) {
|
||||||
|
provider := providers[name]
|
||||||
|
if provider == nil {
|
||||||
|
return nil, fmt.Errorf("no provider for %s exists", name)
|
||||||
|
}
|
||||||
|
return provider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearProviders will remove all providers currently in use.
|
||||||
|
// This is useful, mostly, for testing purposes.
|
||||||
|
func ClearProviders() {
|
||||||
|
providers = Providers{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ContextForClient provides a context for use with oauth2.
|
||||||
|
func ContextForClient(h *http.Client) context.Context {
|
||||||
|
if h == nil {
|
||||||
|
return oauth2.NoContext
|
||||||
|
}
|
||||||
|
return context.WithValue(oauth2.NoContext, oauth2.HTTPClient, h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClientWithFallBack to be used in all fetch operations.
|
||||||
|
func HTTPClientWithFallBack(h *http.Client) *http.Client {
|
||||||
|
if h != nil {
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
return http.DefaultClient
|
||||||
|
}
|
|
@ -0,0 +1,224 @@
|
||||||
|
// Package github implements the OAuth2 protocol for authenticating users through Github.
|
||||||
|
// This package can be used as a reference implementation of an OAuth2 provider for Goth.
|
||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/markbates/goth"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// These vars define the Authentication, Token, and API URLS for GitHub. If
|
||||||
|
// using GitHub enterprise you should change these values before calling New.
|
||||||
|
//
|
||||||
|
// Examples:
|
||||||
|
// github.AuthURL = "https://github.acme.com/login/oauth/authorize
|
||||||
|
// github.TokenURL = "https://github.acme.com/login/oauth/access_token
|
||||||
|
// github.ProfileURL = "https://github.acme.com/api/v3/user
|
||||||
|
// github.EmailURL = "https://github.acme.com/api/v3/user/emails
|
||||||
|
var (
|
||||||
|
AuthURL = "https://github.com/login/oauth/authorize"
|
||||||
|
TokenURL = "https://github.com/login/oauth/access_token"
|
||||||
|
ProfileURL = "https://api.github.com/user"
|
||||||
|
EmailURL = "https://api.github.com/user/emails"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New creates a new Github provider, and sets up important connection details.
|
||||||
|
// You should always call `github.New` to get a new Provider. Never try to create
|
||||||
|
// one manually.
|
||||||
|
func New(clientKey, secret, callbackURL string, scopes ...string) *Provider {
|
||||||
|
p := &Provider{
|
||||||
|
ClientKey: clientKey,
|
||||||
|
Secret: secret,
|
||||||
|
CallbackURL: callbackURL,
|
||||||
|
providerName: "github",
|
||||||
|
}
|
||||||
|
p.config = newConfig(p, scopes)
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider is the implementation of `goth.Provider` for accessing Github.
|
||||||
|
type Provider struct {
|
||||||
|
ClientKey string
|
||||||
|
Secret string
|
||||||
|
CallbackURL string
|
||||||
|
HTTPClient *http.Client
|
||||||
|
config *oauth2.Config
|
||||||
|
providerName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name is the name used to retrieve this provider later.
|
||||||
|
func (p *Provider) Name() string {
|
||||||
|
return p.providerName
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetName is to update the name of the provider (needed in case of multiple providers of 1 type)
|
||||||
|
func (p *Provider) SetName(name string) {
|
||||||
|
p.providerName = name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Provider) Client() *http.Client {
|
||||||
|
return goth.HTTPClientWithFallBack(p.HTTPClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug is a no-op for the github package.
|
||||||
|
func (p *Provider) Debug(debug bool) {}
|
||||||
|
|
||||||
|
// BeginAuth asks Github for an authentication end-point.
|
||||||
|
func (p *Provider) BeginAuth(state string) (goth.Session, error) {
|
||||||
|
url := p.config.AuthCodeURL(state)
|
||||||
|
session := &Session{
|
||||||
|
AuthURL: url,
|
||||||
|
}
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchUser will go to Github and access basic information about the user.
|
||||||
|
func (p *Provider) FetchUser(session goth.Session) (goth.User, error) {
|
||||||
|
sess := session.(*Session)
|
||||||
|
user := goth.User{
|
||||||
|
AccessToken: sess.AccessToken,
|
||||||
|
Provider: p.Name(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.AccessToken == "" {
|
||||||
|
// data is not yet retrieved since accessToken is still empty
|
||||||
|
return user, fmt.Errorf("%s cannot get user information without accessToken", p.providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := p.Client().Get(ProfileURL + "?access_token=" + url.QueryEscape(sess.AccessToken))
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return user, fmt.Errorf("GitHub API responded with a %d trying to fetch user information", response.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
bits, err := ioutil.ReadAll(response.Body)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewDecoder(bytes.NewReader(bits)).Decode(&user.RawData)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = userFromReader(bytes.NewReader(bits), &user)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Email == "" {
|
||||||
|
for _, scope := range p.config.Scopes {
|
||||||
|
if strings.TrimSpace(scope) == "user" || strings.TrimSpace(scope) == "user:email" {
|
||||||
|
user.Email, err = getPrivateMail(p, sess)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func userFromReader(reader io.Reader, user *goth.User) error {
|
||||||
|
u := struct {
|
||||||
|
ID int `json:"id"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Bio string `json:"bio"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Login string `json:"login"`
|
||||||
|
Picture string `json:"avatar_url"`
|
||||||
|
Location string `json:"location"`
|
||||||
|
}{}
|
||||||
|
|
||||||
|
err := json.NewDecoder(reader).Decode(&u)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
user.Name = u.Name
|
||||||
|
user.NickName = u.Login
|
||||||
|
user.Email = u.Email
|
||||||
|
user.Description = u.Bio
|
||||||
|
user.AvatarURL = u.Picture
|
||||||
|
user.UserID = strconv.Itoa(u.ID)
|
||||||
|
user.Location = u.Location
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getPrivateMail(p *Provider, sess *Session) (email string, err error) {
|
||||||
|
response, err := p.Client().Get(EmailURL + "?access_token=" + url.QueryEscape(sess.AccessToken))
|
||||||
|
if err != nil {
|
||||||
|
if response != nil {
|
||||||
|
response.Body.Close()
|
||||||
|
}
|
||||||
|
return email, err
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
|
||||||
|
if response.StatusCode != http.StatusOK {
|
||||||
|
return email, fmt.Errorf("GitHub API responded with a %d trying to fetch user email", response.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
var mailList = []struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Primary bool `json:"primary"`
|
||||||
|
Verified bool `json:"verified"`
|
||||||
|
}{}
|
||||||
|
err = json.NewDecoder(response.Body).Decode(&mailList)
|
||||||
|
if err != nil {
|
||||||
|
return email, err
|
||||||
|
}
|
||||||
|
for _, v := range mailList {
|
||||||
|
if v.Primary && v.Verified {
|
||||||
|
return v.Email, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// can't get primary email - shouldn't be possible
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConfig(provider *Provider, scopes []string) *oauth2.Config {
|
||||||
|
c := &oauth2.Config{
|
||||||
|
ClientID: provider.ClientKey,
|
||||||
|
ClientSecret: provider.Secret,
|
||||||
|
RedirectURL: provider.CallbackURL,
|
||||||
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: AuthURL,
|
||||||
|
TokenURL: TokenURL,
|
||||||
|
},
|
||||||
|
Scopes: []string{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, scope := range scopes {
|
||||||
|
c.Scopes = append(c.Scopes, scope)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
//RefreshToken refresh token is not provided by github
|
||||||
|
func (p *Provider) RefreshToken(refreshToken string) (*oauth2.Token, error) {
|
||||||
|
return nil, errors.New("Refresh token is not provided by github")
|
||||||
|
}
|
||||||
|
|
||||||
|
//RefreshTokenAvailable refresh token is not provided by github
|
||||||
|
func (p *Provider) RefreshTokenAvailable() bool {
|
||||||
|
return false
|
||||||
|
}
|
|
@ -0,0 +1,56 @@
|
||||||
|
package github
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/markbates/goth"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Session stores data during the auth process with Github.
|
||||||
|
type Session struct {
|
||||||
|
AuthURL string
|
||||||
|
AccessToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthURL will return the URL set by calling the `BeginAuth` function on the Github provider.
|
||||||
|
func (s Session) GetAuthURL() (string, error) {
|
||||||
|
if s.AuthURL == "" {
|
||||||
|
return "", errors.New(goth.NoAuthUrlErrorMessage)
|
||||||
|
}
|
||||||
|
return s.AuthURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Authorize the session with Github and return the access token to be stored for future use.
|
||||||
|
func (s *Session) Authorize(provider goth.Provider, params goth.Params) (string, error) {
|
||||||
|
p := provider.(*Provider)
|
||||||
|
token, err := p.config.Exchange(goth.ContextForClient(p.Client()), params.Get("code"))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !token.Valid() {
|
||||||
|
return "", errors.New("Invalid token received from provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.AccessToken = token.AccessToken
|
||||||
|
return token.AccessToken, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal the session into a string
|
||||||
|
func (s Session) Marshal() string {
|
||||||
|
b, _ := json.Marshal(s)
|
||||||
|
return string(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Session) String() string {
|
||||||
|
return s.Marshal()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalSession will unmarshal a JSON string into a session.
|
||||||
|
func (p *Provider) UnmarshalSession(data string) (goth.Session, error) {
|
||||||
|
sess := &Session{}
|
||||||
|
err := json.NewDecoder(strings.NewReader(data)).Decode(sess)
|
||||||
|
return sess, err
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package goth
|
||||||
|
|
||||||
|
// Params is used to pass data to sessions for authorization. An existing
|
||||||
|
// implementation, and the one most likely to be used, is `url.Values`.
|
||||||
|
type Params interface {
|
||||||
|
Get(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Session needs to be implemented as part of the provider package.
|
||||||
|
// It will be marshaled and persisted between requests to "tie"
|
||||||
|
// the start and the end of the authorization process with a
|
||||||
|
// 3rd party provider.
|
||||||
|
type Session interface {
|
||||||
|
// GetAuthURL returns the URL for the authentication end-point for the provider.
|
||||||
|
GetAuthURL() (string, error)
|
||||||
|
// Marshal generates a string representation of the Session for storing between requests.
|
||||||
|
Marshal() string
|
||||||
|
// Authorize should validate the data from the provider and return an access token
|
||||||
|
// that can be stored for later access to the provider.
|
||||||
|
Authorize(Provider, Params) (string, error)
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
package goth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/gob"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
gob.Register(User{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// User contains the information common amongst most OAuth and OAuth2 providers.
|
||||||
|
// All of the "raw" datafrom the provider can be found in the `RawData` field.
|
||||||
|
type User struct {
|
||||||
|
RawData map[string]interface{}
|
||||||
|
Provider string
|
||||||
|
Email string
|
||||||
|
Name string
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
NickName string
|
||||||
|
Description string
|
||||||
|
UserID string
|
||||||
|
AvatarURL string
|
||||||
|
Location string
|
||||||
|
AccessToken string
|
||||||
|
AccessTokenSecret string
|
||||||
|
RefreshToken string
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
|
@ -0,0 +1,3 @@
|
||||||
|
# This source code refers to The Go Authors for copyright purposes.
|
||||||
|
# The master list of authors is in the main Go distribution,
|
||||||
|
# visible at http://tip.golang.org/AUTHORS.
|
|
@ -0,0 +1,31 @@
|
||||||
|
# Contributing to Go
|
||||||
|
|
||||||
|
Go is an open source project.
|
||||||
|
|
||||||
|
It is the work of hundreds of contributors. We appreciate your help!
|
||||||
|
|
||||||
|
|
||||||
|
## Filing issues
|
||||||
|
|
||||||
|
When [filing an issue](https://github.com/golang/oauth2/issues), make sure to answer these five questions:
|
||||||
|
|
||||||
|
1. What version of Go are you using (`go version`)?
|
||||||
|
2. What operating system and processor architecture are you using?
|
||||||
|
3. What did you do?
|
||||||
|
4. What did you expect to see?
|
||||||
|
5. What did you see instead?
|
||||||
|
|
||||||
|
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
|
||||||
|
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
|
||||||
|
|
||||||
|
## Contributing code
|
||||||
|
|
||||||
|
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
|
||||||
|
before sending patches.
|
||||||
|
|
||||||
|
**We do not accept GitHub pull requests**
|
||||||
|
(we use [Gerrit](https://code.google.com/p/gerrit/) instead for code review).
|
||||||
|
|
||||||
|
Unless otherwise noted, the Go source files are distributed under
|
||||||
|
the BSD-style license found in the LICENSE file.
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
# This source code was written by the Go contributors.
|
||||||
|
# The master list of contributors is in the main Go distribution,
|
||||||
|
# visible at http://tip.golang.org/CONTRIBUTORS.
|
|
@ -0,0 +1,27 @@
|
||||||
|
Copyright (c) 2009 The oauth2 Authors. All rights reserved.
|
||||||
|
|
||||||
|
Redistribution and use in source and binary forms, with or without
|
||||||
|
modification, are permitted provided that the following conditions are
|
||||||
|
met:
|
||||||
|
|
||||||
|
* Redistributions of source code must retain the above copyright
|
||||||
|
notice, this list of conditions and the following disclaimer.
|
||||||
|
* Redistributions in binary form must reproduce the above
|
||||||
|
copyright notice, this list of conditions and the following disclaimer
|
||||||
|
in the documentation and/or other materials provided with the
|
||||||
|
distribution.
|
||||||
|
* Neither the name of Google Inc. nor the names of its
|
||||||
|
contributors may be used to endorse or promote products derived from
|
||||||
|
this software without specific prior written permission.
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||||
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||||
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||||
|
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||||
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||||
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||||
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||||
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||||
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||||
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,65 @@
|
||||||
|
# OAuth2 for Go
|
||||||
|
|
||||||
|
[![Build Status](https://travis-ci.org/golang/oauth2.svg?branch=master)](https://travis-ci.org/golang/oauth2)
|
||||||
|
[![GoDoc](https://godoc.org/golang.org/x/oauth2?status.svg)](https://godoc.org/golang.org/x/oauth2)
|
||||||
|
|
||||||
|
oauth2 package contains a client implementation for OAuth 2.0 spec.
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
~~~~
|
||||||
|
go get golang.org/x/oauth2
|
||||||
|
~~~~
|
||||||
|
|
||||||
|
See godoc for further documentation and examples.
|
||||||
|
|
||||||
|
* [godoc.org/golang.org/x/oauth2](http://godoc.org/golang.org/x/oauth2)
|
||||||
|
* [godoc.org/golang.org/x/oauth2/google](http://godoc.org/golang.org/x/oauth2/google)
|
||||||
|
|
||||||
|
|
||||||
|
## App Engine
|
||||||
|
|
||||||
|
In change 96e89be (March 2015) we removed the `oauth2.Context2` type in favor
|
||||||
|
of the [`context.Context`](https://golang.org/x/net/context#Context) type from
|
||||||
|
the `golang.org/x/net/context` package
|
||||||
|
|
||||||
|
This means its no longer possible to use the "Classic App Engine"
|
||||||
|
`appengine.Context` type with the `oauth2` package. (You're using
|
||||||
|
Classic App Engine if you import the package `"appengine"`.)
|
||||||
|
|
||||||
|
To work around this, you may use the new `"google.golang.org/appengine"`
|
||||||
|
package. This package has almost the same API as the `"appengine"` package,
|
||||||
|
but it can be fetched with `go get` and used on "Managed VMs" and well as
|
||||||
|
Classic App Engine.
|
||||||
|
|
||||||
|
See the [new `appengine` package's readme](https://github.com/golang/appengine#updating-a-go-app-engine-app)
|
||||||
|
for information on updating your app.
|
||||||
|
|
||||||
|
If you don't want to update your entire app to use the new App Engine packages,
|
||||||
|
you may use both sets of packages in parallel, using only the new packages
|
||||||
|
with the `oauth2` package.
|
||||||
|
|
||||||
|
import (
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
newappengine "google.golang.org/appengine"
|
||||||
|
newurlfetch "google.golang.org/appengine/urlfetch"
|
||||||
|
|
||||||
|
"appengine"
|
||||||
|
)
|
||||||
|
|
||||||
|
func handler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var c appengine.Context = appengine.NewContext(r)
|
||||||
|
c.Infof("Logging a message with the old package")
|
||||||
|
|
||||||
|
var ctx context.Context = newappengine.NewContext(r)
|
||||||
|
client := &http.Client{
|
||||||
|
Transport: &oauth2.Transport{
|
||||||
|
Source: google.AppEngineTokenSource(ctx, "scope"),
|
||||||
|
Base: &newurlfetch.Transport{Context: ctx},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
client.Get("...")
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// +build appengine
|
||||||
|
|
||||||
|
// App Engine hooks.
|
||||||
|
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
|
"google.golang.org/appengine/urlfetch"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
internal.RegisterContextClientFunc(contextClientAppEngine)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contextClientAppEngine(ctx context.Context) (*http.Client, error) {
|
||||||
|
return urlfetch.Client(ctx), nil
|
||||||
|
}
|
|
@ -0,0 +1,76 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package internal contains support packages for oauth2 package.
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseKey converts the binary contents of a private key file
|
||||||
|
// to an *rsa.PrivateKey. It detects whether the private key is in a
|
||||||
|
// PEM container or not. If so, it extracts the the private key
|
||||||
|
// from PEM container before conversion. It only supports PEM
|
||||||
|
// containers with no passphrase.
|
||||||
|
func ParseKey(key []byte) (*rsa.PrivateKey, error) {
|
||||||
|
block, _ := pem.Decode(key)
|
||||||
|
if block != nil {
|
||||||
|
key = block.Bytes
|
||||||
|
}
|
||||||
|
parsedKey, err := x509.ParsePKCS8PrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
parsedKey, err = x509.ParsePKCS1PrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("private key should be a PEM or plain PKSC1 or PKCS8; parse error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsed, ok := parsedKey.(*rsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("private key is invalid")
|
||||||
|
}
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseINI(ini io.Reader) (map[string]map[string]string, error) {
|
||||||
|
result := map[string]map[string]string{
|
||||||
|
"": map[string]string{}, // root section
|
||||||
|
}
|
||||||
|
scanner := bufio.NewScanner(ini)
|
||||||
|
currentSection := ""
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, ";") {
|
||||||
|
// comment.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(line, "[") && strings.HasSuffix(line, "]") {
|
||||||
|
currentSection = strings.TrimSpace(line[1 : len(line)-1])
|
||||||
|
result[currentSection] = map[string]string{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "=", 2)
|
||||||
|
if len(parts) == 2 && parts[0] != "" {
|
||||||
|
result[currentSection][strings.TrimSpace(parts[0])] = strings.TrimSpace(parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := scanner.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("error scanning ini: %v", err)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CondVal(v string) []string {
|
||||||
|
if v == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return []string{v}
|
||||||
|
}
|
|
@ -0,0 +1,227 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package internal contains support packages for oauth2 package.
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"mime"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Token represents the crendentials used to authorize
|
||||||
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
|
// provider's backend.
|
||||||
|
//
|
||||||
|
// This type is a mirror of oauth2.Token and exists to break
|
||||||
|
// an otherwise-circular dependency. Other internal packages
|
||||||
|
// should convert this Token into an oauth2.Token before use.
|
||||||
|
type Token struct {
|
||||||
|
// AccessToken is the token that authorizes and authenticates
|
||||||
|
// the requests.
|
||||||
|
AccessToken string
|
||||||
|
|
||||||
|
// TokenType is the type of token.
|
||||||
|
// The Type method returns either this or "Bearer", the default.
|
||||||
|
TokenType string
|
||||||
|
|
||||||
|
// RefreshToken is a token that's used by the application
|
||||||
|
// (as opposed to the user) to refresh the access token
|
||||||
|
// if it expires.
|
||||||
|
RefreshToken string
|
||||||
|
|
||||||
|
// Expiry is the optional expiration time of the access token.
|
||||||
|
//
|
||||||
|
// If zero, TokenSource implementations will reuse the same
|
||||||
|
// token forever and RefreshToken or equivalent
|
||||||
|
// mechanisms for that TokenSource will not be used.
|
||||||
|
Expiry time.Time
|
||||||
|
|
||||||
|
// Raw optionally contains extra metadata from the server
|
||||||
|
// when updating a token.
|
||||||
|
Raw interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenJSON is the struct representing the HTTP response from OAuth2
|
||||||
|
// providers returning a token in JSON form.
|
||||||
|
type tokenJSON struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
ExpiresIn expirationTime `json:"expires_in"` // at least PayPal returns string, while most return number
|
||||||
|
Expires expirationTime `json:"expires"` // broken Facebook spelling of expires_in
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *tokenJSON) expiry() (t time.Time) {
|
||||||
|
if v := e.ExpiresIn; v != 0 {
|
||||||
|
return time.Now().Add(time.Duration(v) * time.Second)
|
||||||
|
}
|
||||||
|
if v := e.Expires; v != 0 {
|
||||||
|
return time.Now().Add(time.Duration(v) * time.Second)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
type expirationTime int32
|
||||||
|
|
||||||
|
func (e *expirationTime) UnmarshalJSON(b []byte) error {
|
||||||
|
var n json.Number
|
||||||
|
err := json.Unmarshal(b, &n)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
i, err := n.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*e = expirationTime(i)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var brokenAuthHeaderProviders = []string{
|
||||||
|
"https://accounts.google.com/",
|
||||||
|
"https://api.dropbox.com/",
|
||||||
|
"https://api.dropboxapi.com/",
|
||||||
|
"https://api.instagram.com/",
|
||||||
|
"https://api.netatmo.net/",
|
||||||
|
"https://api.odnoklassniki.ru/",
|
||||||
|
"https://api.pushbullet.com/",
|
||||||
|
"https://api.soundcloud.com/",
|
||||||
|
"https://api.twitch.tv/",
|
||||||
|
"https://app.box.com/",
|
||||||
|
"https://connect.stripe.com/",
|
||||||
|
"https://login.microsoftonline.com/",
|
||||||
|
"https://login.salesforce.com/",
|
||||||
|
"https://oauth.sandbox.trainingpeaks.com/",
|
||||||
|
"https://oauth.trainingpeaks.com/",
|
||||||
|
"https://oauth.vk.com/",
|
||||||
|
"https://openapi.baidu.com/",
|
||||||
|
"https://slack.com/",
|
||||||
|
"https://test-sandbox.auth.corp.google.com",
|
||||||
|
"https://test.salesforce.com/",
|
||||||
|
"https://user.gini.net/",
|
||||||
|
"https://www.douban.com/",
|
||||||
|
"https://www.googleapis.com/",
|
||||||
|
"https://www.linkedin.com/",
|
||||||
|
"https://www.strava.com/oauth/",
|
||||||
|
"https://www.wunderlist.com/oauth/",
|
||||||
|
"https://api.patreon.com/",
|
||||||
|
"https://sandbox.codeswholesale.com/oauth/token",
|
||||||
|
"https://api.codeswholesale.com/oauth/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
||||||
|
brokenAuthHeaderProviders = append(brokenAuthHeaderProviders, tokenURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// providerAuthHeaderWorks reports whether the OAuth2 server identified by the tokenURL
|
||||||
|
// implements the OAuth2 spec correctly
|
||||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||||
|
// In summary:
|
||||||
|
// - Reddit only accepts client secret in the Authorization header
|
||||||
|
// - Dropbox accepts either it in URL param or Auth header, but not both.
|
||||||
|
// - Google only accepts URL param (not spec compliant?), not Auth header
|
||||||
|
// - Stripe only accepts client secret in Auth header with Bearer method, not Basic
|
||||||
|
func providerAuthHeaderWorks(tokenURL string) bool {
|
||||||
|
for _, s := range brokenAuthHeaderProviders {
|
||||||
|
if strings.HasPrefix(tokenURL, s) {
|
||||||
|
// Some sites fail to implement the OAuth2 spec fully.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assume the provider implements the spec properly
|
||||||
|
// otherwise. We can add more exceptions as they're
|
||||||
|
// discovered. We will _not_ be adding configurable hooks
|
||||||
|
// to this package to let users select server bugs.
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetrieveToken(ctx context.Context, clientID, clientSecret, tokenURL string, v url.Values) (*Token, error) {
|
||||||
|
hc, err := ContextClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
v.Set("client_id", clientID)
|
||||||
|
bustedAuth := !providerAuthHeaderWorks(tokenURL)
|
||||||
|
if bustedAuth && clientSecret != "" {
|
||||||
|
v.Set("client_secret", clientSecret)
|
||||||
|
}
|
||||||
|
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(v.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
if !bustedAuth {
|
||||||
|
req.SetBasicAuth(clientID, clientSecret)
|
||||||
|
}
|
||||||
|
r, err := hc.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer r.Body.Close()
|
||||||
|
body, err := ioutil.ReadAll(io.LimitReader(r.Body, 1<<20))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v", err)
|
||||||
|
}
|
||||||
|
if code := r.StatusCode; code < 200 || code > 299 {
|
||||||
|
return nil, fmt.Errorf("oauth2: cannot fetch token: %v\nResponse: %s", r.Status, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var token *Token
|
||||||
|
content, _, _ := mime.ParseMediaType(r.Header.Get("Content-Type"))
|
||||||
|
switch content {
|
||||||
|
case "application/x-www-form-urlencoded", "text/plain":
|
||||||
|
vals, err := url.ParseQuery(string(body))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token = &Token{
|
||||||
|
AccessToken: vals.Get("access_token"),
|
||||||
|
TokenType: vals.Get("token_type"),
|
||||||
|
RefreshToken: vals.Get("refresh_token"),
|
||||||
|
Raw: vals,
|
||||||
|
}
|
||||||
|
e := vals.Get("expires_in")
|
||||||
|
if e == "" {
|
||||||
|
// TODO(jbd): Facebook's OAuth2 implementation is broken and
|
||||||
|
// returns expires_in field in expires. Remove the fallback to expires,
|
||||||
|
// when Facebook fixes their implementation.
|
||||||
|
e = vals.Get("expires")
|
||||||
|
}
|
||||||
|
expires, _ := strconv.Atoi(e)
|
||||||
|
if expires != 0 {
|
||||||
|
token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
var tj tokenJSON
|
||||||
|
if err = json.Unmarshal(body, &tj); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
token = &Token{
|
||||||
|
AccessToken: tj.AccessToken,
|
||||||
|
TokenType: tj.TokenType,
|
||||||
|
RefreshToken: tj.RefreshToken,
|
||||||
|
Expiry: tj.expiry(),
|
||||||
|
Raw: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
json.Unmarshal(body, &token.Raw) // no error checks for optional fields
|
||||||
|
}
|
||||||
|
// Don't overwrite `RefreshToken` with an empty value
|
||||||
|
// if this was a token refreshing request.
|
||||||
|
if token.RefreshToken == "" {
|
||||||
|
token.RefreshToken = v.Get("refresh_token")
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
|
@ -0,0 +1,69 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package internal contains support packages for oauth2 package.
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||||
|
// WithValue function to associate an *http.Client value with a context.
|
||||||
|
var HTTPClient ContextKey
|
||||||
|
|
||||||
|
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
||||||
|
// an immutable public variable with a unique type. It's immutable
|
||||||
|
// because nobody else can create a ContextKey, being unexported.
|
||||||
|
type ContextKey struct{}
|
||||||
|
|
||||||
|
// ContextClientFunc is a func which tries to return an *http.Client
|
||||||
|
// given a Context value. If it returns an error, the search stops
|
||||||
|
// with that error. If it returns (nil, nil), the search continues
|
||||||
|
// down the list of registered funcs.
|
||||||
|
type ContextClientFunc func(context.Context) (*http.Client, error)
|
||||||
|
|
||||||
|
var contextClientFuncs []ContextClientFunc
|
||||||
|
|
||||||
|
func RegisterContextClientFunc(fn ContextClientFunc) {
|
||||||
|
contextClientFuncs = append(contextClientFuncs, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContextClient(ctx context.Context) (*http.Client, error) {
|
||||||
|
if ctx != nil {
|
||||||
|
if hc, ok := ctx.Value(HTTPClient).(*http.Client); ok {
|
||||||
|
return hc, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, fn := range contextClientFuncs {
|
||||||
|
c, err := fn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if c != nil {
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return http.DefaultClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContextTransport(ctx context.Context) http.RoundTripper {
|
||||||
|
hc, err := ContextClient(ctx)
|
||||||
|
// This is a rare error case (somebody using nil on App Engine).
|
||||||
|
if err != nil {
|
||||||
|
return ErrorTransport{err}
|
||||||
|
}
|
||||||
|
return hc.Transport
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorTransport returns the specified error on RoundTrip.
|
||||||
|
// This RoundTripper should be used in rare error cases where
|
||||||
|
// error handling can be postponed to response handling time.
|
||||||
|
type ErrorTransport struct{ Err error }
|
||||||
|
|
||||||
|
func (t ErrorTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
|
return nil, t.Err
|
||||||
|
}
|
|
@ -0,0 +1,341 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
// Package oauth2 provides support for making
|
||||||
|
// OAuth2 authorized and authenticated HTTP requests.
|
||||||
|
// It can additionally grant authorization with Bearer JWT.
|
||||||
|
package oauth2 // import "golang.org/x/oauth2"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NoContext is the default context you should supply if not using
|
||||||
|
// your own context.Context (see https://golang.org/x/net/context).
|
||||||
|
//
|
||||||
|
// Deprecated: Use context.Background() or context.TODO() instead.
|
||||||
|
var NoContext = context.TODO()
|
||||||
|
|
||||||
|
// RegisterBrokenAuthHeaderProvider registers an OAuth2 server
|
||||||
|
// identified by the tokenURL prefix as an OAuth2 implementation
|
||||||
|
// which doesn't support the HTTP Basic authentication
|
||||||
|
// scheme to authenticate with the authorization server.
|
||||||
|
// Once a server is registered, credentials (client_id and client_secret)
|
||||||
|
// will be passed as query parameters rather than being present
|
||||||
|
// in the Authorization header.
|
||||||
|
// See https://code.google.com/p/goauth2/issues/detail?id=31 for background.
|
||||||
|
func RegisterBrokenAuthHeaderProvider(tokenURL string) {
|
||||||
|
internal.RegisterBrokenAuthHeaderProvider(tokenURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config describes a typical 3-legged OAuth2 flow, with both the
|
||||||
|
// client application information and the server's endpoint URLs.
|
||||||
|
// For the client credentials 2-legged OAuth2 flow, see the clientcredentials
|
||||||
|
// package (https://golang.org/x/oauth2/clientcredentials).
|
||||||
|
type Config struct {
|
||||||
|
// ClientID is the application's ID.
|
||||||
|
ClientID string
|
||||||
|
|
||||||
|
// ClientSecret is the application's secret.
|
||||||
|
ClientSecret string
|
||||||
|
|
||||||
|
// Endpoint contains the resource server's token endpoint
|
||||||
|
// URLs. These are constants specific to each server and are
|
||||||
|
// often available via site-specific packages, such as
|
||||||
|
// google.Endpoint or github.Endpoint.
|
||||||
|
Endpoint Endpoint
|
||||||
|
|
||||||
|
// RedirectURL is the URL to redirect users going through
|
||||||
|
// the OAuth flow, after the resource owner's URLs.
|
||||||
|
RedirectURL string
|
||||||
|
|
||||||
|
// Scope specifies optional requested permissions.
|
||||||
|
Scopes []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// A TokenSource is anything that can return a token.
|
||||||
|
type TokenSource interface {
|
||||||
|
// Token returns a token or an error.
|
||||||
|
// Token must be safe for concurrent use by multiple goroutines.
|
||||||
|
// The returned Token must not be modified.
|
||||||
|
Token() (*Token, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Endpoint contains the OAuth 2.0 provider's authorization and token
|
||||||
|
// endpoint URLs.
|
||||||
|
type Endpoint struct {
|
||||||
|
AuthURL string
|
||||||
|
TokenURL string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
// AccessTypeOnline and AccessTypeOffline are options passed
|
||||||
|
// to the Options.AuthCodeURL method. They modify the
|
||||||
|
// "access_type" field that gets sent in the URL returned by
|
||||||
|
// AuthCodeURL.
|
||||||
|
//
|
||||||
|
// Online is the default if neither is specified. If your
|
||||||
|
// application needs to refresh access tokens when the user
|
||||||
|
// is not present at the browser, then use offline. This will
|
||||||
|
// result in your application obtaining a refresh token the
|
||||||
|
// first time your application exchanges an authorization
|
||||||
|
// code for a user.
|
||||||
|
AccessTypeOnline AuthCodeOption = SetAuthURLParam("access_type", "online")
|
||||||
|
AccessTypeOffline AuthCodeOption = SetAuthURLParam("access_type", "offline")
|
||||||
|
|
||||||
|
// ApprovalForce forces the users to view the consent dialog
|
||||||
|
// and confirm the permissions request at the URL returned
|
||||||
|
// from AuthCodeURL, even if they've already done so.
|
||||||
|
ApprovalForce AuthCodeOption = SetAuthURLParam("approval_prompt", "force")
|
||||||
|
)
|
||||||
|
|
||||||
|
// An AuthCodeOption is passed to Config.AuthCodeURL.
|
||||||
|
type AuthCodeOption interface {
|
||||||
|
setValue(url.Values)
|
||||||
|
}
|
||||||
|
|
||||||
|
type setParam struct{ k, v string }
|
||||||
|
|
||||||
|
func (p setParam) setValue(m url.Values) { m.Set(p.k, p.v) }
|
||||||
|
|
||||||
|
// SetAuthURLParam builds an AuthCodeOption which passes key/value parameters
|
||||||
|
// to a provider's authorization endpoint.
|
||||||
|
func SetAuthURLParam(key, value string) AuthCodeOption {
|
||||||
|
return setParam{key, value}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCodeURL returns a URL to OAuth 2.0 provider's consent page
|
||||||
|
// that asks for permissions for the required scopes explicitly.
|
||||||
|
//
|
||||||
|
// State is a token to protect the user from CSRF attacks. You must
|
||||||
|
// always provide a non-zero string and validate that it matches the
|
||||||
|
// the state query parameter on your redirect callback.
|
||||||
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||||
|
//
|
||||||
|
// Opts may include AccessTypeOnline or AccessTypeOffline, as well
|
||||||
|
// as ApprovalForce.
|
||||||
|
func (c *Config) AuthCodeURL(state string, opts ...AuthCodeOption) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
buf.WriteString(c.Endpoint.AuthURL)
|
||||||
|
v := url.Values{
|
||||||
|
"response_type": {"code"},
|
||||||
|
"client_id": {c.ClientID},
|
||||||
|
"redirect_uri": internal.CondVal(c.RedirectURL),
|
||||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
||||||
|
"state": internal.CondVal(state),
|
||||||
|
}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.setValue(v)
|
||||||
|
}
|
||||||
|
if strings.Contains(c.Endpoint.AuthURL, "?") {
|
||||||
|
buf.WriteByte('&')
|
||||||
|
} else {
|
||||||
|
buf.WriteByte('?')
|
||||||
|
}
|
||||||
|
buf.WriteString(v.Encode())
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PasswordCredentialsToken converts a resource owner username and password
|
||||||
|
// pair into a token.
|
||||||
|
//
|
||||||
|
// Per the RFC, this grant type should only be used "when there is a high
|
||||||
|
// degree of trust between the resource owner and the client (e.g., the client
|
||||||
|
// is part of the device operating system or a highly privileged application),
|
||||||
|
// and when other authorization grant types are not available."
|
||||||
|
// See https://tools.ietf.org/html/rfc6749#section-4.3 for more info.
|
||||||
|
//
|
||||||
|
// The HTTP client to use is derived from the context.
|
||||||
|
// If nil, http.DefaultClient is used.
|
||||||
|
func (c *Config) PasswordCredentialsToken(ctx context.Context, username, password string) (*Token, error) {
|
||||||
|
return retrieveToken(ctx, c, url.Values{
|
||||||
|
"grant_type": {"password"},
|
||||||
|
"username": {username},
|
||||||
|
"password": {password},
|
||||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exchange converts an authorization code into a token.
|
||||||
|
//
|
||||||
|
// It is used after a resource provider redirects the user back
|
||||||
|
// to the Redirect URI (the URL obtained from AuthCodeURL).
|
||||||
|
//
|
||||||
|
// The HTTP client to use is derived from the context.
|
||||||
|
// If a client is not provided via the context, http.DefaultClient is used.
|
||||||
|
//
|
||||||
|
// The code will be in the *http.Request.FormValue("code"). Before
|
||||||
|
// calling Exchange, be sure to validate FormValue("state").
|
||||||
|
func (c *Config) Exchange(ctx context.Context, code string) (*Token, error) {
|
||||||
|
return retrieveToken(ctx, c, url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": internal.CondVal(c.RedirectURL),
|
||||||
|
"scope": internal.CondVal(strings.Join(c.Scopes, " ")),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client returns an HTTP client using the provided token.
|
||||||
|
// The token will auto-refresh as necessary. The underlying
|
||||||
|
// HTTP transport will be obtained using the provided context.
|
||||||
|
// The returned client and its Transport should not be modified.
|
||||||
|
func (c *Config) Client(ctx context.Context, t *Token) *http.Client {
|
||||||
|
return NewClient(ctx, c.TokenSource(ctx, t))
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenSource returns a TokenSource that returns t until t expires,
|
||||||
|
// automatically refreshing it as necessary using the provided context.
|
||||||
|
//
|
||||||
|
// Most users will use Config.Client instead.
|
||||||
|
func (c *Config) TokenSource(ctx context.Context, t *Token) TokenSource {
|
||||||
|
tkr := &tokenRefresher{
|
||||||
|
ctx: ctx,
|
||||||
|
conf: c,
|
||||||
|
}
|
||||||
|
if t != nil {
|
||||||
|
tkr.refreshToken = t.RefreshToken
|
||||||
|
}
|
||||||
|
return &reuseTokenSource{
|
||||||
|
t: t,
|
||||||
|
new: tkr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenRefresher is a TokenSource that makes "grant_type"=="refresh_token"
|
||||||
|
// HTTP requests to renew a token using a RefreshToken.
|
||||||
|
type tokenRefresher struct {
|
||||||
|
ctx context.Context // used to get HTTP requests
|
||||||
|
conf *Config
|
||||||
|
refreshToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
// WARNING: Token is not safe for concurrent access, as it
|
||||||
|
// updates the tokenRefresher's refreshToken field.
|
||||||
|
// Within this package, it is used by reuseTokenSource which
|
||||||
|
// synchronizes calls to this method with its own mutex.
|
||||||
|
func (tf *tokenRefresher) Token() (*Token, error) {
|
||||||
|
if tf.refreshToken == "" {
|
||||||
|
return nil, errors.New("oauth2: token expired and refresh token is not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
tk, err := retrieveToken(tf.ctx, tf.conf, url.Values{
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {tf.refreshToken},
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if tf.refreshToken != tk.RefreshToken {
|
||||||
|
tf.refreshToken = tk.RefreshToken
|
||||||
|
}
|
||||||
|
return tk, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// reuseTokenSource is a TokenSource that holds a single token in memory
|
||||||
|
// and validates its expiry before each call to retrieve it with
|
||||||
|
// Token. If it's expired, it will be auto-refreshed using the
|
||||||
|
// new TokenSource.
|
||||||
|
type reuseTokenSource struct {
|
||||||
|
new TokenSource // called when t is expired.
|
||||||
|
|
||||||
|
mu sync.Mutex // guards t
|
||||||
|
t *Token
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token returns the current token if it's still valid, else will
|
||||||
|
// refresh the current token (using r.Context for HTTP client
|
||||||
|
// information) and return the new one.
|
||||||
|
func (s *reuseTokenSource) Token() (*Token, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.t.Valid() {
|
||||||
|
return s.t, nil
|
||||||
|
}
|
||||||
|
t, err := s.new.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
s.t = t
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// StaticTokenSource returns a TokenSource that always returns the same token.
|
||||||
|
// Because the provided token t is never refreshed, StaticTokenSource is only
|
||||||
|
// useful for tokens that never expire.
|
||||||
|
func StaticTokenSource(t *Token) TokenSource {
|
||||||
|
return staticTokenSource{t}
|
||||||
|
}
|
||||||
|
|
||||||
|
// staticTokenSource is a TokenSource that always returns the same Token.
|
||||||
|
type staticTokenSource struct {
|
||||||
|
t *Token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s staticTokenSource) Token() (*Token, error) {
|
||||||
|
return s.t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
||||||
|
// WithValue function to associate an *http.Client value with a context.
|
||||||
|
var HTTPClient internal.ContextKey
|
||||||
|
|
||||||
|
// NewClient creates an *http.Client from a Context and TokenSource.
|
||||||
|
// The returned client is not valid beyond the lifetime of the context.
|
||||||
|
//
|
||||||
|
// As a special case, if src is nil, a non-OAuth2 client is returned
|
||||||
|
// using the provided context. This exists to support related OAuth2
|
||||||
|
// packages.
|
||||||
|
func NewClient(ctx context.Context, src TokenSource) *http.Client {
|
||||||
|
if src == nil {
|
||||||
|
c, err := internal.ContextClient(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return &http.Client{Transport: internal.ErrorTransport{Err: err}}
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &Transport{
|
||||||
|
Base: internal.ContextTransport(ctx),
|
||||||
|
Source: ReuseTokenSource(nil, src),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReuseTokenSource returns a TokenSource which repeatedly returns the
|
||||||
|
// same token as long as it's valid, starting with t.
|
||||||
|
// When its cached token is invalid, a new token is obtained from src.
|
||||||
|
//
|
||||||
|
// ReuseTokenSource is typically used to reuse tokens from a cache
|
||||||
|
// (such as a file on disk) between runs of a program, rather than
|
||||||
|
// obtaining new tokens unnecessarily.
|
||||||
|
//
|
||||||
|
// The initial token t may be nil, in which case the TokenSource is
|
||||||
|
// wrapped in a caching version if it isn't one already. This also
|
||||||
|
// means it's always safe to wrap ReuseTokenSource around any other
|
||||||
|
// TokenSource without adverse effects.
|
||||||
|
func ReuseTokenSource(t *Token, src TokenSource) TokenSource {
|
||||||
|
// Don't wrap a reuseTokenSource in itself. That would work,
|
||||||
|
// but cause an unnecessary number of mutex operations.
|
||||||
|
// Just build the equivalent one.
|
||||||
|
if rt, ok := src.(*reuseTokenSource); ok {
|
||||||
|
if t == nil {
|
||||||
|
// Just use it directly.
|
||||||
|
return rt
|
||||||
|
}
|
||||||
|
src = rt.new
|
||||||
|
}
|
||||||
|
return &reuseTokenSource{
|
||||||
|
t: t,
|
||||||
|
new: src,
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,158 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
"golang.org/x/oauth2/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
// expiryDelta determines how earlier a token should be considered
|
||||||
|
// expired than its actual expiration time. It is used to avoid late
|
||||||
|
// expirations due to client-server time mismatches.
|
||||||
|
const expiryDelta = 10 * time.Second
|
||||||
|
|
||||||
|
// Token represents the crendentials used to authorize
|
||||||
|
// the requests to access protected resources on the OAuth 2.0
|
||||||
|
// provider's backend.
|
||||||
|
//
|
||||||
|
// Most users of this package should not access fields of Token
|
||||||
|
// directly. They're exported mostly for use by related packages
|
||||||
|
// implementing derivative OAuth2 flows.
|
||||||
|
type Token struct {
|
||||||
|
// AccessToken is the token that authorizes and authenticates
|
||||||
|
// the requests.
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
|
||||||
|
// TokenType is the type of token.
|
||||||
|
// The Type method returns either this or "Bearer", the default.
|
||||||
|
TokenType string `json:"token_type,omitempty"`
|
||||||
|
|
||||||
|
// RefreshToken is a token that's used by the application
|
||||||
|
// (as opposed to the user) to refresh the access token
|
||||||
|
// if it expires.
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
|
||||||
|
// Expiry is the optional expiration time of the access token.
|
||||||
|
//
|
||||||
|
// If zero, TokenSource implementations will reuse the same
|
||||||
|
// token forever and RefreshToken or equivalent
|
||||||
|
// mechanisms for that TokenSource will not be used.
|
||||||
|
Expiry time.Time `json:"expiry,omitempty"`
|
||||||
|
|
||||||
|
// raw optionally contains extra metadata from the server
|
||||||
|
// when updating a token.
|
||||||
|
raw interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns t.TokenType if non-empty, else "Bearer".
|
||||||
|
func (t *Token) Type() string {
|
||||||
|
if strings.EqualFold(t.TokenType, "bearer") {
|
||||||
|
return "Bearer"
|
||||||
|
}
|
||||||
|
if strings.EqualFold(t.TokenType, "mac") {
|
||||||
|
return "MAC"
|
||||||
|
}
|
||||||
|
if strings.EqualFold(t.TokenType, "basic") {
|
||||||
|
return "Basic"
|
||||||
|
}
|
||||||
|
if t.TokenType != "" {
|
||||||
|
return t.TokenType
|
||||||
|
}
|
||||||
|
return "Bearer"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAuthHeader sets the Authorization header to r using the access
|
||||||
|
// token in t.
|
||||||
|
//
|
||||||
|
// This method is unnecessary when using Transport or an HTTP Client
|
||||||
|
// returned by this package.
|
||||||
|
func (t *Token) SetAuthHeader(r *http.Request) {
|
||||||
|
r.Header.Set("Authorization", t.Type()+" "+t.AccessToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExtra returns a new Token that's a clone of t, but using the
|
||||||
|
// provided raw extra map. This is only intended for use by packages
|
||||||
|
// implementing derivative OAuth2 flows.
|
||||||
|
func (t *Token) WithExtra(extra interface{}) *Token {
|
||||||
|
t2 := new(Token)
|
||||||
|
*t2 = *t
|
||||||
|
t2.raw = extra
|
||||||
|
return t2
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extra returns an extra field.
|
||||||
|
// Extra fields are key-value pairs returned by the server as a
|
||||||
|
// part of the token retrieval response.
|
||||||
|
func (t *Token) Extra(key string) interface{} {
|
||||||
|
if raw, ok := t.raw.(map[string]interface{}); ok {
|
||||||
|
return raw[key]
|
||||||
|
}
|
||||||
|
|
||||||
|
vals, ok := t.raw.(url.Values)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
v := vals.Get(key)
|
||||||
|
switch s := strings.TrimSpace(v); strings.Count(s, ".") {
|
||||||
|
case 0: // Contains no "."; try to parse as int
|
||||||
|
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
case 1: // Contains a single "."; try to parse as float
|
||||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// expired reports whether the token is expired.
|
||||||
|
// t must be non-nil.
|
||||||
|
func (t *Token) expired() bool {
|
||||||
|
if t.Expiry.IsZero() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return t.Expiry.Add(-expiryDelta).Before(time.Now())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid reports whether t is non-nil, has an AccessToken, and is not expired.
|
||||||
|
func (t *Token) Valid() bool {
|
||||||
|
return t != nil && t.AccessToken != "" && !t.expired()
|
||||||
|
}
|
||||||
|
|
||||||
|
// tokenFromInternal maps an *internal.Token struct into
|
||||||
|
// a *Token struct.
|
||||||
|
func tokenFromInternal(t *internal.Token) *Token {
|
||||||
|
if t == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &Token{
|
||||||
|
AccessToken: t.AccessToken,
|
||||||
|
TokenType: t.TokenType,
|
||||||
|
RefreshToken: t.RefreshToken,
|
||||||
|
Expiry: t.Expiry,
|
||||||
|
raw: t.Raw,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieveToken takes a *Config and uses that to retrieve an *internal.Token.
|
||||||
|
// This token is then mapped from *internal.Token into an *oauth2.Token which is returned along
|
||||||
|
// with an error..
|
||||||
|
func retrieveToken(ctx context.Context, c *Config, v url.Values) (*Token, error) {
|
||||||
|
tk, err := internal.RetrieveToken(ctx, c.ClientID, c.ClientSecret, c.Endpoint.TokenURL, v)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return tokenFromInternal(tk), nil
|
||||||
|
}
|
|
@ -0,0 +1,132 @@
|
||||||
|
// Copyright 2014 The Go Authors. All rights reserved.
|
||||||
|
// Use of this source code is governed by a BSD-style
|
||||||
|
// license that can be found in the LICENSE file.
|
||||||
|
|
||||||
|
package oauth2
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Transport is an http.RoundTripper that makes OAuth 2.0 HTTP requests,
|
||||||
|
// wrapping a base RoundTripper and adding an Authorization header
|
||||||
|
// with a token from the supplied Sources.
|
||||||
|
//
|
||||||
|
// Transport is a low-level mechanism. Most code will use the
|
||||||
|
// higher-level Config.Client method instead.
|
||||||
|
type Transport struct {
|
||||||
|
// Source supplies the token to add to outgoing requests'
|
||||||
|
// Authorization headers.
|
||||||
|
Source TokenSource
|
||||||
|
|
||||||
|
// Base is the base RoundTripper used to make HTTP requests.
|
||||||
|
// If nil, http.DefaultTransport is used.
|
||||||
|
Base http.RoundTripper
|
||||||
|
|
||||||
|
mu sync.Mutex // guards modReq
|
||||||
|
modReq map[*http.Request]*http.Request // original -> modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip authorizes and authenticates the request with an
|
||||||
|
// access token. If no token exists or token is expired,
|
||||||
|
// tries to refresh/fetch a new token.
|
||||||
|
func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
if t.Source == nil {
|
||||||
|
return nil, errors.New("oauth2: Transport's Source is nil")
|
||||||
|
}
|
||||||
|
token, err := t.Source.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
req2 := cloneRequest(req) // per RoundTripper contract
|
||||||
|
token.SetAuthHeader(req2)
|
||||||
|
t.setModReq(req, req2)
|
||||||
|
res, err := t.base().RoundTrip(req2)
|
||||||
|
if err != nil {
|
||||||
|
t.setModReq(req, nil)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
res.Body = &onEOFReader{
|
||||||
|
rc: res.Body,
|
||||||
|
fn: func() { t.setModReq(req, nil) },
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelRequest cancels an in-flight request by closing its connection.
|
||||||
|
func (t *Transport) CancelRequest(req *http.Request) {
|
||||||
|
type canceler interface {
|
||||||
|
CancelRequest(*http.Request)
|
||||||
|
}
|
||||||
|
if cr, ok := t.base().(canceler); ok {
|
||||||
|
t.mu.Lock()
|
||||||
|
modReq := t.modReq[req]
|
||||||
|
delete(t.modReq, req)
|
||||||
|
t.mu.Unlock()
|
||||||
|
cr.CancelRequest(modReq)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) base() http.RoundTripper {
|
||||||
|
if t.Base != nil {
|
||||||
|
return t.Base
|
||||||
|
}
|
||||||
|
return http.DefaultTransport
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) setModReq(orig, mod *http.Request) {
|
||||||
|
t.mu.Lock()
|
||||||
|
defer t.mu.Unlock()
|
||||||
|
if t.modReq == nil {
|
||||||
|
t.modReq = make(map[*http.Request]*http.Request)
|
||||||
|
}
|
||||||
|
if mod == nil {
|
||||||
|
delete(t.modReq, orig)
|
||||||
|
} else {
|
||||||
|
t.modReq[orig] = mod
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cloneRequest returns a clone of the provided *http.Request.
|
||||||
|
// The clone is a shallow copy of the struct and its Header map.
|
||||||
|
func cloneRequest(r *http.Request) *http.Request {
|
||||||
|
// shallow copy of the struct
|
||||||
|
r2 := new(http.Request)
|
||||||
|
*r2 = *r
|
||||||
|
// deep copy of the Header
|
||||||
|
r2.Header = make(http.Header, len(r.Header))
|
||||||
|
for k, s := range r.Header {
|
||||||
|
r2.Header[k] = append([]string(nil), s...)
|
||||||
|
}
|
||||||
|
return r2
|
||||||
|
}
|
||||||
|
|
||||||
|
type onEOFReader struct {
|
||||||
|
rc io.ReadCloser
|
||||||
|
fn func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) Read(p []byte) (n int, err error) {
|
||||||
|
n, err = r.rc.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
r.runFunc()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) Close() error {
|
||||||
|
err := r.rc.Close()
|
||||||
|
r.runFunc()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *onEOFReader) runFunc() {
|
||||||
|
if fn := r.fn; fn != nil {
|
||||||
|
fn()
|
||||||
|
r.fn = nil
|
||||||
|
}
|
||||||
|
}
|
|
@ -550,6 +550,24 @@
|
||||||
"revision": "d8eeeb8bae8896dd8e1b7e514ab0d396c4f12a1b",
|
"revision": "d8eeeb8bae8896dd8e1b7e514ab0d396c4f12a1b",
|
||||||
"revisionTime": "2016-11-03T02:43:54Z"
|
"revisionTime": "2016-11-03T02:43:54Z"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"checksumSHA1": "O3KUfEXQPfdQ+tCMpP2RAIRJJqY=",
|
||||||
|
"path": "github.com/markbates/goth",
|
||||||
|
"revision": "450379d2950a65070b23cc93c53436553add4484",
|
||||||
|
"revisionTime": "2017-02-06T19:46:32Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checksumSHA1": "MkFKwLV3icyUo4oP0BgEs+7+R1Y=",
|
||||||
|
"path": "github.com/markbates/goth/gothic",
|
||||||
|
"revision": "450379d2950a65070b23cc93c53436553add4484",
|
||||||
|
"revisionTime": "2017-02-06T19:46:32Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"checksumSHA1": "ZFqznX3/ZW65I4QeepiHQdE69nA=",
|
||||||
|
"path": "github.com/markbates/goth/providers/github",
|
||||||
|
"revision": "450379d2950a65070b23cc93c53436553add4484",
|
||||||
|
"revisionTime": "2017-02-06T19:46:32Z"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"checksumSHA1": "9FJUwn3EIgASVki+p8IHgWVC5vQ=",
|
"checksumSHA1": "9FJUwn3EIgASVki+p8IHgWVC5vQ=",
|
||||||
"path": "github.com/mattn/go-sqlite3",
|
"path": "github.com/mattn/go-sqlite3",
|
||||||
|
|
Loading…
Reference in New Issue