audit_logger.py
11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
审计日志本地队列 + 后台上传 worker。
核心保证:事件一旦 log_use / log_login 返回,就已经 fsync 到本地 NDJSON 文件。
后台 worker 负责把本地队列异步上传到 MySQL;失败指数退避重试,成功后 compaction
重写队列文件删除已送达行。应用退出时 flush 一次尽量送达。
公开接口:
- init_audit_logger(db_config, queue_path, logs_dir): 启动单例
- get_audit_logger(): 获取单例(未初始化返回 None)
- AuditLogger.log_use(...)
- AuditLogger.log_login(...)
- AuditLogger.shutdown(timeout=5.0)
"""
from __future__ import annotations
import json
import logging
import os
import threading
import time
from datetime import datetime
from pathlib import Path
from typing import Optional
import pymysql
from PySide6.QtCore import QThread
logger = logging.getLogger(__name__)
_instance: Optional["AuditLogger"] = None
_instance_lock = threading.Lock()
def init_audit_logger(db_config: dict, queue_path: Path, logs_dir: Path) -> "AuditLogger":
"""在 preflight 通过后调用;幂等。"""
global _instance
with _instance_lock:
if _instance is None:
_instance = AuditLogger(db_config, queue_path, logs_dir)
_instance.start()
return _instance
def get_audit_logger() -> Optional["AuditLogger"]:
return _instance
class AuditLogger:
"""
对外门面。只负责:
1. 落盘(log_use / log_login, fsync 后返回)
2. 拉起/关闭 worker
真正上传逻辑在 _UploadWorker。
"""
def __init__(self, db_config: dict, queue_path: Path, logs_dir: Path):
self._db_config = db_config
self._queue_path = Path(queue_path)
self._logs_dir = Path(logs_dir)
self._file_lock = threading.Lock()
self._worker = _UploadWorker(
db_config=db_config,
queue_path=self._queue_path,
file_lock=self._file_lock,
)
def start(self) -> None:
self._queue_path.parent.mkdir(parents=True, exist_ok=True)
self._worker.start()
def log_use(
self,
user_name: str,
device_name: str,
prompt: str,
result_path: Optional[str],
status: str,
error_message: Optional[str],
model: Optional[str],
duration_ms: Optional[int],
finish_reason: Optional[str],
) -> None:
record = {
"kind": "use_log",
"ts": datetime.now().isoformat(timespec="seconds"),
"user_name": user_name or "未知用户",
"device_name": device_name or "未知设备",
"prompt": prompt or "",
"result_path": result_path,
"status": status,
"error_message": error_message,
"model": model,
"duration_ms": duration_ms,
"finish_reason": finish_reason,
}
self._append(record)
def log_login(
self,
user_name: str,
local_ip: Optional[str],
public_ip: Optional[str],
device_name: Optional[str],
) -> None:
record = {
"kind": "login_log",
"ts": datetime.now().isoformat(timespec="seconds"),
"user_name": user_name,
"local_ip": local_ip,
"public_ip": public_ip,
"device_name": device_name,
"login_time": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
}
self._append(record)
def shutdown(self, timeout: float = 5.0) -> None:
"""应用退出前调用,尽量 flush。"""
self._worker.stop(timeout)
def _append(self, record: dict) -> None:
"""落盘 + fsync。发生任何异常都不向上抛,但会落 error 日志(而不是 pass 吞掉)。"""
try:
line = json.dumps(record, ensure_ascii=False, default=str)
except Exception as e:
logger.error(f"审计事件序列化失败,已丢弃: {e}; record keys={list(record.keys())}")
return
try:
with self._file_lock:
with open(self._queue_path, "a", encoding="utf-8") as f:
f.write(line + "\n")
f.flush()
os.fsync(f.fileno())
self._worker.wake()
except Exception as e:
# 本地磁盘都写不进去,是真·严重故障。降级到日志文件,不再 raise
logger.error(f"审计事件落盘失败: {e}; 事件内容已写 error 日志兜底: {line[:200]}")
class _UploadWorker(QThread):
"""后台线程:循环 drain 队列文件 → 批量 INSERT → compaction。"""
def __init__(self, db_config: dict, queue_path: Path, file_lock: threading.Lock):
super().__init__()
self._db_config = db_config
self._queue_path = Path(queue_path)
self._file_lock = file_lock
self._stop_event = threading.Event()
self._wake_event = threading.Event()
self._backoff = 1.0
# --- 外部控制 ---
def wake(self) -> None:
self._wake_event.set()
def stop(self, timeout: float = 5.0) -> None:
self._stop_event.set()
self._wake_event.set()
self.wait(int(timeout * 1000))
# --- 主循环 ---
def run(self) -> None:
logger.info("audit UploadWorker started")
while not self._stop_event.is_set():
try:
sent, unsent = self._drain_once()
except Exception as e:
logger.error(f"audit drain 抛出未预期异常: {e}", exc_info=True)
sent, unsent = 0, 1 # 当做失败处理
if unsent > 0:
self._backoff = min(self._backoff * 2, 300.0)
logger.debug(f"audit: unsent={unsent}, backoff={self._backoff}s")
else:
self._backoff = 1.0
# 退出前再尝试一次 drain(worker stop 时)
if self._stop_event.is_set():
break
wait_s = self._backoff if unsent > 0 else 60.0
self._wake_event.wait(wait_s)
self._wake_event.clear()
# 退出前最后一次 drain
try:
self._drain_once()
except Exception:
pass
logger.info("audit UploadWorker stopped")
# --- 核心 drain ---
def _drain_once(self) -> tuple[int, int]:
"""
读快照 -> 批量上传 -> compaction。
返回 (sent_count, unsent_count)。
"""
# 1. 快照读
with self._file_lock:
if not self._queue_path.exists():
return 0, 0
eof_at_read = self._queue_path.stat().st_size
if eof_at_read == 0:
return 0, 0
with open(self._queue_path, "rb") as f:
head_bytes = f.read(eof_at_read)
try:
head_text = head_bytes.decode("utf-8")
except UnicodeDecodeError as e:
logger.error(f"audit 队列文件不是合法 UTF-8,跳过本轮: {e}")
return 0, 1
lines = [ln for ln in head_text.split("\n") if ln.strip()]
if not lines:
return 0, 0
# 2. 连 DB + 批量 INSERT
try:
conn = pymysql.connect(
host=self._db_config["host"],
port=int(self._db_config.get("port", 3306)),
user=self._db_config["user"],
password=self._db_config["password"],
database=self._db_config["database"],
connect_timeout=5,
read_timeout=10,
write_timeout=10,
charset="utf8mb4",
)
except Exception as e:
logger.warning(f"audit connect 失败,稍后重试: {e}")
return 0, len(lines)
sent = 0
unsent_lines: list[str] = []
try:
with conn.cursor() as cursor:
for i, line in enumerate(lines):
try:
record = json.loads(line)
except json.JSONDecodeError as e:
logger.error(f"audit 队列出现坏行,已跳过: {e}; line={line[:120]!r}")
# 不保留到 unsent(避免无限重试坏行)
continue
try:
self._insert_one(cursor, record)
sent += 1
except Exception as e:
logger.warning(
f"audit INSERT 失败(后续全部留队列): {type(e).__name__}: {e}"
)
unsent_lines = lines[i:]
break
conn.commit()
except Exception as e:
logger.warning(f"audit commit 失败: {e}")
unsent_lines = lines
sent = 0
finally:
try:
conn.close()
except Exception:
pass
# 3. Compaction:重写队列文件 = unsent_lines + 期间新增的 tail
with self._file_lock:
try:
# 读快照之后新增的尾部
current_size = self._queue_path.stat().st_size
tail = b""
if current_size > eof_at_read:
with open(self._queue_path, "rb") as f:
f.seek(eof_at_read)
tail = f.read()
with open(self._queue_path, "wb") as f:
for ln in unsent_lines:
f.write((ln + "\n").encode("utf-8"))
if tail:
f.write(tail)
f.flush()
os.fsync(f.fileno())
except Exception as e:
# compaction 失败:不致命,已发送的会在下次 drain 被重发(幂等性由
# MySQL auto-increment id 保障,不会真复制业务数据,仅审计可能重复一次)
logger.error(f"audit compaction 失败: {e}", exc_info=True)
if sent > 0:
logger.info(f"audit drained: sent={sent}, unsent={len(unsent_lines)}")
return sent, len(unsent_lines)
# --- 具体插入 ---
def _insert_one(self, cursor, record: dict) -> None:
kind = record.get("kind")
if kind == "use_log":
sql = """
INSERT INTO `nano_banana_user_use_log`
(`user_name`, `device_name`, `prompt`, `result_path`, `status`,
`error_message`, `model`, `duration_ms`, `finish_reason`)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(
sql,
(
record.get("user_name", "未知用户"),
record.get("device_name", "未知设备"),
record.get("prompt", ""),
record.get("result_path"),
record.get("status", "unknown"),
record.get("error_message"),
record.get("model"),
record.get("duration_ms"),
record.get("finish_reason"),
),
)
elif kind == "login_log":
sql = """
INSERT INTO `nano_banana_user_log`
(`user_name`, `local_ip`, `public_ip`, `device_name`, `login_time`)
VALUES (%s, %s, %s, %s, %s)
"""
login_time_val = record.get("login_time") or record.get("ts")
cursor.execute(
sql,
(
record.get("user_name"),
record.get("local_ip"),
record.get("public_ip"),
record.get("device_name"),
login_time_val,
),
)
else:
raise ValueError(f"未知审计事件 kind={kind!r}")