from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import date, datetime, timedelta
from sqlalchemy import and_, func, select
from sqlalchemy.orm import Session, selectinload
from app.models import Tag, WorkEntry
@dataclass(frozen=True)
class DateRange:
start: datetime
end: datetime
def normalize_tag_names(tag_names: Sequence[str] | None) -> list[str]:
if not tag_names:
return []
normalized = []
for name in tag_names:
cleaned = name.strip()
if cleaned and cleaned not in normalized:
normalized.append(cleaned)
return normalized
def _start_of_week(target: date) -> date:
return target - timedelta(days=target.weekday())
def compute_date_range(range_type: str, target_date: date | None = None) -> DateRange:
current = target_date or date.today()
if range_type == "daily":
start_day = current
elif range_type == "weekly":
start_day = _start_of_week(current)
else:
raise ValueError("range must be 'daily' or 'weekly'")
start_dt = datetime.combine(start_day, datetime.min.time())
if range_type == "daily":
end_dt = start_dt + timedelta(days=1)
else:
end_dt = start_dt + timedelta(days=7)
return DateRange(start=start_dt, end=end_dt)
class TagRepository:
def __init__(self, session: Session):
self.session = session
def list_tags(self) -> list[Tag]:
statement = select(Tag).order_by(Tag.name)
return self.session.execute(statement).scalars().all()
def create_tag(self, name: str, description: str | None = None) -> Tag:
cleaned = name.strip()
if not cleaned:
raise ValueError("Tag name不能为空")
existing = self.session.get(Tag, cleaned)
if existing:
raise ValueError(f"标签 {cleaned} 已存在")
tag = Tag(name=cleaned, description=description)
self.session.add(tag)
self.session.commit()
self.session.refresh(tag)
return tag
def fetch_by_names(self, tag_names: Sequence[str]) -> dict[str, Tag]:
normalized = normalize_tag_names(tag_names)
if not normalized:
return {}
statement = select(Tag).where(Tag.name.in_(normalized))
rows = self.session.execute(statement).scalars().all()
return {tag.name: tag for tag in rows}
class EntryRepository:
def __init__(self, session: Session):
self.session = session
def _apply_tag_filter(self, statement, tag_names: list[str]):
if not tag_names:
return statement
tagged = (
statement.join(WorkEntry.tags)
.where(Tag.name.in_(tag_names))
.group_by(WorkEntry.id)
.having(func.count(func.distinct(Tag.name)) == len(tag_names))
)
return tagged
def create_entry(self, description: str, tags: Sequence[Tag]) -> WorkEntry:
text = description.strip()
if not text:
raise ValueError("description 不能为空")
entry = WorkEntry(description=text)
entry.tags = list(tags)
self.session.add(entry)
self.session.commit()
self.session.refresh(entry)
return entry
def list_entries(
self,
*,
range_type: str,
target_date: date | None = None,
tag_names: Sequence[str] | None = None,
) -> list[WorkEntry]:
normalized_tags = normalize_tag_names(tag_names)
date_range = compute_date_range(range_type, target_date)
statement = (
select(WorkEntry)
.options(selectinload(WorkEntry.tags))
.where(
and_(
WorkEntry.created_at >= date_range.start,
WorkEntry.created_at < date_range.end,
)
)
.order_by(WorkEntry.created_at.desc())
)
statement = self._apply_tag_filter(statement, normalized_tags)
return self.session.execute(statement).scalars().all()
def search_entries(
self,
*,
query: str,
start_date: date | None = None,
end_date: date | None = None,
) -> list[WorkEntry]:
if not query.strip():
return []
statement = select(WorkEntry).options(selectinload(WorkEntry.tags)).where(
WorkEntry.description.ilike(f"%{query}%")
)
if start_date:
start_dt = datetime.combine(start_date, datetime.min.time())
statement = statement.where(WorkEntry.created_at >= start_dt)
if end_date:
end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
statement = statement.where(WorkEntry.created_at < end_dt)
statement = statement.order_by(WorkEntry.created_at.desc())
return self.session.execute(statement).scalars().all()
def export_entries(
self,
*,
start_date: date | None = None,
end_date: date | None = None,
) -> list[WorkEntry]:
statement = (
select(WorkEntry)
.options(selectinload(WorkEntry.tags))
.order_by(WorkEntry.created_at.desc())
)
if start_date:
start_dt = datetime.combine(start_date, datetime.min.time())
statement = statement.where(WorkEntry.created_at >= start_dt)
if end_date:
end_dt = datetime.combine(end_date + timedelta(days=1), datetime.min.time())
statement = statement.where(WorkEntry.created_at < end_dt)
return self.session.execute(statement).scalars().all()
def has_entry_between(self, start_dt: datetime, end_dt: datetime) -> bool:
statement = select(func.count(WorkEntry.id)).where(
and_(WorkEntry.created_at >= start_dt, WorkEntry.created_at < end_dt)
)
total = self.session.execute(statement).scalar_one()
return total > 0
def summarize_last_days(self, days: int) -> list[dict[str, object]]:
if days <= 0:
raise ValueError("days 必须大于0")
start_dt = datetime.combine(date.today(), datetime.min.time()) - timedelta(days=days - 1)
statement = (
select(Tag.name, func.count(WorkEntry.id))
.join(WorkEntry.tags, isouter=True)
.where(WorkEntry.created_at >= start_dt)
.group_by(Tag.name)
.order_by(func.count(WorkEntry.id).desc())
)
rows = self.session.execute(statement).all()
result: list[dict[str, object]] = []
for name, count in rows:
if name is None:
continue
result.append({"tag": name, "count": int(count)})
untagged_statement = (
select(func.count(WorkEntry.id))
.outerjoin(WorkEntry.tags)
.where(and_(WorkEntry.created_at >= start_dt, Tag.name.is_(None)))
)
untagged = int(self.session.execute(untagged_statement).scalar_one())
if untagged:
result.append({"tag": None, "count": untagged})
return result