BayernMessenger/matrix-sdk-android/src/main/java/im/vector/matrix/android/internal/crypto/keysbackup/KeysBackupPassword.kt

152 lines
4.6 KiB
Kotlin

/*
* Copyright 2019 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Utility to compute a backup private key from a password and vice-versa.
*/
package im.vector.matrix.android.internal.crypto.keysbackup
import androidx.annotation.WorkerThread
import im.vector.matrix.android.api.listeners.ProgressListener
import timber.log.Timber
import java.util.UUID
import javax.crypto.Mac
import javax.crypto.spec.SecretKeySpec
import kotlin.experimental.xor
private const val SALT_LENGTH = 32
private const val DEFAULT_ITERATION = 500_000
data class GeneratePrivateKeyResult(
// The private key
val privateKey: ByteArray,
// the salt used to generate the private key
val salt: String,
// number of key derivations done on the generated private key.
val iterations: Int)
/**
* Compute a private key from a password.
*
* @param password the password to use.
*
* @return a {privateKey, salt, iterations} tuple.
*/
@WorkerThread
fun generatePrivateKeyWithPassword(password: String, progressListener: ProgressListener?): GeneratePrivateKeyResult {
val salt = generateSalt()
val iterations = DEFAULT_ITERATION
val privateKey = deriveKey(password, salt, iterations, progressListener)
return GeneratePrivateKeyResult(privateKey, salt, iterations)
}
/**
* Retrieve a private key from {password, salt, iterations}
*
* @param password the password used to generated the private key.
* @param salt the salt.
* @param iterations number of key derivations.
* @param progressListener the progress listener
*
* @return a private key.
*/
@WorkerThread
fun retrievePrivateKeyWithPassword(password: String,
salt: String,
iterations: Int,
progressListener: ProgressListener? = null): ByteArray {
return deriveKey(password, salt, iterations, progressListener)
}
/**
* Compute a private key by deriving a password and a salt strings.
*
* @param password the password.
* @param salt the salt.
* @param iterations number of derivations.
* @param progressListener a listener to follow progress.
*
* @return a private key.
*/
@WorkerThread
private fun deriveKey(password: String,
salt: String,
iterations: Int,
progressListener: ProgressListener?): ByteArray {
// Note: copied and adapted from MXMegolmExportEncryption
val t0 = System.currentTimeMillis()
// based on https://en.wikipedia.org/wiki/PBKDF2 algorithm
// it is simpler than the generic algorithm because the expected key length is equal to the mac key length.
// noticed as dklen/hlen
// dklen = 256
// hlen = 512
val prf = Mac.getInstance("HmacSHA512")
prf.init(SecretKeySpec(password.toByteArray(), "HmacSHA512"))
// 256 bits key length
val dk = ByteArray(32)
val uc = ByteArray(64)
// U1 = PRF(Password, Salt || INT_32_BE(i)) with i goes from 1 to dklen/hlen
prf.update(salt.toByteArray())
val int32BE = byteArrayOf(0, 0, 0, 1)
prf.update(int32BE)
prf.doFinal(uc, 0)
// copy to the key
System.arraycopy(uc, 0, dk, 0, dk.size)
var lastProgress = -1
for (index in 2..iterations) {
// Uc = PRF(Password, Uc-1)
prf.update(uc)
prf.doFinal(uc, 0)
// F(Password, Salt, c, i) = U1 ^ U2 ^ ... ^ Uc
for (byteIndex in dk.indices) {
dk[byteIndex] = dk[byteIndex] xor uc[byteIndex]
}
val progress = (index + 1) * 100 / iterations
if (progress != lastProgress) {
lastProgress = progress
progressListener?.onProgress(lastProgress, 100)
}
}
Timber.v("KeysBackupPassword", "## deriveKeys() : " + iterations + " in " + (System.currentTimeMillis() - t0) + " ms")
return dk
}
/**
* Generate a 32 chars salt
*/
private fun generateSalt(): String {
val salt = buildString {
do {
append(UUID.randomUUID().toString())
} while (length < SALT_LENGTH)
}
return salt.substring(0, SALT_LENGTH)
}