summaryrefslogtreecommitdiff
path: root/src/main/kotlin/io/dico/parcels2/storage/ExposedExtensions.kt
blob: bb6133d39261dbbf4b1d63bb975ba54cd34d855a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
package io.dico.parcels2.storage

import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.InsertStatement
import org.jetbrains.exposed.sql.statements.UpdateStatement
import org.jetbrains.exposed.sql.transactions.TransactionManager

/*
 * insertOrUpdate from https://github.com/JetBrains/Exposed/issues/167#issuecomment-403837917
 */
inline fun <T : Table> T.insertOrUpdate(vararg onDuplicateUpdateKeys: Column<*>, body: T.(InsertStatement<Number>) -> Unit) =
    InsertOrUpdate<Number>(onDuplicateUpdateKeys, this).apply {
        body(this)
        execute(TransactionManager.current())
    }

class InsertOrUpdate<Key : Any>(
    private val onDuplicateUpdateKeys: Array<out Column<*>>,
    table: Table,
    isIgnore: Boolean = false
) : InsertStatement<Key>(table, isIgnore) {
    override fun prepareSQL(transaction: Transaction): String {
        val onUpdateSQL = if (onDuplicateUpdateKeys.isNotEmpty()) {
            " ON DUPLICATE KEY UPDATE " + onDuplicateUpdateKeys.joinToString { "${transaction.identity(it)}=VALUES(${transaction.identity(it)})" }
        } else ""
        return super.prepareSQL(transaction) + onUpdateSQL
    }
}


class UpsertStatement<Key : Any>(table: Table, conflictColumn: Column<*>? = null, conflictIndex: Index? = null)
    : InsertStatement<Key>(table, false) {
    val indexName: String
    val indexColumns: List<Column<*>>

    init {
        if (conflictIndex != null) {
            indexName = conflictIndex.indexName
            indexColumns = conflictIndex.columns
        } else if (conflictColumn != null) {
            indexName = conflictColumn.name
            indexColumns = listOf(conflictColumn)
        } else {
            throw IllegalArgumentException()
        }
    }

    override fun prepareSQL(transaction: Transaction): String {
        val insertSQL = super.prepareSQL(transaction)
        val args = arguments!!.first()
        val map = mutableMapOf<Column<Any?>, Any?>().apply { args.forEach { put(it.first.castUnchecked(), it.second) } }

        val updateSQL = updateBody(table, UpdateStatement(table, null, combineAsConjunctions(indexColumns.castUnchecked(), map))) {
            map.forEach { col, value ->
                if (col !in columns) {
                    it[col] = value
                }
            }
        }.prepareSQL(transaction)

        val builder = StringBuilder().apply {
            append(insertSQL)
            append(" ON CONFLICT(")
            append(indexName)
            append(") DO UPDATE ")
            append(updateSQL)
        }

        return builder.toString().also { println(it) }
    }

    private companion object {

        inline fun <T : Table> updateBody(table: T, updateStatement: UpdateStatement,
                                          body: T.(UpdateStatement) -> Unit): UpdateStatement {
            table.body(updateStatement)
            return updateStatement
        }

        @Suppress("UNCHECKED_CAST", "NOTHING_TO_INLINE")
        inline fun <T> Any.castUnchecked() = this as T

        private val absent = Any() // marker object
        fun combineAsConjunctions(columns: Iterable<Column<Any?>>, map: Map<Column<Any?>, Any?>): Op<Boolean>? {
            return with(SqlExpressionBuilder) {
                columns.fold<Column<Any?>, Op<Boolean>?>(null) { op, col ->
                    val arg = map.getOrDefault(col, absent)
                    if (arg === absent) return@fold op
                    op?.let { it and (col eq arg) } ?: col eq arg
                }
            }
        }

    }
}

inline fun <T : Table> T.upsert(conflictColumn: Column<*>? = null, conflictIndex: Index? = null, body: T.(UpsertStatement<Number>) -> Unit) =
    UpsertStatement<Number>(this, conflictColumn, conflictIndex).apply {
        body(this)
        execute(TransactionManager.current())
    }

fun Table.indexR(customIndexName:String? = null, isUnique: Boolean = false, vararg columns: Column<*>): Index {
    val index = Index(columns.toList(), isUnique, customIndexName)
    indices.add(index)
    return index
}

fun Table.uniqueIndexR(customIndexName:String? = null, vararg columns: Column<*>): Index = indexR(customIndexName, true, *columns)