importer.go 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. package exchange
  2. import (
  3. "fmt"
  4. "net/url"
  5. "reflect"
  6. "strings"
  7. "github.com/google/go-cmp/cmp"
  8. "github.com/mohae/deepcopy"
  9. "gorm.io/gorm"
  10. "gorm.io/gorm/clause"
  11. )
  12. type Importer struct {
  13. resource interface{}
  14. resourceParams int
  15. rtResource reflect.Type
  16. metas []*Meta
  17. pkMetas []*Meta
  18. associations []string
  19. associationsParams []int
  20. validators []func(metaValues MetaValues) error
  21. }
  22. func NewImporter(resource interface{}) *Importer {
  23. return &Importer{
  24. resource: resource,
  25. }
  26. }
  27. func (ip *Importer) Metas(ms ...*Meta) *Importer {
  28. ip.metas = ms
  29. return ip
  30. }
  31. func (ip *Importer) Associations(ts ...string) *Importer {
  32. ip.associations = ts
  33. return ip
  34. }
  35. func (ip *Importer) Validators(vs ...func(metaValues MetaValues) error) *Importer {
  36. ip.validators = vs
  37. return ip
  38. }
  39. func (ip *Importer) Exec(db *gorm.DB, r Reader, opts ...ImporterExecOption) error {
  40. err := ip.validateAndInit()
  41. if err != nil {
  42. return err
  43. }
  44. maxParamsPerSQL := ip.parseOptions(opts...)
  45. fullPrimaryKeyValues := make([][]string, 0, r.Total())
  46. allMetaValues := make([]url.Values, 0, r.Total())
  47. {
  48. headerIdxMetas := make(map[int]*Meta)
  49. header := r.Header()
  50. for i, _ := range ip.metas {
  51. m := ip.metas[i]
  52. hasCol := false
  53. for hi, h := range header {
  54. if h == m.columnHeader {
  55. hasCol = true
  56. headerIdxMetas[hi] = m
  57. break
  58. }
  59. }
  60. if !hasCol {
  61. return fmt.Errorf("column %s not found", m.columnHeader)
  62. }
  63. }
  64. for r.Next() {
  65. metaValues := make(url.Values)
  66. row, err := r.ReadRow()
  67. if err != nil {
  68. return err
  69. }
  70. notEmptyPrimaryKeyValues := make([]string, 0, len(ip.pkMetas))
  71. for i, v := range row {
  72. m, ok := headerIdxMetas[i]
  73. if !ok {
  74. continue
  75. }
  76. metaValues.Set(m.field, v)
  77. if m.primaryKey && v != "" && m.setter == nil {
  78. notEmptyPrimaryKeyValues = append(notEmptyPrimaryKeyValues, v)
  79. }
  80. }
  81. if len(ip.pkMetas) > 0 && len(notEmptyPrimaryKeyValues) == len(ip.pkMetas) {
  82. fullPrimaryKeyValues = append(fullPrimaryKeyValues, notEmptyPrimaryKeyValues)
  83. }
  84. for _, vd := range ip.validators {
  85. err = vd(metaValues)
  86. if err != nil {
  87. return err
  88. }
  89. }
  90. allMetaValues = append(allMetaValues, metaValues)
  91. }
  92. }
  93. // primarykeys:record
  94. oldRecordsMap := make(map[string]interface{})
  95. if len(fullPrimaryKeyValues) > 0 {
  96. oldRecords := reflect.New(reflect.SliceOf(ip.rtResource)).Elem()
  97. tx := preloadDB(db, ip.associations)
  98. searchInKeys := false
  99. {
  100. var total int64
  101. err = db.Model(ip.resource).Count(&total).Error
  102. if err != nil {
  103. return err
  104. }
  105. if int64(len(fullPrimaryKeyValues))*100 < total {
  106. searchInKeys = true
  107. }
  108. }
  109. if searchInKeys {
  110. pkvsGroups := splitStringSliceSlice(fullPrimaryKeyValues, maxParamsPerSQL/len(ip.pkMetas))
  111. for _, g := range pkvsGroups {
  112. if len(ip.pkMetas) == 1 {
  113. vs := make([]string, 0, len(g))
  114. for _, pkvs := range g {
  115. vs = append(vs, pkvs[0])
  116. }
  117. tx = tx.Where(fmt.Sprintf("%s in (?)", ip.pkMetas[0].snakeField), vs)
  118. } else {
  119. var pks []string
  120. for _, m := range ip.pkMetas {
  121. pks = append(pks, m.snakeField)
  122. }
  123. // only test this on Postgres, not sure if this is valid for other databases
  124. tx = tx.Where(fmt.Sprintf("(%s) in (?)", strings.Join(pks, ",")), g)
  125. }
  126. chunkRecords := reflect.New(reflect.SliceOf(ip.rtResource)).Interface()
  127. err = tx.Find(chunkRecords).Error
  128. oldRecords = reflect.AppendSlice(oldRecords, reflect.ValueOf(chunkRecords).Elem())
  129. }
  130. } else {
  131. chunkRecords := reflect.New(reflect.SliceOf(ip.rtResource)).Interface()
  132. err = tx.FindInBatches(chunkRecords, maxParamsPerSQL/len(ip.pkMetas), func(tx *gorm.DB, batch int) error {
  133. oldRecords = reflect.AppendSlice(oldRecords, reflect.ValueOf(chunkRecords).Elem())
  134. return nil
  135. }).Error
  136. }
  137. if err != nil {
  138. return err
  139. }
  140. for i := 0; i < oldRecords.Len(); i++ {
  141. record := oldRecords.Index(i)
  142. pkvs := make([]string, 0, len(ip.pkMetas))
  143. for _, m := range ip.pkMetas {
  144. pkvs = append(pkvs, fmt.Sprintf("%v", record.Elem().FieldByName(m.field).Interface()))
  145. }
  146. oldRecordsMap[strings.Join(pkvs, "$")] = record.Interface()
  147. }
  148. }
  149. records := reflect.New(reflect.SliceOf(ip.rtResource)).Elem()
  150. // key is association
  151. recordsToClearAssociations := make(map[string]reflect.Value)
  152. // key is association
  153. recordsToReplaceAssociations := make(map[string]reflect.Value)
  154. for _, a := range ip.associations {
  155. recordsToClearAssociations[a] = reflect.New(reflect.SliceOf(ip.rtResource)).Elem()
  156. recordsToReplaceAssociations[a] = reflect.New(reflect.SliceOf(ip.rtResource)).Elem()
  157. }
  158. // key is association
  159. toReplaceAssociations := make(map[string][]interface{})
  160. maxAssociationsRecordsLen := make(map[string]int)
  161. for _, metaValues := range allMetaValues {
  162. var record reflect.Value
  163. var oldRecord reflect.Value
  164. {
  165. pkvs := make([]string, 0, len(ip.pkMetas))
  166. for _, m := range ip.pkMetas {
  167. pkvs = append(pkvs, metaValues.Get(m.field))
  168. }
  169. cpkvs := strings.Join(pkvs, "$")
  170. if v, ok := oldRecordsMap[cpkvs]; cpkvs != "" && ok {
  171. record = reflect.ValueOf(v)
  172. oldRecord = reflect.ValueOf(deepcopy.Copy(v))
  173. } else {
  174. record = reflect.New(ip.rtResource.Elem())
  175. }
  176. }
  177. for _, m := range ip.metas {
  178. if m.setter != nil {
  179. err = m.setter(record.Interface(), metaValues.Get(m.field), metaValues)
  180. if err != nil {
  181. return err
  182. }
  183. continue
  184. }
  185. fv := record.Elem().FieldByName(m.field)
  186. err = setValueFromString(fv, metaValues.Get(m.field))
  187. if err != nil {
  188. return err
  189. }
  190. }
  191. if oldRecord.IsValid() {
  192. for _, a := range ip.associations {
  193. newV := record.Elem().FieldByName(a)
  194. oldV := oldRecord.Elem().FieldByName(a)
  195. if !cmp.Equal(newV.Interface(), oldV.Interface()) {
  196. if newV.IsZero() {
  197. recordsToClearAssociations[a] = reflect.Append(recordsToClearAssociations[a], record)
  198. } else {
  199. ft, _ := ip.rtResource.Elem().FieldByName(a)
  200. if !strings.Contains(ft.Tag.Get("gorm"), "many2many") {
  201. ip.clearPrimaryKeyValueForAssociation(newV)
  202. }
  203. recordsToReplaceAssociations[a] = reflect.Append(recordsToReplaceAssociations[a], record)
  204. iNewV := newV.Interface()
  205. if newV.Kind() == reflect.Struct {
  206. v := reflect.New(newV.Type()).Elem()
  207. v.Set(reflect.ValueOf(deepcopy.Copy(iNewV)))
  208. iNewV = v.Addr().Interface()
  209. }
  210. toReplaceAssociations[a] = append(toReplaceAssociations[a], iNewV)
  211. }
  212. if err != nil {
  213. return err
  214. }
  215. }
  216. oldV.Set(reflect.New(oldV.Type()).Elem())
  217. newV.Set(reflect.New(newV.Type()).Elem())
  218. }
  219. if cmp.Equal(record.Interface(), oldRecord.Interface()) {
  220. continue
  221. }
  222. }
  223. for _, a := range ip.associations {
  224. afv := record.Elem().FieldByName(a)
  225. if afv.Type().Kind() == reflect.Slice {
  226. maxAssociationsRecordsLen[a] = afv.Len()
  227. }
  228. }
  229. records = reflect.Append(records, record)
  230. }
  231. batchSize := 10000
  232. {
  233. max := maxParamsPerSQL / ip.resourceParams
  234. for i, a := range ip.associations {
  235. l, ok := maxAssociationsRecordsLen[a]
  236. if ok {
  237. if l > 0 {
  238. if v := maxParamsPerSQL / (l * ip.associationsParams[i]); max > v {
  239. max = v
  240. }
  241. }
  242. } else {
  243. if v := maxParamsPerSQL / ip.associationsParams[i]; max > v {
  244. max = v
  245. }
  246. }
  247. }
  248. if batchSize > max {
  249. batchSize = max
  250. }
  251. }
  252. return db.Transaction(func(tx *gorm.DB) error {
  253. for _, a := range ip.associations {
  254. if recordsToClearAssociations[a].Len() > 0 {
  255. rgs := splitReflectSliceValue(recordsToClearAssociations[a], maxParamsPerSQL/len(ip.pkMetas))
  256. for _, g := range rgs {
  257. err = db.Model(g.Interface()).Association(a).Clear()
  258. if err != nil {
  259. return err
  260. }
  261. }
  262. }
  263. if recordsToReplaceAssociations[a].Len() > 0 {
  264. // TODO: limit batch size
  265. // TODO: it seems not updated in batch from the gorm log
  266. err = db.Model(recordsToReplaceAssociations[a].Interface()).Association(a).Replace(toReplaceAssociations[a]...)
  267. if err != nil {
  268. return err
  269. }
  270. }
  271. }
  272. if records.Len() == 0 {
  273. return nil
  274. }
  275. ocPrimaryCols := make([]clause.Column, 0, len(ip.metas))
  276. for _, m := range ip.metas {
  277. if m.primaryKey {
  278. ocPrimaryCols = append(ocPrimaryCols, clause.Column{
  279. Name: m.snakeField,
  280. })
  281. }
  282. }
  283. // .Session(&gorm.Session{FullSaveAssociations: true}) cannot auto delete associations and not work with many-to-many
  284. return db.Clauses(clause.OnConflict{
  285. Columns: ocPrimaryCols,
  286. UpdateAll: true,
  287. }).Model(ip.resource).CreateInBatches(records.Interface(), batchSize).Error
  288. })
  289. }
  290. func (ip *Importer) clearPrimaryKeyValueForAssociation(v reflect.Value) {
  291. rv := getIndirect(v)
  292. rt := rv.Type()
  293. switch rt.Kind() {
  294. case reflect.Struct:
  295. clearPrimaryKeyValue(rv)
  296. case reflect.Slice:
  297. for i := 0; i < rv.Len(); i++ {
  298. clearPrimaryKeyValue(getIndirect(rv.Index(i)))
  299. }
  300. }
  301. }
  302. func (ip *Importer) validateAndInit() error {
  303. if err := validateResourceAndMetas(ip.resource, ip.metas); err != nil {
  304. return err
  305. }
  306. ip.pkMetas = []*Meta{}
  307. for i, _ := range ip.metas {
  308. m := ip.metas[i]
  309. if m.primaryKey {
  310. ip.pkMetas = append(ip.pkMetas, m)
  311. }
  312. }
  313. ip.rtResource = reflect.TypeOf(ip.resource)
  314. getParamsNumbers(&ip.resourceParams, ip.rtResource.Elem(), ip.associations)
  315. for _, a := range ip.associations {
  316. n := 0
  317. fv, _ := ip.rtResource.Elem().FieldByName(a)
  318. getParamsNumbers(&n, getIndirectStruct(fv.Type), nil)
  319. ip.associationsParams = append(ip.associationsParams, n)
  320. }
  321. return nil
  322. }
  323. func (ip *Importer) parseOptions(opts ...ImporterExecOption) (
  324. maxParamsPerSQL int,
  325. ) {
  326. maxParamsPerSQL = 65000
  327. for _, opt := range opts {
  328. switch v := opt.(type) {
  329. case *maxParamsPerSQLOption:
  330. maxParamsPerSQL = v.v
  331. }
  332. }
  333. return maxParamsPerSQL
  334. }