#!/bin/env python3
import os
import sys
import argparse
import hashlib
import asyncio
import aiohttp
from aiohttp import web
from tqdm import tqdm
from datetime import datetime
from math import ceil

CHUNK_SIZE = 256 * 1024
MAX_BUFFER = 4 * 1024 * 1024
MIN_CHUNK_SIZE = 32 * 1024

def format_size(size_bytes):
    for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
        if size_bytes < 1024:
            return f"{size_bytes:.2f} {unit}"
        size_bytes /= 1024
    return f"{size_bytes:.2f} PB"

async def file_handler(request):
    file_path = os.path.join(os.getcwd(), request.match_info['filename'])

    if not os.path.exists(file_path):
        return web.Response(status=404, text="File not found")
    if not os.path.isfile(file_path):
        return web.Response(status=403, text="Invalid file type")

    stats = request.app['stats']
    stats['total_requests'] += 1
    file_size = os.path.getsize(file_path)
    stats['total_bytes'] += file_size

    headers = {
        "Content-Disposition": f'attachment; filename="{os.path.basename(file_path)}"',
        "Content-Length": str(file_size)
    }

    response = web.StreamResponse(
        status=200,
        headers=headers,
        reason='OK'
    )
    response.content_type = 'application/octet-stream'

    await response.prepare(request)

    try:
        start_time = datetime.now()
        loop = asyncio.get_running_loop()
        
        with open(file_path, 'rb') as f:
            chunk_size = CHUNK_SIZE
            while True:
                chunk = await loop.run_in_executor(None, f.read, chunk_size)
                if not chunk:
                    break
                await response.write(chunk)
                chunk_size = min(MAX_BUFFER, max(MIN_CHUNK_SIZE, chunk_size * 2))

        duration = (datetime.now() - start_time).total_seconds()
        stats['transfer_speeds'].append(file_size / (duration or 0.001))

    except Exception as e:
        stats['transfer_errors'] += 1
        raise
    finally:
        await response.write_eof()

    return response

async def stats_handler(request):
    stats = request.app['stats']
    avg_speed = sum(stats['transfer_speeds'])/len(stats['transfer_speeds']) if stats['transfer_speeds'] else 0
    return web.json_response({
        "total_requests": stats['total_requests'],
        "total_bytes": stats['total_bytes'],
        "average_speed": f"{avg_speed/1e6:.2f} MB/s",
        "transfer_errors": stats['transfer_errors']
    })

def run_server(host='0.0.0.0', port=8080):
    app = web.Application()
    app['stats'] = {
        'total_requests': 0,
        'total_bytes': 0,
        'transfer_speeds': [],
        'transfer_errors': 0
    }

    app.router.add_get('/{filename}', file_handler)
    app.router.add_get('/_stats', stats_handler)

    print(f"šŸ“” Server running at {host}:{port}")
    print(f"šŸ“‚ Shared directory: {os.getcwd()}")
    web.run_app(app, host=host, port=port)

