修正项目结构

This commit is contained in:
lkhsss
2026-01-07 10:36:09 +08:00
parent 91f96a07a3
commit 96673d6fe0
41 changed files with 3366 additions and 1 deletions

44
src/constant.rs Normal file
View File

@ -0,0 +1,44 @@
use include_dir::include_dir;
pub const PROJECT_DIR: include_dir::Dir = include_dir!("./template/dist"); //将前端硬编码到项目
pub const DIST_DIR: &str = "./dist"; //前端释放目录
// pub const DIST_DIR: &str = "./template/dist"; //前端释放目录
pub const DATABASE: &str = "./data.sqlite"; //保存数据的目录
pub const PORT: u16 = 3000;
pub mod sql {
pub const CREATE_TABLE_TASKS: &str = "
CREATE TABLE IF NOT EXISTS TASKS (
ID INTEGER PRIMARY KEY AUTOINCREMENT,
NAME TEXT NOT NULL,
DESCRIPTION TEXT,
TIME_START INTEGER,
TIME_END INTEGER,
STATUS TEXT,
NODE JSON
);
";
pub const CREATE_TABLE_USERS: &str = "create table if not exists users
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
);";
pub const CREATE_TABLE_USER_INFO: &str = "create table if not exists users
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL
);";//TODO
pub const CREATE_USER_ADMIN: &str = "INSERT OR REPLACE INTO users (username, password)
values ('admin', '{}');";
pub const CREATE_USER: &str = "INSERT INTO users (username, password)
values (?, ?);";
}
pub const VERSION: &str = env!("CARGO_PKG_VERSION");

74
src/data.rs Normal file
View File

@ -0,0 +1,74 @@
use std::fmt::Display;
use serde_json::json;
#[derive(Debug, Clone)]
pub struct Task {
id: u128,
name: String,
description: Option<String>,
time_start: Option<u64>,
time_end: Option<u64>,
status: TaskStatus,
node: Vec<u128>, //存储节点的id json
}
impl Task {
pub fn new(
name: String,
description: Option<String>,
time_start: Option<u64>,
time_end: Option<u64>,
status: TaskStatus,
) -> Self {
Self {
id: 0,
name,
description,
time_start,
time_end,
status,
node: vec![],
}
}
}
pub trait ToSql {
fn insert(&self) -> String;
}
impl ToSql for Task {
fn insert(&self) -> String {
format!(
"INSERT INTO TASKS (NAME, DESCRIPTION, TIME_START, TIME_END, STATUS, NODE) VALUES ('{}',{},{},{},'{}','{}');",
&self.name,
take_or_null(&self.description),
take_or_null(&self.time_start),
take_or_null(&self.time_end),
&self.status,
json!(&self.node)
)
}
}
#[derive(Debug, Clone)]
pub enum TaskStatus {
TODO,
Finish,
}
impl Display for TaskStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let status = match self {
TaskStatus::TODO => "TODO",
TaskStatus::Finish => "Finish",
};
write!(f, "{}", status)
}
}
fn take_or_null<T: Display>(s: &Option<T>) -> String {
match s {
Some(x) => format!("'{}'", x.to_string()),
None => "NULL".to_string(),
}
}

94
src/database.rs Normal file
View File

