Skip to content

Add biomedical multimodal dataset preparation tools for Gemma fine-tu… #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions BIOMEDICAL_DATASET_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Biomedical Multimodal Dataset Preparation for Gemma Fine-tuning

This repository contains tools and scripts for preparing biomedical multimodal datasets (text + images/tables/formulas) for fine-tuning Gemma models.

## Overview

The process involves three main steps:

1. **Preprocessing PDFs**: Extract text, images, tables, and formulas from biomedical PDFs
2. **Creating a Dataset**: Structure the extracted content into a format suitable for Gemma fine-tuning
3. **Fine-tuning Gemma**: Fine-tune a Gemma model on the prepared dataset

## Requirements

Install the required dependencies:

```bash
pip install -r requirements.txt
```

## 1. Preprocessing PDFs

The `preprocess_pdfs.py` script extracts content from PDFs and converts it to appropriate formats:

- **Text**: Extracted as Markdown
- **Images**: Extracted as PNG/JPEG files
- **Tables**: Converted to Markdown format
- **Formulas**: Converted to LaTeX format

### Usage

```bash
python preprocess_pdfs.py --input_dir /path/to/pdfs --output_dir /path/to/preprocessed
```

### Output Structure

```
/path/to/preprocessed/
├── document1/
│ ├── document1.md
│ ├── images/
│ │ ├── img_1_1.png
│ │ ├── img_1_2.png
│ │ └── ...
│ ├── tables/
│ │ ├── table_1.md
│ │ ├── table_2.md
│ │ └── ...
│ └── formulas/
│ ├── formula_1_1.tex
│ ├── formula_1_1.png
│ └── ...
└── document2/
└── ...
```

## 2. Creating a Dataset

The `create_dataset.py` script structures the preprocessed content into a format suitable for Gemma fine-tuning:

### Usage

```bash
python create_dataset.py --input_dir /path/to/preprocessed --output_dir /path/to/dataset
```

### Output Structure

```
/path/to/dataset/
├── train/
│ └── data.json
├── validation/
│ └── data.json
└── images/
├── img_1_1.png
├── img_1_2.png
└── ...
```

### Dataset Format

The dataset is structured as a JSON file with the following format:

```json
[
{
"input_text": "Text with <start_of_image> tags and LaTeX formulas",
"output_text": "Expected output for instruction tuning",
"images": ["path/to/image1.png", "path/to/image2.png"]
},
...
]
```

## 3. Fine-tuning Gemma

The `finetune_gemma.py` script fine-tunes a Gemma model on the prepared dataset:

### Usage

```bash
python finetune_gemma.py --dataset_dir /path/to/dataset --output_dir /path/to/model
```

## Best Practices

### Handling Images

- Use `<start_of_image>` tags in the text to indicate where images should appear
- Ensure images are properly sized (800x800 pixels is recommended)
- Include descriptive alt text or captions for images

### Handling Tables

- Use Markdown table format to preserve structure
- Keep tables simple and well-formatted
- Ensure column headers are clear and descriptive

### Handling Formulas

- Use LaTeX format for mathematical formulas
- Enclose inline formulas with single dollar signs: `$formula$`
- Enclose block formulas with double dollar signs: `$$formula$$`

### Dataset Preparation

- Balance the dataset with a variety of content types
- Ensure high-quality text and image content
- Provide clear instruction-response pairs for fine-tuning
- Split the dataset into training and validation sets (90/10 split is recommended)

## Limitations and Considerations

- Complex tables may not be perfectly preserved in Markdown format
- Formula extraction accuracy depends on the quality of the PDF
- Very large images may need to be resized or split
- Some special characters in LaTeX formulas may require escaping

## Requirements File

A `requirements.txt` file is included with all necessary dependencies.

## License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.
145 changes: 145 additions & 0 deletions create_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3
"""
Create a dataset for Gemma fine-tuning from preprocessed biomedical PDFs.

This script creates a dataset structure suitable for Gemma fine-tuning from
preprocessed biomedical PDFs.

Usage:
python create_dataset.py --input_dir /path/to/preprocessed --output_dir /path/to/dataset
"""

