package login import ( "context" "errors" "fmt" "io/fs" "log" "net/http" "reflect" "strings" "time" "github.com/golang-jwt/jwt/v4" "github.com/markbates/goth" "github.com/markbates/goth/gothic" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/qor5/web" "github.com/qor5/x/i18n" h "github.com/theplant/htmlgo" "golang.org/x/text/language" "gorm.io/gorm" ) var ( ErrUserNotFound = errors.New("user not found") ErrUserPassChanged = errors.New("password changed") ErrWrongPassword = errors.New("wrong password") ErrUserLocked = errors.New("user locked") ErrUserGetLocked = errors.New("user get locked") ErrWrongTOTPCode = errors.New("wrong totp code") ErrTOTPCodeHasBeenUsed = errors.New("totp code has been used") ErrEmptyPassword = errors.New("empty password") ErrPasswordNotMatch = errors.New("password not match") ) type HomeURLFunc func(r *http.Request, user interface{}) string type NotifyUserOfResetPasswordLinkFunc func(user interface{}, resetLink string) error type PasswordValidationFunc func(password string) (message string, ok bool) type HookFunc func(r *http.Request, user interface{}, vals ...interface{}) error type void struct{} type Provider struct { Goth goth.Provider Key string Text string Logo h.HTMLComponent } type CookieConfig struct { Path string Domain string SameSite http.SameSite } type RecaptchaConfig struct { SiteKey string SecretKey string } type Builder struct { secret string providers []*Provider authCookieName string authSecureCookieName string continueUrlCookieName string // seconds sessionMaxAge int cookieConfig CookieConfig recaptchaEnabled bool recaptchaConfig RecaptchaConfig autoExtendSession bool maxRetryCount int noForgetPasswordLink bool // Common URLs homePageURLFunc HomeURLFunc loginPageURL string LogoutURL string allowURLs map[string]void // TOTP URLs validateTOTPURL string totpSetupPageURL string totpValidatePageURL string // OAuth URLs oauthBeginURL string oauthCallbackURL string oauthCallbackCompleteURL string // UserPass URLs passwordLoginURL string resetPasswordURL string resetPasswordPageURL string changePasswordURL string changePasswordPageURL string forgetPasswordPageURL string sendResetPasswordLinkURL string resetPasswordLinkSentPageURL string loginPageFunc web.PageFunc forgetPasswordPageFunc web.PageFunc resetPasswordLinkSentPageFunc web.PageFunc resetPasswordPageFunc web.PageFunc changePasswordPageFunc web.PageFunc totpSetupPageFunc web.PageFunc totpValidatePageFunc web.PageFunc notifyUserOfResetPasswordLinkFunc NotifyUserOfResetPasswordLinkFunc passwordValidationFunc PasswordValidationFunc afterLoginHook HookFunc afterFailedToLoginHook HookFunc afterUserLockedHook HookFunc afterLogoutHook HookFunc afterSendResetPasswordLinkHook HookFunc afterResetPasswordHook HookFunc afterChangePasswordHook HookFunc afterExtendSessionHook HookFunc afterTOTPCodeReusedHook HookFunc afterOAuthCompleteHook HookFunc db *gorm.DB userModel interface{} snakePrimaryField string tUser reflect.Type userPassEnabled bool oauthEnabled bool sessionSecureEnabled bool totpEnabled bool totpIssuer string i18nBuilder *i18n.Builder } func New() *Builder { r := &Builder{ authCookieName: "auth", authSecureCookieName: "qor5_auth_secure", continueUrlCookieName: "qor5_continue_url", homePageURLFunc: func(r *http.Request, user interface{}) string { return "/" }, loginPageURL: "/auth/login", LogoutURL: "/auth/logout", validateTOTPURL: "/auth/2fa/totp/do", totpSetupPageURL: "/auth/2fa/totp/setup", totpValidatePageURL: "/auth/2fa/totp/validate", oauthBeginURL: "/auth/begin", oauthCallbackURL: "/auth/callback", oauthCallbackCompleteURL: "/auth/callback-complete", passwordLoginURL: "/auth/userpass/login", resetPasswordURL: "/auth/do-reset-password", resetPasswordPageURL: "/auth/reset-password", changePasswordURL: "/auth/do-change-password", changePasswordPageURL: "/auth/change-password", forgetPasswordPageURL: "/auth/forget-password", sendResetPasswordLinkURL: "/auth/send-reset-password-link", resetPasswordLinkSentPageURL: "/auth/reset-password-link-sent", sessionMaxAge: 60 * 60, cookieConfig: CookieConfig{ Path: "/", Domain: "", SameSite: http.SameSiteStrictMode, }, autoExtendSession: true, maxRetryCount: 5, totpEnabled: true, totpIssuer: "qor5", i18nBuilder: i18n.New(), } r.registerI18n() r.initAllowURLs() vh := r.ViewHelper() r.loginPageFunc = defaultLoginPage(vh) r.forgetPasswordPageFunc = defaultForgetPasswordPage(vh) r.resetPasswordLinkSentPageFunc = defaultResetPasswordLinkSentPage(vh) r.resetPasswordPageFunc = defaultResetPasswordPage(vh) r.changePasswordPageFunc = defaultChangePasswordPage(vh) r.totpSetupPageFunc = defaultTOTPSetupPage(vh) r.totpValidatePageFunc = defaultTOTPValidatePage(vh) return r } func (b *Builder) initAllowURLs() { b.allowURLs = map[string]void{ b.oauthBeginURL: {}, b.oauthCallbackURL: {}, b.oauthCallbackCompleteURL: {}, b.passwordLoginURL: {}, b.forgetPasswordPageURL: {}, b.sendResetPasswordLinkURL: {}, b.resetPasswordLinkSentPageURL: {}, b.resetPasswordURL: {}, b.resetPasswordPageURL: {}, b.validateTOTPURL: {}, } } func (b *Builder) AllowURL(v string) { b.allowURLs[v] = void{} } func (b *Builder) Secret(v string) (r *Builder) { b.secret = v return b } func (b *Builder) CookieConfig(v CookieConfig) (r *Builder) { b.cookieConfig = v return b } // RecaptchaConfig should be set if you want to enable Google reCAPTCHA. func (b *Builder) RecaptchaConfig(v RecaptchaConfig) (r *Builder) { b.recaptchaConfig = v b.recaptchaEnabled = b.recaptchaConfig.SiteKey != "" && b.recaptchaConfig.SecretKey != "" return b } func (b *Builder) OAuthProviders(vs ...*Provider) (r *Builder) { if len(vs) == 0 { return b } b.oauthEnabled = true b.providers = vs var gothProviders []goth.Provider for _, v := range vs { gothProviders = append(gothProviders, v.Goth) } goth.UseProviders(gothProviders...) return b } func (b *Builder) AuthCookieName(v string) (r *Builder) { b.authCookieName = v return b } func (b *Builder) LoginURL(v string) (r *Builder) { b.loginPageURL = v return b } func (b *Builder) HomeURLFunc(v HomeURLFunc) (r *Builder) { b.homePageURLFunc = v return b } func (b *Builder) LoginPageFunc(v web.PageFunc) (r *Builder) { b.loginPageFunc = v return b } func (b *Builder) ForgetPasswordPageFunc(v web.PageFunc) (r *Builder) { b.forgetPasswordPageFunc = v return b } func (b *Builder) ResetPasswordLinkSentPageFunc(v web.PageFunc) (r *Builder) { b.resetPasswordLinkSentPageFunc = v return b } func (b *Builder) ResetPasswordPageFunc(v web.PageFunc) (r *Builder) { b.resetPasswordPageFunc = v return b } func (b *Builder) ChangePasswordPageFunc(v web.PageFunc) (r *Builder) { b.changePasswordPageFunc = v return b } func (b *Builder) TOTPSetupPageFunc(v web.PageFunc) (r *Builder) { b.totpSetupPageFunc = v return b } func (b *Builder) TOTPValidatePageFunc(v web.PageFunc) (r *Builder) { b.totpValidatePageFunc = v return b } func (b *Builder) NotifyUserOfResetPasswordLinkFunc(v NotifyUserOfResetPasswordLinkFunc) (r *Builder) { b.notifyUserOfResetPasswordLinkFunc = v return b } func (b *Builder) PasswordValidationFunc(v PasswordValidationFunc) (r *Builder) { b.passwordValidationFunc = v return b } func (b *Builder) wrapHook(v HookFunc) HookFunc { if v == nil { return nil } return func(r *http.Request, user interface{}, vals ...interface{}) error { if GetCurrentUser(r) == nil { r = r.WithContext(context.WithValue(r.Context(), UserKey, user)) } return v(r, user, vals...) } } func (b *Builder) AfterLogin(v HookFunc) (r *Builder) { b.afterLoginHook = b.wrapHook(v) return b } func (b *Builder) AfterFailedToLogin(v HookFunc) (r *Builder) { b.afterFailedToLoginHook = b.wrapHook(v) return b } func (b *Builder) AfterUserLocked(v HookFunc) (r *Builder) { b.afterUserLockedHook = b.wrapHook(v) return b } func (b *Builder) AfterLogout(v HookFunc) (r *Builder) { b.afterLogoutHook = b.wrapHook(v) return b } func (b *Builder) AfterSendResetPasswordLink(v HookFunc) (r *Builder) { b.afterSendResetPasswordLinkHook = b.wrapHook(v) return b } func (b *Builder) AfterResetPassword(v HookFunc) (r *Builder) { b.afterResetPasswordHook = b.wrapHook(v) return b } func (b *Builder) AfterChangePassword(v HookFunc) (r *Builder) { b.afterChangePasswordHook = b.wrapHook(v) return b } // vals: // - old session token func (b *Builder) AfterExtendSession(v HookFunc) (r *Builder) { b.afterExtendSessionHook = b.wrapHook(v) return b } func (b *Builder) AfterTOTPCodeReused(v HookFunc) (r *Builder) { b.afterTOTPCodeReusedHook = b.wrapHook(v) return b } // user is goth.User func (b *Builder) AfterOAuthComplete(v HookFunc) (r *Builder) { b.afterOAuthCompleteHook = b.wrapHook(v) return b } // seconds // default 1h func (b *Builder) SessionMaxAge(v int) (r *Builder) { b.sessionMaxAge = v return b } // extend the session if successfully authenticated // default true func (b *Builder) AutoExtendSession(v bool) (r *Builder) { b.autoExtendSession = v return b } // default 5 // MaxRetryCount <= 0 means no max retry count limit func (b *Builder) MaxRetryCount(v int) (r *Builder) { b.maxRetryCount = v return b } func (b *Builder) TOTPEnabled(v bool) (r *Builder) { b.totpEnabled = v return b } func (b *Builder) TOTPIssuer(v string) (r *Builder) { b.totpIssuer = v return b } func (b *Builder) NoForgetPasswordLink(v bool) (r *Builder) { b.noForgetPasswordLink = v return b } func (b *Builder) DB(v *gorm.DB) (r *Builder) { b.db = v return b } func (b *Builder) I18n(v *i18n.Builder) (r *Builder) { b.i18nBuilder = v b.registerI18n() return b } func (b *Builder) GetSessionMaxAge() int { return b.sessionMaxAge } func (b *Builder) ViewHelper() *ViewHelper { return &ViewHelper{ b: b, } } func (b *Builder) registerI18n() { b.i18nBuilder.RegisterForModule(language.English, I18nLoginKey, Messages_en_US). RegisterForModule(language.SimplifiedChinese, I18nLoginKey, Messages_zh_CN). RegisterForModule(language.Japanese, I18nLoginKey, Messages_ja_JP) } func (b *Builder) UserModel(m interface{}) (r *Builder) { b.userModel = m b.tUser = underlyingReflectType(reflect.TypeOf(m)) b.snakePrimaryField = snakePrimaryField(m) if _, ok := m.(UserPasser); ok { b.userPassEnabled = true } if _, ok := m.(OAuthUser); ok { b.oauthEnabled = true } if _, ok := m.(SessionSecurer); ok { b.sessionSecureEnabled = true } return b } func (b *Builder) newUserObject() interface{} { return reflect.New(b.tUser).Interface() } func (b *Builder) findUserByID(id string) (user interface{}, err error) { m := b.newUserObject() err = b.db.Where(fmt.Sprintf("%s = ?", b.snakePrimaryField), id). First(m). Error if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotFound } return nil, err } return m, nil } // completeUserAuthCallback is for url "/auth/{provider}/callback" func (b *Builder) completeUserAuthCallback(w http.ResponseWriter, r *http.Request) { if b.cookieConfig.SameSite != http.SameSiteStrictMode { b.completeUserAuthCallbackComplete(w, r) return } completeURL := fmt.Sprintf("%s?%s", b.oauthCallbackCompleteURL, r.URL.Query().Encode()) w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Write([]byte(fmt.Sprintf(` complete `, completeURL, completeURL))) return } func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *http.Request) { var err error var user interface{} defer func() { if b.afterFailedToLoginHook != nil && err != nil && user != nil { b.afterFailedToLoginHook(r, user) } }() var ouser goth.User ouser, err = gothic.CompleteUserAuth(w, r) if err != nil { log.Println("completeUserAuthWithSetCookie", err) setFailCodeFlash(w, FailCodeCompleteUserAuthFailed) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } if b.afterOAuthCompleteHook != nil { if herr := b.afterOAuthCompleteHook(r, ouser); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } } userID := ouser.UserID if b.userModel != nil { user, err = b.userModel.(OAuthUser).FindUserByOAuthUserID(b.db, b.newUserObject(), ouser.Provider, ouser.UserID) if err != nil { if err != gorm.ErrRecordNotFound { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } // TODO: maybe the indentifier of some providers is not email indentifier := ouser.Email user, err = b.userModel.(OAuthUser).FindUserByOAuthIndentifier(b.db, b.newUserObject(), ouser.Provider, indentifier) if err != nil { if err == gorm.ErrRecordNotFound { setFailCodeFlash(w, FailCodeUserNotFound) } else { setFailCodeFlash(w, FailCodeSystemError) } http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } err = user.(OAuthUser).InitOAuthUserID(b.db, b.newUserObject(), ouser.Provider, indentifier, ouser.UserID) if err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } } userID = objectID(user) } claims := UserClaims{ Provider: ouser.Provider, Email: ouser.Email, Name: ouser.Name, UserID: userID, AvatarURL: ouser.AvatarURL, RegisteredClaims: b.genBaseSessionClaim(userID), } if user == nil { user = &claims } if b.afterLoginHook != nil { setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)}) if herr := b.afterLoginHook(r, user); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } if err := b.setSecureCookiesByClaims(w, user, claims); err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } redirectURL := b.homePageURLFunc(r, user) if v := b.getContinueURL(w, r); v != "" { redirectURL = v } http.Redirect(w, r, redirectURL, http.StatusFound) return } // return user if account exists even if there is an error returned func (b *Builder) authUserPass(account string, password string) (user interface{}, err error) { user, err = b.userModel.(UserPasser).FindUser(b.db, b.newUserObject(), account) if err != nil { if err == gorm.ErrRecordNotFound { return nil, ErrUserNotFound } return nil, err } u := user.(UserPasser) if u.GetLocked() { return user, ErrUserLocked } if !u.IsPasswordCorrect(password) { if b.maxRetryCount > 0 { if err = u.IncreaseRetryCount(b.db, b.newUserObject()); err != nil { return user, err } if u.GetLoginRetryCount() >= b.maxRetryCount { if err = u.LockUser(b.db, b.newUserObject()); err != nil { return user, err } return user, ErrUserGetLocked } } return user, ErrWrongPassword } if u.GetLoginRetryCount() != 0 { if err = u.UnlockUser(b.db, b.newUserObject()); err != nil { return user, err } } return user, nil } func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusNotFound) return } // check reCAPTCHA token if b.recaptchaEnabled { token := r.FormValue("token") if !recaptchaTokenCheck(b, token) { setFailCodeFlash(w, FailCodeIncorrectRecaptchaToken) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } var err error var user interface{} defer func() { if b.afterFailedToLoginHook != nil && err != nil && user != nil { b.afterFailedToLoginHook(r, user) } }() account := r.FormValue("account") password := r.FormValue("password") user, err = b.authUserPass(account, password) if err != nil { if err == ErrUserGetLocked && b.afterUserLockedHook != nil { if herr := b.afterUserLockedHook(r, user); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } code := FailCodeSystemError switch err { case ErrWrongPassword, ErrUserNotFound: code = FailCodeIncorrectAccountNameOrPassword case ErrUserLocked, ErrUserGetLocked: code = FailCodeUserLocked } setFailCodeFlash(w, code) setWrongLoginInputFlash(w, WrongLoginInputFlash{ Account: account, Password: password, }) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } u := user.(UserPasser) userID := objectID(user) claims := UserClaims{ UserID: userID, PassUpdatedAt: u.GetPasswordUpdatedAt(), RegisteredClaims: b.genBaseSessionClaim(userID), } if !b.totpEnabled { if b.afterLoginHook != nil { setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)}) if herr := b.afterLoginHook(r, user); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } } if err = b.setSecureCookiesByClaims(w, user, claims); err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } if b.totpEnabled { if u.GetIsTOTPSetup() { http.Redirect(w, r, b.totpValidatePageURL, http.StatusFound) return } var key *otp.Key if key, err = totp.Generate( totp.GenerateOpts{ Issuer: b.totpIssuer, AccountName: u.GetAccountName(), }, ); err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } if err = u.SetTOTPSecret(b.db, b.newUserObject(), key.Secret()); err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound) return } redirectURL := b.homePageURLFunc(r, user) if v := b.getContinueURL(w, r); v != "" { redirectURL = v } http.Redirect(w, r, redirectURL, http.StatusFound) return } func (b *Builder) genBaseSessionClaim(id string) jwt.RegisteredClaims { return genBaseClaims(id, b.sessionMaxAge) } func (b *Builder) mustGetSessionToken(claims UserClaims) string { return mustSignClaims(claims, b.secret) } func (b *Builder) setAuthCookiesFromUserClaims(w http.ResponseWriter, claims *UserClaims, secureSalt string) error { http.SetCookie(w, &http.Cookie{ Name: b.authCookieName, Value: b.mustGetSessionToken(*claims), Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, MaxAge: b.sessionMaxAge, Expires: time.Now().Add(time.Duration(b.sessionMaxAge) * time.Second), HttpOnly: true, Secure: true, SameSite: b.cookieConfig.SameSite, }) if secureSalt != "" { http.SetCookie(w, &http.Cookie{ Name: b.authSecureCookieName, Value: mustSignClaims(&claims.RegisteredClaims, b.secret+secureSalt), Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, MaxAge: b.sessionMaxAge, Expires: time.Now().Add(time.Duration(b.sessionMaxAge) * time.Second), HttpOnly: true, Secure: true, SameSite: b.cookieConfig.SameSite, }) } return nil } func (b *Builder) cleanAuthCookies(w http.ResponseWriter) { http.SetCookie(w, &http.Cookie{ Name: b.authCookieName, Value: "", Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, MaxAge: -1, Expires: time.Unix(1, 0), HttpOnly: true, Secure: true, }) http.SetCookie(w, &http.Cookie{ Name: b.authSecureCookieName, Value: "", Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, MaxAge: -1, Expires: time.Unix(1, 0), HttpOnly: true, Secure: true, }) } func (b *Builder) setContinueURL(w http.ResponseWriter, r *http.Request) { continueURL := r.RequestURI if strings.Contains(continueURL, "?__execute_event__=") { continueURL = r.Referer() } if !strings.HasPrefix(continueURL, "/auth/") { http.SetCookie(w, &http.Cookie{ Name: b.continueUrlCookieName, Value: continueURL, Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, HttpOnly: true, }) } } func (b *Builder) getContinueURL(w http.ResponseWriter, r *http.Request) string { c, err := r.Cookie(b.continueUrlCookieName) if err != nil || c.Value == "" { return "" } http.SetCookie(w, &http.Cookie{ Name: b.continueUrlCookieName, Value: "", MaxAge: -1, Expires: time.Unix(1, 0), Path: b.cookieConfig.Path, Domain: b.cookieConfig.Domain, HttpOnly: true, }) return c.Value } func (b *Builder) setSecureCookiesByClaims(w http.ResponseWriter, user interface{}, claims UserClaims) (err error) { var secureSalt string if b.sessionSecureEnabled { if user.(SessionSecurer).GetSecure() == "" { err = user.(SessionSecurer).UpdateSecure(b.db, b.newUserObject(), objectID(user)) if err != nil { return err } } secureSalt = user.(SessionSecurer).GetSecure() } if err = b.setAuthCookiesFromUserClaims(w, &claims, secureSalt); err != nil { return err } return nil } func (b *Builder) consumeTOTPCode(r *http.Request, up UserPasser, passcode string) error { if !totp.Validate(passcode, up.GetTOTPSecret()) { return ErrWrongTOTPCode } lastCode, usedAt := up.GetLastUsedTOTPCode() if usedAt != nil && time.Now().Sub(*usedAt) > 90*time.Second { lastCode = "" } if passcode == lastCode { if b.afterTOTPCodeReusedHook != nil { if herr := b.afterTOTPCodeReusedHook(r, GetCurrentUser(r)); herr != nil { return herr } } return ErrTOTPCodeHasBeenUsed } if err := up.SetLastUsedTOTPCode(b.db, b.newUserObject(), passcode); err != nil { return err } return nil } func (b *Builder) getFailCodeFromTOTPCodeConsumeError(verr error) FailCode { fc := FailCodeSystemError switch verr { case ErrWrongTOTPCode: fc = FailCodeIncorrectTOTPCode case ErrTOTPCodeHasBeenUsed: fc = FailCodeTOTPCodeHasBeenUsed } return fc } // logout is for url "/logout/{provider}" func (b *Builder) logout(w http.ResponseWriter, r *http.Request) { err := gothic.Logout(w, r) if err != nil { // } b.cleanAuthCookies(w) if b.afterLogoutHook != nil { user := GetCurrentUser(r) if user != nil { if herr := b.afterLogoutHook(r, user); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } } http.Redirect(w, r, b.loginPageURL, http.StatusFound) } // beginAuth is for url "/auth/{provider}" func (b *Builder) beginAuth(w http.ResponseWriter, r *http.Request) { gothic.BeginAuthHandler(w, r) } func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusNotFound) return } failRedirectURL := b.forgetPasswordPageURL // check reCAPTCHA token if b.recaptchaEnabled { token := r.FormValue("token") if !recaptchaTokenCheck(b, token) { setFailCodeFlash(w, FailCodeIncorrectRecaptchaToken) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } account := strings.TrimSpace(r.FormValue("account")) passcode := r.FormValue("otp") doTOTP := r.URL.Query().Get("totp") == "1" if doTOTP { failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1") } if account == "" { setFailCodeFlash(w, FailCodeAccountIsRequired) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } u, err := b.userModel.(UserPasser).FindUser(b.db, b.newUserObject(), account) if err != nil { if err == gorm.ErrRecordNotFound { setFailCodeFlash(w, FailCodeUserNotFound) } else { setFailCodeFlash(w, FailCodeSystemError) } setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } _, createdAt, _ := u.(UserPasser).GetResetPasswordToken() if createdAt != nil { v := 60 - int(time.Now().Sub(*createdAt).Seconds()) if v > 0 { setSecondsToRedoFlash(w, v) setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } if u.(UserPasser).GetIsTOTPSetup() { if !doTOTP { setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, }) failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1") http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil { fc := b.getFailCodeFromTOTPCodeConsumeError(err) setFailCodeFlash(w, fc) setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, TOTP: passcode, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } token, err := u.(UserPasser).GenerateResetPasswordToken(b.db, b.newUserObject()) if err != nil { setFailCodeFlash(w, FailCodeSystemError) setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } scheme := "https" if r.TLS == nil { scheme = "http" } link := fmt.Sprintf("%s://%s%s?id=%s&token=%s", scheme, r.Host, b.resetPasswordPageURL, objectID(u), token) if doTOTP { link = MustSetQuery(link, "totp", "1") } if err = b.notifyUserOfResetPasswordLinkFunc(u, link); err != nil { setFailCodeFlash(w, FailCodeSystemError) setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{ Account: account, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if b.afterSendResetPasswordLinkHook != nil { if herr := b.afterSendResetPasswordLinkHook(r, u); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } http.Redirect(w, r, fmt.Sprintf("%s?a=%s", b.resetPasswordLinkSentPageURL, account), http.StatusFound) return } func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusNotFound) return } userID := r.FormValue("user_id") token := r.FormValue("token") passcode := r.FormValue("otp") doTOTP := r.URL.Query().Get("totp") == "1" failRedirectURL := fmt.Sprintf("%s?id=%s&token=%s", b.resetPasswordPageURL, userID, token) if doTOTP { failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1") } if userID == "" { setFailCodeFlash(w, FailCodeUserNotFound) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if token == "" { setFailCodeFlash(w, FailCodeInvalidToken) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } password := r.FormValue("password") confirmPassword := r.FormValue("confirm_password") if password == "" { setFailCodeFlash(w, FailCodePasswordCannotBeEmpty) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if confirmPassword != password { setFailCodeFlash(w, FailCodePasswordNotMatch) setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{ Password: password, ConfirmPassword: confirmPassword, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if b.passwordValidationFunc != nil { msg, ok := b.passwordValidationFunc(password) if !ok { setCustomErrorMessageFlash(w, msg) setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{ Password: password, ConfirmPassword: confirmPassword, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } u, err := b.findUserByID(userID) if err != nil { if err == ErrUserNotFound { setFailCodeFlash(w, FailCodeUserNotFound) } else { setFailCodeFlash(w, FailCodeSystemError) } http.Redirect(w, r, failRedirectURL, http.StatusFound) return } storedToken, _, expired := u.(UserPasser).GetResetPasswordToken() if expired { setFailCodeFlash(w, FailCodeTokenExpired) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if token != storedToken { setFailCodeFlash(w, FailCodeInvalidToken) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if u.(UserPasser).GetIsTOTPSetup() { if !doTOTP { setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{ Password: password, ConfirmPassword: confirmPassword, }) failRedirectURL = MustSetQuery(failRedirectURL, "totp", "1") http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil { fc := b.getFailCodeFromTOTPCodeConsumeError(err) setFailCodeFlash(w, fc) setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{ Password: password, ConfirmPassword: confirmPassword, TOTP: passcode, }) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } err = u.(UserPasser).ConsumeResetPasswordToken(b.db, b.newUserObject()) if err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } err = u.(UserPasser).SetPassword(b.db, b.newUserObject(), password) if err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } if b.afterResetPasswordHook != nil { if herr := b.afterResetPasswordHook(r, u); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, failRedirectURL, http.StatusFound) return } } setInfoCodeFlash(w, InfoCodePasswordSuccessfullyReset) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } type ValidationError struct { Msg string } func (e *ValidationError) Error() string { return e.Msg } // validationError // errWrongPassword // errEmptyPassword // errPasswordNotMatch // errWrongTOTPCode // errTOTPCodeHasBeenUsed func (b *Builder) ChangePassword( r *http.Request, oldPassword string, password string, confirmPassword string, otp string, ) error { user := GetCurrentUser(r).(UserPasser) if !user.IsPasswordCorrect(oldPassword) { return ErrWrongPassword } if password == "" { return ErrEmptyPassword } if confirmPassword != password { return ErrPasswordNotMatch } if b.passwordValidationFunc != nil { msg, ok := b.passwordValidationFunc(password) if !ok { return &ValidationError{Msg: msg} } } if b.totpEnabled { if err := b.consumeTOTPCode(r, user, otp); err != nil { return err } } err := user.SetPassword(b.db, b.newUserObject(), password) if err != nil { return err } if b.afterChangePasswordHook != nil { if herr := b.afterChangePasswordHook(r, user); herr != nil { return herr } } return nil } func (b *Builder) doFormChangePassword(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusNotFound) return } oldPassword := r.FormValue("old_password") password := r.FormValue("password") confirmPassword := r.FormValue("confirm_password") otp := r.FormValue("otp") redirectURL := b.changePasswordPageURL err := b.ChangePassword(r, oldPassword, password, confirmPassword, otp) if err != nil { if ve, ok := err.(*ValidationError); ok { setCustomErrorMessageFlash(w, ve.Msg) } else { fc := FailCodeSystemError switch err { case ErrWrongPassword: fc = FailCodeIncorrectPassword case ErrEmptyPassword: fc = FailCodePasswordCannotBeEmpty case ErrPasswordNotMatch: fc = FailCodePasswordNotMatch case ErrWrongTOTPCode: fc = FailCodeIncorrectTOTPCode case ErrTOTPCodeHasBeenUsed: fc = FailCodeTOTPCodeHasBeenUsed } setFailCodeFlash(w, fc) } setWrongChangePasswordInputFlash(w, WrongChangePasswordInputFlash{ OldPassword: oldPassword, NewPassword: password, ConfirmPassword: confirmPassword, TOTP: otp, }) http.Redirect(w, r, redirectURL, http.StatusFound) return } setInfoCodeFlash(w, InfoCodePasswordSuccessfullyChanged) http.Redirect(w, r, b.loginPageURL, http.StatusFound) } func (b *Builder) totpDo(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { w.WriteHeader(http.StatusNotFound) return } var err error var user interface{} defer func() { if b.afterFailedToLoginHook != nil && err != nil && user != nil { b.afterFailedToLoginHook(r, user) } }() var claims *UserClaims claims, err = parseUserClaimsFromCookie(r, b.authCookieName, b.secret) if err != nil { http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } user, err = b.findUserByID(claims.UserID) if err != nil { if err == ErrUserNotFound { setFailCodeFlash(w, FailCodeUserNotFound) } else { setFailCodeFlash(w, FailCodeSystemError) } http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } u := user.(UserPasser) otp := r.FormValue("otp") isTOTPSetup := u.GetIsTOTPSetup() if err := b.consumeTOTPCode(r, u, otp); err != nil { fc := b.getFailCodeFromTOTPCodeConsumeError(err) setFailCodeFlash(w, fc) redirectURL := b.totpValidatePageURL if !isTOTPSetup { redirectURL = b.totpSetupPageURL } http.Redirect(w, r, redirectURL, http.StatusFound) return } if !isTOTPSetup { if err = u.SetIsTOTPSetup(b.db, b.newUserObject(), true); err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } } claims.TOTPValidated = true if b.afterLoginHook != nil { setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)}) if herr := b.afterLoginHook(r, user); herr != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.loginPageURL, http.StatusFound) return } } err = b.setSecureCookiesByClaims(w, user, *claims) if err != nil { setFailCodeFlash(w, FailCodeSystemError) http.Redirect(w, r, b.LogoutURL, http.StatusFound) return } redirectURL := b.homePageURLFunc(r, user) if v := b.getContinueURL(w, r); v != "" { redirectURL = v } http.Redirect(w, r, redirectURL, http.StatusFound) } func (b *Builder) Mount(mux *http.ServeMux) { if len(b.secret) == 0 { panic("secret is empty") } if b.userModel != nil { if b.db == nil { panic("db is required") } } wb := web.New() mux.HandleFunc(b.LogoutURL, b.logout) mux.Handle(b.loginPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.loginPageFunc))) if b.userPassEnabled { mux.HandleFunc(b.passwordLoginURL, b.userpassLogin) mux.HandleFunc(b.resetPasswordURL, b.doResetPassword) mux.HandleFunc(b.changePasswordURL, b.doFormChangePassword) mux.Handle(b.resetPasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.resetPasswordPageFunc))) mux.Handle(b.changePasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.changePasswordPageFunc))) if !b.noForgetPasswordLink { mux.HandleFunc(b.sendResetPasswordLinkURL, b.sendResetPasswordLink) mux.Handle(b.forgetPasswordPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.forgetPasswordPageFunc))) mux.Handle(b.resetPasswordLinkSentPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.resetPasswordLinkSentPageFunc))) } if b.totpEnabled { mux.HandleFunc(b.validateTOTPURL, b.totpDo) mux.Handle(b.totpSetupPageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.totpSetupPageFunc))) mux.Handle(b.totpValidatePageURL, b.i18nBuilder.EnsureLanguage(wb.Page(b.totpValidatePageFunc))) } } if b.oauthEnabled { mux.HandleFunc(b.oauthBeginURL, b.beginAuth) mux.HandleFunc(b.oauthCallbackURL, b.completeUserAuthCallback) mux.HandleFunc(b.oauthCallbackCompleteURL, b.completeUserAuthCallbackComplete) } // assets assetsSubFS, err := fs.Sub(assetsFS, "assets") if err != nil { panic(err) } mux.Handle(assetsPathPrefix, http.StripPrefix(assetsPathPrefix, http.FileServer(http.FS(assetsSubFS)))) }