@ -0,0 +1,94 @@
use std::path::Path;
use axum_login::{AuthnBackend, UserId};
use password_auth::verify_password;
use sqlx::SqlitePool;
use sqlx::{
Executor, Pool, Sqlite,
sqlite::{SqlitePoolOptions, SqliteQueryResult},
};
use tokio::fs::create_dir_all;
use tokio::task;
use tracing_subscriber::registry::Data;
use crate::constant::sql::{self, CREATE_USER};
use crate::data::{Task, ToSql};
use crate::error::Error;
use crate::users::{Credentials, User};
#[derive(Clone)]
pub struct Database {
pub pool: SqlitePool,
}
impl Database {
pub async fn new(url: &str) -> Self {
if !Path::new(url).is_file() {
std::fs::File::create_new(url).expect("无法创建数据库文件!");
};
let pool = SqlitePoolOptions::new().connect(url).await.unwrap();
Self { pool }
}
pub async fn init(&self) -> Result<u64, sqlx::Error> {
// debug!("创建表");
self.pool.execute(sql::CREATE_TABLE_TASKS).await?; //任务表
self.pool.execute(sql::CREATE_TABLE_USERS).await?; //用户密码表
// self.pool.execute(sql::CREATE_TABLE_UNCLASSIFIED).await?; //创建未分类表
Ok(0)
}
pub async fn new_task(&self, task: Task) -> Result<SqliteQueryResult, sqlx::Error> {
self.pool.execute(task.insert().as_str()).await
}
pub async fn add_new_user<S: ToString>(
&self,
username: S,
password: S,
) -> Result<SqliteQueryResult, sqlx::Error> {
sqlx::query(CREATE_USER)
.bind(username.to_string())
.bind(password_auth::generate_hash(password.to_string()))
.execute(&self.pool)
.await
}
}
impl AuthnBackend for Database {
type User = User;
type Credentials = Credentials;
type Error = Error;
async fn authenticate(&self, creds: Credentials) -> Result<Option<User>, Error> {
let user: Option<User> = sqlx::query_as("select * from users where username = ? ")
.bind(creds.username)
.fetch_optional(&self.pool)
.await?;
// Verifying the password is blocking and potentially slow, so we'll do so via
// `spawn_blocking`.
task::spawn_blocking(|| {
// We're using password-based authentication--this works by comparing our form
// input with an argon2 password hash.
Ok(user.filter(|user| verify_password(creds.password, &user.password).is_ok()))
})
.await?
}
async fn get_user(&self, user_id: &UserId<Self>) -> Result<Option<Self::User>, Self::Error> {
let user = sqlx::query_as("select * from users where id = ?")
.bind(user_id)
.fetch_optional(&self.pool)
.await?;
Ok(user)
}
}
// We use a type alias for convenience.
//
// Note that we've supplied our concrete backend here.
pub type AuthSession = axum_login::AuthSession<Database>;

10
src/error.rs Normal file
View File

