123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- package login
- import (
- "context"
- "net/http"
- "regexp"
- "strconv"
- "strings"
- "time"
- )
- type ContextUserKey int
- const (
- UserKey ContextUserKey = iota
- loginWIPKey
- )
- type MiddlewareConfig interface {
- middlewareConfig()
- }
- // LoginNotRequired executes the next handler regardless of whether the user is logged in or not
- type LoginNotRequired struct{}
- func (*LoginNotRequired) middlewareConfig() {}
- // DisableAutoRedirectToHomePage makes it possible to visit login page when user is logged in
- type DisableAutoRedirectToHomePage struct{}
- func (*DisableAutoRedirectToHomePage) middlewareConfig() {}
- func (b *Builder) Middleware(cfgs ...MiddlewareConfig) func(next http.Handler) http.Handler {
- mustLogin := true
- autoRedirectToHomePage := true
- for _, cfg := range cfgs {
- switch cfg.(type) {
- case *LoginNotRequired:
- mustLogin = false
- case *DisableAutoRedirectToHomePage:
- autoRedirectToHomePage = false
- }
- }
- whiteList := map[string]struct{}{
- b.oauthBeginURL: {},
- b.oauthCallbackURL: {},
- b.oauthCallbackCompleteURL: {},
- b.passwordLoginURL: {},
- b.forgetPasswordPageURL: {},
- b.sendResetPasswordLinkURL: {},
- b.resetPasswordLinkSentPageURL: {},
- b.resetPasswordURL: {},
- b.resetPasswordPageURL: {},
- b.validateTOTPURL: {},
- }
- staticFileRe := regexp.MustCompile(`\.(css|js|gif|jpg|jpeg|png|ico|svg|ttf|eot|woff|woff2)$`)
- return func(next http.Handler) http.Handler {
- return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- if staticFileRe.MatchString(strings.ToLower(r.URL.Path)) {
- next.ServeHTTP(w, r)
- return
- }
- if _, ok := whiteList[r.URL.Path]; ok {
- next.ServeHTTP(w, r)
- return
- }
- path := strings.TrimRight(r.URL.Path, "/")
- claims, err := parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
- if err != nil {
- if !mustLogin {
- next.ServeHTTP(w, r)
- return
- }
- if r.Method == http.MethodGet {
- b.setContinueURL(w, r)
- }
- if path == b.loginPageURL {
- next.ServeHTTP(w, r)
- } else {
- http.Redirect(w, r, b.loginPageURL, http.StatusFound)
- }
- return
- }
- var user interface{}
- var secureSalt string
- if b.userModel != nil {
- var err error
- user, err = b.findUserByID(claims.UserID)
- if err == nil {
- if claims.Provider == "" {
- if user.(UserPasser).GetPasswordUpdatedAt() != claims.PassUpdatedAt {
- err = ErrPasswordChanged
- }
- if user.(UserPasser).GetLocked() {
- err = ErrUserLocked
- }
- } else {
- user.(OAuthUser).SetAvatar(claims.AvatarURL)
- }
- }
- if err != nil {
- if !mustLogin {
- next.ServeHTTP(w, r)
- return
- }
- switch err {
- case ErrUserNotFound:
- setFailCodeFlash(w, FailCodeUserNotFound)
- case ErrUserLocked:
- setFailCodeFlash(w, FailCodeUserLocked)
- case ErrPasswordChanged:
- isSelfChange := false
- if c, err := r.Cookie(infoCodeFlashCookieName); err == nil {
- v, _ := strconv.Atoi(c.Value)
- if InfoCode(v) == InfoCodePasswordSuccessfullyChanged {
- isSelfChange = true
- }
- }
- if !isSelfChange {
- setWarnCodeFlash(w, WarnCodePasswordHasBeenChanged)
- }
- default:
- panic(err)
- }
- if path == b.LogoutURL {
- next.ServeHTTP(w, r)
- } else {
- http.Redirect(w, r, b.LogoutURL, http.StatusFound)
- }
- return
- }
- if b.sessionSecureEnabled {
- secureSalt = user.(SessionSecurer).GetSecure()
- _, err := parseBaseClaimsFromCookie(r, b.authSecureCookieName, b.secret+secureSalt)
- if err != nil {
- if !mustLogin {
- next.ServeHTTP(w, r)
- return
- }
- if path == b.LogoutURL {
- next.ServeHTTP(w, r)
- } else {
- http.Redirect(w, r, b.LogoutURL, http.StatusFound)
- }
- return
- }
- }
- } else {
- user = claims
- }
- if b.autoExtendSession && time.Now().Sub(claims.IssuedAt.Time).Seconds() > float64(b.sessionMaxAge)/10 {
- oldSessionToken := b.mustGetSessionToken(*claims)
- claims.RegisteredClaims = b.genBaseSessionClaim(claims.UserID)
- b.setAuthCookiesFromUserClaims(w, claims, secureSalt)
- if b.afterExtendSessionHook != nil {
- setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
- if herr := b.afterExtendSessionHook(r, user, oldSessionToken); herr != nil {
- if !mustLogin {
- next.ServeHTTP(w, r)
- return
- }
- setNoticeOrPanic(w, herr)
- http.Redirect(w, r, b.LogoutURL, http.StatusFound)
- return
- }
- }
- }
- r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
- if path == b.LogoutURL {
- next.ServeHTTP(w, r)
- return
- }
- if claims.Provider == "" && b.totpEnabled {
- if !user.(UserPasser).GetIsTOTPSetup() {
- if path == b.loginPageURL {
- next.ServeHTTP(w, r)
- return
- }
- r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
- if path == b.totpSetupPageURL {
- next.ServeHTTP(w, r)
- return
- }
- http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound)
- return
- }
- if !claims.TOTPValidated {
- if path == b.loginPageURL {
- next.ServeHTTP(w, r)
- return
- }
- r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
- if path == b.totpValidatePageURL {
- next.ServeHTTP(w, r)
- return
- }
- http.Redirect(w, r, b.totpValidatePageURL, http.StatusFound)
- return
- }
- }
- if autoRedirectToHomePage {
- if path == b.loginPageURL || path == b.totpSetupPageURL || path == b.totpValidatePageURL {
- http.Redirect(w, r, b.homePageURLFunc(r, user), http.StatusFound)
- return
- }
- }
- next.ServeHTTP(w, r)
- })
- }
- }
- func GetCurrentUser(r *http.Request) (u interface{}) {
- return r.Context().Value(UserKey)
- }
- // IsLoginWIP indicates whether the user is in an intermediate step of login process,
- // such as on the TOTP validation page
- func IsLoginWIP(r *http.Request) bool {
- v, ok := r.Context().Value(loginWIPKey).(bool)
- if !ok {
- return false
- }
- return v
- }
|