async def calculate_sha1(file_path):
    loop = asyncio.get_running_loop()
    sha1 = hashlib.sha1()
    try:
        file_size = os.path.getsize(file_path)
        with tqdm(total=file_size, unit='B', unit_scale=True,
                  desc=f"šŸ” {os.path.basename(file_path)[:15]}", leave=False,
                  bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") as pbar:
            
            def sync_read():
                with open(file_path, 'rb') as f:
                    while True:
                        chunk = f.read(CHUNK_SIZE)
                        if not chunk:
                            return
                        sha1.update(chunk)
                        pbar.update(len(chunk))
            
            await loop.run_in_executor(None, sync_read)
            
        return sha1.hexdigest()
    except Exception as e:
        raise RuntimeError(f"SHA1 calculation failed: {e}")

async def check_local_files(hash_list):
    valid_files = []
    download_list = []

    with tqdm(total=len(hash_list), desc="šŸ“‚ Checking local files",
              unit="file", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") as pbar:
        for expected_hash, filename in hash_list:
            if not os.path.exists(filename):
                download_list.append((expected_hash, filename))
                pbar.update(1)
                continue

            try:
                current_hash = await calculate_sha1(filename)
                if current_hash == expected_hash:
                    valid_files.append(filename)
                else:
                    download_list.append((expected_hash, filename))
                    await asyncio.to_thread(os.remove, filename)
                    tqdm.write(f"ā™»ļø {filename[:20]}... - Invalid hash, deleted")
            except Exception as e:
                download_list.append((expected_hash, filename))
                tqdm.write(f"āš ļø {filename[:20]}... - Error: {str(e)}")
            pbar.update(1)

    return download_list, valid_files

async def download_file(session, base_url, expected_hash, filename, semaphore, timeout, retries=3):
    for attempt in range(retries):
        async with semaphore:
            temp_filename = f"{filename}.downloading"
            downloaded = 0
            start_time = datetime.now()

            try:
                async with session.get(
                    f"{base_url}/{filename}",
                    timeout=aiohttp.ClientTimeout(total=timeout)
                ) as response:
                    if response.status != 200:
                        tqdm.write(f"āš ļø {filename[:15]}... - HTTP {response.status}")
                        await asyncio.sleep(2 ** attempt)
                        continue

                    total_size = int(response.headers.get('Content-Length', 0))
                    total_fmt = format_size(total_size)

                    with tqdm(total=total_size, unit='B', unit_scale=True,
                              desc=f"ā¬‡ļø {filename[:15]}", leave=False,
                              bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{rate_fmt}]") as pbar:
                        try:
                            loop = asyncio.get_running_loop()
                            with open(temp_filename, 'wb') as f:
                                async for chunk in response.content.iter_chunked(MAX_BUFFER):
                                    await loop.run_in_executor(None, f.write, chunk)
                                    downloaded += len(chunk)
                                    pbar.update(len(chunk))

                            await asyncio.to_thread(os.rename, temp_filename, filename)
                            current_hash = await calculate_sha1(filename)
                            if current_hash != expected_hash:
                                raise ValueError(f"Hash mismatch ({current_hash[:8]}..)")

                            duration = datetime.now().timestamp() - start_time.timestamp()
                            speed = downloaded / max(duration, 0.001)
                            tqdm.write(f"āœ… {filename[:20]}... - {format_size(speed)}/s (Attempt {attempt+1})")
                            return True
                        except Exception as e:
                            raise e
            except Exception as e:
                if os.path.exists(temp_filename):
                    await asyncio.to_thread(os.remove, temp_filename)
                tqdm.write(f"āŒ {filename[:15]}... - {str(e)} (Attempt {attempt+1})")
                await asyncio.sleep(2 ** attempt)
    return False

async def main_client(base_url, hash_file, parallel=None, timeout=3600):
    with open(hash_file) as f:
        hash_list = [line.strip().split(maxsplit=1) for line in f if line.strip()]

    download_list, valid_files = await check_local_files(hash_list)

    print(f"\nāœ… Valid files: {len(valid_files)}")
    print(f"ā¬‡ļø Files to download: {len(download_list)}")

    if not download_list:
        print("\nšŸŽ‰ All files are up to date!")
        return

    if parallel is None:
        parallel = min(ceil(len(download_list)/2), 20)
    parallel = max(1, min(parallel, 50))

    semaphore = asyncio.Semaphore(parallel)
    async with aiohttp.ClientSession() as session:
        tasks = [
            download_file(session, base_url, h, f, semaphore, timeout)
            for h, f in download_list
        ]

        results = []
        with tqdm(total=len(tasks), desc="šŸš€ Download Progress",
                  unit="file", bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") as pbar:
            for future in asyncio.as_completed(tasks):
                result = await future
                results.append(result)
                pbar.update(1)

        success_count = sum(results)
        print(f"\nSuccess: {success_count}/{len(results)} ({success_count/len(results):.1%})")

def parse_args():
    parser = argparse.ArgumentParser(description='File Sharing System')
    subparsers = parser.add_subparsers(dest='mode', required=True)

    server_parser = subparsers.add_parser('server', aliases=['s'], help='Start in server mode')
    server_parser.add_argument('-b', '--bind', default='0.0.0.0', help='IP address to bind (default: 0.0.0.0)')
    server_parser.add_argument('-p', '--port', type=int, default=8080, help='Server port (default: 8080)')

    client_parser = subparsers.add_parser('client', aliases=['c'], help='Start in client mode')
    client_parser.add_argument('-u', '--url', required=True, help='Server URL (e.g.: http://10.7.0.2:8080)')
    client_parser.add_argument('-p', '--parallel', type=int, default=None, 
                              help='Parallel downloads (default: auto)')
    client_parser.add_argument('-t', '--timeout', type=int, default=3600,
                              help='Download timeout in seconds (default: 3600)')
    client_parser.add_argument('hash_list', help='File containing hash list')

    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()

    if args.mode in ['server', 's']:
        run_server(host=args.bind, port=args.port)
    elif args.mode in ['client', 'c']:
        asyncio.run(main_client(
            base_url=args.url,
            hash_file=args.hash_list,
            parallel=args.parallel,
            timeout=args.timeout
        ))