@ -0,0 +1,10 @@
use tokio::task;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Sqlx(#[from] sqlx::Error),
#[error(transparent)]
TaskJoin(#[from] task::JoinError),
}

111
src/handlers.rs Normal file
View File

@ -0,0 +1,111 @@
use axum::Form;
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use tracing::{error, info, warn};
use crate::constant::VERSION;
use crate::data::Task;
use crate::database::{AuthSession, Database};
use crate::response::{self, Common};
use crate::users::Credentials;
// 创建新任务
pub async fn new_task(State(db): State<Database>) -> impl IntoResponse {
let t = Task::new(
"测试".to_string(),
None,
None,
None,
crate::data::TaskStatus::TODO,
);
(
StatusCode::OK,
format!("{}", db.new_task(t).await.unwrap().rows_affected()),
)
.into_response()
}
pub async fn reg(
mut auth_session: AuthSession,
Form(creds): Form<Credentials>,
) -> impl IntoResponse {
info!("注册开始");
match auth_session
.backend
.add_new_user(&creds.username, &creds.password)
.await
{
Ok(_) => {
info!("用户:[{}]注册成功", creds.username);
// 自动登录
let user = auth_session
.authenticate(creds.clone())
.await
.unwrap()
.unwrap();
auth_session.login(&user).await.unwrap();
info!("用户:[{}]已自动登录", creds.username);
Common::success("注册成功".to_string())
}
Err(e) => match e {
sqlx::Error::Database(_) => Common::failure("该账号已注册".to_string()),
_ => Common::failure(e.to_string()),
},
}
}
pub async fn login(
mut auth_session: AuthSession,
Form(creds): Form<Credentials>,
) -> impl IntoResponse {
let user = match auth_session.authenticate(creds.clone()).await {
Ok(Some(user)) => user,
Ok(None) => {
warn!("未找到用户: {}", creds.username);
return Common::failure(format!("未找到用户: {}", creds.username)).into_response();
}
Err(e) => {
error!("查询数据库失败:{:?}", e);
return Common::failure(format!("查询数据库失败:{:?}", e)).into_response();
}
};
match auth_session.login(&user).await {
Ok(_) => {
info!("用户:[{}]登陆成功", &user.username);
Common::success(&user.username).into_response()
}
Err(e) => {
error!("用户:[{}]登陆失败", &user.username);
Common::failure(format!("登陆失败:{:?}", e)).into_response()
}
}
}
pub async fn logout(mut auth_session: AuthSession, username: String) -> impl IntoResponse {
match auth_session.logout().await {
Ok(_) => {
info!("用户:[{}]退出登陆", username);
Common::success("退出登陆成功".to_string())
}
Err(e) => {
error!("用户:[{}]退出登陆失败:{:?}", username, e);
Common::success(format!("退出登陆失败:{:?}", e))
}
}
}
pub async fn check_login(auth_session: AuthSession) -> impl IntoResponse {
match auth_session.user {
Some(u) => response::Common::new(response::Status::Success, u.username).into_response(),
None => {
response::Common::new(response::Status::Failure, "未登录".to_string()).into_response()
}
}
}
pub async fn get_username(auth_session: AuthSession) -> impl IntoResponse {}
pub async fn version() -> impl IntoResponse {
VERSION
}

19
src/logger.rs Normal file
View File

@ -0,0 +1,19 @@
use tracing::level_filters::LevelFilter;
use tracing::{Level, level_filters};
use tracing_subscriber::EnvFilter;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
pub fn init() {
// 注册一个全局的日志记录器
tracing_subscriber::registry()
.with(
tracing_subscriber::fmt::layer()
.with_file(true) //打印文件名
.with_line_number(true) //打印行号
.with_thread_ids(true) //打印线程ID
.with_thread_names(true) //打印线程名称
.with_target(false), //不打印target
)
.init();
}

133
src/main.rs Normal file
View File

@ -0,0 +1,133 @@
use std::path::PathBuf;
use axum::{Router, routing::get};
use axum_login::tower_sessions::ExpiredDeletion;
use axum_login::{AuthManagerLayerBuilder, login_required};
use tower_http::services::{ServeDir, ServeFile};
use tower_http::trace::TraceLayer;
use tower_sessions::{
Expiry, SessionManagerLayer,
cookie::{Key, time::Duration},
};
use tower_sessions_sqlx_store::SqliteStore;
use tracing::{error, info};
mod constant;
mod data;
mod database;
mod error;
mod handlers;
mod logger;
mod response;
mod routers;
mod users;
use constant::DIST_DIR;
use constant::PORT;
use database::Database;
use crate::constant::DATABASE;
use crate::constant::PROJECT_DIR;
use crate::routers::api_routes_init;
#[tokio::main]
async fn main() {
// 日志系统
tracing_subscriber::fmt()
.with_max_level(tracing::Level::INFO)
// .with_env_filter(EnvFilter::from_default_env())
.init();
//新建数据库
let db = Database::new(DATABASE).await;
//新建表
db.init().await.unwrap();
//前端硬编码
// FIXME
//TODO 释放目录换成随机temp目录
PROJECT_DIR
.extract(PathBuf::from(DIST_DIR))
.expect("无法提取项目目录");
let session_store = SqliteStore::new(db.pool.clone());
session_store.migrate().await.unwrap(); //自动创建tower_sessions表的关键没有的话不会自动创建tower_sessions表
let deletion_task = tokio::task::spawn(
session_store
.clone()
.continuously_delete_expired(tokio::time::Duration::from_mins(5)), //过期时间
);
// Generate a cryptographic key to sign the session cookie.
let key = Key::generate();
let session_layer = SessionManagerLayer::new(session_store)
.with_secure(false)
.with_expiry(Expiry::OnInactivity(Duration::hours(1)))
.with_signed(key);
let auth_layer = AuthManagerLayerBuilder::new(db.clone(), session_layer).build();
let api_routes: Router<Database> = api_routes_init();
let app = Router::new()
.route_service(
"/",
ServeFile::new(format!("{}{}", DIST_DIR, "/index.html")),
)
.route_service(
"/{*path}",
ServeDir::new(DIST_DIR)
.not_found_service(ServeFile::new(format!("{}{}", DIST_DIR, "/index.html"))),
)
// .route("/new", get(new_task))
.nest("/api", api_routes) //api的路径
.with_state(db)
.layer(auth_layer)
.layer(TraceLayer::new_for_http());
// .route("/test", get())// 测试接口
//.layer(middleware::from_fn(logging_middleware));
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{PORT}"))
.await
.unwrap();
info!("成功启用服务http://0.0.0.0:{PORT}");
match webbrowser::open(&format!("http://127.0.0.1:{PORT}")) {
Ok(_) => info!(
"成功打开浏览器并访问: {}",
&format!("http://127.0.0.1:{PORT}")
),
Err(e) => error!("无法打开浏览器: {}", e),
};
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal(deletion_task.abort_handle()))
.await
.unwrap();
}
//优雅关机
async fn shutdown_signal(deletion_task_abort_handle: tokio::task::AbortHandle) {
let ctrl_c = async {
tokio::signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => { deletion_task_abort_handle.abort() },
_ = terminate => { deletion_task_abort_handle.abort() },
}
}

