[refactor] refactor device ID generation with std::random_device and ensure uniqueness

This commit is contained in:
dijunkun
2025-12-09 21:34:05 +08:00
parent 76f5101343
commit 028acaf269
5 changed files with 102 additions and 49 deletions

View File

@@ -104,58 +104,55 @@ std::string DeviceDBManager::HashPasswordWithSalt(const std::string& salt,
return Sha256(salt + password);
}
bool DeviceDBManager::DeviceIdExists(const std::string& device_id) {
if (db_ == nullptr || device_id.empty()) {
return false;
}
const char* sql = "SELECT 1 FROM devices WHERE device_id = ? LIMIT 1;";
sqlite3_stmt* stmt = nullptr;
if (sqlite3_prepare_v2(db_, sql, -1, &stmt, nullptr) != SQLITE_OK) {
return false;
}
sqlite3_bind_text(stmt, 1, device_id.c_str(), -1, SQLITE_TRANSIENT);
bool exists = (sqlite3_step(stmt) == SQLITE_ROW);
sqlite3_finalize(stmt);
return exists;
}
std::string DeviceDBManager::GenerateDeviceId() {
if (db_ == nullptr) {
LOG_ERROR("Database is not initialized in GenerateDeviceId.");
return {};
}
sqlite3_exec(db_, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
sqlite3_stmt* stmt = nullptr;
int rc = sqlite3_prepare_v2(db_, "SELECT next_id FROM device_id_seq;", -1,
&stmt, nullptr);
if (rc != SQLITE_OK) {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
return {};
}
rc = sqlite3_step(stmt);
if (rc != SQLITE_ROW) {
sqlite3_finalize(stmt);
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
return {};
}
int next_id = sqlite3_column_int(stmt, 0);
sqlite3_finalize(stmt);
const int MIN_ID = 100000000;
const int MAX_ID = 999999999;
const int MAX_RETRIES = 100;
std::mt19937 rng(next_id);
std::random_device rd;
std::mt19937 rng(rd());
std::uniform_int_distribution<int> dist(MIN_ID, MAX_ID);
int obfuscated_id = dist(rng);
rc = sqlite3_prepare_v2(db_, "UPDATE device_id_seq SET next_id = ?;", -1,
&stmt, nullptr);
if (rc != SQLITE_OK) {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
return {};
}
sqlite3_bind_int(stmt, 1, next_id + 1);
rc = sqlite3_step(stmt);
sqlite3_finalize(stmt);
// try to generate unique ID
for (int attempt = 0; attempt < MAX_RETRIES; ++attempt) {
int obfuscated_id = dist(rng);
char buf[10] = {0};
snprintf(buf, sizeof(buf), "%09d", obfuscated_id);
std::string device_id(buf);
if (rc != SQLITE_DONE) {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
return {};
// check if ID already exists
if (!DeviceIdExists(device_id)) {
return device_id;
}
}
sqlite3_exec(db_, "COMMIT;", nullptr, nullptr, nullptr);
char buf[10] = {0};
snprintf(buf, sizeof(buf), "%09d", obfuscated_id);
return std::string(buf);
LOG_ERROR("Failed to generate unique device ID after {} attempts.", MAX_RETRIES);
return {};
}
std::string DeviceDBManager::GeneratePassword() {
@@ -230,17 +227,29 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
sqlite3_finalize(stmt);
}
// Device not exists or device_id is empty — generate new, try 10 times
for (int i = 0; i < 10; ++i) {
// Device not exists or device_id is empty — generate new
const int MAX_RETRIES = 10;
for (int i = 0; i < MAX_RETRIES; ++i) {
std::string new_id;
if (device_id == "web") {
std::string new_id = device_id + "-" + GenerateDeviceId();
return {new_id, "", false};
std::string generated_id = GenerateDeviceId();
if (generated_id.empty()) {
LOG_ERROR("Failed to generate device ID for web client.");
return {};
}
new_id = device_id + "-" + generated_id;
} else {
new_id = GenerateDeviceId();
if (new_id.empty()) {
LOG_ERROR("Failed to generate device ID.");
return {};
}
}
std::string new_id = GenerateDeviceId();
if (new_id.empty()) {
LOG_ERROR("Failed to generate device ID.");
return {};
// Check if the generated ID (including web- prefix) already exists
if (DeviceIdExists(new_id)) {
LOG_WARN("Generated ID {} already exists, retrying...", new_id);
continue;
}
std::string new_pwd = GeneratePassword();
@@ -251,12 +260,24 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
std::string salt = GenerateSalt();
std::string hash = HashPasswordWithSalt(salt, new_pwd);
// Use transaction to reduce race condition
sqlite3_exec(db_, "BEGIN TRANSACTION;", nullptr, nullptr, nullptr);
// Double-check ID uniqueness within transaction
if (DeviceIdExists(new_id)) {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
LOG_WARN("Generated ID {} already exists, retrying...", new_id);
continue;
}
const char* insert_sql =
"INSERT INTO devices (device_id, password_hash, password_salt) VALUES "
"(?, ?, ?);";
sqlite3_stmt* stmt = nullptr;
if (sqlite3_prepare_v2(db_, insert_sql, -1, &stmt, nullptr) != SQLITE_OK) {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
LOG_ERROR("Failed to prepare insert statement.");
return {};
}
@@ -268,18 +289,24 @@ DeviceCredential DeviceDBManager::AddDevice(const std::string& device_id,
sqlite3_finalize(stmt);
if (rc == SQLITE_DONE) {
sqlite3_exec(db_, "COMMIT;", nullptr, nullptr, nullptr);
// For web clients, return empty password
if (device_id == "web") {
return {new_id, "", false};
}
return {new_id, new_pwd, false};
} else if (rc == SQLITE_CONSTRAINT) {
LOG_ERROR("{}:{} Insert failed: rc={}, err={}", new_id, new_pwd, rc,
sqlite3_errmsg(db_));
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
LOG_WARN("Insert failed due to constraint (ID may have been inserted concurrently): {}", new_id);
continue;
} else {
sqlite3_exec(db_, "ROLLBACK;", nullptr, nullptr, nullptr);
LOG_ERROR("Insert device failed: {}", sqlite3_errmsg(db_));
return {};
}
}
LOG_ERROR("Failed to generate unique device_id after multiple attempts.");
LOG_ERROR("Failed to generate unique device_id after {} attempts.", MAX_RETRIES);
return {};
}

View File

@@ -40,6 +40,7 @@ class DeviceDBManager {
std::string GenerateDeviceId();
std::string GeneratePassword();
std::string GenerateSalt();
bool DeviceIdExists(const std::string& device_id);
std::string HashPasswordWithSalt(const std::string& salt,
const std::string& password);

View File

@@ -328,3 +328,19 @@ bool SignalNegotiation::new_candidate_mid(websocketpp::connection_hdl hdl,
return true;
}
void SignalNegotiation::OnWebClientDisconnect(const std::string& user_id) {
// Extract pure user_id (remove password part if exists)
std::string pure_user_id = user_id;
size_t at_pos = user_id.find("@");
if (at_pos != std::string::npos) {
pure_user_id = user_id.substr(0, at_pos);
}
// Check if this is a web client (starts with "web-")
if (pure_user_id.find("web-") == 0) {
if (!device_db_manager_->RemoveDevice(pure_user_id)) {
LOG_WARN("Failed to remove web client device [{}] from database", pure_user_id);
}
}
}

View File

@@ -33,6 +33,7 @@ class SignalNegotiation {
bool answer(websocketpp::connection_hdl hdl, const json& j);
bool new_candidate(websocketpp::connection_hdl hdl, const json& j);
bool new_candidate_mid(websocketpp::connection_hdl hdl, const json& j);
void OnWebClientDisconnect(const std::string& user_id);
private:
std::shared_ptr<TransmissionManager> transmission_manager_;

View File

@@ -88,6 +88,10 @@ bool SignalServer::OnClose(websocketpp::connection_hdl hdl) {
if (!user_id.empty()) {
LOG_INFO("Websocket connection [{}|{}] closed", ws_connections_[hdl],
user_id);
// Remove web client from database on disconnect
if (signal_negotiation_) {
signal_negotiation_->OnWebClientDisconnect(user_id);
}
}
ws_connections_.erase(hdl);
return true;
@@ -98,6 +102,10 @@ bool SignalServer::OnFail(websocketpp::connection_hdl hdl) {
if (!user_id.empty()) {
LOG_INFO("Websocket connection [{}|{}] failed", ws_connections_[hdl],
user_id);
// Remove web client from database on disconnect
if (signal_negotiation_) {
signal_negotiation_->OnWebClientDisconnect(user_id);
}
}
return true;
}