237 lines
7.4 KiB
Python
237 lines
7.4 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, HTTPException, Query
|
||
from sqlalchemy.orm import Session
|
||
from typing import List, Optional, Dict, Any
|
||
from datetime import datetime, timedelta
|
||
import jwt
|
||
from passlib.context import CryptContext
|
||
from core.database import get_db
|
||
|
||
router = APIRouter()
|
||
|
||
# 密码加密上下文
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
# JWT配置
|
||
SECRET_KEY = "your-secret-key-here" # TODO: 从环境变量获取
|
||
ALGORITHM = "HS256"
|
||
ACCESS_TOKEN_EXPIRE_MINUTES = 30
|
||
|
||
# 模拟用户数据
|
||
USERS = {
|
||
"admin": {
|
||
"id": 1,
|
||
"username": "admin",
|
||
"email": "admin@example.com",
|
||
"hashed_password": pwd_context.hash("admin123"),
|
||
"role": "admin",
|
||
"permissions": ["read", "write", "admin"]
|
||
},
|
||
"operator": {
|
||
"id": 2,
|
||
"username": "operator",
|
||
"email": "operator@example.com",
|
||
"hashed_password": pwd_context.hash("operator123"),
|
||
"role": "operator",
|
||
"permissions": ["read", "write"]
|
||
}
|
||
}
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
"""验证密码"""
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
def get_user(username: str):
|
||
"""获取用户信息"""
|
||
if username in USERS:
|
||
return USERS[username]
|
||
return None
|
||
|
||
def authenticate_user(username: str, password: str):
|
||
"""验证用户"""
|
||
user = get_user(username)
|
||
if not user:
|
||
return False
|
||
if not verify_password(password, user["hashed_password"]):
|
||
return False
|
||
return user
|
||
|
||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
|
||
"""创建访问令牌"""
|
||
to_encode = data.copy()
|
||
if expires_delta:
|
||
expire = datetime.utcnow() + expires_delta
|
||
else:
|
||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||
return encoded_jwt
|
||
|
||
@router.post("/login", summary="用户登录")
|
||
async def login(
|
||
username: str,
|
||
password: str,
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""用户登录"""
|
||
try:
|
||
user = authenticate_user(username, password)
|
||
if not user:
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="用户名或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
access_token = create_access_token(
|
||
data={"sub": user["username"]}, expires_delta=access_token_expires
|
||
)
|
||
|
||
return {
|
||
"access_token": access_token,
|
||
"token_type": "bearer",
|
||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
|
||
"user": {
|
||
"id": user["id"],
|
||
"username": user["username"],
|
||
"email": user["email"],
|
||
"role": user["role"]
|
||
}
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"登录失败: {str(e)}")
|
||
|
||
@router.get("/profile", summary="获取用户信息")
|
||
async def get_user_profile(
|
||
current_user: str = Depends(lambda: "admin"), # TODO: 实现真实的JWT验证
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取当前用户信息"""
|
||
try:
|
||
user = get_user(current_user)
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
return {
|
||
"id": user["id"],
|
||
"username": user["username"],
|
||
"email": user["email"],
|
||
"role": user["role"],
|
||
"permissions": user["permissions"]
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取用户信息失败: {str(e)}")
|
||
|
||
@router.post("/logout", summary="用户登出")
|
||
async def logout(
|
||
current_user: str = Depends(lambda: "admin"), # TODO: 实现真实的JWT验证
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""用户登出"""
|
||
try:
|
||
# TODO: 实现真实的登出逻辑(如将token加入黑名单)
|
||
return {
|
||
"message": "登出成功"
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"登出失败: {str(e)}")
|
||
|
||
@router.post("/refresh", summary="刷新访问令牌")
|
||
async def refresh_token(
|
||
current_user: str = Depends(lambda: "admin"), # TODO: 实现真实的JWT验证
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""刷新访问令牌"""
|
||
try:
|
||
user = get_user(current_user)
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail="用户不存在")
|
||
|
||
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||
access_token = create_access_token(
|
||
data={"sub": user["username"]}, expires_delta=access_token_expires
|
||
)
|
||
|
||
return {
|
||
"access_token": access_token,
|
||
"token_type": "bearer",
|
||
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||
}
|
||
except HTTPException:
|
||
raise
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"刷新令牌失败: {str(e)}")
|
||
|
||
@router.get("/users", summary="获取用户列表")
|
||
async def get_users(
|
||
page: int = Query(1, ge=1, description="页码"),
|
||
size: int = Query(20, ge=1, le=100, description="每页数量"),
|
||
role: Optional[str] = Query(None, description="角色筛选"),
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""获取用户列表"""
|
||
try:
|
||
# TODO: 实现真实的用户列表查询
|
||
# 当前返回模拟数据
|
||
users = [
|
||
{
|
||
"id": 1,
|
||
"username": "admin",
|
||
"email": "admin@example.com",
|
||
"role": "admin",
|
||
"status": "active",
|
||
"created_at": "2024-01-01T00:00:00Z"
|
||
},
|
||
{
|
||
"id": 2,
|
||
"username": "operator",
|
||
"email": "operator@example.com",
|
||
"role": "operator",
|
||
"status": "active",
|
||
"created_at": "2024-01-02T00:00:00Z"
|
||
}
|
||
]
|
||
|
||
# 角色筛选
|
||
if role:
|
||
users = [u for u in users if u["role"] == role]
|
||
|
||
total = len(users)
|
||
start = (page - 1) * size
|
||
end = start + size
|
||
paginated_users = users[start:end]
|
||
|
||
return {
|
||
"users": paginated_users,
|
||
"total": total,
|
||
"page": page,
|
||
"size": size
|
||
}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"获取用户列表失败: {str(e)}")
|
||
|
||
@router.post("/users", summary="创建用户")
|
||
async def create_user(
|
||
user_data: Dict[str, Any],
|
||
db: Session = Depends(get_db)
|
||
):
|
||
"""创建新用户"""
|
||
try:
|
||
# TODO: 实现真实的用户创建
|
||
# 当前返回模拟数据
|
||
new_user = {
|
||
"id": 3,
|
||
"username": user_data.get("username", "newuser"),
|
||
"email": user_data.get("email", "newuser@example.com"),
|
||
"role": user_data.get("role", "operator"),
|
||
"status": "active",
|
||
"created_at": datetime.now().isoformat()
|
||
}
|
||
|
||
return new_user
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"创建用户失败: {str(e)}") |