180 lines
6.7 KiB
Kotlin
180 lines
6.7 KiB
Kotlin
|
package backup
|
||
|
|
||
|
import aws.sdk.kotlin.services.s3.*
|
||
|
import aws.sdk.kotlin.services.s3.model.CompletedMultipartUpload
|
||
|
import aws.sdk.kotlin.services.s3.model.CompletedPart
|
||
|
import aws.sdk.kotlin.services.s3.model.GetObjectRequest
|
||
|
import aws.sdk.kotlin.services.s3.model.UploadPartResponse
|
||
|
import aws.smithy.kotlin.runtime.content.ByteStream
|
||
|
import aws.smithy.kotlin.runtime.content.toInputStream
|
||
|
import kotlinx.coroutines.*
|
||
|
import java.io.*
|
||
|
import java.nio.file.Files
|
||
|
import java.nio.file.Path
|
||
|
import java.nio.file.attribute.BasicFileAttributeView
|
||
|
import java.nio.file.attribute.BasicFileAttributes
|
||
|
import java.nio.file.attribute.FileAttribute
|
||
|
import java.time.Instant
|
||
|
import java.util.zip.ZipEntry
|
||
|
import java.util.zip.ZipInputStream
|
||
|
import java.util.zip.ZipOutputStream
|
||
|
import kotlin.io.path.createDirectory
|
||
|
|
||
|
class BackupClient(
|
||
|
private val s3: S3Client,
|
||
|
private val bucketName: String,
|
||
|
private val bufSize: Int = 1024 * 1024 * 100
|
||
|
) {
|
||
|
suspend fun upload(file: File) = coroutineScope {
|
||
|
val backupKey = "${file.name}/${Instant.now()}.zip"
|
||
|
PipedInputStream().use { inputStream ->
|
||
|
val outputStream = PipedOutputStream(inputStream)
|
||
|
val zipper = launch(Dispatchers.IO) {
|
||
|
file.compressToZip(outputStream)
|
||
|
}
|
||
|
|
||
|
val data = ByteArray(bufSize)
|
||
|
val initialRead = inputStream.readNBytes(data, 0, bufSize)
|
||
|
if (initialRead == bufSize) {
|
||
|
// Large upload, use multipart
|
||
|
// TODO: multipart uploads can be asynchronous, which would improve
|
||
|
// performance a little bit for big uploads.
|
||
|
val upload = s3.createMultipartUpload {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
}
|
||
|
try {
|
||
|
val uploadParts = mutableListOf<CompletedPart>()
|
||
|
var number = 1
|
||
|
var bytesRead = initialRead
|
||
|
while (bytesRead > 0) {
|
||
|
val part = s3.uploadPart {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
partNumber = number
|
||
|
uploadId = upload.uploadId
|
||
|
body = ByteStream.fromBytes(data.take(bytesRead))
|
||
|
}.asCompletedPart(number)
|
||
|
uploadParts.add(part)
|
||
|
number++
|
||
|
bytesRead = inputStream.readNBytes(data, 0, bufSize)
|
||
|
}
|
||
|
s3.completeMultipartUpload {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
uploadId = upload.uploadId
|
||
|
multipartUpload = CompletedMultipartUpload {
|
||
|
parts = uploadParts
|
||
|
}
|
||
|
}
|
||
|
} catch (e: Exception) {
|
||
|
s3.abortMultipartUpload {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
uploadId = upload.uploadId
|
||
|
}
|
||
|
throw e
|
||
|
}
|
||
|
} else {
|
||
|
// Small upload, use single request
|
||
|
s3.putObject {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
body = ByteStream.fromBytes(data.take(initialRead))
|
||
|
}
|
||
|
}
|
||
|
zipper.join() // Should be instant
|
||
|
}
|
||
|
backupKey
|
||
|
}
|
||
|
suspend fun restore(destination: File, backupKey: String) = coroutineScope {
|
||
|
val req = GetObjectRequest {
|
||
|
bucket = bucketName
|
||
|
key = backupKey
|
||
|
}
|
||
|
s3.getObject(req) { resp ->
|
||
|
ZipInputStream(resp.body?.toInputStream()
|
||
|
?: throw IOException("S3 response is missing body")).use { zipStream ->
|
||
|
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private fun UploadPartResponse.asCompletedPart(number: Int): CompletedPart {
|
||
|
val part = this
|
||
|
return CompletedPart {
|
||
|
partNumber = number
|
||
|
eTag = part.eTag
|
||
|
checksumSha256 = part.checksumSha256
|
||
|
checksumSha1 = part.checksumSha1
|
||
|
checksumCrc32 = part.checksumCrc32
|
||
|
checksumCrc32C = part.checksumCrc32C
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private fun ByteArray.take(n: Int) =
|
||
|
if (n == size) this // No copy
|
||
|
else asList().subList(0, n).toByteArray() // TODO: One copy (toByteArray()), not sure how to do 0 copies here
|
||
|
|
||
|
private fun File.compressToZip(outputStream: OutputStream) = ZipOutputStream(outputStream).use { zipStream ->
|
||
|
val parentDir = this.absoluteFile.parent + "/"
|
||
|
val fileQueue = ArrayDeque<File>()
|
||
|
fileQueue.add(this)
|
||
|
fileQueue.forEach { subFile ->
|
||
|
val path = subFile.absolutePath.removePrefix(parentDir)
|
||
|
val subFiles = subFile.listFiles()
|
||
|
if (subFiles != null) { // Is a directory
|
||
|
val entry = ZipEntry("$path/")
|
||
|
setZipAttributes(entry, subFile.toPath())
|
||
|
zipStream.putNextEntry(entry)
|
||
|
fileQueue.addAll(subFiles)
|
||
|
} else { // Otherwise, treat it as a file
|
||
|
BufferedInputStream(subFile.inputStream()).use { origin ->
|
||
|
val entry = ZipEntry(path)
|
||
|
setZipAttributes(entry, subFile.toPath())
|
||
|
zipStream.putNextEntry(entry)
|
||
|
origin.copyTo(zipStream)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private fun ZipInputStream.decompressToDirectory(directory: Path, bufSize: Int = 1024 * 1024) {
|
||
|
var entry = this.nextEntry
|
||
|
while (entry != null) {
|
||
|
val path = directory.resolve(entry.name)
|
||
|
if (entry.isDirectory) {
|
||
|
path.createDirectory()
|
||
|
} else {
|
||
|
val buf = ByteArray(bufSize)
|
||
|
path.toFile().outputStream().use { fileStream ->
|
||
|
var bytesRead = this.read(buf)
|
||
|
while (bytesRead > 0) {
|
||
|
fileStream.write(buf, 0, bytesRead)
|
||
|
bytesRead = this.read(buf)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
applyZipAttributes(entry, path)
|
||
|
entry = this.nextEntry
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private fun setZipAttributes(entry: ZipEntry, path: Path) {
|
||
|
try {
|
||
|
val attrs = Files.getFileAttributeView(path, BasicFileAttributeView::class.java).readAttributes()
|
||
|
entry.setCreationTime(attrs.creationTime())
|
||
|
entry.setLastModifiedTime(attrs.lastModifiedTime())
|
||
|
entry.setLastAccessTime(attrs.lastAccessTime())
|
||
|
} catch (_: IOException) {
|
||
|
}
|
||
|
}
|
||
|
|
||
|
private fun applyZipAttributes(entry: ZipEntry, path: Path) {
|
||
|
try {
|
||
|
val attrs = Files.getFileAttributeView(path, BasicFileAttributeView::class.java)
|
||
|
attrs.setTimes(entry.lastModifiedTime, entry.lastAccessTime, entry.creationTime)
|
||
|
} catch (_: IOException) {
|
||
|
}
|
||
|
}
|