Browse Source

Merge pull request #177 from qor5/login-panic-unexpected-error

login: panic for unexpected errors
xuxinx 1 year ago
parent
commit
40997f230e
4 changed files with 112 additions and 115 deletions
  1. 96 96
      login/builder.go
  2. 11 0
      login/flash.go
  3. 3 15
      login/middleware.go
  4. 2 4
      login/views.go

+ 96 - 96
login/builder.go

@@ -509,10 +509,13 @@ func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *htt
 	var user interface{}
 	failRedirectURL := b.LogoutURL
 	defer func() {
+		if perr := recover(); perr != nil {
+			panic(perr)
+		}
 		if err != nil {
 			if b.afterFailedToLoginHook != nil {
 				if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
-					setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+					setNoticeOrPanic(w, herr)
 				}
 			}
 			http.Redirect(w, r, failRedirectURL, http.StatusFound)
@@ -528,7 +531,7 @@ func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *htt
 
 	if b.afterOAuthCompleteHook != nil {
 		if err = b.afterOAuthCompleteHook(r, ouser); err != nil {
-			setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
+			setNoticeOrPanic(w, err)
 			return
 		}
 	}
@@ -539,24 +542,21 @@ func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *htt
 		user, err = b.userModel.(OAuthUser).FindUserByOAuthUserID(b.db, b.newUserObject(), ouser.Provider, ouser.UserID)
 		if err != nil {
 			if err != gorm.ErrRecordNotFound {
-				setFailCodeFlash(w, FailCodeSystemError)
-				return
+				panic(err)
 			}
 			// TODO: maybe the identifier of some providers is not email
 			identifier := ouser.Email
 			user, err = b.userModel.(OAuthUser).FindUserByOAuthIdentifier(b.db, b.newUserObject(), ouser.Provider, identifier)
 			if err != nil {
-				if err == gorm.ErrRecordNotFound {
-					setFailCodeFlash(w, FailCodeUserNotFound)
-				} else {
-					setFailCodeFlash(w, FailCodeSystemError)
+				if err != gorm.ErrRecordNotFound {
+					panic(err)
 				}
+				setFailCodeFlash(w, FailCodeUserNotFound)
 				return
 			}
 			err = user.(OAuthUser).InitOAuthUserID(b.db, b.newUserObject(), ouser.Provider, identifier, ouser.UserID)
 			if err != nil {
-				setFailCodeFlash(w, FailCodeSystemError)
-				return
+				panic(err)
 			}
 		}
 		userID = objectID(user)
@@ -577,14 +577,13 @@ func (b *Builder) completeUserAuthCallbackComplete(w http.ResponseWriter, r *htt
 	if b.afterLoginHook != nil {
 		setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)})
 		if err = b.afterLoginHook(r, user); err != nil {
-			setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
+			setNoticeOrPanic(w, err)
 			return
 		}
 	}
 
 	if err = b.setSecureCookiesByClaims(w, user, claims); err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		return
+		panic(err)
 	}
 
 	redirectURL := b.homePageURLFunc(r, user)
