2017-02-22 00:14:37 -07:00
/ *
Package gothic wraps common behaviour when using Goth . This makes it quick , and easy , to get up
and running with Goth . Of course , if you want complete control over how things flow , in regards
to the authentication process , feel free and use Goth directly .
2019-01-13 12:06:22 -07:00
See https : //github.com/markbates/goth/blob/master/examples/main.go to see this in action.
2017-02-22 00:14:37 -07:00
* /
package gothic
import (
2018-02-18 22:10:51 -07:00
"bytes"
"compress/gzip"
2020-02-24 10:08:43 -07:00
"context"
2018-04-29 19:05:59 -06:00
"crypto/rand"
2018-02-18 22:10:51 -07:00
"encoding/base64"
2017-02-22 00:14:37 -07:00
"errors"
"fmt"
2018-04-29 19:05:59 -06:00
"io"
2018-02-18 22:10:51 -07:00
"io/ioutil"
2017-02-22 00:14:37 -07:00
"net/http"
2018-02-18 22:10:51 -07:00
"net/url"
2017-02-22 00:14:37 -07:00
"os"
2018-02-18 22:10:51 -07:00
"strings"
2017-02-22 00:14:37 -07:00
"github.com/gorilla/mux"
"github.com/gorilla/sessions"
"github.com/markbates/goth"
)
// SessionName is the key used to access the session store.
const SessionName = "_gothic_session"
// Store can/should be set by applications using gothic. The default is a cookie store.
var Store sessions . Store
var defaultStore sessions . Store
var keySet = false
2020-02-24 10:08:43 -07:00
type key int
// ProviderParamKey can be used as a key in context when passing in a provider
const ProviderParamKey key = iota
2017-02-22 00:14:37 -07:00
func init ( ) {
key := [ ] byte ( os . Getenv ( "SESSION_SECRET" ) )
keySet = len ( key ) != 0
2018-02-18 22:10:51 -07:00
cookieStore := sessions . NewCookieStore ( [ ] byte ( key ) )
cookieStore . Options . HttpOnly = true
Store = cookieStore
2017-02-22 00:14:37 -07:00
defaultStore = Store
}
/ *
2018-02-18 22:10:51 -07:00
BeginAuthHandler is a convenience handler for starting the authentication process .
2017-02-22 00:14:37 -07:00
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
BeginAuthHandler will redirect the user to the appropriate authentication end - point
for the requested provider .
See https : //github.com/markbates/goth/examples/main.go to see this in action.
* /
func BeginAuthHandler ( res http . ResponseWriter , req * http . Request ) {
url , err := GetAuthURL ( res , req )
if err != nil {
res . WriteHeader ( http . StatusBadRequest )
fmt . Fprintln ( res , err )
return
}
http . Redirect ( res , req , url , http . StatusTemporaryRedirect )
}
// SetState sets the state string associated with the given request.
// If no state string is associated with the request, one will be generated.
// This state is sent to the provider and can be retrieved during the
// callback.
var SetState = func ( req * http . Request ) string {
state := req . URL . Query ( ) . Get ( "state" )
if len ( state ) > 0 {
return state
}
2018-02-18 22:10:51 -07:00
// If a state query param is not passed in, generate a random
// base64-encoded nonce so that the state on the auth URL
// is unguessable, preventing CSRF attacks, as described in
//
// https://auth0.com/docs/protocols/oauth2/oauth-state#keep-reading
nonceBytes := make ( [ ] byte , 64 )
2018-04-29 19:05:59 -06:00
_ , err := io . ReadFull ( rand . Reader , nonceBytes )
if err != nil {
panic ( "gothic: source of randomness unavailable: " + err . Error ( ) )
2018-02-18 22:10:51 -07:00
}
return base64 . URLEncoding . EncodeToString ( nonceBytes )
2017-02-22 00:14:37 -07:00
}
// GetState gets the state returned by the provider during the callback.
// This is used to prevent CSRF attacks, see
// http://tools.ietf.org/html/rfc6749#section-10.12
var GetState = func ( req * http . Request ) string {
2021-02-28 16:08:33 -07:00
params := req . URL . Query ( )
if params . Encode ( ) == "" && req . Method == http . MethodPost {
return req . FormValue ( "state" )
}
return params . Get ( "state" )
2017-02-22 00:14:37 -07:00
}
/ *
GetAuthURL starts the authentication process with the requested provided .
It will return a URL that should be used to send users to .
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
I would recommend using the BeginAuthHandler instead of doing all of these steps
yourself , but that ' s entirely up to you .
* /
func GetAuthURL ( res http . ResponseWriter , req * http . Request ) ( string , error ) {
if ! keySet && defaultStore == Store {
fmt . Println ( "goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store." )
}
providerName , err := GetProviderName ( req )
if err != nil {
return "" , err
}
provider , err := goth . GetProvider ( providerName )
if err != nil {
return "" , err
}
sess , err := provider . BeginAuth ( SetState ( req ) )
if err != nil {
return "" , err
}
url , err := sess . GetAuthURL ( )
if err != nil {
return "" , err
}
2018-03-12 17:35:46 -06:00
err = StoreInSession ( providerName , sess . Marshal ( ) , req , res )
2017-02-22 00:14:37 -07:00
if err != nil {
return "" , err
}
return url , err
}
/ *
CompleteUserAuth does what it says on the tin . It completes the authentication
process and fetches all of the basic information about the user from the provider .
It expects to be able to get the name of the provider from the query parameters
as either "provider" or ":provider" .
See https : //github.com/markbates/goth/examples/main.go to see this in action.
* /
var CompleteUserAuth = func ( res http . ResponseWriter , req * http . Request ) ( goth . User , error ) {
if ! keySet && defaultStore == Store {
fmt . Println ( "goth/gothic: no SESSION_SECRET environment variable is set. The default cookie store is not available and any calls will fail. Ignore this warning if you are using a different store." )
}
providerName , err := GetProviderName ( req )
if err != nil {
return goth . User { } , err
}
provider , err := goth . GetProvider ( providerName )
if err != nil {
return goth . User { } , err
}
2018-03-12 17:35:46 -06:00
value , err := GetFromSession ( providerName , req )
2017-02-22 00:14:37 -07:00
if err != nil {
return goth . User { } , err
}
2021-02-28 16:08:33 -07:00
defer Logout ( res , req )
2017-02-22 00:14:37 -07:00
sess , err := provider . UnmarshalSession ( value )
if err != nil {
return goth . User { } , err
}
2018-02-18 22:10:51 -07:00
err = validateState ( req , sess )
if err != nil {
return goth . User { } , err
}
2017-02-22 00:14:37 -07:00
user , err := provider . FetchUser ( sess )
if err == nil {
// user can be found with existing session data
return user , err
}
2020-10-15 23:06:27 -06:00
params := req . URL . Query ( )
if params . Encode ( ) == "" && req . Method == "POST" {
req . ParseForm ( )
params = req . Form
}
2017-02-22 00:14:37 -07:00
// get new token and retry fetch
2020-10-15 23:06:27 -06:00
_ , err = sess . Authorize ( provider , params )
2017-02-22 00:14:37 -07:00
if err != nil {
return goth . User { } , err
}
2018-03-12 17:35:46 -06:00
err = StoreInSession ( providerName , sess . Marshal ( ) , req , res )
2017-02-22 00:14:37 -07:00
if err != nil {
return goth . User { } , err
}
2018-02-18 22:10:51 -07:00
gu , err := provider . FetchUser ( sess )
return gu , err
}
// validateState ensures that the state token param from the original
// AuthURL matches the one included in the current (callback) request.
func validateState ( req * http . Request , sess goth . Session ) error {
rawAuthURL , err := sess . GetAuthURL ( )
if err != nil {
return err
}
authURL , err := url . Parse ( rawAuthURL )
if err != nil {
return err
}
2020-10-15 23:06:27 -06:00
reqState := GetState ( req )
2018-02-18 22:10:51 -07:00
originalState := authURL . Query ( ) . Get ( "state" )
2020-10-15 23:06:27 -06:00
if originalState != "" && ( originalState != reqState ) {
2018-02-18 22:10:51 -07:00
return errors . New ( "state token mismatch" )
}
return nil
}
// Logout invalidates a user session.
func Logout ( res http . ResponseWriter , req * http . Request ) error {
session , err := Store . Get ( req , SessionName )
if err != nil {
return err
}
session . Options . MaxAge = - 1
session . Values = make ( map [ interface { } ] interface { } )
err = session . Save ( req , res )
if err != nil {
return errors . New ( "Could not delete user session " )
}
return nil
2017-02-22 00:14:37 -07:00
}
// GetProviderName is a function used to get the name of a provider
// for a given request. By default, this provider is fetched from
// the URL query string. If you provide it in a different way,
// assign your own function to this variable that returns the provider
// name for your request.
var GetProviderName = getProviderName
func getProviderName ( req * http . Request ) ( string , error ) {
2018-02-18 22:10:51 -07:00
// try to get it from the url param "provider"
if p := req . URL . Query ( ) . Get ( "provider" ) ; p != "" {
return p , nil
2017-02-22 00:14:37 -07:00
}
2018-02-18 22:10:51 -07:00
// try to get it from the url param ":provider"
if p := req . URL . Query ( ) . Get ( ":provider" ) ; p != "" {
return p , nil
}
// try to get it from the context's value of "provider" key
if p , ok := mux . Vars ( req ) [ "provider" ] ; ok {
return p , nil
2017-02-22 00:14:37 -07:00
}
2018-02-18 22:10:51 -07:00
// try to get it from the go-context's value of "provider" key
if p , ok := req . Context ( ) . Value ( "provider" ) . ( string ) ; ok {
return p , nil
}
2020-02-24 10:08:43 -07:00
// try to get it from the go-context's value of providerContextKey key
if p , ok := req . Context ( ) . Value ( ProviderParamKey ) . ( string ) ; ok {
return p , nil
}
2019-09-12 20:15:36 -06:00
// As a fallback, loop over the used providers, if we already have a valid session for any provider (ie. user has already begun authentication with a provider), then return that provider name
providers := goth . GetProviders ( )
session , _ := Store . Get ( req , SessionName )
for _ , provider := range providers {
p := provider . Name ( )
value := session . Values [ p ]
if _ , ok := value . ( string ) ; ok {
return p , nil
}
}
2018-02-18 22:10:51 -07:00
// if not found then return an empty string with the corresponding error
return "" , errors . New ( "you must select a provider" )
2017-02-22 00:14:37 -07:00
}
2020-02-24 10:08:43 -07:00
// GetContextWithProvider returns a new request context containing the provider
func GetContextWithProvider ( req * http . Request , provider string ) * http . Request {
return req . WithContext ( context . WithValue ( req . Context ( ) , ProviderParamKey , provider ) )
}
2018-03-12 17:35:46 -06:00
// StoreInSession stores a specified key/value pair in the session.
func StoreInSession ( key string , value string , req * http . Request , res http . ResponseWriter ) error {
session , _ := Store . New ( req , SessionName )
2017-02-22 00:14:37 -07:00
2018-02-18 22:10:51 -07:00
if err := updateSessionValue ( session , key , value ) ; err != nil {
return err
}
2017-02-22 00:14:37 -07:00
return session . Save ( req , res )
}
2018-03-12 17:35:46 -06:00
// GetFromSession retrieves a previously-stored value from the session.
// If no value has previously been stored at the specified key, it will return an error.
func GetFromSession ( key string , req * http . Request ) ( string , error ) {
2018-02-18 22:10:51 -07:00
session , _ := Store . Get ( req , SessionName )
value , err := getSessionValue ( session , key )
if err != nil {
return "" , errors . New ( "could not find a matching session for this request" )
}
2017-02-22 00:14:37 -07:00
2018-02-18 22:10:51 -07:00
return value , nil
}
func getSessionValue ( session * sessions . Session , key string ) ( string , error ) {
2017-02-22 00:14:37 -07:00
value := session . Values [ key ]
if value == nil {
2018-02-18 22:10:51 -07:00
return "" , fmt . Errorf ( "could not find a matching session for this request" )
2017-02-22 00:14:37 -07:00
}
2018-02-18 22:10:51 -07:00
rdata := strings . NewReader ( value . ( string ) )
r , err := gzip . NewReader ( rdata )
if err != nil {
return "" , err
}
s , err := ioutil . ReadAll ( r )
if err != nil {
return "" , err
}
return string ( s ) , nil
}
func updateSessionValue ( session * sessions . Session , key , value string ) error {
var b bytes . Buffer
gz := gzip . NewWriter ( & b )
if _ , err := gz . Write ( [ ] byte ( value ) ) ; err != nil {
return err
}
if err := gz . Flush ( ) ; err != nil {
return err
}
if err := gz . Close ( ) ; err != nil {
return err
}
session . Values [ key ] = b . String ( )
return nil
}