import argparse
import os
import json
import random
import shutil
from pathlib import Path
from typing import Dict, List, Tuple, Any

def create_training_example(markdown_path: str, image_map: Dict[str, str]) -> Dict[str, Any]:
"""
Create a training example from a preprocessed markdown file.

Args:
markdown_path: Path to the markdown file
image_map: Dictionary mapping image IDs to file paths

Returns:
Dictionary with training example data
"""
# Read markdown content
with open(markdown_path, "r") as md_file:
markdown_content = md_file.read()

# Split content into sections (simplified approach)
sections = markdown_content.split("## Page")

# Create examples
examples = []

for section in sections:
if not section.strip():
continue

# Check if section contains image references
has_images = "<start_of_image>" in section

# Get image paths for this section
section_images = []
if has_images:
# Extract image IDs from the section
for image_id in image_map:
if image_id in section:
section_images.append(image_map[image_id])

# Create example
example = {
"input_text": section.strip(),
"output_text": "", # This would be filled with expected output for instruction tuning
"images": section_images
}

examples.append(example)

return examples

def create_dataset(input_dir: str, output_dir: str, split_ratio: float = 0.9) -> None:
"""
Create a dataset for Gemma fine-tuning.

Args:
input_dir: Directory containing preprocessed content
output_dir: Directory to save the dataset
split_ratio: Train/validation split ratio
"""
# Create output directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, "train"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "validation"), exist_ok=True)
os.makedirs(os.path.join(output_dir, "images"), exist_ok=True)

# Get list of preprocessed directories
preprocessed_dirs = [d for d in os.listdir(input_dir) if os.path.isdir(os.path.join(input_dir, d))]

# Process each directory
all_examples = []

for dir_name in preprocessed_dirs:
dir_path = os.path.join(input_dir, dir_name)

# Find markdown file
markdown_files = [f for f in os.listdir(dir_path) if f.endswith(".md")]
if not markdown_files:
continue

markdown_path = os.path.join(dir_path, markdown_files[0])

# Find image map
image_dir = os.path.join(dir_path, "images")
if os.path.exists(image_dir):
image_files = [f for f in os.listdir(image_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
image_map = {os.path.splitext(f)[0]: os.path.join(image_dir, f) for f in image_files}
else:
image_map = {}

# Create examples
examples = create_training_example(markdown_path, image_map)
all_examples.extend(examples)

# Shuffle examples
random.shuffle(all_examples)

# Split into train and validation sets
split_idx = int(len(all_examples) * split_ratio)
train_examples = all_examples[:split_idx]
val_examples = all_examples[split_idx:]

# Copy images to dataset directory and update paths
for example in all_examples:
for i, img_path in enumerate(example["images"]):
img_filename = os.path.basename(img_path)
new_img_path = os.path.join(output_dir, "images", img_filename)
shutil.copy(img_path, new_img_path)
example["images"][i] = os.path.relpath(new_img_path, output_dir)

# Save train and validation sets
with open(os.path.join(output_dir, "train", "data.json"), "w") as f:
json.dump(train_examples, f, indent=2)

with open(os.path.join(output_dir, "validation", "data.json"), "w") as f:
json.dump(val_examples, f, indent=2)

print(f"Created dataset with {len(train_examples)} training examples and {len(val_examples)} validation examples")

def main():
parser = argparse.ArgumentParser(description="Create a dataset for Gemma fine-tuning")
parser.add_argument("--input_dir", required=True, help="Directory containing preprocessed content")
parser.add_argument("--output_dir", required=True, help="Directory to save the dataset")
parser.add_argument("--split_ratio", type=float, default=0.9, help="Train/validation split ratio")
args = parser.parse_args()

create_dataset(args.input_dir, args.output_dir, args.split_ratio)

if __name__ == "__main__":
main()
Loading