Files
tv/webdav_simulator.py

363 lines
14 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from io import BytesIO
from datetime import datetime
from wsgidav.dav_provider import DAVProvider,_DAVResource
from wsgidav.wsgidav_app import WsgiDAVApp
from wsgidav.util import init_logging
from wsgidav.mw.base_mw import BaseMiddleware
from wsgidav.dir_browser import WsgiDavDirBrowser
from wsgidav.dav_error import HTTP_MEDIATYPE_NOT_SUPPORTED, HTTP_OK, DAVError
from cheroot import wsgi
try:
from wsgidav.util import send_redirect
except ImportError:
# 新版兼容性实现
from http import HTTPStatus
def send_redirect(environ, start_response, new_url, status="307"):
status_code = HTTPStatus.TEMPORARY_REDIRECT
start_response(f"{status_code.value} {status_code.phrase}", [("Location", new_url)])
return []
#import yaml
import sys
import os
import requests
import json
import traceback
import time
import re
from functools import lru_cache
def convert_path(webdav_path: str) -> str:
return f"/{webdav_path.replace('/dav','',1)}"
class Node:
"""虚拟文件系统节点"""
def __init__(self, name, is_dir=False, size=0):
self.name = name
self.is_dir = is_dir
self.children = {}
self.content = b""
self.last_modified = datetime.now()
self.size = size
def parse_paths(text_lines):
"""解析完整路径配置文件"""
root = Node('/', is_dir=True)
for line in text_lines:
line = line.strip()
if not line:
continue
arr = line.split('\t')
line = arr[0]
size = 0
if len(arr)>1:
try:
size = int(arr[1])
except:
print("error line:"+line)
is_dir = line.endswith('/')
stripped = line.strip('/')
parts = stripped.split('/') if stripped else []
current = root
if is_dir:
# 目录路径处理
for part in parts:
if not part:
continue
if part not in current.children:
current.children[part] = Node(part, is_dir=True)
current = current.children[part]
else:
# 文件路径处理
if not parts:
# 根目录文件(如 /file.txt
if stripped not in current.children:
current.children[stripped] = Node(stripped, is_dir=False, size=size)
continue
dirs, file_name = parts[:-1], parts[-1]
for part in dirs:
if not part:
continue
if part not in current.children:
current.children[part] = Node(part, is_dir=True)
current = current.children[part]
if file_name not in current.children:
current.children[file_name] = Node(file_name, is_dir=False, size=size)
return root
class VirtualFSProvider(DAVProvider):
"""WebDAV虚拟文件系统提供程序"""
def __init__(self, root_node, alist_config, sort_reverse=False):
super().__init__()
self.root_node = root_node
self.alist_config = alist_config
self.sort_reverse = sort_reverse
self.serverlist=list()
self.sharedir_re = re.compile(r'我的.*分享')
if self.alist_config["config"] == "":
if self.alist_config["api_url"] != "":
api_url = self.alist_config["api_url"]
if api_url.endswith("/"):
api_url = api_url[:-1]
self.serverlist.append({"api_url":api_url,"token":self.alist_config["token"],"prefix":self.alist_config["prefix"]})
else:
# 解析配置
config_list = self.alist_config["config"].split("#")
for config in config_list:
config = config.split(",")
config[0] = config[0].strip()
if config[0] != "":
configlen=len(config)
for i in range(configlen,3):
config.append("")
if len(config) >= 3:
api_url = config[0]
if api_url.endswith("/"):
api_url = api_url[:-1]
self.serverlist.append({"api_url":api_url,"token":config[1],"prefix":config[2]})
def get_resource_inst(self, path, environ):
if path == '/':
return VirtualResource('/', self.root_node, environ, self.alist_config, sort_reverse=self.sort_reverse)
parts = [p for p in path.strip('/').split('/') if p]
current = self.root_node
for part in parts:
if part in current.children:
current = current.children[part]
else:
return None
return VirtualResource(path, current, environ, self.alist_config, sort_reverse=self.sort_reverse)
@lru_cache(maxsize=1024, typed=True)
def get_redirect_url(self, path):
for server in self.serverlist:
print("server:",server)
if not "rootdirmap" in server:
server["rootdirmap"]=dict()
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
}
if server["token"] != "":
headers["Authorization"] = server["token"]
alist_path = "/"
params = {"path": alist_path}
# print("params:",params)
resp = requests.get(
server["api_url"]+"/api/fs/list",
headers=headers,
params=params,
stream=True
)
resp.raise_for_status()
body = resp.content.decode('utf-8')
print(f"detect rootdir alist_api:{server['api_url']}, headers:{headers}, params:{params}, respcode:{resp.status_code}, respbody:{body[:4096]}")
try:
data_dict = json.loads(body)
dirlist = data_dict['data']['content']
for dir in dirlist:
if True or dir["is_dir"]:
dirname=dir["name"]
match = self.sharedir_re.findall(dirname)
if len(match)>0:
server["rootdirmap"][match[0]]=dirname
except:
traceback.print_exc()
pass
pathsegs=path.split('/')
pathlist=set()
match=self.sharedir_re.findall(pathsegs[1])
if len(match)>0:
if match[0] in server["rootdirmap"]:
pathlist.add(os.path.join("/",server["rootdirmap"][match[0]],*pathsegs[2:]))
else:
pathlist.add(path)
for tmppath in pathlist:
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
}
if server["token"] != "":
headers["Authorization"] = server["token"]
alist_path = os.path.join(server['prefix'],tmppath)
params = {"path": alist_path}
# print("params:",params)
resp = requests.get(
server["api_url"]+"/api/fs/get",
headers=headers,
params=params,
stream=True
)
resp.raise_for_status()
body = resp.content.decode('utf-8')
print(f"query alist_api:{server['api_url']}, headers:{headers}, params:{params}, respcode:{resp.status_code}, respbody:{body[:4096]}")
# 解析 JSON 数据
data_dict = json.loads(body)
# 获取 raw_url
raw_url = data_dict['data']['raw_url']
if raw_url:
# print(f"raw_url: {raw_url}")
# send_redirect(environ, start_response, raw_url)
# return [b"Redirecting..."]
return raw_url
except:
traceback.print_exc()
time.sleep(1)
pass
time.sleep(1)
return ""
def custom_request_handler(self, environ, start_response, default_handler):
"""Optionally implement custom request handling.
requestmethod = environ["REQUEST_METHOD"]
Either
- handle the request completely
- do additional processing and call default_handler(environ, start_response)
"""
requestmethod = environ["REQUEST_METHOD"]
path = environ["PATH_INFO"]
if requestmethod == "GET" and not path.endswith('/') and len(path.split('/'))>2:
redirect_url = self.get_redirect_url(path)
if redirect_url!="":
print(f"redirect_urL:{redirect_url}")
send_redirect(environ, start_response, redirect_url)
return [b"Redirecting..."]
return default_handler(environ, start_response)
class VirtualResource(_DAVResource):
"""虚拟资源实现"""
def __init__(self, path, node, environ, alist_config, sort_reverse=False):
super().__init__(path + '/' if node.is_dir and not path.endswith('/') else path, environ=environ, is_collection=True if node.is_dir else False)
#super().__init__(path + '/' if node.is_dir and not path.endswith('/') else path, environ=environ, is_collection=False if node.is_dir and not path.endswith('/') else False)
self.node = node
self.environ = environ
self.alist_config = alist_config
self.sort_reverse = sort_reverse
def get_display_name(self):
return self.node.name
def is_collection(self):
return self.node.is_dir
def get_content_type(self):
return 'httpd/unix-directory' if self.node.is_dir else 'application/octet-stream'
def get_member_names(self):
if not self.sort_reverse:
return list(self.node.children.keys())
else:
return sorted(list(self.node.children.keys()),reverse=self.sort_reverse)
def get_member(self, name):
child = self.node.children.get(name)
if not child:
return None
child_path = f"{self.path.rstrip('/')}/{name}"
return VirtualResource(child_path, child, self.environ, self.alist_config, sort_reverse=self.sort_reverse)
def support_ranges(self):
return False if self.node.is_dir else True
def support_etag(self):
return False if self.node.is_dir else True
def get_content(self):
raise ValueError("not implement")
# content = "VIRTUAL".encode('utf-8')
# buf = BytesIO(content)
# buf.seek(0)
# return buf
def get_content_length(self):
return self.node.size
def get_property_names(self, is_allprop):
if self.is_collection:
# 返回集合资源的属性名称
return ["{DAV:}resourcetype", "{DAV:}displayname"]
else:
# 返回非集合资源的属性名称
return ["{DAV:}resourcetype", "{DAV:}displayname", "{DAV:}getcontentlength"]
def get_htdocs_path():
if getattr(sys, 'frozen', False): # 判断是否为打包环境
#base_path = sys._MEIPASS # 临时解压目录:ml-citation{ref="7,8" data="citationList"}
#base_path = os.path.dirname(sys.executable)
base_path = os.getcwd()
print(f"is frozen path:{base_path}")
else:
base_path = os.path.dirname(__file__)
print(f"is normal path:{base_path}")
return os.path.join(base_path, "htdocs")
def main():
parser = argparse.ArgumentParser(description='WebDAV路径模拟服务器')
parser.add_argument('input', help='包含完整路径的配置文件')
parser.add_argument('--port', type=int, default=5678, help='监听端口号')
parser.add_argument('--alist_url', type=str, help='真实AList地址')
parser.add_argument('--alist_token', type=str, help='真实AList Token')
parser.add_argument('--alist_prefix', type=str, default="", help='真实AList的路径前缀')
parser.add_argument('--alist_config', type=str, default="", help='一组AList配置,单个配置内逗号分隔,配置间#号分隔')
parser.add_argument('--reverse', type=bool, default=False, help='是否对WebDAV内容逆序输出')
args = parser.parse_args()
with open(args.input, 'r', encoding='utf-8') as f:
lines = f.readlines()
root = parse_paths(lines)
config = {
'host': '0.0.0.0',
'port': args.port,
'provider_mapping':
{
'/dav': VirtualFSProvider(root, {"api_url": args.alist_url, "token": args.alist_token, "prefix": args.alist_prefix, "config": args.alist_config}, sort_reverse=args.reverse),
},
'verbose': 9,
'http_authenticator': {"accept_basic": True},
'auth':'basic',
'simple_dc': {
'user_mapping': {
'/dav': {
'guest':{
'password':'guest_Api789',
}
}
}
},
"dir_browser": {
"enable": True, # 启用目录浏览功能:ml-citation{ref="6" data="citationList"}
"htdocs_path": get_htdocs_path(), # 手动指定绝对路径
},
}
print(config)
app = WsgiDAVApp(config)
init_logging(config)
server = wsgi.Server(
(config['host'], config['port']),
app,
server_name='Path-Based WebDAV Server'
)
print(f"Server running on http://{config['host']}:{config['port']}/dav")
try:
server.start()
except KeyboardInterrupt:
print("\nServer stopped.")
if __name__ == '__main__':
main()