@@ -655,10 +654,13 @@ func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) {
 	var user interface{}
 	failRedirectURL := b.LogoutURL
 	defer func() {
+		if perr := recover(); perr != nil {
+			panic(perr)
+		}
 		if err != nil {
 			if b.afterFailedToLoginHook != nil {
 				if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
-					setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+					setNoticeOrPanic(w, herr)
 				}
 			}
 			http.Redirect(w, r, failRedirectURL, http.StatusFound)
@@ -671,17 +673,19 @@ func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		if err == ErrUserGetLocked && b.afterUserLockedHook != nil {
 			if err = b.afterUserLockedHook(r, user); err != nil {
-				setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
+				setNoticeOrPanic(w, err)
 				return
 			}
 		}
 
-		code := FailCodeSystemError
+		var code FailCode
 		switch err {
 		case ErrWrongPassword, ErrUserNotFound:
 			code = FailCodeIncorrectAccountNameOrPassword
 		case ErrUserLocked, ErrUserGetLocked:
 			code = FailCodeUserLocked
+		default:
+			panic(err)
 		}
 		setFailCodeFlash(w, code)
 		setWrongLoginInputFlash(w, WrongLoginInputFlash{
@@ -703,15 +707,14 @@ func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) {
 		if b.afterLoginHook != nil {
 			setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(claims)})
 			if err = b.afterLoginHook(r, user); err != nil {
-				setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
+				setNoticeOrPanic(w, err)
 				return
 			}
 		}
 	}
 
 	if err = b.setSecureCookiesByClaims(w, user, claims); err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		return
+		panic(err)
 	}
 
 	if b.totpEnabled {
@@ -727,13 +730,11 @@ func (b *Builder) userpassLogin(w http.ResponseWriter, r *http.Request) {
 				AccountName: u.GetAccountName(),
 			},
 		); err != nil {
-			setFailCodeFlash(w, FailCodeSystemError)
-			return
+			panic(err)
 		}
 
 		if err = u.SetTOTPSecret(b.db, b.newUserObject(), key.Secret()); err != nil {
-			setFailCodeFlash(w, FailCodeSystemError)
-			return
+			panic(err)
 		}
 
 		http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound)
@@ -757,7 +758,7 @@ func (b *Builder) mustGetSessionToken(claims UserClaims) string {
 	return mustSignClaims(claims, b.secret)
 }
 
-func (b *Builder) setAuthCookiesFromUserClaims(w http.ResponseWriter, claims *UserClaims, secureSalt string) error {
+func (b *Builder) setAuthCookiesFromUserClaims(w http.ResponseWriter, claims *UserClaims, secureSalt string) {
 	http.SetCookie(w, &http.Cookie{
 		Name:     b.authCookieName,
 		Value:    b.mustGetSessionToken(*claims),
@@ -783,8 +784,6 @@ func (b *Builder) setAuthCookiesFromUserClaims(w http.ResponseWriter, claims *Us
 			SameSite: b.cookieConfig.SameSite,
 		})
 	}
-
-	return nil
 }
 
 func (b *Builder) cleanAuthCookies(w http.ResponseWriter) {
@@ -877,10 +876,7 @@ func (b *Builder) setSecureCookiesByClaims(w http.ResponseWriter, user interface
 		}
 		secureSalt = user.(SessionSecurer).GetSecure()
 	}
-	if err = b.setAuthCookiesFromUserClaims(w, &claims, secureSalt); err != nil {
-		return err
-	}
-
+	b.setAuthCookiesFromUserClaims(w, &claims, secureSalt)
 	return nil
 }
 
@@ -906,18 +902,6 @@ func (b *Builder) consumeTOTPCode(r *http.Request, up UserPasser, passcode strin
 	return nil
 }
 
-func (b *Builder) getFailCodeFromTOTPCodeConsumeError(verr error) FailCode {
-	fc := FailCodeSystemError
-	switch verr {
-	case ErrWrongTOTPCode:
-		fc = FailCodeIncorrectTOTPCode
-	case ErrTOTPCodeHasBeenUsed:
-		fc = FailCodeTOTPCodeHasBeenUsed
-	}
-
-	return fc
-}
-
 // logout is for url "/logout/{provider}"
 func (b *Builder) logout(w http.ResponseWriter, r *http.Request) {
 	err := gothic.Logout(w, r)
@@ -931,7 +915,7 @@ func (b *Builder) logout(w http.ResponseWriter, r *http.Request) {
 		user := GetCurrentUser(r)
 		if user != nil {
 			if herr := b.afterLogoutHook(r, user); herr != nil {
-				setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+				setNoticeOrPanic(w, herr)
 				http.Redirect(w, r, b.loginPageURL, http.StatusFound)
 				return
 			}
@@ -982,14 +966,13 @@ func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request)
 	if err != nil {
 		if err == gorm.ErrRecordNotFound {
 			setFailCodeFlash(w, FailCodeUserNotFound)
-		} else {
-			setFailCodeFlash(w, FailCodeSystemError)
+			setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
+				Account: account,
+			})
+			http.Redirect(w, r, failRedirectURL, http.StatusFound)
+			return
 		}
-		setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
-			Account: account,
-		})
-		http.Redirect(w, r, failRedirectURL, http.StatusFound)
-		return
+		panic(err)
 	}
 
 	_, createdAt, _ := u.(UserPasser).GetResetPasswordToken()
