middleware.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. package login
  2. import (
  3. "context"
  4. "net/http"
  5. "regexp"
  6. "strconv"
  7. "strings"
  8. "time"
  9. )
  10. type ContextUserKey int
  11. const (
  12. UserKey ContextUserKey = iota
  13. loginWIPKey
  14. )
  15. type MiddlewareConfig interface {
  16. middlewareConfig()
  17. }
  18. // LoginNotRequired executes the next handler regardless of whether the user is logged in or not
  19. type LoginNotRequired struct{}
  20. func (*LoginNotRequired) middlewareConfig() {}
  21. // DisableAutoRedirectToHomePage makes it possible to visit login page when user is logged in
  22. type DisableAutoRedirectToHomePage struct{}
  23. func (*DisableAutoRedirectToHomePage) middlewareConfig() {}
  24. func (b *Builder) Middleware(cfgs ...MiddlewareConfig) func(next http.Handler) http.Handler {
  25. mustLogin := true
  26. autoRedirectToHomePage := true
  27. for _, cfg := range cfgs {
  28. switch cfg.(type) {
  29. case *LoginNotRequired:
  30. mustLogin = false
  31. case *DisableAutoRedirectToHomePage:
  32. autoRedirectToHomePage = false
  33. }
  34. }
  35. whiteList := map[string]struct{}{
  36. b.oauthBeginURL: {},
  37. b.oauthCallbackURL: {},
  38. b.oauthCallbackCompleteURL: {},
  39. b.passwordLoginURL: {},
  40. b.forgetPasswordPageURL: {},
  41. b.sendResetPasswordLinkURL: {},
  42. b.resetPasswordLinkSentPageURL: {},
  43. b.resetPasswordURL: {},
  44. b.resetPasswordPageURL: {},
  45. b.validateTOTPURL: {},
  46. }
  47. staticFileRe := regexp.MustCompile(`\.(css|js|gif|jpg|jpeg|png|ico|svg|ttf|eot|woff|woff2)$`)
  48. return func(next http.Handler) http.Handler {
  49. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  50. if staticFileRe.MatchString(strings.ToLower(r.URL.Path)) {
  51. next.ServeHTTP(w, r)
  52. return
  53. }
  54. if _, ok := whiteList[r.URL.Path]; ok {
  55. next.ServeHTTP(w, r)
  56. return
  57. }
  58. path := strings.TrimRight(r.URL.Path, "/")
  59. claims, err := parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
  60. if err != nil {
  61. if !mustLogin {
  62. next.ServeHTTP(w, r)
  63. return
  64. }
  65. if r.Method == http.MethodGet {
  66. b.setContinueURL(w, r)
  67. }
  68. if path == b.loginPageURL {
  69. next.ServeHTTP(w, r)
  70. } else {
  71. http.Redirect(w, r, b.loginPageURL, http.StatusFound)
  72. }
  73. return
  74. }
  75. var user interface{}
  76. var secureSalt string
  77. if b.userModel != nil {
  78. var err error
  79. user, err = b.findUserByID(claims.UserID)
  80. if err == nil {
  81. if claims.Provider == "" {
  82. if user.(UserPasser).GetPasswordUpdatedAt() != claims.PassUpdatedAt {
  83. err = ErrPasswordChanged
  84. }
  85. if user.(UserPasser).GetLocked() {
  86. err = ErrUserLocked
  87. }
  88. } else {
  89. user.(OAuthUser).SetAvatar(claims.AvatarURL)
  90. }
  91. }
  92. if err != nil {
  93. if !mustLogin {
  94. next.ServeHTTP(w, r)
  95. return
  96. }
  97. switch err {
  98. case ErrUserNotFound:
  99. setFailCodeFlash(w, FailCodeUserNotFound)
  100. case ErrUserLocked:
  101. setFailCodeFlash(w, FailCodeUserLocked)
  102. case ErrPasswordChanged:
  103. isSelfChange := false
  104. if c, err := r.Cookie(infoCodeFlashCookieName); err == nil {
  105. v, _ := strconv.Atoi(c.Value)
  106. if InfoCode(v) == InfoCodePasswordSuccessfullyChanged {
  107. isSelfChange = true
  108. }
  109. }
  110. if !isSelfChange {
  111. setWarnCodeFlash(w, WarnCodePasswordHasBeenChanged)
  112. }
  113. default:
  114. setFailCodeFlash(w, FailCodeSystemError)
  115. }
  116. if path == b.LogoutURL {
  117. next.ServeHTTP(w, r)
  118. } else {
  119. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  120. }
  121. return
  122. }
  123. if b.sessionSecureEnabled {
  124. secureSalt = user.(SessionSecurer).GetSecure()
  125. _, err := parseBaseClaimsFromCookie(r, b.authSecureCookieName, b.secret+secureSalt)
  126. if err != nil {
  127. if !mustLogin {
  128. next.ServeHTTP(w, r)
  129. return
  130. }
  131. if path == b.LogoutURL {
  132. next.ServeHTTP(w, r)
  133. } else {
  134. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  135. }
  136. return
  137. }
  138. }
  139. } else {
  140. user = claims
  141. }
  142. if b.autoExtendSession && time.Now().Sub(claims.IssuedAt.Time).Seconds() > float64(b.sessionMaxAge)/10 {
  143. oldSessionToken := b.mustGetSessionToken(*claims)
  144. claims.RegisteredClaims = b.genBaseSessionClaim(claims.UserID)
  145. if err := b.setAuthCookiesFromUserClaims(w, claims, secureSalt); err != nil {
  146. if !mustLogin {
  147. next.ServeHTTP(w, r)
  148. return
  149. }
  150. setFailCodeFlash(w, FailCodeSystemError)
  151. if path == b.LogoutURL {
  152. next.ServeHTTP(w, r)
  153. } else {
  154. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  155. }
  156. return
  157. }
  158. if b.afterExtendSessionHook != nil {
  159. setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
  160. if herr := b.afterExtendSessionHook(r, user, oldSessionToken); herr != nil {
  161. if !mustLogin {
  162. next.ServeHTTP(w, r)
  163. return
  164. }
  165. setNoticeOrFailCodeFlash(w, herr, FailCodeSystemError)
  166. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  167. return
  168. }
  169. }
  170. }
  171. r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
  172. if path == b.LogoutURL {
  173. next.ServeHTTP(w, r)
  174. return
  175. }
  176. if claims.Provider == "" && b.totpEnabled {
  177. if !user.(UserPasser).GetIsTOTPSetup() {
  178. if path == b.loginPageURL {
  179. next.ServeHTTP(w, r)
  180. return
  181. }
  182. r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
  183. if path == b.totpSetupPageURL {
  184. next.ServeHTTP(w, r)
  185. return
  186. }
  187. http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound)
  188. return
  189. }
  190. if !claims.TOTPValidated {
  191. if path == b.loginPageURL {
  192. next.ServeHTTP(w, r)
  193. return
  194. }
  195. r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
  196. if path == b.totpValidatePageURL {
  197. next.ServeHTTP(w, r)
  198. return
  199. }
  200. http.Redirect(w, r, b.totpValidatePageURL, http.StatusFound)
  201. return
  202. }
  203. }
  204. if autoRedirectToHomePage {
  205. if path == b.loginPageURL || path == b.totpSetupPageURL || path == b.totpValidatePageURL {
  206. http.Redirect(w, r, b.homePageURLFunc(r, user), http.StatusFound)
  207. return
  208. }
  209. }
  210. next.ServeHTTP(w, r)
  211. })
  212. }
  213. }
  214. func GetCurrentUser(r *http.Request) (u interface{}) {
  215. return r.Context().Value(UserKey)
  216. }
  217. // IsLoginWIP indicates whether the user is in an intermediate step of login process,
  218. // such as on the TOTP validation page
  219. func IsLoginWIP(r *http.Request) bool {
  220. v, ok := r.Context().Value(loginWIPKey).(bool)
  221. if !ok {
  222. return false
  223. }
  224. return v
  225. }