package login
import (
"context"
"errors"
"fmt"
"io/fs"
"net/http"
"net/url"
"reflect"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/markbates/goth"
"github.com/markbates/goth/gothic"
"github.com/pquerna/otp"
"github.com/pquerna/otp/totp"
"github.com/qor5/web"
"github.com/qor5/x/i18n"
h "github.com/theplant/htmlgo"
"golang.org/x/text/language"
"gorm.io/gorm"
)
var (
ErrUserNotFound = errors.New("user not found")
ErrPasswordChanged = errors.New("password changed")
ErrWrongPassword = errors.New("wrong password")
ErrUserLocked = errors.New("user locked")
ErrUserGetLocked = errors.New("user get locked")
ErrWrongTOTPCode = errors.New("wrong totp code")
ErrTOTPCodeHasBeenUsed = errors.New("totp code has been used")
ErrEmptyPassword = errors.New("empty password")
ErrPasswordNotMatch = errors.New("password not match")
)
type HomeURLFunc func(r *http.Request, user interface{}) string
type HookFunc func(r *http.Request, user interface{}, extraVals ...interface{}) error
type Provider struct {
Goth goth.Provider
Key string
Text string
Logo h.HTMLComponent
}
type CookieConfig struct {
Path string
Domain string
SameSite http.SameSite
}
type TOTPConfig struct {
Issuer string
}
type RecaptchaConfig struct {
SiteKey string
SecretKey string
}
type Builder struct {
secret string
providers []*Provider
authCookieName string
authSecureCookieName string
continueUrlCookieName string
// seconds
sessionMaxAge int
cookieConfig CookieConfig
totpEnabled bool
totpConfig TOTPConfig
recaptchaEnabled bool
recaptchaConfig RecaptchaConfig
autoExtendSession bool
maxRetryCount int
noForgetPasswordLink bool
i18nBuilder *i18n.Builder
// Common URLs
homePageURLFunc HomeURLFunc
loginPageURL string
LogoutURL string
// TOTP URLs
validateTOTPURL string
totpSetupPageURL string
totpValidatePageURL string
// OAuth URLs
oauthBeginURL string
oauthCallbackURL string
oauthCallbackCompleteURL string
// UserPass URLs
passwordLoginURL string
resetPasswordURL string
resetPasswordPageURL string
changePasswordURL string
changePasswordPageURL string
forgetPasswordPageURL string
sendResetPasswordLinkURL string
resetPasswordLinkSentPageURL string
loginPageFunc web.PageFunc
forgetPasswordPageFunc web.PageFunc
resetPasswordLinkSentPageFunc web.PageFunc
resetPasswordPageFunc web.PageFunc
changePasswordPageFunc web.PageFunc
totpSetupPageFunc web.PageFunc
totpValidatePageFunc web.PageFunc
beforeSetPasswordHook HookFunc
afterLoginHook HookFunc
afterFailedToLoginHook HookFunc
afterUserLockedHook HookFunc
afterLogoutHook HookFunc
afterConfirmSendResetPasswordLinkHook HookFunc
afterResetPasswordHook HookFunc
afterChangePasswordHook HookFunc
afterExtendSessionHook HookFunc
afterTOTPCodeReusedHook HookFunc
afterOAuthCompleteHook HookFunc
db *gorm.DB
userModel interface{}
snakePrimaryField string
tUser reflect.Type
userPassEnabled bool
oauthEnabled bool
sessionSecureEnabled bool
}
func New() *Builder {
r := &Builder{
authCookieName: "auth",
authSecureCookieName: "qor5_auth_secure",
continueUrlCookieName: "qor5_continue_url",
homePageURLFunc: func(r *http.Request, user interface{}) string {
return "/"
},
loginPageURL: "/auth/login",
LogoutURL: "/auth/logout",
validateTOTPURL: "/auth/2fa/totp/do",
totpSetupPageURL: "/auth/2fa/totp/setup",
totpValidatePageURL: "/auth/2fa/totp/validate",
oauthBeginURL: "/auth/begin",
oauthCallbackURL: "/auth/callback",
oauthCallbackCompleteURL: "/auth/callback-complete",
passwordLoginURL: "/auth/userpass/login",
resetPasswordURL: "/auth/do-reset-password",
resetPasswordPageURL: "/auth/reset-password",
changePasswordURL: "/auth/do-change-password",
changePasswordPageURL: "/auth/change-password",
forgetPasswordPageURL: "/auth/forget-password",
sendResetPasswordLinkURL: "/auth/send-reset-password-link",
resetPasswordLinkSentPageURL: "/auth/reset-password-link-sent",
sessionMaxAge: 60 * 60,
cookieConfig: CookieConfig{
Path: "/",
Domain: "",
SameSite: http.SameSiteStrictMode,
},
autoExtendSession: true,
maxRetryCount: 5,
totpEnabled: true,
totpConfig: TOTPConfig{
Issuer: "QOR5",
},
}
i18nB := i18n.New()
i18nB.SupportLanguages(language.English, language.SimplifiedChinese, language.Japanese)
r.I18n(i18nB)
vh := r.ViewHelper()
r.loginPageFunc = defaultLoginPage(vh)
r.forgetPasswordPageFunc = defaultForgetPasswordPage(vh)
r.resetPasswordLinkSentPageFunc = defaultResetPasswordLinkSentPage(vh)
r.resetPasswordPageFunc = defaultResetPasswordPage(vh)
r.changePasswordPageFunc = defaultChangePasswordPage(vh)
r.totpSetupPageFunc = defaultTOTPSetupPage(vh)
r.totpValidatePageFunc = defaultTOTPValidatePage(vh)
return r
}
func (b *Builder) Secret(v string) (r *Builder) {
b.secret = v
return b
}
func (b *Builder) CookieConfig(v CookieConfig) (r *Builder) {
b.cookieConfig = v
return b
}
// Google reCAPTCHA.
func (b *Builder) Recaptcha(enable bool, config ...RecaptchaConfig) (r *Builder) {
b.recaptchaEnabled = enable
if len(config) > 0 {
b.recaptchaConfig = config[0]
}
if enable {
if b.recaptchaConfig.SiteKey == "" {
panic("SiteKey is empty")
}
if b.recaptchaConfig.SecretKey == "" {
panic("SecretKey is empty")
}
}
return b
}
func (b *Builder) OAuthProviders(vs ...*Provider) (r *Builder) {
if len(vs) == 0 {
return b
}
b.oauthEnabled = true
b.providers = vs
var gothProviders []goth.Provider
for _, v := range vs {
gothProviders = append(gothProviders, v.Goth)
}
goth.UseProviders(gothProviders...)
return b
}
func (b *Builder) AuthCookieName(v string) (r *Builder) {
b.authCookieName = v
return b
}
func (b *Builder) HomeURLFunc(v HomeURLFunc) (r *Builder) {
b.homePageURLFunc = v
return b
}
func (b *Builder) LoginPageURL(v string) (r *Builder) {
b.loginPageURL = v
return b
}
func (b *Builder) ResetPasswordPageURL(v string) (r *Builder) {
b.resetPasswordPageURL = v
return b
}
func (b *Builder) ChangePasswordPageURL(v string) (r *Builder) {
b.changePasswordPageURL = v
return b
}
func (b *Builder) ForgetPasswordPageURL(v string) (r *Builder) {
b.forgetPasswordPageURL = v
return b
}
func (b *Builder) ResetPasswordLinkSentPageURL(v string) (r *Builder) {
b.resetPasswordLinkSentPageURL = v
return b
}
func (b *Builder) TOTPSetupPageURL(v string) (r *Builder) {
b.totpSetupPageURL = v
return b
}
func (b *Builder) TOTPValidatePageURL(v string) (r *Builder) {
b.totpValidatePageURL = v
return b
}
func (b *Builder) LoginPageFunc(v web.PageFunc) (r *Builder) {
b.loginPageFunc = v
return b
}
func (b *Builder) ForgetPasswordPageFunc(v web.PageFunc) (r *Builder) {
b.forgetPasswordPageFunc = v
return b
}
func (b *Builder) ResetPasswordLinkSentPageFunc(v web.PageFunc) (r *Builder) {
b.resetPasswordLinkSentPageFunc = v
return b
}
func (b *Builder) ResetPasswordPageFunc(v web.PageFunc) (r *Builder) {
b.resetPasswordPageFunc = v
return b
}
func (b *Builder) ChangePasswordPageFunc(v web.PageFunc) (r *Builder) {
b.changePasswordPageFunc = v
return b
}
func (b *Builder) TOTPSetupPageFunc(v web.PageFunc) (r *Builder) {
b.totpSetupPageFunc = v
return b
}
func (b *Builder) TOTPValidatePageFunc(v web.PageFunc) (r *Builder) {
b.totpValidatePageFunc = v
return b
}
func (b *Builder) wrapHook(v HookFunc) HookFunc {
if v == nil {
return nil
}
return func(r *http.Request, user interface{}, extraVals ...interface{}) error {
if user != nil && GetCurrentUser(r) == nil {
r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
}
return v(r, user, extraVals...)
}
}
// extra vals:
// - password
func (b *Builder) BeforeSetPassword(v HookFunc) (r *Builder) {
b.beforeSetPasswordHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterLogin(v HookFunc) (r *Builder) {
b.afterLoginHook = b.wrapHook(v)
return b
}
// extra vals:
// - login error
func (b *Builder) AfterFailedToLogin(v HookFunc) (r *Builder) {
b.afterFailedToLoginHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterUserLocked(v HookFunc) (r *Builder) {
b.afterUserLockedHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterLogout(v HookFunc) (r *Builder) {
b.afterLogoutHook = b.wrapHook(v)
return b
}
// extra vals:
// - reset link
func (b *Builder) AfterConfirmSendResetPasswordLink(v HookFunc) (r *Builder) {
b.afterConfirmSendResetPasswordLinkHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterResetPassword(v HookFunc) (r *Builder) {
b.afterResetPasswordHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterChangePassword(v HookFunc) (r *Builder) {
b.afterChangePasswordHook = b.wrapHook(v)
return b
}
// extra vals:
// - old session token
func (b *Builder) AfterExtendSession(v HookFunc) (r *Builder) {
b.afterExtendSessionHook = b.wrapHook(v)
return b
}
func (b *Builder) AfterTOTPCodeReused(v HookFunc) (r *Builder) {
b.afterTOTPCodeReusedHook = b.wrapHook(v)
return b
}
// user is goth.User
func (b *Builder) AfterOAuthComplete(v HookFunc) (r *Builder) {
b.afterOAuthCompleteHook = b.wrapHook(v)
return b
}
// seconds
// default 1h
func (b *Builder) SessionMaxAge(v int) (r *Builder) {
b.sessionMaxAge = v
return b
}
// extend the session if successfully authenticated
// default true
func (b *Builder) AutoExtendSession(v bool) (r *Builder) {
b.autoExtendSession = v
return b
}
// default 5
// MaxRetryCount <= 0 means no max retry count limit
func (b *Builder) MaxRetryCount(v int) (r *Builder) {
b.maxRetryCount = v
return b
}
func (b *Builder) TOTP(enable bool, config ...TOTPConfig) (r *Builder) {
b.totpEnabled = enable
if len(config) > 0 {
b.totpConfig = config[0]
}
if enable {
if b.totpConfig.Issuer == "" {
panic("Issuer is empty")
}
}
return b
}
func (b *Builder) NoForgetPasswordLink(v bool) (r *Builder) {
b.noForgetPasswordLink = v
return b
}
func (b *Builder) DB(v *gorm.DB) (r *Builder) {
b.db = v
return b
}
func (b *Builder) I18n(v *i18n.Builder) (r *Builder) {
v.RegisterForModule(language.English, I18nLoginKey, Messages_en_US).
RegisterForModule(language.SimplifiedChinese, I18nLoginKey, Messages_zh_CN).
RegisterForModule(language.Japanese, I18nLoginKey, Messages_ja_JP)
b.i18nBuilder = v
return b
}
func (b *Builder) GetSessionMaxAge() int {
return b.sessionMaxAge
}
func (b *Builder) ViewHelper() *ViewHelper {
return &ViewHelper{
b: b,
}
}
func (b *Builder) UserModel(m interface{}) (r *Builder) {
b.userModel = m
b.tUser = underlyingReflectType(reflect.TypeOf(m))
b.snakePrimaryField = snakePrimaryField(m)
if _, ok := m.(UserPasser); ok {
b.userPassEnabled = true
}
if _, ok := m.(OAuthUser); ok {
b.oauthEnabled = true
}
if _, ok := m.(SessionSecurer); ok {
b.sessionSecureEnabled = true
}
return b
}
func (b *Builder) newUserObject() interface{} {
return reflect.New(b.tUser).Interface()
}
func (b *Builder) findUserByID(id string) (user interface{}, err error) {
m := b.newUserObject()
err = b.db.Where(fmt.Sprintf("%s = ?", b.snakePrimaryField), id).
First(m).
Error
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotFound
}
return nil, err
}
return m, nil
}
// completeUserAuthCallback is for url "/auth/{provider}/callback"
func (b *Builder) completeUserAuthCallback(w http.ResponseWriter, r *http.Request) {
if b.cookieConfig.SameSite != http.SameSiteStrictMode {
b.completeUserAuthCallbackComplete(w, r)
return
}
completeURL := fmt.Sprintf("%s?%s", b.oauthCallbackCompleteURL, r.URL.Query().Encode())
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Write([]byte(fmt.Sprintf(`
complete
`, completeURL, completeURL)))
return
}
func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *http.Request) {
var err error
var user interface{}
failRedirectURL := b.LogoutURL
defer func() {
if err != nil {
if b.afterFailedToLoginHook != nil {
if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
}
}
http.Redirect(w, r, failRedirectURL, http.StatusFound)
}
}()
var ouser goth.User
ouser, err = gothic.CompleteUserAuth(w, r)
if err != nil {
setFailCodeFlash(w, FailCodeCompleteUserAuthFailed)
return
}
if b.afterOAuthCompleteHook != nil {
if err = b.afterOAuthCompleteHook(r, ouser); err != nil {
setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
return
}
}
userID := ouser.UserID
if b.userModel != nil {
user, err = b.userModel.(OAuthUser).FindUserByOAuthUserID(b.db, b.newUserObject(), ouser.Provider, ouser.UserID)
if err != nil {
if err != gorm.ErrRecordNotFound {
setFailCodeFlash(w, FailCodeSystemError)
return
}
// TODO: maybe the identifier of some providers is not email
identifier := ouser.Email
user, err = b.userModel.(OAuthUser).FindUserByOAuthIdentifier(b.db, b.newUserObject(), ouser.Provider, identifier)
if err != nil {
if err == gorm.ErrRecordNotFound {
setFailCodeFlash(w, FailCodeUserNotFound)
} else {
setFailCodeFlash(w, FailCodeSystemError)
}
return
}
err = user.(OAuthUser).InitOAuthUserID(b.db, b.newUserObject(), ouser.Provider, identifier, ouser.UserID)
if err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
}
userID = objectID(user)
}
claims := UserClaims{
Provider: ouser.Provider,
Email: ouser.Email,
Name: ouser.Name,
UserID: userID,
AvatarURL: ouser.AvatarURL,
RegisteredClaims: b.genBaseSessionClaim(userID),
}
if user == nil {
user = &claims
}
if b.afterLoginHook != nil {
setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)})
if err = b.afterLoginHook(r, user); err != nil {
setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
return
}
}
if err = b.setSecureCookiesByClaims(w, user, claims); err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
redirectURL := b.homePageURLFunc(r, user)
if v := b.getContinueURL(w, r); v != "" {
redirectURL = v
}
http.Redirect(w, r, redirectURL, http.StatusFound)
return
}
// return user if account exists even if there is an error returned
func (b *Builder) authUserPass(account string, password string) (user interface{}, err error) {
user, err = b.userModel.(UserPasser).FindUser(b.db, b.newUserObject(), account)
if err != nil {
if err == gorm.ErrRecordNotFound {
return nil, ErrUserNotFound
}
return nil, err
}
u := user.(UserPasser)
if u.GetLocked() {
return user, ErrUserLocked
}
if !u.IsPasswordCorrect(password) {
if b.maxRetryCount > 0 {
if err = u.IncreaseRetryCount(b.db, b.newUserObject()); err != nil {
return user, err
}
if u.GetLoginRetryCount() >= b.maxRetryCount {
if err = u.LockUser(b.db, b.newUserObject()); err != nil {
return user, err
}
return user, ErrUserGetLocked
}
}
return user, ErrWrongPassword
}
if u.GetLoginRetryCount() != 0 {
if err = u.UnlockUser(b.db, b.newUserObject()); err != nil {
return user, err
}
}
return user, nil
}
func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotFound)
return
}
// check reCAPTCHA token
if b.recaptchaEnabled {
token := r.FormValue("token")
if !recaptchaTokenCheck(b, token) {
setFailCodeFlash(w, FailCodeIncorrectRecaptchaToken)
http.Redirect(w, r, b.loginPageURL, http.StatusFound)
return
}
}
var err error
var user interface{}
failRedirectURL := b.LogoutURL
defer func() {
if err != nil {
if b.afterFailedToLoginHook != nil {
if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
}
}
http.Redirect(w, r, failRedirectURL, http.StatusFound)
}
}()
account := r.FormValue("account")
password := r.FormValue("password")
user, err = b.authUserPass(account, password)
if err != nil {
if err == ErrUserGetLocked && b.afterUserLockedHook != nil {
if err = b.afterUserLockedHook(r, user); err != nil {
setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
return
}
}
code := FailCodeSystemError
switch err {
case ErrWrongPassword, ErrUserNotFound:
code = FailCodeIncorrectAccountNameOrPassword
case ErrUserLocked, ErrUserGetLocked:
code = FailCodeUserLocked
}
setFailCodeFlash(w, code)
setWrongLoginInputFlash(w, WrongLoginInputFlash{
Account: account,
Password: password,
})
return
}
u := user.(UserPasser)
userID := objectID(user)
claims := UserClaims{
UserID: userID,
PassUpdatedAt: u.GetPasswordUpdatedAt(),
RegisteredClaims: b.genBaseSessionClaim(userID),
}
if !b.totpEnabled {
if b.afterLoginHook != nil {
setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)})
if err = b.afterLoginHook(r, user); err != nil {
setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
return
}
}
}
if err = b.setSecureCookiesByClaims(w, user, claims); err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
if b.totpEnabled {
if u.GetIsTOTPSetup() {
http.Redirect(w, r, b.totpValidatePageURL, http.StatusFound)
return
}
var key *otp.Key
if key, err = totp.Generate(
totp.GenerateOpts{
Issuer: b.totpConfig.Issuer,
AccountName: u.GetAccountName(),
},
); err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
if err = u.SetTOTPSecret(b.db, b.newUserObject(), key.Secret()); err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound)
return
}
redirectURL := b.homePageURLFunc(r, user)
if v := b.getContinueURL(w, r); v != "" {
redirectURL = v
}
http.Redirect(w, r, redirectURL, http.StatusFound)
return
}
func (b *Builder) genBaseSessionClaim(id string) jwt.RegisteredClaims {
return genBaseClaims(id, b.sessionMaxAge)
}
func (b *Builder) mustGetSessionToken(claims UserClaims) string {
return mustSignClaims(claims, b.secret)
}
func (b *Builder) setAuthCookiesFromUserClaims(w http.ResponseWriter, claims *UserClaims, secureSalt string) error {
http.SetCookie(w, &http.Cookie{
Name: b.authCookieName,
Value: b.mustGetSessionToken(*claims),
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
MaxAge: b.sessionMaxAge,
Expires: time.Now().Add(time.Duration(b.sessionMaxAge) * time.Second),
HttpOnly: true,
Secure: true,
SameSite: b.cookieConfig.SameSite,
})
if secureSalt != "" {
http.SetCookie(w, &http.Cookie{
Name: b.authSecureCookieName,
Value: mustSignClaims(&claims.RegisteredClaims, b.secret+secureSalt),
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
MaxAge: b.sessionMaxAge,
Expires: time.Now().Add(time.Duration(b.sessionMaxAge) * time.Second),
HttpOnly: true,
Secure: true,
SameSite: b.cookieConfig.SameSite,
})
}
return nil
}
func (b *Builder) cleanAuthCookies(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: b.authCookieName,
Value: "",
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
MaxAge: -1,
Expires: time.Unix(1, 0),
HttpOnly: true,
Secure: true,
})
http.SetCookie(w, &http.Cookie{
Name: b.authSecureCookieName,
Value: "",
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
MaxAge: -1,
Expires: time.Unix(1, 0),
HttpOnly: true,
Secure: true,
})
}
func (b *Builder) setContinueURL(w http.ResponseWriter, r *http.Request) {
continueURL := r.RequestURI
if strings.Contains(continueURL, "?__execute_event__=") {
continueURL = r.Referer()
}
ignore := false
{
ignoreURLs := map[string]struct{}{
b.loginPageURL: {},
b.resetPasswordPageURL: {},
b.forgetPasswordPageURL: {},
b.resetPasswordLinkSentPageURL: {},
b.totpSetupPageURL: {},
b.totpValidatePageURL: {},
b.LogoutURL: {},
}
u, err := url.Parse(continueURL)
if err != nil {
ignore = true
} else {
if _, ok := ignoreURLs[u.Path]; ok {
ignore = true
}
}
}
if ignore {
return
}
http.SetCookie(w, &http.Cookie{
Name: b.continueUrlCookieName,
Value: continueURL,
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
HttpOnly: true,
})
}
func (b *Builder) getContinueURL(w http.ResponseWriter, r *http.Request) string {
c, err := r.Cookie(b.continueUrlCookieName)
if err != nil || c.Value == "" {
return ""
}
http.SetCookie(w, &http.Cookie{
Name: b.continueUrlCookieName,
Value: "",
MaxAge: -1,
Expires: time.Unix(1, 0),
Path: b.cookieConfig.Path,
Domain: b.cookieConfig.Domain,
HttpOnly: true,
})
return c.Value
}
func (b *Builder) setSecureCookiesByClaims(w http.ResponseWriter, user interface{}, claims UserClaims) (err error) {
var secureSalt string
if b.sessionSecureEnabled {
if user.(SessionSecurer).GetSecure() == "" {
err = user.(SessionSecurer).UpdateSecure(b.db, b.newUserObject(), objectID(user))
if err != nil {
return err
}
}
secureSalt = user.(SessionSecurer).GetSecure()
}
if err = b.setAuthCookiesFromUserClaims(w, &claims, secureSalt); err != nil {
return err
}
return nil
}
func (b *Builder) consumeTOTPCode(r *http.Request, up UserPasser, passcode string) error {
if !totp.Validate(passcode, up.GetTOTPSecret()) {
return ErrWrongTOTPCode
}
lastCode, usedAt := up.GetLastUsedTOTPCode()
if usedAt != nil && time.Now().Sub(*usedAt) > 90*time.Second {
lastCode = ""
}
if passcode == lastCode {
if b.afterTOTPCodeReusedHook != nil {
if herr := b.afterTOTPCodeReusedHook(r, GetCurrentUser(r)); herr != nil {
return herr
}
}
return ErrTOTPCodeHasBeenUsed
}
if err := up.SetLastUsedTOTPCode(b.db, b.newUserObject(), passcode); err != nil {
return err
}
return nil
}
func (b *Builder) getFailCodeFromTOTPCodeConsumeError(verr error) FailCode {
fc := FailCodeSystemError
switch verr {
case ErrWrongTOTPCode:
fc = FailCodeIncorrectTOTPCode
case ErrTOTPCodeHasBeenUsed:
fc = FailCodeTOTPCodeHasBeenUsed
}
return fc
}
// logout is for url "/logout/{provider}"
func (b *Builder) logout(w http.ResponseWriter, r *http.Request) {
err := gothic.Logout(w, r)
if err != nil {
//
}
b.cleanAuthCookies(w)
if b.afterLogoutHook != nil {
user := GetCurrentUser(r)
if user != nil {
if herr := b.afterLogoutHook(r, user); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
http.Redirect(w, r, b.loginPageURL, http.StatusFound)
return
}
}
}
http.Redirect(w, r, b.loginPageURL, http.StatusFound)
}
// beginAuth is for url "/auth/{provider}"
func (b *Builder) beginAuth(w http.ResponseWriter, r *http.Request) {
gothic.BeginAuthHandler(w, r)
}
func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotFound)
return
}
failRedirectURL := b.forgetPasswordPageURL
// check reCAPTCHA token
if b.recaptchaEnabled {
token := r.FormValue("token")
if !recaptchaTokenCheck(b, token) {
setFailCodeFlash(w, FailCodeIncorrectRecaptchaToken)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
account := strings.TrimSpace(r.FormValue("account"))
passcode := r.FormValue("otp")
doTOTP := r.URL.Query().Get("totp") == "1"
if doTOTP {
failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1")
}
if account == "" {
setFailCodeFlash(w, FailCodeAccountIsRequired)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
u, err := b.userModel.(UserPasser).FindUser(b.db, b.newUserObject(), account)
if err != nil {
if err == gorm.ErrRecordNotFound {
setFailCodeFlash(w, FailCodeUserNotFound)
} else {
setFailCodeFlash(w, FailCodeSystemError)
}
setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
Account: account,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
_, createdAt, _ := u.(UserPasser).GetResetPasswordToken()
if createdAt != nil {
v := 60 - int(time.Now().Sub(*createdAt).Seconds())
if v > 0 {
setSecondsToRedoFlash(w, v)
setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
Account: account,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
if u.(UserPasser).GetIsTOTPSetup() {
if !doTOTP {
setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
Account: account,
})
failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1")
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil {
fc := b.getFailCodeFromTOTPCodeConsumeError(err)
setNoticeOrFailCodeFlash(w, err, fc)
setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
Account: account,
TOTP: passcode,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
token, err := u.(UserPasser).GenerateResetPasswordToken(b.db, b.newUserObject())
if err != nil {
setFailCodeFlash(w, FailCodeSystemError)
setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
Account: account,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
scheme := "https"
if r.TLS == nil {
scheme = "http"
}
link := fmt.Sprintf("%s://%s%s?id=%s&token=%s", scheme, r.Host, b.resetPasswordPageURL, objectID(u), token)
if doTOTP {
link = MustSetQuery(link, "totp", "1")
}
if b.afterConfirmSendResetPasswordLinkHook != nil {
if herr := b.afterConfirmSendResetPasswordLinkHook(r, u, link); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
http.Redirect(w, r, fmt.Sprintf("%s?a=%s", b.resetPasswordLinkSentPageURL, account), http.StatusFound)
return
}
func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotFound)
return
}
userID := r.FormValue("user_id")
token := r.FormValue("token")
passcode := r.FormValue("otp")
doTOTP := r.URL.Query().Get("totp") == "1"
failRedirectURL := fmt.Sprintf("%s?id=%s&token=%s", b.resetPasswordPageURL, userID, token)
if doTOTP {
failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1")
}
if userID == "" {
setFailCodeFlash(w, FailCodeUserNotFound)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if token == "" {
setFailCodeFlash(w, FailCodeInvalidToken)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
password := r.FormValue("password")
confirmPassword := r.FormValue("confirm_password")
if password == "" {
setFailCodeFlash(w, FailCodePasswordCannotBeEmpty)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if confirmPassword != password {
setFailCodeFlash(w, FailCodePasswordNotMatch)
setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
Password: password,
ConfirmPassword: confirmPassword,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
u, err := b.findUserByID(userID)
if err != nil {
if err == ErrUserNotFound {
setFailCodeFlash(w, FailCodeUserNotFound)
} else {
setFailCodeFlash(w, FailCodeSystemError)
}
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
storedToken, _, expired := u.(UserPasser).GetResetPasswordToken()
if expired {
setFailCodeFlash(w, FailCodeTokenExpired)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if token != storedToken {
setFailCodeFlash(w, FailCodeInvalidToken)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if b.beforeSetPasswordHook != nil {
if herr := b.beforeSetPasswordHook(r, u, password); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
Password: password,
ConfirmPassword: confirmPassword,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
if u.(UserPasser).GetIsTOTPSetup() {
if !doTOTP {
setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
Password: password,
ConfirmPassword: confirmPassword,
})
failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1")
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil {
fc := b.getFailCodeFromTOTPCodeConsumeError(err)
setFailCodeFlash(w, fc)
setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
Password: password,
ConfirmPassword: confirmPassword,
TOTP: passcode,
})
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
err = u.(UserPasser).ConsumeResetPasswordToken(b.db, b.newUserObject())
if err != nil {
setFailCodeFlash(w, FailCodeSystemError)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
err = u.(UserPasser).SetPassword(b.db, b.newUserObject(), password)
if err != nil {
setFailCodeFlash(w, FailCodeSystemError)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
if b.afterResetPasswordHook != nil {
if herr := b.afterResetPasswordHook(r, u); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
http.Redirect(w, r, failRedirectURL, http.StatusFound)
return
}
}
setInfoCodeFlash(w, InfoCodePasswordSuccessfullyReset)
http.Redirect(w, r, b.loginPageURL, http.StatusFound)
return
}
// NoticeError
// ErrWrongPassword
// ErrEmptyPassword
// ErrPasswordNotMatch
// ErrWrongTOTPCode
// ErrTOTPCodeHasBeenUsed
func (b *Builder) ChangePassword(
r *http.Request,
oldPassword string,
password string,
confirmPassword string,
otp string,
) error {
user := GetCurrentUser(r).(UserPasser)
if !user.IsPasswordCorrect(oldPassword) {
return ErrWrongPassword
}
if password == "" {
return ErrEmptyPassword
}
if confirmPassword != password {
return ErrPasswordNotMatch
}
if b.beforeSetPasswordHook != nil {
if herr := b.beforeSetPasswordHook(r, user, password); herr != nil {
return herr
}
}
if b.totpEnabled {
if err := b.consumeTOTPCode(r, user, otp); err != nil {
return err
}
}
err := user.SetPassword(b.db, b.newUserObject(), password)
if err != nil {
return err
}
if b.afterChangePasswordHook != nil {
if herr := b.afterChangePasswordHook(r, user); herr != nil {
return herr
}
}
return nil
}
func (b *Builder) doFormChangePassword(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotFound)
return
}
oldPassword := r.FormValue("old_password")
password := r.FormValue("password")
confirmPassword := r.FormValue("confirm_password")
otp := r.FormValue("otp")
redirectURL := b.changePasswordPageURL
err := b.ChangePassword(r, oldPassword, password, confirmPassword, otp)
if err != nil {
if ne, ok := err.(*NoticeError); ok {
setNoticeFlash(w, ne)
} else {
fc := FailCodeSystemError
switch err {
case ErrWrongPassword:
fc = FailCodeIncorrectPassword
case ErrEmptyPassword:
fc = FailCodePasswordCannotBeEmpty
case ErrPasswordNotMatch:
fc = FailCodePasswordNotMatch
case ErrWrongTOTPCode:
fc = FailCodeIncorrectTOTPCode
case ErrTOTPCodeHasBeenUsed:
fc = FailCodeTOTPCodeHasBeenUsed
}
setFailCodeFlash(w, fc)
}
setWrongChangePasswordInputFlash(w, WrongChangePasswordInputFlash{
OldPassword: oldPassword,
NewPassword: password,
ConfirmPassword: confirmPassword,
TOTP: otp,
})
http.Redirect(w, r, redirectURL, http.StatusFound)
return
}
setInfoCodeFlash(w, InfoCodePasswordSuccessfullyChanged)
http.Redirect(w, r, b.loginPageURL, http.StatusFound)
}
func (b *Builder) totpDo(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
w.WriteHeader(http.StatusNotFound)
return
}
var err error
var user interface{}
failRedirectURL := b.LogoutURL
defer func() {
if err != nil {
if b.afterFailedToLoginHook != nil {
if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
}
}
http.Redirect(w, r, failRedirectURL, http.StatusFound)
}
}()
var claims *UserClaims
claims, err = parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
if err != nil {
return
}
user, err = b.findUserByID(claims.UserID)
if err != nil {
if err == ErrUserNotFound {
setFailCodeFlash(w, FailCodeUserNotFound)
} else {
setFailCodeFlash(w, FailCodeSystemError)
}
return
}
u := user.(UserPasser)
otp := r.FormValue("otp")
isTOTPSetup := u.GetIsTOTPSetup()
if err = b.consumeTOTPCode(r, u, otp); err != nil {
fc := b.getFailCodeFromTOTPCodeConsumeError(err)
setFailCodeFlash(w, fc)
failRedirectURL = b.totpValidatePageURL
if !isTOTPSetup {
failRedirectURL = b.totpSetupPageURL
}
return
}
if !isTOTPSetup {
if err = u.SetIsTOTPSetup(b.db, b.newUserObject(), true); err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
}
claims.TOTPValidated = true
if b.afterLoginHook != nil {
setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
if err = b.afterLoginHook(r, user); err != nil {
setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
return
}
}
err = b.setSecureCookiesByClaims(w, user, *claims)
if err != nil {
setFailCodeFlash(w, FailCodeSystemError)
return
}
redirectURL := b.homePageURLFunc(r, user)
if v := b.getContinueURL(w, r); v != "" {
redirectURL = v
}
http.Redirect(w, r, redirectURL, http.StatusFound)
}
func (b *Builder) Mount(mux *http.ServeMux) {
b.MountAPI(mux)
// pages
wb := web.New()
mux.Handle(b.loginPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.loginPageFunc)))
if b.userPassEnabled {
mux.Handle(b.resetPasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.resetPasswordPageFunc)))
mux.Handle(b.changePasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.changePasswordPageFunc)))
if !b.noForgetPasswordLink {
mux.Handle(b.forgetPasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.forgetPasswordPageFunc)))
mux.Handle(b.resetPasswordLinkSentPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.resetPasswordLinkSentPageFunc)))
}
if b.totpEnabled {
mux.Handle(b.totpSetupPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.totpSetupPageFunc)))
mux.Handle(b.totpValidatePageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.totpValidatePageFunc)))
}
}
// assets
assetsSubFS, err := fs.Sub(assetsFS, "assets")
if err != nil {
panic(err)
}
mux.Handle(assetsPathPrefix, http.StripPrefix(assetsPathPrefix, http.FileServer(http.FS(assetsSubFS))))
}
func (b *Builder) MountAPI(mux *http.ServeMux) {
if len(b.secret) == 0 {
panic("secret is empty")
}
if b.userModel != nil {
if b.db == nil {
panic("db is required")
}
}
mux.HandleFunc(b.LogoutURL, b.logout)
if b.userPassEnabled {
mux.HandleFunc(b.passwordLoginURL, b.userpassLogin)
mux.HandleFunc(b.resetPasswordURL, b.doResetPassword)
mux.HandleFunc(b.changePasswordURL, b.doFormChangePassword)
if !b.noForgetPasswordLink {
mux.HandleFunc(b.sendResetPasswordLinkURL, b.sendResetPasswordLink)
}
if b.totpEnabled {
mux.HandleFunc(b.validateTOTPURL, b.totpDo)
}
}
if b.oauthEnabled {
mux.HandleFunc(b.oauthBeginURL, b.beginAuth)
mux.HandleFunc(b.oauthCallbackURL, b.completeUserAuthCallback)
mux.HandleFunc(b.oauthCallbackCompleteURL, b.completeUserAuthCallbackComplete)
}
}