Files
tv/webdav_simulator.py

363 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from io import BytesIO
from datetime import datetime
from wsgidav.dav_provider import DAVProvider,_DAVResource
from wsgidav.wsgidav_app import WsgiDAVApp
from wsgidav.util import init_logging
from wsgidav.mw.base_mw import BaseMiddleware
from wsgidav.dir_browser import WsgiDavDirBrowser
from wsgidav.dav_error import HTTP_MEDIATYPE_NOT_SUPPORTED, HTTP_OK, DAVError
from cheroot import wsgi
try:
from wsgidav.util import send_redirect
except ImportError:
# 新版兼容性实现
from http import HTTPStatus
def send_redirect(environ, start_response, new_url, status="307"):
status_code = HTTPStatus.TEMPORARY_REDIRECT
start_response(f"{status_code.value} {status_code.phrase}", [("Location", new_url)])
return []
#import yaml
import sys
import os
import requests
import json
import traceback
import time
import re
from functools import lru_cache
def convert_path(webdav_path: str) -> str:
return f"/{webdav_path.replace('/dav','',1)}"
class Node:
"""虚拟文件系统节点"""
def __init__(self, name, is_dir=False, size=0):
self.name = name
self.is_dir = is_dir
self.children = {}
self.content = b""
self.last_modified = datetime.now()
self.size = size
def parse_paths(text_lines):
"""解析完整路径配置文件"""
root = Node('/', is_dir=True)
for line in text_lines:
line = line.strip()
if not line:
continue
arr = line.split('\t')
line = arr[0]
size = 0
if len(arr)>1:
try:
size = int(arr[1])
except:
print("error line:"+line)
is_dir = line.endswith('/')
stripped = line.strip('/')
parts = stripped.split('/') if stripped else []
current = root
if is_dir:
# 目录路径处理
for part in parts:
if not part:
continue
if part not in current.children:
current.children[part] = Node(part, is_dir=True)
current = current.children[part]
else:
# 文件路径处理
if not parts:
# 根目录文件(如 /file.txt
if stripped not in current.children:
current.children[stripped] = Node(stripped, is_dir=False, size=size)
continue
dirs, file_name = parts[:-1], parts[-1]
for part in dirs:
if not part:
continue
if part not in current.children:
current.children[part] = Node(part, is_dir=True)
current = current.children[part]
if file_name not in current.children:
current.children[file_name] = Node(file_name, is_dir=False, size=size)
return root
class VirtualFSProvider(DAVProvider):
"""WebDAV虚拟文件系统提供程序"""
def __init__(self, root_node, alist_config, sort_reverse=False):
super().__init__()
self.root_node = root_node
self.alist_config = alist_config
self.sort_reverse = sort_reverse
self.serverlist=list()
self.sharedir_re = re.compile(r'我的.*分享')
if self.alist_config["config"] == "":
if self.alist_config["api_url"] != "":
api_url = self.alist_config["api_url"]
if api_url.endswith("/"):
api_url = api_url[:-1]
self.serverlist.append({"api_url":api_url,"token":self.alist_config["token"],"prefix":self.alist_config["prefix"]})
else:
# 解析配置
config_list = self.alist_config["config"].split("#")
for config in config_list:
config = config.split(",")
config[0] = config[0].strip()
if config[0] != "":
configlen=len(config)
for i in range(configlen,3):
config.append("")
if len(config) >= 3:
api_url = config[0]
if api_url.endswith("/"):
api_url = api_url[:-1]
self.serverlist.append({"api_url":api_url,"token":config[1],"prefix":config[2]})
def get_resource_inst(self, path, environ):
if path == '/':
return VirtualResource('/', self.root_node, environ, self.alist_config, sort_reverse=self.sort_reverse)
parts = [p for p in path.strip('/').split('/') if p]
current = self.root_node
for part in parts:
if part in current.children:
current = current.children[part]
else:
return None
return VirtualResource(path, current, environ, self.alist_config, sort_reverse=self.sort_reverse)
@lru_cache(maxsize=1024, typed=True)
def get_redirect_url(self, path):
for server in self.serverlist:
print("server:",server)
if not "rootdirmap" in server:
server["rootdirmap"]=dict()
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
}
if server["token"] != "":
headers["Authorization"] = server["token"]
alist_path = "/"
params = {"path": alist_path}
# print("params:",params)
resp = requests.get(
server["api_url"]+"/api/fs/list",
headers=headers,
params=params,
stream=True
)
resp.raise_for_status()
body = resp.content.decode('utf-8')
print(f"detect rootdir alist_api:{server['api_url']}, headers:{headers}, params:{params}, respcode:{resp.status_code}, respbody:{body[:4096]}")
try:
data_dict = json.loads(body)
dirlist = data_dict['data']['content']
for dir in dirlist:
if True or dir["is_dir"]:
dirname=dir["name"]
match = self.sharedir_re.findall(dirname)
if len(match)>0:
server["rootdirmap"][match[0]]=dirname
except:
traceback.print_exc()
pass
pathsegs=path.split('/')
pathlist=set()
match=self.sharedir_re.findall(pathsegs[1])
if len(match)>0:
if match[0] in server["rootdirmap"]:
pathlist.add(os.path.join("/",server["rootdirmap"][match[0]],*pathsegs[2:]))
else:
pathlist.add(path)
for tmppath in pathlist:
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36",
}
if server["token"] != "":
headers["Authorization"] = server["token"]
alist_path = os.path.join(server['prefix'],tmppath)
params = {"path": alist_path}
# print("params:",params)
resp = requests.get(
server["api_url"]+"/api/fs/get",
headers=headers,
params=params,
stream=True
)
resp.raise_for_status()
body = resp.content.decode('utf-8')
print(f"query alist_api:{server['api_url']}, headers:{headers}, params:{params}, respcode:{resp.status_code}, respbody:{body[:4096]}")
# 解析 JSON 数据
data_dict = json.loads(body)
# 获取 raw_url
raw_url = data_dict['data']['raw_url']
if raw_url:
# print(f"raw_url: {raw_url}")
# send_redirect(environ, start_response, raw_url)
# return [b"Redirecting..."]
return raw_url
except:
traceback.print_exc()
time.sleep(1)
pass
time.sleep(1)
return ""
def custom_request_handler(self, environ, start_response, default_handler):
"""Optionally implement custom request handling.
requestmethod = environ["REQUEST_METHOD"]
Either
- handle the request completely
- do additional processing and call default_handler(environ, start_response)
"""
requestmethod = environ["REQUEST_METHOD"]
path = environ["PATH_INFO"]
if requestmethod == "GET" and not path.endswith('/') and len(path.split('/'))>2:
redirect_url = self.get_redirect_url(path)
if redirect_url!="":
print(f"redirect_urL:{redirect_url}")
send_redirect(environ, start_response, redirect_url)
return [b"Redirecting..."]
return default_handler(environ, start_response)
class VirtualResource(_DAVResource):
"""虚拟资源实现"""
def __init__(self, path, node, environ, alist_config, sort_reverse=False):
super().__init__(path + '/' if node.is_dir and not path.endswith('/') else path, environ=environ, is_collection=True if node.is_dir else False)
#super().__init__(path + '/' if node.is_dir and not path.endswith('/') else path, environ=environ, is_collection=False if node.is_dir and not path.endswith('/') else False)
self.node = node
self.environ = environ
self.alist_config = alist_config
self.sort_reverse = sort_reverse
def get_display_name(self):
return self.node.name
def is_collection(self):
return self.node.is_dir
def get_content_type(self):
return 'httpd/unix-directory' if self.node.is_dir else 'application/octet-stream'
def get_member_names(self):
if not self.sort_reverse:
return list(self.node.children.keys())
else:
return sorted(list(self.node.children.keys()),reverse=self.sort_reverse)
def get_member(self, name):
child = self.node.children.get(name)
if not child:
return None
child_path = f"{self.path.rstrip('/')}/{name}"
return VirtualResource(child_path, child, self.environ, self.alist_config, sort_reverse=self.sort_reverse)
def support_ranges(self):
return False if self.node.is_dir else True
def support_etag(self):
return False if self.node.is_dir else True
def get_content(self):
raise ValueError("not implement")
# content = "VIRTUAL".encode('utf-8')
# buf = BytesIO(content)
# buf.seek(0)
# return buf
def get_content_length(self):
return self.node.size
def get_property_names(self, is_allprop):
if self.is_collection:
# 返回集合资源的属性名称
return ["{DAV:}resourcetype", "{DAV:}displayname"]
else:
# 返回非集合资源的属性名称
return ["{DAV:}resourcetype", "{DAV:}displayname", "{DAV:}getcontentlength"]
def get_htdocs_path():
if getattr(sys, 'frozen', False): # 判断是否为打包环境
#base_path = sys._MEIPASS # 临时解压目录:ml-citation{ref="7,8" data="citationList"}
#base_path = os.path.dirname(sys.executable)
base_path = os.getcwd()
print(f"is frozen path:{base_path}")
else:
base_path = os.path.dirname(__file__)
print(f"is normal path:{base_path}")
return os.path.join(base_path, "htdocs")
def main():
parser = argparse.ArgumentParser(description='WebDAV路径模拟服务器')
parser.add_argument('input', help='包含完整路径的配置文件')
parser.add_argument('--port', type=int, default=5678, help='监听端口号')
parser.add_argument('--alist_url', type=str, help='真实AList地址')
parser.add_argument('--alist_token', type=str, help='真实AList Token')
parser.add_argument('--alist_prefix', type=str, default="", help='真实AList的路径前缀')
parser.add_argument('--alist_config', type=str, default="", help='一组AList配置,单个配置内逗号分隔,配置间#号分隔')
parser.add_argument('--reverse', type=bool, default=False, help='是否对WebDAV内容逆序输出')
args = parser.parse_args()
with open(args.input, 'r', encoding='utf-8') as f:
lines = f.readlines()
root = parse_paths(lines)
config = {
'host': '0.0.0.0',
'port': args.port,
'provider_mapping':
{
'/dav': VirtualFSProvider(root, {"api_url": args.alist_url, "token": args.alist_token, "prefix": args.alist_prefix, "config": args.alist_config}, sort_reverse=args.reverse),
},
'verbose': 9,
'http_authenticator': {"accept_basic": True},
'auth':'basic',
'simple_dc': {
'user_mapping': {
'/dav': {
'guest':{
'password':'guest_Api789',
}
}
}
},
"dir_browser": {
"enable": True, # 启用目录浏览功能:ml-citation{ref="6" data="citationList"}
"htdocs_path": get_htdocs_path(), # 手动指定绝对路径
},
}
print(config)
app = WsgiDAVApp(config)
init_logging(config)
server = wsgi.Server(
(config['host'], config['port']),
app,
server_name='Path-Based WebDAV Server'
)
print(f"Server running on http://{config['host']}:{config['port']}/dav")
try:
server.start()
except KeyboardInterrupt:
print("\nServer stopped.")
if __name__ == '__main__':
main()