80
src/response.rs Normal file
View File

@ -0,0 +1,80 @@
use std::fmt::Display;
use axum::{Json, http::StatusCode, response::IntoResponse};
use serde::Serialize;
use serde_json::json;
use crate::users::User;
pub struct CheckLogin(pub Option<User>);
impl Display for CheckLogin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
Some(u) => write!(
f,
"{}",
Common::new(Status::Success, format!("\"{}\"", u.username.clone()))
),
None => write!(f, "{}", Common::new(Status::Failure, "\"未登录\"")),
}
}
}
/// 通用的返回值
pub struct Common<D: Serialize> {
status: Status,
data: D,
}
impl<D: Serialize> Common<D> {
pub fn new(status: Status, data: D) -> Self {
Self { status, data }
}
pub fn success(data: D) -> Self {
Self {
status: Status::Success,
data,
}
}
pub fn failure(data: D) -> Self {
Self {
status: Status::Failure,
data,
}
}
}
pub enum Status {
Success,
Failure,
}
impl Display for Status {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
Status::Success => "Success",
Status::Failure => "Failure",
};
write!(f, "{}", s)
}
}
impl<D: Serialize> Display for Common<D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
json!({"status":self.status.to_string(),"data":self.data}),
)
}
}
impl<D: Serialize> IntoResponse for Common<D> {
fn into_response(self) -> axum::response::Response {
(
StatusCode::OK,
Json::from(json!({"status":self.status.to_string(),"data":self.data})),
)
.into_response()
}
}

18
src/routers.rs Normal file
View File

@ -0,0 +1,18 @@
use axum::{
Router,
routing::{get, post},
};
use crate::{
database::Database,
handlers::{check_login, login, logout, reg, version},
};
// api路径的集合
pub fn api_routes_init() -> Router<Database> {
Router::new()
.route("/login", post(login).get(check_login))
.route("/reg", post(reg))
.route("/logout", post(logout))
.route("/version", get(version))
}

55
src/users.rs Normal file
View File

@ -0,0 +1,55 @@
use std::fmt::Display;
use axum_login::AuthUser;
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
#[derive(Clone, Serialize, Deserialize, FromRow)]
pub struct User {
id: i64,
pub username: String,
pub password: String,
}
// Here we've implemented `Debug` manually to avoid accidentally logging the
// password hash.
impl std::fmt::Debug for User {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("User")
.field("id", &self.id)
.field("username", &self.username)
.field("password", &"[redacted]")
.finish()
}
}
impl AuthUser for User {
type Id = i64;
fn id(&self) -> Self::Id {
self.id
}
fn session_auth_hash(&self) -> &[u8] {
self.password.as_bytes() // We use the password hash as the auth
// hash--what this means
// is when the user changes their password the
// auth session becomes invalid.
}
}
// This allows us to extract the authentication fields from forms. We use this
// to authenticate requests with the backend.
/// 用于检验用户
/// 用户名和密码
#[derive(Debug, Clone, Deserialize)]
pub struct Credentials {
pub username: String,
pub password: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Credentials_logout {
pub username: String,
}