diff --git a/.claude/settings.local.json b/.claude/settings.local.json deleted file mode 100644 index 17b48e8..0000000 --- a/.claude/settings.local.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "permissions": { - "allow": [ - "Bash(go build:*)", - "Bash(python:*)", - "Bash(go run:*)", - "Bash(git add:*)", - "Bash(git commit:*)", - "Bash(find:*)" - ] - } -} diff --git a/README.md b/README.md index 6022bee..746bdf7 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,11 @@ ## 更新日志 +### 2026年1月22日 +- **自动解密增量优化(实验性)**: + - **WAL 增量写入**:开启 WAL 后,首次全量解密到工作目录,后续监听 data dir 的 WAL/SHM 并增量写入工作目录数据库。 + - **实验性提示**:该功能目前为实验性,遇到异常会回退全量解密以保证一致性。 + ### 2026年1月19日 - **消息内容解压缩支持**: - **ZSTD 解压缩**:为消息表 (`message`) 的 `message_content` 字段添加 ZSTD 解压缩功能,现在可以正确显示被压缩存储的消息内容,确保消息以纯文本形式正常展示。 diff --git a/internal/chatlog/app.go b/internal/chatlog/app.go index 3dbdd37..9c21cd6 100644 --- a/internal/chatlog/app.go +++ b/internal/chatlog/app.go @@ -3,6 +3,7 @@ package chatlog import ( "fmt" "path/filepath" + "strconv" "time" "github.com/rs/zerolog/log" @@ -141,11 +142,13 @@ func (a *App) refresh() { case <-a.stopRefresh: return case <-tick.C: + var processErr error // 如果当前账号为空,尝试查找微信进程 if a.ctx.Current == nil { // 获取微信实例 - instances := a.m.wechat.GetWeChatInstances() - if len(instances) > 0 { + instances, err := a.m.wechat.GetWeChatInstancesWithError() + processErr = err + if err == nil && len(instances) > 0 { // 找到微信进程,设置第一个为当前账号 a.ctx.SwitchCurrent(instances[0]) log.Info().Msgf("检测到微信进程,PID: %d,已设置为当前账号", instances[0].PID) @@ -168,7 +171,11 @@ func (a *App) refresh() { } a.infoBar.UpdateAccount(a.ctx.Account) a.infoBar.UpdateBasicInfo(a.ctx.PID, a.ctx.FullVersion, a.ctx.ExePath) - a.infoBar.UpdateStatus(a.ctx.Status) + statusText := a.ctx.Status + if a.ctx.PID == 0 && processErr != nil { + statusText = fmt.Sprintf("[red]获取进程失败: %v[white]", processErr) + } + a.infoBar.UpdateStatus(statusText) a.infoBar.UpdateDataKey(a.ctx.DataKey) a.infoBar.UpdateImageKey(a.ctx.ImgKey) a.infoBar.UpdatePlatform(a.ctx.Platform) @@ -182,10 +189,19 @@ func (a *App) refresh() { } else { a.infoBar.UpdateHTTPServer("[未启动]") } + autoDecryptText := "[未开启]" if a.ctx.AutoDecrypt { - a.infoBar.UpdateAutoDecrypt("[green][已开启][white]") + if a.ctx.AutoDecryptDebounce > 0 { + autoDecryptText = fmt.Sprintf("[green][已开启][white] %dms", a.ctx.AutoDecryptDebounce) + } else { + autoDecryptText = "[green][已开启][white]" + } + } + a.infoBar.UpdateAutoDecrypt(autoDecryptText) + if a.ctx.WalEnabled { + a.infoBar.UpdateWal("[green][已启用][white]") } else { - a.infoBar.UpdateAutoDecrypt("[未开启]") + a.infoBar.UpdateWal("[未启用]") } // Update latest message in footer @@ -602,6 +618,16 @@ func (a *App) settingSelected(i *menu.Item) { description: "配置微信数据文件所在目录", action: a.settingDataDir, }, + { + name: "启用 WAL 支持", + description: "同步并监控 .db-wal/.db-shm 文件", + action: a.settingWalEnabled, + }, + { + name: "设置自动解密去抖", + description: "配置自动解密触发间隔(ms)", + action: a.settingAutoDecryptDebounce, + }, } subMenu := menu.NewSubMenu("设置") @@ -759,6 +785,80 @@ func (a *App) settingDataDir() { a.SetFocus(formView) } +func (a *App) settingWalEnabled() { + formView := form.NewForm("设置 WAL 支持") + + tempWalEnabled := a.ctx.WalEnabled + + formView.AddCheckbox("启用 WAL 支持", tempWalEnabled, func(checked bool) { + tempWalEnabled = checked + }) + + formView.AddButton("保存", func() { + a.ctx.SetWalEnabled(tempWalEnabled) + a.mainPages.RemovePage("submenu2") + if tempWalEnabled { + a.showInfo("WAL 支持已开启") + } else { + a.showInfo("WAL 支持已关闭") + } + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) +} + +func (a *App) settingAutoDecryptDebounce() { + formView := form.NewForm("设置自动解密去抖") + + tempDebounceText := "" + if a.ctx.AutoDecryptDebounce > 0 { + tempDebounceText = strconv.Itoa(a.ctx.AutoDecryptDebounce) + } + + formView.AddInputField("去抖时长(ms)", tempDebounceText, 0, func(textToCheck string, lastChar rune) bool { + if textToCheck == "" { + return true + } + for _, r := range textToCheck { + if r < '0' || r > '9' { + return false + } + } + return true + }, func(text string) { + tempDebounceText = text + }) + + formView.AddButton("保存", func() { + if tempDebounceText == "" { + a.ctx.SetAutoDecryptDebounce(0) + a.mainPages.RemovePage("submenu2") + a.showInfo("已恢复默认去抖时长") + return + } + value, err := strconv.Atoi(tempDebounceText) + if err != nil { + a.showError(fmt.Errorf("去抖时长必须为数字")) + return + } + a.ctx.SetAutoDecryptDebounce(value) + a.mainPages.RemovePage("submenu2") + a.showInfo(fmt.Sprintf("去抖时长已设置为 %dms", value)) + }) + + formView.AddButton("取消", func() { + a.mainPages.RemovePage("submenu2") + }) + + a.mainPages.AddPage("submenu2", formView, true, true) + a.SetFocus(formView) +} + // selectAccountSelected 处理切换账号菜单项的选择事件 func (a *App) selectAccountSelected(i *menu.Item) { // 创建子菜单 diff --git a/internal/chatlog/conf/server.go b/internal/chatlog/conf/server.go index 5c52a13..085ac7f 100644 --- a/internal/chatlog/conf/server.go +++ b/internal/chatlog/conf/server.go @@ -15,6 +15,8 @@ type ServerConfig struct { WorkDir string `mapstructure:"work_dir"` HTTPAddr string `mapstructure:"http_addr"` AutoDecrypt bool `mapstructure:"auto_decrypt"` + WalEnabled bool `mapstructure:"wal_enabled"` + AutoDecryptDebounce int `mapstructure:"auto_decrypt_debounce"` SaveDecryptedMedia bool `mapstructure:"save_decrypted_media"` Webhook *Webhook `mapstructure:"webhook"` } @@ -51,6 +53,14 @@ func (c *ServerConfig) GetAutoDecrypt() bool { return c.AutoDecrypt } +func (c *ServerConfig) GetWalEnabled() bool { + return c.WalEnabled +} + +func (c *ServerConfig) GetAutoDecryptDebounce() int { + return c.AutoDecryptDebounce +} + func (c *ServerConfig) GetHTTPAddr() string { if c.HTTPAddr == "" { c.HTTPAddr = DefalutHTTPAddr diff --git a/internal/chatlog/conf/tui.go b/internal/chatlog/conf/tui.go index 2d54fb4..860de6f 100644 --- a/internal/chatlog/conf/tui.go +++ b/internal/chatlog/conf/tui.go @@ -21,6 +21,8 @@ type ProcessConfig struct { WorkDir string `mapstructure:"work_dir" json:"work_dir"` HTTPEnabled bool `mapstructure:"http_enabled" json:"http_enabled"` HTTPAddr string `mapstructure:"http_addr" json:"http_addr"` + WalEnabled bool `mapstructure:"wal_enabled" json:"wal_enabled"` + AutoDecryptDebounce int `mapstructure:"auto_decrypt_debounce" json:"auto_decrypt_debounce"` LastTime int64 `mapstructure:"last_time" json:"last_time"` Files []File `mapstructure:"files" json:"files"` } diff --git a/internal/chatlog/ctx/context.go b/internal/chatlog/ctx/context.go index 30f0f70..f239525 100644 --- a/internal/chatlog/ctx/context.go +++ b/internal/chatlog/ctx/context.go @@ -50,6 +50,8 @@ type Context struct { // 自动解密 AutoDecrypt bool LastSession time.Time + WalEnabled bool + AutoDecryptDebounce int // 当前选中的微信实例 Current *wechat.Account @@ -103,6 +105,8 @@ func (c *Context) SwitchHistory(account string) { c.WorkDir = history.WorkDir c.HTTPEnabled = history.HTTPEnabled c.HTTPAddr = history.HTTPAddr + c.WalEnabled = history.WalEnabled + c.AutoDecryptDebounce = history.AutoDecryptDebounce } else { c.Account = "" c.Platform = "" @@ -114,6 +118,8 @@ func (c *Context) SwitchHistory(account string) { c.WorkDir = "" c.HTTPEnabled = false c.HTTPAddr = "" + c.WalEnabled = false + c.AutoDecryptDebounce = 0 } } @@ -182,6 +188,18 @@ func (c *Context) GetDataKey() string { return c.DataKey } +func (c *Context) GetWalEnabled() bool { + c.mu.RLock() + defer c.mu.RUnlock() + return c.WalEnabled +} + +func (c *Context) GetAutoDecryptDebounce() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.AutoDecryptDebounce +} + func (c *Context) GetHTTPAddr() string { if c.HTTPAddr == "" { c.HTTPAddr = DefalutHTTPAddr @@ -260,6 +278,26 @@ func (c *Context) SetAutoDecrypt(enabled bool) { c.UpdateConfig() } +func (c *Context) SetWalEnabled(enabled bool) { + c.mu.Lock() + defer c.mu.Unlock() + if c.WalEnabled == enabled { + return + } + c.WalEnabled = enabled + c.UpdateConfig() +} + +func (c *Context) SetAutoDecryptDebounce(debounce int) { + c.mu.Lock() + defer c.mu.Unlock() + if c.AutoDecryptDebounce == debounce { + return + } + c.AutoDecryptDebounce = debounce + c.UpdateConfig() +} + // 更新配置 func (c *Context) UpdateConfig() { @@ -275,6 +313,8 @@ func (c *Context) UpdateConfig() { WorkDir: c.WorkDir, HTTPEnabled: c.HTTPEnabled, HTTPAddr: c.HTTPAddr, + WalEnabled: c.WalEnabled, + AutoDecryptDebounce: c.AutoDecryptDebounce, } if c.conf.History == nil { diff --git a/internal/chatlog/database/service.go b/internal/chatlog/database/service.go index 1adb755..54b5668 100644 --- a/internal/chatlog/database/service.go +++ b/internal/chatlog/database/service.go @@ -33,6 +33,7 @@ type Config interface { GetPlatform() string GetVersion() int GetWebhook() *conf.Webhook + GetWalEnabled() bool } func NewService(conf Config) *Service { @@ -43,7 +44,7 @@ func NewService(conf Config) *Service { } func (s *Service) Start() error { - db, err := wechatdb.New(s.conf.GetWorkDir(), s.conf.GetPlatform(), s.conf.GetVersion()) + db, err := wechatdb.New(s.conf.GetWorkDir(), s.conf.GetPlatform(), s.conf.GetVersion(), s.conf.GetWalEnabled()) if err != nil { return err } diff --git a/internal/chatlog/wechat/service.go b/internal/chatlog/wechat/service.go index 465abfd..680f13e 100644 --- a/internal/chatlog/wechat/service.go +++ b/internal/chatlog/wechat/service.go @@ -2,9 +2,14 @@ package wechat import ( "context" + "encoding/binary" + "encoding/hex" "fmt" + "io" "os" "path/filepath" + "sort" + "strings" "sync" "time" @@ -14,6 +19,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat" "github.com/sjzar/chatlog/internal/wechat/decrypt" + "github.com/sjzar/chatlog/internal/wechat/decrypt/common" "github.com/sjzar/chatlog/pkg/filemonitor" "github.com/sjzar/chatlog/pkg/util" ) @@ -27,17 +33,37 @@ type Service struct { conf Config lastEvents map[string]time.Time pendingActions map[string]bool + pendingEvents map[string]*pendingEvent + walStates map[string]*walState mutex sync.Mutex fm *filemonitor.FileMonitor errorHandler func(error) } +type pendingEvent struct { + sawDB bool + sawWal bool +} + +type walState struct { + offset int64 + salt1 uint32 + salt2 uint32 +} + +type walFrame struct { + pageNo uint32 + data []byte +} + type Config interface { GetDataKey() string GetDataDir() string GetWorkDir() string GetPlatform() string GetVersion() int + GetWalEnabled() bool + GetAutoDecryptDebounce() int } func NewService(conf Config) *Service { @@ -45,6 +71,8 @@ func NewService(conf Config) *Service { conf: conf, lastEvents: make(map[string]time.Time), pendingActions: make(map[string]bool), + pendingEvents: make(map[string]*pendingEvent), + walStates: make(map[string]*walState), } } @@ -55,8 +83,15 @@ func (s *Service) SetAutoDecryptErrorHandler(handler func(error)) { // GetWeChatInstances returns all running WeChat instances func (s *Service) GetWeChatInstances() []*wechat.Account { - wechat.Load() - return wechat.GetAccounts() + instances, _ := s.GetWeChatInstancesWithError() + return instances +} + +func (s *Service) GetWeChatInstancesWithError() ([]*wechat.Account, error) { + if err := wechat.Load(); err != nil { + return nil, err + } + return wechat.GetAccounts(), nil } // GetDataKey extracts the encryption key from a WeChat process @@ -84,7 +119,11 @@ func (s *Service) GetImageKey(info *wechat.Account) (string, error) { func (s *Service) StartAutoDecrypt() error { log.Info().Msgf("start auto decrypt, data dir: %s", s.conf.GetDataDir()) - dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), `.*\.db$`, []string{"fts"}) + pattern := `.*\.db$` + if s.conf.GetWalEnabled() { + pattern = `.*\.db(-wal|-shm)?$` + } + dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), pattern, []string{"fts"}) if err != nil { return err } @@ -122,13 +161,25 @@ func (s *Service) DecryptFileCallback(event fsnotify.Event) error { return nil } + dbFile := s.normalizeDBFile(event.Name) + isWal := isWalFile(event.Name) s.mutex.Lock() - s.lastEvents[event.Name] = time.Now() + s.lastEvents[dbFile] = time.Now() + flags, ok := s.pendingEvents[dbFile] + if !ok { + flags = &pendingEvent{} + s.pendingEvents[dbFile] = flags + } + if isWal { + flags.sawWal = true + } else { + flags.sawDB = true + } - if !s.pendingActions[event.Name] { - s.pendingActions[event.Name] = true + if !s.pendingActions[dbFile] { + s.pendingActions[dbFile] = true s.mutex.Unlock() - go s.waitAndProcess(event.Name) + go s.waitAndProcess(dbFile) } else { s.mutex.Unlock() } @@ -139,21 +190,85 @@ func (s *Service) DecryptFileCallback(event fsnotify.Event) error { func (s *Service) waitAndProcess(dbFile string) { start := time.Now() for { - time.Sleep(DebounceTime) + debounce := s.getDebounceTimeForFile(dbFile) + maxWait := s.getMaxWaitTimeForFile(dbFile) + time.Sleep(debounce) s.mutex.Lock() lastEventTime := s.lastEvents[dbFile] elapsed := time.Since(lastEventTime) totalElapsed := time.Since(start) - if elapsed >= DebounceTime || totalElapsed >= MaxWaitTime { + if elapsed >= debounce || totalElapsed >= maxWait { s.pendingActions[dbFile] = false + flags := pendingEvent{} + if state, ok := s.pendingEvents[dbFile]; ok && state != nil { + flags = *state + } + s.pendingEvents[dbFile] = &pendingEvent{} s.mutex.Unlock() + if _, err := os.Stat(dbFile); err != nil { + return + } log.Debug().Msgf("Processing file: %s", dbFile) - if err := s.DecryptDBFile(dbFile); err != nil { - if s.errorHandler != nil { - s.errorHandler(err) + workCopyExists := false + if s.conf.GetWorkDir() != "" { + if relPath, err := filepath.Rel(s.conf.GetDataDir(), dbFile); err == nil { + output := filepath.Join(s.conf.GetWorkDir(), relPath) + if _, err := os.Stat(output); err == nil { + workCopyExists = true + } + } + } + if flags.sawDB { + if !s.conf.GetWalEnabled() || !workCopyExists { + if err := s.DecryptDBFile(dbFile); err != nil { + if s.errorHandler != nil { + s.errorHandler(err) + } + } + return + } + if flags.sawWal { + handled, err := s.IncrementalDecryptDBFile(dbFile) + if err != nil { + if s.errorHandler != nil { + s.errorHandler(err) + } + return + } + if handled { + return + } + } + return + } + if flags.sawWal && s.conf.GetWalEnabled() { + handled, err := s.IncrementalDecryptDBFile(dbFile) + if err != nil { + if s.errorHandler != nil { + s.errorHandler(err) + } + return + } + if handled { + return + } + if !workCopyExists { + if err := s.DecryptDBFile(dbFile); err != nil { + if s.errorHandler != nil { + s.errorHandler(err) + } + } + } + return + } + if !s.conf.GetWalEnabled() || !workCopyExists { + if err := s.DecryptDBFile(dbFile); err != nil { + if s.errorHandler != nil { + s.errorHandler(err) + } } } return @@ -195,6 +310,11 @@ func (s *Service) DecryptDBFile(dbFile string) error { if data, err := os.ReadFile(dbFile); err == nil { outputFile.Write(data) } + if s.conf.GetWalEnabled() { + if err := s.syncWalFiles(dbFile, output); err != nil { + log.Debug().Err(err).Msgf("failed to sync wal files for %s", dbFile) + } + } return nil } log.Err(err).Msgf("failed to decrypt %s", dbFile) @@ -203,9 +323,147 @@ func (s *Service) DecryptDBFile(dbFile string) error { log.Debug().Msgf("Decrypted %s to %s", dbFile, output) + if s.conf.GetWalEnabled() { + if err := s.syncWalFiles(dbFile, output); err != nil { + log.Debug().Err(err).Msgf("failed to sync wal files for %s", dbFile) + } + } + return nil } +func (s *Service) getDebounceTime() time.Duration { + debounce := s.conf.GetAutoDecryptDebounce() + if debounce <= 0 { + return DebounceTime + } + return time.Duration(debounce) * time.Millisecond +} + +func (s *Service) getMaxWaitTime() time.Duration { + if !s.conf.GetWalEnabled() { + return MaxWaitTime + } + debounce := s.getDebounceTime() + maxWait := 2 * debounce + if maxWait < time.Second { + return time.Second + } + if maxWait > 3*time.Second { + return 3 * time.Second + } + return maxWait +} + +func (s *Service) getDebounceTimeForFile(dbFile string) time.Duration { + debounce := s.getDebounceTime() + if !s.conf.GetWalEnabled() { + return debounce + } + if isRealtimeDBFile(dbFile) { + if debounce > 300*time.Millisecond { + return 300 * time.Millisecond + } + } + return debounce +} + +func (s *Service) getMaxWaitTimeForFile(dbFile string) time.Duration { + if !s.conf.GetWalEnabled() { + return s.getMaxWaitTime() + } + if isRealtimeDBFile(dbFile) { + debounce := s.getDebounceTimeForFile(dbFile) + maxWait := 2 * debounce + if maxWait > time.Second { + return time.Second + } + return maxWait + } + return s.getMaxWaitTime() +} + +func isRealtimeDBFile(dbFile string) bool { + base := filepath.Base(dbFile) + if base == "session.db" { + return true + } + return strings.HasPrefix(base, "message_") && strings.HasSuffix(base, ".db") +} + +func (s *Service) normalizeDBFile(path string) string { + if strings.HasSuffix(path, ".db-wal") { + return strings.TrimSuffix(path, "-wal") + } + if strings.HasSuffix(path, ".db-shm") { + return strings.TrimSuffix(path, "-shm") + } + return path +} + +func isWalFile(path string) bool { + return strings.HasSuffix(path, ".db-wal") || strings.HasSuffix(path, ".db-shm") +} + +func (s *Service) syncWalFiles(dbFile, output string) error { + walSrc := dbFile + "-wal" + walDst := output + "-wal" + if err := syncAuxFile(walSrc, walDst); err != nil { + return err + } + shmSrc := dbFile + "-shm" + shmDst := output + "-shm" + if err := syncAuxFile(shmSrc, shmDst); err != nil { + return err + } + return nil +} + +func syncAuxFile(src, dst string) error { + if _, err := os.Stat(src); err != nil { + if os.IsNotExist(err) { + if err := os.Remove(dst); err != nil && !os.IsNotExist(err) { + return err + } + return nil + } + return err + } + if err := util.PrepareDir(filepath.Dir(dst)); err != nil { + return err + } + return copyFileAtomic(src, dst) +} + +func copyFileAtomic(src, dst string) error { + input, err := os.Open(src) + if err != nil { + return err + } + defer input.Close() + + temp := dst + ".tmp" + output, err := os.Create(temp) + if err != nil { + return err + } + if _, err := io.Copy(output, input); err != nil { + output.Close() + os.Remove(temp) + return err + } + if err := output.Sync(); err != nil { + output.Close() + os.Remove(temp) + return err + } + if err := output.Close(); err != nil { + os.Remove(temp) + return err + } + return os.Rename(temp, dst) +} + func (s *Service) DecryptDBFiles() error { dbGroup, err := filemonitor.NewFileGroup("wechat", s.conf.GetDataDir(), `.*\.db$`, []string{"fts"}) if err != nil { @@ -216,6 +474,14 @@ func (s *Service) DecryptDBFiles() error { if err != nil { return err } + sort.SliceStable(dbFiles, func(i, j int) bool { + pi := dbFilePriority(dbFiles[i]) + pj := dbFilePriority(dbFiles[j]) + if pi != pj { + return pi < pj + } + return filepath.Base(dbFiles[i]) < filepath.Base(dbFiles[j]) + }) var lastErr error failCount := 0 @@ -235,3 +501,243 @@ func (s *Service) DecryptDBFiles() error { return nil } + +func dbFilePriority(path string) int { + base := filepath.Base(path) + if strings.HasPrefix(base, "message_") && strings.HasSuffix(base, ".db") { + return 0 + } + if base == "session.db" { + return 1 + } + return 2 +} + +func (s *Service) IncrementalDecryptDBFile(dbFile string) (bool, error) { + if !s.conf.GetWalEnabled() { + return false, nil + } + walPath := dbFile + "-wal" + if _, err := os.Stat(walPath); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + relPath, err := filepath.Rel(s.conf.GetDataDir(), dbFile) + if err != nil { + return false, fmt.Errorf("failed to get relative path for %s: %w", dbFile, err) + } + output := filepath.Join(s.conf.GetWorkDir(), relPath) + if _, err := os.Stat(output); err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + + decryptor, err := decrypt.NewDecryptor(s.conf.GetPlatform(), s.conf.GetVersion()) + if err != nil { + return true, err + } + + dbInfo, err := common.OpenDBFile(dbFile, decryptor.GetPageSize()) + if err != nil { + if err == errors.ErrAlreadyDecrypted { + return false, nil + } + return true, err + } + + keyBytes, err := hex.DecodeString(s.conf.GetDataKey()) + if err != nil { + return true, errors.DecodeKeyFailed(err) + } + if !decryptor.Validate(dbInfo.FirstPage, keyBytes) { + return true, errors.ErrDecryptIncorrectKey + } + + encKey, macKey, err := decryptor.DeriveKeys(keyBytes, dbInfo.Salt) + if err != nil { + return true, err + } + + walFile, err := os.Open(walPath) + if err != nil { + return true, err + } + defer walFile.Close() + + info, err := walFile.Stat() + if err != nil { + return true, err + } + if info.Size() < walHeaderSize { + return false, nil + } + + headerBuf := make([]byte, walHeaderSize) + if _, err := io.ReadFull(walFile, headerBuf); err != nil { + return true, err + } + order, pageSize, salt1, salt2, err := parseWalHeader(headerBuf) + if err != nil { + return true, err + } + if pageSize != 0 && pageSize != uint32(decryptor.GetPageSize()) { + return true, fmt.Errorf("unexpected wal page size: %d", pageSize) + } + + s.mutex.Lock() + state := s.walStates[dbFile] + if state != nil && (state.salt1 != salt1 || state.salt2 != salt2 || info.Size() < state.offset) { + delete(s.walStates, dbFile) + state = nil + } + startOffset := int64(walHeaderSize) + if state != nil && state.offset > startOffset { + startOffset = state.offset + } + s.mutex.Unlock() + + if _, err := walFile.Seek(startOffset, io.SeekStart); err != nil { + return true, err + } + + outputFile, err := os.OpenFile(output, os.O_RDWR, 0) + if err != nil { + return true, err + } + defer outputFile.Close() + + frameHeader := make([]byte, walFrameHeaderSize) + pageBuf := make([]byte, decryptor.GetPageSize()) + txFrames := make([]walFrame, 0) + var lastCommitOffset int64 + var applied bool + curOffset := startOffset + + for curOffset+int64(walFrameHeaderSize)+int64(decryptor.GetPageSize()) <= info.Size() { + if _, err := io.ReadFull(walFile, frameHeader); err != nil { + break + } + curOffset += int64(walFrameHeaderSize) + + frameSalt1 := order.Uint32(frameHeader[8:12]) + frameSalt2 := order.Uint32(frameHeader[12:16]) + if frameSalt1 != salt1 || frameSalt2 != salt2 { + s.mutex.Lock() + delete(s.walStates, dbFile) + s.mutex.Unlock() + return false, nil + } + + if _, err := io.ReadFull(walFile, pageBuf); err != nil { + break + } + curOffset += int64(decryptor.GetPageSize()) + + pageNo := order.Uint32(frameHeader[0:4]) + commit := order.Uint32(frameHeader[4:8]) + data := make([]byte, len(pageBuf)) + copy(data, pageBuf) + txFrames = append(txFrames, walFrame{pageNo: pageNo, data: data}) + + if commit != 0 { + if err := applyWalFrames(outputFile, txFrames, decryptor, encKey, macKey); err != nil { + return true, err + } + txFrames = txFrames[:0] + lastCommitOffset = curOffset + applied = true + } + } + + if lastCommitOffset > 0 { + s.mutex.Lock() + s.walStates[dbFile] = &walState{ + offset: lastCommitOffset, + salt1: salt1, + salt2: salt2, + } + s.mutex.Unlock() + } + + if err := s.syncWalFiles(dbFile, output); err != nil { + return true, err + } + + if applied { + return true, nil + } + return true, nil +} + +func parseWalHeader(buf []byte) (binary.ByteOrder, uint32, uint32, uint32, error) { + if len(buf) < walHeaderSize { + return nil, 0, 0, 0, fmt.Errorf("wal header too short") + } + magic := binary.BigEndian.Uint32(buf[0:4]) + var order binary.ByteOrder + switch magic { + case 0x377f0682: + order = binary.BigEndian + case 0x377f0683: + order = binary.LittleEndian + default: + return nil, 0, 0, 0, fmt.Errorf("invalid wal magic: %x", magic) + } + pageSize := order.Uint32(buf[8:12]) + salt1 := order.Uint32(buf[16:20]) + salt2 := order.Uint32(buf[20:24]) + if pageSize == 0 { + pageSize = 65536 + } + return order, pageSize, salt1, salt2, nil +} + +func applyWalFrames(output *os.File, frames []walFrame, decryptor decrypt.Decryptor, encKey, macKey []byte) error { + pageSize := decryptor.GetPageSize() + reserve := decryptor.GetReserve() + hmacSize := decryptor.GetHMACSize() + hashFunc := decryptor.GetHashFunc() + for _, frame := range frames { + pageNo := int64(frame.pageNo) - 1 + if pageNo < 0 { + continue + } + allZeros := true + for _, b := range frame.data { + if b != 0 { + allZeros = false + break + } + } + var pageData []byte + if allZeros { + pageData = frame.data + } else { + decrypted, err := common.DecryptPage(frame.data, encKey, macKey, pageNo, hashFunc, hmacSize, reserve, pageSize) + if err != nil { + return err + } + if pageNo == 0 { + fullPage := make([]byte, pageSize) + copy(fullPage, []byte(common.SQLiteHeader)) + copy(fullPage[len(common.SQLiteHeader):], decrypted) + pageData = fullPage + } else { + pageData = decrypted + } + } + if _, err := output.WriteAt(pageData, pageNo*int64(pageSize)); err != nil { + return err + } + } + return nil +} + +const ( + walHeaderSize = 32 + walFrameHeaderSize = 24 +) diff --git a/internal/ui/infobar/infobar.go b/internal/ui/infobar/infobar.go index 63f1f92..47cc0c2 100644 --- a/internal/ui/infobar/infobar.go +++ b/internal/ui/infobar/infobar.go @@ -157,6 +157,13 @@ func New() *InfoBar { ) table.SetCell(autoDecryptRow, valueCol1, tview.NewTableCell("")) + table.SetCell( + autoDecryptRow, + labelCol2, + tview.NewTableCell(fmt.Sprintf(" [%s::]%s", headerColor, "WAL:")), + ) + table.SetCell(autoDecryptRow, valueCol2, tview.NewTableCell("")) + // infobar infoBar := &InfoBar{ Box: tview.NewBox(), @@ -217,6 +224,10 @@ func (info *InfoBar) UpdateAutoDecrypt(text string) { info.table.GetCell(autoDecryptRow, valueCol1).SetText(text) } +func (info *InfoBar) UpdateWal(text string) { + info.table.GetCell(autoDecryptRow, valueCol2).SetText(text) +} + // Draw draws this primitive onto the screen. func (info *InfoBar) Draw(screen tcell.Screen) { info.Box.DrawForSubclass(screen, info) diff --git a/internal/wechat/decrypt/decryptor.go b/internal/wechat/decrypt/decryptor.go index daeaec8..5d6e0c4 100644 --- a/internal/wechat/decrypt/decryptor.go +++ b/internal/wechat/decrypt/decryptor.go @@ -2,6 +2,7 @@ package decrypt import ( "context" + "hash" "io" "github.com/sjzar/chatlog/internal/errors" @@ -25,6 +26,9 @@ type Decryptor interface { // GetHMACSize 返回HMAC大小 GetHMACSize() int + GetHashFunc() func() hash.Hash + DeriveKeys(key []byte, salt []byte) ([]byte, []byte, error) + // GetVersion 返回解密器版本 GetVersion() string } diff --git a/internal/wechat/decrypt/windows/v4.go b/internal/wechat/decrypt/windows/v4.go index 9bcfdc7..4e5f698 100644 --- a/internal/wechat/decrypt/windows/v4.go +++ b/internal/wechat/decrypt/windows/v4.go @@ -65,6 +65,14 @@ func (d *V4Decryptor) deriveKeys(key []byte, salt []byte) ([]byte, []byte) { return encKey, macKey } +func (d *V4Decryptor) DeriveKeys(key []byte, salt []byte) ([]byte, []byte, error) { + if len(key) != common.KeySize { + return nil, nil, errors.ErrKeyLengthMust32 + } + encKey, macKey := d.deriveKeys(key, salt) + return encKey, macKey, nil +} + // Validate 验证密钥是否有效 func (d *V4Decryptor) Validate(page1 []byte, key []byte) bool { if len(page1) < d.pageSize || len(key) != common.KeySize { @@ -183,6 +191,10 @@ func (d *V4Decryptor) GetHMACSize() int { return d.hmacSize } +func (d *V4Decryptor) GetHashFunc() func() hash.Hash { + return d.hashFunc +} + // GetVersion 返回解密器版本 func (d *V4Decryptor) GetVersion() string { return d.version diff --git a/internal/wechat/process/windows/detector.go b/internal/wechat/process/windows/detector.go index c4eb960..9320a81 100644 --- a/internal/wechat/process/windows/detector.go +++ b/internal/wechat/process/windows/detector.go @@ -39,17 +39,7 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) { continue } - // v4 存在同名进程,需要继续判断 cmdline - if name == V4ProcessName { - cmdline, err := p.Cmdline() - if err != nil { - log.Err(err).Msg("获取进程命令行失败") - continue - } - if strings.Contains(cmdline, "--") { - continue - } - } + cmdline, cmdlineErr := p.Cmdline() // 获取进程信息 procInfo, err := d.getProcessInfo(p) @@ -58,6 +48,10 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) { continue } + if cmdlineErr == nil && strings.Contains(cmdline, "--") && procInfo.DataDir == "" { + continue + } + result = append(result, procInfo) } diff --git a/internal/wechatdb/datasource/datasource.go b/internal/wechatdb/datasource/datasource.go index 1266404..fe32ad9 100644 --- a/internal/wechatdb/datasource/datasource.go +++ b/internal/wechatdb/datasource/datasource.go @@ -51,10 +51,10 @@ type DataSource interface { Close() error } -func New(path string, platform string, version int) (DataSource, error) { +func New(path string, platform string, version int, walEnabled bool) (DataSource, error) { switch { case platform == "windows" && version == 4: - return v4.New(path) + return v4.New(path, walEnabled) default: return nil, errors.PlatformUnsupported(platform, version) } diff --git a/internal/wechatdb/datasource/dbm/dbm.go b/internal/wechatdb/datasource/dbm/dbm.go index e61dcb8..77bc5b1 100644 --- a/internal/wechatdb/datasource/dbm/dbm.go +++ b/internal/wechatdb/datasource/dbm/dbm.go @@ -2,8 +2,11 @@ package dbm import ( "database/sql" + "io" + "os" "path/filepath" "runtime" + "strings" "sync" "time" @@ -17,23 +20,25 @@ import ( ) type DBManager struct { - path string - id string - fm *filemonitor.FileMonitor - fgs map[string]*filemonitor.FileGroup - dbs map[string]*sql.DB - dbPaths map[string][]string - mutex sync.RWMutex + path string + id string + walEnabled bool + fm *filemonitor.FileMonitor + fgs map[string]*filemonitor.FileGroup + dbs map[string]*sql.DB + dbPaths map[string][]string + mutex sync.RWMutex } -func NewDBManager(path string) *DBManager { +func NewDBManager(path string, walEnabled bool) *DBManager { return &DBManager{ - path: path, - id: filepath.Base(path), - fm: filemonitor.NewFileMonitor(), - fgs: make(map[string]*filemonitor.FileGroup), - dbs: make(map[string]*sql.DB), - dbPaths: make(map[string][]string), + path: path, + id: filepath.Base(path), + walEnabled: walEnabled, + fm: filemonitor.NewFileMonitor(), + fgs: make(map[string]*filemonitor.FileGroup), + dbs: make(map[string]*sql.DB), + dbPaths: make(map[string][]string), } } @@ -103,7 +108,7 @@ func (d *DBManager) GetDBPath(name string) ([]string, error) { if len(list) == 0 { return nil, errors.DBFileNotFound(d.path, fg.PatternStr, nil) } - dbPaths = list + dbPaths = filterPrimaryDBs(list) d.mutex.Lock() d.dbPaths[name] = dbPaths d.mutex.Unlock() @@ -126,6 +131,12 @@ func (d *DBManager) OpenDB(path string) (*sql.DB, error) { log.Err(err).Msgf("获取临时拷贝文件 %s 失败", path) return nil, err } + if d.walEnabled { + if err := d.syncWalFiles(path, tempPath); err != nil { + log.Err(err).Msgf("同步 WAL 文件失败: %s", path) + return nil, err + } + } } db, err = sql.Open("sqlite3", tempPath) if err != nil { @@ -139,19 +150,23 @@ func (d *DBManager) OpenDB(path string) (*sql.DB, error) { } func (d *DBManager) Callback(event fsnotify.Event) error { - if !event.Op.Has(fsnotify.Create) { + if !(event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Rename)) { return nil } + basePath := normalizeDBPath(event.Name) d.mutex.Lock() - db, ok := d.dbs[event.Name] + db, ok := d.dbs[basePath] if ok { - delete(d.dbs, event.Name) + delete(d.dbs, basePath) go func(db *sql.DB) { time.Sleep(time.Second * 5) db.Close() }(db) } + if (event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Rename)) && isPrimaryDBFile(event.Name) { + d.dbPaths = make(map[string][]string) + } d.mutex.Unlock() return nil @@ -171,3 +186,76 @@ func (d *DBManager) Close() error { } return d.fm.Stop() } + +func (d *DBManager) syncWalFiles(dbPath, tempPath string) error { + if err := syncAuxFile(dbPath+"-wal", tempPath+"-wal"); err != nil { + return err + } + if err := syncAuxFile(dbPath+"-shm", tempPath+"-shm"); err != nil { + return err + } + return nil +} + +func syncAuxFile(src, dst string) error { + if _, err := os.Stat(src); err != nil { + if os.IsNotExist(err) { + if err := os.Remove(dst); err != nil && !os.IsNotExist(err) { + return err + } + return nil + } + return err + } + return copyFileAtomic(src, dst) +} + +func copyFileAtomic(src, dst string) error { + input, err := os.Open(src) + if err != nil { + return err + } + defer input.Close() + + temp := dst + ".tmp" + output, err := os.Create(temp) + if err != nil { + return err + } + if _, err := io.Copy(output, input); err != nil { + output.Close() + os.Remove(temp) + return err + } + if err := output.Sync(); err != nil { + output.Close() + os.Remove(temp) + return err + } + if err := output.Close(); err != nil { + os.Remove(temp) + return err + } + return os.Rename(temp, dst) +} + +func normalizeDBPath(path string) string { + if strings.HasSuffix(path, "-wal") || strings.HasSuffix(path, "-shm") { + return strings.TrimSuffix(strings.TrimSuffix(path, "-wal"), "-shm") + } + return path +} + +func isPrimaryDBFile(path string) bool { + return strings.HasSuffix(path, ".db") && !strings.HasSuffix(path, ".db-wal") && !strings.HasSuffix(path, ".db-shm") +} + +func filterPrimaryDBs(paths []string) []string { + result := make([]string, 0, len(paths)) + for _, path := range paths { + if isPrimaryDBFile(path) { + result = append(result, path) + } + } + return result +} diff --git a/internal/wechatdb/datasource/dbm/dbm_test.go b/internal/wechatdb/datasource/dbm/dbm_test.go index 67aa9db..b6bcc59 100644 --- a/internal/wechatdb/datasource/dbm/dbm_test.go +++ b/internal/wechatdb/datasource/dbm/dbm_test.go @@ -15,7 +15,7 @@ func TestXxx(t *testing.T) { BlackList: []string{}, } - d := NewDBManager(path) + d := NewDBManager(path, false) d.AddGroup(g) d.Start() diff --git a/internal/wechatdb/datasource/v4/datasource.go b/internal/wechatdb/datasource/v4/datasource.go index 38e8bda..145dff9 100644 --- a/internal/wechatdb/datasource/v4/datasource.go +++ b/internal/wechatdb/datasource/v4/datasource.go @@ -34,32 +34,32 @@ const ( var Groups = []*dbm.Group{ { Name: Message, - Pattern: `^message_([0-9]?[0-9])?\.db$`, + Pattern: `^message_([0-9]?[0-9])?\.db(-wal|-shm)?$`, BlackList: []string{}, }, { Name: Contact, - Pattern: `^contact\.db$`, + Pattern: `^contact\.db(-wal|-shm)?$`, BlackList: []string{}, }, { Name: Session, - Pattern: `session\.db$`, + Pattern: `^session\.db(-wal|-shm)?$`, BlackList: []string{}, }, { Name: Media, - Pattern: `^hardlink\.db$`, + Pattern: `^hardlink\.db(-wal|-shm)?$`, BlackList: []string{}, }, { Name: Voice, - Pattern: `^media_([0-9]?[0-9])?\.db$`, + Pattern: `^media_([0-9]?[0-9])?\.db(-wal|-shm)?$`, BlackList: []string{}, }, { Name: SNS, - Pattern: `^sns\.db$`, + Pattern: `^sns\.db(-wal|-shm)?$`, BlackList: []string{}, }, } @@ -79,11 +79,11 @@ type DataSource struct { messageInfos []MessageDBInfo } -func New(path string) (*DataSource, error) { +func New(path string, walEnabled bool) (*DataSource, error) { ds := &DataSource{ path: path, - dbm: dbm.NewDBManager(path), + dbm: dbm.NewDBManager(path, walEnabled), messageInfos: make([]MessageDBInfo, 0), } diff --git a/internal/wechatdb/wechatdb.go b/internal/wechatdb/wechatdb.go index 916c5d3..dad9284 100644 --- a/internal/wechatdb/wechatdb.go +++ b/internal/wechatdb/wechatdb.go @@ -13,19 +13,21 @@ import ( ) type DB struct { - path string - platform string - version int - ds datasource.DataSource - repo *repository.Repository + path string + platform string + version int + walEnabled bool + ds datasource.DataSource + repo *repository.Repository } -func New(path string, platform string, version int) (*DB, error) { +func New(path string, platform string, version int, walEnabled bool) (*DB, error) { w := &DB{ - path: path, - platform: platform, - version: version, + path: path, + platform: platform, + version: version, + walEnabled: walEnabled, } // 初始化,加载数据库文件信息 @@ -45,7 +47,7 @@ func (w *DB) Close() error { func (w *DB) Initialize() error { var err error - w.ds, err = datasource.New(w.path, w.platform, w.version) + w.ds, err = datasource.New(w.path, w.platform, w.version, w.walEnabled) if err != nil { return err } diff --git a/pkg/process/process.go b/pkg/process/process.go index b04805e..3ed6b41 100644 --- a/pkg/process/process.go +++ b/pkg/process/process.go @@ -15,37 +15,40 @@ import ( // If another instance is found, it prompts the user to force close it. // Returns a cleanup function to be called on exit. func CheckSingleInstance(workDir string) (func(), error) { + if err := os.MkdirAll(workDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create work dir: %w", err) + } pidFile := filepath.Join(workDir, "chatlog.pid") // Read existing PID file if content, err := os.ReadFile(pidFile); err == nil { pidStr := strings.TrimSpace(string(content)) if pid, err := strconv.Atoi(pidStr); err == nil { - // Check if process exists - if exists, _ := process.PidExists(int32(pid)); exists { - // Process exists, check if it's really us (optional, but good practice) - // For now, just assume if PID exists it might be us or a zombie. - // We can check process name if needed, but pid file is strong hint. - - fmt.Printf("Detected another instance running (PID: %d).\n", pid) - fmt.Print("Do you want to force close it and continue? [y/N]: ") - - reader := bufio.NewReader(os.Stdin) - input, _ := reader.ReadString('\n') - input = strings.TrimSpace(strings.ToLower(input)) + if pid != os.Getpid() { + if exists, _ := process.PidExists(int32(pid)); exists { + if isSameExecutable(pid) { + fmt.Printf("Detected another instance running (PID: %d).\n", pid) + fmt.Print("Do you want to force close it and continue? [y/N]: ") - if input == "y" || input == "yes" { - if p, err := process.NewProcess(int32(pid)); err == nil { - if err := p.Kill(); err != nil { - return nil, fmt.Errorf("failed to kill process: %w", err) + reader := bufio.NewReader(os.Stdin) + input, _ := reader.ReadString('\n') + input = strings.TrimSpace(strings.ToLower(input)) + + if input == "y" || input == "yes" { + if p, err := process.NewProcess(int32(pid)); err == nil { + if err := p.Kill(); err != nil { + return nil, fmt.Errorf("failed to kill process: %w", err) + } + fmt.Println("Process killed.") + } else { + fmt.Println("Process not found, continuing...") + } + } else { + return nil, fmt.Errorf("application already running") } - fmt.Println("Process killed.") } else { - // Process might have exited in the meantime - fmt.Println("Process not found, continuing...") + os.Remove(pidFile) } - } else { - return nil, fmt.Errorf("application already running") } } } @@ -63,3 +66,32 @@ func CheckSingleInstance(workDir string) (func(), error) { }, nil } + +func isSameExecutable(pid int) bool { + currentExe, err := os.Executable() + if err != nil { + currentExe = "" + } + currentBase := strings.ToLower(filepath.Base(currentExe)) + + p, err := process.NewProcess(int32(pid)) + if err != nil { + return false + } + if exe, err := p.Exe(); err == nil && exe != "" { + if currentExe != "" && strings.EqualFold(exe, currentExe) { + return true + } + if currentBase != "" && strings.EqualFold(filepath.Base(exe), currentBase) { + return true + } + } + if name, err := p.Name(); err == nil && name != "" { + name = strings.TrimSuffix(strings.ToLower(name), ".exe") + base := strings.TrimSuffix(currentBase, ".exe") + if base != "" && name == base { + return true + } + } + return false +}