mirror of
https://github.com/teest114514/chatlog_alpha.git
synced 2026-03-20 17:07:50 +08:00
同步本地代码
This commit is contained in:
25
pkg/appver/version.go
Normal file
25
pkg/appver/version.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package appver
|
||||
|
||||
type Info struct {
|
||||
FilePath string `json:"file_path"`
|
||||
CompanyName string `json:"company_name"`
|
||||
FileDescription string `json:"file_description"`
|
||||
Version int `json:"version"`
|
||||
FullVersion string `json:"full_version"`
|
||||
LegalCopyright string `json:"legal_copyright"`
|
||||
ProductName string `json:"product_name"`
|
||||
ProductVersion string `json:"product_version"`
|
||||
}
|
||||
|
||||
func New(filePath string) (*Info, error) {
|
||||
i := &Info{
|
||||
FilePath: filePath,
|
||||
}
|
||||
|
||||
err := i.initialize()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
41
pkg/appver/version_darwin.go
Normal file
41
pkg/appver/version_darwin.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package appver
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"howett.net/plist"
|
||||
)
|
||||
|
||||
const (
|
||||
InfoFile = "Info.plist"
|
||||
)
|
||||
|
||||
type Plist struct {
|
||||
CFBundleShortVersionString string `plist:"CFBundleShortVersionString"`
|
||||
NSHumanReadableCopyright string `plist:"NSHumanReadableCopyright"`
|
||||
}
|
||||
|
||||
func (i *Info) initialize() error {
|
||||
|
||||
parts := strings.Split(i.FilePath, string(filepath.Separator))
|
||||
file := filepath.Join(append(parts[:len(parts)-2], InfoFile)...)
|
||||
b, err := os.ReadFile("/" + file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := Plist{}
|
||||
_, err = plist.Unmarshal(b, &p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
i.FullVersion = p.CFBundleShortVersionString
|
||||
i.Version, _ = strconv.Atoi(strings.Split(i.FullVersion, ".")[0])
|
||||
i.CompanyName = p.NSHumanReadableCopyright
|
||||
|
||||
return nil
|
||||
}
|
||||
7
pkg/appver/version_others.go
Normal file
7
pkg/appver/version_others.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package appver
|
||||
|
||||
func (i *Info) initialize() error {
|
||||
return nil
|
||||
}
|
||||
142
pkg/appver/version_windows.go
Normal file
142
pkg/appver/version_windows.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package appver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
modversion = syscall.NewLazyDLL("version.dll")
|
||||
procGetFileVersionInfoSize = modversion.NewProc("GetFileVersionInfoSizeW")
|
||||
procGetFileVersionInfo = modversion.NewProc("GetFileVersionInfoW")
|
||||
procVerQueryValue = modversion.NewProc("VerQueryValueW")
|
||||
)
|
||||
|
||||
// VS_FIXEDFILEINFO 结构体
|
||||
type VS_FIXEDFILEINFO struct {
|
||||
Signature uint32
|
||||
StrucVersion uint32
|
||||
FileVersionMS uint32
|
||||
FileVersionLS uint32
|
||||
ProductVersionMS uint32
|
||||
ProductVersionLS uint32
|
||||
FileFlagsMask uint32
|
||||
FileFlags uint32
|
||||
FileOS uint32
|
||||
FileType uint32
|
||||
FileSubtype uint32
|
||||
FileDateMS uint32
|
||||
FileDateLS uint32
|
||||
}
|
||||
|
||||
// initialize 初始化版本信息
|
||||
func (i *Info) initialize() error {
|
||||
// 转换路径为 UTF16
|
||||
pathPtr, err := syscall.UTF16PtrFromString(i.FilePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取版本信息大小
|
||||
var handle uintptr
|
||||
size, _, err := procGetFileVersionInfoSize.Call(
|
||||
uintptr(unsafe.Pointer(pathPtr)),
|
||||
uintptr(unsafe.Pointer(&handle)),
|
||||
)
|
||||
if size == 0 {
|
||||
return fmt.Errorf("GetFileVersionInfoSize failed: %v", err)
|
||||
}
|
||||
|
||||
// 分配内存
|
||||
verInfo := make([]byte, size)
|
||||
ret, _, err := procGetFileVersionInfo.Call(
|
||||
uintptr(unsafe.Pointer(pathPtr)),
|
||||
0,
|
||||
size,
|
||||
uintptr(unsafe.Pointer(&verInfo[0])),
|
||||
)
|
||||
if ret == 0 {
|
||||
return fmt.Errorf("GetFileVersionInfo failed: %v", err)
|
||||
}
|
||||
|
||||
// 获取固定的文件信息
|
||||
var fixedFileInfo *VS_FIXEDFILEINFO
|
||||
var uLen uint32
|
||||
rootPtr, _ := syscall.UTF16PtrFromString("\\")
|
||||
ret, _, err = procVerQueryValue.Call(
|
||||
uintptr(unsafe.Pointer(&verInfo[0])),
|
||||
uintptr(unsafe.Pointer(rootPtr)),
|
||||
uintptr(unsafe.Pointer(&fixedFileInfo)),
|
||||
uintptr(unsafe.Pointer(&uLen)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return fmt.Errorf("VerQueryValue failed: %v", err)
|
||||
}
|
||||
|
||||
// 解析文件版本
|
||||
i.FullVersion = fmt.Sprintf("%d.%d.%d.%d",
|
||||
(fixedFileInfo.FileVersionMS>>16)&0xffff,
|
||||
(fixedFileInfo.FileVersionMS>>0)&0xffff,
|
||||
(fixedFileInfo.FileVersionLS>>16)&0xffff,
|
||||
(fixedFileInfo.FileVersionLS>>0)&0xffff,
|
||||
)
|
||||
i.Version = int((fixedFileInfo.FileVersionMS >> 16) & 0xffff)
|
||||
|
||||
i.ProductVersion = fmt.Sprintf("%d.%d.%d.%d",
|
||||
(fixedFileInfo.ProductVersionMS>>16)&0xffff,
|
||||
(fixedFileInfo.ProductVersionMS>>0)&0xffff,
|
||||
(fixedFileInfo.ProductVersionLS>>16)&0xffff,
|
||||
(fixedFileInfo.ProductVersionLS>>0)&0xffff,
|
||||
)
|
||||
|
||||
// 获取翻译信息
|
||||
type langAndCodePage struct {
|
||||
language uint16
|
||||
codePage uint16
|
||||
}
|
||||
|
||||
var lpTranslate *langAndCodePage
|
||||
var cbTranslate uint32
|
||||
transPtr, _ := syscall.UTF16PtrFromString("\\VarFileInfo\\Translation")
|
||||
ret, _, _ = procVerQueryValue.Call(
|
||||
uintptr(unsafe.Pointer(&verInfo[0])),
|
||||
uintptr(unsafe.Pointer(transPtr)),
|
||||
uintptr(unsafe.Pointer(&lpTranslate)),
|
||||
uintptr(unsafe.Pointer(&cbTranslate)),
|
||||
)
|
||||
|
||||
if ret != 0 && cbTranslate > 0 {
|
||||
// 获取所有需要的字符串信息
|
||||
stringInfos := map[string]*string{
|
||||
"CompanyName": &i.CompanyName,
|
||||
"FileDescription": &i.FileDescription,
|
||||
"FileVersion": &i.FullVersion,
|
||||
"LegalCopyright": &i.LegalCopyright,
|
||||
"ProductName": &i.ProductName,
|
||||
"ProductVersion": &i.ProductVersion,
|
||||
}
|
||||
|
||||
for name, ptr := range stringInfos {
|
||||
subBlock := fmt.Sprintf("\\StringFileInfo\\%04x%04x\\%s",
|
||||
lpTranslate.language, lpTranslate.codePage, name)
|
||||
|
||||
subBlockPtr, _ := syscall.UTF16PtrFromString(subBlock)
|
||||
var buffer *uint16
|
||||
var bufLen uint32
|
||||
|
||||
ret, _, _ = procVerQueryValue.Call(
|
||||
uintptr(unsafe.Pointer(&verInfo[0])),
|
||||
uintptr(unsafe.Pointer(subBlockPtr)),
|
||||
uintptr(unsafe.Pointer(&buffer)),
|
||||
uintptr(unsafe.Pointer(&bufLen)),
|
||||
)
|
||||
|
||||
if ret != 0 && bufLen > 0 {
|
||||
*ptr = syscall.UTF16ToString((*[1 << 20]uint16)(unsafe.Pointer(buffer))[:bufLen:bufLen])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
159
pkg/config/config.go
Normal file
159
pkg/config/config.go
Normal file
@@ -0,0 +1,159 @@
|
||||
/*
|
||||
* Copyright (c) 2023 shenjunzheng@gmail.com
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultConfigType = "json"
|
||||
)
|
||||
|
||||
var (
|
||||
// ERROR
|
||||
ErrInvalidDirectory = errors.New("invalid directory path")
|
||||
ErrMissingConfigName = errors.New("config name not specified")
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
App string
|
||||
EnvPrefix string
|
||||
Path string
|
||||
Name string
|
||||
WriteConfig bool
|
||||
|
||||
Viper *viper.Viper
|
||||
}
|
||||
|
||||
// New initializes the configuration settings.
|
||||
// It sets up the name, type, and path for the configuration file.
|
||||
func New(app, path, name, envPrefix string, writeConfig bool) (*Manager, error) {
|
||||
if len(app) == 0 {
|
||||
return nil, ErrMissingConfigName
|
||||
}
|
||||
|
||||
v := viper.New()
|
||||
v.SetConfigType(DefaultConfigType)
|
||||
var err error
|
||||
|
||||
// Path
|
||||
if len(path) == 0 {
|
||||
path, err = os.UserHomeDir()
|
||||
if err != nil {
|
||||
path = os.TempDir()
|
||||
}
|
||||
path += string(os.PathSeparator) + "." + app
|
||||
}
|
||||
if err := PrepareDir(path); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
v.AddConfigPath(path)
|
||||
|
||||
// Name
|
||||
if len(name) == 0 {
|
||||
name = app
|
||||
}
|
||||
v.SetConfigName(name)
|
||||
|
||||
// Env
|
||||
if len(envPrefix) != 0 {
|
||||
v.SetEnvPrefix(strings.ToUpper(app))
|
||||
v.AutomaticEnv()
|
||||
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
}
|
||||
|
||||
return &Manager{
|
||||
App: app,
|
||||
EnvPrefix: envPrefix,
|
||||
Path: path,
|
||||
Name: name,
|
||||
Viper: v,
|
||||
WriteConfig: writeConfig,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Load loads the configuration from the previously initialized file.
|
||||
// It unmarshals the configuration into the provided conf interface.
|
||||
func (c *Manager) Load(conf interface{}) error {
|
||||
if err := c.Viper.ReadInConfig(); err != nil {
|
||||
log.Error().Err(err).Msg("read config failed")
|
||||
if c.WriteConfig {
|
||||
if err := c.Viper.SafeWriteConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := c.Viper.Unmarshal(conf, decoderConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFile loads the configuration from a specified file.
|
||||
// It unmarshals the configuration into the provided conf interface.
|
||||
func (c *Manager) LoadFile(file string, conf interface{}) error {
|
||||
c.Viper.SetConfigFile(file)
|
||||
if err := c.Viper.ReadInConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.Viper.Unmarshal(conf, decoderConfig()); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetConfig sets a configuration key to a specified value.
|
||||
// It also writes the updated configuration back to the file.
|
||||
func (c *Manager) SetConfig(key string, value interface{}) error {
|
||||
c.Viper.Set(key, value)
|
||||
if c.WriteConfig {
|
||||
if err := c.Viper.WriteConfig(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetConfig retrieves all configuration settings as a map.
|
||||
func (c *Manager) GetConfig() map[string]interface{} {
|
||||
return c.Viper.AllSettings()
|
||||
}
|
||||
|
||||
// PrepareDir ensures that the specified directory path exists.
|
||||
// If the directory does not exist, it attempts to create it.
|
||||
func PrepareDir(path string) error {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else if !stat.IsDir() {
|
||||
log.Debug().Msgf("%s is not a directory", path)
|
||||
return ErrInvalidDirectory
|
||||
}
|
||||
return nil
|
||||
}
|
||||
33
pkg/config/default.go
Normal file
33
pkg/config/default.go
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
* Copyright (c) 2023 shenjunzheng@gmail.com
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
func SetDefaults(v *viper.Viper, c any, defaults map[string]any) {
|
||||
keys := GetStructKeys(reflect.TypeOf(c), "mapstructure", "squash")
|
||||
for _, key := range keys {
|
||||
v.SetDefault(key, nil)
|
||||
}
|
||||
for key := range defaults {
|
||||
v.SetDefault(key, defaults[key])
|
||||
}
|
||||
}
|
||||
149
pkg/config/struct_kerys.go
Normal file
149
pkg/config/struct_kerys.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const sep = "."
|
||||
|
||||
// GetStructKeys returns all keys in a nested struct type, taking the name from the tag name or
|
||||
// the field name. It handles an additional suffix squashValue like mapstructure does: if
|
||||
// present on an embedded struct, name components for that embedded struct should not be
|
||||
// included. It does not handle maps, does chase pointers, but does not check for loops in
|
||||
// nesting.
|
||||
func GetStructKeys(typ reflect.Type, tag, squashValue string) []string {
|
||||
return appendStructKeys(typ, tag, ","+squashValue, nil, nil)
|
||||
}
|
||||
|
||||
// appendStructKeys recursively appends to keys all keys of nested struct type typ, taking tag
|
||||
// and squashValue from GetStructKeys. prefix holds all components of the path from the typ
|
||||
// passed to GetStructKeys down to this typ.
|
||||
func appendStructKeys(typ reflect.Type, tag, squashValue string, prefix []string, keys []string) []string {
|
||||
// Dereference any pointers. This is a finite loop: Go types are well-founded.
|
||||
for ; typ.Kind() == reflect.Ptr; typ = typ.Elem() {
|
||||
}
|
||||
|
||||
// Handle only struct containers; terminate the recursion on anything else.
|
||||
if typ.Kind() != reflect.Struct {
|
||||
return append(keys, strings.Join(prefix, sep))
|
||||
}
|
||||
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
fieldType := typ.Field(i)
|
||||
var (
|
||||
// fieldName is the name to use for the field.
|
||||
fieldName string
|
||||
// If squash is true, squash the sub-struct no additional accessor.
|
||||
squash bool
|
||||
ok bool
|
||||
)
|
||||
if fieldName, ok = fieldType.Tag.Lookup(tag); ok {
|
||||
if strings.HasSuffix(fieldName, squashValue) {
|
||||
squash = true
|
||||
fieldName = strings.TrimSuffix(fieldName, squashValue)
|
||||
}
|
||||
} else {
|
||||
fieldName = strings.ToLower(fieldType.Name)
|
||||
}
|
||||
// Update prefix to recurse into this field.
|
||||
if !squash {
|
||||
prefix = append(prefix, fieldName)
|
||||
}
|
||||
keys = appendStructKeys(fieldType.Type, tag, squashValue, prefix, keys)
|
||||
// Restore prefix.
|
||||
if !squash {
|
||||
prefix = prefix[:len(prefix)-1]
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// ValidateMissingRequiredKeys returns all keys of value in GetStructKeys format that have an
|
||||
// additional required tag set but are unset.
|
||||
func ValidateMissingRequiredKeys(value interface{}, tag, squashValue string) []string {
|
||||
return appendStructKeysIfZero(reflect.ValueOf(value), tag, ","+squashValue, "validate", "required", nil, nil)
|
||||
}
|
||||
|
||||
// isScalar returns true if kind is "scalar", i.e. has no Elem(). This
|
||||
// recites the list from reflect/type.go and may start to give incorrect
|
||||
// results if new kinds are added to the language.
|
||||
func isScalar(kind reflect.Kind) bool {
|
||||
switch kind {
|
||||
case reflect.Array:
|
||||
case reflect.Chan:
|
||||
case reflect.Map:
|
||||
case reflect.Ptr:
|
||||
case reflect.Slice:
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func appendStructKeysIfZero(value reflect.Value, tag, squashValue, validateTag, requiredValue string, prefix []string, keys []string) []string {
|
||||
// finite loop: Go types are well-founded.
|
||||
for value.Kind() == reflect.Ptr {
|
||||
if value.IsZero() { // If required, would already have errored out.
|
||||
return keys
|
||||
}
|
||||
value = value.Elem()
|
||||
}
|
||||
|
||||
if !isScalar(value.Kind()) {
|
||||
// Why use Type().Elem() when reflect.Value provides a perfectly good Elem()
|
||||
// method? The two are *not* the same, e.g. for nil pointers value.Elem() is
|
||||
// invalid and has no Kind(). (See https://play.golang.org/p/M3ZV19AZAW0)
|
||||
if !isScalar(value.Type().Elem().Kind()) {
|
||||
// TODO(ariels): Possible to add, but need to define the semantics. One
|
||||
// way might be to validate each field according to its type.
|
||||
panic("No support for detecting required keys inside " + value.Kind().String() + " of structs")
|
||||
}
|
||||
}
|
||||
|
||||
// Handle only struct containers; terminate the recursion on anything else.
|
||||
if value.Kind() != reflect.Struct {
|
||||
return keys
|
||||
}
|
||||
|
||||
for i := 0; i < value.NumField(); i++ {
|
||||
fieldType := value.Type().Field(i)
|
||||
fieldValue := value.Field(i)
|
||||
|
||||
var (
|
||||
// fieldName is the name to use for the field.
|
||||
fieldName string
|
||||
// If squash is true, squash the sub-struct no additional accessor.
|
||||
squash bool
|
||||
ok bool
|
||||
)
|
||||
if fieldName, ok = fieldType.Tag.Lookup(tag); ok {
|
||||
if strings.HasSuffix(fieldName, squashValue) {
|
||||
squash = true
|
||||
fieldName = strings.TrimSuffix(fieldName, squashValue)
|
||||
}
|
||||
} else {
|
||||
fieldName = strings.ToLower(fieldType.Name)
|
||||
}
|
||||
|
||||
// Perform any needed validations.
|
||||
if validationsString, ok := fieldType.Tag.Lookup(validateTag); ok {
|
||||
for _, validation := range strings.Split(validationsString, ",") {
|
||||
// Validate "required" field.
|
||||
if validation == requiredValue && fieldValue.IsZero() {
|
||||
keys = append(keys, strings.Join(append(prefix, fieldName), sep))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update prefix to recurse into this field.
|
||||
if !squash {
|
||||
prefix = append(prefix, fieldName)
|
||||
}
|
||||
keys = appendStructKeysIfZero(fieldValue, tag, squashValue, validateTag, requiredValue, prefix, keys)
|
||||
// Restore prefix.
|
||||
if !squash {
|
||||
prefix = prefix[:len(prefix)-1]
|
||||
}
|
||||
}
|
||||
return keys
|
||||
}
|
||||
121
pkg/config/types.go
Normal file
121
pkg/config/types.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
// DecodeStringToMap returns a DecodeHookFunc that converts a string to a map[string]string.
|
||||
// The string is expected to be a comma-separated list of key-value pairs, where the key and value
|
||||
// are separated by an equal sign.
|
||||
func DecodeStringToMap() mapstructure.DecodeHookFunc {
|
||||
return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
|
||||
// check if field is a string and target is a map
|
||||
if f != reflect.String || t != reflect.Map {
|
||||
return data, nil
|
||||
}
|
||||
// check if target is map[string]string
|
||||
if t != reflect.TypeOf(map[string]string{}).Kind() {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
raw := data.(string)
|
||||
if raw == "" {
|
||||
return map[string]string{}, nil
|
||||
}
|
||||
// parse raw string as key1=value1,key2=value2
|
||||
const pairSep = ","
|
||||
const valueSep = "="
|
||||
pairs := strings.Split(raw, pairSep)
|
||||
m := make(map[string]string, len(pairs))
|
||||
for _, pair := range pairs {
|
||||
key, value, found := strings.Cut(pair, valueSep)
|
||||
if !found {
|
||||
return nil, fmt.Errorf("invalid key-value pair: %s", pair)
|
||||
}
|
||||
m[strings.TrimSpace(key)] = strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
// StringToSliceWithBracketHookFunc returns a DecodeHookFunc that converts a string to a slice of strings.
|
||||
// Useful when configuration values are provided as JSON arrays in string form, but need to be parsed into slices.
|
||||
// The string is expected to be a JSON array.
|
||||
// If the string is empty, an empty slice is returned.
|
||||
// If the string cannot be parsed as a JSON array, the original data is returned unchanged.
|
||||
func StringToSliceWithBracketHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f reflect.Kind, t reflect.Kind, data interface{}) (interface{}, error) {
|
||||
if f != reflect.String || t != reflect.Slice {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
raw := data.(string)
|
||||
if raw == "" {
|
||||
return []string{}, nil
|
||||
}
|
||||
var result any
|
||||
err := json.Unmarshal([]byte(raw), &result)
|
||||
if err != nil {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// Verify that the result matches the target (slice)
|
||||
if reflect.TypeOf(result).Kind() != t {
|
||||
return data, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
// StringToStructHookFunc returns a DecodeHookFunc that converts a string to a struct.
|
||||
// Useful for parsing configuration values that are provided as JSON strings but need to be converted to sturcts.
|
||||
// The string is expected to be a JSON object that can be unmarshaled into the target struct.
|
||||
// If the string is empty, a new instance of the target struct is returned.
|
||||
// If the string cannot be parsed as a JSON object, the original data is returned unchanged.
|
||||
func StringToStructHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f reflect.Type, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if f.Kind() != reflect.String ||
|
||||
(t.Kind() != reflect.Struct && !(t.Kind() == reflect.Pointer && t.Elem().Kind() == reflect.Struct)) {
|
||||
return data, nil
|
||||
}
|
||||
raw := data.(string)
|
||||
var val reflect.Value
|
||||
// Struct or the pointer to a struct
|
||||
if t.Kind() == reflect.Struct {
|
||||
val = reflect.New(t)
|
||||
} else {
|
||||
val = reflect.New(t.Elem())
|
||||
}
|
||||
|
||||
if raw == "" {
|
||||
return val, nil
|
||||
}
|
||||
var m map[string]interface{}
|
||||
err := json.Unmarshal([]byte(raw), &m)
|
||||
if err != nil {
|
||||
return data, nil
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
}
|
||||
|
||||
// CompositeDecodeHook 组合所有解码钩子
|
||||
func CompositeDecodeHook() mapstructure.DecodeHookFunc {
|
||||
return mapstructure.ComposeDecodeHookFunc(
|
||||
mapstructure.StringToTimeDurationHookFunc(),
|
||||
DecodeStringToMap(),
|
||||
StringToStructHookFunc(),
|
||||
StringToSliceWithBracketHookFunc(),
|
||||
)
|
||||
}
|
||||
|
||||
func decoderConfig() viper.DecoderConfigOption {
|
||||
return viper.DecodeHook(CompositeDecodeHook())
|
||||
}
|
||||
1114
pkg/filecopy/filecopy.go
Normal file
1114
pkg/filecopy/filecopy.go
Normal file
File diff suppressed because it is too large
Load Diff
182
pkg/filemonitor/filegroup.go
Normal file
182
pkg/filemonitor/filegroup.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package filemonitor
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FileChangeCallback defines the callback function signature for file change events
|
||||
type FileChangeCallback func(event fsnotify.Event) error
|
||||
|
||||
// FileGroup represents a group of files with the same processing logic
|
||||
type FileGroup struct {
|
||||
ID string // Unique identifier
|
||||
RootDir string // Root directory
|
||||
Pattern *regexp.Regexp // File matching pattern
|
||||
PatternStr string // Original pattern string for rebuilding
|
||||
Blacklist []string // Blacklist patterns
|
||||
Callbacks []FileChangeCallback // File change callbacks
|
||||
mutex sync.RWMutex // Concurrency control
|
||||
}
|
||||
|
||||
// NewFileGroup creates a new file group
|
||||
func NewFileGroup(id, rootDir, pattern string, blacklist []string) (*FileGroup, error) {
|
||||
// Compile the regular expression
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid pattern '%s': %w", pattern, err)
|
||||
}
|
||||
|
||||
// Normalize root directory path
|
||||
rootDir = filepath.Clean(rootDir)
|
||||
|
||||
return &FileGroup{
|
||||
ID: id,
|
||||
RootDir: rootDir,
|
||||
Pattern: re,
|
||||
PatternStr: pattern,
|
||||
Blacklist: blacklist,
|
||||
Callbacks: []FileChangeCallback{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AddCallback adds a callback function to the file group
|
||||
func (fg *FileGroup) AddCallback(callback FileChangeCallback) {
|
||||
fg.mutex.Lock()
|
||||
defer fg.mutex.Unlock()
|
||||
|
||||
fg.Callbacks = append(fg.Callbacks, callback)
|
||||
}
|
||||
|
||||
// RemoveCallback removes a callback function from the file group
|
||||
func (fg *FileGroup) RemoveCallback(callbackToRemove FileChangeCallback) bool {
|
||||
fg.mutex.Lock()
|
||||
defer fg.mutex.Unlock()
|
||||
|
||||
for i, callback := range fg.Callbacks {
|
||||
// Compare function addresses
|
||||
if fmt.Sprintf("%p", callback) == fmt.Sprintf("%p", callbackToRemove) {
|
||||
// Remove the callback
|
||||
fg.Callbacks = append(fg.Callbacks[:i], fg.Callbacks[i+1:]...)
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// Match checks if a file path matches this group's criteria
|
||||
func (fg *FileGroup) Match(path string) bool {
|
||||
// Normalize paths for comparison
|
||||
path = filepath.Clean(path)
|
||||
rootDir := filepath.Clean(fg.RootDir)
|
||||
|
||||
// Check if path is under root directory
|
||||
// Use filepath.Rel to handle path comparison safely across different OSes
|
||||
relPath, err := filepath.Rel(rootDir, path)
|
||||
if err != nil || strings.HasPrefix(relPath, "..") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if file matches pattern
|
||||
if !fg.Pattern.MatchString(filepath.Base(path)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check blacklist
|
||||
for _, blackItem := range fg.Blacklist {
|
||||
if strings.Contains(relPath, blackItem) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// List returns a list of files in the group (real-time scan)
|
||||
func (fg *FileGroup) List() ([]string, error) {
|
||||
files := []string{}
|
||||
|
||||
// Scan directory for matching files using fs.WalkDir
|
||||
err := fs.WalkDir(os.DirFS(fg.RootDir), ".", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return fs.SkipDir
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Convert relative path to absolute
|
||||
absPath := filepath.Join(fg.RootDir, path)
|
||||
|
||||
// Use Match function to check if file belongs to this group
|
||||
if fg.Match(absPath) {
|
||||
files = append(files, absPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, fmt.Errorf("error listing files: %w", err)
|
||||
}
|
||||
|
||||
return files, nil
|
||||
}
|
||||
|
||||
// ListMatchingDirectories returns directories containing matching files
|
||||
func (fg *FileGroup) ListMatchingDirectories() (map[string]bool, error) {
|
||||
directories := make(map[string]bool)
|
||||
|
||||
// Get matching files
|
||||
files, err := fg.List()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Extract directories from matching files
|
||||
for _, file := range files {
|
||||
dir := filepath.Dir(file)
|
||||
directories[dir] = true
|
||||
}
|
||||
|
||||
return directories, nil
|
||||
}
|
||||
|
||||
// HandleEvent processes a file event and triggers callbacks if the file matches
|
||||
func (fg *FileGroup) HandleEvent(event fsnotify.Event) {
|
||||
// Check if this event is relevant for this group
|
||||
if !fg.Match(event.Name) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get callbacks under read lock
|
||||
fg.mutex.RLock()
|
||||
callbacks := make([]FileChangeCallback, len(fg.Callbacks))
|
||||
copy(callbacks, fg.Callbacks)
|
||||
fg.mutex.RUnlock()
|
||||
|
||||
// Asynchronously call callbacks
|
||||
for _, callback := range callbacks {
|
||||
go func(cb FileChangeCallback) {
|
||||
if err := cb(event); err != nil {
|
||||
log.Error().
|
||||
Str("file", event.Name).
|
||||
Str("op", event.Op.String()).
|
||||
Err(err).
|
||||
Msg("Callback error")
|
||||
}
|
||||
}(callback)
|
||||
}
|
||||
}
|
||||
430
pkg/filemonitor/filemonitor.go
Normal file
430
pkg/filemonitor/filemonitor.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package filemonitor
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FileMonitor manages multiple file groups
|
||||
type FileMonitor struct {
|
||||
groups map[string]*FileGroup // Map of file groups
|
||||
watcher *fsnotify.Watcher // File system watcher
|
||||
watchDirs map[string]bool // Monitored directories
|
||||
blacklist []string // Global blacklist patterns
|
||||
mutex sync.RWMutex // Concurrency control for groups and watchDirs
|
||||
stopCh chan struct{} // Stop signal
|
||||
wg sync.WaitGroup // Wait group
|
||||
isRunning bool // Running state flag
|
||||
stateMutex sync.RWMutex // State mutex
|
||||
}
|
||||
|
||||
func (fm *FileMonitor) Watcher() *fsnotify.Watcher {
|
||||
return fm.watcher
|
||||
}
|
||||
|
||||
// NewFileMonitor creates a new file monitor
|
||||
func NewFileMonitor() *FileMonitor {
|
||||
return &FileMonitor{
|
||||
groups: make(map[string]*FileGroup),
|
||||
watchDirs: make(map[string]bool),
|
||||
blacklist: []string{},
|
||||
isRunning: false,
|
||||
}
|
||||
}
|
||||
|
||||
// SetBlacklist sets the global directory blacklist
|
||||
func (fm *FileMonitor) SetBlacklist(blacklist []string) {
|
||||
fm.mutex.Lock()
|
||||
defer fm.mutex.Unlock()
|
||||
|
||||
fm.blacklist = make([]string, len(blacklist))
|
||||
copy(fm.blacklist, blacklist)
|
||||
}
|
||||
|
||||
// AddGroup adds a new file group
|
||||
func (fm *FileMonitor) AddGroup(group *FileGroup) error {
|
||||
if group == nil {
|
||||
return errors.New("group cannot be nil")
|
||||
}
|
||||
|
||||
// First check if monitor is running
|
||||
isRunning := fm.IsRunning()
|
||||
|
||||
// Add group to monitor
|
||||
fm.mutex.Lock()
|
||||
// Check if ID already exists
|
||||
if _, exists := fm.groups[group.ID]; exists {
|
||||
fm.mutex.Unlock()
|
||||
return fmt.Errorf("group with ID '%s' already exists", group.ID)
|
||||
}
|
||||
// Add to monitor
|
||||
fm.groups[group.ID] = group
|
||||
fm.mutex.Unlock()
|
||||
|
||||
// If monitor is running, set up watching
|
||||
if isRunning {
|
||||
if err := fm.setupWatchForGroup(group); err != nil {
|
||||
// Remove group on failure
|
||||
fm.mutex.Lock()
|
||||
delete(fm.groups, group.ID)
|
||||
fm.mutex.Unlock()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateGroup creates and adds a new file group (convenience method)
|
||||
func (fm *FileMonitor) CreateGroup(id, rootDir, pattern string, blacklist []string) (*FileGroup, error) {
|
||||
// Create file group
|
||||
group, err := NewFileGroup(id, rootDir, pattern, blacklist)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add to monitor
|
||||
if err := fm.AddGroup(group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// RemoveGroup removes a file group
|
||||
func (fm *FileMonitor) RemoveGroup(id string) error {
|
||||
fm.mutex.Lock()
|
||||
defer fm.mutex.Unlock()
|
||||
|
||||
// Check if group exists
|
||||
_, exists := fm.groups[id]
|
||||
if !exists {
|
||||
return fmt.Errorf("group with ID '%s' does not exist", id)
|
||||
}
|
||||
|
||||
// Remove group
|
||||
delete(fm.groups, id)
|
||||
// log.Info().Str("groupID", id).Msg("Removed file group")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGroups returns a list of all file group IDs
|
||||
func (fm *FileMonitor) GetGroups() []*FileGroup {
|
||||
fm.mutex.RLock()
|
||||
defer fm.mutex.RUnlock()
|
||||
|
||||
groups := make([]*FileGroup, 0, len(fm.groups))
|
||||
for _, group := range fm.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
|
||||
return groups
|
||||
}
|
||||
|
||||
// GetGroup returns the specified file group
|
||||
func (fm *FileMonitor) GetGroup(id string) (*FileGroup, bool) {
|
||||
fm.mutex.RLock()
|
||||
defer fm.mutex.RUnlock()
|
||||
|
||||
group, exists := fm.groups[id]
|
||||
return group, exists
|
||||
}
|
||||
|
||||
// Start starts the file monitor
|
||||
func (fm *FileMonitor) Start() error {
|
||||
// Check if already running
|
||||
fm.stateMutex.Lock()
|
||||
if fm.isRunning {
|
||||
fm.stateMutex.Unlock()
|
||||
return errors.New("file monitor is already running")
|
||||
}
|
||||
|
||||
// Create new watcher
|
||||
watcher, err := fsnotify.NewWatcher()
|
||||
if err != nil {
|
||||
fm.stateMutex.Unlock()
|
||||
return fmt.Errorf("failed to create watcher: %w", err)
|
||||
}
|
||||
fm.watcher = watcher
|
||||
|
||||
// Reset stop channel
|
||||
fm.stopCh = make(chan struct{})
|
||||
|
||||
// Get groups to monitor (without holding the state lock)
|
||||
fm.mutex.RLock()
|
||||
groups := make([]*FileGroup, 0, len(fm.groups))
|
||||
for _, group := range fm.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
// Reset monitored directories
|
||||
fm.mutex.Lock()
|
||||
fm.watchDirs = make(map[string]bool)
|
||||
fm.mutex.Unlock()
|
||||
|
||||
// Mark as running before setting up watches
|
||||
fm.isRunning = true
|
||||
fm.stateMutex.Unlock()
|
||||
|
||||
// Set up monitoring for all groups (without holding any locks)
|
||||
for _, group := range groups {
|
||||
if err := fm.setupWatchForGroup(group); err != nil {
|
||||
// Clean up resources on failure
|
||||
_ = fm.watcher.Close()
|
||||
|
||||
// Reset running state
|
||||
fm.stateMutex.Lock()
|
||||
fm.watcher = nil
|
||||
fm.isRunning = false
|
||||
fm.stateMutex.Unlock()
|
||||
|
||||
return fmt.Errorf("failed to setup watch for group '%s': %w", group.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Start watch loop
|
||||
fm.wg.Add(1)
|
||||
go fm.watchLoop()
|
||||
|
||||
// log.Info().Msg("File monitor started")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the file monitor
|
||||
func (fm *FileMonitor) Stop() error {
|
||||
// Check if already stopped
|
||||
fm.stateMutex.Lock()
|
||||
if !fm.isRunning {
|
||||
fm.stateMutex.Unlock()
|
||||
return errors.New("file monitor is not running")
|
||||
}
|
||||
|
||||
// Get watcher reference before changing state
|
||||
watcher := fm.watcher
|
||||
|
||||
// Send stop signal
|
||||
close(fm.stopCh)
|
||||
|
||||
// Mark as not running
|
||||
fm.isRunning = false
|
||||
fm.stateMutex.Unlock()
|
||||
|
||||
// Wait for all goroutines to exit
|
||||
fm.wg.Wait()
|
||||
|
||||
// Close watcher
|
||||
if watcher != nil {
|
||||
if err := watcher.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close watcher: %w", err)
|
||||
}
|
||||
|
||||
fm.stateMutex.Lock()
|
||||
fm.watcher = nil
|
||||
fm.stateMutex.Unlock()
|
||||
}
|
||||
|
||||
// log.Info().Msg("File monitor stopped")
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsRunning returns whether the file monitor is running
|
||||
func (fm *FileMonitor) IsRunning() bool {
|
||||
fm.stateMutex.RLock()
|
||||
defer fm.stateMutex.RUnlock()
|
||||
return fm.isRunning
|
||||
}
|
||||
|
||||
// addWatchDir adds a directory to monitoring
|
||||
func (fm *FileMonitor) addWatchDir(dirPath string) error {
|
||||
// Check global blacklist first
|
||||
fm.mutex.RLock()
|
||||
for _, pattern := range fm.blacklist {
|
||||
if strings.Contains(dirPath, pattern) {
|
||||
fm.mutex.RUnlock()
|
||||
log.Debug().Str("dir", dirPath).Msg("Skipping blacklisted directory")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
fm.mutex.Lock()
|
||||
defer fm.mutex.Unlock()
|
||||
|
||||
// Check if directory is already being monitored
|
||||
if _, watched := fm.watchDirs[dirPath]; watched {
|
||||
return nil // Already monitored, no need to add again
|
||||
}
|
||||
|
||||
// Add to monitoring
|
||||
if err := fm.watcher.Add(dirPath); err != nil {
|
||||
return fmt.Errorf("failed to watch directory '%s': %w", dirPath, err)
|
||||
}
|
||||
|
||||
fm.watchDirs[dirPath] = true
|
||||
// log.Debug().Str("dir", dirPath).Msg("Added watch for directory")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupWatchForGroup sets up monitoring for a file group
|
||||
func (fm *FileMonitor) setupWatchForGroup(group *FileGroup) error {
|
||||
// Check if file monitor is running
|
||||
if !fm.IsRunning() {
|
||||
return errors.New("file monitor is not running")
|
||||
}
|
||||
|
||||
// Find directories containing matching files
|
||||
matchingDirs, err := group.ListMatchingDirectories()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list matching directories: %w", err)
|
||||
}
|
||||
|
||||
// Always watch the root directory to catch new files
|
||||
rootDir := filepath.Clean(group.RootDir)
|
||||
if err := fm.addWatchDir(rootDir); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Watch directories containing matching files
|
||||
for dir := range matchingDirs {
|
||||
if err := fm.addWatchDir(dir); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RefreshWatches updates the watched directories based on current matching files
|
||||
func (fm *FileMonitor) RefreshWatches() error {
|
||||
// Check if file monitor is running
|
||||
if !fm.IsRunning() {
|
||||
return errors.New("file monitor is not running")
|
||||
}
|
||||
|
||||
// Get groups to refresh
|
||||
fm.mutex.RLock()
|
||||
groups := make([]*FileGroup, 0, len(fm.groups))
|
||||
for _, group := range fm.groups {
|
||||
groups = append(groups, group)
|
||||
}
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
// Reset watched directories
|
||||
fm.mutex.Lock()
|
||||
oldWatchDirs := fm.watchDirs
|
||||
fm.watchDirs = make(map[string]bool)
|
||||
fm.mutex.Unlock()
|
||||
|
||||
// Setup watches for each group
|
||||
for _, group := range groups {
|
||||
if err := fm.setupWatchForGroup(group); err != nil {
|
||||
return fmt.Errorf("failed to refresh watches for group '%s': %w", group.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Remove watches for directories no longer needed
|
||||
for dir := range oldWatchDirs {
|
||||
fm.mutex.RLock()
|
||||
_, stillWatched := fm.watchDirs[dir]
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
if !stillWatched && fm.watcher != nil {
|
||||
_ = fm.watcher.Remove(dir)
|
||||
log.Debug().Str("dir", dir).Msg("Removed watch for directory")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// watchLoop monitors for file system events
|
||||
func (fm *FileMonitor) watchLoop() {
|
||||
defer fm.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-fm.stopCh:
|
||||
return
|
||||
|
||||
case event, ok := <-fm.watcher.Events:
|
||||
if !ok {
|
||||
// Channel closed, exit loop
|
||||
return
|
||||
}
|
||||
|
||||
// Handle directory creation events to add new watches
|
||||
info, err := os.Stat(event.Name)
|
||||
if err == nil && info.IsDir() && event.Op&(fsnotify.Create|fsnotify.Rename) != 0 {
|
||||
// Add new directory to monitoring
|
||||
if err := fm.addWatchDir(event.Name); err != nil {
|
||||
log.Error().
|
||||
Str("dir", event.Name).
|
||||
Err(err).
|
||||
Msg("Error watching new directory")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// For file creation/modification, check if we need to watch its directory
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
// Check if this file matches any group
|
||||
shouldWatch := false
|
||||
|
||||
fm.mutex.RLock()
|
||||
for _, group := range fm.groups {
|
||||
if group.Match(event.Name) {
|
||||
shouldWatch = true
|
||||
break
|
||||
}
|
||||
}
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
// If file matches, ensure its directory is watched
|
||||
if shouldWatch {
|
||||
dir := filepath.Dir(event.Name)
|
||||
if err := fm.addWatchDir(dir); err != nil {
|
||||
log.Error().
|
||||
Str("dir", dir).
|
||||
Err(err).
|
||||
Msg("Error watching directory of matching file")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Forward event to all groups
|
||||
fm.forwardEventToGroups(event)
|
||||
|
||||
case err, ok := <-fm.watcher.Errors:
|
||||
if !ok {
|
||||
// Channel closed, exit loop
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("Watcher error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// forwardEventToGroups forwards file events to matching groups
|
||||
func (fm *FileMonitor) forwardEventToGroups(event fsnotify.Event) {
|
||||
// Get a copy of groups to avoid holding lock during processing
|
||||
fm.mutex.RLock()
|
||||
groupsCopy := make([]*FileGroup, 0, len(fm.groups))
|
||||
for _, group := range fm.groups {
|
||||
groupsCopy = append(groupsCopy, group)
|
||||
}
|
||||
fm.mutex.RUnlock()
|
||||
|
||||
// Forward to all groups - each group will check if the event is relevant
|
||||
for _, group := range groupsCopy {
|
||||
group.HandleEvent(event)
|
||||
}
|
||||
}
|
||||
322
pkg/util/dat2img/dat2img.go
Normal file
322
pkg/util/dat2img/dat2img.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package dat2img
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Format defines the header and extension for different image types
|
||||
type Format struct {
|
||||
Header []byte
|
||||
AesKey []byte
|
||||
Ext string
|
||||
}
|
||||
|
||||
var (
|
||||
// Common image format definitions
|
||||
JPG = Format{Header: []byte{0xFF, 0xD8, 0xFF}, Ext: "jpg"}
|
||||
PNG = Format{Header: []byte{0x89, 0x50, 0x4E, 0x47}, Ext: "png"}
|
||||
GIF = Format{Header: []byte{0x47, 0x49, 0x46, 0x38}, Ext: "gif"}
|
||||
TIFF = Format{Header: []byte{0x49, 0x49, 0x2A, 0x00}, Ext: "tiff"}
|
||||
BMP = Format{Header: []byte{0x42, 0x4D}, Ext: "bmp"}
|
||||
WXGF = Format{Header: []byte{0x77, 0x78, 0x67, 0x66}, Ext: "wxgf"}
|
||||
Formats = []Format{JPG, PNG, GIF, TIFF, BMP, WXGF}
|
||||
|
||||
// Updated V4 definitions to match Dart implementation (6 bytes signature)
|
||||
// V4 Type 1: 0x07 0x08 0x56 0x31 0x08 0x07
|
||||
V4Format1 = Format{Header: []byte{0x07, 0x08, 0x56, 0x31, 0x08, 0x07}, AesKey: []byte("cfcd208495d565ef")}
|
||||
// V4 Type 2: 0x07 0x08 0x56 0x32 0x08 0x07
|
||||
V4Format2 = Format{Header: []byte{0x07, 0x08, 0x56, 0x32, 0x08, 0x07}, AesKey: []byte("0000000000000000")} // User needs to provide key
|
||||
V4Formats = []*Format{&V4Format1, &V4Format2}
|
||||
|
||||
// WeChat v4 related constants
|
||||
V4XorKey byte = 0x37 // Default XOR key for WeChat v4 dat files
|
||||
JpgTail = []byte{0xFF, 0xD9} // JPG file tail marker
|
||||
)
|
||||
|
||||
// Dat2Image converts WeChat dat file data to image data
|
||||
func Dat2Image(data []byte) ([]byte, string, error) {
|
||||
if len(data) < 4 {
|
||||
return nil, "", fmt.Errorf("data length is too short: %d", len(data))
|
||||
}
|
||||
|
||||
// Check if this is a WeChat v4 dat file (Check first 6 bytes)
|
||||
if len(data) >= 6 {
|
||||
for _, format := range V4Formats {
|
||||
if bytes.Equal(data[:6], format.Header) {
|
||||
return Dat2ImageV4(data, format.AesKey)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// For older WeChat versions (V3), use XOR decryption
|
||||
findFormat := func(data []byte, header []byte) bool {
|
||||
xorBit := data[0] ^ header[0]
|
||||
for i := 0; i < len(header); i++ {
|
||||
if data[i]^header[i] != xorBit {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var xorBit byte
|
||||
var found bool
|
||||
var ext string
|
||||
for _, format := range Formats {
|
||||
if found = findFormat(data, format.Header); found {
|
||||
xorBit = data[0] ^ format.Header[0]
|
||||
ext = format.Ext
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
// Fallback check: if no known header found, verify if it's a V4 file with only 4 bytes matching (loose check)
|
||||
// This handles cases where the file might be truncated or slightly different, but it's risky.
|
||||
// Original code checked 4 bytes, let's keep strict 6 bytes above, and if failed, maybe return error.
|
||||
return nil, "", fmt.Errorf("unknown image type: %x %x", data[0], data[1])
|
||||
}
|
||||
|
||||
// Apply XOR decryption (V3)
|
||||
out := make([]byte, len(data))
|
||||
for i := range data {
|
||||
out[i] = data[i] ^ xorBit
|
||||
}
|
||||
|
||||
return out, ext, nil
|
||||
}
|
||||
|
||||
// calculateXorKeyV4 calculates the XOR key for WeChat v4 dat files
|
||||
func calculateXorKeyV4(data []byte) (byte, error) {
|
||||
if len(data) < 2 {
|
||||
return 0, fmt.Errorf("data too short to calculate XOR key")
|
||||
}
|
||||
fileTail := data[len(data)-2:]
|
||||
xorKeys := make([]byte, 2)
|
||||
for i := 0; i < 2; i++ {
|
||||
xorKeys[i] = fileTail[i] ^ JpgTail[i]
|
||||
}
|
||||
if xorKeys[0] == xorKeys[1] {
|
||||
return xorKeys[0], nil
|
||||
}
|
||||
return xorKeys[0], fmt.Errorf("inconsistent XOR key, using first byte: 0x%x", xorKeys[0])
|
||||
}
|
||||
|
||||
// ScanAndSetXorKey scans a directory to calculate and set the global XOR key
|
||||
func ScanAndSetXorKey(dirPath string) (byte, error) {
|
||||
err := filepath.Walk(dirPath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(info.Name(), "_t.dat") {
|
||||
return nil
|
||||
}
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check header (looser check for scan is acceptable, or use exact)
|
||||
isV4 := false
|
||||
if len(data) >= 4 {
|
||||
if bytes.Equal(data[:4], V4Format1.Header[:4]) || bytes.Equal(data[:4], V4Format2.Header[:4]) {
|
||||
isV4 = true
|
||||
}
|
||||
}
|
||||
if !isV4 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(data) < 15 {
|
||||
return nil
|
||||
}
|
||||
|
||||
xorEncryptLen := binary.LittleEndian.Uint32(data[10:14])
|
||||
fileData := data[15:]
|
||||
|
||||
if xorEncryptLen == 0 || uint32(len(fileData)) <= uint32(len(fileData))-xorEncryptLen {
|
||||
return nil
|
||||
}
|
||||
|
||||
xorData := fileData[uint32(len(fileData))-xorEncryptLen:]
|
||||
key, err := calculateXorKeyV4(xorData)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
V4XorKey = key
|
||||
return filepath.SkipAll
|
||||
})
|
||||
|
||||
if err != nil && err != filepath.SkipAll {
|
||||
return V4XorKey, fmt.Errorf("error scanning directory: %v", err)
|
||||
}
|
||||
return V4XorKey, nil
|
||||
}
|
||||
|
||||
func SetAesKey(key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
// Dart implementation uses asciiKey16 (taking first 16 chars/bytes)
|
||||
// If the input is hex, decoding is fine, but if it's a raw string like "cfcd...",
|
||||
// we should handle it carefully. Assuming input is hex string here as per original Go.
|
||||
decoded, err := hex.DecodeString(key)
|
||||
if err != nil {
|
||||
// Fallback: if hex decode fails, use raw bytes if length is 16 (matching Dart logic partially)
|
||||
if len(key) == 16 {
|
||||
V4Format2.AesKey = []byte(key)
|
||||
return
|
||||
}
|
||||
log.Error().Err(err).Msg("invalid aes key")
|
||||
return
|
||||
}
|
||||
V4Format2.AesKey = decoded
|
||||
}
|
||||
|
||||
// Dat2ImageV4 processes WeChat v4 dat image files
|
||||
// Refactored to match Dart implementation logic
|
||||
func Dat2ImageV4(data []byte, aesKey []byte) ([]byte, string, error) {
|
||||
if len(data) < 15 {
|
||||
return nil, "", fmt.Errorf("data length is too short for WeChat v4 format")
|
||||
}
|
||||
|
||||
// 1. Parse Headers (Little Endian)
|
||||
// Offset 6-10: AES Encryption Length
|
||||
aesSize := binary.LittleEndian.Uint32(data[6:10])
|
||||
// Offset 10-14: XOR Encryption Length
|
||||
xorSize := binary.LittleEndian.Uint32(data[10:14])
|
||||
|
||||
// Skip header (15 bytes)
|
||||
fileData := data[15:]
|
||||
|
||||
// 2. AES Decryption Logic
|
||||
// Calculate aligned size: size + (16 - size % 16)
|
||||
// This ensures we read the full PKCS7 padded block
|
||||
alignedAesSize := aesSize + (16 - (aesSize % 16))
|
||||
|
||||
if uint32(len(fileData)) < alignedAesSize {
|
||||
return nil, "", fmt.Errorf("file data too short for declared AES length")
|
||||
}
|
||||
|
||||
// Split data: [AES Part] [Middle Raw Part] [XOR Part]
|
||||
aesPart := fileData[:alignedAesSize]
|
||||
remainingPart := fileData[alignedAesSize:]
|
||||
|
||||
var unpaddedAesData []byte
|
||||
var err error
|
||||
|
||||
// Decrypt AES part
|
||||
if len(aesPart) > 0 {
|
||||
unpaddedAesData, err = decryptAESECBStrict(aesPart, aesKey)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("AES decryption failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Handle Middle and XOR Parts
|
||||
// XOR size validation
|
||||
if uint32(len(remainingPart)) < xorSize {
|
||||
return nil, "", fmt.Errorf("file data too short for declared XOR length")
|
||||
}
|
||||
|
||||
rawLen := uint32(len(remainingPart)) - xorSize
|
||||
rawMiddleData := remainingPart[:rawLen]
|
||||
xorTailData := remainingPart[rawLen:]
|
||||
|
||||
// Decrypt XOR part
|
||||
decryptedXorData := make([]byte, len(xorTailData))
|
||||
for i := range xorTailData {
|
||||
decryptedXorData[i] = xorTailData[i] ^ V4XorKey
|
||||
}
|
||||
|
||||
// 4. Reassemble: [Unpadded AES] + [Raw Middle] + [Decrypted XOR]
|
||||
// Pre-allocate exact size
|
||||
totalLen := len(unpaddedAesData) + len(rawMiddleData) + len(decryptedXorData)
|
||||
result := make([]byte, 0, totalLen)
|
||||
|
||||
result = append(result, unpaddedAesData...)
|
||||
result = append(result, rawMiddleData...)
|
||||
result = append(result, decryptedXorData...)
|
||||
|
||||
// Identify image type
|
||||
imgType := ""
|
||||
for _, format := range Formats {
|
||||
// Only check headers for image types, not V4 types
|
||||
if format.Ext == "wxgf" || format.Ext == "jpg" || format.Ext == "png" || format.Ext == "gif" || format.Ext == "tiff" || format.Ext == "bmp" {
|
||||
if len(result) >= len(format.Header) && bytes.Equal(result[:len(format.Header)], format.Header) {
|
||||
imgType = format.Ext
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if imgType == "wxgf" {
|
||||
return Wxam2pic(result)
|
||||
}
|
||||
|
||||
if imgType == "" {
|
||||
// Fallback detection (check first bytes if header match failed)
|
||||
if len(result) > 2 {
|
||||
return nil, "", fmt.Errorf("unknown image type after decryption: %x %x", result[0], result[1])
|
||||
}
|
||||
return nil, "", errors.New("unknown image type")
|
||||
}
|
||||
|
||||
return result, imgType, nil
|
||||
}
|
||||
|
||||
// decryptAESECBStrict decrypts data using AES in ECB mode and strictly removes PKCS7 padding
|
||||
// This matches Dart's _strictRemovePadding logic
|
||||
func decryptAESECBStrict(data, key []byte) ([]byte, error) {
|
||||
if len(data) == 0 {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
cipher, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(data)%aes.BlockSize != 0 {
|
||||
return nil, fmt.Errorf("data length %d is not a multiple of block size", len(data))
|
||||
}
|
||||
|
||||
decrypted := make([]byte, len(data))
|
||||
for bs, be := 0, aes.BlockSize; bs < len(data); bs, be = bs+aes.BlockSize, be+aes.BlockSize {
|
||||
cipher.Decrypt(decrypted[bs:be], data[bs:be])
|
||||
}
|
||||
|
||||
// Strict PKCS7 Unpadding
|
||||
length := len(decrypted)
|
||||
if length == 0 {
|
||||
return nil, errors.New("decrypted data is empty")
|
||||
}
|
||||
|
||||
paddingLen := int(decrypted[length-1])
|
||||
if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > length {
|
||||
return nil, fmt.Errorf("invalid PKCS7 padding length: %d", paddingLen)
|
||||
}
|
||||
|
||||
// Verify all padding bytes
|
||||
for i := length - paddingLen; i < length; i++ {
|
||||
if decrypted[i] != byte(paddingLen) {
|
||||
return nil, errors.New("invalid PKCS7 padding content")
|
||||
}
|
||||
}
|
||||
|
||||
return decrypted[:length-paddingLen], nil
|
||||
}
|
||||
72
pkg/util/dat2img/imgkey.go
Normal file
72
pkg/util/dat2img/imgkey.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package dat2img
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type AesKeyValidator struct {
|
||||
Path string
|
||||
EncryptedData []byte
|
||||
}
|
||||
|
||||
func NewImgKeyValidator(path string) *AesKeyValidator {
|
||||
validator := &AesKeyValidator{
|
||||
Path: path,
|
||||
}
|
||||
|
||||
// Walk the directory to find *.dat files (excluding *_t.dat files)
|
||||
filepath.Walk(path, func(filePath string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Skip directories
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only process *.dat files but exclude *_t.dat files
|
||||
if !strings.HasSuffix(info.Name(), ".dat") || strings.HasSuffix(info.Name(), "_t.dat") {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Read file content
|
||||
data, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if header matches V4Format2.Header
|
||||
// Get aes.BlockSize (16) bytes starting from position 15
|
||||
if len(data) >= 15+aes.BlockSize && bytes.Equal(data[:4], V4Format2.Header) {
|
||||
validator.EncryptedData = make([]byte, aes.BlockSize)
|
||||
copy(validator.EncryptedData, data[15:15+aes.BlockSize])
|
||||
return filepath.SkipAll // Found what we need, stop walking
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
return validator
|
||||
}
|
||||
|
||||
func (v *AesKeyValidator) Validate(key []byte) bool {
|
||||
if len(key) < 16 {
|
||||
return false
|
||||
}
|
||||
aesKey := key[:16]
|
||||
|
||||
cipher, err := aes.NewCipher(aesKey)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
decrypted := make([]byte, len(v.EncryptedData))
|
||||
cipher.Decrypt(decrypted, v.EncryptedData)
|
||||
|
||||
return bytes.HasPrefix(decrypted, JPG.Header) || bytes.HasPrefix(decrypted, WXGF.Header)
|
||||
}
|
||||
464
pkg/util/dat2img/wxgf.go
Normal file
464
pkg/util/dat2img/wxgf.go
Normal file
@@ -0,0 +1,464 @@
|
||||
package dat2img
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/Eyevinn/mp4ff/avc"
|
||||
"github.com/Eyevinn/mp4ff/bits"
|
||||
"github.com/Eyevinn/mp4ff/hevc"
|
||||
"github.com/Eyevinn/mp4ff/mp4"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
ENV_FFMPEG_PATH = "FFMPEG_PATH"
|
||||
MinRatio = 0.6
|
||||
FixSliceHeaders = true
|
||||
)
|
||||
|
||||
var (
|
||||
FFmpegMode = false
|
||||
FFMpegPath = "ffmpeg"
|
||||
)
|
||||
|
||||
func init() {
|
||||
ffmpegPath := os.Getenv(ENV_FFMPEG_PATH)
|
||||
if len(ffmpegPath) > 0 {
|
||||
FFmpegMode = true
|
||||
FFMpegPath = ffmpegPath
|
||||
}
|
||||
if isFFmpegAvailable() {
|
||||
FFmpegMode = true
|
||||
}
|
||||
}
|
||||
|
||||
func Wxam2pic(data []byte) ([]byte, string, error) {
|
||||
|
||||
if len(data) < 15 || !bytes.Equal(data[0:4], WXGF.Header) {
|
||||
return nil, "", fmt.Errorf("invalid wxgf")
|
||||
}
|
||||
|
||||
partitions, err := findDataPartition(data)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if partitions.LikeAnime() {
|
||||
// FIXME mask frame not work
|
||||
animeFrames := make([][]byte, 0)
|
||||
maskFrames := make([][]byte, 0)
|
||||
for i, partition := range partitions.Partitions {
|
||||
if i%2 == 0 {
|
||||
maskFrames = append(maskFrames, data[partition.Offset:partition.Offset+partition.Size])
|
||||
} else {
|
||||
animeFrames = append(animeFrames, data[partition.Offset:partition.Offset+partition.Size])
|
||||
}
|
||||
}
|
||||
if FFmpegMode {
|
||||
mp4Data, err := ConvertAnime2GIF(animeFrames, maskFrames)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return mp4Data, "gif", nil
|
||||
}
|
||||
gifData, err := ConvertAnime2GIF(animeFrames, maskFrames)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return gifData, "gif", nil
|
||||
}
|
||||
|
||||
offset := partitions.Partitions[partitions.MaxIndex].Offset
|
||||
size := partitions.Partitions[partitions.MaxIndex].Size
|
||||
|
||||
if FFmpegMode {
|
||||
jpgData, err := Convert2JPG(data[offset : offset+size])
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return jpgData, JPG.Ext, nil
|
||||
}
|
||||
return nil, "", fmt.Errorf("ffmpeg is not available, cannot convert this type of wxgf image")
|
||||
}
|
||||
|
||||
type Partitions struct {
|
||||
Partitions []Partition
|
||||
MaxRatio float64
|
||||
MaxIndex int
|
||||
}
|
||||
|
||||
func (p *Partitions) LikeAnime() bool {
|
||||
return len(p.Partitions) > 1 && p.MaxRatio < MinRatio
|
||||
}
|
||||
|
||||
type Partition struct {
|
||||
Offset int
|
||||
Size int
|
||||
Ratio float64
|
||||
}
|
||||
|
||||
func findDataPartition(data []byte) (*Partitions, error) {
|
||||
|
||||
headerLen := int(data[4])
|
||||
if headerLen >= len(data) {
|
||||
return nil, fmt.Errorf("invalid wxgf")
|
||||
}
|
||||
|
||||
patterns := [][]byte{
|
||||
{0x00, 0x00, 0x00, 0x01},
|
||||
{0x00, 0x00, 0x01},
|
||||
}
|
||||
|
||||
for _, pattern := range patterns {
|
||||
ret := &Partitions{
|
||||
Partitions: make([]Partition, 0),
|
||||
}
|
||||
offset := 0
|
||||
for {
|
||||
if headerLen+offset > len(data) {
|
||||
break
|
||||
}
|
||||
|
||||
index := bytes.Index(data[headerLen+offset:], pattern)
|
||||
if index == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
absIndex := headerLen + offset + index
|
||||
|
||||
if absIndex < 4 {
|
||||
offset += index + 1
|
||||
continue
|
||||
}
|
||||
|
||||
length := int(data[absIndex-4])<<24 | int(data[absIndex-3])<<16 |
|
||||
int(data[absIndex-2])<<8 | int(data[absIndex-1])
|
||||
|
||||
if length <= 0 || absIndex+length > len(data) {
|
||||
offset += index + 1
|
||||
continue
|
||||
}
|
||||
|
||||
partition := Partition{
|
||||
Offset: absIndex,
|
||||
Size: length,
|
||||
Ratio: float64(length) / float64(len(data)),
|
||||
}
|
||||
ret.Partitions = append(ret.Partitions, partition)
|
||||
if partition.Ratio > ret.MaxRatio {
|
||||
ret.MaxRatio = partition.Ratio
|
||||
ret.MaxIndex = len(ret.Partitions) - 1
|
||||
}
|
||||
offset += index + length
|
||||
}
|
||||
|
||||
if len(ret.Partitions) > 0 {
|
||||
return ret, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("no partition found")
|
||||
}
|
||||
|
||||
func Convert2JPG(data []byte) ([]byte, error) {
|
||||
cmd := exec.Command(FFMpegPath,
|
||||
"-i", "-",
|
||||
"-vframes", "1",
|
||||
"-c:v", "mjpeg",
|
||||
"-q:v", "4",
|
||||
"-f", "image2",
|
||||
"-")
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdin = bytes.NewReader(data)
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg failed: %w", err)
|
||||
}
|
||||
|
||||
jpegData := stdout.Bytes()
|
||||
if len(jpegData) == 0 {
|
||||
return nil, fmt.Errorf("ffmpeg output is empty")
|
||||
}
|
||||
|
||||
return jpegData, nil
|
||||
}
|
||||
|
||||
func writeTempFile(data [][]byte) (string, error) {
|
||||
path := filepath.Join(os.TempDir(), fmt.Sprintf("anime-%s", uuid.New().String()))
|
||||
file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0644)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to open anime temp file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
for _, frame := range data {
|
||||
_, err := file.Write(frame)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to write anime frame to temp file: %w", err)
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
// ConvertAnime2GIF convert anime frames and mask frames to mp4
|
||||
// FIXME No longer need to write to temporary files
|
||||
func ConvertAnime2GIF(animeFrames [][]byte, maskFrames [][]byte) ([]byte, error) {
|
||||
animeFilePath, err := writeTempFile(animeFrames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write anime temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(animeFilePath)
|
||||
|
||||
maskFilePath, err := writeTempFile(maskFrames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write mask temp file: %w", err)
|
||||
}
|
||||
defer os.Remove(maskFilePath)
|
||||
|
||||
cmd := exec.Command(FFMpegPath,
|
||||
"-i", animeFilePath,
|
||||
"-i", maskFilePath,
|
||||
"-filter_complex", "[0:v][1:v]alphamerge,split[s0][s1];[s0]palettegen[p];[s1][p]paletteuse",
|
||||
"-f", "gif",
|
||||
"-")
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
return nil, fmt.Errorf("ffmpeg failed: %w", err)
|
||||
}
|
||||
|
||||
gifData := stdout.Bytes()
|
||||
if len(gifData) == 0 {
|
||||
return nil, fmt.Errorf("ffmpeg output is empty")
|
||||
}
|
||||
|
||||
return gifData, nil
|
||||
}
|
||||
|
||||
func isFFmpegAvailable() bool {
|
||||
cmd := exec.Command(FFMpegPath, "-version")
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func Transmux2MP4(data []byte) ([]byte, error) {
|
||||
|
||||
vpsNALUs, spsNALUs, ppsNALUs := hevc.GetParameterSetsFromByteStream(data)
|
||||
|
||||
videoTimescale := uint32(1000)
|
||||
init := mp4.CreateEmptyInit()
|
||||
init.AddEmptyTrack(videoTimescale, "video", "und")
|
||||
|
||||
trak := init.Moov.Trak
|
||||
err := trak.SetHEVCDescriptor("hvc1", vpsNALUs, spsNALUs, ppsNALUs, nil, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seg := mp4.NewMediaSegment()
|
||||
seg.EncOptimize = mp4.OptimizeTrun
|
||||
frag, err := mp4.CreateFragment(1, mp4.DefaultTrakID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seg.AddFragment(frag)
|
||||
|
||||
sampleData := avc.ConvertByteStreamToNaluSample(data)
|
||||
sample := mp4.FullSample{
|
||||
Sample: mp4.Sample{
|
||||
Flags: 0x02000000,
|
||||
Dur: 1000,
|
||||
Size: uint32(len(sampleData)),
|
||||
CompositionTimeOffset: 0,
|
||||
},
|
||||
DecodeTime: 0,
|
||||
Data: sampleData,
|
||||
}
|
||||
|
||||
frag.AddFullSample(sample)
|
||||
|
||||
totalSize := init.Size() + seg.Size()
|
||||
sw := bits.NewFixedSliceWriter(int(totalSize))
|
||||
|
||||
init.EncodeSW(sw)
|
||||
seg.EncodeSW(sw)
|
||||
|
||||
return sw.Bytes(), nil
|
||||
}
|
||||
|
||||
func TransmuxAnime2MP4(animeFrames [][]byte, maskFrames [][]byte) ([]byte, error) {
|
||||
|
||||
if len(maskFrames) != len(animeFrames) {
|
||||
return nil, fmt.Errorf("mask frame num (%d) not equal to anime frame num (%d)", len(maskFrames), len(animeFrames))
|
||||
}
|
||||
|
||||
init := mp4.CreateEmptyInit()
|
||||
seg := mp4.NewMediaSegment()
|
||||
trackIDs := []uint32{1, 2}
|
||||
frag, err := mp4.CreateMultiTrackFragment(1, trackIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
seg.AddFragment(frag)
|
||||
|
||||
err = Add2Trak(init, frag, 0, animeFrames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add full sample to track failed: %w", err)
|
||||
}
|
||||
|
||||
err = Add2Trak(init, frag, 1, maskFrames)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("add full sample to track failed: %w", err)
|
||||
}
|
||||
|
||||
totalSize := init.Size() + seg.Size()
|
||||
sw := bits.NewFixedSliceWriter(int(totalSize))
|
||||
|
||||
init.EncodeSW(sw)
|
||||
seg.EncodeSW(sw)
|
||||
|
||||
return sw.Bytes(), nil
|
||||
}
|
||||
|
||||
func Add2Trak(init *mp4.InitSegment, frag *mp4.Fragment, index int, data [][]byte) error {
|
||||
videoTimescale := uint32(90000)
|
||||
init.AddEmptyTrack(videoTimescale, "video", "und")
|
||||
trak := init.Moov.Traks[index]
|
||||
|
||||
vps, sps, pps := hevc.GetParameterSetsFromByteStream(data[0])
|
||||
|
||||
// FIXME Two slices reporting being the first in the same frame.
|
||||
if FixSliceHeaders {
|
||||
spsMap, ppsMap, err := createSPSPPSMaps(vps, sps, pps)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sps pps map failed: %w", err)
|
||||
}
|
||||
for i := range data {
|
||||
fixedFrame, err := fixSliceHeadersInFrame(data[i], spsMap, ppsMap)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fix slice header failed: %w", err)
|
||||
}
|
||||
data[i] = fixedFrame
|
||||
}
|
||||
}
|
||||
|
||||
err := trak.SetHEVCDescriptor("hev1", vps, sps, pps, nil, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("set trak failed: %w", err)
|
||||
}
|
||||
|
||||
var decodeTime uint64 = 0
|
||||
frameDuration := uint32(3000)
|
||||
for i := 0; i < len(data); i++ {
|
||||
sampleData := avc.ConvertByteStreamToNaluSample(removeParameterSets(data[i]))
|
||||
simple := mp4.FullSample{
|
||||
Sample: mp4.Sample{
|
||||
Flags: getSampleFlags(sampleData, i == 0),
|
||||
Dur: frameDuration,
|
||||
Size: uint32(len(sampleData)),
|
||||
},
|
||||
DecodeTime: decodeTime,
|
||||
Data: sampleData,
|
||||
}
|
||||
|
||||
err = frag.AddFullSampleToTrack(simple, uint32(index+1))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodeTime += uint64(frameDuration)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSampleFlags(frameData []byte, isFirstFrame bool) uint32 {
|
||||
if isFirstFrame || hevc.IsRAPSample(frameData) {
|
||||
return 0x02000000
|
||||
}
|
||||
return 0x01010000
|
||||
}
|
||||
|
||||
func createSPSPPSMaps(vpsNalus, spsNalus, ppsNalus [][]byte) (map[uint32]*hevc.SPS, map[uint32]*hevc.PPS, error) {
|
||||
spsMap := make(map[uint32]*hevc.SPS)
|
||||
for _, spsNalu := range spsNalus {
|
||||
sps, err := hevc.ParseSPSNALUnit(spsNalu)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse sps failed: %w", err)
|
||||
}
|
||||
spsMap[uint32(sps.SpsID)] = sps
|
||||
}
|
||||
|
||||
ppsMap := make(map[uint32]*hevc.PPS)
|
||||
for _, ppsNalu := range ppsNalus {
|
||||
pps, err := hevc.ParsePPSNALUnit(ppsNalu, spsMap)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse pps failed: %w", err)
|
||||
}
|
||||
ppsMap[pps.PicParameterSetID] = pps
|
||||
}
|
||||
|
||||
return spsMap, ppsMap, nil
|
||||
}
|
||||
|
||||
func fixSliceHeadersInFrame(frameData []byte, spsMap map[uint32]*hevc.SPS, ppsMap map[uint32]*hevc.PPS) ([]byte, error) {
|
||||
nalus := avc.ExtractNalusFromByteStream(frameData)
|
||||
var fixedNalus [][]byte
|
||||
var firstSliceFound bool
|
||||
|
||||
for _, nalu := range nalus {
|
||||
naluType := hevc.GetNaluType(nalu[0])
|
||||
|
||||
if naluType < hevc.NALU_TRAIL_N || naluType > hevc.NALU_CRA {
|
||||
fixedNalus = append(fixedNalus, nalu)
|
||||
continue
|
||||
}
|
||||
|
||||
if !firstSliceFound {
|
||||
fixedNalus = append(fixedNalus, nalu)
|
||||
firstSliceFound = true
|
||||
} else {
|
||||
sliceHeader, err := hevc.ParseSliceHeader(nalu, spsMap, ppsMap)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse slice header failed: %w", err)
|
||||
}
|
||||
if !sliceHeader.FirstSliceSegmentInPicFlag {
|
||||
fixedNalus = append(fixedNalus, nalu)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return reconstructAnnexBStream(fixedNalus), nil
|
||||
}
|
||||
|
||||
func removeParameterSets(annexBData []byte) []byte {
|
||||
nalus := avc.ExtractNalusFromByteStream(annexBData)
|
||||
var videoNalus [][]byte
|
||||
|
||||
for _, nalu := range nalus {
|
||||
naluType := hevc.GetNaluType(nalu[0])
|
||||
if naluType <= 31 {
|
||||
videoNalus = append(videoNalus, nalu)
|
||||
}
|
||||
}
|
||||
|
||||
return reconstructAnnexBStream(videoNalus)
|
||||
}
|
||||
|
||||
func reconstructAnnexBStream(nalus [][]byte) []byte {
|
||||
var result []byte
|
||||
startCode := []byte{0x00, 0x00, 0x01}
|
||||
|
||||
for _, nalu := range nalus {
|
||||
result = append(result, startCode...)
|
||||
result = append(result, nalu...)
|
||||
}
|
||||
return result
|
||||
}
|
||||
16
pkg/util/lz4/lz4.go
Normal file
16
pkg/util/lz4/lz4.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package lz4
|
||||
|
||||
import (
|
||||
"github.com/pierrec/lz4/v4"
|
||||
)
|
||||
|
||||
func Decompress(src []byte) ([]byte, error) {
|
||||
// FIXME: lz4 的压缩率预计不到 3,这里设置了 4 保险一点
|
||||
out := make([]byte, len(src)*4)
|
||||
|
||||
n, err := lz4.UncompressBlock(src, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out[:n], nil
|
||||
}
|
||||
135
pkg/util/os.go
Normal file
135
pkg/util/os.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// FindFilesWithPatterns 在指定目录下查找匹配多个正则表达式的文件
|
||||
// directory: 要搜索的目录路径
|
||||
// patterns: 正则表达式模式列表
|
||||
// recursive: 是否递归搜索子目录
|
||||
// 返回匹配的文件路径列表和可能的错误
|
||||
func FindFilesWithPatterns(directory string, pattern string, recursive bool) ([]string, error) {
|
||||
// 编译所有正则表达式
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无效的正则表达式 '%s': %v", pattern, err)
|
||||
}
|
||||
|
||||
// 检查目录是否存在
|
||||
dirInfo, err := os.Stat(directory)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("无法访问目录 '%s': %v", directory, err)
|
||||
}
|
||||
if !dirInfo.IsDir() {
|
||||
return nil, fmt.Errorf("'%s' 不是一个目录", directory)
|
||||
}
|
||||
|
||||
// 存储匹配的文件路径
|
||||
var matchedFiles []string
|
||||
|
||||
// 创建文件系统
|
||||
fsys := os.DirFS(directory)
|
||||
|
||||
// 遍历文件系统
|
||||
err = fs.WalkDir(fsys, ".", func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果是目录且不递归,则跳过子目录
|
||||
if d.IsDir() {
|
||||
if !recursive && path != "." {
|
||||
return fs.SkipDir
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查文件名是否匹配任何一个正则表达式
|
||||
if re.MatchString(d.Name()) {
|
||||
// 添加完整路径到结果列表
|
||||
fullPath := filepath.Join(directory, path)
|
||||
matchedFiles = append(matchedFiles, fullPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("遍历目录时出错: %v", err)
|
||||
}
|
||||
|
||||
return matchedFiles, nil
|
||||
}
|
||||
|
||||
func DefaultWorkDir(account string) string {
|
||||
if len(account) == 0 {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return filepath.Join(os.ExpandEnv("${USERPROFILE}"), "Documents", "chatlog")
|
||||
case "darwin":
|
||||
return filepath.Join(os.ExpandEnv("${HOME}"), "Documents", "chatlog")
|
||||
default:
|
||||
return filepath.Join(os.ExpandEnv("${HOME}"), "chatlog")
|
||||
}
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return filepath.Join(os.ExpandEnv("${USERPROFILE}"), "Documents", "chatlog", account)
|
||||
case "darwin":
|
||||
return filepath.Join(os.ExpandEnv("${HOME}"), "Documents", "chatlog", account)
|
||||
default:
|
||||
return filepath.Join(os.ExpandEnv("${HOME}"), "chatlog", account)
|
||||
}
|
||||
}
|
||||
|
||||
func GetDirSize(dir string) string {
|
||||
var size int64
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if err == nil {
|
||||
size += info.Size()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return ByteCountSI(size)
|
||||
}
|
||||
|
||||
func ByteCountSI(b int64) string {
|
||||
const unit = 1000
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
div, exp := int64(unit), 0
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
return fmt.Sprintf("%.1f %cB",
|
||||
float64(b)/float64(div), "kMGTPE"[exp])
|
||||
}
|
||||
|
||||
// PrepareDir ensures that the specified directory path exists.
|
||||
// If the directory does not exist, it attempts to create it.
|
||||
func PrepareDir(path string) error {
|
||||
stat, err := os.Stat(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
} else if !stat.IsDir() {
|
||||
log.Debug().Msgf("%s is not a directory", path)
|
||||
return fmt.Errorf("%s is not a directory", path)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
15
pkg/util/os_windows.go
Normal file
15
pkg/util/os_windows.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func Is64Bit(handle windows.Handle) (bool, error) {
|
||||
var is32Bit bool
|
||||
if err := windows.IsWow64Process(handle, &is32Bit); err != nil {
|
||||
return false, fmt.Errorf("检查进程位数失败: %w", err)
|
||||
}
|
||||
return !is32Bit, nil
|
||||
}
|
||||
36
pkg/util/silk/silk.go
Normal file
36
pkg/util/silk/silk.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package silk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sjzar/go-lame"
|
||||
"github.com/sjzar/go-silk"
|
||||
)
|
||||
|
||||
func Silk2MP3(data []byte) ([]byte, error) {
|
||||
|
||||
sd := silk.SilkInit()
|
||||
defer sd.Close()
|
||||
|
||||
pcmdata := sd.Decode(data)
|
||||
if len(pcmdata) == 0 {
|
||||
return nil, fmt.Errorf("silk decode failed")
|
||||
}
|
||||
|
||||
le := lame.Init()
|
||||
defer le.Close()
|
||||
|
||||
le.SetInSamplerate(24000)
|
||||
le.SetOutSamplerate(24000)
|
||||
le.SetNumChannels(1)
|
||||
le.SetBitrate(16)
|
||||
// IMPORTANT!
|
||||
le.InitParams()
|
||||
|
||||
mp3data := le.Encode(pcmdata)
|
||||
if len(mp3data) == 0 {
|
||||
return nil, fmt.Errorf("mp3 encode failed")
|
||||
}
|
||||
|
||||
return mp3data, nil
|
||||
}
|
||||
71
pkg/util/strings.go
Normal file
71
pkg/util/strings.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
func IsNormalString(b []byte) bool {
|
||||
str := string(b)
|
||||
|
||||
// 检查是否为有效的 UTF-8
|
||||
if !utf8.ValidString(str) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否全部为可打印字符
|
||||
for _, r := range str {
|
||||
if !unicode.IsPrint(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func MustAnyToInt(v interface{}) int {
|
||||
str := fmt.Sprintf("%v", v)
|
||||
if i, err := strconv.Atoi(str); err == nil {
|
||||
return i
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func IsNumeric(s string) bool {
|
||||
for _, r := range s {
|
||||
if !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
func SplitInt64ToTwoInt32(input int64) (int64, int64) {
|
||||
return input & 0xFFFFFFFF, input >> 32
|
||||
}
|
||||
|
||||
func Str2List(str string, sep string) []string {
|
||||
list := make([]string, 0)
|
||||
|
||||
if str == "" {
|
||||
return list
|
||||
}
|
||||
|
||||
listMap := make(map[string]bool)
|
||||
for _, elem := range strings.Split(str, sep) {
|
||||
elem = strings.TrimSpace(elem)
|
||||
if len(elem) == 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := listMap[elem]; ok {
|
||||
continue
|
||||
}
|
||||
listMap[elem] = true
|
||||
list = append(list, elem)
|
||||
}
|
||||
|
||||
return list
|
||||
}
|
||||
658
pkg/util/time.go
Normal file
658
pkg/util/time.go
Normal file
@@ -0,0 +1,658 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var zoneStr = time.Now().Format("-0700")
|
||||
|
||||
// 时间粒度常量
|
||||
type TimeGranularity int
|
||||
|
||||
const (
|
||||
GranularityUnknown TimeGranularity = iota // 未知粒度
|
||||
GranularitySecond // 精确到秒
|
||||
GranularityMinute // 精确到分钟
|
||||
GranularityHour // 精确到小时
|
||||
GranularityDay // 精确到天
|
||||
GranularityMonth // 精确到月
|
||||
GranularityQuarter // 精确到季度
|
||||
GranularityYear // 精确到年
|
||||
)
|
||||
|
||||
// timeOf 内部函数,解析各种格式的时间点,并返回时间粒度
|
||||
// 支持以下格式:
|
||||
// 1. 时间戳(秒): 1609459200 (GranularitySecond)
|
||||
// 2. 标准日期: 20060102, 2006-01-02 (GranularityDay)
|
||||
// 3. 带时间的日期: 20060102/15:04, 2006-01-02/15:04 (GranularityMinute)
|
||||
// 4. 完整时间: 20060102150405 (GranularitySecond)
|
||||
// 5. RFC3339: 2006-01-02T15:04:05Z07:00 (GranularitySecond)
|
||||
// 6. 相对时间: 5h-ago, 3d-ago, 1w-ago, 1m-ago, 1y-ago (根据单位确定粒度)
|
||||
// 7. 自然语言: now (GranularitySecond), today, yesterday (GranularityDay)
|
||||
// 8. 年份: 2006 (GranularityYear)
|
||||
// 9. 月份: 200601, 2006-01 (GranularityMonth)
|
||||
// 10. 季度: 2006Q1, 2006Q2, 2006Q3, 2006Q4 (GranularityQuarter)
|
||||
// 11. 年月日时分: 200601021504 (GranularityMinute)
|
||||
func timeOf(str string) (t time.Time, g TimeGranularity, ok bool) {
|
||||
if str == "" {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
str = strings.TrimSpace(str)
|
||||
|
||||
// 处理自然语言时间
|
||||
switch strings.ToLower(str) {
|
||||
case "now":
|
||||
return time.Now(), GranularitySecond, true
|
||||
case "today":
|
||||
now := time.Now()
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()), GranularityDay, true
|
||||
case "yesterday":
|
||||
now := time.Now().AddDate(0, 0, -1)
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()), GranularityDay, true
|
||||
case "this-week":
|
||||
now := time.Now()
|
||||
weekday := int(now.Weekday())
|
||||
if weekday == 0 { // 周日
|
||||
weekday = 7
|
||||
}
|
||||
// 本周一
|
||||
monday := now.AddDate(0, 0, -(weekday - 1))
|
||||
return time.Date(monday.Year(), monday.Month(), monday.Day(), 0, 0, 0, 0, now.Location()), GranularityDay, true
|
||||
case "last-week":
|
||||
now := time.Now()
|
||||
weekday := int(now.Weekday())
|
||||
if weekday == 0 { // 周日
|
||||
weekday = 7
|
||||
}
|
||||
// 上周一
|
||||
lastMonday := now.AddDate(0, 0, -(weekday-1)-7)
|
||||
return time.Date(lastMonday.Year(), lastMonday.Month(), lastMonday.Day(), 0, 0, 0, 0, now.Location()), GranularityDay, true
|
||||
case "this-month":
|
||||
now := time.Now()
|
||||
return time.Date(now.Year(), now.Month(), 1, 0, 0, 0, 0, now.Location()), GranularityMonth, true
|
||||
case "last-month":
|
||||
now := time.Now()
|
||||
return time.Date(now.Year(), now.Month()-1, 1, 0, 0, 0, 0, now.Location()), GranularityMonth, true
|
||||
case "this-year":
|
||||
now := time.Now()
|
||||
return time.Date(now.Year(), 1, 1, 0, 0, 0, 0, now.Location()), GranularityYear, true
|
||||
case "last-year":
|
||||
now := time.Now()
|
||||
return time.Date(now.Year()-1, 1, 1, 0, 0, 0, 0, now.Location()), GranularityYear, true
|
||||
case "all":
|
||||
// 返回零值时间
|
||||
return time.Time{}, GranularityYear, true
|
||||
}
|
||||
|
||||
// 处理相对时间: 5h-ago, 3d-ago, 1w-ago, 1m-ago, 1y-ago
|
||||
if strings.HasSuffix(str, "-ago") {
|
||||
str = strings.TrimSuffix(str, "-ago")
|
||||
|
||||
// 特殊处理 0d-ago 为当天开始
|
||||
if str == "0d" {
|
||||
now := time.Now()
|
||||
return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location()), GranularityDay, true
|
||||
}
|
||||
|
||||
// 解析数字和单位
|
||||
re := regexp.MustCompile(`^(\d+)([hdwmy])$`)
|
||||
matches := re.FindStringSubmatch(str)
|
||||
if len(matches) == 3 {
|
||||
num, err := strconv.Atoi(matches[1])
|
||||
if err != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 确保数字是正数
|
||||
if num <= 0 {
|
||||
// 对于0d-ago已经特殊处理,其他0或负数都是无效的
|
||||
if num == 0 && matches[2] != "d" {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
var resultTime time.Time
|
||||
var granularity TimeGranularity
|
||||
|
||||
switch matches[2] {
|
||||
case "h": // 小时
|
||||
resultTime = now.Add(-time.Duration(num) * time.Hour)
|
||||
granularity = GranularityHour
|
||||
case "d": // 天
|
||||
resultTime = now.AddDate(0, 0, -num)
|
||||
granularity = GranularityDay
|
||||
case "w": // 周
|
||||
resultTime = now.AddDate(0, 0, -num*7)
|
||||
granularity = GranularityDay
|
||||
case "m": // 月
|
||||
resultTime = now.AddDate(0, -num, 0)
|
||||
granularity = GranularityMonth
|
||||
case "y": // 年
|
||||
resultTime = now.AddDate(-num, 0, 0)
|
||||
granularity = GranularityYear
|
||||
default:
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
return resultTime, granularity, true
|
||||
}
|
||||
|
||||
// 尝试标准 duration 解析
|
||||
dur, err := time.ParseDuration(str)
|
||||
if err == nil {
|
||||
// 根据duration单位确定粒度
|
||||
hours := dur.Hours()
|
||||
if hours < 1 {
|
||||
return time.Now().Add(-dur), GranularitySecond, true
|
||||
} else if hours < 24 {
|
||||
return time.Now().Add(-dur), GranularityHour, true
|
||||
} else {
|
||||
return time.Now().Add(-dur), GranularityDay, true
|
||||
}
|
||||
}
|
||||
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 处理季度: 2006Q1, 2006Q2, 2006Q3, 2006Q4
|
||||
if matched, _ := regexp.MatchString(`^\d{4}Q[1-4]$`, str); matched {
|
||||
re := regexp.MustCompile(`^(\d{4})Q([1-4])$`)
|
||||
matches := re.FindStringSubmatch(str)
|
||||
if len(matches) == 3 {
|
||||
year, _ := strconv.Atoi(matches[1])
|
||||
quarter, _ := strconv.Atoi(matches[2])
|
||||
|
||||
// 验证年份范围
|
||||
if year < 1970 || year > 9999 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 计算季度的开始月份
|
||||
startMonth := time.Month((quarter-1)*3 + 1)
|
||||
|
||||
return time.Date(year, startMonth, 1, 0, 0, 0, 0, time.Local), GranularityQuarter, true
|
||||
}
|
||||
}
|
||||
|
||||
// 处理年份: 2006
|
||||
if len(str) == 4 && isDigitsOnly(str) {
|
||||
year, err := strconv.Atoi(str)
|
||||
if err == nil && year >= 1970 && year <= 9999 {
|
||||
return time.Date(year, 1, 1, 0, 0, 0, 0, time.Local), GranularityYear, true
|
||||
}
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 处理月份: 200601 或 2006-01
|
||||
if (len(str) == 6 && isDigitsOnly(str)) || (len(str) == 7 && strings.Count(str, "-") == 1) {
|
||||
var year, month int
|
||||
var err error
|
||||
|
||||
if len(str) == 6 && isDigitsOnly(str) {
|
||||
year, err = strconv.Atoi(str[0:4])
|
||||
if err != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
month, err = strconv.Atoi(str[4:6])
|
||||
if err != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
} else { // 2006-01
|
||||
parts := strings.Split(str, "-")
|
||||
if len(parts) != 2 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
year, err = strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
month, err = strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
}
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
return time.Date(year, time.Month(month), 1, 0, 0, 0, 0, time.Local), GranularityMonth, true
|
||||
}
|
||||
|
||||
// 处理日期格式: 20060102 或 2006-01-02
|
||||
if len(str) == 8 && isDigitsOnly(str) {
|
||||
// 验证年月日
|
||||
year, _ := strconv.Atoi(str[0:4])
|
||||
month, _ := strconv.Atoi(str[4:6])
|
||||
day, _ := strconv.Atoi(str[6:8])
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 进一步验证日期是否有效
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 直接构造时间
|
||||
result := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
|
||||
return result, GranularityDay, true
|
||||
} else if len(str) == 10 && strings.Count(str, "-") == 2 {
|
||||
// 验证年月日
|
||||
parts := strings.Split(str, "-")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
year, err1 := strconv.Atoi(parts[0])
|
||||
month, err2 := strconv.Atoi(parts[1])
|
||||
day, err3 := strconv.Atoi(parts[2])
|
||||
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 进一步验证日期是否有效
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 直接构造时间
|
||||
result := time.Date(year, time.Month(month), day, 0, 0, 0, 0, time.Local)
|
||||
return result, GranularityDay, true
|
||||
}
|
||||
|
||||
// 处理年月日时分: 200601021504
|
||||
if len(str) == 12 && isDigitsOnly(str) {
|
||||
year, _ := strconv.Atoi(str[0:4])
|
||||
month, _ := strconv.Atoi(str[4:6])
|
||||
day, _ := strconv.Atoi(str[6:8])
|
||||
hour, _ := strconv.Atoi(str[8:10])
|
||||
minute, _ := strconv.Atoi(str[10:12])
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31 ||
|
||||
hour < 0 || hour > 23 || minute < 0 || minute > 59 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 进一步验证日期是否有效
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 直接构造时间
|
||||
result := time.Date(year, time.Month(month), day, hour, minute, 0, 0, time.Local)
|
||||
return result, GranularityMinute, true
|
||||
}
|
||||
|
||||
// 处理带时间的日期: 20060102/15:04 或 2006-01-02/15:04
|
||||
if strings.Contains(str, "/") {
|
||||
parts := strings.Split(str, "/")
|
||||
if len(parts) != 2 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
datePart := parts[0]
|
||||
timePart := parts[1]
|
||||
|
||||
// 验证日期部分
|
||||
var year, month, day int
|
||||
var err1, err2, err3 error
|
||||
|
||||
if len(datePart) == 8 && isDigitsOnly(datePart) {
|
||||
year, err1 = strconv.Atoi(datePart[0:4])
|
||||
month, err2 = strconv.Atoi(datePart[4:6])
|
||||
day, err3 = strconv.Atoi(datePart[6:8])
|
||||
} else if len(datePart) == 10 && strings.Count(datePart, "-") == 2 {
|
||||
dateParts := strings.Split(datePart, "-")
|
||||
if len(dateParts) != 3 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
year, err1 = strconv.Atoi(dateParts[0])
|
||||
month, err2 = strconv.Atoi(dateParts[1])
|
||||
day, err3 = strconv.Atoi(dateParts[2])
|
||||
} else {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
if err1 != nil || err2 != nil || err3 != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 进一步验证日期是否有效
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 验证时间部分
|
||||
if !regexp.MustCompile(`^\d{2}:\d{2}$`).MatchString(timePart) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
timeParts := strings.Split(timePart, ":")
|
||||
hour, err1 := strconv.Atoi(timeParts[0])
|
||||
minute, err2 := strconv.Atoi(timeParts[1])
|
||||
|
||||
if err1 != nil || err2 != nil {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
if hour < 0 || hour > 23 || minute < 0 || minute > 59 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 直接构造时间
|
||||
result := time.Date(year, time.Month(month), day, hour, minute, 0, 0, time.Local)
|
||||
return result, GranularityMinute, true
|
||||
}
|
||||
|
||||
// 处理完整时间: 20060102150405
|
||||
if len(str) == 14 && isDigitsOnly(str) {
|
||||
year, _ := strconv.Atoi(str[0:4])
|
||||
month, _ := strconv.Atoi(str[4:6])
|
||||
day, _ := strconv.Atoi(str[6:8])
|
||||
hour, _ := strconv.Atoi(str[8:10])
|
||||
minute, _ := strconv.Atoi(str[10:12])
|
||||
second, _ := strconv.Atoi(str[12:14])
|
||||
|
||||
if year < 1970 || year > 9999 || month < 1 || month > 12 || day < 1 || day > 31 ||
|
||||
hour < 0 || hour > 23 || minute < 0 || minute > 59 || second < 0 || second > 59 {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 进一步验证日期是否有效
|
||||
if !isValidDate(year, month, day) {
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 直接构造时间
|
||||
result := time.Date(year, time.Month(month), day, hour, minute, second, 0, time.Local)
|
||||
return result, GranularitySecond, true
|
||||
}
|
||||
|
||||
// 处理时间戳(秒)
|
||||
if isDigitsOnly(str) {
|
||||
n, err := strconv.ParseInt(str, 10, 64)
|
||||
if err == nil {
|
||||
// 检查是否是合理的时间戳范围
|
||||
if n >= 1000000000 && n <= 253402300799 { // 2001年到2286年的秒级时间戳
|
||||
return time.Unix(n, 0), GranularitySecond, true
|
||||
}
|
||||
}
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// 处理 RFC3339: 2006-01-02T15:04:05Z07:00
|
||||
if strings.Contains(str, "T") && (strings.Contains(str, "Z") || strings.Contains(str, "+") || strings.Contains(str, "-")) {
|
||||
t, err := time.Parse(time.RFC3339, str)
|
||||
if err != nil {
|
||||
// 尝试不带秒的格式
|
||||
t, err = time.Parse("2006-01-02T15:04Z07:00", str)
|
||||
}
|
||||
if err == nil {
|
||||
return t, GranularitySecond, true
|
||||
}
|
||||
}
|
||||
|
||||
// 排除所有其他不支持的格式
|
||||
return time.Time{}, GranularityUnknown, false
|
||||
}
|
||||
|
||||
// TimeOf 解析各种格式的时间点
|
||||
// 支持以下格式:
|
||||
// 1. 时间戳(秒): 1609459200
|
||||
// 2. 标准日期: 20060102, 2006-01-02
|
||||
// 3. 带时间的日期: 20060102/15:04, 2006-01-02/15:04
|
||||
// 4. 完整时间: 20060102150405
|
||||
// 5. RFC3339: 2006-01-02T15:04:05Z07:00
|
||||
// 6. 相对时间: 5h-ago, 3d-ago, 1w-ago, 1m-ago, 1y-ago (小时、天、周、月、年)
|
||||
// 7. 自然语言: now, today, yesterday
|
||||
// 8. 年份: 2006
|
||||
// 9. 月份: 200601, 2006-01
|
||||
// 10. 季度: 2006Q1, 2006Q2, 2006Q3, 2006Q4
|
||||
// 11. 年月日时分: 200601021504
|
||||
func TimeOf(str string) (t time.Time, ok bool) {
|
||||
t, _, ok = timeOf(str)
|
||||
return
|
||||
}
|
||||
|
||||
// TimeRangeOf 解析各种格式的时间范围
|
||||
// 支持以下格式:
|
||||
// 1. 单个时间点: 根据时间粒度确定合适的时间范围
|
||||
// - 精确到秒/分钟/小时: 扩展为当天范围
|
||||
// - 精确到天: 当天 00:00:00 ~ 23:59:59
|
||||
// - 精确到月: 当月第一天 ~ 最后一天
|
||||
// - 精确到季度: 季度第一天 ~ 最后一天
|
||||
// - 精确到年: 当年第一天 ~ 最后一天
|
||||
//
|
||||
// 2. 时间区间: 2006-01-01~2006-01-31, 2006-01-01,2006-01-31, 2006-01-01 to 2006-01-31
|
||||
// 3. 相对时间: last-7d, last-30d, last-3m, last-1y (最近7天、30天、3个月、1年)
|
||||
// 4. 特定时间段: today, yesterday, this-week, last-week, this-month, last-month, this-year, last-year
|
||||
// 5. all: 表示所有时间
|
||||
func TimeRangeOf(str string) (start, end time.Time, ok bool) {
|
||||
if str == "" {
|
||||
return time.Time{}, time.Time{}, false
|
||||
}
|
||||
|
||||
str = strings.TrimSpace(str)
|
||||
|
||||
// 处理 all 特殊情况
|
||||
if strings.ToLower(str) == "all" {
|
||||
start = time.Date(1970, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end = time.Date(9999, 12, 31, 23, 59, 59, 999999999, time.UTC)
|
||||
return start, end, true
|
||||
}
|
||||
|
||||
// 处理相对时间范围: last-7d, last-30d, last-3m, last-1y
|
||||
if matched, _ := regexp.MatchString(`^last-\d+[dwmy]$`, str); matched {
|
||||
re := regexp.MustCompile(`^last-(\d+)([dwmy])$`)
|
||||
matches := re.FindStringSubmatch(str)
|
||||
if len(matches) == 3 {
|
||||
num, err := strconv.Atoi(matches[1])
|
||||
if err != nil || num <= 0 {
|
||||
return time.Time{}, time.Time{}, false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
end = time.Date(now.Year(), now.Month(), now.Day(), 23, 59, 59, 999999999, now.Location())
|
||||
|
||||
switch matches[2] {
|
||||
case "d": // 天
|
||||
start = now.AddDate(0, 0, -num)
|
||||
start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, start.Location())
|
||||
return start, end, true
|
||||
case "w": // 周
|
||||
start = now.AddDate(0, 0, -num*7)
|
||||
start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, start.Location())
|
||||
return start, end, true
|
||||
case "m": // 月
|
||||
start = now.AddDate(0, -num, 0)
|
||||
start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, start.Location())
|
||||
return start, end, true
|
||||
case "y": // 年
|
||||
start = now.AddDate(-num, 0, 0)
|
||||
start = time.Date(start.Year(), start.Month(), start.Day(), 0, 0, 0, 0, start.Location())
|
||||
return start, end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理时间区间: 2006-01-01~2006-01-31, 2006-01-01,2006-01-31, 2006-01-01 to 2006-01-31
|
||||
separators := []string{"~", ",", " to "}
|
||||
for _, sep := range separators {
|
||||
if strings.Contains(str, sep) {
|
||||
parts := strings.Split(str, sep)
|
||||
if len(parts) == 2 {
|
||||
startTime, startGran, startOk := timeOf(strings.TrimSpace(parts[0]))
|
||||
endTime, endGran, endOk := timeOf(strings.TrimSpace(parts[1]))
|
||||
|
||||
if startOk && endOk {
|
||||
// 根据粒度调整时间范围
|
||||
start = adjustStartTime(startTime, startGran)
|
||||
end = adjustEndTime(endTime, endGran)
|
||||
|
||||
// 确保开始时间早于结束时间
|
||||
if start.After(end) {
|
||||
// 正确交换开始和结束时间
|
||||
start, end = adjustStartTime(endTime, endGran), adjustEndTime(startTime, startGran)
|
||||
}
|
||||
|
||||
return start, end, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理单个时间点,根据粒度确定合适的时间范围
|
||||
t, g, ok := timeOf(str)
|
||||
if ok {
|
||||
switch g {
|
||||
case GranularitySecond, GranularityMinute, GranularityHour:
|
||||
// 精确到秒/分钟/小时的时间点,扩展为当天范围
|
||||
start = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
end = time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityDay:
|
||||
// 精确到天的时间点
|
||||
start = time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
end = time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityMonth:
|
||||
// 精确到月的时间点
|
||||
start = time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||
end = time.Date(t.Year(), t.Month()+1, 0, 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityQuarter:
|
||||
// 精确到季度的时间点
|
||||
quarter := (t.Month()-1)/3 + 1
|
||||
startMonth := time.Month((int(quarter)-1)*3 + 1)
|
||||
endMonth := startMonth + 2
|
||||
start = time.Date(t.Year(), startMonth, 1, 0, 0, 0, 0, t.Location())
|
||||
end = time.Date(t.Year(), endMonth+1, 0, 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityYear:
|
||||
// 精确到年的时间点
|
||||
start = time.Date(t.Year(), 1, 1, 0, 0, 0, 0, t.Location())
|
||||
end = time.Date(t.Year(), 12, 31, 23, 59, 59, 999999999, t.Location())
|
||||
}
|
||||
return start, end, true
|
||||
}
|
||||
|
||||
return time.Time{}, time.Time{}, false
|
||||
}
|
||||
|
||||
// adjustStartTime 根据时间粒度调整开始时间
|
||||
func adjustStartTime(t time.Time, g TimeGranularity) time.Time {
|
||||
switch g {
|
||||
case GranularitySecond, GranularityMinute, GranularityHour:
|
||||
// 对于精确到秒/分钟/小时的时间,保持原样
|
||||
return t
|
||||
case GranularityDay:
|
||||
// 精确到天,设置为当天开始
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
case GranularityMonth:
|
||||
// 精确到月,设置为当月第一天
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, t.Location())
|
||||
case GranularityQuarter:
|
||||
// 精确到季度,设置为季度第一天
|
||||
quarter := (t.Month()-1)/3 + 1
|
||||
startMonth := time.Month((int(quarter)-1)*3 + 1)
|
||||
return time.Date(t.Year(), startMonth, 1, 0, 0, 0, 0, t.Location())
|
||||
case GranularityYear:
|
||||
// 精确到年,设置为当年第一天
|
||||
return time.Date(t.Year(), 1, 1, 0, 0, 0, 0, t.Location())
|
||||
default:
|
||||
// 未知粒度,默认为当天开始
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, t.Location())
|
||||
}
|
||||
}
|
||||
|
||||
// adjustEndTime 根据时间粒度调整结束时间
|
||||
func adjustEndTime(t time.Time, g TimeGranularity) time.Time {
|
||||
switch g {
|
||||
case GranularitySecond, GranularityMinute, GranularityHour:
|
||||
// 对于精确到秒/分钟/小时的时间,保持原样
|
||||
return t
|
||||
case GranularityDay:
|
||||
// 精确到天,设置为当天结束
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityMonth:
|
||||
// 精确到月,设置为当月最后一天
|
||||
return time.Date(t.Year(), t.Month()+1, 0, 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityQuarter:
|
||||
// 精确到季度,设置为季度最后一天
|
||||
quarter := (t.Month()-1)/3 + 1
|
||||
startMonth := time.Month((int(quarter)-1)*3 + 1)
|
||||
endMonth := startMonth + 2
|
||||
return time.Date(t.Year(), endMonth+1, 0, 23, 59, 59, 999999999, t.Location())
|
||||
case GranularityYear:
|
||||
// 精确到年,设置为当年最后一天
|
||||
return time.Date(t.Year(), 12, 31, 23, 59, 59, 999999999, t.Location())
|
||||
default:
|
||||
// 未知粒度,默认为当天结束
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 23, 59, 59, 999999999, t.Location())
|
||||
}
|
||||
}
|
||||
|
||||
// isDigitsOnly 检查字符串是否只包含数字
|
||||
func isDigitsOnly(s string) bool {
|
||||
for _, c := range s {
|
||||
if c < '0' || c > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return len(s) > 0
|
||||
}
|
||||
|
||||
// isValidDate 检查日期是否有效
|
||||
func isValidDate(year, month, day int) bool {
|
||||
// 检查月份的天数
|
||||
daysInMonth := 31
|
||||
|
||||
switch month {
|
||||
case 4, 6, 9, 11:
|
||||
daysInMonth = 30
|
||||
case 2:
|
||||
// 闰年判断
|
||||
if (year%4 == 0 && year%100 != 0) || year%400 == 0 {
|
||||
daysInMonth = 29
|
||||
} else {
|
||||
daysInMonth = 28
|
||||
}
|
||||
}
|
||||
|
||||
return day <= daysInMonth
|
||||
}
|
||||
|
||||
func PerfectTimeFormat(start time.Time, end time.Time) string {
|
||||
endTime := end
|
||||
|
||||
// 如果结束时间是某一天的 0 点整,将其减去 1 秒,视为前一天的结束
|
||||
if endTime.Hour() == 0 && endTime.Minute() == 0 && endTime.Second() == 0 && endTime.Nanosecond() == 0 {
|
||||
endTime = endTime.Add(-time.Second) // 减去 1 秒
|
||||
}
|
||||
|
||||
// 判断是否跨年
|
||||
if start.Year() != endTime.Year() {
|
||||
return "2006-01-02 15:04:05" // 完整格式,包含年月日时分秒
|
||||
}
|
||||
|
||||
// 判断是否跨天(但在同一年内)
|
||||
if start.YearDay() != endTime.YearDay() {
|
||||
return "01-02 15:04:05" // 月日时分秒格式
|
||||
}
|
||||
|
||||
// 在同一天内
|
||||
return "15:04:05" // 只显示时分秒
|
||||
}
|
||||
1004
pkg/util/time_test.go
Normal file
1004
pkg/util/time_test.go
Normal file
File diff suppressed because it is too large
Load Diff
11
pkg/util/zstd/zstd.go
Normal file
11
pkg/util/zstd/zstd.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"github.com/klauspost/compress/zstd"
|
||||
)
|
||||
|
||||
var decoder, _ = zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
|
||||
|
||||
func Decompress(src []byte) ([]byte, error) {
|
||||
return decoder.DecodeAll(src, nil)
|
||||
}
|
||||
32
pkg/version/version.go
Normal file
32
pkg/version/version.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
Version = "(dev)"
|
||||
buildInfo = debug.BuildInfo{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
if bi, ok := debug.ReadBuildInfo(); ok {
|
||||
buildInfo = *bi
|
||||
if len(bi.Main.Version) > 0 {
|
||||
Version = bi.Main.Version
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetMore(mod bool) string {
|
||||
if mod {
|
||||
mod := buildInfo.String()
|
||||
if len(mod) > 0 {
|
||||
return fmt.Sprintf("\t%s\n", strings.ReplaceAll(mod[:len(mod)-1], "\n", "\n\t"))
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("version %s %s %s/%s\n", Version, runtime.Version(), runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
Reference in New Issue
Block a user