middlewares.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package login
  2. import (
  3. "context"
  4. "log"
  5. "net/http"
  6. "regexp"
  7. "strings"
  8. "time"
  9. )
  10. type ContextUserKey int
  11. const (
  12. UserKey ContextUserKey = iota
  13. loginWIPKey
  14. )
  15. var staticFileRe = regexp.MustCompile(`\.(css|js|gif|jpg|jpeg|png|ico|svg|ttf|eot|woff|woff2)$`)
  16. func Authenticate(b *Builder) func(next http.Handler) http.Handler {
  17. return func(next http.Handler) http.Handler {
  18. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  19. if staticFileRe.MatchString(strings.ToLower(r.URL.Path)) {
  20. next.ServeHTTP(w, r)
  21. return
  22. }
  23. if _, ok := b.allowURLs[r.URL.Path]; ok {
  24. next.ServeHTTP(w, r)
  25. return
  26. }
  27. path := strings.TrimRight(r.URL.Path, "/")
  28. claims, err := parseUserClaimsFromCookie(r, b.authCookieName, b.secret)
  29. if err != nil {
  30. log.Println(err)
  31. b.setContinueURL(w, r)
  32. if path == b.loginPageURL {
  33. next.ServeHTTP(w, r)
  34. } else {
  35. http.Redirect(w, r, b.loginPageURL, http.StatusFound)
  36. }
  37. return
  38. }
  39. var user interface{}
  40. var secureSalt string
  41. if b.userModel != nil {
  42. var err error
  43. user, err = b.findUserByID(claims.UserID)
  44. if err == nil {
  45. if claims.Provider == "" {
  46. if user.(UserPasser).GetPasswordUpdatedAt() != claims.PassUpdatedAt {
  47. err = ErrUserPassChanged
  48. }
  49. if user.(UserPasser).GetLocked() {
  50. err = ErrUserLocked
  51. }
  52. } else {
  53. user.(OAuthUser).SetAvatar(claims.AvatarURL)
  54. }
  55. }
  56. if err != nil {
  57. log.Println(err)
  58. switch err {
  59. case ErrUserNotFound:
  60. setFailCodeFlash(w, FailCodeUserNotFound)
  61. case ErrUserLocked:
  62. setFailCodeFlash(w, FailCodeUserLocked)
  63. case ErrUserPassChanged:
  64. setWarnCodeFlash(w, WarnCodePasswordHasBeenChanged)
  65. default:
  66. setFailCodeFlash(w, FailCodeSystemError)
  67. }
  68. if path == b.LogoutURL {
  69. next.ServeHTTP(w, r)
  70. } else {
  71. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  72. }
  73. return
  74. }
  75. if b.sessionSecureEnabled {
  76. secureSalt = user.(SessionSecurer).GetSecure()
  77. _, err := parseBaseClaimsFromCookie(r, b.authSecureCookieName, b.secret+secureSalt)
  78. if err != nil {
  79. if path == b.LogoutURL {
  80. next.ServeHTTP(w, r)
  81. } else {
  82. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  83. }
  84. return
  85. }
  86. }
  87. } else {
  88. user = claims
  89. }
  90. if b.autoExtendSession && time.Now().Sub(claims.IssuedAt.Time).Seconds() > float64(b.sessionMaxAge)/10 {
  91. oldSessionToken := b.mustGetSessionToken(*claims)
  92. claims.RegisteredClaims = b.genBaseSessionClaim(claims.UserID)
  93. if err := b.setAuthCookiesFromUserClaims(w, claims, secureSalt); err != nil {
  94. setFailCodeFlash(w, FailCodeSystemError)
  95. if path == b.LogoutURL {
  96. next.ServeHTTP(w, r)
  97. } else {
  98. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  99. }
  100. return
  101. }
  102. if b.afterExtendSessionHook != nil {
  103. setCookieForRequest(r, &http.Cookie{Name: b.authCookieName, Value: b.mustGetSessionToken(*claims)})
  104. if herr := b.afterExtendSessionHook(r, user, oldSessionToken); herr != nil {
  105. setFailCodeFlash(w, FailCodeSystemError)
  106. http.Redirect(w, r, b.LogoutURL, http.StatusFound)
  107. return
  108. }
  109. }
  110. }
  111. r = r.WithContext(context.WithValue(r.Context(), UserKey, user))
  112. if path == b.LogoutURL {
  113. next.ServeHTTP(w, r)
  114. return
  115. }
  116. if claims.Provider == "" && b.totpEnabled {
  117. if !user.(UserPasser).GetIsTOTPSetup() {
  118. if path == b.loginPageURL {
  119. next.ServeHTTP(w, r)
  120. return
  121. }
  122. r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
  123. if path == b.totpSetupPageURL {
  124. next.ServeHTTP(w, r)
  125. return
  126. }
  127. http.Redirect(w, r, b.totpSetupPageURL, http.StatusFound)
  128. return
  129. }
  130. if !claims.TOTPValidated {
  131. if path == b.loginPageURL {
  132. next.ServeHTTP(w, r)
  133. return
  134. }
  135. r = r.WithContext(context.WithValue(r.Context(), loginWIPKey, true))
  136. if path == b.totpValidatePageURL {
  137. next.ServeHTTP(w, r)
  138. return
  139. }
  140. http.Redirect(w, r, b.totpValidatePageURL, http.StatusFound)
  141. return
  142. }
  143. }
  144. if path == b.loginPageURL || path == b.totpSetupPageURL || path == b.totpValidatePageURL {
  145. http.Redirect(w, r, b.homePageURLFunc(r, user), http.StatusFound)
  146. return
  147. }
  148. next.ServeHTTP(w, r)
  149. })
  150. }
  151. }
  152. func GetCurrentUser(r *http.Request) (u interface{}) {
  153. return r.Context().Value(UserKey)
  154. }
  155. func IsLoginWIP(r *http.Request) bool {
  156. v, ok := r.Context().Value(loginWIPKey).(bool)
  157. if !ok {
  158. return false
  159. }
  160. return v
  161. }