Some checks failed
continuous-integration/drone/push Build is failing
- 后端:钉钉 OAuth 认证服务 - 后端:系统设置 API(钉钉配置) - 前端:登录页钉钉扫码入口 - 前端:系统设置页面 - 数据库迁移脚本
434 lines
14 KiB
Python
434 lines
14 KiB
Python
"""
|
||
用户服务
|
||
"""
|
||
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from sqlalchemy import and_, or_, select, func
|
||
from sqlalchemy.exc import IntegrityError
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.core.exceptions import ConflictError, NotFoundError
|
||
from app.core.logger import logger
|
||
from app.core.security import get_password_hash, verify_password
|
||
from app.models.user import Team, User, user_teams
|
||
from app.schemas.user import UserCreate, UserFilter, UserUpdate
|
||
from app.services.base_service import BaseService
|
||
|
||
|
||
class UserService(BaseService[User]):
|
||
"""用户服务"""
|
||
|
||
def __init__(self, db: AsyncSession):
|
||
super().__init__(User)
|
||
self.db = db
|
||
|
||
async def get_by_id(self, user_id: int) -> Optional[User]:
|
||
"""根据ID获取用户"""
|
||
result = await self.db.execute(
|
||
select(User).where(User.id == user_id, User.is_deleted == False)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_by_username(self, username: str) -> Optional[User]:
|
||
"""根据用户名获取用户"""
|
||
result = await self.db.execute(
|
||
select(User).where(
|
||
User.username == username,
|
||
User.is_deleted == False,
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_by_email(self, email: str) -> Optional[User]:
|
||
"""根据邮箱获取用户"""
|
||
result = await self.db.execute(
|
||
select(User).where(
|
||
User.email == email,
|
||
User.is_deleted == False,
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_by_phone(self, phone: str) -> Optional[User]:
|
||
"""根据手机号获取用户"""
|
||
result = await self.db.execute(
|
||
select(User).where(
|
||
User.phone == phone,
|
||
User.is_deleted == False,
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_by_dingtalk_id(self, dingtalk_id: str) -> Optional[User]:
|
||
"""根据钉钉用户ID获取用户"""
|
||
result = await self.db.execute(
|
||
select(User).where(
|
||
User.dingtalk_id == dingtalk_id,
|
||
User.is_deleted == False,
|
||
)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def _check_username_exists_all(self, username: str) -> Optional[User]:
|
||
"""
|
||
检查用户名是否已存在(包括已删除的用户)
|
||
用于创建用户时检查唯一性约束
|
||
"""
|
||
result = await self.db.execute(
|
||
select(User).where(User.username == username)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def _check_email_exists_all(self, email: str) -> Optional[User]:
|
||
"""
|
||
检查邮箱是否已存在(包括已删除的用户)
|
||
用于创建用户时检查唯一性约束
|
||
"""
|
||
result = await self.db.execute(
|
||
select(User).where(User.email == email)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def _check_phone_exists_all(self, phone: str) -> Optional[User]:
|
||
"""
|
||
检查手机号是否已存在(包括已删除的用户)
|
||
用于创建用户时检查唯一性约束
|
||
"""
|
||
result = await self.db.execute(
|
||
select(User).where(User.phone == phone)
|
||
)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def create_user(
|
||
self,
|
||
*,
|
||
obj_in: UserCreate,
|
||
created_by: Optional[int] = None,
|
||
) -> User:
|
||
"""创建用户"""
|
||
# 检查用户名是否已存在(包括已删除的用户,防止唯一键冲突)
|
||
existing_user = await self._check_username_exists_all(obj_in.username)
|
||
if existing_user:
|
||
if existing_user.is_deleted:
|
||
raise ConflictError(f"用户名 {obj_in.username} 已被使用(历史用户),请更换其他用户名")
|
||
else:
|
||
raise ConflictError(f"用户名 {obj_in.username} 已存在")
|
||
|
||
# 检查邮箱是否已存在(包括已删除的用户)
|
||
if obj_in.email:
|
||
existing_email = await self._check_email_exists_all(obj_in.email)
|
||
if existing_email:
|
||
if existing_email.is_deleted:
|
||
raise ConflictError(f"邮箱 {obj_in.email} 已被使用(历史用户),请更换其他邮箱")
|
||
else:
|
||
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
|
||
|
||
# 检查手机号是否已存在(包括已删除的用户)
|
||
if obj_in.phone:
|
||
existing_phone = await self._check_phone_exists_all(obj_in.phone)
|
||
if existing_phone:
|
||
if existing_phone.is_deleted:
|
||
raise ConflictError(f"手机号 {obj_in.phone} 已被使用(历史用户),请更换其他手机号")
|
||
else:
|
||
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
|
||
|
||
# 创建用户数据
|
||
user_data = obj_in.model_dump(exclude={"password"})
|
||
user_data["hashed_password"] = get_password_hash(obj_in.password)
|
||
# 注意:User模型不包含created_by字段,该信息记录在日志中
|
||
# user_data["created_by"] = created_by
|
||
|
||
try:
|
||
# 创建用户
|
||
user = await self.create(db=self.db, obj_in=user_data)
|
||
except IntegrityError as e:
|
||
# 捕获数据库唯一键冲突异常,返回友好错误信息
|
||
await self.db.rollback()
|
||
error_msg = str(e.orig) if e.orig else str(e)
|
||
logger.warning(
|
||
"创建用户时发生唯一键冲突",
|
||
username=obj_in.username,
|
||
email=obj_in.email,
|
||
error=error_msg,
|
||
)
|
||
if "username" in error_msg.lower():
|
||
raise ConflictError(f"用户名 {obj_in.username} 已被占用,请更换其他用户名")
|
||
elif "email" in error_msg.lower():
|
||
raise ConflictError(f"邮箱 {obj_in.email} 已被占用,请更换其他邮箱")
|
||
elif "phone" in error_msg.lower():
|
||
raise ConflictError(f"手机号 {obj_in.phone} 已被占用,请更换其他手机号")
|
||
else:
|
||
raise ConflictError(f"创建用户失败:数据冲突,请检查用户名、邮箱或手机号是否重复")
|
||
|
||
# 记录日志
|
||
logger.info(
|
||
"用户创建成功",
|
||
user_id=user.id,
|
||
username=user.username,
|
||
role=user.role,
|
||
created_by=created_by,
|
||
)
|
||
|
||
return user
|
||
|
||
async def update_user(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
obj_in: UserUpdate,
|
||
updated_by: Optional[int] = None,
|
||
) -> User:
|
||
"""更新用户"""
|
||
user = await self.get_by_id(user_id)
|
||
if not user:
|
||
raise NotFoundError("用户不存在")
|
||
|
||
# 如果更新邮箱,检查是否已存在
|
||
if obj_in.email and obj_in.email != user.email:
|
||
if await self.get_by_email(obj_in.email):
|
||
raise ConflictError(f"邮箱 {obj_in.email} 已存在")
|
||
|
||
# 如果更新手机号,检查是否已存在
|
||
if obj_in.phone and obj_in.phone != user.phone:
|
||
if await self.get_by_phone(obj_in.phone):
|
||
raise ConflictError(f"手机号 {obj_in.phone} 已存在")
|
||
|
||
# 更新用户数据
|
||
update_data = obj_in.model_dump(exclude_unset=True)
|
||
update_data["updated_by"] = updated_by
|
||
|
||
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
|
||
|
||
# 记录日志
|
||
logger.info(
|
||
"用户更新成功",
|
||
user_id=user.id,
|
||
username=user.username,
|
||
updated_fields=list(update_data.keys()),
|
||
updated_by=updated_by,
|
||
)
|
||
|
||
return user
|
||
|
||
async def update_password(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
old_password: str,
|
||
new_password: str,
|
||
) -> User:
|
||
"""更新密码"""
|
||
user = await self.get_by_id(user_id)
|
||
if not user:
|
||
raise NotFoundError("用户不存在")
|
||
|
||
# 验证旧密码
|
||
if not verify_password(old_password, user.hashed_password):
|
||
raise ConflictError("旧密码错误")
|
||
|
||
# 更新密码
|
||
update_data = {
|
||
"hashed_password": get_password_hash(new_password),
|
||
"password_changed_at": datetime.now(),
|
||
}
|
||
user = await self.update(db=self.db, db_obj=user, obj_in=update_data)
|
||
|
||
# 记录日志
|
||
logger.info(
|
||
"用户密码更新成功",
|
||
user_id=user.id,
|
||
username=user.username,
|
||
)
|
||
|
||
return user
|
||
|
||
async def update_last_login(self, user_id: int) -> None:
|
||
"""更新最后登录时间"""
|
||
user = await self.get_by_id(user_id)
|
||
if user:
|
||
await self.update(
|
||
db=self.db,
|
||
db_obj=user,
|
||
obj_in={"last_login_at": datetime.now()},
|
||
)
|
||
|
||
async def get_users_with_filter(
|
||
self,
|
||
*,
|
||
skip: int = 0,
|
||
limit: int = 100,
|
||
filter_params: UserFilter,
|
||
) -> tuple[List[User], int]:
|
||
"""根据筛选条件获取用户列表"""
|
||
# 构建筛选条件
|
||
filters = [User.is_deleted == False]
|
||
|
||
if filter_params.role:
|
||
filters.append(User.role == filter_params.role)
|
||
|
||
if filter_params.is_active is not None:
|
||
filters.append(User.is_active == filter_params.is_active)
|
||
|
||
if filter_params.keyword:
|
||
keyword = f"%{filter_params.keyword}%"
|
||
filters.append(
|
||
or_(
|
||
User.username.like(keyword),
|
||
User.email.like(keyword),
|
||
User.full_name.like(keyword),
|
||
)
|
||
)
|
||
|
||
if filter_params.team_id:
|
||
# 通过团队ID筛选用户
|
||
subquery = select(user_teams.c.user_id).where(
|
||
user_teams.c.team_id == filter_params.team_id
|
||
)
|
||
filters.append(User.id.in_(subquery))
|
||
|
||
# 构建查询
|
||
query = select(User).where(and_(*filters))
|
||
|
||
# 获取用户列表
|
||
users = await self.get_multi(self.db, skip=skip, limit=limit, query=query)
|
||
|
||
# 获取总数
|
||
count_query = select(func.count(User.id)).where(and_(*filters))
|
||
count_result = await self.db.execute(count_query)
|
||
total = count_result.scalar()
|
||
|
||
return users, total
|
||
|
||
async def add_user_to_team(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
team_id: int,
|
||
role: str = "member",
|
||
) -> None:
|
||
"""将用户添加到团队"""
|
||
# 检查用户是否存在
|
||
user = await self.get_by_id(user_id)
|
||
if not user:
|
||
raise NotFoundError("用户不存在")
|
||
|
||
# 检查团队是否存在
|
||
team_result = await self.db.execute(
|
||
select(Team).where(Team.id == team_id, Team.is_deleted == False)
|
||
)
|
||
team = team_result.scalar_one_or_none()
|
||
if not team:
|
||
raise NotFoundError("团队不存在")
|
||
|
||
# 检查是否已在团队中
|
||
existing = await self.db.execute(
|
||
select(user_teams).where(
|
||
user_teams.c.user_id == user_id,
|
||
user_teams.c.team_id == team_id,
|
||
)
|
||
)
|
||
if existing.first():
|
||
raise ConflictError("用户已在该团队中")
|
||
|
||
# 添加到团队
|
||
await self.db.execute(
|
||
user_teams.insert().values(
|
||
user_id=user_id,
|
||
team_id=team_id,
|
||
role=role,
|
||
joined_at=datetime.now(),
|
||
)
|
||
)
|
||
await self.db.commit()
|
||
|
||
# 记录日志
|
||
logger.info(
|
||
"用户加入团队",
|
||
user_id=user_id,
|
||
username=user.username,
|
||
team_id=team_id,
|
||
team_name=team.name,
|
||
role=role,
|
||
)
|
||
|
||
async def remove_user_from_team(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
team_id: int,
|
||
) -> None:
|
||
"""从团队中移除用户"""
|
||
# 删除关联
|
||
result = await self.db.execute(
|
||
user_teams.delete().where(
|
||
user_teams.c.user_id == user_id,
|
||
user_teams.c.team_id == team_id,
|
||
)
|
||
)
|
||
|
||
if result.rowcount == 0:
|
||
raise NotFoundError("用户不在该团队中")
|
||
|
||
await self.db.commit()
|
||
|
||
# 记录日志
|
||
logger.info(
|
||
"用户离开团队",
|
||
user_id=user_id,
|
||
team_id=team_id,
|
||
)
|
||
|
||
async def soft_delete(self, *, db_obj: User) -> User:
|
||
"""
|
||
软删除用户
|
||
|
||
Args:
|
||
db_obj: 用户对象
|
||
|
||
Returns:
|
||
软删除后的用户对象
|
||
"""
|
||
db_obj.is_deleted = True
|
||
db_obj.deleted_at = datetime.now()
|
||
self.db.add(db_obj)
|
||
await self.db.commit()
|
||
await self.db.refresh(db_obj)
|
||
|
||
logger.info(
|
||
"用户软删除成功",
|
||
user_id=db_obj.id,
|
||
username=db_obj.username,
|
||
)
|
||
|
||
return db_obj
|
||
|
||
async def authenticate(
|
||
self,
|
||
*,
|
||
username: str,
|
||
password: str,
|
||
) -> Optional[User]:
|
||
"""用户认证"""
|
||
# 尝试用户名登录
|
||
user = await self.get_by_username(username)
|
||
|
||
# 尝试邮箱登录
|
||
if not user:
|
||
user = await self.get_by_email(username)
|
||
|
||
# 尝试手机号登录
|
||
if not user:
|
||
user = await self.get_by_phone(username)
|
||
|
||
if not user:
|
||
return None
|
||
|
||
# 验证密码
|
||
if not verify_password(password, user.hashed_password):
|
||
return None
|
||
|
||
return user
|