import argparse
import pandas as pd

def parse_attributes(attributes_str):
    """ 解析 GTF 文件中的属性字段 """
    attributes = {}
    for attr in attributes_str.split(';'):
        attr = attr.strip()
        if attr:
            parts = attr.split(' ', 1)
            if len(parts) == 2:
                key, value = parts
                attributes[key] = value.strip('"')
    return attributes

def read_gtf(gtf_file):
    """ 从 GTF 文件读取数据 """
    data = []
    with open(gtf_file, 'r') as file:
        for line in file:
            if line.startswith('#'):
                continue
            fields = line.strip().split('\t')
            attributes = parse_attributes(fields[8])
            data.append({
                'chrom': fields[0],
                'source': fields[1],
                'feature': fields[2],
                'start': int(fields[3]),
                'end': int(fields[4]),
                'score': fields[5],
                'strand': fields[6],
                'frame': fields[7],
                'attributes': fields[8],
                'transcript_id': attributes.get('transcript_id', ''),
                'gene_id': attributes.get('gene_id', ''),
                'line': line.strip()  # 保存整行
            })
    return pd.DataFrame(data)

def longest_transcripts(df):
    """ 找到每个基因的最长转录本 """
    return df[df['feature'] == 'transcript'].sort_values(['gene_id', 'end'], ascending=[True, False]).drop_duplicates('gene_id')

def to_bed6(df):
    """ 将数据转换为 BED6 格式 """
    df['score'] = "."
    df['name'] = df['gene_id']
    return df[['chrom', 'start', 'end', 'name', 'score', 'strand']]

def write_filtered_gtf(df, longest_transcripts, gtf_out_file):
    """ 将最长转录本的相关行和对应基因的行写入新的 GTF 文件 """
    longest_transcript_ids = set(longest_transcripts['transcript_id'])
    with open(gtf_out_file, 'w') as f:
        for _, row in df.iterrows():
            if row['feature'] == 'gene' or row['transcript_id'] in longest_transcript_ids:
                f.write(row['line'] + '\n')

def main():
    parser = argparse.ArgumentParser(description="从 GTF 文件提取最长转录本并保存为 BED6 和 GTF 格式")
    parser.add_argument("gtf_file", help="输入的 GTF 文件路径")
    parser.add_argument("bed_file", help="输出的 BED6 格式文件路径")
    parser.add_argument("gtf_out_file", help="输出的 GTF 格式文件路径")
    args = parser.parse_args()

    df = read_gtf(args.gtf_file)
    longest_df = longest_transcripts(df)
    bed6_df = to_bed6(longest_df)
    bed6_df.to_csv(args.bed_file, sep='\t', index=False, header=False)

    write_filtered_gtf(df, longest_df, args.gtf_out_file)

if __name__ == "__main__":
    main()