@@ -1016,7 +999,15 @@ func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request)
 		}
 
 		if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil {
-			fc := b.getFailCodeFromTOTPCodeConsumeError(err)
+			var fc FailCode
+			switch err {
+			case ErrWrongTOTPCode:
+				fc = FailCodeIncorrectTOTPCode
+			case ErrTOTPCodeHasBeenUsed:
+				fc = FailCodeTOTPCodeHasBeenUsed
+			default:
+				panic(err)
+			}
 			setNoticeOrFailCodeFlash(w, err, fc)
 			setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
 				Account: account,
@@ -1029,12 +1020,7 @@ func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request)
 
 	token, err := u.(UserPasser).GenerateResetPasswordToken(b.db, b.newUserObject())
 	if err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		setWrongForgetPasswordInputFlash(w, WrongForgetPasswordInputFlash{
-			Account: account,
-		})
-		http.Redirect(w, r, failRedirectURL, http.StatusFound)
-		return
+		panic(err)
 	}
 
 	scheme := "https"
@@ -1047,7 +1033,7 @@ func (b *Builder) sendResetPasswordLink(w http.ResponseWriter, r *http.Request)
 	}
 	if b.afterConfirmSendResetPasswordLinkHook != nil {
 		if herr := b.afterConfirmSendResetPasswordLinkHook(r, u, link); herr != nil {
-			setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+			setNoticeOrPanic(w, herr)
 			http.Redirect(w, r, failRedirectURL, http.StatusFound)
 			return
 		}
@@ -1104,11 +1090,10 @@ func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) {
 	if err != nil {
 		if err == ErrUserNotFound {
 			setFailCodeFlash(w, FailCodeUserNotFound)
-		} else {
-			setFailCodeFlash(w, FailCodeSystemError)
+			http.Redirect(w, r, failRedirectURL, http.StatusFound)
+			return
 		}
-		http.Redirect(w, r, failRedirectURL, http.StatusFound)
-		return
+		panic(err)
 	}
 
 	storedToken, _, expired := u.(UserPasser).GetResetPasswordToken()
@@ -1125,7 +1110,7 @@ func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) {
 
 	if b.beforeSetPasswordHook != nil {
 		if herr := b.beforeSetPasswordHook(r, u, password); herr != nil {
-			setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+			setNoticeOrPanic(w, herr)
 			setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
 				Password:        password,
 				ConfirmPassword: confirmPassword,
@@ -1147,7 +1132,15 @@ func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) {
 		}
 
 		if err = b.consumeTOTPCode(r, u.(UserPasser), passcode); err != nil {
-			fc := b.getFailCodeFromTOTPCodeConsumeError(err)
+			var fc FailCode
+			switch err {
+			case ErrWrongTOTPCode:
+				fc = FailCodeIncorrectTOTPCode
+			case ErrTOTPCodeHasBeenUsed:
+				fc = FailCodeTOTPCodeHasBeenUsed
+			default:
+				panic(err)
+			}
 			setFailCodeFlash(w, fc)
 			setWrongResetPasswordInputFlash(w, WrongResetPasswordInputFlash{
 				Password:        password,
@@ -1161,21 +1154,17 @@ func (b *Builder) doResetPassword(w http.ResponseWriter, r *http.Request) {
 
 	err = u.(UserPasser).ConsumeResetPasswordToken(b.db, b.newUserObject())
 	if err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		http.Redirect(w, r, failRedirectURL, http.StatusFound)
-		return
+		panic(err)
 	}
 
 	err = u.(UserPasser).SetPassword(b.db, b.newUserObject(), password)
 	if err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		http.Redirect(w, r, failRedirectURL, http.StatusFound)
-		return
+		panic(err)
 	}
 
 	if b.afterResetPasswordHook != nil {
 		if herr := b.afterResetPasswordHook(r, u); herr != nil {
-			setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+			setNoticeOrPanic(w, herr)
 			http.Redirect(w, r, failRedirectURL, http.StatusFound)
 			return
 		}
@@ -1256,7 +1245,7 @@ func (b *Builder) doFormChangePassword(w http.ResponseWriter, r *http.Request) {
 		if ne, ok := err.(*NoticeError); ok {
 			setNoticeFlash(w, ne)
 		} else {
-			fc := FailCodeSystemError
+			var fc FailCode
 			switch err {
 			case ErrWrongPassword:
 				fc = FailCodeIncorrectPassword
@@ -1268,6 +1257,8 @@ func (b *Builder) doFormChangePassword(w http.ResponseWriter, r *http.Request) {
 				fc = FailCodeIncorrectTOTPCode
 			case ErrTOTPCodeHasBeenUsed:
 				fc = FailCodeTOTPCodeHasBeenUsed
+			default:
+				panic(err)
 			}
 			setFailCodeFlash(w, fc)
 		}
@@ -1292,42 +1283,53 @@ func (b *Builder) totpDo(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	var err error
-	var user interface{}
+	var claims *UserClaims
+	claims, err := parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
+	if err != nil {
+		http.Redirect(w, r, b.LogoutURL, http.StatusFound)
+		return
+	}
+
+	user, err := b.findUserByID(claims.UserID)
+	if err != nil {
+		if err == ErrUserNotFound {
+			setFailCodeFlash(w, FailCodeUserNotFound)
+			http.Redirect(w, r, b.LogoutURL, http.StatusFound)
+			return
+		}
+		panic(err)
+	}
+
 	failRedirectURL := b.LogoutURL
 	defer func() {
+		if perr := recover(); perr != nil {
+			panic(perr)
+		}
 		if err != nil {
 			if b.afterFailedToLoginHook != nil {
 				if herr := b.afterFailedToLoginHook(r, user, err); herr != nil {
-					setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+					setNoticeOrPanic(w, herr)
 				}
 			}
 			http.Redirect(w, r, failRedirectURL, http.StatusFound)
 		}
 	}()
 
-	var claims *UserClaims
-	claims, err = parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
-	if err != nil {
-		return
-	}
-
-	user, err = b.findUserByID(claims.UserID)
-	if err != nil {
-		if err == ErrUserNotFound {
-			setFailCodeFlash(w, FailCodeUserNotFound)
-		} else {
-			setFailCodeFlash(w, FailCodeSystemError)
-		}
-		return
-	}
 	u := user.(UserPasser)
 
 	otp := r.FormValue("otp")
 	isTOTPSetup := u.GetIsTOTPSetup()
 
 	if err = b.consumeTOTPCode(r, u, otp); err != nil {
-		fc := b.getFailCodeFromTOTPCodeConsumeError(err)
+		var fc FailCode
+		switch err {
+		case ErrWrongTOTPCode:
+			fc = FailCodeIncorrectTOTPCode
+		case ErrTOTPCodeHasBeenUsed:
+			fc = FailCodeTOTPCodeHasBeenUsed
+		default:
+			panic(err)
+		}
 		setFailCodeFlash(w, fc)
 		failRedirectURL = b.totpValidatePageURL
 		if !isTOTPSetup {
@@ -1338,8 +1340,7 @@ func (b *Builder) totpDo(w http.ResponseWriter, r *http.Request) {
 
 	if !isTOTPSetup {
 		if err = u.SetIsTOTPSetup(b.db, b.newUserObject(), true); err != nil {
-			setFailCodeFlash(w, FailCodeSystemError)
-			return
+			panic(err)
 		}
 	}
 
@@ -1347,15 +1348,14 @@ func (b *Builder) totpDo(w http.ResponseWriter, r *http.Request) {
 	if b.afterLoginHook != nil {
 		setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
 		if err = b.afterLoginHook(r, user); err != nil {
-			setNoticeOrFailCodeFlash(w, err, FailCodeSystemError)
+			setNoticeOrPanic(w, err)
 			return
 		}
 	}
 
 	err = b.setSecureCookiesByClaims(w, user, *claims)
 	if err != nil {
-		setFailCodeFlash(w, FailCodeSystemError)
-		return
+		panic(err)
 	}
 
 	redirectURL := b.homePageURLFunc(r, user)

+ 11 - 0
login/flash.go

@@ -113,6 +113,17 @@ func setNoticeOrFailCodeFlash(w http.ResponseWriter, err error, c FailCode) {
 	setFailCodeFlash(w, c)
 }
 
+func setNoticeOrPanic(w http.ResponseWriter, err error) {
+	if err == nil {
+		return
+	}
+	ne, ok := err.(*NoticeError)
+	if !ok {
+		panic(err)
+	}
+	setNoticeFlash(w, ne)
+}
+
 const wrongLoginInputFlashCookieName = "qor5_wli_flash"
 
 type WrongLoginInputFlash struct {

+ 3 - 15
login/middleware.go

@@ -126,7 +126,7 @@ func (b *Builder) Middleware(cfgs ...MiddlewareConfig) func(next http.Handler) h
 							setWarnCodeFlash(w, WarnCodePasswordHasBeenChanged)
 						}
 					default:
-						setFailCodeFlash(w, FailCodeSystemError)
+						panic(err)
 					}
 					if path == b.LogoutURL {
 						next.ServeHTTP(w, r)
@@ -160,19 +160,7 @@ func (b *Builder) Middleware(cfgs ...MiddlewareConfig) func(next http.Handler) h
 				oldSessionToken := b.mustGetSessionToken(*claims)
 
 				claims.RegisteredClaims = b.genBaseSessionClaim(claims.UserID)
-				if err := b.setAuthCookiesFromUserClaims(w, claims, secureSalt); err != nil {
-					if !mustLogin {
-						next.ServeHTTP(w, r)
-						return
-					}
-					setFailCodeFlash(w, FailCodeSystemError)
-					if path == b.LogoutURL {
-						next.ServeHTTP(w, r)
-					} else {
-						http.Redirect(w, r, b.LogoutURL, http.StatusFound)
-					}
-					return
-				}
+				b.setAuthCookiesFromUserClaims(w, claims, secureSalt)
 
 				if b.afterExtendSessionHook != nil {
 					setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
@@ -181,7 +169,7 @@ func (b *Builder) Middleware(cfgs ...MiddlewareConfig) func(next http.Handler) h
 							next.ServeHTTP(w, r)
 							return
 						}
-						setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
+						setNoticeOrPanic(w, herr)
 						http.Redirect(w, r, b.LogoutURL, http.StatusFound)
 						return
 					}

+ 2 - 4
login/views.go

@@ -14,7 +14,6 @@ import (
 	. "github.com/theplant/htmlgo"
 	"golang.org/x/text/language"
 	"golang.org/x/text/language/display"
-	"gorm.io/gorm"
 )
 
 func defaultLoginPage(vh *ViewHelper) web.PageFunc {
@@ -273,12 +272,11 @@ func defaultResetPasswordPage(vh *ViewHelper) web.PageFunc {
 		} else {
 			user, err = vh.FindUserByID(id)
 			if err != nil {
-				if err == gorm.ErrRecordNotFound {
+				if err == ErrUserNotFound {
 					r.Body = Div(Text("user not found"))
 					return r, nil
 				}
-				r.Body = Div(Text("system error"))
-				return r, nil
+				panic(err)
 			}
 		}
 		token := query.Get("token")