helper.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. package activity
  2. import (
  3. "context"
  4. "fmt"
  5. "reflect"
  6. "strings"
  7. "github.com/qor5/admin/presets"
  8. "gorm.io/gorm"
  9. )
  10. func findOldWithSlug(obj interface{}, slug string, db *gorm.DB) (interface{}, bool) {
  11. if slug == "" {
  12. return findOld(obj, db)
  13. }
  14. var (
  15. objValue = reflect.Indirect(reflect.ValueOf(obj))
  16. old = reflect.New(objValue.Type()).Interface()
  17. )
  18. if slugger, ok := obj.(presets.SlugDecoder); ok {
  19. cs := slugger.PrimaryColumnValuesBySlug(slug)
  20. for key, value := range cs {
  21. db = db.Where(fmt.Sprintf("%s = ?", key), value)
  22. }
  23. } else {
  24. db = db.Where("id = ?", slug)
  25. }
  26. if db.First(old).Error != nil {
  27. return nil, false
  28. }
  29. return old, true
  30. }
  31. func findOld(obj interface{}, db *gorm.DB) (interface{}, bool) {
  32. var (
  33. objValue = reflect.Indirect(reflect.ValueOf(obj))
  34. old = reflect.New(objValue.Type()).Interface()
  35. sqls []string
  36. vars []interface{}
  37. )
  38. stmt := &gorm.Statement{DB: db}
  39. if err := stmt.Parse(obj); err != nil {
  40. return nil, false
  41. }
  42. for _, dbName := range stmt.Schema.DBNames {
  43. if field := stmt.Schema.LookUpField(dbName); field != nil && field.PrimaryKey {
  44. if value, isZero := field.ValueOf(db.Statement.Context, objValue); !isZero {
  45. sqls = append(sqls, fmt.Sprintf("%v = ?", dbName))
  46. vars = append(vars, value)
  47. }
  48. }
  49. }
  50. if len(sqls) == 0 || len(vars) == 0 || len(sqls) != len(vars) {
  51. return nil, false
  52. }
  53. if db.Where(strings.Join(sqls, " AND "), vars...).First(old).Error != nil {
  54. return nil, false
  55. }
  56. return old, true
  57. }
  58. // getPrimaryKey get primary keys from a model
  59. func getPrimaryKey(t reflect.Type) (keys []string) {
  60. if t.Kind() != reflect.Struct {
  61. return
  62. }
  63. for i := 0; i < t.NumField(); i++ {
  64. if strings.Contains(t.Field(i).Tag.Get("gorm"), "primary") {
  65. keys = append(keys, t.Field(i).Name)
  66. continue
  67. }
  68. if t.Field(i).Type.Kind() == reflect.Ptr && t.Field(i).Anonymous {
  69. keys = append(keys, getPrimaryKey(t.Field(i).Type.Elem())...)
  70. }
  71. if t.Field(i).Type.Kind() == reflect.Struct && t.Field(i).Anonymous {
  72. keys = append(keys, getPrimaryKey(t.Field(i).Type)...)
  73. }
  74. }
  75. return
  76. }
  77. func ContextWithCreator(ctx context.Context, name string) context.Context {
  78. return context.WithValue(ctx, CreatorContextKey, name)
  79. }
  80. func ContextWithDB(ctx context.Context, db *gorm.DB) context.Context {
  81. return context.WithValue(ctx, DBContextKey, db)
  82. }
  83. func getBasicModel(m interface{}) interface{} {
  84. if preset, ok := m.(*presets.ModelBuilder); ok {
  85. return preset.NewModel()
  86. }
  87. return m
  88. }