summaryrefslogtreecommitdiff
path: root/src/main/kotlin/io/dico/parcels2/storage/exposed/CoroutineTransactionManager.kt
blob: ab707afa6ce750edd29baa399719cc5c2eb9ac64 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package io.dico.parcels2.storage.exposed

import kotlinx.coroutines.experimental.*
import org.jetbrains.exposed.sql.*
import org.jetbrains.exposed.sql.statements.StatementContext
import org.jetbrains.exposed.sql.statements.StatementInterceptor
import org.jetbrains.exposed.sql.statements.expandArgs
import org.jetbrains.exposed.sql.transactions.*
import org.slf4j.LoggerFactory
import java.sql.Connection
import kotlin.coroutines.experimental.CoroutineContext

fun <T> ctransaction(db: Database? = null, statement: suspend Transaction.() -> T): T {
    return ctransaction(TransactionManager.manager.defaultIsolationLevel, 3, db, statement)
}

fun <T> ctransaction(transactionIsolation: Int, repetitionAttempts: Int, db: Database? = null, statement: suspend Transaction.() -> T): T {
    return transaction(transactionIsolation, repetitionAttempts, db) {
        if (this !is CoroutineTransaction) throw IllegalStateException("ctransaction requires CoroutineTransactionManager.")

        val job = async(context = manager.context, start = CoroutineStart.UNDISPATCHED) {
            this@transaction.statement()
        }

        if (job.isActive) {
            runBlocking(context = Unconfined) {
                job.join()
            }
        }

        job.getCompleted()
    }
}

class CoroutineTransactionManager(private val db: Database,
                                  dispatcher: CoroutineDispatcher,
                                  override var defaultIsolationLevel: Int = DEFAULT_ISOLATION_LEVEL) : TransactionManager {
    val context: CoroutineDispatcher = TransactionCoroutineDispatcher(dispatcher)
    private val transaction = ThreadLocal<CoroutineTransaction?>()

    override fun currentOrNull(): Transaction? {


        return transaction.get()
        ?: null
    }

    override fun newTransaction(isolation: Int): Transaction {
        return CoroutineTransaction(this, CoroutineTransactionInterface(db, isolation, transaction)).also { transaction.set(it) }
    }

    private inner class TransactionCoroutineDispatcher(val delegate: CoroutineDispatcher) : CoroutineDispatcher() {

        // When the thread changes, move the transaction to the new thread
        override fun dispatch(context: CoroutineContext, block: Runnable) {
            val existing = transaction.get()

            val newContext: CoroutineContext
            if (existing != null) {
                transaction.set(null)
                newContext = context // + existing
            } else {
                newContext = context
            }

            delegate.dispatch(newContext, Runnable {
                if (existing != null) {
                    transaction.set(existing)
                }

                block.run()
            })
        }

    }

}

private class CoroutineTransaction(val manager: CoroutineTransactionManager,
                                   itf: CoroutineTransactionInterface) : Transaction(itf), CoroutineContext.Element {
    companion object Key : CoroutineContext.Key<CoroutineTransaction>

    override val key: CoroutineContext.Key<CoroutineTransaction> = Key
}

private class CoroutineTransactionInterface(override val db: Database, isolation: Int, val threadLocal: ThreadLocal<CoroutineTransaction?>) : TransactionInterface {
    private val connectionLazy = lazy(LazyThreadSafetyMode.NONE) {
        db.connector().apply {
            autoCommit = false
            transactionIsolation = isolation
        }
    }
    override val connection: Connection
        get() = connectionLazy.value

    override val outerTransaction: CoroutineTransaction? = threadLocal.get()

    override fun commit() {
        if (connectionLazy.isInitialized())
            connection.commit()
    }

    override fun rollback() {
        if (connectionLazy.isInitialized() && !connection.isClosed) {
            connection.rollback()
        }
    }

    override fun close() {
        try {
            if (connectionLazy.isInitialized()) connection.close()
        } finally {
            threadLocal.set(outerTransaction)
        }
    }

}