Add WAL incremental sync and auto-decrypt debounce

Introduces experimental WAL (Write-Ahead Logging) incremental sync for WeChat database decryption, allowing real-time monitoring and incremental updates to the working directory database. Adds UI and config options to enable WAL support and configure auto-decrypt debounce interval. Updates related services, context, and data source logic to support WAL file handling, incremental decryption, and improved process detection. Also improves single-instance process checks and updates documentation for these new features.
This commit is contained in:
lx1056758714-glitch
2026-01-22 21:08:36 +08:00
parent fcda41f03a
commit 8fe3d595e7
18 changed files with 896 additions and 101 deletions

View File

@@ -3,6 +3,7 @@ package chatlog
import (
"fmt"
"path/filepath"
"strconv"
"time"
"github.com/rs/zerolog/log"
@@ -141,11 +142,13 @@ func (a *App) refresh() {
case <-a.stopRefresh:
return
case <-tick.C:
var processErr error
// 如果当前账号为空,尝试查找微信进程
if a.ctx.Current == nil {
// 获取微信实例
instances := a.m.wechat.GetWeChatInstances()
if len(instances) > 0 {
instances, err := a.m.wechat.GetWeChatInstancesWithError()
processErr = err
if err == nil && len(instances) > 0 {
// 找到微信进程,设置第一个为当前账号
a.ctx.SwitchCurrent(instances[0])
log.Info().Msgf("检测到微信进程PID: %d已设置为当前账号", instances[0].PID)
@@ -168,7 +171,11 @@ func (a *App) refresh() {
}
a.infoBar.UpdateAccount(a.ctx.Account)
a.infoBar.UpdateBasicInfo(a.ctx.PID, a.ctx.FullVersion, a.ctx.ExePath)
a.infoBar.UpdateStatus(a.ctx.Status)
statusText := a.ctx.Status
if a.ctx.PID == 0 && processErr != nil {
statusText = fmt.Sprintf("[red]获取进程失败: %v[white]", processErr)
}
a.infoBar.UpdateStatus(statusText)
a.infoBar.UpdateDataKey(a.ctx.DataKey)
a.infoBar.UpdateImageKey(a.ctx.ImgKey)
a.infoBar.UpdatePlatform(a.ctx.Platform)
@@ -182,10 +189,19 @@ func (a *App) refresh() {
} else {
a.infoBar.UpdateHTTPServer("[未启动]")
}
autoDecryptText := "[未开启]"
if a.ctx.AutoDecrypt {
a.infoBar.UpdateAutoDecrypt("[green][已开启][white]")
if a.ctx.AutoDecryptDebounce > 0 {
autoDecryptText = fmt.Sprintf("[green][已开启][white] %dms", a.ctx.AutoDecryptDebounce)
} else {
autoDecryptText = "[green][已开启][white]"
}
}
a.infoBar.UpdateAutoDecrypt(autoDecryptText)
if a.ctx.WalEnabled {
a.infoBar.UpdateWal("[green][已启用][white]")
} else {
a.infoBar.UpdateAutoDecrypt("[未启]")
a.infoBar.UpdateWal("[未启]")
}
// Update latest message in footer
@@ -602,6 +618,16 @@ func (a *App) settingSelected(i *menu.Item) {
description: "配置微信数据文件所在目录",
action: a.settingDataDir,
},
{
name: "启用 WAL 支持",
description: "同步并监控 .db-wal/.db-shm 文件",
action: a.settingWalEnabled,
},
{
name: "设置自动解密去抖",
description: "配置自动解密触发间隔(ms)",
action: a.settingAutoDecryptDebounce,
},
}
subMenu := menu.NewSubMenu("设置")
@@ -759,6 +785,80 @@ func (a *App) settingDataDir() {
a.SetFocus(formView)
}
func (a *App) settingWalEnabled() {
formView := form.NewForm("设置 WAL 支持")
tempWalEnabled := a.ctx.WalEnabled
formView.AddCheckbox("启用 WAL 支持", tempWalEnabled, func(checked bool) {
tempWalEnabled = checked
})
formView.AddButton("保存", func() {
a.ctx.SetWalEnabled(tempWalEnabled)
a.mainPages.RemovePage("submenu2")
if tempWalEnabled {
a.showInfo("WAL 支持已开启")
} else {
a.showInfo("WAL 支持已关闭")
}
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
func (a *App) settingAutoDecryptDebounce() {
formView := form.NewForm("设置自动解密去抖")
tempDebounceText := ""
if a.ctx.AutoDecryptDebounce > 0 {
tempDebounceText = strconv.Itoa(a.ctx.AutoDecryptDebounce)
}
formView.AddInputField("去抖时长(ms)", tempDebounceText, 0, func(textToCheck string, lastChar rune) bool {
if textToCheck == "" {
return true
}
for _, r := range textToCheck {
if r < '0' || r > '9' {
return false
}
}
return true
}, func(text string) {
tempDebounceText = text
})
formView.AddButton("保存", func() {
if tempDebounceText == "" {
a.ctx.SetAutoDecryptDebounce(0)
a.mainPages.RemovePage("submenu2")
a.showInfo("已恢复默认去抖时长")
return
}
value, err := strconv.Atoi(tempDebounceText)
if err != nil {
a.showError(fmt.Errorf("去抖时长必须为数字"))
return
}
a.ctx.SetAutoDecryptDebounce(value)
a.mainPages.RemovePage("submenu2")
a.showInfo(fmt.Sprintf("去抖时长已设置为 %dms", value))
})
formView.AddButton("取消", func() {
a.mainPages.RemovePage("submenu2")
})
a.mainPages.AddPage("submenu2", formView, true, true)
a.SetFocus(formView)
}
// selectAccountSelected 处理切换账号菜单项的选择事件
func (a *App) selectAccountSelected(i *menu.Item) {
// 创建子菜单

View File

@@ -15,6 +15,8 @@ type ServerConfig struct {
WorkDir string `mapstructure:"work_dir"`
HTTPAddr string `mapstructure:"http_addr"`
AutoDecrypt bool `mapstructure:"auto_decrypt"`
WalEnabled bool `mapstructure:"wal_enabled"`
AutoDecryptDebounce int `mapstructure:"auto_decrypt_debounce"`
SaveDecryptedMedia bool `mapstructure:"save_decrypted_media"`
Webhook *Webhook `mapstructure:"webhook"`
}
@@ -51,6 +53,14 @@ func (c *ServerConfig) GetAutoDecrypt() bool {
return c.AutoDecrypt
}
func (c *ServerConfig) GetWalEnabled() bool {
return c.WalEnabled
}
func (c *ServerConfig) GetAutoDecryptDebounce() int {
return c.AutoDecryptDebounce
}
func (c *ServerConfig) GetHTTPAddr() string {
if c.HTTPAddr == "" {
c.HTTPAddr = DefalutHTTPAddr

View File

@@ -21,6 +21,8 @@ type ProcessConfig struct {
WorkDir string `mapstructure:"work_dir" json:"work_dir"`
HTTPEnabled bool `mapstructure:"http_enabled" json:"http_enabled"`
HTTPAddr string `mapstructure:"http_addr" json:"http_addr"`
WalEnabled bool `mapstructure:"wal_enabled" json:"wal_enabled"`
AutoDecryptDebounce int `mapstructure:"auto_decrypt_debounce" json:"auto_decrypt_debounce"`
LastTime int64 `mapstructure:"last_time" json:"last_time"`
Files []File `mapstructure:"files" json:"files"`
}

View File

@@ -50,6 +50,8 @@ type Context struct {
// 自动解密
AutoDecrypt bool
LastSession time.Time
WalEnabled bool
AutoDecryptDebounce int
// 当前选中的微信实例
Current *wechat.Account
@@ -103,6 +105,8 @@ func (c *Context) SwitchHistory(account string) {
c.WorkDir = history.WorkDir
c.HTTPEnabled = history.HTTPEnabled
c.HTTPAddr = history.HTTPAddr
c.WalEnabled = history.WalEnabled
c.AutoDecryptDebounce = history.AutoDecryptDebounce
} else {
c.Account = ""
c.Platform = ""
@@ -114,6 +118,8 @@ func (c *Context) SwitchHistory(account string) {
c.WorkDir = ""
c.HTTPEnabled = false
c.HTTPAddr = ""
c.WalEnabled = false
c.AutoDecryptDebounce = 0
}
}
@@ -182,6 +188,18 @@ func (c *Context) GetDataKey() string {
return c.DataKey
}
func (c *Context) GetWalEnabled() bool {
c.mu.RLock()
defer c.mu.RUnlock()
return c.WalEnabled
}
func (c *Context) GetAutoDecryptDebounce() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AutoDecryptDebounce
}
func (c *Context) GetHTTPAddr() string {
if c.HTTPAddr == "" {
c.HTTPAddr = DefalutHTTPAddr
@@ -260,6 +278,26 @@ func (c *Context) SetAutoDecrypt(enabled bool) {
c.UpdateConfig()
}
func (c *Context) SetWalEnabled(enabled bool) {
c.mu.Lock()
defer c.mu.Unlock()
if c.WalEnabled == enabled {
return
}
c.WalEnabled = enabled
c.UpdateConfig()
}
func (c *Context) SetAutoDecryptDebounce(debounce int) {
c.mu.Lock()
defer c.mu.Unlock()
if c.AutoDecryptDebounce == debounce {
return
}
c.AutoDecryptDebounce = debounce
c.UpdateConfig()
}
// 更新配置
func (c *Context) UpdateConfig() {
@@ -275,6 +313,8 @@ func (c *Context) UpdateConfig() {
WorkDir: c.WorkDir,
HTTPEnabled: c.HTTPEnabled,
HTTPAddr: c.HTTPAddr,
WalEnabled: c.WalEnabled,
AutoDecryptDebounce: c.AutoDecryptDebounce,
}
if c.conf.History == nil {

View File

@@ -33,6 +33,7 @@ type Config interface {
GetPlatform() string
GetVersion() int
GetWebhook() *conf.Webhook
GetWalEnabled() bool
}
func NewService(conf Config) *Service {
@@ -43,7 +44,7 @@ func NewService(conf Config) *Service {
}
func (s *Service) Start() error {
db, err := wechatdb.New(s.conf.GetWorkDir(), s.conf.GetPlatform(), s.conf.GetVersion())
db, err := wechatdb.New(s.conf.GetWorkDir(), s.conf.GetPlatform(), s.conf.GetVersion(), s.conf.GetWalEnabled())
if err != nil {
return err
}

View File

@@ -2,9 +2,14 @@ package wechat
import (
"context"
"encoding/binary"
"encoding/hex"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
@@ -14,6 +19,7 @@ import (
"github.com/sjzar/chatlog/internal/errors"
"github.com/sjzar/chatlog/internal/wechat"
"github.com/sjzar/chatlog/internal/wechat/decrypt"
"github.com/sjzar/chatlog/internal/wechat/decrypt/common"
"github.com/sjzar/chatlog/pkg/filemonitor"
"github.com/sjzar/chatlog/pkg/util"
)
@@ -27,17 +33,37 @@ type Service struct {
conf Config
lastEvents map[string]time.Time
pendingActions map[string]bool
pendingEvents map[string]*pendingEvent
walStates map[string]*walState
mutex sync.Mutex
fm *filemonitor.FileMonitor
errorHandler func(error)
}
type pendingEvent struct {
sawDB bool
sawWal bool
}
type walState struct {
offset int64
salt1 uint32
salt2 uint32
}
type walFrame struct {
pageNo uint32
data []byte
}
type Config interface {
GetDataKey() string
GetDataDir() string
GetWorkDir() string
GetPlatform() string
GetVersion() int
GetWalEnabled() bool
GetAutoDecryptDebounce() int
}
func NewService(conf Config) *Service {
@@ -45,6 +71,8 @@ func NewService(conf Config) *Service {
conf: conf,
lastEvents: make(map[string]time.Time),
pendingActions: make(map[string]bool),
pendingEvents: make(map[string]*pendingEvent),
walStates: make(map[string]*walState),
}
}
@@ -55,8 +83,15 @@ func (s *Service) SetAutoDecryptErrorHandler(handler func(error)) {
// GetWeChatInstances returns all running WeChat instances
func (s *Service) GetWeChatInstances() []*wechat.Account {
wechat.Load()
return wechat.GetAccounts()
instances, _ := s.GetWeChatInstancesWithError()
return instances
}
func (s *Service) GetWeChatInstancesWithError() ([]*wechat.Account, error) {
if err := wechat.Load(); err != nil {
return nil, err
}
return wechat.GetAccounts(), nil
}
// GetDataKey extracts the encryption key from a WeChat process
@@ -84,7 +119,11 @@ func (s *Service) GetImageKey(info *wechat.Account) (string, error) {
func (s *Service) StartAutoDecrypt() error {
log.Info().Msgf("start auto decrypt, data dir: %s", s.conf.GetDataDir())
dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), `.*\.db$`, []string{"fts"})
pattern := `.*\.db$`
if s.conf.GetWalEnabled() {
pattern = `.*\.db(-wal|-shm)?$`
}
dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), pattern, []string{"fts"})
if err != nil {
return err
}
@@ -122,13 +161,25 @@ func (s *Service) DecryptFileCallback(event fsnotify.Event) error {
return nil
}
dbFile := s.normalizeDBFile(event.Name)
isWal := isWalFile(event.Name)
s.mutex.Lock()
s.lastEvents[event.Name] = time.Now()
s.lastEvents[dbFile] = time.Now()
flags, ok := s.pendingEvents[dbFile]
if !ok {
flags = &pendingEvent{}
s.pendingEvents[dbFile] = flags
}
if isWal {
flags.sawWal = true
} else {
flags.sawDB = true
}
if !s.pendingActions[event.Name] {
s.pendingActions[event.Name] = true
if !s.pendingActions[dbFile] {
s.pendingActions[dbFile] = true
s.mutex.Unlock()
go s.waitAndProcess(event.Name)
go s.waitAndProcess(dbFile)
} else {
s.mutex.Unlock()
}
@@ -139,21 +190,85 @@ func (s *Service) DecryptFileCallback(event fsnotify.Event) error {
func (s *Service) waitAndProcess(dbFile string) {
start := time.Now()
for {
time.Sleep(DebounceTime)
debounce := s.getDebounceTimeForFile(dbFile)
maxWait := s.getMaxWaitTimeForFile(dbFile)
time.Sleep(debounce)
s.mutex.Lock()
lastEventTime := s.lastEvents[dbFile]
elapsed := time.Since(lastEventTime)
totalElapsed := time.Since(start)
if elapsed >= DebounceTime || totalElapsed >= MaxWaitTime {
if elapsed >= debounce || totalElapsed >= maxWait {
s.pendingActions[dbFile] = false
flags := pendingEvent{}
if state, ok := s.pendingEvents[dbFile]; ok && state != nil {
flags = *state
}
s.pendingEvents[dbFile] = &pendingEvent{}
s.mutex.Unlock()
if _, err := os.Stat(dbFile); err != nil {
return
}
log.Debug().Msgf("Processing file: %s", dbFile)
if err := s.DecryptDBFile(dbFile); err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
workCopyExists := false
if s.conf.GetWorkDir() != "" {
if relPath, err := filepath.Rel(s.conf.GetDataDir(), dbFile); err == nil {
output := filepath.Join(s.conf.GetWorkDir(), relPath)
if _, err := os.Stat(output); err == nil {
workCopyExists = true
}
}
}
if flags.sawDB {
if !s.conf.GetWalEnabled() || !workCopyExists {
if err := s.DecryptDBFile(dbFile); err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
}
}
return
}
if flags.sawWal {
handled, err := s.IncrementalDecryptDBFile(dbFile)
if err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
}
return
}
if handled {
return
}
}
return
}
if flags.sawWal && s.conf.GetWalEnabled() {
handled, err := s.IncrementalDecryptDBFile(dbFile)
if err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
}
return
}
if handled {
return
}
if !workCopyExists {
if err := s.DecryptDBFile(dbFile); err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
}
}
}
return
}
if !s.conf.GetWalEnabled() || !workCopyExists {
if err := s.DecryptDBFile(dbFile); err != nil {
if s.errorHandler != nil {
s.errorHandler(err)
}
}
}
return
@@ -195,6 +310,11 @@ func (s *Service) DecryptDBFile(dbFile string) error {
if data, err := os.ReadFile(dbFile); err == nil {
outputFile.Write(data)
}
if s.conf.GetWalEnabled() {
if err := s.syncWalFiles(dbFile, output); err != nil {
log.Debug().Err(err).Msgf("failed to sync wal files for %s", dbFile)
}
}
return nil
}
log.Err(err).Msgf("failed to decrypt %s", dbFile)
@@ -203,9 +323,147 @@ func (s *Service) DecryptDBFile(dbFile string) error {
log.Debug().Msgf("Decrypted %s to %s", dbFile, output)
if s.conf.GetWalEnabled() {
if err := s.syncWalFiles(dbFile, output); err != nil {
log.Debug().Err(err).Msgf("failed to sync wal files for %s", dbFile)
}
}
return nil
}
func (s *Service) getDebounceTime() time.Duration {
debounce := s.conf.GetAutoDecryptDebounce()
if debounce <= 0 {
return DebounceTime
}
return time.Duration(debounce) * time.Millisecond
}
func (s *Service) getMaxWaitTime() time.Duration {
if !s.conf.GetWalEnabled() {
return MaxWaitTime
}
debounce := s.getDebounceTime()
maxWait := 2 * debounce
if maxWait < time.Second {
return time.Second
}
if maxWait > 3*time.Second {
return 3 * time.Second
}
return maxWait
}
func (s *Service) getDebounceTimeForFile(dbFile string) time.Duration {
debounce := s.getDebounceTime()
if !s.conf.GetWalEnabled() {
return debounce
}
if isRealtimeDBFile(dbFile) {
if debounce > 300*time.Millisecond {
return 300 * time.Millisecond
}
}
return debounce
}
func (s *Service) getMaxWaitTimeForFile(dbFile string) time.Duration {
if !s.conf.GetWalEnabled() {
return s.getMaxWaitTime()
}
if isRealtimeDBFile(dbFile) {
debounce := s.getDebounceTimeForFile(dbFile)
maxWait := 2 * debounce
if maxWait > time.Second {
return time.Second
}
return maxWait
}
return s.getMaxWaitTime()
}
func isRealtimeDBFile(dbFile string) bool {
base := filepath.Base(dbFile)
if base == "session.db" {
return true
}
return strings.HasPrefix(base, "message_") && strings.HasSuffix(base, ".db")
}
func (s *Service) normalizeDBFile(path string) string {
if strings.HasSuffix(path, ".db-wal") {
return strings.TrimSuffix(path, "-wal")
}
if strings.HasSuffix(path, ".db-shm") {
return strings.TrimSuffix(path, "-shm")
}
return path
}
func isWalFile(path string) bool {
return strings.HasSuffix(path, ".db-wal") || strings.HasSuffix(path, ".db-shm")
}
func (s *Service) syncWalFiles(dbFile, output string) error {
walSrc := dbFile + "-wal"
walDst := output + "-wal"
if err := syncAuxFile(walSrc, walDst); err != nil {
return err
}
shmSrc := dbFile + "-shm"
shmDst := output + "-shm"
if err := syncAuxFile(shmSrc, shmDst); err != nil {
return err
}
return nil
}
func syncAuxFile(src, dst string) error {
if _, err := os.Stat(src); err != nil {
if os.IsNotExist(err) {
if err := os.Remove(dst); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
return err
}
if err := util.PrepareDir(filepath.Dir(dst)); err != nil {
return err
}
return copyFileAtomic(src, dst)
}
func copyFileAtomic(src, dst string) error {
input, err := os.Open(src)
if err != nil {
return err
}
defer input.Close()
temp := dst + ".tmp"
output, err := os.Create(temp)
if err != nil {
return err
}
if _, err := io.Copy(output, input); err != nil {
output.Close()
os.Remove(temp)
return err
}
if err := output.Sync(); err != nil {
output.Close()
os.Remove(temp)
return err
}
if err := output.Close(); err != nil {
os.Remove(temp)
return err
}
return os.Rename(temp, dst)
}
func (s *Service) DecryptDBFiles() error {
dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), `.*\.db$`, []string{"fts"})
if err != nil {
@@ -216,6 +474,14 @@ func (s *Service) DecryptDBFiles() error {
if err != nil {
return err
}
sort.SliceStable(dbFiles, func(i, j int) bool {
pi := dbFilePriority(dbFiles[i])
pj := dbFilePriority(dbFiles[j])
if pi != pj {
return pi < pj
}
return filepath.Base(dbFiles[i]) < filepath.Base(dbFiles[j])
})
var lastErr error
failCount := 0
@@ -235,3 +501,243 @@ func (s *Service) DecryptDBFiles() error {
return nil
}
func dbFilePriority(path string) int {
base := filepath.Base(path)
if strings.HasPrefix(base, "message_") && strings.HasSuffix(base, ".db") {
return 0
}
if base == "session.db" {
return 1
}
return 2
}
func (s *Service) IncrementalDecryptDBFile(dbFile string) (bool, error) {
if !s.conf.GetWalEnabled() {
return false, nil
}
walPath := dbFile + "-wal"
if _, err := os.Stat(walPath); err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
relPath, err := filepath.Rel(s.conf.GetDataDir(), dbFile)
if err != nil {
return false, fmt.Errorf("failed to get relative path for %s: %w", dbFile, err)
}
output := filepath.Join(s.conf.GetWorkDir(), relPath)
if _, err := os.Stat(output); err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
decryptor, err := decrypt.NewDecryptor(s.conf.GetPlatform(), s.conf.GetVersion())
if err != nil {
return true, err
}
dbInfo, err := common.OpenDBFile(dbFile, decryptor.GetPageSize())
if err != nil {
if err == errors.ErrAlreadyDecrypted {
return false, nil
}
return true, err
}
keyBytes, err := hex.DecodeString(s.conf.GetDataKey())
if err != nil {
return true, errors.DecodeKeyFailed(err)
}
if !decryptor.Validate(dbInfo.FirstPage, keyBytes) {
return true, errors.ErrDecryptIncorrectKey
}
encKey, macKey, err := decryptor.DeriveKeys(keyBytes, dbInfo.Salt)
if err != nil {
return true, err
}
walFile, err := os.Open(walPath)
if err != nil {
return true, err
}
defer walFile.Close()
info, err := walFile.Stat()
if err != nil {
return true, err
}
if info.Size() < walHeaderSize {
return false, nil
}
headerBuf := make([]byte, walHeaderSize)
if _, err := io.ReadFull(walFile, headerBuf); err != nil {
return true, err
}
order, pageSize, salt1, salt2, err := parseWalHeader(headerBuf)
if err != nil {
return true, err
}
if pageSize != 0 && pageSize != uint32(decryptor.GetPageSize()) {
return true, fmt.Errorf("unexpected wal page size: %d", pageSize)
}
s.mutex.Lock()
state := s.walStates[dbFile]
if state != nil && (state.salt1 != salt1 || state.salt2 != salt2 || info.Size() < state.offset) {
delete(s.walStates, dbFile)
state = nil
}
startOffset := int64(walHeaderSize)
if state != nil && state.offset > startOffset {
startOffset = state.offset
}
s.mutex.Unlock()
if _, err := walFile.Seek(startOffset, io.SeekStart); err != nil {
return true, err
}
outputFile, err := os.OpenFile(output, os.O_RDWR, 0)
if err != nil {
return true, err
}
defer outputFile.Close()
frameHeader := make([]byte, walFrameHeaderSize)
pageBuf := make([]byte, decryptor.GetPageSize())
txFrames := make([]walFrame, 0)
var lastCommitOffset int64
var applied bool
curOffset := startOffset
for curOffset+int64(walFrameHeaderSize)+int64(decryptor.GetPageSize()) <= info.Size() {
if _, err := io.ReadFull(walFile, frameHeader); err != nil {
break
}
curOffset += int64(walFrameHeaderSize)
frameSalt1 := order.Uint32(frameHeader[8:12])
frameSalt2 := order.Uint32(frameHeader[12:16])
if frameSalt1 != salt1 || frameSalt2 != salt2 {
s.mutex.Lock()
delete(s.walStates, dbFile)
s.mutex.Unlock()
return false, nil
}
if _, err := io.ReadFull(walFile, pageBuf); err != nil {
break
}
curOffset += int64(decryptor.GetPageSize())
pageNo := order.Uint32(frameHeader[0:4])
commit := order.Uint32(frameHeader[4:8])
data := make([]byte, len(pageBuf))
copy(data, pageBuf)
txFrames = append(txFrames, walFrame{pageNo: pageNo, data: data})
if commit != 0 {
if err := applyWalFrames(outputFile, txFrames, decryptor, encKey, macKey); err != nil {
return true, err
}
txFrames = txFrames[:0]
lastCommitOffset = curOffset
applied = true
}
}
if lastCommitOffset > 0 {
s.mutex.Lock()
s.walStates[dbFile] = &walState{
offset: lastCommitOffset,
salt1: salt1,
salt2: salt2,
}
s.mutex.Unlock()
}
if err := s.syncWalFiles(dbFile, output); err != nil {
return true, err
}
if applied {
return true, nil
}
return true, nil
}
func parseWalHeader(buf []byte) (binary.ByteOrder, uint32, uint32, uint32, error) {
if len(buf) < walHeaderSize {
return nil, 0, 0, 0, fmt.Errorf("wal header too short")
}
magic := binary.BigEndian.Uint32(buf[0:4])
var order binary.ByteOrder
switch magic {
case 0x377f0682:
order = binary.BigEndian
case 0x377f0683:
order = binary.LittleEndian
default:
return nil, 0, 0, 0, fmt.Errorf("invalid wal magic: %x", magic)
}
pageSize := order.Uint32(buf[8:12])
salt1 := order.Uint32(buf[16:20])
salt2 := order.Uint32(buf[20:24])
if pageSize == 0 {
pageSize = 65536
}
return order, pageSize, salt1, salt2, nil
}
func applyWalFrames(output *os.File, frames []walFrame, decryptor decrypt.Decryptor, encKey, macKey []byte) error {
pageSize := decryptor.GetPageSize()
reserve := decryptor.GetReserve()
hmacSize := decryptor.GetHMACSize()
hashFunc := decryptor.GetHashFunc()
for _, frame := range frames {
pageNo := int64(frame.pageNo) - 1
if pageNo < 0 {
continue
}
allZeros := true
for _, b := range frame.data {
if b != 0 {
allZeros = false
break
}
}
var pageData []byte
if allZeros {
pageData = frame.data
} else {
decrypted, err := common.DecryptPage(frame.data, encKey, macKey, pageNo, hashFunc, hmacSize, reserve, pageSize)
if err != nil {
return err
}
if pageNo == 0 {
fullPage := make([]byte, pageSize)
copy(fullPage, []byte(common.SQLiteHeader))
copy(fullPage[len(common.SQLiteHeader):], decrypted)
pageData = fullPage
} else {
pageData = decrypted
}
}
if _, err := output.WriteAt(pageData, pageNo*int64(pageSize)); err != nil {
return err
}
}
return nil
}
const (
walHeaderSize = 32
walFrameHeaderSize = 24
)

View File

@@ -157,6 +157,13 @@ func New() *InfoBar {
)
table.SetCell(autoDecryptRow, valueCol1, tview.NewTableCell(""))
table.SetCell(
autoDecryptRow,
labelCol2,
tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "WAL:")),
)
table.SetCell(autoDecryptRow, valueCol2, tview.NewTableCell(""))
// infobar
infoBar := &InfoBar{
Box: tview.NewBox(),
@@ -217,6 +224,10 @@ func (info *InfoBar) UpdateAutoDecrypt(text string) {
info.table.GetCell(autoDecryptRow, valueCol1).SetText(text)
}
func (info *InfoBar) UpdateWal(text string) {
info.table.GetCell(autoDecryptRow, valueCol2).SetText(text)
}
// Draw draws this primitive onto the screen.
func (info *InfoBar) Draw(screen tcell.Screen) {
info.Box.DrawForSubclass(screen, info)

View File

@@ -2,6 +2,7 @@ package decrypt
import (
"context"
"hash"
"io"
"github.com/sjzar/chatlog/internal/errors"
@@ -25,6 +26,9 @@ type Decryptor interface {
// GetHMACSize 返回HMAC大小
GetHMACSize() int
GetHashFunc() func() hash.Hash
DeriveKeys(key []byte, salt []byte) ([]byte, []byte, error)
// GetVersion 返回解密器版本
GetVersion() string
}

View File

@@ -65,6 +65,14 @@ func (d *V4Decryptor) deriveKeys(key []byte, salt []byte) ([]byte, []byte) {
return encKey, macKey
}
func (d *V4Decryptor) DeriveKeys(key []byte, salt []byte) ([]byte, []byte, error) {
if len(key) != common.KeySize {
return nil, nil, errors.ErrKeyLengthMust32
}
encKey, macKey := d.deriveKeys(key, salt)
return encKey, macKey, nil
}
// Validate 验证密钥是否有效
func (d *V4Decryptor) Validate(page1 []byte, key []byte) bool {
if len(page1) < d.pageSize || len(key) != common.KeySize {
@@ -183,6 +191,10 @@ func (d *V4Decryptor) GetHMACSize() int {
return d.hmacSize
}
func (d *V4Decryptor) GetHashFunc() func() hash.Hash {
return d.hashFunc
}
// GetVersion 返回解密器版本
func (d *V4Decryptor) GetVersion() string {
return d.version

View File

@@ -39,17 +39,7 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) {
continue
}
// v4 存在同名进程,需要继续判断 cmdline
if name == V4ProcessName {
cmdline, err := p.Cmdline()
if err != nil {
log.Err(err).Msg("获取进程命令行失败")
continue
}
if strings.Contains(cmdline, "--") {
continue
}
}
cmdline, cmdlineErr := p.Cmdline()
// 获取进程信息
procInfo, err := d.getProcessInfo(p)
@@ -58,6 +48,10 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) {
continue
}
if cmdlineErr == nil && strings.Contains(cmdline, "--") && procInfo.DataDir == "" {
continue
}
result = append(result, procInfo)
}

View File

@@ -51,10 +51,10 @@ type DataSource interface {
Close() error
}
func New(path string, platform string, version int) (DataSource, error) {
func New(path string, platform string, version int, walEnabled bool) (DataSource, error) {
switch {
case platform == "windows" && version == 4:
return v4.New(path)
return v4.New(path, walEnabled)
default:
return nil, errors.PlatformUnsupported(platform, version)
}

View File

@@ -2,8 +2,11 @@ package dbm
import (
"database/sql"
"io"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
@@ -17,23 +20,25 @@ import (
)
type DBManager struct {
path string
id string
fm *filemonitor.FileMonitor
fgs map[string]*filemonitor.FileGroup
dbs map[string]*sql.DB
dbPaths map[string][]string
mutex sync.RWMutex
path string
id string
walEnabled bool
fm *filemonitor.FileMonitor
fgs map[string]*filemonitor.FileGroup
dbs map[string]*sql.DB
dbPaths map[string][]string
mutex sync.RWMutex
}
func NewDBManager(path string) *DBManager {
func NewDBManager(path string, walEnabled bool) *DBManager {
return &DBManager{
path: path,
id: filepath.Base(path),
fm: filemonitor.NewFileMonitor(),
fgs: make(map[string]*filemonitor.FileGroup),
dbs: make(map[string]*sql.DB),
dbPaths: make(map[string][]string),
path: path,
id: filepath.Base(path),
walEnabled: walEnabled,
fm: filemonitor.NewFileMonitor(),
fgs: make(map[string]*filemonitor.FileGroup),
dbs: make(map[string]*sql.DB),
dbPaths: make(map[string][]string),
}
}
@@ -103,7 +108,7 @@ func (d *DBManager) GetDBPath(name string) ([]string, error) {
if len(list) == 0 {
return nil, errors.DBFileNotFound(d.path, fg.PatternStr, nil)
}
dbPaths = list
dbPaths = filterPrimaryDBs(list)
d.mutex.Lock()
d.dbPaths[name] = dbPaths
d.mutex.Unlock()
@@ -126,6 +131,12 @@ func (d *DBManager) OpenDB(path string) (*sql.DB, error) {
log.Err(err).Msgf("获取临时拷贝文件 %s 失败", path)
return nil, err
}
if d.walEnabled {
if err := d.syncWalFiles(path, tempPath); err != nil {
log.Err(err).Msgf("同步 WAL 文件失败: %s", path)
return nil, err
}
}
}
db, err = sql.Open("sqlite3", tempPath)
if err != nil {
@@ -139,19 +150,23 @@ func (d *DBManager) OpenDB(path string) (*sql.DB, error) {
}
func (d *DBManager) Callback(event fsnotify.Event) error {
if !event.Op.Has(fsnotify.Create) {
if !(event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Rename)) {
return nil
}
basePath := normalizeDBPath(event.Name)
d.mutex.Lock()
db, ok := d.dbs[event.Name]
db, ok := d.dbs[basePath]
if ok {
delete(d.dbs, event.Name)
delete(d.dbs, basePath)
go func(db *sql.DB) {
time.Sleep(time.Second * 5)
db.Close()
}(db)
}
if (event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Rename)) && isPrimaryDBFile(event.Name) {
d.dbPaths = make(map[string][]string)
}
d.mutex.Unlock()
return nil
@@ -171,3 +186,76 @@ func (d *DBManager) Close() error {
}
return d.fm.Stop()
}
func (d *DBManager) syncWalFiles(dbPath, tempPath string) error {
if err := syncAuxFile(dbPath+"-wal", tempPath+"-wal"); err != nil {
return err
}
if err := syncAuxFile(dbPath+"-shm", tempPath+"-shm"); err != nil {
return err
}
return nil
}
func syncAuxFile(src, dst string) error {
if _, err := os.Stat(src); err != nil {
if os.IsNotExist(err) {
if err := os.Remove(dst); err != nil && !os.IsNotExist(err) {
return err
}
return nil
}
return err
}
return copyFileAtomic(src, dst)
}
func copyFileAtomic(src, dst string) error {
input, err := os.Open(src)
if err != nil {
return err
}
defer input.Close()
temp := dst + ".tmp"
output, err := os.Create(temp)
if err != nil {
return err
}
if _, err := io.Copy(output, input); err != nil {
output.Close()
os.Remove(temp)
return err
}
if err := output.Sync(); err != nil {
output.Close()
os.Remove(temp)
return err
}
if err := output.Close(); err != nil {
os.Remove(temp)
return err
}
return os.Rename(temp, dst)
}
func normalizeDBPath(path string) string {
if strings.HasSuffix(path, "-wal") || strings.HasSuffix(path, "-shm") {
return strings.TrimSuffix(strings.TrimSuffix(path, "-wal"), "-shm")
}
return path
}
func isPrimaryDBFile(path string) bool {
return strings.HasSuffix(path, ".db") && !strings.HasSuffix(path, ".db-wal") && !strings.HasSuffix(path, ".db-shm")
}
func filterPrimaryDBs(paths []string) []string {
result := make([]string, 0, len(paths))
for _, path := range paths {
if isPrimaryDBFile(path) {
result = append(result, path)
}
}
return result
}

View File

@@ -15,7 +15,7 @@ func TestXxx(t *testing.T) {
BlackList: []string{},
}
d := NewDBManager(path)
d := NewDBManager(path, false)
d.AddGroup(g)
d.Start()

View File

@@ -34,32 +34,32 @@ const (
var Groups = []*dbm.Group{
{
Name: Message,
Pattern: `^message_([0-9]?[0-9])?\.db$`,
Pattern: `^message_([0-9]?[0-9])?\.db(-wal|-shm)?$`,
BlackList: []string{},
},
{
Name: Contact,
Pattern: `^contact\.db$`,
Pattern: `^contact\.db(-wal|-shm)?$`,
BlackList: []string{},
},
{
Name: Session,
Pattern: `session\.db$`,
Pattern: `^session\.db(-wal|-shm)?$`,
BlackList: []string{},
},
{
Name: Media,
Pattern: `^hardlink\.db$`,
Pattern: `^hardlink\.db(-wal|-shm)?$`,
BlackList: []string{},
},
{
Name: Voice,
Pattern: `^media_([0-9]?[0-9])?\.db$`,
Pattern: `^media_([0-9]?[0-9])?\.db(-wal|-shm)?$`,
BlackList: []string{},
},
{
Name: SNS,
Pattern: `^sns\.db$`,
Pattern: `^sns\.db(-wal|-shm)?$`,
BlackList: []string{},
},
}
@@ -79,11 +79,11 @@ type DataSource struct {
messageInfos []MessageDBInfo
}
func New(path string) (*DataSource, error) {
func New(path string, walEnabled bool) (*DataSource, error) {
ds := &DataSource{
path: path,
dbm: dbm.NewDBManager(path),
dbm: dbm.NewDBManager(path, walEnabled),
messageInfos: make([]MessageDBInfo, 0),
}

View File

@@ -13,19 +13,21 @@ import (
)
type DB struct {
path string
platform string
version int
ds datasource.DataSource
repo *repository.Repository
path string
platform string
version int
walEnabled bool
ds datasource.DataSource
repo *repository.Repository
}
func New(path string, platform string, version int) (*DB, error) {
func New(path string, platform string, version int, walEnabled bool) (*DB, error) {
w := &DB{
path: path,
platform: platform,
version: version,
path: path,
platform: platform,
version: version,
walEnabled: walEnabled,
}
// 初始化,加载数据库文件信息
@@ -45,7 +47,7 @@ func (w *DB) Close() error {
func (w *DB) Initialize() error {
var err error
w.ds, err = datasource.New(w.path, w.platform, w.version)
w.ds, err = datasource.New(w.path, w.platform, w.version, w.walEnabled)
if err != nil {
return err
}